1use std::collections::HashMap;
4use crate::error::{DataForgeError, Result};
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum DatabaseType {
9 MySQL,
10 PostgreSQL,
11 SQLite,
12}
13
14#[derive(Debug, Clone)]
16pub struct ConnectionConfig {
17 pub database_type: DatabaseType,
18 pub host: Option<String>,
19 pub port: Option<u16>,
20 pub database: String,
21 pub username: Option<String>,
22 pub password: Option<String>,
23 pub connection_pool_size: usize,
24 pub connection_timeout: u64,
25 pub additional_params: HashMap<String, String>,
26}
27
28impl ConnectionConfig {
29 pub fn sqlite<P: AsRef<str>>(database_path: P) -> Self {
31 Self {
32 database_type: DatabaseType::SQLite,
33 host: None,
34 port: None,
35 database: database_path.as_ref().to_string(),
36 username: None,
37 password: None,
38 connection_pool_size: 10,
39 connection_timeout: 30,
40 additional_params: HashMap::new(),
41 }
42 }
43
44 pub fn mysql<H, D, U, P>(host: H, database: D, username: U, password: P) -> Self
46 where
47 H: AsRef<str>,
48 D: AsRef<str>,
49 U: AsRef<str>,
50 P: AsRef<str>,
51 {
52 Self {
53 database_type: DatabaseType::MySQL,
54 host: Some(host.as_ref().to_string()),
55 port: Some(3306),
56 database: database.as_ref().to_string(),
57 username: Some(username.as_ref().to_string()),
58 password: Some(password.as_ref().to_string()),
59 connection_pool_size: 10,
60 connection_timeout: 30,
61 additional_params: HashMap::new(),
62 }
63 }
64
65 pub fn postgres<H, D, U, P>(host: H, database: D, username: U, password: P) -> Self
67 where
68 H: AsRef<str>,
69 D: AsRef<str>,
70 U: AsRef<str>,
71 P: AsRef<str>,
72 {
73 Self {
74 database_type: DatabaseType::PostgreSQL,
75 host: Some(host.as_ref().to_string()),
76 port: Some(5432),
77 database: database.as_ref().to_string(),
78 username: Some(username.as_ref().to_string()),
79 password: Some(password.as_ref().to_string()),
80 connection_pool_size: 10,
81 connection_timeout: 30,
82 additional_params: HashMap::new(),
83 }
84 }
85
86 pub fn with_port(mut self, port: u16) -> Self {
88 self.port = Some(port);
89 self
90 }
91
92 pub fn with_pool_size(mut self, size: usize) -> Self {
94 self.connection_pool_size = size;
95 self
96 }
97
98 pub fn with_timeout(mut self, timeout: u64) -> Self {
100 self.connection_timeout = timeout;
101 self
102 }
103
104 pub fn with_param<K, V>(mut self, key: K, value: V) -> Self
106 where
107 K: AsRef<str>,
108 V: AsRef<str>,
109 {
110 self.additional_params.insert(
111 key.as_ref().to_string(),
112 value.as_ref().to_string(),
113 );
114 self
115 }
116
117 pub fn to_connection_string(&self) -> Result<String> {
119 match self.database_type {
120 DatabaseType::SQLite => {
121 Ok(format!("sqlite://{}", self.database))
122 },
123 DatabaseType::MySQL => {
124 let host = self.host.as_ref()
125 .ok_or_else(|| DataForgeError::config("MySQL host is required"))?;
126 let port = self.port.unwrap_or(3306);
127 let username = self.username.as_ref()
128 .ok_or_else(|| DataForgeError::config("MySQL username is required"))?;
129 let password = self.password.as_ref()
130 .ok_or_else(|| DataForgeError::config("MySQL password is required"))?;
131
132 let mut conn_str = format!(
133 "mysql://{}:{}@{}:{}/{}",
134 username, password, host, port, self.database
135 );
136
137 if !self.additional_params.is_empty() {
138 let params: Vec<String> = self.additional_params
139 .iter()
140 .map(|(k, v)| format!("{}={}", k, v))
141 .collect();
142 conn_str.push('?');
143 conn_str.push_str(¶ms.join("&"));
144 }
145
146 Ok(conn_str)
147 },
148 DatabaseType::PostgreSQL => {
149 let host = self.host.as_ref()
150 .ok_or_else(|| DataForgeError::config("PostgreSQL host is required"))?;
151 let port = self.port.unwrap_or(5432);
152 let username = self.username.as_ref()
153 .ok_or_else(|| DataForgeError::config("PostgreSQL username is required"))?;
154 let password = self.password.as_ref()
155 .ok_or_else(|| DataForgeError::config("PostgreSQL password is required"))?;
156
157 let mut conn_str = format!(
158 "postgresql://{}:{}@{}:{}/{}",
159 username, password, host, port, self.database
160 );
161
162 if !self.additional_params.is_empty() {
163 let params: Vec<String> = self.additional_params
164 .iter()
165 .map(|(k, v)| format!("{}={}", k, v))
166 .collect();
167 conn_str.push('?');
168 conn_str.push_str(¶ms.join("&"));
169 }
170
171 Ok(conn_str)
172 },
173 }
174 }
175
176 pub fn validate(&self) -> Result<()> {
178 match self.database_type {
179 DatabaseType::SQLite => {
180 if self.database.is_empty() {
181 return Err(DataForgeError::config("SQLite database path cannot be empty"));
182 }
183 },
184 DatabaseType::MySQL | DatabaseType::PostgreSQL => {
185 if self.host.is_none() {
186 return Err(DataForgeError::config("Host is required for MySQL/PostgreSQL"));
187 }
188 if self.username.is_none() {
189 return Err(DataForgeError::config("Username is required for MySQL/PostgreSQL"));
190 }
191 if self.password.is_none() {
192 return Err(DataForgeError::config("Password is required for MySQL/PostgreSQL"));
193 }
194 if self.database.is_empty() {
195 return Err(DataForgeError::config("Database name cannot be empty"));
196 }
197 },
198 }
199
200 if self.connection_pool_size == 0 {
201 return Err(DataForgeError::config("Connection pool size must be greater than 0"));
202 }
203
204 Ok(())
205 }
206}
207
208pub struct ConnectionManager {
210 config: ConnectionConfig,
211 connection_string: String,
212}
213
214impl ConnectionManager {
215 pub fn new(config: ConnectionConfig) -> Result<Self> {
217 config.validate()?;
218 let connection_string = config.to_connection_string()?;
219
220 Ok(Self {
221 config,
222 connection_string,
223 })
224 }
225
226 pub fn connection_string(&self) -> &str {
228 &self.connection_string
229 }
230
231 pub fn get_connection_string(&self) -> Result<String> {
233 Ok(self.connection_string.clone())
234 }
235
236 pub fn database_type(&self) -> &DatabaseType {
238 &self.config.database_type
239 }
240
241 pub fn config(&self) -> &ConnectionConfig {
243 &self.config
244 }
245
246 pub fn test_connection(&self) -> Result<()> {
248 self.config.validate()
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_sqlite_config() {
259 let config = ConnectionConfig::sqlite("test.db");
260 assert_eq!(config.database_type, DatabaseType::SQLite);
261 assert_eq!(config.database, "test.db");
262 assert!(config.validate().is_ok());
263
264 let conn_str = config.to_connection_string().unwrap();
265 assert_eq!(conn_str, "sqlite://test.db");
266 }
267
268 #[test]
269 fn test_mysql_config() {
270 let config = ConnectionConfig::mysql("localhost", "testdb", "user", "pass")
271 .with_port(3307)
272 .with_param("charset", "utf8mb4");
273
274 assert_eq!(config.database_type, DatabaseType::MySQL);
275 assert_eq!(config.port, Some(3307));
276 assert!(config.validate().is_ok());
277
278 let conn_str = config.to_connection_string().unwrap();
279 assert!(conn_str.contains("mysql://user:pass@localhost:3307/testdb"));
280 assert!(conn_str.contains("charset=utf8mb4"));
281 }
282
283 #[test]
284 fn test_postgres_config() {
285 let config = ConnectionConfig::postgres("localhost", "testdb", "user", "pass")
286 .with_pool_size(20)
287 .with_timeout(60);
288
289 assert_eq!(config.database_type, DatabaseType::PostgreSQL);
290 assert_eq!(config.connection_pool_size, 20);
291 assert_eq!(config.connection_timeout, 60);
292 assert!(config.validate().is_ok());
293
294 let conn_str = config.to_connection_string().unwrap();
295 assert!(conn_str.contains("postgresql://user:pass@localhost:5432/testdb"));
296 }
297
298 #[test]
299 fn test_connection_manager() {
300 let config = ConnectionConfig::sqlite("test.db");
301 let manager = ConnectionManager::new(config);
302
303 assert!(manager.is_ok());
304 let manager = manager.unwrap();
305 assert_eq!(manager.connection_string(), "sqlite://test.db");
306 assert_eq!(*manager.database_type(), DatabaseType::SQLite);
307 }
308
309 #[test]
310 fn test_invalid_config() {
311 let mut config = ConnectionConfig::mysql("localhost", "testdb", "user", "pass");
312 config.connection_pool_size = 0;
313
314 assert!(config.validate().is_err());
315 }
316}