Skip to main content

modkit_db/
options.rs

1//! Database connection options and configuration types.
2
3use modkit_utils::var_expand::expand_env_vars;
4
5use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
6use crate::{DbError, DbHandle, Result};
7
8// Pool configuration moved to config.rs
9
10/// Database connection options using typed sqlx `ConnectOptions`.
11#[derive(Debug, Clone)]
12pub(crate) enum DbConnectOptions {
13    #[cfg(feature = "sqlite")]
14    Sqlite(sqlx::sqlite::SqliteConnectOptions),
15    #[cfg(feature = "pg")]
16    Postgres(sqlx::postgres::PgConnectOptions),
17    #[cfg(feature = "mysql")]
18    MySql(sqlx::mysql::MySqlConnectOptions),
19}
20
21impl std::fmt::Display for DbConnectOptions {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            #[cfg(feature = "sqlite")]
25            DbConnectOptions::Sqlite(opts) => {
26                let filename = opts.get_filename().display().to_string();
27                if filename.is_empty() {
28                    write!(f, "sqlite://memory")
29                } else {
30                    write!(f, "sqlite://{filename}")
31                }
32            }
33            #[cfg(feature = "pg")]
34            DbConnectOptions::Postgres(opts) => {
35                write!(
36                    f,
37                    "postgresql://<redacted>@{}:{}/{}",
38                    opts.get_host(),
39                    opts.get_port(),
40                    opts.get_database().unwrap_or("")
41                )
42            }
43            #[cfg(feature = "mysql")]
44            DbConnectOptions::MySql(_opts) => {
45                write!(f, "mysql://<redacted>@...")
46            }
47            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
48            _ => {
49                unreachable!("No database features enabled")
50            }
51        }
52    }
53}
54
55#[cfg(feature = "sqlite")]
56fn is_memory_filename(path: &std::path::Path) -> bool {
57    if path.as_os_str().is_empty() {
58        return true;
59    }
60
61    match path.to_str() {
62        Some(raw) => matches!(
63            raw.trim(),
64            ":memory:" | "memory:" | "file::memory:" | "file:memory:" | ""
65        ),
66        None => false,
67    }
68}
69
70impl DbConnectOptions {
71    /// Connect to the database using the configured options.
72    ///
73    /// # Errors
74    /// Returns an error if the database connection fails.
75    pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
76        match self {
77            #[cfg(feature = "sqlite")]
78            DbConnectOptions::Sqlite(opts) => {
79                let mut pool_opts = pool.apply_sqlite(sqlx::sqlite::SqlitePoolOptions::new());
80
81                if is_memory_filename(opts.get_filename()) {
82                    pool_opts = pool_opts.max_connections(1).min_connections(1);
83                    tracing::info!("Using single connection pool for in-memory SQLite database");
84                }
85
86                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
87
88                let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool);
89
90                let filename = opts.get_filename().display().to_string();
91                let handle = DbHandle {
92                    engine: crate::DbEngine::Sqlite,
93                    dsn: format!("sqlite://{filename}"),
94                    sea,
95                };
96
97                Ok(handle)
98            }
99            #[cfg(feature = "pg")]
100            DbConnectOptions::Postgres(opts) => {
101                let pool_opts = pool.apply_pg(sqlx::postgres::PgPoolOptions::new());
102
103                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
104
105                let sea = sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool);
106
107                let handle = DbHandle {
108                    engine: crate::DbEngine::Postgres,
109                    dsn: format!(
110                        "postgresql://<redacted>@{}:{}/{}",
111                        opts.get_host(),
112                        opts.get_port(),
113                        opts.get_database().unwrap_or("")
114                    ),
115                    sea,
116                };
117
118                Ok(handle)
119            }
120            #[cfg(feature = "mysql")]
121            DbConnectOptions::MySql(opts) => {
122                let pool_opts = pool.apply_mysql(sqlx::mysql::MySqlPoolOptions::new());
123
124                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
125
126                let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool);
127
128                let handle = DbHandle {
129                    engine: crate::DbEngine::MySql,
130                    dsn: "mysql://<redacted>@...".to_owned(),
131                    sea,
132                };
133
134                Ok(handle)
135            }
136            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
137            _ => {
138                unreachable!("No database features enabled")
139            }
140        }
141    }
142}
143
144/// `SQLite` PRAGMA whitelist and validation.
145#[cfg(feature = "sqlite")]
146pub mod sqlite_pragma {
147    use crate::DbError;
148    use std::collections::HashMap;
149    use std::hash::BuildHasher;
150
151    /// Whitelisted `SQLite` PRAGMA parameters.
152    const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
153
154    /// Validate and apply `SQLite` PRAGMA parameters to connection options.
155    ///
156    /// # Errors
157    /// Returns `DbError::UnknownSqlitePragma` if an unsupported pragma is provided.
158    /// Returns `DbError::InvalidSqlitePragmaValue` if a pragma value is invalid.
159    pub fn apply_pragmas<S: BuildHasher>(
160        mut opts: sqlx::sqlite::SqliteConnectOptions,
161        params: &HashMap<String, String, S>,
162    ) -> crate::Result<sqlx::sqlite::SqliteConnectOptions> {
163        for (key, value) in params {
164            let key_lower = key.to_lowercase();
165
166            if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
167                return Err(DbError::UnknownSqlitePragma(key.clone()));
168            }
169
170            match key_lower.as_str() {
171                "wal" => {
172                    let journal_mode = validate_wal_pragma(value)?;
173                    opts = opts.pragma("journal_mode", journal_mode);
174                }
175                "journal_mode" => {
176                    let mode = validate_journal_mode_pragma(value)?;
177                    opts = opts.pragma("journal_mode", mode);
178                }
179                "synchronous" => {
180                    let sync_mode = validate_synchronous_pragma(value)?;
181                    opts = opts.pragma("synchronous", sync_mode);
182                }
183                "busy_timeout" => {
184                    let timeout = validate_busy_timeout_pragma(value)?;
185                    opts = opts.pragma("busy_timeout", timeout.to_string());
186                }
187                _ => unreachable!("Checked against whitelist above"),
188            }
189        }
190
191        Ok(opts)
192    }
193
194    /// Validate WAL PRAGMA value.
195    fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
196        match value.to_lowercase().as_str() {
197            "true" | "1" => Ok("WAL"),
198            "false" | "0" => Ok("DELETE"),
199            _ => Err(DbError::InvalidSqlitePragma {
200                key: "wal".to_owned(),
201                message: format!("must be true/false/1/0, got '{value}'"),
202            }),
203        }
204    }
205
206    /// Validate synchronous PRAGMA value.
207    fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
208        match value.to_uppercase().as_str() {
209            "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
210            _ => Err(DbError::InvalidSqlitePragma {
211                key: "synchronous".to_owned(),
212                message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
213            }),
214        }
215    }
216
217    /// Validate `busy_timeout` PRAGMA value.
218    fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
219        let timeout = value
220            .parse::<i64>()
221            .map_err(|_| DbError::InvalidSqlitePragma {
222                key: "busy_timeout".to_owned(),
223                message: format!("must be a non-negative integer, got '{value}'"),
224            })?;
225
226        if timeout < 0 {
227            return Err(DbError::InvalidSqlitePragma {
228                key: "busy_timeout".to_owned(),
229                message: format!("must be non-negative, got '{timeout}'"),
230            });
231        }
232
233        Ok(timeout)
234    }
235
236    /// Validate `journal_mode` PRAGMA value.
237    fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
238        match value.to_uppercase().as_str() {
239            "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
240                Ok(value.to_uppercase())
241            }
242            _ => Err(DbError::InvalidSqlitePragma {
243                key: "journal_mode".to_owned(),
244                message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
245            }),
246        }
247    }
248}
249
250/// Build a database handle from configuration (internal).
251///
252/// This is an internal entry point used by `DbManager` / runtime wiring. Module code must
253/// never observe `DbHandle`; it should use `Db` or `DBProvider<E>` only.
254///
255/// # Errors
256/// Returns an error if the database connection fails or configuration is invalid.
257pub(crate) async fn build_db_handle(
258    mut cfg: DbConnConfig,
259    _global: Option<&GlobalDatabaseConfig>,
260) -> Result<DbHandle> {
261    // Expand environment variables in DSN and password
262    if let Some(dsn) = &cfg.dsn {
263        cfg.dsn = Some(expand_env_vars(dsn)?);
264    }
265    if let Some(password) = &cfg.password {
266        cfg.password = Some(resolve_password(password)?);
267    }
268
269    // Expand environment variables in params
270    if let Some(ref mut params) = cfg.params {
271        for (_, value) in params.iter_mut() {
272            if value.contains("${") {
273                *value = expand_env_vars(value)?;
274            }
275        }
276    }
277
278    // Validate configuration for conflicts
279    validate_config_consistency(&cfg)?;
280
281    // Determine database engine and build connection options.
282    let engine = determine_engine(&cfg)?;
283    let connect_options = match engine {
284        DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
285        DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
286    };
287
288    // Build pool configuration
289    let pool_cfg = cfg.pool.unwrap_or_default();
290
291    // Log connection attempt (without credentials)
292    let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
293    tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
294
295    // Connect to database
296    let handle = connect_options.connect(pool_cfg).await?;
297
298    Ok(handle)
299}
300
301fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
302    // If both engine and DSN are provided, validate they don't conflict.
303    // (We do the same check in validate_config_consistency, but keep this here to ensure
304    // determine_engine() never returns a misleading value.)
305    if let Some(engine) = cfg.engine {
306        if let Some(dsn) = cfg.dsn.as_deref() {
307            let inferred = engine_from_dsn(dsn)?;
308            if inferred != engine {
309                return Err(DbError::ConfigConflict(format!(
310                    "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
311                )));
312            }
313        }
314        return Ok(engine);
315    }
316
317    // If DSN is not provided, engine is required.
318    //
319    // Rationale:
320    // - Without DSN we cannot reliably distinguish Postgres vs MySQL.
321    // - For SQLite we also want explicit intent (file/path alone is not a transport selector).
322    if cfg.dsn.is_none() {
323        return Err(DbError::InvalidParameter(
324            "Missing 'engine': required when 'dsn' is not provided".to_owned(),
325        ));
326    }
327
328    // Infer from DSN scheme when present.
329    let Some(dsn) = cfg.dsn.as_deref() else {
330        // SAFETY: guarded above by `cfg.dsn.is_none()`.
331        return Err(DbError::InvalidParameter(
332            "Missing 'dsn': required to infer database engine".to_owned(),
333        ));
334    };
335    engine_from_dsn(dsn)
336}
337
338fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
339    let s = dsn.trim_start();
340    if s.starts_with("postgres://") || s.starts_with("postgresql://") {
341        Ok(DbEngineCfg::Postgres)
342    } else if s.starts_with("mysql://") {
343        Ok(DbEngineCfg::Mysql)
344    } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
345        Ok(DbEngineCfg::Sqlite)
346    } else {
347        Err(DbError::UnknownDsn(dsn.to_owned()))
348    }
349}
350
351/// Build `SQLite` connection options from configuration.
352#[cfg(feature = "sqlite")]
353fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
354    let db_path = if let Some(dsn) = &cfg.dsn {
355        parse_sqlite_path_from_dsn(dsn)?
356    } else if let Some(path) = &cfg.path {
357        path.clone()
358    } else if let Some(_file) = &cfg.file {
359        // This should not happen as manager.rs should have resolved file to path
360        return Err(DbError::InvalidParameter(
361            "File path should have been resolved to absolute path".to_owned(),
362        ));
363    } else {
364        return Err(DbError::InvalidParameter(
365            "SQLite connection requires either DSN, path, or file".to_owned(),
366        ));
367    };
368
369    // Ensure parent directory exists
370    if let Some(parent) = db_path.parent() {
371        std::fs::create_dir_all(parent)?;
372    }
373
374    let mut opts = sqlx::sqlite::SqliteConnectOptions::new()
375        .filename(&db_path)
376        .create_if_missing(true);
377
378    // Apply PRAGMA parameters with whitelist validation
379    if let Some(params) = &cfg.params {
380        opts = sqlite_pragma::apply_pragmas(opts, params)?;
381    }
382
383    Ok(DbConnectOptions::Sqlite(opts))
384}
385
386#[cfg(not(feature = "sqlite"))]
387fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
388    Err(DbError::FeatureDisabled("SQLite feature not enabled"))
389}
390
391/// Apply PostgreSQL-specific parameters, distinguishing connection-level params from runtime options.
392#[cfg(feature = "pg")]
393fn apply_pg_params<S: std::hash::BuildHasher>(
394    mut opts: sqlx::postgres::PgConnectOptions,
395    params: &std::collections::HashMap<String, String, S>,
396) -> Result<sqlx::postgres::PgConnectOptions> {
397    use sqlx::postgres::PgSslMode;
398
399    for (key, value) in params {
400        let key_lower = key.to_lowercase();
401        match key_lower.as_str() {
402            // Connection-level SSL parameters
403            "sslmode" | "ssl_mode" => {
404                let mode = value.parse::<PgSslMode>().map_err(|_| {
405                    DbError::InvalidParameter(format!(
406                        "Invalid ssl_mode '{value}': expected disable, allow, prefer, require, verify-ca, or verify-full"
407                    ))
408                })?;
409                opts = opts.ssl_mode(mode);
410            }
411            "sslrootcert" | "ssl_root_cert" => {
412                opts = opts.ssl_root_cert(value.as_str());
413            }
414            "sslcert" | "ssl_client_cert" => {
415                opts = opts.ssl_client_cert(value.as_str());
416            }
417            "sslkey" | "ssl_client_key" => {
418                opts = opts.ssl_client_key(value.as_str());
419            }
420            // Other connection-level parameters
421            "application_name" => {
422                opts = opts.application_name(value);
423            }
424            "statement_cache_capacity" => {
425                let capacity = value.parse::<usize>().map_err(|_| {
426                    DbError::InvalidParameter(format!(
427                        "Invalid statement_cache_capacity '{value}': expected positive integer"
428                    ))
429                })?;
430                opts = opts.statement_cache_capacity(capacity);
431            }
432            "extra_float_digits" => {
433                let val = value.parse::<i8>().map_err(|_| {
434                    DbError::InvalidParameter(format!(
435                        "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
436                    ))
437                })?;
438                if !(-15..=3).contains(&val) {
439                    return Err(DbError::InvalidParameter(format!(
440                        "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
441                    )));
442                }
443                opts = opts.extra_float_digits(val);
444            }
445            // Server runtime parameters go to options()
446            _ => {
447                opts = opts.options([(key.as_str(), value.as_str())]);
448            }
449        }
450    }
451
452    Ok(opts)
453}
454
455/// Apply `MySQL`-specific parameters. `MySQL` has no runtime options like `PostgreSQL`,
456/// so all params are connection-level settings.
457#[cfg(feature = "mysql")]
458fn apply_mysql_params<S: std::hash::BuildHasher>(
459    mut opts: sqlx::mysql::MySqlConnectOptions,
460    params: &std::collections::HashMap<String, String, S>,
461) -> Result<sqlx::mysql::MySqlConnectOptions> {
462    use sqlx::mysql::MySqlSslMode;
463
464    for (key, value) in params {
465        let key_lower = key.to_lowercase();
466        match key_lower.as_str() {
467            // SSL parameters
468            "sslmode" | "ssl_mode" | "ssl-mode" => {
469                let mode = value.parse::<MySqlSslMode>().map_err(|_| {
470                    DbError::InvalidParameter(format!(
471                        "Invalid ssl_mode '{value}': expected disabled, preferred, required, verify_ca, or verify_identity"
472                    ))
473                })?;
474                opts = opts.ssl_mode(mode);
475            }
476            "sslca" | "ssl_ca" | "ssl-ca" => {
477                opts = opts.ssl_ca(value.as_str());
478            }
479            "sslcert" | "ssl_client_cert" | "ssl-cert" => {
480                opts = opts.ssl_client_cert(value.as_str());
481            }
482            "sslkey" | "ssl_client_key" | "ssl-key" => {
483                opts = opts.ssl_client_key(value.as_str());
484            }
485            // Connection parameters
486            "charset" => {
487                opts = opts.charset(value);
488            }
489            "collation" => {
490                opts = opts.collation(value);
491            }
492            "statement_cache_capacity" => {
493                let capacity = value.parse::<usize>().map_err(|_| {
494                    DbError::InvalidParameter(format!(
495                        "Invalid statement_cache_capacity '{value}': expected positive integer"
496                    ))
497                })?;
498                opts = opts.statement_cache_capacity(capacity);
499            }
500            "connect_timeout" | "connect-timeout" => {
501                // NOTE: `sqlx::mysql::MySqlConnectOptions` does not expose a typed connect-timeout
502                // setter. We still accept and validate this parameter for compatibility with
503                // DSN-style configuration and integration tests, but it currently does not
504                // change runtime behavior.
505                let _secs = value.parse::<u64>().map_err(|_| {
506                    DbError::InvalidParameter(format!(
507                        "Invalid connect_timeout '{value}': expected non-negative integer seconds"
508                    ))
509                })?;
510            }
511            "socket" => {
512                opts = opts.socket(value.as_str());
513            }
514            "timezone" => {
515                let tz = if value.eq_ignore_ascii_case("none") || value.is_empty() {
516                    None
517                } else {
518                    Some(value.clone())
519                };
520                opts = opts.timezone(tz);
521            }
522            "pipes_as_concat" => {
523                let flag = parse_bool_param("pipes_as_concat", value)?;
524                opts = opts.pipes_as_concat(flag);
525            }
526            "no_engine_substitution" => {
527                let flag = parse_bool_param("no_engine_substitution", value)?;
528                opts = opts.no_engine_substitution(flag);
529            }
530            "enable_cleartext_plugin" => {
531                let flag = parse_bool_param("enable_cleartext_plugin", value)?;
532                opts = opts.enable_cleartext_plugin(flag);
533            }
534            "set_names" => {
535                let flag = parse_bool_param("set_names", value)?;
536                opts = opts.set_names(flag);
537            }
538            // Unknown parameters - MySQL doesn't support arbitrary runtime params
539            _ => {
540                return Err(DbError::InvalidParameter(format!(
541                    "Unknown MySQL connection parameter: '{key}'"
542                )));
543            }
544        }
545    }
546
547    Ok(opts)
548}
549
550/// Parse a boolean parameter value.
551#[cfg(feature = "mysql")]
552fn parse_bool_param(name: &str, value: &str) -> Result<bool> {
553    match value.to_lowercase().as_str() {
554        "true" | "1" | "yes" | "on" => Ok(true),
555        "false" | "0" | "no" | "off" => Ok(false),
556        _ => Err(DbError::InvalidParameter(format!(
557            "Invalid {name} '{value}': expected true/false/1/0/yes/no/on/off"
558        ))),
559    }
560}
561
562/// Build server-based connection options from configuration.
563fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
564    // When neither `pg` nor `mysql` features are enabled, the match arms that would use `cfg`
565    // are compiled out, but the function still needs to compile cleanly under `-D warnings`.
566    #[cfg(not(any(feature = "pg", feature = "mysql")))]
567    let _ = cfg;
568
569    match engine {
570        DbEngineCfg::Postgres => {
571            #[cfg(feature = "pg")]
572            {
573                let mut opts = if let Some(dsn) = &cfg.dsn {
574                    dsn.parse::<sqlx::postgres::PgConnectOptions>()
575                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
576                } else {
577                    sqlx::postgres::PgConnectOptions::new()
578                };
579
580                // Override with individual fields
581                if let Some(host) = &cfg.host {
582                    opts = opts.host(host);
583                }
584                if let Some(port) = cfg.port {
585                    opts = opts.port(port);
586                }
587                if let Some(user) = &cfg.user {
588                    opts = opts.username(user);
589                }
590                if let Some(password) = &cfg.password {
591                    opts = opts.password(password);
592                }
593                if let Some(dbname) = &cfg.dbname {
594                    opts = opts.database(dbname);
595                } else if cfg.dsn.is_none() {
596                    return Err(DbError::InvalidParameter(
597                        "dbname is required for PostgreSQL connections".to_owned(),
598                    ));
599                }
600
601                // Apply additional parameters
602                if let Some(params) = &cfg.params {
603                    opts = apply_pg_params(opts, params)?;
604                }
605
606                Ok(DbConnectOptions::Postgres(opts))
607            }
608            #[cfg(not(feature = "pg"))]
609            {
610                Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
611            }
612        }
613        DbEngineCfg::Mysql => {
614            #[cfg(feature = "mysql")]
615            {
616                let mut opts = if let Some(dsn) = &cfg.dsn {
617                    dsn.parse::<sqlx::mysql::MySqlConnectOptions>()
618                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
619                } else {
620                    sqlx::mysql::MySqlConnectOptions::new()
621                };
622
623                // Override with individual fields
624                if let Some(host) = &cfg.host {
625                    opts = opts.host(host);
626                }
627                if let Some(port) = cfg.port {
628                    opts = opts.port(port);
629                }
630                if let Some(user) = &cfg.user {
631                    opts = opts.username(user);
632                }
633                if let Some(password) = &cfg.password {
634                    opts = opts.password(password);
635                }
636                if let Some(dbname) = &cfg.dbname {
637                    opts = opts.database(dbname);
638                } else if cfg.dsn.is_none() {
639                    return Err(DbError::InvalidParameter(
640                        "dbname is required for MySQL connections".to_owned(),
641                    ));
642                }
643
644                // Apply additional parameters
645                if let Some(params) = &cfg.params {
646                    opts = apply_mysql_params(opts, params)?;
647                }
648
649                Ok(DbConnectOptions::MySql(opts))
650            }
651            #[cfg(not(feature = "mysql"))]
652            {
653                Err(DbError::FeatureDisabled("MySQL feature not enabled"))
654            }
655        }
656        DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
657            "build_server_options called with sqlite engine".to_owned(),
658        )),
659    }
660}
661
662/// Parse `SQLite` path from DSN.
663#[cfg(feature = "sqlite")]
664fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
665    if dsn.starts_with("sqlite:") {
666        let path_part = dsn
667            .strip_prefix("sqlite:")
668            .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
669        let path_part = if path_part.starts_with("//") {
670            path_part
671                .strip_prefix("//")
672                .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
673        } else {
674            path_part
675        };
676
677        // Remove query parameters
678        let path_part = if let Some(pos) = path_part.find('?') {
679            &path_part[..pos]
680        } else {
681            path_part
682        };
683
684        Ok(std::path::PathBuf::from(path_part))
685    } else {
686        Err(DbError::InvalidParameter(format!(
687            "Invalid SQLite DSN: {dsn}"
688        )))
689    }
690}
691
692/// Resolve password from environment variable if it starts with ${VAR}.
693fn resolve_password(password: &str) -> Result<String> {
694    if password.starts_with("${") && password.ends_with('}') {
695        let var_name = &password[2..password.len() - 1];
696        std::env::var(var_name).map_err(|source| DbError::EnvVar {
697            name: var_name.to_owned(),
698            source,
699        })
700    } else {
701        Ok(password.to_owned())
702    }
703}
704
705/// Validate configuration for consistency and detect conflicts.
706fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
707    // Validate engine against DSN if both are present
708    if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
709        let inferred = engine_from_dsn(dsn)?;
710        if inferred != engine {
711            return Err(DbError::ConfigConflict(format!(
712                "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
713            )));
714        }
715    }
716
717    // Check for SQLite vs server engine conflicts
718    if let Some(dsn) = &cfg.dsn {
719        let is_sqlite_dsn = dsn.starts_with("sqlite");
720        let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
721        let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
722
723        if is_sqlite_dsn && has_server_fields {
724            return Err(DbError::ConfigConflict(
725                "SQLite DSN cannot be used with host/port fields".to_owned(),
726            ));
727        }
728
729        if !is_sqlite_dsn && has_sqlite_fields {
730            return Err(DbError::ConfigConflict(
731                "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
732            ));
733        }
734
735        // Check for server vs non-server DSN conflicts
736        if !is_sqlite_dsn
737            && cfg.server.is_some()
738            && (cfg.host.is_some()
739                || cfg.port.is_some()
740                || cfg.user.is_some()
741                || cfg.password.is_some()
742                || cfg.dbname.is_some())
743        {
744            // This is actually allowed - server provides base config, DSN can override
745            // Fields here override DSN parts intentionally.
746        }
747    }
748
749    // Check for SQLite-specific conflicts
750    if cfg.file.is_some() && cfg.path.is_some() {
751        return Err(DbError::ConfigConflict(
752            "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
753        ));
754    }
755
756    if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
757        return Err(DbError::ConfigConflict(
758            "SQLite file/path fields cannot be used with host/port fields".to_owned(),
759        ));
760    }
761
762    // If engine explicitly says SQLite, reject server connection fields early (even without DSN)
763    if cfg.engine == Some(DbEngineCfg::Sqlite)
764        && (cfg.host.is_some()
765            || cfg.port.is_some()
766            || cfg.user.is_some()
767            || cfg.password.is_some()
768            || cfg.dbname.is_some())
769    {
770        return Err(DbError::ConfigConflict(
771            "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
772        ));
773    }
774
775    // If engine explicitly says server-based, reject sqlite file/path early (even without DSN)
776    if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
777        && (cfg.file.is_some() || cfg.path.is_some())
778    {
779        return Err(DbError::ConfigConflict(
780            "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
781        ));
782    }
783
784    Ok(())
785}
786
787/// Redact credentials from DSN for logging.
788#[must_use]
789pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
790    match dsn {
791        Some(dsn) if dsn.contains('@') => {
792            if let Ok(mut parsed) = url::Url::parse(dsn) {
793                if parsed.password().is_some() {
794                    _ = parsed.set_password(Some("***"));
795                }
796                parsed.to_string()
797            } else {
798                "***".to_owned()
799            }
800        }
801        Some(dsn) => dsn.to_owned(),
802        None => "none".to_owned(),
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809
810    #[test]
811    fn determine_engine_requires_engine_when_dsn_missing() {
812        let cfg = DbConnConfig {
813            dsn: None,
814            engine: None,
815            ..Default::default()
816        };
817
818        let err = determine_engine(&cfg).unwrap_err();
819        assert!(matches!(err, DbError::InvalidParameter(_)));
820        assert!(err.to_string().contains("Missing 'engine'"));
821    }
822
823    #[test]
824    fn determine_engine_infers_from_dsn_when_engine_missing() {
825        let cfg = DbConnConfig {
826            engine: None,
827            dsn: Some("sqlite::memory:".to_owned()),
828            ..Default::default()
829        };
830
831        let engine = determine_engine(&cfg).unwrap();
832        assert_eq!(engine, DbEngineCfg::Sqlite);
833    }
834
835    #[test]
836    fn engine_and_dsn_match_ok() {
837        let cases = [
838            (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
839            (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
840            (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
841            (DbEngineCfg::Sqlite, "sqlite::memory:"),
842            (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
843        ];
844
845        for (engine, dsn) in cases {
846            let cfg = DbConnConfig {
847                engine: Some(engine),
848                dsn: Some(dsn.to_owned()),
849                ..Default::default()
850            };
851            validate_config_consistency(&cfg).unwrap();
852            assert_eq!(determine_engine(&cfg).unwrap(), engine);
853        }
854    }
855
856    #[test]
857    fn engine_and_dsn_mismatch_is_error() {
858        let cases = [
859            (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
860            (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
861            (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
862        ];
863
864        for (engine, dsn) in cases {
865            let cfg = DbConnConfig {
866                engine: Some(engine),
867                dsn: Some(dsn.to_owned()),
868                ..Default::default()
869            };
870
871            let err = validate_config_consistency(&cfg).unwrap_err();
872            assert!(matches!(err, DbError::ConfigConflict(_)));
873        }
874    }
875
876    #[test]
877    fn unknown_dsn_is_error() {
878        let cfg = DbConnConfig {
879            engine: None,
880            dsn: Some("unknown://localhost/db".to_owned()),
881            ..Default::default()
882        };
883
884        // Consistency validation doesn't validate unknown schemes unless `engine` is set,
885        // but engine determination must fail.
886        validate_config_consistency(&cfg).unwrap();
887        let err = determine_engine(&cfg).unwrap_err();
888        assert!(matches!(err, DbError::UnknownDsn(_)));
889    }
890}