dataforge/filling/
connection.rs

1//! 数据库连接管理模块
2
3use std::collections::HashMap;
4use crate::error::{DataForgeError, Result};
5
6/// 数据库连接类型
7#[derive(Debug, Clone, PartialEq)]
8pub enum DatabaseType {
9    MySQL,
10    PostgreSQL,
11    SQLite,
12}
13
14/// 数据库连接配置
15#[derive(Debug, Clone)]
16pub struct ConnectionConfig {
17    pub database_type: DatabaseType,
18    pub host: Option<String>,
19    pub port: Option<u16>,
20    pub database: String,
21    pub username: Option<String>,
22    pub password: Option<String>,
23    pub connection_pool_size: usize,
24    pub connection_timeout: u64,
25    pub additional_params: HashMap<String, String>,
26}
27
28impl ConnectionConfig {
29    /// 创建SQLite连接配置
30    pub fn sqlite<P: AsRef<str>>(database_path: P) -> Self {
31        Self {
32            database_type: DatabaseType::SQLite,
33            host: None,
34            port: None,
35            database: database_path.as_ref().to_string(),
36            username: None,
37            password: None,
38            connection_pool_size: 10,
39            connection_timeout: 30,
40            additional_params: HashMap::new(),
41        }
42    }
43
44    /// 创建MySQL连接配置
45    pub fn mysql<H, D, U, P>(host: H, database: D, username: U, password: P) -> Self
46    where
47        H: AsRef<str>,
48        D: AsRef<str>,
49        U: AsRef<str>,
50        P: AsRef<str>,
51    {
52        Self {
53            database_type: DatabaseType::MySQL,
54            host: Some(host.as_ref().to_string()),
55            port: Some(3306),
56            database: database.as_ref().to_string(),
57            username: Some(username.as_ref().to_string()),
58            password: Some(password.as_ref().to_string()),
59            connection_pool_size: 10,
60            connection_timeout: 30,
61            additional_params: HashMap::new(),
62        }
63    }
64
65    /// 创建PostgreSQL连接配置
66    pub fn postgres<H, D, U, P>(host: H, database: D, username: U, password: P) -> Self
67    where
68        H: AsRef<str>,
69        D: AsRef<str>,
70        U: AsRef<str>,
71        P: AsRef<str>,
72    {
73        Self {
74            database_type: DatabaseType::PostgreSQL,
75            host: Some(host.as_ref().to_string()),
76            port: Some(5432),
77            database: database.as_ref().to_string(),
78            username: Some(username.as_ref().to_string()),
79            password: Some(password.as_ref().to_string()),
80            connection_pool_size: 10,
81            connection_timeout: 30,
82            additional_params: HashMap::new(),
83        }
84    }
85
86    /// 设置端口
87    pub fn with_port(mut self, port: u16) -> Self {
88        self.port = Some(port);
89        self
90    }
91
92    /// 设置连接池大小
93    pub fn with_pool_size(mut self, size: usize) -> Self {
94        self.connection_pool_size = size;
95        self
96    }
97
98    /// 设置连接超时
99    pub fn with_timeout(mut self, timeout: u64) -> Self {
100        self.connection_timeout = timeout;
101        self
102    }
103
104    /// 添加额外参数
105    pub fn with_param<K, V>(mut self, key: K, value: V) -> Self
106    where
107        K: AsRef<str>,
108        V: AsRef<str>,
109    {
110        self.additional_params.insert(
111            key.as_ref().to_string(),
112            value.as_ref().to_string(),
113        );
114        self
115    }
116
117    /// 生成连接字符串
118    pub fn to_connection_string(&self) -> Result<String> {
119        match self.database_type {
120            DatabaseType::SQLite => {
121                Ok(format!("sqlite://{}", self.database))
122            },
123            DatabaseType::MySQL => {
124                let host = self.host.as_ref()
125                    .ok_or_else(|| DataForgeError::config("MySQL host is required"))?;
126                let port = self.port.unwrap_or(3306);
127                let username = self.username.as_ref()
128                    .ok_or_else(|| DataForgeError::config("MySQL username is required"))?;
129                let password = self.password.as_ref()
130                    .ok_or_else(|| DataForgeError::config("MySQL password is required"))?;
131
132                let mut conn_str = format!(
133                    "mysql://{}:{}@{}:{}/{}",
134                    username, password, host, port, self.database
135                );
136
137                if !self.additional_params.is_empty() {
138                    let params: Vec<String> = self.additional_params
139                        .iter()
140                        .map(|(k, v)| format!("{}={}", k, v))
141                        .collect();
142                    conn_str.push('?');
143                    conn_str.push_str(&params.join("&"));
144                }
145
146                Ok(conn_str)
147            },
148            DatabaseType::PostgreSQL => {
149                let host = self.host.as_ref()
150                    .ok_or_else(|| DataForgeError::config("PostgreSQL host is required"))?;
151                let port = self.port.unwrap_or(5432);
152                let username = self.username.as_ref()
153                    .ok_or_else(|| DataForgeError::config("PostgreSQL username is required"))?;
154                let password = self.password.as_ref()
155                    .ok_or_else(|| DataForgeError::config("PostgreSQL password is required"))?;
156
157                let mut conn_str = format!(
158                    "postgresql://{}:{}@{}:{}/{}",
159                    username, password, host, port, self.database
160                );
161
162                if !self.additional_params.is_empty() {
163                    let params: Vec<String> = self.additional_params
164                        .iter()
165                        .map(|(k, v)| format!("{}={}", k, v))
166                        .collect();
167                    conn_str.push('?');
168                    conn_str.push_str(&params.join("&"));
169                }
170
171                Ok(conn_str)
172            },
173        }
174    }
175
176    /// 验证配置
177    pub fn validate(&self) -> Result<()> {
178        match self.database_type {
179            DatabaseType::SQLite => {
180                if self.database.is_empty() {
181                    return Err(DataForgeError::config("SQLite database path cannot be empty"));
182                }
183            },
184            DatabaseType::MySQL | DatabaseType::PostgreSQL => {
185                if self.host.is_none() {
186                    return Err(DataForgeError::config("Host is required for MySQL/PostgreSQL"));
187                }
188                if self.username.is_none() {
189                    return Err(DataForgeError::config("Username is required for MySQL/PostgreSQL"));
190                }
191                if self.password.is_none() {
192                    return Err(DataForgeError::config("Password is required for MySQL/PostgreSQL"));
193                }
194                if self.database.is_empty() {
195                    return Err(DataForgeError::config("Database name cannot be empty"));
196                }
197            },
198        }
199
200        if self.connection_pool_size == 0 {
201            return Err(DataForgeError::config("Connection pool size must be greater than 0"));
202        }
203
204        Ok(())
205    }
206}
207
208/// 连接管理器
209pub struct ConnectionManager {
210    config: ConnectionConfig,
211    connection_string: String,
212}
213
214impl ConnectionManager {
215    /// 创建新的连接管理器
216    pub fn new(config: ConnectionConfig) -> Result<Self> {
217        config.validate()?;
218        let connection_string = config.to_connection_string()?;
219        
220        Ok(Self {
221            config,
222            connection_string,
223        })
224    }
225
226    /// 获取连接字符串
227    pub fn connection_string(&self) -> &str {
228        &self.connection_string
229    }
230    
231    /// 获取连接字符串(返回所有权)
232    pub fn get_connection_string(&self) -> Result<String> {
233        Ok(self.connection_string.clone())
234    }
235
236    /// 获取数据库类型
237    pub fn database_type(&self) -> &DatabaseType {
238        &self.config.database_type
239    }
240
241    /// 获取配置
242    pub fn config(&self) -> &ConnectionConfig {
243        &self.config
244    }
245
246    /// 测试连接
247    pub fn test_connection(&self) -> Result<()> {
248        // TODO 这里应该实际测试数据库连接,由于我们没有实际的数据库驱动,这里只做配置验证
249        self.config.validate()
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_sqlite_config() {
259        let config = ConnectionConfig::sqlite("test.db");
260        assert_eq!(config.database_type, DatabaseType::SQLite);
261        assert_eq!(config.database, "test.db");
262        assert!(config.validate().is_ok());
263        
264        let conn_str = config.to_connection_string().unwrap();
265        assert_eq!(conn_str, "sqlite://test.db");
266    }
267
268    #[test]
269    fn test_mysql_config() {
270        let config = ConnectionConfig::mysql("localhost", "testdb", "user", "pass")
271            .with_port(3307)
272            .with_param("charset", "utf8mb4");
273        
274        assert_eq!(config.database_type, DatabaseType::MySQL);
275        assert_eq!(config.port, Some(3307));
276        assert!(config.validate().is_ok());
277        
278        let conn_str = config.to_connection_string().unwrap();
279        assert!(conn_str.contains("mysql://user:pass@localhost:3307/testdb"));
280        assert!(conn_str.contains("charset=utf8mb4"));
281    }
282
283    #[test]
284    fn test_postgres_config() {
285        let config = ConnectionConfig::postgres("localhost", "testdb", "user", "pass")
286            .with_pool_size(20)
287            .with_timeout(60);
288        
289        assert_eq!(config.database_type, DatabaseType::PostgreSQL);
290        assert_eq!(config.connection_pool_size, 20);
291        assert_eq!(config.connection_timeout, 60);
292        assert!(config.validate().is_ok());
293        
294        let conn_str = config.to_connection_string().unwrap();
295        assert!(conn_str.contains("postgresql://user:pass@localhost:5432/testdb"));
296    }
297
298    #[test]
299    fn test_connection_manager() {
300        let config = ConnectionConfig::sqlite("test.db");
301        let manager = ConnectionManager::new(config);
302        
303        assert!(manager.is_ok());
304        let manager = manager.unwrap();
305        assert_eq!(manager.connection_string(), "sqlite://test.db");
306        assert_eq!(*manager.database_type(), DatabaseType::SQLite);
307    }
308
309    #[test]
310    fn test_invalid_config() {
311        let mut config = ConnectionConfig::mysql("localhost", "testdb", "user", "pass");
312        config.connection_pool_size = 0;
313        
314        assert!(config.validate().is_err());
315    }
316}