secra-database 0.1.0

基于 SeaORM 的 Rust 数据库连接和管理库
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
//! 数据库连接管理服务
//!
//! 提供数据库连接的创建、管理和配置功能

use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, DatabaseBackend};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use tracing::{error, info, warn};

/// 数据库连接错误
#[derive(Error, Debug)]
pub enum DatabaseError {
    #[error("数据库连接失败: {0}")]
    ConnectionFailed(String),
    #[error("设置 Schema 失败: {0}")]
    SchemaSetFailed(String),
    #[error("无效的配置: {0}")]
    InvalidConfig(String),
}

/// 数据库配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
    /// 数据库类型(postgres, mysql, sqlite)
    pub database_type: String,
    /// 数据库主机
    pub host: String,
    /// 数据库端口
    pub port: u16,
    /// 数据库用户名
    pub username: String,
    /// 数据库密码
    pub password: String,
    /// 数据库名称
    pub database_name: String,
    /// 默认 Schema(PostgreSQL)
    #[serde(default = "default_schema")]
    pub schema: String,
    /// 日志级别
    #[serde(default = "default_logging_level")]
    pub logging_level: String,
    /// 是否使用 PgBouncer(如果使用 PgBouncer,则跳过设置某些运行时参数)
    #[serde(default = "default_use_pgbouncer")]
    pub use_pgbouncer: bool,
}

fn default_schema() -> String {
    "public".to_string()
}

fn default_logging_level() -> String {
    "info".to_string()
}

fn default_use_pgbouncer() -> bool {
    false
}

/// 连接选项配置
#[derive(Debug, Clone)]
pub struct ConnectionOptions {
    /// 最大连接数
    pub max_connections: u32,
    /// 最小连接数
    pub min_connections: u32,
    /// 连接超时时间(秒)
    pub connect_timeout: u64,
    /// 获取连接超时时间(秒)
    pub acquire_timeout: u64,
    /// 空闲连接超时时间(秒)
    pub idle_timeout: u64,
    /// 连接最大生命周期(秒)
    pub max_lifetime: u64,
    /// 是否启用 SQL 日志
    pub sqlx_logging: bool,
}

impl Default for ConnectionOptions {
    fn default() -> Self {
        Self {
            max_connections: 300,
            min_connections: 5,
            connect_timeout: 8,
            acquire_timeout: 8,
            idle_timeout: 600,  // 10分钟(优化:增加空闲连接超时时间)
            max_lifetime: 1800, // 30分钟(优化:增加连接最大生命周期,避免频繁重建连接)
            sqlx_logging: true,
        }
    }
}

/// 数据库服务
pub struct DatabaseService;

impl DatabaseService {
    /// 构建数据库连接 URL
    ///
    /// # 安全性
    /// 使用 URL 编码确保特殊字符(如密码中的特殊字符)被正确编码
    pub fn build_database_url(config: &DatabaseConfig) -> String {
        // 对用户名和密码进行 URL 编码,避免特殊字符导致的问题
        // 对用户名、密码、主机和数据库名进行 URL 编码
        // 注意:只编码必要的部分,避免过度编码
        let encoded_username = urlencoding::encode(&config.username);
        let encoded_password = urlencoding::encode(&config.password);
        // 主机名处理:IPv6 地址需要特殊处理
        let encoded_host = if config.host.contains(':') && !config.host.starts_with('[') {
            // IPv6 地址需要用方括号包裹
            format!("[{}]", config.host)
        } else {
            config.host.clone()
        };
        let encoded_database = urlencoding::encode(&config.database_name);
        
        format!(
            "{}://{}:{}@{}:{}/{}",
            config.database_type,
            encoded_username,
            encoded_password,
            encoded_host,
            config.port,
            encoded_database
        )
    }

    /// 创建数据库连接(支持指定 schema)
    pub async fn create_connection(
        config: &DatabaseConfig,
        schema: Option<&str>,
        options: Option<ConnectionOptions>,
    ) -> Result<DatabaseConnection, DatabaseError> {
        let database_url = Self::build_database_url(config);
        let target_schema = schema.unwrap_or(&config.schema);

        info!("正在连接数据库...");
        // 日志脱敏:不记录完整 URL,避免密码泄露
        info!(
            "数据库地址: {}@{}:{}",
            config.database_name, config.host, config.port
        );
        info!("使用 Schema: {}", target_schema);

        let opts = options.unwrap_or_default();
        let mut connect_options = ConnectOptions::new(&database_url);
        let log_level = Self::parse_log_level(&config.logging_level);

        connect_options
            .max_connections(opts.max_connections)
            .min_connections(opts.min_connections)
            .connect_timeout(Duration::from_secs(opts.connect_timeout))
            .acquire_timeout(Duration::from_secs(opts.acquire_timeout))
            .idle_timeout(Duration::from_secs(opts.idle_timeout))
            .max_lifetime(Duration::from_secs(opts.max_lifetime))
            .sqlx_logging(opts.sqlx_logging)
            .sqlx_logging_level(log_level);

        let db = Database::connect(connect_options).await.map_err(|e| {
            error!("数据库连接失败: {}", e);
            DatabaseError::ConnectionFailed(format!("数据库连接失败: {}", e))
        })?;

        // 数据库特定配置
        let backend = db.get_database_backend();
        match backend {
            DatabaseBackend::Postgres => {
                // 禁用 extra_float_digits 参数(PgBouncer 不支持运行时参数设置,需要跳过)
                if !config.use_pgbouncer {
                    Self::set_extra_float_digits(&db).await?;
                } else {
                    info!("使用 PgBouncer,跳过设置 extra_float_digits");
                }

                // 设置 search_path(PostgreSQL 的 schema)
                // 注意:如果使用 PgBouncer 的 transaction 模式,SET 命令在事务结束后会失效
                // 建议:1) 使用 session 模式,或 2) 在数据库层面设置默认 schema
                match Self::set_schema(&db, target_schema).await {
                    Ok(_) => {}
                    Err(e) => {
                        if config.use_pgbouncer {
                            warn!(
                                "设置 schema 失败(可能是 PgBouncer transaction 模式导致): {}",
                                e
                            );
                            warn!(
                                "建议:1) 改用 session 模式,或 2) 在数据库层面设置: ALTER DATABASE {} SET search_path TO {}",
                                config.database_name, target_schema
                            );
                        } else {
                            return Err(e);
                        }
                    }
                }
            }
            DatabaseBackend::MySql => {
                // MySQL 特定配置可以在这里添加
                info!("MySQL 数据库连接已建立");
            }
            DatabaseBackend::Sqlite => {
                // SQLite 特定配置可以在这里添加
                info!("SQLite 数据库连接已建立");
            }
            _ => {
                // 其他数据库类型
                info!("数据库连接已建立");
            }
        }

        info!("✓ 数据库连接成功");
        Ok(db)
    }

    /// 初始化数据库连接(使用配置中的默认 schema)
    pub async fn init(
        config: &DatabaseConfig,
        options: Option<ConnectionOptions>,
    ) -> Result<DatabaseConnection, DatabaseError> {
        Self::create_connection(config, None, options).await
    }

    /// 设置 PostgreSQL extra_float_digits 参数
    pub async fn set_extra_float_digits(db: &DatabaseConnection) -> Result<(), DatabaseError> {
        db.execute_unprepared("SET extra_float_digits = 0")
            .await
            .map_err(|e| {
                error!("设置 extra_float_digits 失败: {}", e);
                DatabaseError::SchemaSetFailed(format!("设置 extra_float_digits 失败: {}", e))
            })?;
        info!("✓ 已设置 extra_float_digits = 0");
        Ok(())
    }

    /// 验证 schema 名称安全性
    ///
    /// # 安全性
    /// 确保 schema 名称只包含合法字符,防止 SQL 注入攻击
    fn validate_schema_name(schema: &str) -> Result<(), DatabaseError> {
        // Schema 名称只能包含字母、数字、下划线和连字符
        // PostgreSQL 允许的标识符字符:字母、数字、下划线、美元符号
        // 为了安全,我们限制为:字母、数字、下划线、连字符
        if schema.is_empty() {
            return Err(DatabaseError::InvalidConfig(
                "Schema 名称不能为空".to_string(),
            ));
        }

        // 检查是否包含非法字符
        if !schema
            .chars()
            .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
        {
            return Err(DatabaseError::InvalidConfig(format!(
                "Schema 名称包含非法字符: {},只允许字母、数字、下划线和连字符",
                schema
            )));
        }

        // 检查长度(PostgreSQL 标识符最大长度为 63)
        if schema.len() > 63 {
            return Err(DatabaseError::InvalidConfig(format!(
                "Schema 名称过长: {},最大长度为 63 字符",
                schema
            )));
        }

        Ok(())
    }

    /// 设置 PostgreSQL schema
    ///
    /// # 安全性
    /// 验证 schema 名称,防止 SQL 注入攻击
    pub async fn set_schema(db: &DatabaseConnection, schema: &str) -> Result<(), DatabaseError> {
        // 验证 schema 名称安全性
        Self::validate_schema_name(schema)?;

        // 使用参数化查询或验证后的字符串
        // 注意:PostgreSQL 的 SET 命令不支持参数化查询,但我们已经验证了 schema 名称的安全性
        let sql = format!("SET search_path TO {}", schema);
        db.execute_unprepared(&sql).await.map_err(|e| {
            error!("设置 schema 失败: {}", e);
            DatabaseError::SchemaSetFailed(format!("设置 schema 失败: {}", e))
        })?;
        info!("✓ 已设置 search_path 到 schema: {}", schema);
        Ok(())
    }

    /// 测试数据库连接
    pub async fn test_connection(db: &DatabaseConnection) -> Result<(), DatabaseError> {
        db.execute_unprepared("SELECT 1").await.map_err(|e| {
            error!("数据库连接测试失败: {}", e);
            DatabaseError::ConnectionFailed(format!("数据库连接测试失败: {}", e))
        })?;
        info!("✓ 数据库连接测试成功");
        Ok(())
    }

    /// 解析日志级别字符串为 log::LevelFilter
    fn parse_log_level(level: &str) -> log::LevelFilter {
        match level.to_lowercase().as_str() {
            "off" => log::LevelFilter::Off,
            "error" => log::LevelFilter::Error,
            "warn" => log::LevelFilter::Warn,
            "info" => log::LevelFilter::Info,
            "debug" => log::LevelFilter::Debug,
            "trace" => log::LevelFilter::Trace,
            _ => {
                warn!("未知的日志级别 '{}', 使用默认级别 Info", level);
                log::LevelFilter::Info
            }
        }
    }

    /// 验证数据库配置
    ///
    /// # 验证项
    /// - 主机地址不能为空
    /// - 端口必须在有效范围内(1-65535)
    /// - 数据库名称不能为空
    /// - 用户名不能为空(SQLite 除外)
    /// - 数据库类型必须支持
    pub fn validate_config(config: &DatabaseConfig) -> Result<(), DatabaseError> {
        if config.host.is_empty() {
            return Err(DatabaseError::InvalidConfig(
                "数据库主机不能为空".to_string(),
            ));
        }
        
        if config.port == 0 {
            return Err(DatabaseError::InvalidConfig(
                format!("数据库端口无效: {},有效范围: 1-65535", config.port)
            ));
        }
        
        if config.database_name.is_empty() {
            return Err(DatabaseError::InvalidConfig(
                "数据库名称不能为空".to_string(),
            ));
        }
        
        // SQLite 不需要用户名和密码验证
        if config.database_type != "sqlite" {
            if config.username.is_empty() {
                return Err(DatabaseError::InvalidConfig(
                    "数据库用户名不能为空".to_string(),
                ));
            }
        }
        
        if !["postgres", "mysql", "sqlite"].contains(&config.database_type.as_str()) {
            return Err(DatabaseError::InvalidConfig(
                format!(
                    "不支持的数据库类型: {},支持的类型: postgres, mysql, sqlite",
                    config.database_type
                )
            ));
        }
        
        // 验证日志级别
        let valid_log_levels = ["off", "error", "warn", "info", "debug", "trace"];
        if !valid_log_levels.contains(&config.logging_level.to_lowercase().as_str()) {
            warn!(
                "无效的日志级别: {},将使用默认级别 info",
                config.logging_level
            );
        }
        
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_build_database_url() {
        let config = DatabaseConfig {
            database_type: "postgres".to_string(),
            host: "localhost".to_string(),
            port: 5432,
            username: "user".to_string(),
            password: "pass".to_string(),
            database_name: "testdb".to_string(),
            schema: "public".to_string(),
            logging_level: "info".to_string(),
            use_pgbouncer: false,
        };

        let url = DatabaseService::build_database_url(&config);
        assert_eq!(url, "postgres://user:pass@localhost:5432/testdb");
    }

    #[test]
    fn test_validate_config() {
        let valid_config = DatabaseConfig {
            database_type: "postgres".to_string(),
            host: "localhost".to_string(),
            port: 5432,
            username: "user".to_string(),
            password: "pass".to_string(),
            database_name: "testdb".to_string(),
            schema: "public".to_string(),
            logging_level: "info".to_string(),
            use_pgbouncer: false,
        };

        assert!(DatabaseService::validate_config(&valid_config).is_ok());

        let invalid_config = DatabaseConfig {
            database_type: "postgres".to_string(),
            host: "".to_string(),
            port: 5432,
            username: "user".to_string(),
            password: "pass".to_string(),
            database_name: "testdb".to_string(),
            schema: "public".to_string(),
            logging_level: "info".to_string(),
            use_pgbouncer: false,
        };

        assert!(DatabaseService::validate_config(&invalid_config).is_err());
    }
}