1use super::{
4 ConnectionError, ConnectionOptions, ConnectionResult, ConnectionString, Driver,
5 MySqlOptions, PoolOptions, PostgresOptions, SqliteOptions, SslMode,
6};
7use std::collections::HashMap;
8use std::time::Duration;
9use tracing::info;
10
11#[derive(Debug, Clone)]
13pub struct DatabaseConfig {
14 pub driver: Driver,
16 pub url: String,
18 pub host: Option<String>,
20 pub port: Option<u16>,
22 pub database: Option<String>,
24 pub user: Option<String>,
26 pub password: Option<String>,
28 pub connection: ConnectionOptions,
30 pub pool: PoolOptions,
32 pub postgres: Option<PostgresOptions>,
34 pub mysql: Option<MySqlOptions>,
36 pub sqlite: Option<SqliteOptions>,
38}
39
40impl DatabaseConfig {
41 pub fn postgres() -> DatabaseConfigBuilder {
43 DatabaseConfigBuilder::new(Driver::Postgres)
44 }
45
46 pub fn mysql() -> DatabaseConfigBuilder {
48 DatabaseConfigBuilder::new(Driver::MySql)
49 }
50
51 pub fn sqlite() -> DatabaseConfigBuilder {
53 DatabaseConfigBuilder::new(Driver::Sqlite)
54 }
55
56 pub fn from_url(url: &str) -> ConnectionResult<Self> {
58 let conn = ConnectionString::parse(url)?;
59 let opts = ConnectionOptions::from_params(conn.params());
60
61 let config = Self {
62 driver: conn.driver(),
63 url: url.to_string(),
64 host: conn.host().map(String::from),
65 port: conn.port(),
66 database: conn.database().map(String::from),
67 user: conn.user().map(String::from),
68 password: conn.password().map(String::from),
69 connection: opts,
70 pool: PoolOptions::default(),
71 postgres: if conn.driver() == Driver::Postgres {
72 Some(PostgresOptions::new())
73 } else {
74 None
75 },
76 mysql: if conn.driver() == Driver::MySql {
77 Some(MySqlOptions::new())
78 } else {
79 None
80 },
81 sqlite: if conn.driver() == Driver::Sqlite {
82 Some(SqliteOptions::new())
83 } else {
84 None
85 },
86 };
87
88 info!(
89 driver = %config.driver.name(),
90 host = ?config.host,
91 database = ?config.database,
92 "DatabaseConfig loaded from URL"
93 );
94
95 Ok(config)
96 }
97
98 pub fn from_env() -> ConnectionResult<Self> {
100 info!("Loading database configuration from DATABASE_URL");
101 let url = std::env::var("DATABASE_URL")
102 .map_err(|_| ConnectionError::EnvNotFound("DATABASE_URL".to_string()))?;
103 Self::from_url(&url)
104 }
105
106 pub fn to_url(&self) -> String {
108 if !self.url.is_empty() {
109 return self.url.clone();
110 }
111
112 let mut url = format!("{}://", self.driver.name());
113
114 if let Some(ref user) = self.user {
115 url.push_str(user);
116 if let Some(ref pass) = self.password {
117 url.push(':');
118 url.push_str(pass);
119 }
120 url.push('@');
121 }
122
123 if let Some(ref host) = self.host {
124 url.push_str(host);
125 if let Some(port) = self.port {
126 url.push(':');
127 url.push_str(&port.to_string());
128 }
129 }
130
131 if let Some(ref db) = self.database {
132 url.push('/');
133 url.push_str(db);
134 }
135
136 url
137 }
138}
139
140pub struct DatabaseConfigBuilder {
142 driver: Driver,
143 url: Option<String>,
144 host: Option<String>,
145 port: Option<u16>,
146 database: Option<String>,
147 user: Option<String>,
148 password: Option<String>,
149 connection: ConnectionOptions,
150 pool: PoolOptions,
151 postgres: Option<PostgresOptions>,
152 mysql: Option<MySqlOptions>,
153 sqlite: Option<SqliteOptions>,
154}
155
156impl DatabaseConfigBuilder {
157 pub fn new(driver: Driver) -> Self {
159 Self {
160 driver,
161 url: None,
162 host: None,
163 port: None,
164 database: None,
165 user: None,
166 password: None,
167 connection: ConnectionOptions::default(),
168 pool: PoolOptions::default(),
169 postgres: if driver == Driver::Postgres {
170 Some(PostgresOptions::new())
171 } else {
172 None
173 },
174 mysql: if driver == Driver::MySql {
175 Some(MySqlOptions::new())
176 } else {
177 None
178 },
179 sqlite: if driver == Driver::Sqlite {
180 Some(SqliteOptions::new())
181 } else {
182 None
183 },
184 }
185 }
186
187 pub fn url(mut self, url: impl Into<String>) -> Self {
189 self.url = Some(url.into());
190 self
191 }
192
193 pub fn host(mut self, host: impl Into<String>) -> Self {
195 self.host = Some(host.into());
196 self
197 }
198
199 pub fn port(mut self, port: u16) -> Self {
201 self.port = Some(port);
202 self
203 }
204
205 pub fn database(mut self, db: impl Into<String>) -> Self {
207 self.database = Some(db.into());
208 self
209 }
210
211 pub fn user(mut self, user: impl Into<String>) -> Self {
213 self.user = Some(user.into());
214 self
215 }
216
217 pub fn password(mut self, password: impl Into<String>) -> Self {
219 self.password = Some(password.into());
220 self
221 }
222
223 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
225 self.connection.connect_timeout = timeout;
226 self
227 }
228
229 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
231 self.connection.ssl.mode = mode;
232 self
233 }
234
235 pub fn application_name(mut self, name: impl Into<String>) -> Self {
237 self.connection.application_name = Some(name.into());
238 self
239 }
240
241 pub fn max_connections(mut self, n: u32) -> Self {
243 self.pool.max_connections = n;
244 self
245 }
246
247 pub fn min_connections(mut self, n: u32) -> Self {
249 self.pool.min_connections = n;
250 self
251 }
252
253 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
255 self.pool.idle_timeout = Some(timeout);
256 self
257 }
258
259 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
261 self.pool.max_lifetime = Some(lifetime);
262 self
263 }
264
265 pub fn postgres_options<F>(mut self, f: F) -> Self
267 where
268 F: FnOnce(PostgresOptions) -> PostgresOptions,
269 {
270 if let Some(opts) = self.postgres.take() {
271 self.postgres = Some(f(opts));
272 }
273 self
274 }
275
276 pub fn mysql_options<F>(mut self, f: F) -> Self
278 where
279 F: FnOnce(MySqlOptions) -> MySqlOptions,
280 {
281 if let Some(opts) = self.mysql.take() {
282 self.mysql = Some(f(opts));
283 }
284 self
285 }
286
287 pub fn sqlite_options<F>(mut self, f: F) -> Self
289 where
290 F: FnOnce(SqliteOptions) -> SqliteOptions,
291 {
292 if let Some(opts) = self.sqlite.take() {
293 self.sqlite = Some(f(opts));
294 }
295 self
296 }
297
298 pub fn build(self) -> ConnectionResult<DatabaseConfig> {
300 if self.url.is_none() {
302 match self.driver {
303 Driver::Postgres | Driver::MySql => {
304 if self.host.is_none() {
305 return Err(ConnectionError::MissingField("host".to_string()));
306 }
307 }
308 Driver::Sqlite => {
309 if self.database.is_none() {
310 return Err(ConnectionError::MissingField("database (file path)".to_string()));
311 }
312 }
313 }
314 }
315
316 Ok(DatabaseConfig {
317 driver: self.driver,
318 url: self.url.unwrap_or_default(),
319 host: self.host,
320 port: self.port,
321 database: self.database,
322 user: self.user,
323 password: self.password,
324 connection: self.connection,
325 pool: self.pool,
326 postgres: self.postgres,
327 mysql: self.mysql,
328 sqlite: self.sqlite,
329 })
330 }
331}
332
333#[derive(Debug, Clone, Default)]
335pub struct MultiDatabaseConfig {
336 pub primary: Option<DatabaseConfig>,
338 pub replicas: Vec<DatabaseConfig>,
340 pub databases: HashMap<String, DatabaseConfig>,
342 pub load_balance: LoadBalanceStrategy,
344}
345
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
348pub enum LoadBalanceStrategy {
349 #[default]
351 RoundRobin,
352 Random,
354 First,
356 LeastLatency,
358}
359
360impl MultiDatabaseConfig {
361 pub fn new() -> Self {
363 Self::default()
364 }
365
366 pub fn primary(mut self, config: DatabaseConfig) -> Self {
368 self.primary = Some(config);
369 self
370 }
371
372 pub fn replica(mut self, config: DatabaseConfig) -> Self {
374 self.replicas.push(config);
375 self
376 }
377
378 pub fn database(mut self, name: impl Into<String>, config: DatabaseConfig) -> Self {
380 self.databases.insert(name.into(), config);
381 self
382 }
383
384 pub fn load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
386 self.load_balance = strategy;
387 self
388 }
389
390 pub fn get_primary(&self) -> Option<&DatabaseConfig> {
392 self.primary.as_ref()
393 }
394
395 pub fn get(&self, name: &str) -> Option<&DatabaseConfig> {
397 self.databases.get(name)
398 }
399
400 pub fn has_replicas(&self) -> bool {
402 !self.replicas.is_empty()
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_config_from_url() {
412 let config = DatabaseConfig::from_url(
413 "postgres://user:pass@localhost:5432/mydb?sslmode=require"
414 ).unwrap();
415
416 assert_eq!(config.driver, Driver::Postgres);
417 assert_eq!(config.host, Some("localhost".to_string()));
418 assert_eq!(config.port, Some(5432));
419 assert_eq!(config.database, Some("mydb".to_string()));
420 assert_eq!(config.user, Some("user".to_string()));
421 assert!(config.postgres.is_some());
422 }
423
424 #[test]
425 fn test_postgres_builder() {
426 let config = DatabaseConfig::postgres()
427 .host("localhost")
428 .port(5432)
429 .database("mydb")
430 .user("user")
431 .password("pass")
432 .max_connections(20)
433 .ssl_mode(SslMode::Require)
434 .build()
435 .unwrap();
436
437 assert_eq!(config.driver, Driver::Postgres);
438 assert_eq!(config.host, Some("localhost".to_string()));
439 assert_eq!(config.pool.max_connections, 20);
440 assert_eq!(config.connection.ssl.mode, SslMode::Require);
441 }
442
443 #[test]
444 fn test_mysql_builder() {
445 let config = DatabaseConfig::mysql()
446 .host("127.0.0.1")
447 .database("testdb")
448 .user("root")
449 .mysql_options(|opts| opts.charset("utf8mb4"))
450 .build()
451 .unwrap();
452
453 assert_eq!(config.driver, Driver::MySql);
454 assert!(config.mysql.is_some());
455 assert_eq!(config.mysql.unwrap().charset, Some("utf8mb4".to_string()));
456 }
457
458 #[test]
459 fn test_sqlite_builder() {
460 let config = DatabaseConfig::sqlite()
461 .database("./data/app.db")
462 .sqlite_options(|opts| opts.foreign_keys(true))
463 .build()
464 .unwrap();
465
466 assert_eq!(config.driver, Driver::Sqlite);
467 assert!(config.sqlite.is_some());
468 assert!(config.sqlite.unwrap().foreign_keys);
469 }
470
471 #[test]
472 fn test_builder_validation() {
473 let result = DatabaseConfig::postgres()
475 .database("mydb")
476 .build();
477 assert!(result.is_err());
478
479 let result = DatabaseConfig::sqlite().build();
481 assert!(result.is_err());
482 }
483
484 #[test]
485 fn test_multi_database_config() {
486 let config = MultiDatabaseConfig::new()
487 .primary(DatabaseConfig::from_url("postgres://localhost/primary").unwrap())
488 .replica(DatabaseConfig::from_url("postgres://localhost/replica1").unwrap())
489 .replica(DatabaseConfig::from_url("postgres://localhost/replica2").unwrap())
490 .database("analytics", DatabaseConfig::from_url("postgres://localhost/analytics").unwrap())
491 .load_balance(LoadBalanceStrategy::RoundRobin);
492
493 assert!(config.get_primary().is_some());
494 assert_eq!(config.replicas.len(), 2);
495 assert!(config.get("analytics").is_some());
496 assert!(config.has_replicas());
497 }
498
499 #[test]
500 fn test_to_url() {
501 let config = DatabaseConfig::postgres()
502 .host("localhost")
503 .port(5432)
504 .database("mydb")
505 .user("user")
506 .password("pass")
507 .build()
508 .unwrap();
509
510 let url = config.to_url();
511 assert!(url.contains("postgres://"));
512 assert!(url.contains("user:pass@"));
513 assert!(url.contains("localhost:5432"));
514 assert!(url.contains("/mydb"));
515 }
516}
517
518