prax_query/connection/
config.rs

1//! Database configuration.
2
3use super::{
4    ConnectionError, ConnectionOptions, ConnectionResult, ConnectionString, Driver, MySqlOptions,
5    PoolOptions, PostgresOptions, SqliteOptions, SslMode,
6};
7use std::collections::HashMap;
8use std::time::Duration;
9use tracing::info;
10
11/// Complete database configuration.
12#[derive(Debug, Clone)]
13pub struct DatabaseConfig {
14    /// Database driver.
15    pub driver: Driver,
16    /// Connection URL.
17    pub url: String,
18    /// Host (if not in URL).
19    pub host: Option<String>,
20    /// Port (if not in URL).
21    pub port: Option<u16>,
22    /// Database name (if not in URL).
23    pub database: Option<String>,
24    /// Username (if not in URL).
25    pub user: Option<String>,
26    /// Password (if not in URL).
27    pub password: Option<String>,
28    /// Connection options.
29    pub connection: ConnectionOptions,
30    /// Pool options.
31    pub pool: PoolOptions,
32    /// PostgreSQL-specific options.
33    pub postgres: Option<PostgresOptions>,
34    /// MySQL-specific options.
35    pub mysql: Option<MySqlOptions>,
36    /// SQLite-specific options.
37    pub sqlite: Option<SqliteOptions>,
38}
39
40impl DatabaseConfig {
41    /// Create a new PostgreSQL configuration builder.
42    pub fn postgres() -> DatabaseConfigBuilder {
43        DatabaseConfigBuilder::new(Driver::Postgres)
44    }
45
46    /// Create a new MySQL configuration builder.
47    pub fn mysql() -> DatabaseConfigBuilder {
48        DatabaseConfigBuilder::new(Driver::MySql)
49    }
50
51    /// Create a new SQLite configuration builder.
52    pub fn sqlite() -> DatabaseConfigBuilder {
53        DatabaseConfigBuilder::new(Driver::Sqlite)
54    }
55
56    /// Create configuration from a connection string.
57    pub fn from_url(url: &str) -> ConnectionResult<Self> {
58        let conn = ConnectionString::parse(url)?;
59        let opts = ConnectionOptions::from_params(conn.params());
60
61        let config = Self {
62            driver: conn.driver(),
63            url: url.to_string(),
64            host: conn.host().map(String::from),
65            port: conn.port(),
66            database: conn.database().map(String::from),
67            user: conn.user().map(String::from),
68            password: conn.password().map(String::from),
69            connection: opts,
70            pool: PoolOptions::default(),
71            postgres: if conn.driver() == Driver::Postgres {
72                Some(PostgresOptions::new())
73            } else {
74                None
75            },
76            mysql: if conn.driver() == Driver::MySql {
77                Some(MySqlOptions::new())
78            } else {
79                None
80            },
81            sqlite: if conn.driver() == Driver::Sqlite {
82                Some(SqliteOptions::new())
83            } else {
84                None
85            },
86        };
87
88        info!(
89            driver = %config.driver.name(),
90            host = ?config.host,
91            database = ?config.database,
92            "DatabaseConfig loaded from URL"
93        );
94
95        Ok(config)
96    }
97
98    /// Create configuration from DATABASE_URL environment variable.
99    pub fn from_env() -> ConnectionResult<Self> {
100        info!("Loading database configuration from DATABASE_URL");
101        let url = std::env::var("DATABASE_URL")
102            .map_err(|_| ConnectionError::EnvNotFound("DATABASE_URL".to_string()))?;
103        Self::from_url(&url)
104    }
105
106    /// Build a connection URL from the configuration.
107    pub fn to_url(&self) -> String {
108        if !self.url.is_empty() {
109            return self.url.clone();
110        }
111
112        let mut url = format!("{}://", self.driver.name());
113
114        if let Some(ref user) = self.user {
115            url.push_str(user);
116            if let Some(ref pass) = self.password {
117                url.push(':');
118                url.push_str(pass);
119            }
120            url.push('@');
121        }
122
123        if let Some(ref host) = self.host {
124            url.push_str(host);
125            if let Some(port) = self.port {
126                url.push(':');
127                url.push_str(&port.to_string());
128            }
129        }
130
131        if let Some(ref db) = self.database {
132            url.push('/');
133            url.push_str(db);
134        }
135
136        url
137    }
138}
139
140/// Builder for database configuration.
141pub struct DatabaseConfigBuilder {
142    driver: Driver,
143    url: Option<String>,
144    host: Option<String>,
145    port: Option<u16>,
146    database: Option<String>,
147    user: Option<String>,
148    password: Option<String>,
149    connection: ConnectionOptions,
150    pool: PoolOptions,
151    postgres: Option<PostgresOptions>,
152    mysql: Option<MySqlOptions>,
153    sqlite: Option<SqliteOptions>,
154}
155
156impl DatabaseConfigBuilder {
157    /// Create a new builder for the given driver.
158    pub fn new(driver: Driver) -> Self {
159        Self {
160            driver,
161            url: None,
162            host: None,
163            port: None,
164            database: None,
165            user: None,
166            password: None,
167            connection: ConnectionOptions::default(),
168            pool: PoolOptions::default(),
169            postgres: if driver == Driver::Postgres {
170                Some(PostgresOptions::new())
171            } else {
172                None
173            },
174            mysql: if driver == Driver::MySql {
175                Some(MySqlOptions::new())
176            } else {
177                None
178            },
179            sqlite: if driver == Driver::Sqlite {
180                Some(SqliteOptions::new())
181            } else {
182                None
183            },
184        }
185    }
186
187    /// Set the connection URL (overrides other connection settings).
188    pub fn url(mut self, url: impl Into<String>) -> Self {
189        self.url = Some(url.into());
190        self
191    }
192
193    /// Set the host.
194    pub fn host(mut self, host: impl Into<String>) -> Self {
195        self.host = Some(host.into());
196        self
197    }
198
199    /// Set the port.
200    pub fn port(mut self, port: u16) -> Self {
201        self.port = Some(port);
202        self
203    }
204
205    /// Set the database name.
206    pub fn database(mut self, db: impl Into<String>) -> Self {
207        self.database = Some(db.into());
208        self
209    }
210
211    /// Set the username.
212    pub fn user(mut self, user: impl Into<String>) -> Self {
213        self.user = Some(user.into());
214        self
215    }
216
217    /// Set the password.
218    pub fn password(mut self, password: impl Into<String>) -> Self {
219        self.password = Some(password.into());
220        self
221    }
222
223    /// Set connection timeout.
224    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
225        self.connection.connect_timeout = timeout;
226        self
227    }
228
229    /// Set SSL mode.
230    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
231        self.connection.ssl.mode = mode;
232        self
233    }
234
235    /// Set application name.
236    pub fn application_name(mut self, name: impl Into<String>) -> Self {
237        self.connection.application_name = Some(name.into());
238        self
239    }
240
241    /// Set max connections.
242    pub fn max_connections(mut self, n: u32) -> Self {
243        self.pool.max_connections = n;
244        self
245    }
246
247    /// Set min connections.
248    pub fn min_connections(mut self, n: u32) -> Self {
249        self.pool.min_connections = n;
250        self
251    }
252
253    /// Set idle timeout.
254    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
255        self.pool.idle_timeout = Some(timeout);
256        self
257    }
258
259    /// Set max lifetime.
260    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
261        self.pool.max_lifetime = Some(lifetime);
262        self
263    }
264
265    /// Configure PostgreSQL options.
266    pub fn postgres_options<F>(mut self, f: F) -> Self
267    where
268        F: FnOnce(PostgresOptions) -> PostgresOptions,
269    {
270        if let Some(opts) = self.postgres.take() {
271            self.postgres = Some(f(opts));
272        }
273        self
274    }
275
276    /// Configure MySQL options.
277    pub fn mysql_options<F>(mut self, f: F) -> Self
278    where
279        F: FnOnce(MySqlOptions) -> MySqlOptions,
280    {
281        if let Some(opts) = self.mysql.take() {
282            self.mysql = Some(f(opts));
283        }
284        self
285    }
286
287    /// Configure SQLite options.
288    pub fn sqlite_options<F>(mut self, f: F) -> Self
289    where
290        F: FnOnce(SqliteOptions) -> SqliteOptions,
291    {
292        if let Some(opts) = self.sqlite.take() {
293            self.sqlite = Some(f(opts));
294        }
295        self
296    }
297
298    /// Build the configuration.
299    pub fn build(self) -> ConnectionResult<DatabaseConfig> {
300        // Validate required fields based on driver
301        if self.url.is_none() {
302            match self.driver {
303                Driver::Postgres | Driver::MySql => {
304                    if self.host.is_none() {
305                        return Err(ConnectionError::MissingField("host".to_string()));
306                    }
307                }
308                Driver::Sqlite => {
309                    if self.database.is_none() {
310                        return Err(ConnectionError::MissingField(
311                            "database (file path)".to_string(),
312                        ));
313                    }
314                }
315            }
316        }
317
318        Ok(DatabaseConfig {
319            driver: self.driver,
320            url: self.url.unwrap_or_default(),
321            host: self.host,
322            port: self.port,
323            database: self.database,
324            user: self.user,
325            password: self.password,
326            connection: self.connection,
327            pool: self.pool,
328            postgres: self.postgres,
329            mysql: self.mysql,
330            sqlite: self.sqlite,
331        })
332    }
333}
334
335/// Configuration for multiple databases.
336#[derive(Debug, Clone, Default)]
337pub struct MultiDatabaseConfig {
338    /// Primary database configuration.
339    pub primary: Option<DatabaseConfig>,
340    /// Read replica configurations.
341    pub replicas: Vec<DatabaseConfig>,
342    /// Named database configurations.
343    pub databases: HashMap<String, DatabaseConfig>,
344    /// Load balancing strategy for replicas.
345    pub load_balance: LoadBalanceStrategy,
346}
347
348/// Load balancing strategy for read replicas.
349#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
350pub enum LoadBalanceStrategy {
351    /// Round-robin between replicas.
352    #[default]
353    RoundRobin,
354    /// Random selection.
355    Random,
356    /// Use the first available replica.
357    First,
358    /// Use the replica with lowest latency.
359    LeastLatency,
360}
361
362impl MultiDatabaseConfig {
363    /// Create a new multi-database configuration.
364    pub fn new() -> Self {
365        Self::default()
366    }
367
368    /// Set the primary database.
369    pub fn primary(mut self, config: DatabaseConfig) -> Self {
370        self.primary = Some(config);
371        self
372    }
373
374    /// Add a read replica.
375    pub fn replica(mut self, config: DatabaseConfig) -> Self {
376        self.replicas.push(config);
377        self
378    }
379
380    /// Add a named database.
381    pub fn database(mut self, name: impl Into<String>, config: DatabaseConfig) -> Self {
382        self.databases.insert(name.into(), config);
383        self
384    }
385
386    /// Set load balancing strategy.
387    pub fn load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
388        self.load_balance = strategy;
389        self
390    }
391
392    /// Get the primary database configuration.
393    pub fn get_primary(&self) -> Option<&DatabaseConfig> {
394        self.primary.as_ref()
395    }
396
397    /// Get a named database configuration.
398    pub fn get(&self, name: &str) -> Option<&DatabaseConfig> {
399        self.databases.get(name)
400    }
401
402    /// Check if replicas are configured.
403    pub fn has_replicas(&self) -> bool {
404        !self.replicas.is_empty()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_config_from_url() {
414        let config =
415            DatabaseConfig::from_url("postgres://user:pass@localhost:5432/mydb?sslmode=require")
416                .unwrap();
417
418        assert_eq!(config.driver, Driver::Postgres);
419        assert_eq!(config.host, Some("localhost".to_string()));
420        assert_eq!(config.port, Some(5432));
421        assert_eq!(config.database, Some("mydb".to_string()));
422        assert_eq!(config.user, Some("user".to_string()));
423        assert!(config.postgres.is_some());
424    }
425
426    #[test]
427    fn test_postgres_builder() {
428        let config = DatabaseConfig::postgres()
429            .host("localhost")
430            .port(5432)
431            .database("mydb")
432            .user("user")
433            .password("pass")
434            .max_connections(20)
435            .ssl_mode(SslMode::Require)
436            .build()
437            .unwrap();
438
439        assert_eq!(config.driver, Driver::Postgres);
440        assert_eq!(config.host, Some("localhost".to_string()));
441        assert_eq!(config.pool.max_connections, 20);
442        assert_eq!(config.connection.ssl.mode, SslMode::Require);
443    }
444
445    #[test]
446    fn test_mysql_builder() {
447        let config = DatabaseConfig::mysql()
448            .host("127.0.0.1")
449            .database("testdb")
450            .user("root")
451            .mysql_options(|opts| opts.charset("utf8mb4"))
452            .build()
453            .unwrap();
454
455        assert_eq!(config.driver, Driver::MySql);
456        assert!(config.mysql.is_some());
457        assert_eq!(config.mysql.unwrap().charset, Some("utf8mb4".to_string()));
458    }
459
460    #[test]
461    fn test_sqlite_builder() {
462        let config = DatabaseConfig::sqlite()
463            .database("./data/app.db")
464            .sqlite_options(|opts| opts.foreign_keys(true))
465            .build()
466            .unwrap();
467
468        assert_eq!(config.driver, Driver::Sqlite);
469        assert!(config.sqlite.is_some());
470        assert!(config.sqlite.unwrap().foreign_keys);
471    }
472
473    #[test]
474    fn test_builder_validation() {
475        // Missing host for PostgreSQL
476        let result = DatabaseConfig::postgres().database("mydb").build();
477        assert!(result.is_err());
478
479        // Missing database for SQLite
480        let result = DatabaseConfig::sqlite().build();
481        assert!(result.is_err());
482    }
483
484    #[test]
485    fn test_multi_database_config() {
486        let config = MultiDatabaseConfig::new()
487            .primary(DatabaseConfig::from_url("postgres://localhost/primary").unwrap())
488            .replica(DatabaseConfig::from_url("postgres://localhost/replica1").unwrap())
489            .replica(DatabaseConfig::from_url("postgres://localhost/replica2").unwrap())
490            .database(
491                "analytics",
492                DatabaseConfig::from_url("postgres://localhost/analytics").unwrap(),
493            )
494            .load_balance(LoadBalanceStrategy::RoundRobin);
495
496        assert!(config.get_primary().is_some());
497        assert_eq!(config.replicas.len(), 2);
498        assert!(config.get("analytics").is_some());
499        assert!(config.has_replicas());
500    }
501
502    #[test]
503    fn test_to_url() {
504        let config = DatabaseConfig::postgres()
505            .host("localhost")
506            .port(5432)
507            .database("mydb")
508            .user("user")
509            .password("pass")
510            .build()
511            .unwrap();
512
513        let url = config.to_url();
514        assert!(url.contains("postgres://"));
515        assert!(url.contains("user:pass@"));
516        assert!(url.contains("localhost:5432"));
517        assert!(url.contains("/mydb"));
518    }
519}