secra_database/
connection.rs

1//! 数据库连接管理服务
2//!
3//! 提供数据库连接的创建、管理和配置功能
4
5use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, DatabaseBackend};
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{error, info, warn};
10
11/// 数据库连接错误
12#[derive(Error, Debug)]
13pub enum DatabaseError {
14    #[error("数据库连接失败: {0}")]
15    ConnectionFailed(String),
16    #[error("设置 Schema 失败: {0}")]
17    SchemaSetFailed(String),
18    #[error("无效的配置: {0}")]
19    InvalidConfig(String),
20}
21
22/// 数据库配置
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct DatabaseConfig {
25    /// 数据库类型(postgres, mysql, sqlite)
26    pub database_type: String,
27    /// 数据库主机
28    pub host: String,
29    /// 数据库端口
30    pub port: u16,
31    /// 数据库用户名
32    pub username: String,
33    /// 数据库密码
34    pub password: String,
35    /// 数据库名称
36    pub database_name: String,
37    /// 默认 Schema(PostgreSQL)
38    #[serde(default = "default_schema")]
39    pub schema: String,
40    /// 日志级别
41    #[serde(default = "default_logging_level")]
42    pub logging_level: String,
43    /// 是否使用 PgBouncer(如果使用 PgBouncer,则跳过设置某些运行时参数)
44    #[serde(default = "default_use_pgbouncer")]
45    pub use_pgbouncer: bool,
46}
47
48fn default_schema() -> String {
49    "public".to_string()
50}
51
52fn default_logging_level() -> String {
53    "info".to_string()
54}
55
56fn default_use_pgbouncer() -> bool {
57    false
58}
59
60/// 连接选项配置
61#[derive(Debug, Clone)]
62pub struct ConnectionOptions {
63    /// 最大连接数
64    pub max_connections: u32,
65    /// 最小连接数
66    pub min_connections: u32,
67    /// 连接超时时间(秒)
68    pub connect_timeout: u64,
69    /// 获取连接超时时间(秒)
70    pub acquire_timeout: u64,
71    /// 空闲连接超时时间(秒)
72    pub idle_timeout: u64,
73    /// 连接最大生命周期(秒)
74    pub max_lifetime: u64,
75    /// 是否启用 SQL 日志
76    pub sqlx_logging: bool,
77}
78
79impl Default for ConnectionOptions {
80    fn default() -> Self {
81        Self {
82            max_connections: 300,
83            min_connections: 5,
84            connect_timeout: 8,
85            acquire_timeout: 8,
86            idle_timeout: 600,  // 10分钟(优化:增加空闲连接超时时间)
87            max_lifetime: 1800, // 30分钟(优化:增加连接最大生命周期,避免频繁重建连接)
88            sqlx_logging: true,
89        }
90    }
91}
92
93/// 数据库服务
94pub struct DatabaseService;
95
96impl DatabaseService {
97    /// 构建数据库连接 URL
98    ///
99    /// # 安全性
100    /// 使用 URL 编码确保特殊字符(如密码中的特殊字符)被正确编码
101    pub fn build_database_url(config: &DatabaseConfig) -> String {
102        // 对用户名和密码进行 URL 编码,避免特殊字符导致的问题
103        // 对用户名、密码、主机和数据库名进行 URL 编码
104        // 注意:只编码必要的部分,避免过度编码
105        let encoded_username = urlencoding::encode(&config.username);
106        let encoded_password = urlencoding::encode(&config.password);
107        // 主机名处理:IPv6 地址需要特殊处理
108        let encoded_host = if config.host.contains(':') && !config.host.starts_with('[') {
109            // IPv6 地址需要用方括号包裹
110            format!("[{}]", config.host)
111        } else {
112            config.host.clone()
113        };
114        let encoded_database = urlencoding::encode(&config.database_name);
115        
116        format!(
117            "{}://{}:{}@{}:{}/{}",
118            config.database_type,
119            encoded_username,
120            encoded_password,
121            encoded_host,
122            config.port,
123            encoded_database
124        )
125    }
126
127    /// 创建数据库连接(支持指定 schema)
128    pub async fn create_connection(
129        config: &DatabaseConfig,
130        schema: Option<&str>,
131        options: Option<ConnectionOptions>,
132    ) -> Result<DatabaseConnection, DatabaseError> {
133        let database_url = Self::build_database_url(config);
134        let target_schema = schema.unwrap_or(&config.schema);
135
136        info!("正在连接数据库...");
137        // 日志脱敏:不记录完整 URL,避免密码泄露
138        info!(
139            "数据库地址: {}@{}:{}",
140            config.database_name, config.host, config.port
141        );
142        info!("使用 Schema: {}", target_schema);
143
144        let opts = options.unwrap_or_default();
145        let mut connect_options = ConnectOptions::new(&database_url);
146        let log_level = Self::parse_log_level(&config.logging_level);
147
148        connect_options
149            .max_connections(opts.max_connections)
150            .min_connections(opts.min_connections)
151            .connect_timeout(Duration::from_secs(opts.connect_timeout))
152            .acquire_timeout(Duration::from_secs(opts.acquire_timeout))
153            .idle_timeout(Duration::from_secs(opts.idle_timeout))
154            .max_lifetime(Duration::from_secs(opts.max_lifetime))
155            .sqlx_logging(opts.sqlx_logging)
156            .sqlx_logging_level(log_level);
157
158        let db = Database::connect(connect_options).await.map_err(|e| {
159            error!("数据库连接失败: {}", e);
160            DatabaseError::ConnectionFailed(format!("数据库连接失败: {}", e))
161        })?;
162
163        // 数据库特定配置
164        let backend = db.get_database_backend();
165        match backend {
166            DatabaseBackend::Postgres => {
167                // 禁用 extra_float_digits 参数(PgBouncer 不支持运行时参数设置,需要跳过)
168                if !config.use_pgbouncer {
169                    Self::set_extra_float_digits(&db).await?;
170                } else {
171                    info!("使用 PgBouncer,跳过设置 extra_float_digits");
172                }
173
174                // 设置 search_path(PostgreSQL 的 schema)
175                // 注意:如果使用 PgBouncer 的 transaction 模式,SET 命令在事务结束后会失效
176                // 建议:1) 使用 session 模式,或 2) 在数据库层面设置默认 schema
177                match Self::set_schema(&db, target_schema).await {
178                    Ok(_) => {}
179                    Err(e) => {
180                        if config.use_pgbouncer {
181                            warn!(
182                                "设置 schema 失败(可能是 PgBouncer transaction 模式导致): {}",
183                                e
184                            );
185                            warn!(
186                                "建议:1) 改用 session 模式,或 2) 在数据库层面设置: ALTER DATABASE {} SET search_path TO {}",
187                                config.database_name, target_schema
188                            );
189                        } else {
190                            return Err(e);
191                        }
192                    }
193                }
194            }
195            DatabaseBackend::MySql => {
196                // MySQL 特定配置可以在这里添加
197                info!("MySQL 数据库连接已建立");
198            }
199            DatabaseBackend::Sqlite => {
200                // SQLite 特定配置可以在这里添加
201                info!("SQLite 数据库连接已建立");
202            }
203            _ => {
204                // 其他数据库类型
205                info!("数据库连接已建立");
206            }
207        }
208
209        info!("✓ 数据库连接成功");
210        Ok(db)
211    }
212
213    /// 初始化数据库连接(使用配置中的默认 schema)
214    pub async fn init(
215        config: &DatabaseConfig,
216        options: Option<ConnectionOptions>,
217    ) -> Result<DatabaseConnection, DatabaseError> {
218        Self::create_connection(config, None, options).await
219    }
220
221    /// 设置 PostgreSQL extra_float_digits 参数
222    pub async fn set_extra_float_digits(db: &DatabaseConnection) -> Result<(), DatabaseError> {
223        db.execute_unprepared("SET extra_float_digits = 0")
224            .await
225            .map_err(|e| {
226                error!("设置 extra_float_digits 失败: {}", e);
227                DatabaseError::SchemaSetFailed(format!("设置 extra_float_digits 失败: {}", e))
228            })?;
229        info!("✓ 已设置 extra_float_digits = 0");
230        Ok(())
231    }
232
233    /// 验证 schema 名称安全性
234    ///
235    /// # 安全性
236    /// 确保 schema 名称只包含合法字符,防止 SQL 注入攻击
237    fn validate_schema_name(schema: &str) -> Result<(), DatabaseError> {
238        // Schema 名称只能包含字母、数字、下划线和连字符
239        // PostgreSQL 允许的标识符字符:字母、数字、下划线、美元符号
240        // 为了安全,我们限制为:字母、数字、下划线、连字符
241        if schema.is_empty() {
242            return Err(DatabaseError::InvalidConfig(
243                "Schema 名称不能为空".to_string(),
244            ));
245        }
246
247        // 检查是否包含非法字符
248        if !schema
249            .chars()
250            .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
251        {
252            return Err(DatabaseError::InvalidConfig(format!(
253                "Schema 名称包含非法字符: {},只允许字母、数字、下划线和连字符",
254                schema
255            )));
256        }
257
258        // 检查长度(PostgreSQL 标识符最大长度为 63)
259        if schema.len() > 63 {
260            return Err(DatabaseError::InvalidConfig(format!(
261                "Schema 名称过长: {},最大长度为 63 字符",
262                schema
263            )));
264        }
265
266        Ok(())
267    }
268
269    /// 设置 PostgreSQL schema
270    ///
271    /// # 安全性
272    /// 验证 schema 名称,防止 SQL 注入攻击
273    pub async fn set_schema(db: &DatabaseConnection, schema: &str) -> Result<(), DatabaseError> {
274        // 验证 schema 名称安全性
275        Self::validate_schema_name(schema)?;
276
277        // 使用参数化查询或验证后的字符串
278        // 注意:PostgreSQL 的 SET 命令不支持参数化查询,但我们已经验证了 schema 名称的安全性
279        let sql = format!("SET search_path TO {}", schema);
280        db.execute_unprepared(&sql).await.map_err(|e| {
281            error!("设置 schema 失败: {}", e);
282            DatabaseError::SchemaSetFailed(format!("设置 schema 失败: {}", e))
283        })?;
284        info!("✓ 已设置 search_path 到 schema: {}", schema);
285        Ok(())
286    }
287
288    /// 测试数据库连接
289    pub async fn test_connection(db: &DatabaseConnection) -> Result<(), DatabaseError> {
290        db.execute_unprepared("SELECT 1").await.map_err(|e| {
291            error!("数据库连接测试失败: {}", e);
292            DatabaseError::ConnectionFailed(format!("数据库连接测试失败: {}", e))
293        })?;
294        info!("✓ 数据库连接测试成功");
295        Ok(())
296    }
297
298    /// 解析日志级别字符串为 log::LevelFilter
299    fn parse_log_level(level: &str) -> log::LevelFilter {
300        match level.to_lowercase().as_str() {
301            "off" => log::LevelFilter::Off,
302            "error" => log::LevelFilter::Error,
303            "warn" => log::LevelFilter::Warn,
304            "info" => log::LevelFilter::Info,
305            "debug" => log::LevelFilter::Debug,
306            "trace" => log::LevelFilter::Trace,
307            _ => {
308                warn!("未知的日志级别 '{}', 使用默认级别 Info", level);
309                log::LevelFilter::Info
310            }
311        }
312    }
313
314    /// 验证数据库配置
315    ///
316    /// # 验证项
317    /// - 主机地址不能为空
318    /// - 端口必须在有效范围内(1-65535)
319    /// - 数据库名称不能为空
320    /// - 用户名不能为空(SQLite 除外)
321    /// - 数据库类型必须支持
322    pub fn validate_config(config: &DatabaseConfig) -> Result<(), DatabaseError> {
323        if config.host.is_empty() {
324            return Err(DatabaseError::InvalidConfig(
325                "数据库主机不能为空".to_string(),
326            ));
327        }
328        
329        if config.port == 0 {
330            return Err(DatabaseError::InvalidConfig(
331                format!("数据库端口无效: {},有效范围: 1-65535", config.port)
332            ));
333        }
334        
335        if config.database_name.is_empty() {
336            return Err(DatabaseError::InvalidConfig(
337                "数据库名称不能为空".to_string(),
338            ));
339        }
340        
341        // SQLite 不需要用户名和密码验证
342        if config.database_type != "sqlite" {
343            if config.username.is_empty() {
344                return Err(DatabaseError::InvalidConfig(
345                    "数据库用户名不能为空".to_string(),
346                ));
347            }
348        }
349        
350        if !["postgres", "mysql", "sqlite"].contains(&config.database_type.as_str()) {
351            return Err(DatabaseError::InvalidConfig(
352                format!(
353                    "不支持的数据库类型: {},支持的类型: postgres, mysql, sqlite",
354                    config.database_type
355                )
356            ));
357        }
358        
359        // 验证日志级别
360        let valid_log_levels = ["off", "error", "warn", "info", "debug", "trace"];
361        if !valid_log_levels.contains(&config.logging_level.to_lowercase().as_str()) {
362            warn!(
363                "无效的日志级别: {},将使用默认级别 info",
364                config.logging_level
365            );
366        }
367        
368        Ok(())
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_build_database_url() {
378        let config = DatabaseConfig {
379            database_type: "postgres".to_string(),
380            host: "localhost".to_string(),
381            port: 5432,
382            username: "user".to_string(),
383            password: "pass".to_string(),
384            database_name: "testdb".to_string(),
385            schema: "public".to_string(),
386            logging_level: "info".to_string(),
387            use_pgbouncer: false,
388        };
389
390        let url = DatabaseService::build_database_url(&config);
391        assert_eq!(url, "postgres://user:pass@localhost:5432/testdb");
392    }
393
394    #[test]
395    fn test_validate_config() {
396        let valid_config = DatabaseConfig {
397            database_type: "postgres".to_string(),
398            host: "localhost".to_string(),
399            port: 5432,
400            username: "user".to_string(),
401            password: "pass".to_string(),
402            database_name: "testdb".to_string(),
403            schema: "public".to_string(),
404            logging_level: "info".to_string(),
405            use_pgbouncer: false,
406        };
407
408        assert!(DatabaseService::validate_config(&valid_config).is_ok());
409
410        let invalid_config = DatabaseConfig {
411            database_type: "postgres".to_string(),
412            host: "".to_string(),
413            port: 5432,
414            username: "user".to_string(),
415            password: "pass".to_string(),
416            database_name: "testdb".to_string(),
417            schema: "public".to_string(),
418            logging_level: "info".to_string(),
419            use_pgbouncer: false,
420        };
421
422        assert!(DatabaseService::validate_config(&invalid_config).is_err());
423    }
424}