Skip to main content

modkit_db/
options.rs

1//! Database connection options and configuration types.
2
3use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5use thiserror::Error;
6
7// Pool configuration moved to config.rs
8
9/// Database connection options using typed sqlx `ConnectOptions`.
10#[derive(Debug, Clone)]
11pub enum DbConnectOptions {
12    #[cfg(feature = "sqlite")]
13    Sqlite(sea_orm::sqlx::sqlite::SqliteConnectOptions),
14    #[cfg(feature = "pg")]
15    Postgres(sea_orm::sqlx::postgres::PgConnectOptions),
16    #[cfg(feature = "mysql")]
17    MySql(sea_orm::sqlx::mysql::MySqlConnectOptions),
18}
19
20/// Errors that can occur during connection option building.
21#[derive(Debug, Error)]
22pub enum ConnectionOptionsError {
23    #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
24    InvalidSqlitePragma { key: String, message: String },
25
26    #[error("Unknown SQLite PRAGMA parameter: {0}")]
27    UnknownSqlitePragma(String),
28
29    #[error("Invalid connection parameter: {0}")]
30    InvalidParameter(String),
31
32    #[error("Feature not enabled: {0}")]
33    FeatureDisabled(&'static str),
34
35    #[error("IO error: {0}")]
36    Io(#[from] std::io::Error),
37
38    #[error("URL parsing error: {0}")]
39    UrlParse(#[from] url::ParseError),
40
41    #[error("Environment variable error: {0}")]
42    EnvVar(#[from] std::env::VarError),
43}
44
45impl std::fmt::Display for DbConnectOptions {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            #[cfg(feature = "sqlite")]
49            DbConnectOptions::Sqlite(opts) => {
50                let filename = opts.get_filename().display().to_string();
51                if filename.is_empty() {
52                    write!(f, "sqlite://memory")
53                } else {
54                    write!(f, "sqlite://{filename}")
55                }
56            }
57            #[cfg(feature = "pg")]
58            DbConnectOptions::Postgres(opts) => {
59                write!(
60                    f,
61                    "postgresql://<redacted>@{}:{}/{}",
62                    opts.get_host(),
63                    opts.get_port(),
64                    opts.get_database().unwrap_or("")
65                )
66            }
67            #[cfg(feature = "mysql")]
68            DbConnectOptions::MySql(_opts) => {
69                write!(f, "mysql://<redacted>@...")
70            }
71            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
72            _ => {
73                unreachable!("No database features enabled")
74            }
75        }
76    }
77}
78
79impl DbConnectOptions {
80    /// Connect to the database using the configured options.
81    ///
82    /// # Errors
83    /// Returns an error if the database connection fails.
84    pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
85        match self {
86            #[cfg(feature = "sqlite")]
87            DbConnectOptions::Sqlite(opts) => {
88                let pool_opts = pool.apply_sqlite(sea_orm::sqlx::sqlite::SqlitePoolOptions::new());
89
90                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
91
92                let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool.clone());
93
94                let filename = opts.get_filename().display().to_string();
95                let handle = DbHandle {
96                    engine: crate::DbEngine::Sqlite,
97                    pool: crate::DbPool::Sqlite(sqlx_pool),
98                    dsn: format!("sqlite://{filename}"),
99                    sea,
100                };
101
102                Ok(handle)
103            }
104            #[cfg(feature = "pg")]
105            DbConnectOptions::Postgres(opts) => {
106                let pool_opts = pool.apply_pg(sea_orm::sqlx::postgres::PgPoolOptions::new());
107
108                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
109
110                let sea =
111                    sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool.clone());
112
113                let handle = DbHandle {
114                    engine: crate::DbEngine::Postgres,
115                    pool: crate::DbPool::Postgres(sqlx_pool),
116                    dsn: format!(
117                        "postgresql://<redacted>@{}:{}/{}",
118                        opts.get_host(),
119                        opts.get_port(),
120                        opts.get_database().unwrap_or("")
121                    ),
122                    sea,
123                };
124
125                Ok(handle)
126            }
127            #[cfg(feature = "mysql")]
128            DbConnectOptions::MySql(opts) => {
129                let pool_opts = pool.apply_mysql(sea_orm::sqlx::mysql::MySqlPoolOptions::new());
130
131                let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
132
133                let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool.clone());
134
135                let handle = DbHandle {
136                    engine: crate::DbEngine::MySql,
137                    pool: crate::DbPool::MySql(sqlx_pool),
138                    dsn: "mysql://<redacted>@...".to_owned(),
139                    sea,
140                };
141
142                Ok(handle)
143            }
144            #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
145            _ => {
146                unreachable!("No database features enabled")
147            }
148        }
149    }
150}
151
152/// `SQLite` PRAGMA whitelist and validation.
153#[cfg(feature = "sqlite")]
154pub mod sqlite_pragma {
155    use crate::DbError;
156    use std::collections::HashMap;
157    use std::hash::BuildHasher;
158
159    /// Whitelisted `SQLite` PRAGMA parameters.
160    const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
161
162    /// Validate and apply `SQLite` PRAGMA parameters to connection options.
163    ///
164    /// # Errors
165    /// Returns `DbError::UnknownSqlitePragma` if an unsupported pragma is provided.
166    /// Returns `DbError::InvalidSqlitePragmaValue` if a pragma value is invalid.
167    pub fn apply_pragmas<S: BuildHasher>(
168        mut opts: sea_orm::sqlx::sqlite::SqliteConnectOptions,
169        params: &HashMap<String, String, S>,
170    ) -> crate::Result<sea_orm::sqlx::sqlite::SqliteConnectOptions> {
171        for (key, value) in params {
172            let key_lower = key.to_lowercase();
173
174            if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
175                return Err(DbError::UnknownSqlitePragma(key.clone()));
176            }
177
178            match key_lower.as_str() {
179                "wal" => {
180                    let journal_mode = validate_wal_pragma(value)?;
181                    opts = opts.pragma("journal_mode", journal_mode);
182                }
183                "journal_mode" => {
184                    let mode = validate_journal_mode_pragma(value)?;
185                    opts = opts.pragma("journal_mode", mode);
186                }
187                "synchronous" => {
188                    let sync_mode = validate_synchronous_pragma(value)?;
189                    opts = opts.pragma("synchronous", sync_mode);
190                }
191                "busy_timeout" => {
192                    let timeout = validate_busy_timeout_pragma(value)?;
193                    opts = opts.pragma("busy_timeout", timeout.to_string());
194                }
195                _ => unreachable!("Checked against whitelist above"),
196            }
197        }
198
199        Ok(opts)
200    }
201
202    /// Validate WAL PRAGMA value.
203    fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
204        match value.to_lowercase().as_str() {
205            "true" | "1" => Ok("WAL"),
206            "false" | "0" => Ok("DELETE"),
207            _ => Err(DbError::InvalidSqlitePragma {
208                key: "wal".to_owned(),
209                message: format!("must be true/false/1/0, got '{value}'"),
210            }),
211        }
212    }
213
214    /// Validate synchronous PRAGMA value.
215    fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
216        match value.to_uppercase().as_str() {
217            "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
218            _ => Err(DbError::InvalidSqlitePragma {
219                key: "synchronous".to_owned(),
220                message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
221            }),
222        }
223    }
224
225    /// Validate `busy_timeout` PRAGMA value.
226    fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
227        let timeout = value
228            .parse::<i64>()
229            .map_err(|_| DbError::InvalidSqlitePragma {
230                key: "busy_timeout".to_owned(),
231                message: format!("must be a non-negative integer, got '{value}'"),
232            })?;
233
234        if timeout < 0 {
235            return Err(DbError::InvalidSqlitePragma {
236                key: "busy_timeout".to_owned(),
237                message: format!("must be non-negative, got '{timeout}'"),
238            });
239        }
240
241        Ok(timeout)
242    }
243
244    /// Validate `journal_mode` PRAGMA value.
245    fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
246        match value.to_uppercase().as_str() {
247            "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
248                Ok(value.to_uppercase())
249            }
250            _ => Err(DbError::InvalidSqlitePragma {
251                key: "journal_mode".to_owned(),
252                message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
253            }),
254        }
255    }
256}
257
258/// Build a database handle from configuration.
259/// This is the unified entry point for building database handles from configuration.
260///
261/// # Errors
262/// Returns an error if the database connection fails or configuration is invalid.
263pub async fn build_db_handle(
264    mut cfg: DbConnConfig,
265    _global: Option<&GlobalDatabaseConfig>,
266) -> Result<DbHandle> {
267    // Expand environment variables in DSN and password
268    if let Some(dsn) = &cfg.dsn {
269        cfg.dsn = Some(expand_env_vars(dsn)?);
270    }
271    if let Some(password) = &cfg.password {
272        cfg.password = Some(resolve_password(password)?);
273    }
274
275    // Expand environment variables in params
276    if let Some(ref mut params) = cfg.params {
277        for (_, value) in params.iter_mut() {
278            if value.contains("${") {
279                *value = expand_env_vars(value)?;
280            }
281        }
282    }
283
284    // Validate configuration for conflicts
285    validate_config_consistency(&cfg)?;
286
287    // Determine database engine and build connection options.
288    let engine = determine_engine(&cfg)?;
289    let connect_options = match engine {
290        DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
291        DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
292    };
293
294    // Build pool configuration
295    let pool_cfg = cfg.pool.unwrap_or_default();
296
297    // Log connection attempt (without credentials)
298    let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
299    tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
300
301    // Connect to database
302    let handle = connect_options.connect(pool_cfg).await?;
303
304    Ok(handle)
305}
306
307fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
308    // If both engine and DSN are provided, validate they don't conflict.
309    // (We do the same check in validate_config_consistency, but keep this here to ensure
310    // determine_engine() never returns a misleading value.)
311    if let Some(engine) = cfg.engine {
312        if let Some(dsn) = cfg.dsn.as_deref() {
313            let inferred = engine_from_dsn(dsn)?;
314            if inferred != engine {
315                return Err(DbError::ConfigConflict(format!(
316                    "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
317                )));
318            }
319        }
320        return Ok(engine);
321    }
322
323    // If DSN is not provided, engine is required.
324    //
325    // Rationale:
326    // - Without DSN we cannot reliably distinguish Postgres vs MySQL.
327    // - For SQLite we also want explicit intent (file/path alone is not a transport selector).
328    if cfg.dsn.is_none() {
329        return Err(DbError::InvalidParameter(
330            "Missing 'engine': required when 'dsn' is not provided".to_owned(),
331        ));
332    }
333
334    // Infer from DSN scheme when present.
335    let Some(dsn) = cfg.dsn.as_deref() else {
336        // SAFETY: guarded above by `cfg.dsn.is_none()`.
337        return Err(DbError::InvalidParameter(
338            "Missing 'dsn': required to infer database engine".to_owned(),
339        ));
340    };
341    engine_from_dsn(dsn)
342}
343
344fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
345    let s = dsn.trim_start();
346    if s.starts_with("postgres://") || s.starts_with("postgresql://") {
347        Ok(DbEngineCfg::Postgres)
348    } else if s.starts_with("mysql://") {
349        Ok(DbEngineCfg::Mysql)
350    } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
351        Ok(DbEngineCfg::Sqlite)
352    } else {
353        Err(DbError::UnknownDsn(dsn.to_owned()))
354    }
355}
356
357/// Build `SQLite` connection options from configuration.
358#[cfg(feature = "sqlite")]
359fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
360    let db_path = if let Some(dsn) = &cfg.dsn {
361        parse_sqlite_path_from_dsn(dsn)?
362    } else if let Some(path) = &cfg.path {
363        path.clone()
364    } else if let Some(_file) = &cfg.file {
365        // This should not happen as manager.rs should have resolved file to path
366        return Err(DbError::InvalidParameter(
367            "File path should have been resolved to absolute path".to_owned(),
368        ));
369    } else {
370        return Err(DbError::InvalidParameter(
371            "SQLite connection requires either DSN, path, or file".to_owned(),
372        ));
373    };
374
375    // Ensure parent directory exists
376    if let Some(parent) = db_path.parent() {
377        std::fs::create_dir_all(parent)?;
378    }
379
380    let mut opts = sea_orm::sqlx::sqlite::SqliteConnectOptions::new()
381        .filename(&db_path)
382        .create_if_missing(true);
383
384    // Apply PRAGMA parameters with whitelist validation
385    if let Some(params) = &cfg.params {
386        opts = sqlite_pragma::apply_pragmas(opts, params)?;
387    }
388
389    Ok(DbConnectOptions::Sqlite(opts))
390}
391
392#[cfg(not(feature = "sqlite"))]
393fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
394    Err(DbError::FeatureDisabled("SQLite feature not enabled"))
395}
396
397/// Build server-based connection options from configuration.
398fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
399    // When neither `pg` nor `mysql` features are enabled, the match arms that would use `cfg`
400    // are compiled out, but the function still needs to compile cleanly under `-D warnings`.
401    #[cfg(not(any(feature = "pg", feature = "mysql")))]
402    let _ = cfg;
403
404    match engine {
405        DbEngineCfg::Postgres => {
406            #[cfg(feature = "pg")]
407            {
408                let mut opts = if let Some(dsn) = &cfg.dsn {
409                    dsn.parse::<sea_orm::sqlx::postgres::PgConnectOptions>()
410                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
411                } else {
412                    sea_orm::sqlx::postgres::PgConnectOptions::new()
413                };
414
415                // Override with individual fields
416                if let Some(host) = &cfg.host {
417                    opts = opts.host(host);
418                }
419                if let Some(port) = cfg.port {
420                    opts = opts.port(port);
421                }
422                if let Some(user) = &cfg.user {
423                    opts = opts.username(user);
424                }
425                if let Some(password) = &cfg.password {
426                    opts = opts.password(password);
427                }
428                if let Some(dbname) = &cfg.dbname {
429                    opts = opts.database(dbname);
430                } else if cfg.dsn.is_none() {
431                    return Err(DbError::InvalidParameter(
432                        "dbname is required for PostgreSQL connections".to_owned(),
433                    ));
434                }
435
436                // Apply additional parameters
437                if let Some(params) = &cfg.params {
438                    for (key, value) in params {
439                        opts = opts.options([(key.as_str(), value.as_str())]);
440                    }
441                }
442
443                Ok(DbConnectOptions::Postgres(opts))
444            }
445            #[cfg(not(feature = "pg"))]
446            {
447                Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
448            }
449        }
450        DbEngineCfg::Mysql => {
451            #[cfg(feature = "mysql")]
452            {
453                let mut opts = if let Some(dsn) = &cfg.dsn {
454                    dsn.parse::<sea_orm::sqlx::mysql::MySqlConnectOptions>()
455                        .map_err(|e| DbError::InvalidParameter(e.to_string()))?
456                } else {
457                    sea_orm::sqlx::mysql::MySqlConnectOptions::new()
458                };
459
460                // Override with individual fields
461                if let Some(host) = &cfg.host {
462                    opts = opts.host(host);
463                }
464                if let Some(port) = cfg.port {
465                    opts = opts.port(port);
466                }
467                if let Some(user) = &cfg.user {
468                    opts = opts.username(user);
469                }
470                if let Some(password) = &cfg.password {
471                    opts = opts.password(password);
472                }
473                if let Some(dbname) = &cfg.dbname {
474                    opts = opts.database(dbname);
475                } else if cfg.dsn.is_none() {
476                    return Err(DbError::InvalidParameter(
477                        "dbname is required for MySQL connections".to_owned(),
478                    ));
479                }
480
481                Ok(DbConnectOptions::MySql(opts))
482            }
483            #[cfg(not(feature = "mysql"))]
484            {
485                Err(DbError::FeatureDisabled("MySQL feature not enabled"))
486            }
487        }
488        DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
489            "build_server_options called with sqlite engine".to_owned(),
490        )),
491    }
492}
493
494/// Parse `SQLite` path from DSN.
495#[cfg(feature = "sqlite")]
496fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
497    if dsn.starts_with("sqlite:") {
498        let path_part = dsn
499            .strip_prefix("sqlite:")
500            .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
501        let path_part = if path_part.starts_with("//") {
502            path_part
503                .strip_prefix("//")
504                .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
505        } else {
506            path_part
507        };
508
509        // Remove query parameters
510        let path_part = if let Some(pos) = path_part.find('?') {
511            &path_part[..pos]
512        } else {
513            path_part
514        };
515
516        Ok(std::path::PathBuf::from(path_part))
517    } else {
518        Err(DbError::InvalidParameter(format!(
519            "Invalid SQLite DSN: {dsn}"
520        )))
521    }
522}
523
524/// Expand environment variables in a string.
525fn expand_env_vars(input: &str) -> Result<String> {
526    let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
527        .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
528    let mut result = input.to_owned();
529
530    for caps in re.captures_iter(input) {
531        let full_match = &caps[0];
532        let var_name = &caps[1];
533        let value = std::env::var(var_name)?;
534        result = result.replace(full_match, &value);
535    }
536
537    Ok(result)
538}
539
540/// Resolve password from environment variable if it starts with ${VAR}.
541fn resolve_password(password: &str) -> Result<String> {
542    if password.starts_with("${") && password.ends_with('}') {
543        let var_name = &password[2..password.len() - 1];
544        Ok(std::env::var(var_name)?)
545    } else {
546        Ok(password.to_owned())
547    }
548}
549
550/// Validate configuration for consistency and detect conflicts.
551fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
552    // Validate engine against DSN if both are present
553    if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
554        let inferred = engine_from_dsn(dsn)?;
555        if inferred != engine {
556            return Err(DbError::ConfigConflict(format!(
557                "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
558            )));
559        }
560    }
561
562    // Check for SQLite vs server engine conflicts
563    if let Some(dsn) = &cfg.dsn {
564        let is_sqlite_dsn = dsn.starts_with("sqlite");
565        let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
566        let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
567
568        if is_sqlite_dsn && has_server_fields {
569            return Err(DbError::ConfigConflict(
570                "SQLite DSN cannot be used with host/port fields".to_owned(),
571            ));
572        }
573
574        if !is_sqlite_dsn && has_sqlite_fields {
575            return Err(DbError::ConfigConflict(
576                "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
577            ));
578        }
579
580        // Check for server vs non-server DSN conflicts
581        if !is_sqlite_dsn
582            && cfg.server.is_some()
583            && (cfg.host.is_some()
584                || cfg.port.is_some()
585                || cfg.user.is_some()
586                || cfg.password.is_some()
587                || cfg.dbname.is_some())
588        {
589            // This is actually allowed - server provides base config, DSN can override
590            // Fields here override DSN parts intentionally.
591        }
592    }
593
594    // Check for SQLite-specific conflicts
595    if cfg.file.is_some() && cfg.path.is_some() {
596        return Err(DbError::ConfigConflict(
597            "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
598        ));
599    }
600
601    if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
602        return Err(DbError::ConfigConflict(
603            "SQLite file/path fields cannot be used with host/port fields".to_owned(),
604        ));
605    }
606
607    // If engine explicitly says SQLite, reject server connection fields early (even without DSN)
608    if cfg.engine == Some(DbEngineCfg::Sqlite)
609        && (cfg.host.is_some()
610            || cfg.port.is_some()
611            || cfg.user.is_some()
612            || cfg.password.is_some()
613            || cfg.dbname.is_some())
614    {
615        return Err(DbError::ConfigConflict(
616            "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
617        ));
618    }
619
620    // If engine explicitly says server-based, reject sqlite file/path early (even without DSN)
621    if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
622        && (cfg.file.is_some() || cfg.path.is_some())
623    {
624        return Err(DbError::ConfigConflict(
625            "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
626        ));
627    }
628
629    Ok(())
630}
631
632/// Redact credentials from DSN for logging.
633#[must_use]
634pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
635    match dsn {
636        Some(dsn) if dsn.contains('@') => {
637            if let Ok(mut parsed) = url::Url::parse(dsn) {
638                if parsed.password().is_some() {
639                    let _ = parsed.set_password(Some("***"));
640                }
641                parsed.to_string()
642            } else {
643                "***".to_owned()
644            }
645        }
646        Some(dsn) => dsn.to_owned(),
647        None => "none".to_owned(),
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn determine_engine_requires_engine_when_dsn_missing() {
657        let cfg = DbConnConfig {
658            dsn: None,
659            engine: None,
660            ..Default::default()
661        };
662
663        let err = determine_engine(&cfg).unwrap_err();
664        assert!(matches!(err, DbError::InvalidParameter(_)));
665        assert!(err.to_string().contains("Missing 'engine'"));
666    }
667
668    #[test]
669    fn determine_engine_infers_from_dsn_when_engine_missing() {
670        let cfg = DbConnConfig {
671            engine: None,
672            dsn: Some("sqlite::memory:".to_owned()),
673            ..Default::default()
674        };
675
676        let engine = determine_engine(&cfg).unwrap();
677        assert_eq!(engine, DbEngineCfg::Sqlite);
678    }
679
680    #[test]
681    fn engine_and_dsn_match_ok() {
682        let cases = [
683            (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
684            (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
685            (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
686            (DbEngineCfg::Sqlite, "sqlite::memory:"),
687            (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
688        ];
689
690        for (engine, dsn) in cases {
691            let cfg = DbConnConfig {
692                engine: Some(engine),
693                dsn: Some(dsn.to_owned()),
694                ..Default::default()
695            };
696            validate_config_consistency(&cfg).unwrap();
697            assert_eq!(determine_engine(&cfg).unwrap(), engine);
698        }
699    }
700
701    #[test]
702    fn engine_and_dsn_mismatch_is_error() {
703        let cases = [
704            (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
705            (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
706            (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
707        ];
708
709        for (engine, dsn) in cases {
710            let cfg = DbConnConfig {
711                engine: Some(engine),
712                dsn: Some(dsn.to_owned()),
713                ..Default::default()
714            };
715
716            let err = validate_config_consistency(&cfg).unwrap_err();
717            assert!(matches!(err, DbError::ConfigConflict(_)));
718        }
719    }
720
721    #[test]
722    fn unknown_dsn_is_error() {
723        let cfg = DbConnConfig {
724            engine: None,
725            dsn: Some("unknown://localhost/db".to_owned()),
726            ..Default::default()
727        };
728
729        // Consistency validation doesn't validate unknown schemes unless `engine` is set,
730        // but engine determination must fail.
731        validate_config_consistency(&cfg).unwrap();
732        let err = determine_engine(&cfg).unwrap_err();
733        assert!(matches!(err, DbError::UnknownDsn(_)));
734    }
735}