1use crate::config::{DbConnConfig, DbEngineCfg, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5use thiserror::Error;
6
7#[derive(Debug, Clone)]
11pub enum DbConnectOptions {
12 #[cfg(feature = "sqlite")]
13 Sqlite(sea_orm::sqlx::sqlite::SqliteConnectOptions),
14 #[cfg(feature = "pg")]
15 Postgres(sea_orm::sqlx::postgres::PgConnectOptions),
16 #[cfg(feature = "mysql")]
17 MySql(sea_orm::sqlx::mysql::MySqlConnectOptions),
18}
19
20#[derive(Debug, Error)]
22pub enum ConnectionOptionsError {
23 #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
24 InvalidSqlitePragma { key: String, message: String },
25
26 #[error("Unknown SQLite PRAGMA parameter: {0}")]
27 UnknownSqlitePragma(String),
28
29 #[error("Invalid connection parameter: {0}")]
30 InvalidParameter(String),
31
32 #[error("Feature not enabled: {0}")]
33 FeatureDisabled(&'static str),
34
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37
38 #[error("URL parsing error: {0}")]
39 UrlParse(#[from] url::ParseError),
40
41 #[error("Environment variable error: {0}")]
42 EnvVar(#[from] std::env::VarError),
43}
44
45impl std::fmt::Display for DbConnectOptions {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 #[cfg(feature = "sqlite")]
49 DbConnectOptions::Sqlite(opts) => {
50 let filename = opts.get_filename().display().to_string();
51 if filename.is_empty() {
52 write!(f, "sqlite://memory")
53 } else {
54 write!(f, "sqlite://{filename}")
55 }
56 }
57 #[cfg(feature = "pg")]
58 DbConnectOptions::Postgres(opts) => {
59 write!(
60 f,
61 "postgresql://<redacted>@{}:{}/{}",
62 opts.get_host(),
63 opts.get_port(),
64 opts.get_database().unwrap_or("")
65 )
66 }
67 #[cfg(feature = "mysql")]
68 DbConnectOptions::MySql(_opts) => {
69 write!(f, "mysql://<redacted>@...")
70 }
71 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
72 _ => {
73 unreachable!("No database features enabled")
74 }
75 }
76 }
77}
78
79impl DbConnectOptions {
80 pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
85 match self {
86 #[cfg(feature = "sqlite")]
87 DbConnectOptions::Sqlite(opts) => {
88 let pool_opts = pool.apply_sqlite(sea_orm::sqlx::sqlite::SqlitePoolOptions::new());
89
90 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
91
92 let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool.clone());
93
94 let filename = opts.get_filename().display().to_string();
95 let handle = DbHandle {
96 engine: crate::DbEngine::Sqlite,
97 pool: crate::DbPool::Sqlite(sqlx_pool),
98 dsn: format!("sqlite://{filename}"),
99 sea,
100 };
101
102 Ok(handle)
103 }
104 #[cfg(feature = "pg")]
105 DbConnectOptions::Postgres(opts) => {
106 let pool_opts = pool.apply_pg(sea_orm::sqlx::postgres::PgPoolOptions::new());
107
108 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
109
110 let sea =
111 sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool.clone());
112
113 let handle = DbHandle {
114 engine: crate::DbEngine::Postgres,
115 pool: crate::DbPool::Postgres(sqlx_pool),
116 dsn: format!(
117 "postgresql://<redacted>@{}:{}/{}",
118 opts.get_host(),
119 opts.get_port(),
120 opts.get_database().unwrap_or("")
121 ),
122 sea,
123 };
124
125 Ok(handle)
126 }
127 #[cfg(feature = "mysql")]
128 DbConnectOptions::MySql(opts) => {
129 let pool_opts = pool.apply_mysql(sea_orm::sqlx::mysql::MySqlPoolOptions::new());
130
131 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
132
133 let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool.clone());
134
135 let handle = DbHandle {
136 engine: crate::DbEngine::MySql,
137 pool: crate::DbPool::MySql(sqlx_pool),
138 dsn: "mysql://<redacted>@...".to_owned(),
139 sea,
140 };
141
142 Ok(handle)
143 }
144 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
145 _ => {
146 unreachable!("No database features enabled")
147 }
148 }
149 }
150}
151
152#[cfg(feature = "sqlite")]
154pub mod sqlite_pragma {
155 use crate::DbError;
156 use std::collections::HashMap;
157 use std::hash::BuildHasher;
158
159 const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
161
162 pub fn apply_pragmas<S: BuildHasher>(
168 mut opts: sea_orm::sqlx::sqlite::SqliteConnectOptions,
169 params: &HashMap<String, String, S>,
170 ) -> crate::Result<sea_orm::sqlx::sqlite::SqliteConnectOptions> {
171 for (key, value) in params {
172 let key_lower = key.to_lowercase();
173
174 if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
175 return Err(DbError::UnknownSqlitePragma(key.clone()));
176 }
177
178 match key_lower.as_str() {
179 "wal" => {
180 let journal_mode = validate_wal_pragma(value)?;
181 opts = opts.pragma("journal_mode", journal_mode);
182 }
183 "journal_mode" => {
184 let mode = validate_journal_mode_pragma(value)?;
185 opts = opts.pragma("journal_mode", mode);
186 }
187 "synchronous" => {
188 let sync_mode = validate_synchronous_pragma(value)?;
189 opts = opts.pragma("synchronous", sync_mode);
190 }
191 "busy_timeout" => {
192 let timeout = validate_busy_timeout_pragma(value)?;
193 opts = opts.pragma("busy_timeout", timeout.to_string());
194 }
195 _ => unreachable!("Checked against whitelist above"),
196 }
197 }
198
199 Ok(opts)
200 }
201
202 fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
204 match value.to_lowercase().as_str() {
205 "true" | "1" => Ok("WAL"),
206 "false" | "0" => Ok("DELETE"),
207 _ => Err(DbError::InvalidSqlitePragma {
208 key: "wal".to_owned(),
209 message: format!("must be true/false/1/0, got '{value}'"),
210 }),
211 }
212 }
213
214 fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
216 match value.to_uppercase().as_str() {
217 "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
218 _ => Err(DbError::InvalidSqlitePragma {
219 key: "synchronous".to_owned(),
220 message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
221 }),
222 }
223 }
224
225 fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
227 let timeout = value
228 .parse::<i64>()
229 .map_err(|_| DbError::InvalidSqlitePragma {
230 key: "busy_timeout".to_owned(),
231 message: format!("must be a non-negative integer, got '{value}'"),
232 })?;
233
234 if timeout < 0 {
235 return Err(DbError::InvalidSqlitePragma {
236 key: "busy_timeout".to_owned(),
237 message: format!("must be non-negative, got '{timeout}'"),
238 });
239 }
240
241 Ok(timeout)
242 }
243
244 fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
246 match value.to_uppercase().as_str() {
247 "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
248 Ok(value.to_uppercase())
249 }
250 _ => Err(DbError::InvalidSqlitePragma {
251 key: "journal_mode".to_owned(),
252 message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
253 }),
254 }
255 }
256}
257
258pub async fn build_db_handle(
264 mut cfg: DbConnConfig,
265 _global: Option<&GlobalDatabaseConfig>,
266) -> Result<DbHandle> {
267 if let Some(dsn) = &cfg.dsn {
269 cfg.dsn = Some(expand_env_vars(dsn)?);
270 }
271 if let Some(password) = &cfg.password {
272 cfg.password = Some(resolve_password(password)?);
273 }
274
275 if let Some(ref mut params) = cfg.params {
277 for (_, value) in params.iter_mut() {
278 if value.contains("${") {
279 *value = expand_env_vars(value)?;
280 }
281 }
282 }
283
284 validate_config_consistency(&cfg)?;
286
287 let engine = determine_engine(&cfg)?;
289 let connect_options = match engine {
290 DbEngineCfg::Sqlite => build_sqlite_options(&cfg)?,
291 DbEngineCfg::Postgres | DbEngineCfg::Mysql => build_server_options(&cfg, engine)?,
292 };
293
294 let pool_cfg = cfg.pool.unwrap_or_default();
296
297 let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
299 tracing::debug!(dsn = log_dsn, engine = ?engine, "Building database connection");
300
301 let handle = connect_options.connect(pool_cfg).await?;
303
304 Ok(handle)
305}
306
307fn determine_engine(cfg: &DbConnConfig) -> Result<DbEngineCfg> {
308 if let Some(engine) = cfg.engine {
312 if let Some(dsn) = cfg.dsn.as_deref() {
313 let inferred = engine_from_dsn(dsn)?;
314 if inferred != engine {
315 return Err(DbError::ConfigConflict(format!(
316 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
317 )));
318 }
319 }
320 return Ok(engine);
321 }
322
323 if cfg.dsn.is_none() {
329 return Err(DbError::InvalidParameter(
330 "Missing 'engine': required when 'dsn' is not provided".to_owned(),
331 ));
332 }
333
334 let Some(dsn) = cfg.dsn.as_deref() else {
336 return Err(DbError::InvalidParameter(
338 "Missing 'dsn': required to infer database engine".to_owned(),
339 ));
340 };
341 engine_from_dsn(dsn)
342}
343
344fn engine_from_dsn(dsn: &str) -> Result<DbEngineCfg> {
345 let s = dsn.trim_start();
346 if s.starts_with("postgres://") || s.starts_with("postgresql://") {
347 Ok(DbEngineCfg::Postgres)
348 } else if s.starts_with("mysql://") {
349 Ok(DbEngineCfg::Mysql)
350 } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
351 Ok(DbEngineCfg::Sqlite)
352 } else {
353 Err(DbError::UnknownDsn(dsn.to_owned()))
354 }
355}
356
357#[cfg(feature = "sqlite")]
359fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
360 let db_path = if let Some(dsn) = &cfg.dsn {
361 parse_sqlite_path_from_dsn(dsn)?
362 } else if let Some(path) = &cfg.path {
363 path.clone()
364 } else if let Some(_file) = &cfg.file {
365 return Err(DbError::InvalidParameter(
367 "File path should have been resolved to absolute path".to_owned(),
368 ));
369 } else {
370 return Err(DbError::InvalidParameter(
371 "SQLite connection requires either DSN, path, or file".to_owned(),
372 ));
373 };
374
375 if let Some(parent) = db_path.parent() {
377 std::fs::create_dir_all(parent)?;
378 }
379
380 let mut opts = sea_orm::sqlx::sqlite::SqliteConnectOptions::new()
381 .filename(&db_path)
382 .create_if_missing(true);
383
384 if let Some(params) = &cfg.params {
386 opts = sqlite_pragma::apply_pragmas(opts, params)?;
387 }
388
389 Ok(DbConnectOptions::Sqlite(opts))
390}
391
392#[cfg(not(feature = "sqlite"))]
393fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
394 Err(DbError::FeatureDisabled("SQLite feature not enabled"))
395}
396
397fn build_server_options(cfg: &DbConnConfig, engine: DbEngineCfg) -> Result<DbConnectOptions> {
399 #[cfg(not(any(feature = "pg", feature = "mysql")))]
402 let _ = cfg;
403
404 match engine {
405 DbEngineCfg::Postgres => {
406 #[cfg(feature = "pg")]
407 {
408 let mut opts = if let Some(dsn) = &cfg.dsn {
409 dsn.parse::<sea_orm::sqlx::postgres::PgConnectOptions>()
410 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
411 } else {
412 sea_orm::sqlx::postgres::PgConnectOptions::new()
413 };
414
415 if let Some(host) = &cfg.host {
417 opts = opts.host(host);
418 }
419 if let Some(port) = cfg.port {
420 opts = opts.port(port);
421 }
422 if let Some(user) = &cfg.user {
423 opts = opts.username(user);
424 }
425 if let Some(password) = &cfg.password {
426 opts = opts.password(password);
427 }
428 if let Some(dbname) = &cfg.dbname {
429 opts = opts.database(dbname);
430 } else if cfg.dsn.is_none() {
431 return Err(DbError::InvalidParameter(
432 "dbname is required for PostgreSQL connections".to_owned(),
433 ));
434 }
435
436 if let Some(params) = &cfg.params {
438 for (key, value) in params {
439 opts = opts.options([(key.as_str(), value.as_str())]);
440 }
441 }
442
443 Ok(DbConnectOptions::Postgres(opts))
444 }
445 #[cfg(not(feature = "pg"))]
446 {
447 Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
448 }
449 }
450 DbEngineCfg::Mysql => {
451 #[cfg(feature = "mysql")]
452 {
453 let mut opts = if let Some(dsn) = &cfg.dsn {
454 dsn.parse::<sea_orm::sqlx::mysql::MySqlConnectOptions>()
455 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
456 } else {
457 sea_orm::sqlx::mysql::MySqlConnectOptions::new()
458 };
459
460 if let Some(host) = &cfg.host {
462 opts = opts.host(host);
463 }
464 if let Some(port) = cfg.port {
465 opts = opts.port(port);
466 }
467 if let Some(user) = &cfg.user {
468 opts = opts.username(user);
469 }
470 if let Some(password) = &cfg.password {
471 opts = opts.password(password);
472 }
473 if let Some(dbname) = &cfg.dbname {
474 opts = opts.database(dbname);
475 } else if cfg.dsn.is_none() {
476 return Err(DbError::InvalidParameter(
477 "dbname is required for MySQL connections".to_owned(),
478 ));
479 }
480
481 Ok(DbConnectOptions::MySql(opts))
482 }
483 #[cfg(not(feature = "mysql"))]
484 {
485 Err(DbError::FeatureDisabled("MySQL feature not enabled"))
486 }
487 }
488 DbEngineCfg::Sqlite => Err(DbError::InvalidParameter(
489 "build_server_options called with sqlite engine".to_owned(),
490 )),
491 }
492}
493
494#[cfg(feature = "sqlite")]
496fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
497 if dsn.starts_with("sqlite:") {
498 let path_part = dsn
499 .strip_prefix("sqlite:")
500 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
501 let path_part = if path_part.starts_with("//") {
502 path_part
503 .strip_prefix("//")
504 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
505 } else {
506 path_part
507 };
508
509 let path_part = if let Some(pos) = path_part.find('?') {
511 &path_part[..pos]
512 } else {
513 path_part
514 };
515
516 Ok(std::path::PathBuf::from(path_part))
517 } else {
518 Err(DbError::InvalidParameter(format!(
519 "Invalid SQLite DSN: {dsn}"
520 )))
521 }
522}
523
524fn expand_env_vars(input: &str) -> Result<String> {
526 let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
527 .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
528 let mut result = input.to_owned();
529
530 for caps in re.captures_iter(input) {
531 let full_match = &caps[0];
532 let var_name = &caps[1];
533 let value = std::env::var(var_name)?;
534 result = result.replace(full_match, &value);
535 }
536
537 Ok(result)
538}
539
540fn resolve_password(password: &str) -> Result<String> {
542 if password.starts_with("${") && password.ends_with('}') {
543 let var_name = &password[2..password.len() - 1];
544 Ok(std::env::var(var_name)?)
545 } else {
546 Ok(password.to_owned())
547 }
548}
549
550fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
552 if let (Some(engine), Some(dsn)) = (cfg.engine, cfg.dsn.as_deref()) {
554 let inferred = engine_from_dsn(dsn)?;
555 if inferred != engine {
556 return Err(DbError::ConfigConflict(format!(
557 "engine='{engine:?}' conflicts with DSN scheme inferred as '{inferred:?}'"
558 )));
559 }
560 }
561
562 if let Some(dsn) = &cfg.dsn {
564 let is_sqlite_dsn = dsn.starts_with("sqlite");
565 let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
566 let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
567
568 if is_sqlite_dsn && has_server_fields {
569 return Err(DbError::ConfigConflict(
570 "SQLite DSN cannot be used with host/port fields".to_owned(),
571 ));
572 }
573
574 if !is_sqlite_dsn && has_sqlite_fields {
575 return Err(DbError::ConfigConflict(
576 "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
577 ));
578 }
579
580 if !is_sqlite_dsn
582 && cfg.server.is_some()
583 && (cfg.host.is_some()
584 || cfg.port.is_some()
585 || cfg.user.is_some()
586 || cfg.password.is_some()
587 || cfg.dbname.is_some())
588 {
589 }
592 }
593
594 if cfg.file.is_some() && cfg.path.is_some() {
596 return Err(DbError::ConfigConflict(
597 "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
598 ));
599 }
600
601 if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
602 return Err(DbError::ConfigConflict(
603 "SQLite file/path fields cannot be used with host/port fields".to_owned(),
604 ));
605 }
606
607 if cfg.engine == Some(DbEngineCfg::Sqlite)
609 && (cfg.host.is_some()
610 || cfg.port.is_some()
611 || cfg.user.is_some()
612 || cfg.password.is_some()
613 || cfg.dbname.is_some())
614 {
615 return Err(DbError::ConfigConflict(
616 "engine=sqlite cannot be used with host/port/user/password/dbname fields".to_owned(),
617 ));
618 }
619
620 if matches!(cfg.engine, Some(DbEngineCfg::Postgres | DbEngineCfg::Mysql))
622 && (cfg.file.is_some() || cfg.path.is_some())
623 {
624 return Err(DbError::ConfigConflict(
625 "engine=postgres/mysql cannot be used with file/path fields".to_owned(),
626 ));
627 }
628
629 Ok(())
630}
631
632#[must_use]
634pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
635 match dsn {
636 Some(dsn) if dsn.contains('@') => {
637 if let Ok(mut parsed) = url::Url::parse(dsn) {
638 if parsed.password().is_some() {
639 let _ = parsed.set_password(Some("***"));
640 }
641 parsed.to_string()
642 } else {
643 "***".to_owned()
644 }
645 }
646 Some(dsn) => dsn.to_owned(),
647 None => "none".to_owned(),
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn determine_engine_requires_engine_when_dsn_missing() {
657 let cfg = DbConnConfig {
658 dsn: None,
659 engine: None,
660 ..Default::default()
661 };
662
663 let err = determine_engine(&cfg).unwrap_err();
664 assert!(matches!(err, DbError::InvalidParameter(_)));
665 assert!(err.to_string().contains("Missing 'engine'"));
666 }
667
668 #[test]
669 fn determine_engine_infers_from_dsn_when_engine_missing() {
670 let cfg = DbConnConfig {
671 engine: None,
672 dsn: Some("sqlite::memory:".to_owned()),
673 ..Default::default()
674 };
675
676 let engine = determine_engine(&cfg).unwrap();
677 assert_eq!(engine, DbEngineCfg::Sqlite);
678 }
679
680 #[test]
681 fn engine_and_dsn_match_ok() {
682 let cases = [
683 (DbEngineCfg::Postgres, "postgres://user:pass@localhost/db"),
684 (DbEngineCfg::Postgres, "postgresql://user:pass@localhost/db"),
685 (DbEngineCfg::Mysql, "mysql://user:pass@localhost/db"),
686 (DbEngineCfg::Sqlite, "sqlite::memory:"),
687 (DbEngineCfg::Sqlite, "sqlite:///tmp/test.db"),
688 ];
689
690 for (engine, dsn) in cases {
691 let cfg = DbConnConfig {
692 engine: Some(engine),
693 dsn: Some(dsn.to_owned()),
694 ..Default::default()
695 };
696 validate_config_consistency(&cfg).unwrap();
697 assert_eq!(determine_engine(&cfg).unwrap(), engine);
698 }
699 }
700
701 #[test]
702 fn engine_and_dsn_mismatch_is_error() {
703 let cases = [
704 (DbEngineCfg::Postgres, "mysql://user:pass@localhost/db"),
705 (DbEngineCfg::Mysql, "postgres://user:pass@localhost/db"),
706 (DbEngineCfg::Sqlite, "postgresql://user:pass@localhost/db"),
707 ];
708
709 for (engine, dsn) in cases {
710 let cfg = DbConnConfig {
711 engine: Some(engine),
712 dsn: Some(dsn.to_owned()),
713 ..Default::default()
714 };
715
716 let err = validate_config_consistency(&cfg).unwrap_err();
717 assert!(matches!(err, DbError::ConfigConflict(_)));
718 }
719 }
720
721 #[test]
722 fn unknown_dsn_is_error() {
723 let cfg = DbConnConfig {
724 engine: None,
725 dsn: Some("unknown://localhost/db".to_owned()),
726 ..Default::default()
727 };
728
729 validate_config_consistency(&cfg).unwrap();
732 let err = determine_engine(&cfg).unwrap_err();
733 assert!(matches!(err, DbError::UnknownDsn(_)));
734 }
735}