1use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, DatabaseBackend};
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{error, info, warn};
10
11#[derive(Error, Debug)]
13pub enum DatabaseError {
14 #[error("数据库连接失败: {0}")]
15 ConnectionFailed(String),
16 #[error("设置 Schema 失败: {0}")]
17 SchemaSetFailed(String),
18 #[error("无效的配置: {0}")]
19 InvalidConfig(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DatabaseConfig {
25 pub database_type: String,
27 pub host: String,
29 pub port: u16,
31 pub username: String,
33 pub password: String,
35 pub database_name: String,
37 #[serde(default = "default_schema")]
39 pub schema: String,
40 #[serde(default = "default_logging_level")]
42 pub logging_level: String,
43 #[serde(default = "default_use_pgbouncer")]
45 pub use_pgbouncer: bool,
46}
47
48fn default_schema() -> String {
49 "public".to_string()
50}
51
52fn default_logging_level() -> String {
53 "info".to_string()
54}
55
56fn default_use_pgbouncer() -> bool {
57 false
58}
59
60#[derive(Debug, Clone)]
62pub struct ConnectionOptions {
63 pub max_connections: u32,
65 pub min_connections: u32,
67 pub connect_timeout: u64,
69 pub acquire_timeout: u64,
71 pub idle_timeout: u64,
73 pub max_lifetime: u64,
75 pub sqlx_logging: bool,
77}
78
79impl Default for ConnectionOptions {
80 fn default() -> Self {
81 Self {
82 max_connections: 300,
83 min_connections: 5,
84 connect_timeout: 8,
85 acquire_timeout: 8,
86 idle_timeout: 600, max_lifetime: 1800, sqlx_logging: true,
89 }
90 }
91}
92
93pub struct DatabaseService;
95
96impl DatabaseService {
97 pub fn build_database_url(config: &DatabaseConfig) -> String {
102 let encoded_username = urlencoding::encode(&config.username);
106 let encoded_password = urlencoding::encode(&config.password);
107 let encoded_host = if config.host.contains(':') && !config.host.starts_with('[') {
109 format!("[{}]", config.host)
111 } else {
112 config.host.clone()
113 };
114 let encoded_database = urlencoding::encode(&config.database_name);
115
116 format!(
117 "{}://{}:{}@{}:{}/{}",
118 config.database_type,
119 encoded_username,
120 encoded_password,
121 encoded_host,
122 config.port,
123 encoded_database
124 )
125 }
126
127 pub async fn create_connection(
129 config: &DatabaseConfig,
130 schema: Option<&str>,
131 options: Option<ConnectionOptions>,
132 ) -> Result<DatabaseConnection, DatabaseError> {
133 let database_url = Self::build_database_url(config);
134 let target_schema = schema.unwrap_or(&config.schema);
135
136 info!("正在连接数据库...");
137 info!(
139 "数据库地址: {}@{}:{}",
140 config.database_name, config.host, config.port
141 );
142 info!("使用 Schema: {}", target_schema);
143
144 let opts = options.unwrap_or_default();
145 let mut connect_options = ConnectOptions::new(&database_url);
146 let log_level = Self::parse_log_level(&config.logging_level);
147
148 connect_options
149 .max_connections(opts.max_connections)
150 .min_connections(opts.min_connections)
151 .connect_timeout(Duration::from_secs(opts.connect_timeout))
152 .acquire_timeout(Duration::from_secs(opts.acquire_timeout))
153 .idle_timeout(Duration::from_secs(opts.idle_timeout))
154 .max_lifetime(Duration::from_secs(opts.max_lifetime))
155 .sqlx_logging(opts.sqlx_logging)
156 .sqlx_logging_level(log_level);
157
158 let db = Database::connect(connect_options).await.map_err(|e| {
159 error!("数据库连接失败: {}", e);
160 DatabaseError::ConnectionFailed(format!("数据库连接失败: {}", e))
161 })?;
162
163 let backend = db.get_database_backend();
165 match backend {
166 DatabaseBackend::Postgres => {
167 if !config.use_pgbouncer {
169 Self::set_extra_float_digits(&db).await?;
170 } else {
171 info!("使用 PgBouncer,跳过设置 extra_float_digits");
172 }
173
174 match Self::set_schema(&db, target_schema).await {
178 Ok(_) => {}
179 Err(e) => {
180 if config.use_pgbouncer {
181 warn!(
182 "设置 schema 失败(可能是 PgBouncer transaction 模式导致): {}",
183 e
184 );
185 warn!(
186 "建议:1) 改用 session 模式,或 2) 在数据库层面设置: ALTER DATABASE {} SET search_path TO {}",
187 config.database_name, target_schema
188 );
189 } else {
190 return Err(e);
191 }
192 }
193 }
194 }
195 DatabaseBackend::MySql => {
196 info!("MySQL 数据库连接已建立");
198 }
199 DatabaseBackend::Sqlite => {
200 info!("SQLite 数据库连接已建立");
202 }
203 _ => {
204 info!("数据库连接已建立");
206 }
207 }
208
209 info!("✓ 数据库连接成功");
210 Ok(db)
211 }
212
213 pub async fn init(
215 config: &DatabaseConfig,
216 options: Option<ConnectionOptions>,
217 ) -> Result<DatabaseConnection, DatabaseError> {
218 Self::create_connection(config, None, options).await
219 }
220
221 pub async fn set_extra_float_digits(db: &DatabaseConnection) -> Result<(), DatabaseError> {
223 db.execute_unprepared("SET extra_float_digits = 0")
224 .await
225 .map_err(|e| {
226 error!("设置 extra_float_digits 失败: {}", e);
227 DatabaseError::SchemaSetFailed(format!("设置 extra_float_digits 失败: {}", e))
228 })?;
229 info!("✓ 已设置 extra_float_digits = 0");
230 Ok(())
231 }
232
233 fn validate_schema_name(schema: &str) -> Result<(), DatabaseError> {
238 if schema.is_empty() {
242 return Err(DatabaseError::InvalidConfig(
243 "Schema 名称不能为空".to_string(),
244 ));
245 }
246
247 if !schema
249 .chars()
250 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
251 {
252 return Err(DatabaseError::InvalidConfig(format!(
253 "Schema 名称包含非法字符: {},只允许字母、数字、下划线和连字符",
254 schema
255 )));
256 }
257
258 if schema.len() > 63 {
260 return Err(DatabaseError::InvalidConfig(format!(
261 "Schema 名称过长: {},最大长度为 63 字符",
262 schema
263 )));
264 }
265
266 Ok(())
267 }
268
269 pub async fn set_schema(db: &DatabaseConnection, schema: &str) -> Result<(), DatabaseError> {
274 Self::validate_schema_name(schema)?;
276
277 let sql = format!("SET search_path TO {}", schema);
280 db.execute_unprepared(&sql).await.map_err(|e| {
281 error!("设置 schema 失败: {}", e);
282 DatabaseError::SchemaSetFailed(format!("设置 schema 失败: {}", e))
283 })?;
284 info!("✓ 已设置 search_path 到 schema: {}", schema);
285 Ok(())
286 }
287
288 pub async fn test_connection(db: &DatabaseConnection) -> Result<(), DatabaseError> {
290 db.execute_unprepared("SELECT 1").await.map_err(|e| {
291 error!("数据库连接测试失败: {}", e);
292 DatabaseError::ConnectionFailed(format!("数据库连接测试失败: {}", e))
293 })?;
294 info!("✓ 数据库连接测试成功");
295 Ok(())
296 }
297
298 fn parse_log_level(level: &str) -> log::LevelFilter {
300 match level.to_lowercase().as_str() {
301 "off" => log::LevelFilter::Off,
302 "error" => log::LevelFilter::Error,
303 "warn" => log::LevelFilter::Warn,
304 "info" => log::LevelFilter::Info,
305 "debug" => log::LevelFilter::Debug,
306 "trace" => log::LevelFilter::Trace,
307 _ => {
308 warn!("未知的日志级别 '{}', 使用默认级别 Info", level);
309 log::LevelFilter::Info
310 }
311 }
312 }
313
314 pub fn validate_config(config: &DatabaseConfig) -> Result<(), DatabaseError> {
323 if config.host.is_empty() {
324 return Err(DatabaseError::InvalidConfig(
325 "数据库主机不能为空".to_string(),
326 ));
327 }
328
329 if config.port == 0 {
330 return Err(DatabaseError::InvalidConfig(
331 format!("数据库端口无效: {},有效范围: 1-65535", config.port)
332 ));
333 }
334
335 if config.database_name.is_empty() {
336 return Err(DatabaseError::InvalidConfig(
337 "数据库名称不能为空".to_string(),
338 ));
339 }
340
341 if config.database_type != "sqlite" {
343 if config.username.is_empty() {
344 return Err(DatabaseError::InvalidConfig(
345 "数据库用户名不能为空".to_string(),
346 ));
347 }
348 }
349
350 if !["postgres", "mysql", "sqlite"].contains(&config.database_type.as_str()) {
351 return Err(DatabaseError::InvalidConfig(
352 format!(
353 "不支持的数据库类型: {},支持的类型: postgres, mysql, sqlite",
354 config.database_type
355 )
356 ));
357 }
358
359 let valid_log_levels = ["off", "error", "warn", "info", "debug", "trace"];
361 if !valid_log_levels.contains(&config.logging_level.to_lowercase().as_str()) {
362 warn!(
363 "无效的日志级别: {},将使用默认级别 info",
364 config.logging_level
365 );
366 }
367
368 Ok(())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_build_database_url() {
378 let config = DatabaseConfig {
379 database_type: "postgres".to_string(),
380 host: "localhost".to_string(),
381 port: 5432,
382 username: "user".to_string(),
383 password: "pass".to_string(),
384 database_name: "testdb".to_string(),
385 schema: "public".to_string(),
386 logging_level: "info".to_string(),
387 use_pgbouncer: false,
388 };
389
390 let url = DatabaseService::build_database_url(&config);
391 assert_eq!(url, "postgres://user:pass@localhost:5432/testdb");
392 }
393
394 #[test]
395 fn test_validate_config() {
396 let valid_config = DatabaseConfig {
397 database_type: "postgres".to_string(),
398 host: "localhost".to_string(),
399 port: 5432,
400 username: "user".to_string(),
401 password: "pass".to_string(),
402 database_name: "testdb".to_string(),
403 schema: "public".to_string(),
404 logging_level: "info".to_string(),
405 use_pgbouncer: false,
406 };
407
408 assert!(DatabaseService::validate_config(&valid_config).is_ok());
409
410 let invalid_config = DatabaseConfig {
411 database_type: "postgres".to_string(),
412 host: "".to_string(),
413 port: 5432,
414 username: "user".to_string(),
415 password: "pass".to_string(),
416 database_name: "testdb".to_string(),
417 schema: "public".to_string(),
418 logging_level: "info".to_string(),
419 use_pgbouncer: false,
420 };
421
422 assert!(DatabaseService::validate_config(&invalid_config).is_err());
423 }
424}