1use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5
6#[derive(Debug, Clone)]
10pub(crate) enum DbConnectOptions {
11 #[cfg(feature = "sqlite")]
12 Sqlite(sqlx::sqlite::SqliteConnectOptions),
13 #[cfg(feature = "pg")]
14 Postgres(sqlx::postgres::PgConnectOptions),
15 #[cfg(feature = "mysql")]
16 MySql(sqlx::mysql::MySqlConnectOptions),
17}
18
19impl std::fmt::Display for DbConnectOptions {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 #[cfg(feature = "sqlite")]
23 DbConnectOptions::Sqlite(opts) => {
24 let filename = opts.get_filename().display().to_string();
25 if filename.is_empty() {
26 write!(f, "sqlite://memory")
27 } else {
28 write!(f, "sqlite://{filename}")
29 }
30 }
31 #[cfg(feature = "pg")]
32 DbConnectOptions::Postgres(opts) => {
33 write!(
34 f,
35 "postgresql://<redacted>@{}:{}/{}",
36 opts.get_host(),
37 opts.get_port(),
38 opts.get_database().unwrap_or("")
39 )
40 }
41 #[cfg(feature = "mysql")]
42 DbConnectOptions::MySql(_opts) => {
43 write!(f, "mysql://<redacted>@...")
44 }
45 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
46 _ => {
47 unreachable!("No database features enabled")
48 }
49 }
50 }
51}
52
53#[cfg(feature = "sqlite")]
54fn is_memory_filename(path: &std::path::Path) -> bool {
55 if path.as_os_str().is_empty() {
56 return true;
57 }
58
59 match path.to_str() {
60 Some(raw) => matches!(
61 raw.trim(),
62 ":memory:" | "memory:" | "file::memory:" | "file:memory:" | ""
63 ),
64 None => false,
65 }
66}
67
68impl DbConnectOptions {
69 pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
74 match self {
75 #[cfg(feature = "sqlite")]
76 DbConnectOptions::Sqlite(opts) => {
77 let mut pool_opts = pool.apply_sqlite(sqlx::sqlite::SqlitePoolOptions::new());
78
79 if is_memory_filename(opts.get_filename()) {
80 pool_opts = pool_opts.max_connections(1).min_connections(1);
81 tracing::info!("Using single connection pool for in-memory SQLite database");
82 }
83
84 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
85
86 let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool);
87
88 let filename = opts.get_filename().display().to_string();
89 let handle = DbHandle {
90 engine: crate::DbEngine::Sqlite,
91 dsn: format!("sqlite://{filename}"),
92 sea,
93 };
94
95 Ok(handle)
96 }
97 #[cfg(feature = "pg")]
98 DbConnectOptions::Postgres(opts) => {
99 let pool_opts = pool.apply_pg(sqlx::postgres::PgPoolOptions::new());
100
101 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
102
103 let sea = sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool);
104
105 let handle = DbHandle {
106 engine: crate::DbEngine::Postgres,
107 dsn: format!(
108 "postgresql://<redacted>@{}:{}/{}",
109 opts.get_host(),
110 opts.get_port(),
111 opts.get_database().unwrap_or("")
112 ),
113 sea,
114 };
115
116 Ok(handle)
117 }
118 #[cfg(feature = "mysql")]
119 DbConnectOptions::MySql(opts) => {
120 let pool_opts = pool.apply_mysql(sqlx::mysql::MySqlPoolOptions::new());
121
122 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
123
124 let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool);
125
126 let handle = DbHandle {
127 engine: crate::DbEngine::MySql,
128 dsn: "mysql://<redacted>@...".to_owned(),
129 sea,
130 };
131
132 Ok(handle)
133 }
134 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
135 _ => {
136 unreachable!("No database features enabled")
137 }
138 }
139 }
140}
141
142#[cfg(feature = "sqlite")]
144pub mod sqlite_pragma {
145 use crate::DbError;
146 use std::collections::HashMap;
147 use std::hash::BuildHasher;
148
149 const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
151
152 pub fn apply_pragmas<S: BuildHasher>(
158 mut opts: sqlx::sqlite::SqliteConnectOptions,
159 params: &HashMap<String, String, S>,
160 ) -> crate::Result<sqlx::sqlite::SqliteConnectOptions> {
161 for (key, value) in params {
162 let key_lower = key.to_lowercase();
163
164 if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
165 return Err(DbError::UnknownSqlitePragma(key.clone()));
166 }
167
168 match key_lower.as_str() {
169 "wal" => {
170 let journal_mode = validate_wal_pragma(value)?;
171 opts = opts.pragma("journal_mode", journal_mode);
172 }
173 "journal_mode" => {
174 let mode = validate_journal_mode_pragma(value)?;
175 opts = opts.pragma("journal_mode", mode);
176 }
177 "synchronous" => {
178 let sync_mode = validate_synchronous_pragma(value)?;
179 opts = opts.pragma("synchronous", sync_mode);
180 }
181 "busy_timeout" => {
182 let timeout = validate_busy_timeout_pragma(value)?;
183 opts = opts.pragma("busy_timeout", timeout.to_string());
184 }
185 _ => unreachable!("Checked against whitelist above"),
186 }
187 }
188
189 Ok(opts)
190 }
191
192 fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
194 match value.to_lowercase().as_str() {
195 "true" | "1" => Ok("WAL"),
196 "false" | "0" => Ok("DELETE"),
197 _ => Err(DbError::InvalidSqlitePragma {
198 key: "wal".to_owned(),
199 message: format!("must be true/false/1/0, got '{value}'"),
200 }),
201 }
202 }
203
204 fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
206 match value.to_uppercase().as_str() {
207 "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
208 _ => Err(DbError::InvalidSqlitePragma {
209 key: "synchronous".to_owned(),
210 message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
211 }),
212 }
213 }
214
215 fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
217 let timeout = value
218 .parse::<i64>()
219 .map_err(|_| DbError::InvalidSqlitePragma {
220 key: "busy_timeout".to_owned(),
221 message: format!("must be a non-negative integer, got '{value}'"),
222 })?;
223
224 if timeout < 0 {
225 return Err(DbError::InvalidSqlitePragma {
226 key: "busy_timeout".to_owned(),
227 message: format!("must be non-negative, got '{timeout}'"),
228 });
229 }
230
231 Ok(timeout)
232 }
233
234 fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
236 match value.to_uppercase().as_str() {
237 "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
238 Ok(value.to_uppercase())
239 }
240 _ => Err(DbError::InvalidSqlitePragma {
241 key: "journal_mode".to_owned(),
242 message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
243 }),
244 }
245 }
246}
247
248pub(crate) async fn build_db_handle(
256 mut cfg: DbConnConfig,
257 _global: Option<&GlobalDatabaseConfig>,
258) -> Result<DbHandle> {
259 if let Some(dsn) = &cfg.dsn {
261 cfg.dsn = Some(expand_env_vars(dsn)?);
262 }
263 if let Some(password) = &cfg.password {
264 cfg.password = Some(resolve_password(password)?);
265 }
266
267 if let Some(ref mut params) = cfg.params {
269 for (_, value) in params.iter_mut() {
270 if value.contains("${") {
271 *value = expand_env_vars(value)?;
272 }
273 }
274 }
275
276 validate_config_consistency(&cfg)?;
278
279 let engine = determine_engine(&cfg)?;
281 let connect_options = match engine {
282 DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
283 DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
284 };
285
286 let pool_cfg = cfg.pool.unwrap_or_default();
288
289 let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
291 tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
292
293 let handle = connect_options.connect(pool_cfg).await?;
295
296 Ok(handle)
297}
298
299fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
300 if let Some(engine) = cfg.engine {
304 if let Some(dsn) = cfg.dsn.as_deref() {
305 let inferred = engine_from_dsn(dsn)?;
306 if inferred != engine {
307 return Err(DbError::ConfigConflict(format!(
308 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
309 )));
310 }
311 }
312 return Ok(engine);
313 }
314
315 if cfg.dsn.is_none() {
321 return Err(DbError::InvalidParameter(
322 "Missing 'engine': required when 'dsn' is not provided".to_owned(),
323 ));
324 }
325
326 let Some(dsn) = cfg.dsn.as_deref() else {
328 return Err(DbError::InvalidParameter(
330 "Missing 'dsn': required to infer database engine".to_owned(),
331 ));
332 };
333 engine_from_dsn(dsn)
334}
335
336fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
337 let s = dsn.trim_start();
338 if s.starts_with("postgres://") || s.starts_with("postgresql://") {
339 Ok(DbEngineCfg::Postgres)
340 } else if s.starts_with("mysql://") {
341 Ok(DbEngineCfg::Mysql)
342 } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
343 Ok(DbEngineCfg::Sqlite)
344 } else {
345 Err(DbError::UnknownDsn(dsn.to_owned()))
346 }
347}
348
349#[cfg(feature = "sqlite")]
351fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
352 let db_path = if let Some(dsn) = &cfg.dsn {
353 parse_sqlite_path_from_dsn(dsn)?
354 } else if let Some(path) = &cfg.path {
355 path.clone()
356 } else if let Some(_file) = &cfg.file {
357 return Err(DbError::InvalidParameter(
359 "File path should have been resolved to absolute path".to_owned(),
360 ));
361 } else {
362 return Err(DbError::InvalidParameter(
363 "SQLite connection requires either DSN, path, or file".to_owned(),
364 ));
365 };
366
367 if let Some(parent) = db_path.parent() {
369 std::fs::create_dir_all(parent)?;
370 }
371
372 let mut opts = sqlx::sqlite::SqliteConnectOptions::new()
373 .filename(&db_path)
374 .create_if_missing(true);
375
376 if let Some(params) = &cfg.params {
378 opts = sqlite_pragma::apply_pragmas(opts, params)?;
379 }
380
381 Ok(DbConnectOptions::Sqlite(opts))
382}
383
384#[cfg(not(feature = "sqlite"))]
385fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
386 Err(DbError::FeatureDisabled("SQLite feature not enabled"))
387}
388
389#[cfg(feature = "pg")]
391fn apply_pg_params<S: std::hash::BuildHasher>(
392 mut opts: sqlx::postgres::PgConnectOptions,
393 params: &std::collections::HashMap<String, String, S>,
394) -> Result<sqlx::postgres::PgConnectOptions> {
395 use sqlx::postgres::PgSslMode;
396
397 for (key, value) in params {
398 let key_lower = key.to_lowercase();
399 match key_lower.as_str() {
400 "sslmode" | "ssl_mode" => {
402 let mode = value.parse::<PgSslMode>().map_err(|_| {
403 DbError::InvalidParameter(format!(
404 "Invalid ssl_mode '{value}': expected disable, allow, prefer, require, verify-ca, or verify-full"
405 ))
406 })?;
407 opts = opts.ssl_mode(mode);
408 }
409 "sslrootcert" | "ssl_root_cert" => {
410 opts = opts.ssl_root_cert(value.as_str());
411 }
412 "sslcert" | "ssl_client_cert" => {
413 opts = opts.ssl_client_cert(value.as_str());
414 }
415 "sslkey" | "ssl_client_key" => {
416 opts = opts.ssl_client_key(value.as_str());
417 }
418 "application_name" => {
420 opts = opts.application_name(value);
421 }
422 "statement_cache_capacity" => {
423 let capacity = value.parse::<usize>().map_err(|_| {
424 DbError::InvalidParameter(format!(
425 "Invalid statement_cache_capacity '{value}': expected positive integer"
426 ))
427 })?;
428 opts = opts.statement_cache_capacity(capacity);
429 }
430 "extra_float_digits" => {
431 let val = value.parse::<i8>().map_err(|_| {
432 DbError::InvalidParameter(format!(
433 "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
434 ))
435 })?;
436 if !(-15..=3).contains(&val) {
437 return Err(DbError::InvalidParameter(format!(
438 "Invalid extra_float_digits '{value}': expected integer between -15 and 3"
439 )));
440 }
441 opts = opts.extra_float_digits(val);
442 }
443 _ => {
445 opts = opts.options([(key.as_str(), value.as_str())]);
446 }
447 }
448 }
449
450 Ok(opts)
451}
452
453#[cfg(feature = "mysql")]
456fn apply_mysql_params<S: std::hash::BuildHasher>(
457 mut opts: sqlx::mysql::MySqlConnectOptions,
458 params: &std::collections::HashMap<String, String, S>,
459) -> Result<sqlx::mysql::MySqlConnectOptions> {
460 use sqlx::mysql::MySqlSslMode;
461
462 for (key, value) in params {
463 let key_lower = key.to_lowercase();
464 match key_lower.as_str() {
465 "sslmode" | "ssl_mode" | "ssl-mode" => {
467 let mode = value.parse::<MySqlSslMode>().map_err(|_| {
468 DbError::InvalidParameter(format!(
469 "Invalid ssl_mode '{value}': expected disabled, preferred, required, verify_ca, or verify_identity"
470 ))
471 })?;
472 opts = opts.ssl_mode(mode);
473 }
474 "sslca" | "ssl_ca" | "ssl-ca" => {
475 opts = opts.ssl_ca(value.as_str());
476 }
477 "sslcert" | "ssl_client_cert" | "ssl-cert" => {
478 opts = opts.ssl_client_cert(value.as_str());
479 }
480 "sslkey" | "ssl_client_key" | "ssl-key" => {
481 opts = opts.ssl_client_key(value.as_str());
482 }
483 "charset" => {
485 opts = opts.charset(value);
486 }
487 "collation" => {
488 opts = opts.collation(value);
489 }
490 "statement_cache_capacity" => {
491 let capacity = value.parse::<usize>().map_err(|_| {
492 DbError::InvalidParameter(format!(
493 "Invalid statement_cache_capacity '{value}': expected positive integer"
494 ))
495 })?;
496 opts = opts.statement_cache_capacity(capacity);
497 }
498 "connect_timeout" | "connect-timeout" => {
499 let _secs = value.parse::<u64>().map_err(|_| {
504 DbError::InvalidParameter(format!(
505 "Invalid connect_timeout '{value}': expected non-negative integer seconds"
506 ))
507 })?;
508 }
509 "socket" => {
510 opts = opts.socket(value.as_str());
511 }
512 "timezone" => {
513 let tz = if value.eq_ignore_ascii_case("none") || value.is_empty() {
514 None
515 } else {
516 Some(value.clone())
517 };
518 opts = opts.timezone(tz);
519 }
520 "pipes_as_concat" => {
521 let flag = parse_bool_param("pipes_as_concat", value)?;
522 opts = opts.pipes_as_concat(flag);
523 }
524 "no_engine_substitution" => {
525 let flag = parse_bool_param("no_engine_substitution", value)?;
526 opts = opts.no_engine_substitution(flag);
527 }
528 "enable_cleartext_plugin" => {
529 let flag = parse_bool_param("enable_cleartext_plugin", value)?;
530 opts = opts.enable_cleartext_plugin(flag);
531 }
532 "set_names" => {
533 let flag = parse_bool_param("set_names", value)?;
534 opts = opts.set_names(flag);
535 }
536 _ => {
538 return Err(DbError::InvalidParameter(format!(
539 "Unknown MySQL connection parameter: '{key}'"
540 )));
541 }
542 }
543 }
544
545 Ok(opts)
546}
547
548#[cfg(feature = "mysql")]
550fn parse_bool_param(name: &str, value: &str) -> Result<bool> {
551 match value.to_lowercase().as_str() {
552 "true" | "1" | "yes" | "on" => Ok(true),
553 "false" | "0" | "no" | "off" => Ok(false),
554 _ => Err(DbError::InvalidParameter(format!(
555 "Invalid {name} '{value}': expected true/false/1/0/yes/no/on/off"
556 ))),
557 }
558}
559
560fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
562 #[cfg(not(any(feature = "pg", feature = "mysql")))]
565 let _ = cfg;
566
567 match engine {
568 DbEngineCfg::Postgres => {
569 #[cfg(feature = "pg")]
570 {
571 let mut opts = if let Some(dsn) = &cfg.dsn {
572 dsn.parse::<sqlx::postgres::PgConnectOptions>()
573 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
574 } else {
575 sqlx::postgres::PgConnectOptions::new()
576 };
577
578 if let Some(host) = &cfg.host {
580 opts = opts.host(host);
581 }
582 if let Some(port) = cfg.port {
583 opts = opts.port(port);
584 }
585 if let Some(user) = &cfg.user {
586 opts = opts.username(user);
587 }
588 if let Some(password) = &cfg.password {
589 opts = opts.password(password);
590 }
591 if let Some(dbname) = &cfg.dbname {
592 opts = opts.database(dbname);
593 } else if cfg.dsn.is_none() {
594 return Err(DbError::InvalidParameter(
595 "dbname is required for PostgreSQL connections".to_owned(),
596 ));
597 }
598
599 if let Some(params) = &cfg.params {
601 opts = apply_pg_params(opts, params)?;
602 }
603
604 Ok(DbConnectOptions::Postgres(opts))
605 }
606 #[cfg(not(feature = "pg"))]
607 {
608 Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
609 }
610 }
611 DbEngineCfg::Mysql => {
612 #[cfg(feature = "mysql")]
613 {
614 let mut opts = if let Some(dsn) = &cfg.dsn {
615 dsn.parse::<sqlx::mysql::MySqlConnectOptions>()
616 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
617 } else {
618 sqlx::mysql::MySqlConnectOptions::new()
619 };
620
621 if let Some(host) = &cfg.host {
623 opts = opts.host(host);
624 }
625 if let Some(port) = cfg.port {
626 opts = opts.port(port);
627 }
628 if let Some(user) = &cfg.user {
629 opts = opts.username(user);
630 }
631 if let Some(password) = &cfg.password {
632 opts = opts.password(password);
633 }
634 if let Some(dbname) = &cfg.dbname {
635 opts = opts.database(dbname);
636 } else if cfg.dsn.is_none() {
637 return Err(DbError::InvalidParameter(
638 "dbname is required for MySQL connections".to_owned(),
639 ));
640 }
641
642 if let Some(params) = &cfg.params {
644 opts = apply_mysql_params(opts, params)?;
645 }
646
647 Ok(DbConnectOptions::MySql(opts))
648 }
649 #[cfg(not(feature = "mysql"))]
650 {
651 Err(DbError::FeatureDisabled("MySQL feature not enabled"))
652 }
653 }
654 DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
655 "build_server_options called with sqlite engine".to_owned(),
656 )),
657 }
658}
659
660#[cfg(feature = "sqlite")]
662fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
663 if dsn.starts_with("sqlite:") {
664 let path_part = dsn
665 .strip_prefix("sqlite:")
666 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
667 let path_part = if path_part.starts_with("//") {
668 path_part
669 .strip_prefix("//")
670 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
671 } else {
672 path_part
673 };
674
675 let path_part = if let Some(pos) = path_part.find('?') {
677 &path_part[..pos]
678 } else {
679 path_part
680 };
681
682 Ok(std::path::PathBuf::from(path_part))
683 } else {
684 Err(DbError::InvalidParameter(format!(
685 "Invalid SQLite DSN: {dsn}"
686 )))
687 }
688}
689
690fn expand_env_vars(input: &str) -> Result<String> {
692 let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
693 .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
694 let mut result = input.to_owned();
695
696 for caps in re.captures_iter(input) {
697 let full_match = &caps[0];
698 let var_name = &caps[1];
699 let value = std::env::var(var_name)?;
700 result = result.replace(full_match, &value);
701 }
702
703 Ok(result)
704}
705
706fn resolve_password(password: &str) -> Result<String> {
708 if password.starts_with("${") && password.ends_with('}') {
709 let var_name = &password[2..password.len() - 1];
710 Ok(std::env::var(var_name)?)
711 } else {
712 Ok(password.to_owned())
713 }
714}
715
716fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
718 if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
720 let inferred = engine_from_dsn(dsn)?;
721 if inferred != engine {
722 return Err(DbError::ConfigConflict(format!(
723 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
724 )));
725 }
726 }
727
728 if let Some(dsn) = &cfg.dsn {
730 let is_sqlite_dsn = dsn.starts_with("sqlite");
731 let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
732 let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
733
734 if is_sqlite_dsn && has_server_fields {
735 return Err(DbError::ConfigConflict(
736 "SQLite DSN cannot be used with host/port fields".to_owned(),
737 ));
738 }
739
740 if !is_sqlite_dsn && has_sqlite_fields {
741 return Err(DbError::ConfigConflict(
742 "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
743 ));
744 }
745
746 if !is_sqlite_dsn
748 && cfg.server.is_some()
749 && (cfg.host.is_some()
750 || cfg.port.is_some()
751 || cfg.user.is_some()
752 || cfg.password.is_some()
753 || cfg.dbname.is_some())
754 {
755 }
758 }
759
760 if cfg.file.is_some() && cfg.path.is_some() {
762 return Err(DbError::ConfigConflict(
763 "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
764 ));
765 }
766
767 if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
768 return Err(DbError::ConfigConflict(
769 "SQLite file/path fields cannot be used with host/port fields".to_owned(),
770 ));
771 }
772
773 if cfg.engine == Some(DbEngineCfg::Sqlite)
775 && (cfg.host.is_some()
776 || cfg.port.is_some()
777 || cfg.user.is_some()
778 || cfg.password.is_some()
779 || cfg.dbname.is_some())
780 {
781 return Err(DbError::ConfigConflict(
782 "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
783 ));
784 }
785
786 if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
788 && (cfg.file.is_some() || cfg.path.is_some())
789 {
790 return Err(DbError::ConfigConflict(
791 "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
792 ));
793 }
794
795 Ok(())
796}
797
798#[must_use]
800pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
801 match dsn {
802 Some(dsn) if dsn.contains('@') => {
803 if let Ok(mut parsed) = url::Url::parse(dsn) {
804 if parsed.password().is_some() {
805 _ = parsed.set_password(Some("***"));
806 }
807 parsed.to_string()
808 } else {
809 "***".to_owned()
810 }
811 }
812 Some(dsn) => dsn.to_owned(),
813 None => "none".to_owned(),
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820
821 #[test]
822 fn determine_engine_requires_engine_when_dsn_missing() {
823 let cfg = DbConnConfig {
824 dsn: None,
825 engine: None,
826 ..Default::default()
827 };
828
829 let err = determine_engine(&cfg).unwrap_err();
830 assert!(matches!(err, DbError::InvalidParameter(_)));
831 assert!(err.to_string().contains("Missing 'engine'"));
832 }
833
834 #[test]
835 fn determine_engine_infers_from_dsn_when_engine_missing() {
836 let cfg = DbConnConfig {
837 engine: None,
838 dsn: Some("sqlite::memory:".to_owned()),
839 ..Default::default()
840 };
841
842 let engine = determine_engine(&cfg).unwrap();
843 assert_eq!(engine, DbEngineCfg::Sqlite);
844 }
845
846 #[test]
847 fn engine_and_dsn_match_ok() {
848 let cases = [
849 (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
850 (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
851 (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
852 (DbEngineCfg::Sqlite, "sqlite::memory:"),
853 (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
854 ];
855
856 for (engine, dsn) in cases {
857 let cfg = DbConnConfig {
858 engine: Some(engine),
859 dsn: Some(dsn.to_owned()),
860 ..Default::default()
861 };
862 validate_config_consistency(&cfg).unwrap();
863 assert_eq!(determine_engine(&cfg).unwrap(), engine);
864 }
865 }
866
867 #[test]
868 fn engine_and_dsn_mismatch_is_error() {
869 let cases = [
870 (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
871 (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
872 (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
873 ];
874
875 for (engine, dsn) in cases {
876 let cfg = DbConnConfig {
877 engine: Some(engine),
878 dsn: Some(dsn.to_owned()),
879 ..Default::default()
880 };
881
882 let err = validate_config_consistency(&cfg).unwrap_err();
883 assert!(matches!(err, DbError::ConfigConflict(_)));
884 }
885 }
886
887 #[test]
888 fn unknown_dsn_is_error() {
889 let cfg = DbConnConfig {
890 engine: None,
891 dsn: Some("unknown://localhost/db".to_owned()),
892 ..Default::default()
893 };
894
895 validate_config_consistency(&cfg).unwrap();
898 let err = determine_engine(&cfg).unwrap_err();
899 assert!(matches!(err, DbError::UnknownDsn(_)));
900 }
901}