Skip to main content

modkit_db/
options.rs

1//! Database connection options and configuration types.
2
3use crate::config::{DbConnConfig, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5use thiserror::Error;
6
7// Pool configuration moved to config.rs
8
9/// Database connection options using typed sqlx `ConnectOptions`.
10#[derive(Debug, Clone)]
11pub enum DbConnectOptions {
12    #[cfg(feature = "sqlite")]
13    Sqlite(sea_orm::sqlx::sqlite::SqliteConnectOptions),
14    #[cfg(feature = "pg")]
15    Postgres(sea_orm::sqlx::postgres::PgConnectOptions),
16    #[cfg(feature = "mysql")]
17    MySql(sea_orm::sqlx::mysql::MySqlConnectOptions),
18}
19
20/// Errors that can occur during connection option building.
21#[derive(Debug, Error)]
22pub enum ConnectionOptionsError {
23    #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
24    InvalidSqlitePragma { key: String, message: String },
25
26    #[error("Unknown SQLite PRAGMA parameter: {0}")]
27    UnknownSqlitePragma(String),
28
29    #[error("Invalid connection parameter: {0}")]
30    InvalidParameter(String),
31
32    #[error("Feature not enabled: {0}")]
33    FeatureDisabled(&'static str),
34
35    #[error("IO error: {0}")]
36    Io(#[from] std::io::Error),
37
38    #[error("URL parsing error: {0}")]
39    UrlParse(#[from] url::ParseError),
40
41    #[error("Environment variable error: {0}")]
42    EnvVar(#[from] std::env::VarError),
43}
44
45impl std::fmt::Display for DbConnectOptions {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            #[cfg(feature = "sqlite")]
49            DbConnectOptions::Sqlite(opts) => {
50                let filename = opts.get_filename().display().to_string();
51                if filename.is_empty() {
52                    write!(f, "sqlite://memory")
53                } else {
54                    write!(f, "sqlite://{filename}")
55                }
56            }
57            #[cfg(feature = "pg")]
58            DbConnectOptions::Postgres(opts) => {
59                write!(
60                    f,
61                    "postgresql://<redacted>@{}:{}/{}",
62                    opts.get_host(),
63                    opts.get_port(),
64                    opts.get_database().unwrap_or("")
65                )
66            }
67            #[cfg(feature = "mysql")]
68            DbConnectOptions::MySql(_opts) => {
69                write!(f, "mysql://<redacted>@...")
70            }
71            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
72            _ => {
73                unreachable!("No database features enabled")
74            }
75        }
76    }
77}
78
79impl DbConnectOptions {
80    /// Connect to the database using the configured options.
81    ///
82    /// # Errors
83    /// Returns an error if the database connection fails.
84    pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
85        match self {
86            #[cfg(feature = "sqlite")]
87            DbConnectOptions::Sqlite(opts) => {
88                let pool_opts = pool.apply_sqlite(sea_orm::sqlx::sqlite::SqlitePoolOptions::new());
89
90                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
91
92                let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool.clone());
93
94                let filename = opts.get_filename().display().to_string();
95                let handle = DbHandle {
96                    engine: crate::DbEngine::Sqlite,
97                    pool: crate::DbPool::Sqlite(sqlx_pool),
98                    dsn: format!("sqlite://{filename}"),
99                    sea,
100                };
101
102                Ok(handle)
103            }
104            #[cfg(feature = "pg")]
105            DbConnectOptions::Postgres(opts) => {
106                let pool_opts = pool.apply_pg(sea_orm::sqlx::postgres::PgPoolOptions::new());
107
108                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
109
110                let sea =
111                    sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool.clone());
112
113                let handle = DbHandle {
114                    engine: crate::DbEngine::Postgres,
115                    pool: crate::DbPool::Postgres(sqlx_pool),
116                    dsn: format!(
117                        "postgresql://<redacted>@{}:{}/{}",
118                        opts.get_host(),
119                        opts.get_port(),
120                        opts.get_database().unwrap_or("")
121                    ),
122                    sea,
123                };
124
125                Ok(handle)
126            }
127            #[cfg(feature = "mysql")]
128            DbConnectOptions::MySql(opts) => {
129                let pool_opts = pool.apply_mysql(sea_orm::sqlx::mysql::MySqlPoolOptions::new());
130
131                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
132
133                let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool.clone());
134
135                let handle = DbHandle {
136                    engine: crate::DbEngine::MySql,
137                    pool: crate::DbPool::MySql(sqlx_pool),
138                    dsn: "mysql://<redacted>@...".to_owned(),
139                    sea,
140                };
141
142                Ok(handle)
143            }
144            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
145            _ => {
146                unreachable!("No database features enabled")
147            }
148        }
149    }
150}
151
152/// `SQLite` PRAGMA whitelist and validation.
153#[cfg(feature = "sqlite")]
154pub mod sqlite_pragma {
155    use crate::DbError;
156    use std::collections::HashMap;
157    use std::hash::BuildHasher;
158
159    /// Whitelisted `SQLite` PRAGMA parameters.
160    const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
161
162    /// Validate and apply `SQLite` PRAGMA parameters to connection options.
163    ///
164    /// # Errors
165    /// Returns `DbError::UnknownSqlitePragma` if an unsupported pragma is provided.
166    /// Returns `DbError::InvalidSqlitePragmaValue` if a pragma value is invalid.
167    pub fn apply_pragmas<S: BuildHasher>(
168        mut opts: sea_orm::sqlx::sqlite::SqliteConnectOptions,
169        params: &HashMap<String, String, S>,
170    ) -> crate::Result<sea_orm::sqlx::sqlite::SqliteConnectOptions> {
171        for (key, value) in params {
172            let key_lower = key.to_lowercase();
173
174            if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
175                return Err(DbError::UnknownSqlitePragma(key.clone()));
176            }
177
178            match key_lower.as_str() {
179                "wal" => {
180                    let journal_mode = validate_wal_pragma(value)?;
181                    opts = opts.pragma("journal_mode", journal_mode);
182                }
183                "journal_mode" => {
184                    let mode = validate_journal_mode_pragma(value)?;
185                    opts = opts.pragma("journal_mode", mode);
186                }
187                "synchronous" => {
188                    let sync_mode = validate_synchronous_pragma(value)?;
189                    opts = opts.pragma("synchronous", sync_mode);
190                }
191                "busy_timeout" => {
192                    let timeout = validate_busy_timeout_pragma(value)?;
193                    opts = opts.pragma("busy_timeout", timeout.to_string());
194                }
195                _ => unreachable!("Checked against whitelist above"),
196            }
197        }
198
199        Ok(opts)
200    }
201
202    /// Validate WAL PRAGMA value.
203    fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
204        match value.to_lowercase().as_str() {
205            "true" | "1" => Ok("WAL"),
206            "false" | "0" => Ok("DELETE"),
207            _ => Err(DbError::InvalidSqlitePragma {
208                key: "wal".to_owned(),
209                message: format!("must be true/false/1/0, got '{value}'"),
210            }),
211        }
212    }
213
214    /// Validate synchronous PRAGMA value.
215    fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
216        match value.to_uppercase().as_str() {
217            "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
218            _ => Err(DbError::InvalidSqlitePragma {
219                key: "synchronous".to_owned(),
220                message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
221            }),
222        }
223    }
224
225    /// Validate `busy_timeout` PRAGMA value.
226    fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
227        let timeout = value
228            .parse::<i64>()
229            .map_err(|_| DbError::InvalidSqlitePragma {
230                key: "busy_timeout".to_owned(),
231                message: format!("must be a non-negative integer, got '{value}'"),
232            })?;
233
234        if timeout < 0 {
235            return Err(DbError::InvalidSqlitePragma {
236                key: "busy_timeout".to_owned(),
237                message: format!("must be non-negative, got '{timeout}'"),
238            });
239        }
240
241        Ok(timeout)
242    }
243
244    /// Validate `journal_mode` PRAGMA value.
245    fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
246        match value.to_uppercase().as_str() {
247            "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
248                Ok(value.to_uppercase())
249            }
250            _ => Err(DbError::InvalidSqlitePragma {
251                key: "journal_mode".to_owned(),
252                message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
253            }),
254        }
255    }
256}
257
258/// Build a database handle from configuration.
259/// This is the unified entry point for building database handles from configuration.
260///
261/// # Errors
262/// Returns an error if the database connection fails or configuration is invalid.
263pub async fn build_db_handle(
264    mut cfg: DbConnConfig,
265    _global: Option<&GlobalDatabaseConfig>,
266) -> Result<DbHandle> {
267    // Expand environment variables in DSN and password
268    if let Some(dsn) = &cfg.dsn {
269        cfg.dsn = Some(expand_env_vars(dsn)?);
270    }
271    if let Some(password) = &cfg.password {
272        cfg.password = Some(resolve_password(password)?);
273    }
274
275    // Expand environment variables in params
276    if let Some(ref mut params) = cfg.params {
277        for (_, value) in params.iter_mut() {
278            if value.contains("${") {
279                *value = expand_env_vars(value)?;
280            }
281        }
282    }
283
284    // Validate configuration for conflicts
285    validate_config_consistency(&cfg)?;
286
287    // Determine database type and build connection options
288    let is_sqlite = cfg.file.is_some()
289        || cfg.path.is_some()
290        || cfg
291            .dsn
292            .as_ref()
293            .is_some_and(|dsn| dsn.starts_with("sqlite"))
294        || (cfg.server.is_none() && cfg.dsn.is_none());
295
296    let connect_options = if is_sqlite {
297        build_sqlite_options(&cfg)?
298    } else {
299        build_server_options(&cfg)?
300    };
301
302    // Build pool configuration
303    let pool_cfg = cfg.pool.unwrap_or_default();
304
305    // Log connection attempt (without credentials)
306    let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
307    tracing::debug!(
308        dsn = log_dsn,
309        is_sqlite = is_sqlite,
310        "Building database connection"
311    );
312
313    // Connect to database
314    let handle = connect_options.connect(pool_cfg).await?;
315
316    Ok(handle)
317}
318
319/// Build `SQLite` connection options from configuration.
320#[cfg(feature = "sqlite")]
321fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
322    let db_path = if let Some(dsn) = &cfg.dsn {
323        parse_sqlite_path_from_dsn(dsn)?
324    } else if let Some(path) = &cfg.path {
325        path.clone()
326    } else if let Some(_file) = &cfg.file {
327        // This should not happen as manager.rs should have resolved file to path
328        return Err(DbError::InvalidParameter(
329            "File path should have been resolved to absolute path".to_owned(),
330        ));
331    } else {
332        return Err(DbError::InvalidParameter(
333            "SQLite connection requires either DSN, path, or file".to_owned(),
334        ));
335    };
336
337    // Ensure parent directory exists
338    if let Some(parent) = db_path.parent() {
339        std::fs::create_dir_all(parent)?;
340    }
341
342    let mut opts = sea_orm::sqlx::sqlite::SqliteConnectOptions::new()
343        .filename(&db_path)
344        .create_if_missing(true);
345
346    // Apply PRAGMA parameters with whitelist validation
347    if let Some(params) = &cfg.params {
348        opts = sqlite_pragma::apply_pragmas(opts, params)?;
349    }
350
351    Ok(DbConnectOptions::Sqlite(opts))
352}
353
354#[cfg(not(feature = "sqlite"))]
355fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
356    Err(DbError::FeatureDisabled("SQLite feature not enabled"))
357}
358
359/// Build server-based connection options from configuration.
360fn build_server_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
361    // Determine the database type from DSN or default to PostgreSQL
362    let scheme = if let Some(dsn) = &cfg.dsn {
363        let parsed = url::Url::parse(dsn)?;
364        parsed.scheme().to_owned()
365    } else {
366        "postgresql".to_owned()
367    };
368
369    match scheme.as_str() {
370        "postgresql" | "postgres" => {
371            #[cfg(feature = "pg")]
372            {
373                let mut opts = if let Some(dsn) = &cfg.dsn {
374                    dsn.parse::<sea_orm::sqlx::postgres::PgConnectOptions>()
375                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
376                } else {
377                    sea_orm::sqlx::postgres::PgConnectOptions::new()
378                };
379
380                // Override with individual fields
381                if let Some(host) = &cfg.host {
382                    opts = opts.host(host);
383                }
384                if let Some(port) = cfg.port {
385                    opts = opts.port(port);
386                }
387                if let Some(user) = &cfg.user {
388                    opts = opts.username(user);
389                }
390                if let Some(password) = &cfg.password {
391                    opts = opts.password(password);
392                }
393                if let Some(dbname) = &cfg.dbname {
394                    opts = opts.database(dbname);
395                } else if cfg.dsn.is_none() {
396                    return Err(DbError::InvalidParameter(
397                        "dbname is required for PostgreSQL connections".to_owned(),
398                    ));
399                }
400
401                // Apply additional parameters
402                if let Some(params) = &cfg.params {
403                    for (key, value) in params {
404                        opts = opts.options([(key.as_str(), value.as_str())]);
405                    }
406                }
407
408                Ok(DbConnectOptions::Postgres(opts))
409            }
410            #[cfg(not(feature = "pg"))]
411            {
412                Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
413            }
414        }
415        "mysql" => {
416            #[cfg(feature = "mysql")]
417            {
418                let mut opts = if let Some(dsn) = &cfg.dsn {
419                    dsn.parse::<sea_orm::sqlx::mysql::MySqlConnectOptions>()
420                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
421                } else {
422                    sea_orm::sqlx::mysql::MySqlConnectOptions::new()
423                };
424
425                // Override with individual fields
426                if let Some(host) = &cfg.host {
427                    opts = opts.host(host);
428                }
429                if let Some(port) = cfg.port {
430                    opts = opts.port(port);
431                }
432                if let Some(user) = &cfg.user {
433                    opts = opts.username(user);
434                }
435                if let Some(password) = &cfg.password {
436                    opts = opts.password(password);
437                }
438                if let Some(dbname) = &cfg.dbname {
439                    opts = opts.database(dbname);
440                } else if cfg.dsn.is_none() {
441                    return Err(DbError::InvalidParameter(
442                        "dbname is required for MySQL connections".to_owned(),
443                    ));
444                }
445
446                Ok(DbConnectOptions::MySql(opts))
447            }
448            #[cfg(not(feature = "mysql"))]
449            {
450                Err(DbError::FeatureDisabled("MySQL feature not enabled"))
451            }
452        }
453        _ => Err(DbError::InvalidParameter(format!(
454            "Unsupported database scheme: {scheme}"
455        ))),
456    }
457}
458
459/// Parse `SQLite` path from DSN.
460#[cfg(feature = "sqlite")]
461fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
462    if dsn.starts_with("sqlite:") {
463        let path_part = dsn
464            .strip_prefix("sqlite:")
465            .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
466        let path_part = if path_part.starts_with("//") {
467            path_part
468                .strip_prefix("//")
469                .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
470        } else {
471            path_part
472        };
473
474        // Remove query parameters
475        let path_part = if let Some(pos) = path_part.find('?') {
476            &path_part[..pos]
477        } else {
478            path_part
479        };
480
481        Ok(std::path::PathBuf::from(path_part))
482    } else {
483        Err(DbError::InvalidParameter(format!(
484            "Invalid SQLite DSN: {dsn}"
485        )))
486    }
487}
488
489/// Expand environment variables in a string.
490fn expand_env_vars(input: &str) -> Result<String> {
491    let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
492        .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
493    let mut result = input.to_owned();
494
495    for caps in re.captures_iter(input) {
496        let full_match = &caps[0];
497        let var_name = &caps[1];
498        let value = std::env::var(var_name)?;
499        result = result.replace(full_match, &value);
500    }
501
502    Ok(result)
503}
504
505/// Resolve password from environment variable if it starts with ${VAR}.
506fn resolve_password(password: &str) -> Result<String> {
507    if password.starts_with("${") && password.ends_with('}') {
508        let var_name = &password[2..password.len() - 1];
509        Ok(std::env::var(var_name)?)
510    } else {
511        Ok(password.to_owned())
512    }
513}
514
515/// Validate configuration for consistency and detect conflicts.
516fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
517    // Check for SQLite vs server engine conflicts
518    if let Some(dsn) = &cfg.dsn {
519        let is_sqlite_dsn = dsn.starts_with("sqlite");
520        let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
521        let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
522
523        if is_sqlite_dsn && has_server_fields {
524            return Err(DbError::ConfigConflict(
525                "SQLite DSN cannot be used with host/port fields".to_owned(),
526            ));
527        }
528
529        if !is_sqlite_dsn && has_sqlite_fields {
530            return Err(DbError::ConfigConflict(
531                "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
532            ));
533        }
534
535        // Check for server vs non-server DSN conflicts
536        if !is_sqlite_dsn
537            && cfg.server.is_some()
538            && (cfg.host.is_some()
539                || cfg.port.is_some()
540                || cfg.user.is_some()
541                || cfg.password.is_some()
542                || cfg.dbname.is_some())
543        {
544            // This is actually allowed - server provides base config, DSN can override
545            // Fields here override DSN parts intentionally.
546        }
547    }
548
549    // Check for SQLite-specific conflicts
550    if cfg.file.is_some() && cfg.path.is_some() {
551        return Err(DbError::ConfigConflict(
552            "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
553        ));
554    }
555
556    if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
557        return Err(DbError::ConfigConflict(
558            "SQLite file/path fields cannot be used with host/port fields".to_owned(),
559        ));
560    }
561
562    Ok(())
563}
564
565/// Redact credentials from DSN for logging.
566#[must_use]
567pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
568    match dsn {
569        Some(dsn) if dsn.contains('@') => {
570            if let Ok(mut parsed) = url::Url::parse(dsn) {
571                if parsed.password().is_some() {
572                    let _ = parsed.set_password(Some("***"));
573                }
574                parsed.to_string()
575            } else {
576                "***".to_owned()
577            }
578        }
579        Some(dsn) => dsn.to_owned(),
580        None => "none".to_owned(),
581    }
582}