1use modkit_utils::var_expand::expand_env_vars;
4
5use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
6use crate::{DbError, DbHandle, Result};
7
8#[derive(Debug, Clone)]
12pub(crate) enum DbConnectOptions {
13 #[cfg(feature = "sqlite")]
14 Sqlite(sqlx::sqlite::SqliteConnectOptions),
15 #[cfg(feature = "pg")]
16 Postgres(sqlx::postgres::PgConnectOptions),
17 #[cfg(feature = "mysql")]
18 MySql(sqlx::mysql::MySqlConnectOptions),
19}
20
21impl std::fmt::Display for DbConnectOptions {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 #[cfg(feature = "sqlite")]
25 DbConnectOptions::Sqlite(opts) => {
26 let filename = opts.get_filename().display().to_string();
27 if filename.is_empty() {
28 write!(f, "sqlite://memory")
29 } else {
30 write!(f, "sqlite://{filename}")
31 }
32 }
33 #[cfg(feature = "pg")]
34 DbConnectOptions::Postgres(opts) => {
35 write!(
36 f,
37 "postgresql://<redacted>@{}:{}/{}",
38 opts.get_host(),
39 opts.get_port(),
40 opts.get_database().unwrap_or("")
41 )
42 }
43 #[cfg(feature = "mysql")]
44 DbConnectOptions::MySql(_opts) => {
45 write!(f, "mysql://<redacted>@...")
46 }
47 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
48 _ => {
49 unreachable!("No database features enabled")
50 }
51 }
52 }
53}
54
55#[cfg(feature = "sqlite")]
56fn is_memory_filename(path: &std::path::Path) -> bool {
57 if path.as_os_str().is_empty() {
58 return true;
59 }
60
61 match path.to_str() {
62 Some(raw) => matches!(
63 raw.trim(),
64 ":memory:" | "memory:" | "file::memory:" | "file:memory:" | ""
65 ),
66 None => false,
67 }
68}
69
70impl DbConnectOptions {
71 pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
76 match self {
77 #[cfg(feature = "sqlite")]
78 DbConnectOptions::Sqlite(opts) => {
79 let mut pool_opts = pool.apply_sqlite(sqlx::sqlite::SqlitePoolOptions::new());
80
81 if is_memory_filename(opts.get_filename()) {
82 pool_opts = pool_opts.max_connections(1).min_connections(1);
83 tracing::info!("Using single connection pool for in-memory SQLite database");
84 }
85
86 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
87
88 let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool);
89
90 let filename = opts.get_filename().display().to_string();
91 let handle = DbHandle {
92 engine: crate::DbEngine::Sqlite,
93 dsn: format!("sqlite://{filename}"),
94 sea,
95 };
96
97 Ok(handle)
98 }
99 #[cfg(feature = "pg")]
100 DbConnectOptions::Postgres(opts) => {
101 let pool_opts = pool.apply_pg(sqlx::postgres::PgPoolOptions::new());
102
103 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
104
105 let sea = sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool);
106
107 let handle = DbHandle {
108 engine: crate::DbEngine::Postgres,
109 dsn: format!(
110 "postgresql://<redacted>@{}:{}/{}",
111 opts.get_host(),
112 opts.get_port(),
113 opts.get_database().unwrap_or("")
114 ),
115 sea,
116 };
117
118 Ok(handle)
119 }
120 #[cfg(feature = "mysql")]
121 DbConnectOptions::MySql(opts) => {
122 let pool_opts = pool.apply_mysql(sqlx::mysql::MySqlPoolOptions::new());
123
124 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
125
126 let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool);
127
128 let handle = DbHandle {
129 engine: crate::DbEngine::MySql,
130 dsn: "mysql://<redacted>@...".to_owned(),
131 sea,
132 };
133
134 Ok(handle)
135 }
136 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
137 _ => {
138 unreachable!("No database features enabled")
139 }
140 }
141 }
142}
143
144#[cfg(feature = "sqlite")]
146pub mod sqlite_pragma {
147 use crate::DbError;
148 use std::collections::HashMap;
149 use std::hash::BuildHasher;
150
151 const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
153
154 pub fn apply_pragmas<S: BuildHasher>(
160 mut opts: sqlx::sqlite::SqliteConnectOptions,
161 params: &HashMap<String, String, S>,
162 ) -> crate::Result<sqlx::sqlite::SqliteConnectOptions> {
163 for (key, value) in params {
164 let key_lower = key.to_lowercase();
165
166 if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
167 return Err(DbError::UnknownSqlitePragma(key.clone()));
168 }
169
170 match key_lower.as_str() {
171 "wal" => {
172 let journal_mode = validate_wal_pragma(value)?;
173 opts = opts.pragma("journal_mode", journal_mode);
174 }
175 "journal_mode" => {
176 let mode = validate_journal_mode_pragma(value)?;
177 opts = opts.pragma("journal_mode", mode);
178 }
179 "synchronous" => {
180 let sync_mode = validate_synchronous_pragma(value)?;
181 opts = opts.pragma("synchronous", sync_mode);
182 }
183 "busy_timeout" => {
184 let timeout = validate_busy_timeout_pragma(value)?;
185 opts = opts.pragma("busy_timeout", timeout.to_string());
186 }
187 _ => unreachable!("Checked against whitelist above"),
188 }
189 }
190
191 Ok(opts)
192 }
193
194 fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
196 match value.to_lowercase().as_str() {
197 "true" | "1" => Ok("WAL"),
198 "false" | "0" => Ok("DELETE"),
199 _ => Err(DbError::InvalidSqlitePragma {
200 key: "wal".to_owned(),
201 message: format!("must be true/false/1/0, got '{value}'"),
202 }),
203 }
204 }
205
206 fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
208 match value.to_uppercase().as_str() {
209 "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
210 _ => Err(DbError::InvalidSqlitePragma {
211 key: "synchronous".to_owned(),
212 message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
213 }),
214 }
215 }
216
217 fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
219 let timeout = value
220 .parse::<i64>()
221 .map_err(|_| DbError::InvalidSqlitePragma {
222 key: "busy_timeout".to_owned(),
223 message: format!("must be a non-negative integer, got '{value}'"),
224 })?;
225
226 if timeout < 0 {
227 return Err(DbError::InvalidSqlitePragma {
228 key: "busy_timeout".to_owned(),
229 message: format!("must be non-negative, got '{timeout}'"),
230 });
231 }
232
233 Ok(timeout)
234 }
235
236 fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
238 match value.to_uppercase().as_str() {
239 "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
240 Ok(value.to_uppercase())
241 }
242 _ => Err(DbError::InvalidSqlitePragma {
243 key: "journal_mode".to_owned(),
244 message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
245 }),
246 }
247 }
248}
249
250pub(crate) async fn build_db_handle(
258 mut cfg: DbConnConfig,
259 _global: Option<&GlobalDatabaseConfig>,
260) -> Result<DbHandle> {
261 if let Some(dsn) = &cfg.dsn {
263 cfg.dsn = Some(expand_env_vars(dsn)?);
264 }
265 if let Some(password) = &cfg.password {
266 cfg.password = Some(resolve_password(password)?);
267 }
268
269 if let Some(ref mut params) = cfg.params {
271 for (_, value) in params.iter_mut() {
272 if value.contains("${") {
273 *value = expand_env_vars(value)?;
274 }
275 }
276 }
277
278 validate_config_consistency(&cfg)?;
280
281 let engine = determine_engine(&cfg)?;
283 let connect_options = match engine {
284 DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
285 DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
286 };
287
288 let pool_cfg = cfg.pool.unwrap_or_default();
290
291 let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
293 tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
294
295 let handle = connect_options.connect(pool_cfg).await?;
297
298 Ok(handle)
299}
300
301fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
302 if let Some(engine) = cfg.engine {
306 if let Some(dsn) = cfg.dsn.as_deref() {
307 let inferred = engine_from_dsn(dsn)?;
308 if inferred != engine {
309 return Err(DbError::ConfigConflict(format!(
310 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
311 )));
312 }
313 }
314 return Ok(engine);
315 }
316
317 if cfg.dsn.is_none() {
323 return Err(DbError::InvalidParameter(
324 "Missing 'engine': required when 'dsn' is not provided".to_owned(),
325 ));
326 }
327
328 let Some(dsn) = cfg.dsn.as_deref() else {
330 return Err(DbError::InvalidParameter(
332 "Missing 'dsn': required to infer database engine".to_owned(),
333 ));
334 };
335 engine_from_dsn(dsn)
336}
337
338fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
339 let s = dsn.trim_start();
340 if s.starts_with("postgres://") || s.starts_with("postgresql://") {
341 Ok(DbEngineCfg::Postgres)
342 } else if s.starts_with("mysql://") {
343 Ok(DbEngineCfg::Mysql)
344 } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
345 Ok(DbEngineCfg::Sqlite)
346 } else {
347 Err(DbError::UnknownDsn(dsn.to_owned()))
348 }
349}
350
351#[cfg(feature = "sqlite")]
353fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
354 let db_path = if let Some(dsn) = &cfg.dsn {
355 parse_sqlite_path_from_dsn(dsn)?
356 } else if let Some(path) = &cfg.path {
357 path.clone()
358 } else if let Some(_file) = &cfg.file {
359 return Err(DbError::InvalidParameter(
361 "File path should have been resolved to absolute path".to_owned(),
362 ));
363 } else {
364 return Err(DbError::InvalidParameter(
365 "SQLite connection requires either DSN, path, or file".to_owned(),
366 ));
367 };
368
369 if let Some(parent) = db_path.parent() {
371 std::fs::create_dir_all(parent)?;
372 }
373
374 let mut opts = sqlx::sqlite::SqliteConnectOptions::new()
375 .filename(&db_path)
376 .create_if_missing(true);
377
378 if let Some(params) = &cfg.params {
380 opts = sqlite_pragma::apply_pragmas(opts, params)?;
381 }
382
383 Ok(DbConnectOptions::Sqlite(opts))
384}
385
386#[cfg(not(feature = "sqlite"))]
387fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
388 Err(DbError::FeatureDisabled("SQLite feature not enabled"))
389}
390
391#[cfg(feature = "pg")]
393fn apply_pg_params<S: std::hash::BuildHasher>(
394 mut opts: sqlx::postgres::PgConnectOptions,
395 params: &std::collections::HashMap<String, String, S>,
396) -> Result<sqlx::postgres::PgConnectOptions> {
397 use sqlx::postgres::PgSslMode;
398
399 for (key, value) in params {
400 let key_lower = key.to_lowercase();
401 match key_lower.as_str() {
402 "sslmode" | "ssl_mode" => {
404 let mode = value.parse::<PgSslMode>().map_err(|_| {
405 DbError::InvalidParameter(format!(
406 "Invalid ssl_mode '{value}': expected disable, allow, prefer, require, verify-ca, or verify-full"
407 ))
408 })?;
409 opts = opts.ssl_mode(mode);
410 }
411 "sslrootcert" | "ssl_root_cert" => {
412 opts = opts.ssl_root_cert(value.as_str());
413 }
414 "sslcert" | "ssl_client_cert" => {
415 opts = opts.ssl_client_cert(value.as_str());
416 }
417 "sslkey" | "ssl_client_key" => {
418 opts = opts.ssl_client_key(value.as_str());
419 }
420 "application_name" => {
422 opts = opts.application_name(value);
423 }
424 "statement_cache_capacity" => {
425 let capacity = value.parse::<usize>().map_err(|_| {
426 DbError::InvalidParameter(format!(
427 "Invalid statement_cache_capacity '{value}': expected positive integer"
428 ))
429 })?;
430 opts = opts.statement_cache_capacity(capacity);
431 }
432 "extra_float_digits" => {
433 let val = value.parse::<i8>().map_err(|_| {
434 DbError::InvalidParameter(format!(
435 "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
436 ))
437 })?;
438 if !(-15..=3).contains(&val) {
439 return Err(DbError::InvalidParameter(format!(
440 "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
441 )));
442 }
443 opts = opts.extra_float_digits(val);
444 }
445 _ => {
447 opts = opts.options([(key.as_str(), value.as_str())]);
448 }
449 }
450 }
451
452 Ok(opts)
453}
454
455#[cfg(feature = "mysql")]
458fn apply_mysql_params<S: std::hash::BuildHasher>(
459 mut opts: sqlx::mysql::MySqlConnectOptions,
460 params: &std::collections::HashMap<String, String, S>,
461) -> Result<sqlx::mysql::MySqlConnectOptions> {
462 use sqlx::mysql::MySqlSslMode;
463
464 for (key, value) in params {
465 let key_lower = key.to_lowercase();
466 match key_lower.as_str() {
467 "sslmode" | "ssl_mode" | "ssl-mode" => {
469 let mode = value.parse::<MySqlSslMode>().map_err(|_| {
470 DbError::InvalidParameter(format!(
471 "Invalid ssl_mode '{value}': expected disabled, preferred, required, verify_ca, or verify_identity"
472 ))
473 })?;
474 opts = opts.ssl_mode(mode);
475 }
476 "sslca" | "ssl_ca" | "ssl-ca" => {
477 opts = opts.ssl_ca(value.as_str());
478 }
479 "sslcert" | "ssl_client_cert" | "ssl-cert" => {
480 opts = opts.ssl_client_cert(value.as_str());
481 }
482 "sslkey" | "ssl_client_key" | "ssl-key" => {
483 opts = opts.ssl_client_key(value.as_str());
484 }
485 "charset" => {
487 opts = opts.charset(value);
488 }
489 "collation" => {
490 opts = opts.collation(value);
491 }
492 "statement_cache_capacity" => {
493 let capacity = value.parse::<usize>().map_err(|_| {
494 DbError::InvalidParameter(format!(
495 "Invalid statement_cache_capacity '{value}': expected positive integer"
496 ))
497 })?;
498 opts = opts.statement_cache_capacity(capacity);
499 }
500 "connect_timeout" | "connect-timeout" => {
501 let _secs = value.parse::<u64>().map_err(|_| {
506 DbError::InvalidParameter(format!(
507 "Invalid connect_timeout '{value}': expected non-negative integer seconds"
508 ))
509 })?;
510 }
511 "socket" => {
512 opts = opts.socket(value.as_str());
513 }
514 "timezone" => {
515 let tz = if value.eq_ignore_ascii_case("none") || value.is_empty() {
516 None
517 } else {
518 Some(value.clone())
519 };
520 opts = opts.timezone(tz);
521 }
522 "pipes_as_concat" => {
523 let flag = parse_bool_param("pipes_as_concat", value)?;
524 opts = opts.pipes_as_concat(flag);
525 }
526 "no_engine_substitution" => {
527 let flag = parse_bool_param("no_engine_substitution", value)?;
528 opts = opts.no_engine_substitution(flag);
529 }
530 "enable_cleartext_plugin" => {
531 let flag = parse_bool_param("enable_cleartext_plugin", value)?;
532 opts = opts.enable_cleartext_plugin(flag);
533 }
534 "set_names" => {
535 let flag = parse_bool_param("set_names", value)?;
536 opts = opts.set_names(flag);
537 }
538 _ => {
540 return Err(DbError::InvalidParameter(format!(
541 "Unknown MySQL connection parameter: '{key}'"
542 )));
543 }
544 }
545 }
546
547 Ok(opts)
548}
549
550#[cfg(feature = "mysql")]
552fn parse_bool_param(name: &str, value: &str) -> Result<bool> {
553 match value.to_lowercase().as_str() {
554 "true" | "1" | "yes" | "on" => Ok(true),
555 "false" | "0" | "no" | "off" => Ok(false),
556 _ => Err(DbError::InvalidParameter(format!(
557 "Invalid {name} '{value}': expected true/false/1/0/yes/no/on/off"
558 ))),
559 }
560}
561
562fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
564 #[cfg(not(any(feature = "pg", feature = "mysql")))]
567 let _ = cfg;
568
569 match engine {
570 DbEngineCfg::Postgres => {
571 #[cfg(feature = "pg")]
572 {
573 let mut opts = if let Some(dsn) = &cfg.dsn {
574 dsn.parse::<sqlx::postgres::PgConnectOptions>()
575 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
576 } else {
577 sqlx::postgres::PgConnectOptions::new()
578 };
579
580 if let Some(host) = &cfg.host {
582 opts = opts.host(host);
583 }
584 if let Some(port) = cfg.port {
585 opts = opts.port(port);
586 }
587 if let Some(user) = &cfg.user {
588 opts = opts.username(user);
589 }
590 if let Some(password) = &cfg.password {
591 opts = opts.password(password);
592 }
593 if let Some(dbname) = &cfg.dbname {
594 opts = opts.database(dbname);
595 } else if cfg.dsn.is_none() {
596 return Err(DbError::InvalidParameter(
597 "dbname is required for PostgreSQL connections".to_owned(),
598 ));
599 }
600
601 if let Some(params) = &cfg.params {
603 opts = apply_pg_params(opts, params)?;
604 }
605
606 Ok(DbConnectOptions::Postgres(opts))
607 }
608 #[cfg(not(feature = "pg"))]
609 {
610 Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
611 }
612 }
613 DbEngineCfg::Mysql => {
614 #[cfg(feature = "mysql")]
615 {
616 let mut opts = if let Some(dsn) = &cfg.dsn {
617 dsn.parse::<sqlx::mysql::MySqlConnectOptions>()
618 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
619 } else {
620 sqlx::mysql::MySqlConnectOptions::new()
621 };
622
623 if let Some(host) = &cfg.host {
625 opts = opts.host(host);
626 }
627 if let Some(port) = cfg.port {
628 opts = opts.port(port);
629 }
630 if let Some(user) = &cfg.user {
631 opts = opts.username(user);
632 }
633 if let Some(password) = &cfg.password {
634 opts = opts.password(password);
635 }
636 if let Some(dbname) = &cfg.dbname {
637 opts = opts.database(dbname);
638 } else if cfg.dsn.is_none() {
639 return Err(DbError::InvalidParameter(
640 "dbname is required for MySQL connections".to_owned(),
641 ));
642 }
643
644 if let Some(params) = &cfg.params {
646 opts = apply_mysql_params(opts, params)?;
647 }
648
649 Ok(DbConnectOptions::MySql(opts))
650 }
651 #[cfg(not(feature = "mysql"))]
652 {
653 Err(DbError::FeatureDisabled("MySQL feature not enabled"))
654 }
655 }
656 DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
657 "build_server_options called with sqlite engine".to_owned(),
658 )),
659 }
660}
661
662#[cfg(feature = "sqlite")]
664fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
665 if dsn.starts_with("sqlite:") {
666 let path_part = dsn
667 .strip_prefix("sqlite:")
668 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
669 let path_part = if path_part.starts_with("//") {
670 path_part
671 .strip_prefix("//")
672 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
673 } else {
674 path_part
675 };
676
677 let path_part = if let Some(pos) = path_part.find('?') {
679 &path_part[..pos]
680 } else {
681 path_part
682 };
683
684 Ok(std::path::PathBuf::from(path_part))
685 } else {
686 Err(DbError::InvalidParameter(format!(
687 "Invalid SQLite DSN: {dsn}"
688 )))
689 }
690}
691
692fn resolve_password(password: &str) -> Result<String> {
694 if password.starts_with("${") && password.ends_with('}') {
695 let var_name = &password[2..password.len() - 1];
696 std::env::var(var_name).map_err(|source| DbError::EnvVar {
697 name: var_name.to_owned(),
698 source,
699 })
700 } else {
701 Ok(password.to_owned())
702 }
703}
704
705fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
707 if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
709 let inferred = engine_from_dsn(dsn)?;
710 if inferred != engine {
711 return Err(DbError::ConfigConflict(format!(
712 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
713 )));
714 }
715 }
716
717 if let Some(dsn) = &cfg.dsn {
719 let is_sqlite_dsn = dsn.starts_with("sqlite");
720 let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
721 let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
722
723 if is_sqlite_dsn && has_server_fields {
724 return Err(DbError::ConfigConflict(
725 "SQLite DSN cannot be used with host/port fields".to_owned(),
726 ));
727 }
728
729 if !is_sqlite_dsn && has_sqlite_fields {
730 return Err(DbError::ConfigConflict(
731 "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
732 ));
733 }
734
735 if !is_sqlite_dsn
737 && cfg.server.is_some()
738 && (cfg.host.is_some()
739 || cfg.port.is_some()
740 || cfg.user.is_some()
741 || cfg.password.is_some()
742 || cfg.dbname.is_some())
743 {
744 }
747 }
748
749 if cfg.file.is_some() && cfg.path.is_some() {
751 return Err(DbError::ConfigConflict(
752 "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
753 ));
754 }
755
756 if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
757 return Err(DbError::ConfigConflict(
758 "SQLite file/path fields cannot be used with host/port fields".to_owned(),
759 ));
760 }
761
762 if cfg.engine == Some(DbEngineCfg::Sqlite)
764 && (cfg.host.is_some()
765 || cfg.port.is_some()
766 || cfg.user.is_some()
767 || cfg.password.is_some()
768 || cfg.dbname.is_some())
769 {
770 return Err(DbError::ConfigConflict(
771 "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
772 ));
773 }
774
775 if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
777 && (cfg.file.is_some() || cfg.path.is_some())
778 {
779 return Err(DbError::ConfigConflict(
780 "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
781 ));
782 }
783
784 Ok(())
785}
786
787#[must_use]
789pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
790 match dsn {
791 Some(dsn) if dsn.contains('@') => {
792 if let Ok(mut parsed) = url::Url::parse(dsn) {
793 if parsed.password().is_some() {
794 _ = parsed.set_password(Some("***"));
795 }
796 parsed.to_string()
797 } else {
798 "***".to_owned()
799 }
800 }
801 Some(dsn) => dsn.to_owned(),
802 None => "none".to_owned(),
803 }
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809
810 #[test]
811 fn determine_engine_requires_engine_when_dsn_missing() {
812 let cfg = DbConnConfig {
813 dsn: None,
814 engine: None,
815 ..Default::default()
816 };
817
818 let err = determine_engine(&cfg).unwrap_err();
819 assert!(matches!(err, DbError::InvalidParameter(_)));
820 assert!(err.to_string().contains("Missing 'engine'"));
821 }
822
823 #[test]
824 fn determine_engine_infers_from_dsn_when_engine_missing() {
825 let cfg = DbConnConfig {
826 engine: None,
827 dsn: Some("sqlite::memory:".to_owned()),
828 ..Default::default()
829 };
830
831 let engine = determine_engine(&cfg).unwrap();
832 assert_eq!(engine, DbEngineCfg::Sqlite);
833 }
834
835 #[test]
836 fn engine_and_dsn_match_ok() {
837 let cases = [
838 (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
839 (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
840 (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
841 (DbEngineCfg::Sqlite, "sqlite::memory:"),
842 (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
843 ];
844
845 for (engine, dsn) in cases {
846 let cfg = DbConnConfig {
847 engine: Some(engine),
848 dsn: Some(dsn.to_owned()),
849 ..Default::default()
850 };
851 validate_config_consistency(&cfg).unwrap();
852 assert_eq!(determine_engine(&cfg).unwrap(), engine);
853 }
854 }
855
856 #[test]
857 fn engine_and_dsn_mismatch_is_error() {
858 let cases = [
859 (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
860 (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
861 (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
862 ];
863
864 for (engine, dsn) in cases {
865 let cfg = DbConnConfig {
866 engine: Some(engine),
867 dsn: Some(dsn.to_owned()),
868 ..Default::default()
869 };
870
871 let err = validate_config_consistency(&cfg).unwrap_err();
872 assert!(matches!(err, DbError::ConfigConflict(_)));
873 }
874 }
875
876 #[test]
877 fn unknown_dsn_is_error() {
878 let cfg = DbConnConfig {
879 engine: None,
880 dsn: Some("unknown://localhost/db".to_owned()),
881 ..Default::default()
882 };
883
884 validate_config_consistency(&cfg).unwrap();
887 let err = determine_engine(&cfg).unwrap_err();
888 assert!(matches!(err, DbError::UnknownDsn(_)));
889 }
890}