1use crate::config::{DbConnConfig, GlobalDatabaseConfig, PoolCfg};
4use crate::{DbError, DbHandle, Result};
5use thiserror::Error;
6
7#[derive(Debug, Clone)]
11pub enum DbConnectOptions {
12 #[cfg(feature = "sqlite")]
13 Sqlite(sea_orm::sqlx::sqlite::SqliteConnectOptions),
14 #[cfg(feature = "pg")]
15 Postgres(sea_orm::sqlx::postgres::PgConnectOptions),
16 #[cfg(feature = "mysql")]
17 MySql(sea_orm::sqlx::mysql::MySqlConnectOptions),
18}
19
20#[derive(Debug, Error)]
22pub enum ConnectionOptionsError {
23 #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
24 InvalidSqlitePragma { key: String, message: String },
25
26 #[error("Unknown SQLite PRAGMA parameter: {0}")]
27 UnknownSqlitePragma(String),
28
29 #[error("Invalid connection parameter: {0}")]
30 InvalidParameter(String),
31
32 #[error("Feature not enabled: {0}")]
33 FeatureDisabled(&'static str),
34
35 #[error("IO error: {0}")]
36 Io(#[from] std::io::Error),
37
38 #[error("URL parsing error: {0}")]
39 UrlParse(#[from] url::ParseError),
40
41 #[error("Environment variable error: {0}")]
42 EnvVar(#[from] std::env::VarError),
43}
44
45impl std::fmt::Display for DbConnectOptions {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 #[cfg(feature = "sqlite")]
49 DbConnectOptions::Sqlite(opts) => {
50 let filename = opts.get_filename().display().to_string();
51 if filename.is_empty() {
52 write!(f, "sqlite://memory")
53 } else {
54 write!(f, "sqlite://{filename}")
55 }
56 }
57 #[cfg(feature = "pg")]
58 DbConnectOptions::Postgres(opts) => {
59 write!(
60 f,
61 "postgresql://<redacted>@{}:{}/{}",
62 opts.get_host(),
63 opts.get_port(),
64 opts.get_database().unwrap_or("")
65 )
66 }
67 #[cfg(feature = "mysql")]
68 DbConnectOptions::MySql(_opts) => {
69 write!(f, "mysql://<redacted>@...")
70 }
71 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
72 _ => {
73 unreachable!("No database features enabled")
74 }
75 }
76 }
77}
78
79impl DbConnectOptions {
80 pub async fn connect(&self, pool: PoolCfg) -> Result<DbHandle> {
85 match self {
86 #[cfg(feature = "sqlite")]
87 DbConnectOptions::Sqlite(opts) => {
88 let pool_opts = pool.apply_sqlite(sea_orm::sqlx::sqlite::SqlitePoolOptions::new());
89
90 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
91
92 let sea = sea_orm::SqlxSqliteConnector::from_sqlx_sqlite_pool(sqlx_pool.clone());
93
94 let filename = opts.get_filename().display().to_string();
95 let handle = DbHandle {
96 engine: crate::DbEngine::Sqlite,
97 pool: crate::DbPool::Sqlite(sqlx_pool),
98 dsn: format!("sqlite://{filename}"),
99 sea,
100 };
101
102 Ok(handle)
103 }
104 #[cfg(feature = "pg")]
105 DbConnectOptions::Postgres(opts) => {
106 let pool_opts = pool.apply_pg(sea_orm::sqlx::postgres::PgPoolOptions::new());
107
108 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
109
110 let sea =
111 sea_orm::SqlxPostgresConnector::from_sqlx_postgres_pool(sqlx_pool.clone());
112
113 let handle = DbHandle {
114 engine: crate::DbEngine::Postgres,
115 pool: crate::DbPool::Postgres(sqlx_pool),
116 dsn: format!(
117 "postgresql://<redacted>@{}:{}/{}",
118 opts.get_host(),
119 opts.get_port(),
120 opts.get_database().unwrap_or("")
121 ),
122 sea,
123 };
124
125 Ok(handle)
126 }
127 #[cfg(feature = "mysql")]
128 DbConnectOptions::MySql(opts) => {
129 let pool_opts = pool.apply_mysql(sea_orm::sqlx::mysql::MySqlPoolOptions::new());
130
131 let sqlx_pool = pool_opts.connect_with(opts.clone()).await?;
132
133 let sea = sea_orm::SqlxMySqlConnector::from_sqlx_mysql_pool(sqlx_pool.clone());
134
135 let handle = DbHandle {
136 engine: crate::DbEngine::MySql,
137 pool: crate::DbPool::MySql(sqlx_pool),
138 dsn: "mysql://<redacted>@...".to_owned(),
139 sea,
140 };
141
142 Ok(handle)
143 }
144 #[cfg(not(any(feature = "sqlite", feature = "pg", feature = "mysql")))]
145 _ => {
146 unreachable!("No database features enabled")
147 }
148 }
149 }
150}
151
152#[cfg(feature = "sqlite")]
154pub mod sqlite_pragma {
155 use crate::DbError;
156 use std::collections::HashMap;
157 use std::hash::BuildHasher;
158
159 const ALLOWED_PRAGMAS: &[&str] = &["wal", "synchronous", "busy_timeout", "journal_mode"];
161
162 pub fn apply_pragmas<S: BuildHasher>(
168 mut opts: sea_orm::sqlx::sqlite::SqliteConnectOptions,
169 params: &HashMap<String, String, S>,
170 ) -> crate::Result<sea_orm::sqlx::sqlite::SqliteConnectOptions> {
171 for (key, value) in params {
172 let key_lower = key.to_lowercase();
173
174 if !ALLOWED_PRAGMAS.contains(&key_lower.as_str()) {
175 return Err(DbError::UnknownSqlitePragma(key.clone()));
176 }
177
178 match key_lower.as_str() {
179 "wal" => {
180 let journal_mode = validate_wal_pragma(value)?;
181 opts = opts.pragma("journal_mode", journal_mode);
182 }
183 "journal_mode" => {
184 let mode = validate_journal_mode_pragma(value)?;
185 opts = opts.pragma("journal_mode", mode);
186 }
187 "synchronous" => {
188 let sync_mode = validate_synchronous_pragma(value)?;
189 opts = opts.pragma("synchronous", sync_mode);
190 }
191 "busy_timeout" => {
192 let timeout = validate_busy_timeout_pragma(value)?;
193 opts = opts.pragma("busy_timeout", timeout.to_string());
194 }
195 _ => unreachable!("Checked against whitelist above"),
196 }
197 }
198
199 Ok(opts)
200 }
201
202 fn validate_wal_pragma(value: &str) -> crate::Result<&'static str> {
204 match value.to_lowercase().as_str() {
205 "true" | "1" => Ok("WAL"),
206 "false" | "0" => Ok("DELETE"),
207 _ => Err(DbError::InvalidSqlitePragma {
208 key: "wal".to_owned(),
209 message: format!("must be true/false/1/0, got '{value}'"),
210 }),
211 }
212 }
213
214 fn validate_synchronous_pragma(value: &str) -> crate::Result<String> {
216 match value.to_uppercase().as_str() {
217 "OFF" | "NORMAL" | "FULL" | "EXTRA" => Ok(value.to_uppercase()),
218 _ => Err(DbError::InvalidSqlitePragma {
219 key: "synchronous".to_owned(),
220 message: format!("must be OFF/NORMAL/FULL/EXTRA, got '{value}'"),
221 }),
222 }
223 }
224
225 fn validate_busy_timeout_pragma(value: &str) -> crate::Result<i64> {
227 let timeout = value
228 .parse::<i64>()
229 .map_err(|_| DbError::InvalidSqlitePragma {
230 key: "busy_timeout".to_owned(),
231 message: format!("must be a non-negative integer, got '{value}'"),
232 })?;
233
234 if timeout < 0 {
235 return Err(DbError::InvalidSqlitePragma {
236 key: "busy_timeout".to_owned(),
237 message: format!("must be non-negative, got '{timeout}'"),
238 });
239 }
240
241 Ok(timeout)
242 }
243
244 fn validate_journal_mode_pragma(value: &str) -> crate::Result<String> {
246 match value.to_uppercase().as_str() {
247 "DELETE" | "WAL" | "MEMORY" | "TRUNCATE" | "PERSIST" | "OFF" => {
248 Ok(value.to_uppercase())
249 }
250 _ => Err(DbError::InvalidSqlitePragma {
251 key: "journal_mode".to_owned(),
252 message: format!("must be DELETE/WAL/MEMORY/TRUNCATE/PERSIST/OFF, got '{value}'"),
253 }),
254 }
255 }
256}
257
258pub async fn build_db_handle(
264 mut cfg: DbConnConfig,
265 _global: Option<&GlobalDatabaseConfig>,
266) -> Result<DbHandle> {
267 if let Some(dsn) = &cfg.dsn {
269 cfg.dsn = Some(expand_env_vars(dsn)?);
270 }
271 if let Some(password) = &cfg.password {
272 cfg.password = Some(resolve_password(password)?);
273 }
274
275 if let Some(ref mut params) = cfg.params {
277 for (_, value) in params.iter_mut() {
278 if value.contains("${") {
279 *value = expand_env_vars(value)?;
280 }
281 }
282 }
283
284 validate_config_consistency(&cfg)?;
286
287 let is_sqlite = cfg.file.is_some()
289 || cfg.path.is_some()
290 || cfg
291 .dsn
292 .as_ref()
293 .is_some_and(|dsn| dsn.starts_with("sqlite"))
294 || (cfg.server.is_none() && cfg.dsn.is_none());
295
296 let connect_options = if is_sqlite {
297 build_sqlite_options(&cfg)?
298 } else {
299 build_server_options(&cfg)?
300 };
301
302 let pool_cfg = cfg.pool.unwrap_or_default();
304
305 let log_dsn = redact_credentials_in_dsn(cfg.dsn.as_deref());
307 tracing::debug!(
308 dsn = log_dsn,
309 is_sqlite = is_sqlite,
310 "Building database connection"
311 );
312
313 let handle = connect_options.connect(pool_cfg).await?;
315
316 Ok(handle)
317}
318
319#[cfg(feature = "sqlite")]
321fn build_sqlite_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
322 let db_path = if let Some(dsn) = &cfg.dsn {
323 parse_sqlite_path_from_dsn(dsn)?
324 } else if let Some(path) = &cfg.path {
325 path.clone()
326 } else if let Some(_file) = &cfg.file {
327 return Err(DbError::InvalidParameter(
329 "File path should have been resolved to absolute path".to_owned(),
330 ));
331 } else {
332 return Err(DbError::InvalidParameter(
333 "SQLite connection requires either DSN, path, or file".to_owned(),
334 ));
335 };
336
337 if let Some(parent) = db_path.parent() {
339 std::fs::create_dir_all(parent)?;
340 }
341
342 let mut opts = sea_orm::sqlx::sqlite::SqliteConnectOptions::new()
343 .filename(&db_path)
344 .create_if_missing(true);
345
346 if let Some(params) = &cfg.params {
348 opts = sqlite_pragma::apply_pragmas(opts, params)?;
349 }
350
351 Ok(DbConnectOptions::Sqlite(opts))
352}
353
354#[cfg(not(feature = "sqlite"))]
355fn build_sqlite_options(_: &DbConnConfig) -> Result<DbConnectOptions> {
356 Err(DbError::FeatureDisabled("SQLite feature not enabled"))
357}
358
359fn build_server_options(cfg: &DbConnConfig) -> Result<DbConnectOptions> {
361 let scheme = if let Some(dsn) = &cfg.dsn {
363 let parsed = url::Url::parse(dsn)?;
364 parsed.scheme().to_owned()
365 } else {
366 "postgresql".to_owned()
367 };
368
369 match scheme.as_str() {
370 "postgresql" | "postgres" => {
371 #[cfg(feature = "pg")]
372 {
373 let mut opts = if let Some(dsn) = &cfg.dsn {
374 dsn.parse::<sea_orm::sqlx::postgres::PgConnectOptions>()
375 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
376 } else {
377 sea_orm::sqlx::postgres::PgConnectOptions::new()
378 };
379
380 if let Some(host) = &cfg.host {
382 opts = opts.host(host);
383 }
384 if let Some(port) = cfg.port {
385 opts = opts.port(port);
386 }
387 if let Some(user) = &cfg.user {
388 opts = opts.username(user);
389 }
390 if let Some(password) = &cfg.password {
391 opts = opts.password(password);
392 }
393 if let Some(dbname) = &cfg.dbname {
394 opts = opts.database(dbname);
395 } else if cfg.dsn.is_none() {
396 return Err(DbError::InvalidParameter(
397 "dbname is required for PostgreSQL connections".to_owned(),
398 ));
399 }
400
401 if let Some(params) = &cfg.params {
403 for (key, value) in params {
404 opts = opts.options([(key.as_str(), value.as_str())]);
405 }
406 }
407
408 Ok(DbConnectOptions::Postgres(opts))
409 }
410 #[cfg(not(feature = "pg"))]
411 {
412 Err(DbError::FeatureDisabled("PostgreSQL feature not enabled"))
413 }
414 }
415 "mysql" => {
416 #[cfg(feature = "mysql")]
417 {
418 let mut opts = if let Some(dsn) = &cfg.dsn {
419 dsn.parse::<sea_orm::sqlx::mysql::MySqlConnectOptions>()
420 .map_err(|e| DbError::InvalidParameter(e.to_string()))?
421 } else {
422 sea_orm::sqlx::mysql::MySqlConnectOptions::new()
423 };
424
425 if let Some(host) = &cfg.host {
427 opts = opts.host(host);
428 }
429 if let Some(port) = cfg.port {
430 opts = opts.port(port);
431 }
432 if let Some(user) = &cfg.user {
433 opts = opts.username(user);
434 }
435 if let Some(password) = &cfg.password {
436 opts = opts.password(password);
437 }
438 if let Some(dbname) = &cfg.dbname {
439 opts = opts.database(dbname);
440 } else if cfg.dsn.is_none() {
441 return Err(DbError::InvalidParameter(
442 "dbname is required for MySQL connections".to_owned(),
443 ));
444 }
445
446 Ok(DbConnectOptions::MySql(opts))
447 }
448 #[cfg(not(feature = "mysql"))]
449 {
450 Err(DbError::FeatureDisabled("MySQL feature not enabled"))
451 }
452 }
453 _ => Err(DbError::InvalidParameter(format!(
454 "Unsupported database scheme: {scheme}"
455 ))),
456 }
457}
458
459#[cfg(feature = "sqlite")]
461fn parse_sqlite_path_from_dsn(dsn: &str) -> Result<std::path::PathBuf> {
462 if dsn.starts_with("sqlite:") {
463 let path_part = dsn
464 .strip_prefix("sqlite:")
465 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?;
466 let path_part = if path_part.starts_with("//") {
467 path_part
468 .strip_prefix("//")
469 .ok_or_else(|| DbError::InvalidParameter("Invalid SQLite DSN".to_owned()))?
470 } else {
471 path_part
472 };
473
474 let path_part = if let Some(pos) = path_part.find('?') {
476 &path_part[..pos]
477 } else {
478 path_part
479 };
480
481 Ok(std::path::PathBuf::from(path_part))
482 } else {
483 Err(DbError::InvalidParameter(format!(
484 "Invalid SQLite DSN: {dsn}"
485 )))
486 }
487}
488
489fn expand_env_vars(input: &str) -> Result<String> {
491 let re = regex::Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
492 .map_err(|e| DbError::InvalidParameter(e.to_string()))?;
493 let mut result = input.to_owned();
494
495 for caps in re.captures_iter(input) {
496 let full_match = &caps[0];
497 let var_name = &caps[1];
498 let value = std::env::var(var_name)?;
499 result = result.replace(full_match, &value);
500 }
501
502 Ok(result)
503}
504
505fn resolve_password(password: &str) -> Result<String> {
507 if password.starts_with("${") && password.ends_with('}') {
508 let var_name = &password[2..password.len() - 1];
509 Ok(std::env::var(var_name)?)
510 } else {
511 Ok(password.to_owned())
512 }
513}
514
515fn validate_config_consistency(cfg: &DbConnConfig) -> Result<()> {
517 if let Some(dsn) = &cfg.dsn {
519 let is_sqlite_dsn = dsn.starts_with("sqlite");
520 let has_sqlite_fields = cfg.file.is_some() || cfg.path.is_some();
521 let has_server_fields = cfg.host.is_some() || cfg.port.is_some();
522
523 if is_sqlite_dsn && has_server_fields {
524 return Err(DbError::ConfigConflict(
525 "SQLite DSN cannot be used with host/port fields".to_owned(),
526 ));
527 }
528
529 if !is_sqlite_dsn && has_sqlite_fields {
530 return Err(DbError::ConfigConflict(
531 "Non-SQLite DSN cannot be used with file/path fields".to_owned(),
532 ));
533 }
534
535 if !is_sqlite_dsn
537 && cfg.server.is_some()
538 && (cfg.host.is_some()
539 || cfg.port.is_some()
540 || cfg.user.is_some()
541 || cfg.password.is_some()
542 || cfg.dbname.is_some())
543 {
544 }
547 }
548
549 if cfg.file.is_some() && cfg.path.is_some() {
551 return Err(DbError::ConfigConflict(
552 "Cannot specify both 'file' and 'path' for SQLite - use one or the other".to_owned(),
553 ));
554 }
555
556 if (cfg.file.is_some() || cfg.path.is_some()) && (cfg.host.is_some() || cfg.port.is_some()) {
557 return Err(DbError::ConfigConflict(
558 "SQLite file/path fields cannot be used with host/port fields".to_owned(),
559 ));
560 }
561
562 Ok(())
563}
564
565#[must_use]
567pub fn redact_credentials_in_dsn(dsn: Option<&str>) -> String {
568 match dsn {
569 Some(dsn) if dsn.contains('@') => {
570 if let Ok(mut parsed) = url::Url::parse(dsn) {
571 if parsed.password().is_some() {
572 let _ = parsed.set_password(Some("***"));
573 }
574 parsed.to_string()
575 } else {
576 "***".to_owned()
577 }
578 }
579 Some(dsn) => dsn.to_owned(),
580 None => "none".to_owned(),
581 }
582}