1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2#![cfg_attr(
58 not(any(feature = "pg", feature = "mysql", feature = "sqlite")),
59 allow(
60 unused_imports,
61 unused_variables,
62 dead_code,
63 unreachable_code,
64 unused_lifetimes,
65 clippy::unused_async,
66 )
67)]
68
69pub use advisory_locks::{DbLockGuard, LockConfig};
71
72pub use sea_orm::ConnectionTrait as DbConnTrait;
73
74pub mod advisory_locks;
76pub mod config;
77pub mod manager;
78pub mod odata;
79pub mod options;
80pub mod secure;
81
82mod pool_opts;
84#[cfg(feature = "sqlite")]
85mod sqlite;
86
87pub use config::{DbConnConfig, GlobalDatabaseConfig, PoolCfg};
89pub use manager::DbManager;
90pub use options::{
91 ConnectionOptionsError, DbConnectOptions, build_db_handle, redact_credentials_in_dsn,
92};
93
94use std::time::Duration;
95
96#[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
98use pool_opts::ApplyPoolOpts;
99#[cfg(feature = "sqlite")]
100use sqlite::{Pragmas, extract_sqlite_pragmas, is_memory_dsn, prepare_sqlite_path};
101
102#[cfg(feature = "mysql")]
105use sea_orm::sqlx::{MySql, MySqlPool, mysql::MySqlPoolOptions};
106#[cfg(feature = "pg")]
107use sea_orm::sqlx::{PgPool, Postgres, postgres::PgPoolOptions};
108#[cfg(feature = "sqlite")]
109use sea_orm::sqlx::{Sqlite, SqlitePool, sqlite::SqlitePoolOptions};
110
111use sea_orm::DatabaseConnection;
112#[cfg(feature = "mysql")]
113use sea_orm::SqlxMySqlConnector;
114#[cfg(feature = "pg")]
115use sea_orm::SqlxPostgresConnector;
116#[cfg(feature = "sqlite")]
117use sea_orm::SqlxSqliteConnector;
118
119use thiserror::Error;
120
121pub type Result<T> = std::result::Result<T, DbError>;
123
124#[derive(Debug, Error)]
126pub enum DbError {
127 #[error("Unknown DSN: {0}")]
128 UnknownDsn(String),
129
130 #[error("Feature not enabled: {0}")]
131 FeatureDisabled(&'static str),
132
133 #[error("Invalid configuration: {0}")]
134 InvalidConfig(String),
135
136 #[error("Configuration conflict: {0}")]
137 ConfigConflict(String),
138
139 #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
140 InvalidSqlitePragma { key: String, message: String },
141
142 #[error("Unknown SQLite PRAGMA parameter: {0}")]
143 UnknownSqlitePragma(String),
144
145 #[error("Invalid connection parameter: {0}")]
146 InvalidParameter(String),
147
148 #[error("SQLite pragma error: {0}")]
149 SqlitePragma(String),
150
151 #[error("Environment variable error: {0}")]
152 EnvVar(#[from] std::env::VarError),
153
154 #[error("URL parsing error: {0}")]
155 UrlParse(#[from] url::ParseError),
156
157 #[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
158 #[error(transparent)]
159 Sqlx(#[from] sea_orm::sqlx::Error),
160
161 #[error(transparent)]
162 Sea(#[from] sea_orm::DbErr),
163
164 #[error(transparent)]
165 Io(#[from] std::io::Error),
166
167 #[error(transparent)]
169 Lock(#[from] advisory_locks::DbLockError),
170
171 #[error(transparent)]
173 ConnectionOptions(#[from] ConnectionOptionsError),
174
175 #[error(transparent)]
176 Other(#[from] anyhow::Error),
177}
178
179#[derive(Clone, Copy, Debug, PartialEq, Eq)]
181pub enum DbEngine {
182 Postgres,
183 MySql,
184 Sqlite,
185}
186
187#[derive(Clone, Debug)]
190pub struct ConnectOpts {
191 pub max_conns: Option<u32>,
193 pub min_conns: Option<u32>,
195 pub acquire_timeout: Option<Duration>,
197 pub idle_timeout: Option<Duration>,
199 pub max_lifetime: Option<Duration>,
201 pub test_before_acquire: bool,
203 pub create_sqlite_dirs: bool,
205}
206impl Default for ConnectOpts {
207 fn default() -> Self {
208 Self {
209 max_conns: Some(10),
210 min_conns: None,
211 acquire_timeout: Some(Duration::from_secs(30)),
212 idle_timeout: None,
213 max_lifetime: None,
214 test_before_acquire: false,
215
216 create_sqlite_dirs: true,
217 }
218 }
219}
220
221#[derive(Clone, Debug)]
223pub enum DbPool {
224 #[cfg(feature = "pg")]
225 Postgres(PgPool),
226 #[cfg(feature = "mysql")]
227 MySql(MySqlPool),
228 #[cfg(feature = "sqlite")]
229 Sqlite(SqlitePool),
230}
231
232pub enum DbTransaction<'a> {
234 #[cfg(feature = "pg")]
235 Postgres(sea_orm::sqlx::Transaction<'a, Postgres>),
236 #[cfg(feature = "mysql")]
237 MySql(sea_orm::sqlx::Transaction<'a, MySql>),
238 #[cfg(feature = "sqlite")]
239 Sqlite(sea_orm::sqlx::Transaction<'a, Sqlite>),
240 #[cfg(not(any(feature = "pg", feature = "mysql", feature = "sqlite")))]
243 _Phantom(std::marker::PhantomData<&'a ()>),
244}
245
246impl DbTransaction<'_> {
247 pub async fn commit(self) -> Result<()> {
252 match self {
253 #[cfg(feature = "pg")]
254 DbTransaction::Postgres(tx) => tx.commit().await.map_err(Into::into),
255 #[cfg(feature = "mysql")]
256 DbTransaction::MySql(tx) => tx.commit().await.map_err(Into::into),
257 #[cfg(feature = "sqlite")]
258 DbTransaction::Sqlite(tx) => tx.commit().await.map_err(Into::into),
259 #[cfg(not(any(feature = "pg", feature = "mysql", feature = "sqlite")))]
260 DbTransaction::_Phantom(_) => Ok(()),
261 }
262 }
263
264 pub async fn rollback(self) -> Result<()> {
269 match self {
270 #[cfg(feature = "pg")]
271 DbTransaction::Postgres(tx) => tx.rollback().await.map_err(Into::into),
272 #[cfg(feature = "mysql")]
273 DbTransaction::MySql(tx) => tx.rollback().await.map_err(Into::into),
274 #[cfg(feature = "sqlite")]
275 DbTransaction::Sqlite(tx) => tx.rollback().await.map_err(Into::into),
276 #[cfg(not(any(feature = "pg", feature = "mysql", feature = "sqlite")))]
277 DbTransaction::_Phantom(_) => Ok(()),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
284pub struct DbHandle {
285 engine: DbEngine,
286 pool: DbPool,
287 dsn: String,
288 sea: DatabaseConnection,
289}
290
291#[cfg(feature = "sqlite")]
292const DEFAULT_SQLITE_BUSY_TIMEOUT: i32 = 5000;
293
294impl DbHandle {
295 pub fn detect(dsn: &str) -> Result<DbEngine> {
302 let s = dsn.trim_start();
304
305 if s.starts_with("postgres://") || s.starts_with("postgresql://") {
308 Ok(DbEngine::Postgres)
309 } else if s.starts_with("mysql://") {
310 Ok(DbEngine::MySql)
311 } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
312 Ok(DbEngine::Sqlite)
313 } else {
314 Err(DbError::UnknownDsn(dsn.to_owned()))
315 }
316 }
317
318 pub async fn connect(dsn: &str, opts: ConnectOpts) -> Result<Self> {
323 let engine = Self::detect(dsn)?;
324 match engine {
325 #[cfg(feature = "pg")]
326 DbEngine::Postgres => {
327 let o = PgPoolOptions::new().apply(&opts);
328 let pool = o.connect(dsn).await?;
329 let sea = SqlxPostgresConnector::from_sqlx_postgres_pool(pool.clone());
330 Ok(Self {
331 engine,
332 pool: DbPool::Postgres(pool),
333 dsn: dsn.to_owned(),
334 sea,
335 })
336 }
337 #[cfg(not(feature = "pg"))]
338 DbEngine::Postgres => Err(DbError::FeatureDisabled("PostgreSQL feature not enabled")),
339 #[cfg(feature = "mysql")]
340 DbEngine::MySql => {
341 let o = MySqlPoolOptions::new().apply(&opts);
342 let pool = o.connect(dsn).await?;
343 let sea = SqlxMySqlConnector::from_sqlx_mysql_pool(pool.clone());
344 Ok(Self {
345 engine,
346 pool: DbPool::MySql(pool),
347 dsn: dsn.to_owned(),
348 sea,
349 })
350 }
351 #[cfg(not(feature = "mysql"))]
352 DbEngine::MySql => Err(DbError::FeatureDisabled("MySQL feature not enabled")),
353 #[cfg(feature = "sqlite")]
354 DbEngine::Sqlite => {
355 let dsn = prepare_sqlite_path(dsn, opts.create_sqlite_dirs)?;
356
357 let (clean_dsn, pairs) = extract_sqlite_pragmas(&dsn);
359 let pragmas = Pragmas::from_pairs(&pairs);
360
361 let mut o = SqlitePoolOptions::new().apply(&opts);
363
364 let is_memory = is_memory_dsn(&clean_dsn);
366 o = o.after_connect(move |conn, _meta| {
367 let pragmas = pragmas.clone();
368 Box::pin(async move {
369 let journal_mode = if let Some(mode) = &pragmas.journal_mode {
371 mode.as_sql()
372 } else if let Some(wal_toggle) = pragmas.wal_toggle {
373 if wal_toggle { "WAL" } else { "DELETE" }
374 } else if is_memory {
375 "DELETE"
377 } else {
378 "WAL"
379 };
380
381 let stmt = format!("PRAGMA journal_mode = {journal_mode}");
382 sea_orm::sqlx::query(&stmt).execute(&mut *conn).await?;
383
384 let sync_mode = pragmas
386 .synchronous
387 .as_ref()
388 .map_or("NORMAL", |s| s.as_sql());
389 let stmt = format!("PRAGMA synchronous = {sync_mode}");
390 sea_orm::sqlx::query(&stmt).execute(&mut *conn).await?;
391
392 if !is_memory {
394 let timeout = pragmas
395 .busy_timeout_ms
396 .unwrap_or(DEFAULT_SQLITE_BUSY_TIMEOUT.into());
397 sea_orm::sqlx::query("PRAGMA busy_timeout = ?")
398 .bind(timeout)
399 .execute(&mut *conn)
400 .await?;
401 }
402
403 Ok(())
404 })
405 });
406
407 let pool = o.connect(&clean_dsn).await?;
408 let sea = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool.clone());
409
410 Ok(Self {
411 engine,
412 pool: DbPool::Sqlite(pool),
413 dsn: clean_dsn,
414 sea,
415 })
416 }
417 #[cfg(not(feature = "sqlite"))]
418 DbEngine::Sqlite => Err(DbError::FeatureDisabled("SQLite feature not enabled")),
419 }
420 }
421
422 pub async fn close(self) {
424 match self.pool {
425 #[cfg(feature = "pg")]
426 DbPool::Postgres(p) => p.close().await,
427 #[cfg(feature = "mysql")]
428 DbPool::MySql(p) => p.close().await,
429 #[cfg(feature = "sqlite")]
430 DbPool::Sqlite(p) => p.close().await,
431 }
432 }
433
434 #[must_use]
436 pub fn engine(&self) -> DbEngine {
437 self.engine
438 }
439
440 #[must_use]
442 pub fn dsn(&self) -> &str {
443 &self.dsn
444 }
445
446 #[cfg(feature = "pg")]
448 #[must_use]
449 pub fn sqlx_postgres(&self) -> Option<&PgPool> {
450 match self.pool {
451 DbPool::Postgres(ref p) => Some(p),
452 #[cfg(any(feature = "mysql", feature = "sqlite"))]
453 _ => None,
454 }
455 }
456 #[cfg(feature = "mysql")]
457 #[must_use]
458 pub fn sqlx_mysql(&self) -> Option<&MySqlPool> {
459 match self.pool {
460 DbPool::MySql(ref p) => Some(p),
461 #[cfg(any(feature = "pg", feature = "sqlite"))]
462 _ => None,
463 }
464 }
465 #[cfg(feature = "sqlite")]
466 #[must_use]
467 pub fn sqlx_sqlite(&self) -> Option<&SqlitePool> {
468 match self.pool {
469 DbPool::Sqlite(ref p) => Some(p),
470 #[cfg(any(feature = "pg", feature = "mysql"))]
471 _ => None,
472 }
473 }
474
475 #[must_use]
498 pub fn sea_secure(&self) -> crate::secure::SecureConn {
499 crate::secure::SecureConn::new(self.sea.clone())
500 }
501
502 #[cfg(feature = "insecure-escape")]
528 pub fn sea(&self) -> DatabaseConnection {
529 tracing::warn!(
530 target: "security",
531 "DbHandle::sea() called - bypassing secure ORM layer"
532 );
533 self.sea.clone()
534 }
535
536 #[cfg(feature = "pg")]
543 pub async fn with_pg_tx<F, T>(&self, f: F) -> Result<T>
544 where
545 F: for<'a> FnOnce(
546 &'a mut sea_orm::sqlx::Transaction<'_, Postgres>,
547 ) -> std::pin::Pin<
548 Box<dyn std::future::Future<Output = Result<T>> + Send + 'a>,
549 >,
550 {
551 let pool = self
552 .sqlx_postgres()
553 .ok_or(DbError::FeatureDisabled("not a postgres pool"))?;
554 let mut tx = pool.begin().await?;
555 let res = f(&mut tx).await;
556 match res {
557 Ok(v) => {
558 tx.commit().await?;
559 Ok(v)
560 }
561 Err(e) => {
562 let _ = tx.rollback().await;
564 Err(e)
565 }
566 }
567 }
568
569 #[cfg(feature = "mysql")]
574 pub async fn with_mysql_tx<F, T>(&self, f: F) -> Result<T>
575 where
576 F: for<'a> FnOnce(
577 &'a mut sea_orm::sqlx::Transaction<'_, MySql>,
578 ) -> std::pin::Pin<
579 Box<dyn std::future::Future<Output = Result<T>> + Send + 'a>,
580 >,
581 {
582 let pool = self
583 .sqlx_mysql()
584 .ok_or(DbError::FeatureDisabled("not a mysql pool"))?;
585 let mut tx = pool.begin().await?;
586 let res = f(&mut tx).await;
587 match res {
588 Ok(v) => {
589 tx.commit().await?;
590 Ok(v)
591 }
592 Err(e) => {
593 let _ = tx.rollback().await;
594 Err(e)
595 }
596 }
597 }
598
599 #[cfg(feature = "sqlite")]
604 pub async fn with_sqlite_tx<F, T>(&self, f: F) -> Result<T>
605 where
606 F: for<'a> FnOnce(
607 &'a mut sea_orm::sqlx::Transaction<'_, Sqlite>,
608 ) -> std::pin::Pin<
609 Box<dyn std::future::Future<Output = Result<T>> + Send + 'a>,
610 >,
611 {
612 let pool = self
613 .sqlx_sqlite()
614 .ok_or(DbError::FeatureDisabled("not a sqlite pool"))?;
615 let mut tx = pool.begin().await?;
616 let res = f(&mut tx).await;
617 match res {
618 Ok(v) => {
619 tx.commit().await?;
620 Ok(v)
621 }
622 Err(e) => {
623 let _ = tx.rollback().await;
624 Err(e)
625 }
626 }
627 }
628
629 pub async fn lock(&self, module: &str, key: &str) -> Result<DbLockGuard> {
636 let lock_manager =
637 advisory_locks::LockManager::new(self.engine, self.pool.clone(), self.dsn.clone());
638 let guard = lock_manager.lock(module, key).await?;
639 Ok(guard)
640 }
641
642 pub async fn try_lock(
647 &self,
648 module: &str,
649 key: &str,
650 config: LockConfig,
651 ) -> Result<Option<DbLockGuard>> {
652 let lock_manager =
653 advisory_locks::LockManager::new(self.engine, self.pool.clone(), self.dsn.clone());
654 let res = lock_manager.try_lock(module, key, config).await?;
655 Ok(res)
656 }
657
658 pub async fn begin(&self) -> Result<DbTransaction<'_>> {
665 match &self.pool {
666 #[cfg(feature = "pg")]
667 DbPool::Postgres(pool) => {
668 let tx = pool.begin().await?;
669 Ok(DbTransaction::Postgres(tx))
670 }
671 #[cfg(feature = "mysql")]
672 DbPool::MySql(pool) => {
673 let tx = pool.begin().await?;
674 Ok(DbTransaction::MySql(tx))
675 }
676 #[cfg(feature = "sqlite")]
677 DbPool::Sqlite(pool) => {
678 let tx = pool.begin().await?;
679 Ok(DbTransaction::Sqlite(tx))
680 }
681 #[cfg(not(any(feature = "pg", feature = "mysql", feature = "sqlite")))]
682 _ => Err(DbError::FeatureDisabled("no database backends enabled")),
683 }
684 }
685}
686
687#[cfg(test)]
690#[cfg_attr(coverage_nightly, coverage(off))]
691mod tests {
692 use super::*;
693 #[cfg(feature = "sqlite")]
694 use tokio::time::Duration;
695
696 #[cfg(feature = "sqlite")]
697 #[tokio::test]
698 async fn test_sqlite_connection() -> Result<()> {
699 let dsn = "sqlite::memory:";
700 let opts = ConnectOpts::default();
701 let db = DbHandle::connect(dsn, opts).await?;
702 assert_eq!(db.engine(), DbEngine::Sqlite);
703 Ok(())
704 }
705
706 #[cfg(feature = "sqlite")]
707 #[tokio::test]
708 async fn test_sqlite_connection_with_pragma_parameters() -> Result<()> {
709 let dsn = "sqlite::memory:?wal=true&synchronous=NORMAL&busy_timeout=5000&journal_mode=WAL";
711 let opts = ConnectOpts::default();
712 let db = DbHandle::connect(dsn, opts).await?;
713 assert_eq!(db.engine(), DbEngine::Sqlite);
714
715 assert!(db.dsn == "sqlite::memory:" || db.dsn.starts_with("sqlite::memory:"));
718
719 let pool = db.sqlx_sqlite().unwrap();
721 sea_orm::sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
722 .execute(pool)
723 .await?;
724 sea_orm::sqlx::query("INSERT INTO test (name) VALUES (?)")
725 .bind("test_value")
726 .execute(pool)
727 .await?;
728
729 let row: (i64, String) = sea_orm::sqlx::query_as("SELECT id, name FROM test WHERE id = 1")
730 .fetch_one(pool)
731 .await?;
732
733 assert_eq!(row.0, 1);
734 assert_eq!(row.1, "test_value");
735
736 Ok(())
737 }
738
739 #[tokio::test]
740 async fn test_backend_detection() {
741 assert_eq!(
742 DbHandle::detect("sqlite::memory:").unwrap(),
743 DbEngine::Sqlite
744 );
745 assert_eq!(
746 DbHandle::detect("postgres://localhost/test").unwrap(),
747 DbEngine::Postgres
748 );
749 assert_eq!(
750 DbHandle::detect("mysql://localhost/test").unwrap(),
751 DbEngine::MySql
752 );
753 assert!(DbHandle::detect("unknown://test").is_err());
754 }
755
756 #[cfg(feature = "sqlite")]
757 #[tokio::test]
758 async fn test_advisory_lock_sqlite() -> Result<()> {
759 let dsn = "sqlite:file:memdb1?mode=memory&cache=shared";
760 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
761
762 let now = std::time::SystemTime::now()
763 .duration_since(std::time::UNIX_EPOCH)
764 .map_or(0, |d| d.as_nanos());
765 let test_id = format!("test_basic_{now}");
766
767 let guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
768 let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
769 let _guard3 = db
770 .lock("different_module", &format!("{test_id}_key1"))
771 .await?;
772
773 guard1.release().await;
775 let _guard4 = db.lock("test_module", &format!("{test_id}_key1")).await?;
776 Ok(())
777 }
778
779 #[cfg(feature = "sqlite")]
780 #[tokio::test]
781 async fn test_advisory_lock_different_keys() -> Result<()> {
782 let dsn = "sqlite:file:memdb_diff_keys?mode=memory&cache=shared";
783 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
784
785 let now = std::time::SystemTime::now()
786 .duration_since(std::time::UNIX_EPOCH)
787 .map_or(0, |d| d.as_nanos());
788 let test_id = format!("test_diff_{now}");
789
790 let _guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
791 let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
792 let _guard3 = db.lock("other_module", &format!("{test_id}_key1")).await?;
793 Ok(())
794 }
795
796 #[cfg(feature = "sqlite")]
797 #[tokio::test]
798 async fn test_try_lock_with_config() -> Result<()> {
799 let dsn = "sqlite:file:memdb2?mode=memory&cache=shared";
800 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
801
802 let now = std::time::SystemTime::now()
803 .duration_since(std::time::UNIX_EPOCH)
804 .map_or(0, |d| d.as_nanos());
805 let test_id = format!("test_config_{now}");
806
807 let _guard1 = db.lock("test_module", &format!("{test_id}_key")).await?;
808
809 let config = LockConfig {
810 max_wait: Some(Duration::from_millis(200)),
811 initial_backoff: Duration::from_millis(50),
812 max_attempts: Some(3),
813 ..Default::default()
814 };
815
816 let result = db
817 .try_lock("test_module", &format!("{test_id}_different_key"), config)
818 .await?;
819 assert!(
820 result.is_some(),
821 "expected lock acquisition for different key"
822 );
823 Ok(())
824 }
825
826 #[cfg(feature = "sqlite")]
827 #[tokio::test]
828 async fn test_transaction() -> Result<()> {
829 let dsn = "sqlite::memory:";
830 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
831 let tx = db.begin().await?;
832 tx.commit().await?;
833 Ok(())
834 }
835
836 #[cfg(feature = "sqlite")]
837 #[tokio::test]
838 async fn test_secure_conn() -> Result<()> {
839 let dsn = "sqlite::memory:";
840 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
841
842 let _secure_conn = db.sea_secure();
843 Ok(())
844 }
845
846 #[cfg(all(feature = "sqlite", feature = "insecure-escape"))]
847 #[tokio::test]
848 async fn test_insecure_sea_access() -> Result<()> {
849 let dsn = "sqlite::memory:";
850 let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
851
852 let _raw = db.sea();
854 Ok(())
855 }
856}