1use super::{
4 ConnectionError, ConnectionOptions, ConnectionResult, ConnectionString, Driver, MySqlOptions,
5 PoolOptions, PostgresOptions, SqliteOptions, SslMode,
6};
7use std::collections::HashMap;
8use std::time::Duration;
9use tracing::info;
10
11#[derive(Debug, Clone)]
13pub struct DatabaseConfig {
14 pub driver: Driver,
16 pub url: String,
18 pub host: Option<String>,
20 pub port: Option<u16>,
22 pub database: Option<String>,
24 pub user: Option<String>,
26 pub password: Option<String>,
28 pub connection: ConnectionOptions,
30 pub pool: PoolOptions,
32 pub postgres: Option<PostgresOptions>,
34 pub mysql: Option<MySqlOptions>,
36 pub sqlite: Option<SqliteOptions>,
38}
39
40impl DatabaseConfig {
41 pub fn postgres() -> DatabaseConfigBuilder {
43 DatabaseConfigBuilder::new(Driver::Postgres)
44 }
45
46 pub fn mysql() -> DatabaseConfigBuilder {
48 DatabaseConfigBuilder::new(Driver::MySql)
49 }
50
51 pub fn sqlite() -> DatabaseConfigBuilder {
53 DatabaseConfigBuilder::new(Driver::Sqlite)
54 }
55
56 pub fn from_url(url: &str) -> ConnectionResult<Self> {
58 let conn = ConnectionString::parse(url)?;
59 let opts = ConnectionOptions::from_params(conn.params());
60
61 let config = Self {
62 driver: conn.driver(),
63 url: url.to_string(),
64 host: conn.host().map(String::from),
65 port: conn.port(),
66 database: conn.database().map(String::from),
67 user: conn.user().map(String::from),
68 password: conn.password().map(String::from),
69 connection: opts,
70 pool: PoolOptions::default(),
71 postgres: if conn.driver() == Driver::Postgres {
72 Some(PostgresOptions::new())
73 } else {
74 None
75 },
76 mysql: if conn.driver() == Driver::MySql {
77 Some(MySqlOptions::new())
78 } else {
79 None
80 },
81 sqlite: if conn.driver() == Driver::Sqlite {
82 Some(SqliteOptions::new())
83 } else {
84 None
85 },
86 };
87
88 info!(
89 driver = %config.driver.name(),
90 host = ?config.host,
91 database = ?config.database,
92 "DatabaseConfig loaded from URL"
93 );
94
95 Ok(config)
96 }
97
98 pub fn from_env() -> ConnectionResult<Self> {
100 info!("Loading database configuration from DATABASE_URL");
101 let url = std::env::var("DATABASE_URL")
102 .map_err(|_| ConnectionError::EnvNotFound("DATABASE_URL".to_string()))?;
103 Self::from_url(&url)
104 }
105
106 pub fn to_url(&self) -> String {
108 if !self.url.is_empty() {
109 return self.url.clone();
110 }
111
112 let mut url = format!("{}://", self.driver.name());
113
114 if let Some(ref user) = self.user {
115 url.push_str(user);
116 if let Some(ref pass) = self.password {
117 url.push(':');
118 url.push_str(pass);
119 }
120 url.push('@');
121 }
122
123 if let Some(ref host) = self.host {
124 url.push_str(host);
125 if let Some(port) = self.port {
126 url.push(':');
127 url.push_str(&port.to_string());
128 }
129 }
130
131 if let Some(ref db) = self.database {
132 url.push('/');
133 url.push_str(db);
134 }
135
136 url
137 }
138}
139
140pub struct DatabaseConfigBuilder {
142 driver: Driver,
143 url: Option<String>,
144 host: Option<String>,
145 port: Option<u16>,
146 database: Option<String>,
147 user: Option<String>,
148 password: Option<String>,
149 connection: ConnectionOptions,
150 pool: PoolOptions,
151 postgres: Option<PostgresOptions>,
152 mysql: Option<MySqlOptions>,
153 sqlite: Option<SqliteOptions>,
154}
155
156impl DatabaseConfigBuilder {
157 pub fn new(driver: Driver) -> Self {
159 Self {
160 driver,
161 url: None,
162 host: None,
163 port: None,
164 database: None,
165 user: None,
166 password: None,
167 connection: ConnectionOptions::default(),
168 pool: PoolOptions::default(),
169 postgres: if driver == Driver::Postgres {
170 Some(PostgresOptions::new())
171 } else {
172 None
173 },
174 mysql: if driver == Driver::MySql {
175 Some(MySqlOptions::new())
176 } else {
177 None
178 },
179 sqlite: if driver == Driver::Sqlite {
180 Some(SqliteOptions::new())
181 } else {
182 None
183 },
184 }
185 }
186
187 pub fn url(mut self, url: impl Into<String>) -> Self {
189 self.url = Some(url.into());
190 self
191 }
192
193 pub fn host(mut self, host: impl Into<String>) -> Self {
195 self.host = Some(host.into());
196 self
197 }
198
199 pub fn port(mut self, port: u16) -> Self {
201 self.port = Some(port);
202 self
203 }
204
205 pub fn database(mut self, db: impl Into<String>) -> Self {
207 self.database = Some(db.into());
208 self
209 }
210
211 pub fn user(mut self, user: impl Into<String>) -> Self {
213 self.user = Some(user.into());
214 self
215 }
216
217 pub fn password(mut self, password: impl Into<String>) -> Self {
219 self.password = Some(password.into());
220 self
221 }
222
223 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
225 self.connection.connect_timeout = timeout;
226 self
227 }
228
229 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
231 self.connection.ssl.mode = mode;
232 self
233 }
234
235 pub fn application_name(mut self, name: impl Into<String>) -> Self {
237 self.connection.application_name = Some(name.into());
238 self
239 }
240
241 pub fn max_connections(mut self, n: u32) -> Self {
243 self.pool.max_connections = n;
244 self
245 }
246
247 pub fn min_connections(mut self, n: u32) -> Self {
249 self.pool.min_connections = n;
250 self
251 }
252
253 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
255 self.pool.idle_timeout = Some(timeout);
256 self
257 }
258
259 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
261 self.pool.max_lifetime = Some(lifetime);
262 self
263 }
264
265 pub fn postgres_options<F>(mut self, f: F) -> Self
267 where
268 F: FnOnce(PostgresOptions) -> PostgresOptions,
269 {
270 if let Some(opts) = self.postgres.take() {
271 self.postgres = Some(f(opts));
272 }
273 self
274 }
275
276 pub fn mysql_options<F>(mut self, f: F) -> Self
278 where
279 F: FnOnce(MySqlOptions) -> MySqlOptions,
280 {
281 if let Some(opts) = self.mysql.take() {
282 self.mysql = Some(f(opts));
283 }
284 self
285 }
286
287 pub fn sqlite_options<F>(mut self, f: F) -> Self
289 where
290 F: FnOnce(SqliteOptions) -> SqliteOptions,
291 {
292 if let Some(opts) = self.sqlite.take() {
293 self.sqlite = Some(f(opts));
294 }
295 self
296 }
297
298 pub fn build(self) -> ConnectionResult<DatabaseConfig> {
300 if self.url.is_none() {
302 match self.driver {
303 Driver::Postgres | Driver::MySql => {
304 if self.host.is_none() {
305 return Err(ConnectionError::MissingField("host".to_string()));
306 }
307 }
308 Driver::Sqlite => {
309 if self.database.is_none() {
310 return Err(ConnectionError::MissingField(
311 "database (file path)".to_string(),
312 ));
313 }
314 }
315 }
316 }
317
318 Ok(DatabaseConfig {
319 driver: self.driver,
320 url: self.url.unwrap_or_default(),
321 host: self.host,
322 port: self.port,
323 database: self.database,
324 user: self.user,
325 password: self.password,
326 connection: self.connection,
327 pool: self.pool,
328 postgres: self.postgres,
329 mysql: self.mysql,
330 sqlite: self.sqlite,
331 })
332 }
333}
334
335#[derive(Debug, Clone, Default)]
337pub struct MultiDatabaseConfig {
338 pub primary: Option<DatabaseConfig>,
340 pub replicas: Vec<DatabaseConfig>,
342 pub databases: HashMap<String, DatabaseConfig>,
344 pub load_balance: LoadBalanceStrategy,
346}
347
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
350pub enum LoadBalanceStrategy {
351 #[default]
353 RoundRobin,
354 Random,
356 First,
358 LeastLatency,
360}
361
362impl MultiDatabaseConfig {
363 pub fn new() -> Self {
365 Self::default()
366 }
367
368 pub fn primary(mut self, config: DatabaseConfig) -> Self {
370 self.primary = Some(config);
371 self
372 }
373
374 pub fn replica(mut self, config: DatabaseConfig) -> Self {
376 self.replicas.push(config);
377 self
378 }
379
380 pub fn database(mut self, name: impl Into<String>, config: DatabaseConfig) -> Self {
382 self.databases.insert(name.into(), config);
383 self
384 }
385
386 pub fn load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
388 self.load_balance = strategy;
389 self
390 }
391
392 pub fn get_primary(&self) -> Option<&DatabaseConfig> {
394 self.primary.as_ref()
395 }
396
397 pub fn get(&self, name: &str) -> Option<&DatabaseConfig> {
399 self.databases.get(name)
400 }
401
402 pub fn has_replicas(&self) -> bool {
404 !self.replicas.is_empty()
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_config_from_url() {
414 let config =
415 DatabaseConfig::from_url("postgres://user:pass@localhost:5432/mydb?sslmode=require")
416 .unwrap();
417
418 assert_eq!(config.driver, Driver::Postgres);
419 assert_eq!(config.host, Some("localhost".to_string()));
420 assert_eq!(config.port, Some(5432));
421 assert_eq!(config.database, Some("mydb".to_string()));
422 assert_eq!(config.user, Some("user".to_string()));
423 assert!(config.postgres.is_some());
424 }
425
426 #[test]
427 fn test_postgres_builder() {
428 let config = DatabaseConfig::postgres()
429 .host("localhost")
430 .port(5432)
431 .database("mydb")
432 .user("user")
433 .password("pass")
434 .max_connections(20)
435 .ssl_mode(SslMode::Require)
436 .build()
437 .unwrap();
438
439 assert_eq!(config.driver, Driver::Postgres);
440 assert_eq!(config.host, Some("localhost".to_string()));
441 assert_eq!(config.pool.max_connections, 20);
442 assert_eq!(config.connection.ssl.mode, SslMode::Require);
443 }
444
445 #[test]
446 fn test_mysql_builder() {
447 let config = DatabaseConfig::mysql()
448 .host("127.0.0.1")
449 .database("testdb")
450 .user("root")
451 .mysql_options(|opts| opts.charset("utf8mb4"))
452 .build()
453 .unwrap();
454
455 assert_eq!(config.driver, Driver::MySql);
456 assert!(config.mysql.is_some());
457 assert_eq!(config.mysql.unwrap().charset, Some("utf8mb4".to_string()));
458 }
459
460 #[test]
461 fn test_sqlite_builder() {
462 let config = DatabaseConfig::sqlite()
463 .database("./data/app.db")
464 .sqlite_options(|opts| opts.foreign_keys(true))
465 .build()
466 .unwrap();
467
468 assert_eq!(config.driver, Driver::Sqlite);
469 assert!(config.sqlite.is_some());
470 assert!(config.sqlite.unwrap().foreign_keys);
471 }
472
473 #[test]
474 fn test_builder_validation() {
475 let result = DatabaseConfig::postgres().database("mydb").build();
477 assert!(result.is_err());
478
479 let result = DatabaseConfig::sqlite().build();
481 assert!(result.is_err());
482 }
483
484 #[test]
485 fn test_multi_database_config() {
486 let config = MultiDatabaseConfig::new()
487 .primary(DatabaseConfig::from_url("postgres://localhost/primary").unwrap())
488 .replica(DatabaseConfig::from_url("postgres://localhost/replica1").unwrap())
489 .replica(DatabaseConfig::from_url("postgres://localhost/replica2").unwrap())
490 .database(
491 "analytics",
492 DatabaseConfig::from_url("postgres://localhost/analytics").unwrap(),
493 )
494 .load_balance(LoadBalanceStrategy::RoundRobin);
495
496 assert!(config.get_primary().is_some());
497 assert_eq!(config.replicas.len(), 2);
498 assert!(config.get("analytics").is_some());
499 assert!(config.has_replicas());
500 }
501
502 #[test]
503 fn test_to_url() {
504 let config = DatabaseConfig::postgres()
505 .host("localhost")
506 .port(5432)
507 .database("mydb")
508 .user("user")
509 .password("pass")
510 .build()
511 .unwrap();
512
513 let url = config.to_url();
514 assert!(url.contains("postgres://"));
515 assert!(url.contains("user:pass@"));
516 assert!(url.contains("localhost:5432"));
517 assert!(url.contains("/mydb"));
518 }
519}