Skip to main content

modkit_db/
options.rs

1//! Database connection options and configuration types.
2
3use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5
6// Pool configuration moved to config.rs
7
8/// Database connection options using typed sqlx `ConnectOptions`.
9#[derive(Debug, Clone)]
10pub(crate) enum DbConnectOptions {
11    #[cfg(feature = "sqlite")]
12    Sqlite(sqlx::sqlite::SqliteConnectOptions),
13    #[cfg(feature = "pg")]
14    Postgres(sqlx::postgres::PgConnectOptions),
15    #[cfg(feature = "mysql")]
16    MySql(sqlx::mysql::MySqlConnectOptions),
17}
18
19impl std::fmt::Display for DbConnectOptions {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            #[cfg(feature = "sqlite")]
23            DbConnectOptions::Sqlite(opts) => {
24                let filename = opts.get_filename().display().to_string();
25                if filename.is_empty() {
26                    write!(f, "sqlite://memory")
27                } else {
28                    write!(f, "sqlite://{filename}")
29                }
30            }
31            #[cfg(feature = "pg")]
32            DbConnectOptions::Postgres(opts) => {
33                write!(
34                    f,
35                    "postgresql://<redacted>@{}:{}/{}",
36                    opts.get_host(),
37                    opts.get_port(),
38                    opts.get_database().unwrap_or("")
39                )
40            }
41            #[cfg(feature = "mysql")]
42            DbConnectOptions::MySql(_opts) => {
43                write!(f, "mysql://<redacted>@...")
44            }
45            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
46            _ => {
47                unreachable!("No database features enabled")
48            }
49        }
50    }
51}
52
53#[cfg(feature = "sqlite")]
54fn is_memory_filename(path: &std::path::Path) -> bool {
55    if path.as_os_str().is_empty() {
56        return true;
57    }
58
59    match path.to_str() {
60        Some(raw) => matches!(
61            raw.trim(),
62            ":memory:" | "memory:" | "file::memory:" | "file:memory:" | ""
63        ),
64        None => false,
65    }
66}
67
68impl DbConnectOptions {
69    /// Connect to the database using the configured options.
70    ///
71    /// # Errors
72    /// Returns an error if the database connection fails.
73    pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
74        match self {
75            #[cfg(feature = "sqlite")]
76            DbConnectOptions::Sqlite(opts) => {
77                let mut pool_opts = pool.apply_sqlite(sqlx::sqlite::SqlitePoolOptions::new());
78
79                if is_memory_filename(opts.get_filename()) {
80                    pool_opts = pool_opts.max_connections(1).min_connections(1);
81                    tracing::info!("Using single connection pool for in-memory SQLite database");
82                }
83
84                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
85
86                let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool);
87
88                let filename = opts.get_filename().display().to_string();
89                let handle = DbHandle {
90                    engine: crate::DbEngine::Sqlite,
91                    dsn: format!("sqlite://{filename}"),
92                    sea,
93                };
94
95                Ok(handle)
96            }
97            #[cfg(feature = "pg")]
98            DbConnectOptions::Postgres(opts) => {
99                let pool_opts = pool.apply_pg(sqlx::postgres::PgPoolOptions::new());
100
101                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
102
103                let sea = sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool);
104
105                let handle = DbHandle {
106                    engine: crate::DbEngine::Postgres,
107                    dsn: format!(
108                        "postgresql://<redacted>@{}:{}/{}",
109                        opts.get_host(),
110                        opts.get_port(),
111                        opts.get_database().unwrap_or("")
112                    ),
113                    sea,
114                };
115
116                Ok(handle)
117            }
118            #[cfg(feature = "mysql")]
119            DbConnectOptions::MySql(opts) => {
120                let pool_opts = pool.apply_mysql(sqlx::mysql::MySqlPoolOptions::new());
121
122                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
123
124                let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool);
125
126                let handle = DbHandle {
127                    engine: crate::DbEngine::MySql,
128                    dsn: "mysql://<redacted>@...".to_owned(),
129                    sea,
130                };
131
132                Ok(handle)
133            }
134            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
135            _ => {
136                unreachable!("No database features enabled")
137            }
138        }
139    }
140}
141
142/// `SQLite` PRAGMA whitelist and validation.
143#[cfg(feature = "sqlite")]
144pub mod sqlite_pragma {
145    use crate::DbError;
146    use std::collections::HashMap;
147    use std::hash::BuildHasher;
148
149    /// Whitelisted `SQLite` PRAGMA parameters.
150    const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
151
152    /// Validate and apply `SQLite` PRAGMA parameters to connection options.
153    ///
154    /// # Errors
155    /// Returns `DbError::UnknownSqlitePragma` if an unsupported pragma is provided.
156    /// Returns `DbError::InvalidSqlitePragmaValue` if a pragma value is invalid.
157    pub fn apply_pragmas<S: BuildHasher>(
158        mut opts: sqlx::sqlite::SqliteConnectOptions,
159        params: &HashMap<String, String, S>,
160    ) -> crate::Result<sqlx::sqlite::SqliteConnectOptions> {
161        for (key, value) in params {
162            let key_lower = key.to_lowercase();
163
164            if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
165                return Err(DbError::UnknownSqlitePragma(key.clone()));
166            }
167
168            match key_lower.as_str() {
169                "wal" => {
170                    let journal_mode = validate_wal_pragma(value)?;
171                    opts = opts.pragma("journal_mode", journal_mode);
172                }
173                "journal_mode" => {
174                    let mode = validate_journal_mode_pragma(value)?;
175                    opts = opts.pragma("journal_mode", mode);
176                }
177                "synchronous" => {
178                    let sync_mode = validate_synchronous_pragma(value)?;
179                    opts = opts.pragma("synchronous", sync_mode);
180                }
181                "busy_timeout" => {
182                    let timeout = validate_busy_timeout_pragma(value)?;
183                    opts = opts.pragma("busy_timeout", timeout.to_string());
184                }
185                _ => unreachable!("Checked against whitelist above"),
186            }
187        }
188
189        Ok(opts)
190    }
191
192    /// Validate WAL PRAGMA value.
193    fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
194        match value.to_lowercase().as_str() {
195            "true" | "1" => Ok("WAL"),
196            "false" | "0" => Ok("DELETE"),
197            _ => Err(DbError::InvalidSqlitePragma {
198                key: "wal".to_owned(),
199                message: format!("must be true/false/1/0, got '{value}'"),
200            }),
201        }
202    }
203
204    /// Validate synchronous PRAGMA value.
205    fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
206        match value.to_uppercase().as_str() {
207            "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
208            _ => Err(DbError::InvalidSqlitePragma {
209                key: "synchronous".to_owned(),
210                message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
211            }),
212        }
213    }
214
215    /// Validate `busy_timeout` PRAGMA value.
216    fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
217        let timeout = value
218            .parse::<i64>()
219            .map_err(|_| DbError::InvalidSqlitePragma {
220                key: "busy_timeout".to_owned(),
221                message: format!("must be a non-negative integer, got '{value}'"),
222            })?;
223
224        if timeout < 0 {
225            return Err(DbError::InvalidSqlitePragma {
226                key: "busy_timeout".to_owned(),
227                message: format!("must be non-negative, got '{timeout}'"),
228            });
229        }
230
231        Ok(timeout)
232    }
233
234    /// Validate `journal_mode` PRAGMA value.
235    fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
236        match value.to_uppercase().as_str() {
237            "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
238                Ok(value.to_uppercase())
239            }
240            _ => Err(DbError::InvalidSqlitePragma {
241                key: "journal_mode".to_owned(),
242                message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
243            }),
244        }
245    }
246}
247
248/// Build a database handle from configuration (internal).
249///
250/// This is an internal entry point used by `DbManager` / runtime wiring. Module code must
251/// never observe `DbHandle`; it should use `Db` or `DBProvider<E>` only.
252///
253/// # Errors
254/// Returns an error if the database connection fails or configuration is invalid.
255pub(crate) async fn build_db_handle(
256    mut cfg: DbConnConfig,
257    _global: Option<&GlobalDatabaseConfig>,
258) -> Result<DbHandle> {
259    // Expand environment variables in DSN and password
260    if let Some(dsn) = &cfg.dsn {
261        cfg.dsn = Some(expand_env_vars(dsn)?);
262    }
263    if let Some(password) = &cfg.password {
264        cfg.password = Some(resolve_password(password)?);
265    }
266
267    // Expand environment variables in params
268    if let Some(ref mut params) = cfg.params {
269        for (_, value) in params.iter_mut() {
270            if value.contains("${") {
271                *value = expand_env_vars(value)?;
272            }
273        }
274    }
275
276    // Validate configuration for conflicts
277    validate_config_consistency(&cfg)?;
278
279    // Determine database engine and build connection options.
280    let engine = determine_engine(&cfg)?;
281    let connect_options = match engine {
282        DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
283        DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
284    };
285
286    // Build pool configuration
287    let pool_cfg = cfg.pool.unwrap_or_default();
288
289    // Log connection attempt (without credentials)
290    let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
291    tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
292
293    // Connect to database
294    let handle = connect_options.connect(pool_cfg).await?;
295
296    Ok(handle)
297}
298
299fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
300    // If both engine and DSN are provided, validate they don't conflict.
301    // (We do the same check in validate_config_consistency, but keep this here to ensure
302    // determine_engine() never returns a misleading value.)
303    if let Some(engine) = cfg.engine {
304        if let Some(dsn) = cfg.dsn.as_deref() {
305            let inferred = engine_from_dsn(dsn)?;
306            if inferred != engine {
307                return Err(DbError::ConfigConflict(format!(
308                    "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
309                )));
310            }
311        }
312        return Ok(engine);
313    }
314
315    // If DSN is not provided, engine is required.
316    //
317    // Rationale:
318    // - Without DSN we cannot reliably distinguish Postgres vs MySQL.
319    // - For SQLite we also want explicit intent (file/path alone is not a transport selector).
320    if cfg.dsn.is_none() {
321        return Err(DbError::InvalidParameter(
322            "Missing 'engine': required when 'dsn' is not provided".to_owned(),
323        ));
324    }
325
326    // Infer from DSN scheme when present.
327    let Some(dsn) = cfg.dsn.as_deref() else {
328        // SAFETY: guarded above by `cfg.dsn.is_none()`.
329        return Err(DbError::InvalidParameter(
330            "Missing 'dsn': required to infer database engine".to_owned(),
331        ));
332    };
333    engine_from_dsn(dsn)
334}
335
336fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
337    let s = dsn.trim_start();
338    if s.starts_with("postgres://") || s.starts_with("postgresql://") {
339        Ok(DbEngineCfg::Postgres)
340    } else if s.starts_with("mysql://") {
341        Ok(DbEngineCfg::Mysql)
342    } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
343        Ok(DbEngineCfg::Sqlite)
344    } else {
345        Err(DbError::UnknownDsn(dsn.to_owned()))
346    }
347}
348
349/// Build `SQLite` connection options from configuration.
350#[cfg(feature = "sqlite")]
351fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
352    let db_path = if let Some(dsn) = &cfg.dsn {
353        parse_sqlite_path_from_dsn(dsn)?
354    } else if let Some(path) = &cfg.path {
355        path.clone()
356    } else if let Some(_file) = &cfg.file {
357        // This should not happen as manager.rs should have resolved file to path
358        return Err(DbError::InvalidParameter(
359            "File path should have been resolved to absolute path".to_owned(),
360        ));
361    } else {
362        return Err(DbError::InvalidParameter(
363            "SQLite connection requires either DSN, path, or file".to_owned(),
364        ));
365    };
366
367    // Ensure parent directory exists
368    if let Some(parent) = db_path.parent() {
369        std::fs::create_dir_all(parent)?;
370    }
371
372    let mut opts = sqlx::sqlite::SqliteConnectOptions::new()
373        .filename(&db_path)
374        .create_if_missing(true);
375
376    // Apply PRAGMA parameters with whitelist validation
377    if let Some(params) = &cfg.params {
378        opts = sqlite_pragma::apply_pragmas(opts, params)?;
379    }
380
381    Ok(DbConnectOptions::Sqlite(opts))
382}
383
384#[cfg(not(feature = "sqlite"))]
385fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
386    Err(DbError::FeatureDisabled("SQLite feature not enabled"))
387}
388
389/// Apply PostgreSQL-specific parameters, distinguishing connection-level params from runtime options.
390#[cfg(feature = "pg")]
391fn apply_pg_params<S: std::hash::BuildHasher>(
392    mut opts: sqlx::postgres::PgConnectOptions,
393    params: &std::collections::HashMap<String, String, S>,
394) -> Result<sqlx::postgres::PgConnectOptions> {
395    use sqlx::postgres::PgSslMode;
396
397    for (key, value) in params {
398        let key_lower = key.to_lowercase();
399        match key_lower.as_str() {
400            // Connection-level SSL parameters
401            "sslmode" | "ssl_mode" => {
402                let mode = value.parse::<PgSslMode>().map_err(|_| {
403                    DbError::InvalidParameter(format!(
404                        "Invalid ssl_mode '{value}': expected disable, allow, prefer, require, verify-ca, or verify-full"
405                    ))
406                })?;
407                opts = opts.ssl_mode(mode);
408            }
409            "sslrootcert" | "ssl_root_cert" => {
410                opts = opts.ssl_root_cert(value.as_str());
411            }
412            "sslcert" | "ssl_client_cert" => {
413                opts = opts.ssl_client_cert(value.as_str());
414            }
415            "sslkey" | "ssl_client_key" => {
416                opts = opts.ssl_client_key(value.as_str());
417            }
418            // Other connection-level parameters
419            "application_name" => {
420                opts = opts.application_name(value);
421            }
422            "statement_cache_capacity" => {
423                let capacity = value.parse::<usize>().map_err(|_| {
424                    DbError::InvalidParameter(format!(
425                        "Invalid statement_cache_capacity '{value}': expected positive integer"
426                    ))
427                })?;
428                opts = opts.statement_cache_capacity(capacity);
429            }
430            "extra_float_digits" => {
431                let val = value.parse::<i8>().map_err(|_| {
432                    DbError::InvalidParameter(format!(
433                        "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
434                    ))
435                })?;
436                if !(-15..=3).contains(&val) {
437                    return Err(DbError::InvalidParameter(format!(
438                        "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
439                    )));
440                }
441                opts = opts.extra_float_digits(val);
442            }
443            // Server runtime parameters go to options()
444            _ => {
445                opts = opts.options([(key.as_str(), value.as_str())]);
446            }
447        }
448    }
449
450    Ok(opts)
451}
452
453/// Apply `MySQL`-specific parameters. `MySQL` has no runtime options like `PostgreSQL`,
454/// so all params are connection-level settings.
455#[cfg(feature = "mysql")]
456fn apply_mysql_params<S: std::hash::BuildHasher>(
457    mut opts: sqlx::mysql::MySqlConnectOptions,
458    params: &std::collections::HashMap<String, String, S>,
459) -> Result<sqlx::mysql::MySqlConnectOptions> {
460    use sqlx::mysql::MySqlSslMode;
461
462    for (key, value) in params {
463        let key_lower = key.to_lowercase();
464        match key_lower.as_str() {
465            // SSL parameters
466            "sslmode" | "ssl_mode" | "ssl-mode" => {
467                let mode = value.parse::<MySqlSslMode>().map_err(|_| {
468                    DbError::InvalidParameter(format!(
469                        "Invalid ssl_mode '{value}': expected disabled, preferred, required, verify_ca, or verify_identity"
470                    ))
471                })?;
472                opts = opts.ssl_mode(mode);
473            }
474            "sslca" | "ssl_ca" | "ssl-ca" => {
475                opts = opts.ssl_ca(value.as_str());
476            }
477            "sslcert" | "ssl_client_cert" | "ssl-cert" => {
478                opts = opts.ssl_client_cert(value.as_str());
479            }
480            "sslkey" | "ssl_client_key" | "ssl-key" => {
481                opts = opts.ssl_client_key(value.as_str());
482            }
483            // Connection parameters
484            "charset" => {
485                opts = opts.charset(value);
486            }
487            "collation" => {
488                opts = opts.collation(value);
489            }
490            "statement_cache_capacity" => {
491                let capacity = value.parse::<usize>().map_err(|_| {
492                    DbError::InvalidParameter(format!(
493                        "Invalid statement_cache_capacity '{value}': expected positive integer"
494                    ))
495                })?;
496                opts = opts.statement_cache_capacity(capacity);
497            }
498            "connect_timeout" | "connect-timeout" => {
499                // NOTE: `sqlx::mysql::MySqlConnectOptions` does not expose a typed connect-timeout
500                // setter. We still accept and validate this parameter for compatibility with
501                // DSN-style configuration and integration tests, but it currently does not
502                // change runtime behavior.
503                let _secs = value.parse::<u64>().map_err(|_| {
504                    DbError::InvalidParameter(format!(
505                        "Invalid connect_timeout '{value}': expected non-negative integer seconds"
506                    ))
507                })?;
508            }
509            "socket" => {
510                opts = opts.socket(value.as_str());
511            }
512            "timezone" => {
513                let tz = if value.eq_ignore_ascii_case("none") || value.is_empty() {
514                    None
515                } else {
516                    Some(value.clone())
517                };
518                opts = opts.timezone(tz);
519            }
520            "pipes_as_concat" => {
521                let flag = parse_bool_param("pipes_as_concat", value)?;
522                opts = opts.pipes_as_concat(flag);
523            }
524            "no_engine_substitution" => {
525                let flag = parse_bool_param("no_engine_substitution", value)?;
526                opts = opts.no_engine_substitution(flag);
527            }
528            "enable_cleartext_plugin" => {
529                let flag = parse_bool_param("enable_cleartext_plugin", value)?;
530                opts = opts.enable_cleartext_plugin(flag);
531            }
532            "set_names" => {
533                let flag = parse_bool_param("set_names", value)?;
534                opts = opts.set_names(flag);
535            }
536            // Unknown parameters - MySQL doesn't support arbitrary runtime params
537            _ => {
538                return Err(DbError::InvalidParameter(format!(
539                    "Unknown MySQL connection parameter: '{key}'"
540                )));
541            }
542        }
543    }
544
545    Ok(opts)
546}
547
548/// Parse a boolean parameter value.
549#[cfg(feature = "mysql")]
550fn parse_bool_param(name: &str, value: &str) -> Result<bool> {
551    match value.to_lowercase().as_str() {
552        "true" | "1" | "yes" | "on" => Ok(true),
553        "false" | "0" | "no" | "off" => Ok(false),
554        _ => Err(DbError::InvalidParameter(format!(
555            "Invalid {name} '{value}': expected true/false/1/0/yes/no/on/off"
556        ))),
557    }
558}
559
560/// Build server-based connection options from configuration.
561fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
562    // When neither `pg` nor `mysql` features are enabled, the match arms that would use `cfg`
563    // are compiled out, but the function still needs to compile cleanly under `-D warnings`.
564    #[cfg(not(any(feature = "pg", feature = "mysql")))]
565    let _ = cfg;
566
567    match engine {
568        DbEngineCfg::Postgres => {
569            #[cfg(feature = "pg")]
570            {
571                let mut opts = if let Some(dsn) = &cfg.dsn {
572                    dsn.parse::<sqlx::postgres::PgConnectOptions>()
573                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
574                } else {
575                    sqlx::postgres::PgConnectOptions::new()
576                };
577
578                // Override with individual fields
579                if let Some(host) = &cfg.host {
580                    opts = opts.host(host);
581                }
582                if let Some(port) = cfg.port {
583                    opts = opts.port(port);
584                }
585                if let Some(user) = &cfg.user {
586                    opts = opts.username(user);
587                }
588                if let Some(password) = &cfg.password {
589                    opts = opts.password(password);
590                }
591                if let Some(dbname) = &cfg.dbname {
592                    opts = opts.database(dbname);
593                } else if cfg.dsn.is_none() {
594                    return Err(DbError::InvalidParameter(
595                        "dbname is required for PostgreSQL connections".to_owned(),
596                    ));
597                }
598
599                // Apply additional parameters
600                if let Some(params) = &cfg.params {
601                    opts = apply_pg_params(opts, params)?;
602                }
603
604                Ok(DbConnectOptions::Postgres(opts))
605            }
606            #[cfg(not(feature = "pg"))]
607            {
608                Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
609            }
610        }
611        DbEngineCfg::Mysql => {
612            #[cfg(feature = "mysql")]
613            {
614                let mut opts = if let Some(dsn) = &cfg.dsn {
615                    dsn.parse::<sqlx::mysql::MySqlConnectOptions>()
616                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
617                } else {
618                    sqlx::mysql::MySqlConnectOptions::new()
619                };
620
621                // Override with individual fields
622                if let Some(host) = &cfg.host {
623                    opts = opts.host(host);
624                }
625                if let Some(port) = cfg.port {
626                    opts = opts.port(port);
627                }
628                if let Some(user) = &cfg.user {
629                    opts = opts.username(user);
630                }
631                if let Some(password) = &cfg.password {
632                    opts = opts.password(password);
633                }
634                if let Some(dbname) = &cfg.dbname {
635                    opts = opts.database(dbname);
636                } else if cfg.dsn.is_none() {
637                    return Err(DbError::InvalidParameter(
638                        "dbname is required for MySQL connections".to_owned(),
639                    ));
640                }
641
642                // Apply additional parameters
643                if let Some(params) = &cfg.params {
644                    opts = apply_mysql_params(opts, params)?;
645                }
646
647                Ok(DbConnectOptions::MySql(opts))
648            }
649            #[cfg(not(feature = "mysql"))]
650            {
651                Err(DbError::FeatureDisabled("MySQL feature not enabled"))
652            }
653        }
654        DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
655            "build_server_options called with sqlite engine".to_owned(),
656        )),
657    }
658}
659
660/// Parse `SQLite` path from DSN.
661#[cfg(feature = "sqlite")]
662fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
663    if dsn.starts_with("sqlite:") {
664        let path_part = dsn
665            .strip_prefix("sqlite:")
666            .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
667        let path_part = if path_part.starts_with("//") {
668            path_part
669                .strip_prefix("//")
670                .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
671        } else {
672            path_part
673        };
674
675        // Remove query parameters
676        let path_part = if let Some(pos) = path_part.find('?') {
677            &path_part[..pos]
678        } else {
679            path_part
680        };
681
682        Ok(std::path::PathBuf::from(path_part))
683    } else {
684        Err(DbError::InvalidParameter(format!(
685            "Invalid SQLite DSN: {dsn}"
686        )))
687    }
688}
689
690/// Expand environment variables in a string.
691fn expand_env_vars(input: &str) -> Result<String> {
692    let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
693        .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
694    let mut result = input.to_owned();
695
696    for caps in re.captures_iter(input) {
697        let full_match = &caps[0];
698        let var_name = &caps[1];
699        let value = std::env::var(var_name)?;
700        result = result.replace(full_match, &value);
701    }
702
703    Ok(result)
704}
705
706/// Resolve password from environment variable if it starts with ${VAR}.
707fn resolve_password(password: &str) -> Result<String> {
708    if password.starts_with("${") && password.ends_with('}') {
709        let var_name = &password[2..password.len() - 1];
710        Ok(std::env::var(var_name)?)
711    } else {
712        Ok(password.to_owned())
713    }
714}
715
716/// Validate configuration for consistency and detect conflicts.
717fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
718    // Validate engine against DSN if both are present
719    if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
720        let inferred = engine_from_dsn(dsn)?;
721        if inferred != engine {
722            return Err(DbError::ConfigConflict(format!(
723                "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
724            )));
725        }
726    }
727
728    // Check for SQLite vs server engine conflicts
729    if let Some(dsn) = &cfg.dsn {
730        let is_sqlite_dsn = dsn.starts_with("sqlite");
731        let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
732        let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
733
734        if is_sqlite_dsn && has_server_fields {
735            return Err(DbError::ConfigConflict(
736                "SQLite DSN cannot be used with host/port fields".to_owned(),
737            ));
738        }
739
740        if !is_sqlite_dsn && has_sqlite_fields {
741            return Err(DbError::ConfigConflict(
742                "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
743            ));
744        }
745
746        // Check for server vs non-server DSN conflicts
747        if !is_sqlite_dsn
748            && cfg.server.is_some()
749            && (cfg.host.is_some()
750                || cfg.port.is_some()
751                || cfg.user.is_some()
752                || cfg.password.is_some()
753                || cfg.dbname.is_some())
754        {
755            // This is actually allowed - server provides base config, DSN can override
756            // Fields here override DSN parts intentionally.
757        }
758    }
759
760    // Check for SQLite-specific conflicts
761    if cfg.file.is_some() && cfg.path.is_some() {
762        return Err(DbError::ConfigConflict(
763            "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
764        ));
765    }
766
767    if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
768        return Err(DbError::ConfigConflict(
769            "SQLite file/path fields cannot be used with host/port fields".to_owned(),
770        ));
771    }
772
773    // If engine explicitly says SQLite, reject server connection fields early (even without DSN)
774    if cfg.engine == Some(DbEngineCfg::Sqlite)
775        && (cfg.host.is_some()
776            || cfg.port.is_some()
777            || cfg.user.is_some()
778            || cfg.password.is_some()
779            || cfg.dbname.is_some())
780    {
781        return Err(DbError::ConfigConflict(
782            "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
783        ));
784    }
785
786    // If engine explicitly says server-based, reject sqlite file/path early (even without DSN)
787    if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
788        && (cfg.file.is_some() || cfg.path.is_some())
789    {
790        return Err(DbError::ConfigConflict(
791            "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
792        ));
793    }
794
795    Ok(())
796}
797
798/// Redact credentials from DSN for logging.
799#[must_use]
800pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
801    match dsn {
802        Some(dsn) if dsn.contains('@') => {
803            if let Ok(mut parsed) = url::Url::parse(dsn) {
804                if parsed.password().is_some() {
805                    _ = parsed.set_password(Some("***"));
806                }
807                parsed.to_string()
808            } else {
809                "***".to_owned()
810            }
811        }
812        Some(dsn) => dsn.to_owned(),
813        None => "none".to_owned(),
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820
821    #[test]
822    fn determine_engine_requires_engine_when_dsn_missing() {
823        let cfg = DbConnConfig {
824            dsn: None,
825            engine: None,
826            ..Default::default()
827        };
828
829        let err = determine_engine(&cfg).unwrap_err();
830        assert!(matches!(err, DbError::InvalidParameter(_)));
831        assert!(err.to_string().contains("Missing 'engine'"));
832    }
833
834    #[test]
835    fn determine_engine_infers_from_dsn_when_engine_missing() {
836        let cfg = DbConnConfig {
837            engine: None,
838            dsn: Some("sqlite::memory:".to_owned()),
839            ..Default::default()
840        };
841
842        let engine = determine_engine(&cfg).unwrap();
843        assert_eq!(engine, DbEngineCfg::Sqlite);
844    }
845
846    #[test]
847    fn engine_and_dsn_match_ok() {
848        let cases = [
849            (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
850            (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
851            (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
852            (DbEngineCfg::Sqlite, "sqlite::memory:"),
853            (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
854        ];
855
856        for (engine, dsn) in cases {
857            let cfg = DbConnConfig {
858                engine: Some(engine),
859                dsn: Some(dsn.to_owned()),
860                ..Default::default()
861            };
862            validate_config_consistency(&cfg).unwrap();
863            assert_eq!(determine_engine(&cfg).unwrap(), engine);
864        }
865    }
866
867    #[test]
868    fn engine_and_dsn_mismatch_is_error() {
869        let cases = [
870            (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
871            (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
872            (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
873        ];
874
875        for (engine, dsn) in cases {
876            let cfg = DbConnConfig {
877                engine: Some(engine),
878                dsn: Some(dsn.to_owned()),
879                ..Default::default()
880            };
881
882            let err = validate_config_consistency(&cfg).unwrap_err();
883            assert!(matches!(err, DbError::ConfigConflict(_)));
884        }
885    }
886
887    #[test]
888    fn unknown_dsn_is_error() {
889        let cfg = DbConnConfig {
890            engine: None,
891            dsn: Some("unknown://localhost/db".to_owned()),
892            ..Default::default()
893        };
894
895        // Consistency validation doesn't validate unknown schemes unless `engine` is set,
896        // but engine determination must fail.
897        validate_config_consistency(&cfg).unwrap();
898        let err = determine_engine(&cfg).unwrap_err();
899        assert!(matches!(err, DbError::UnknownDsn(_)));
900    }
901}