Skip to main content

modkit_db/
lib.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2//! `ModKit` Database abstraction crate.
3//!
4//! This crate provides a unified interface for working with different databases
5//! (`SQLite`, `PostgreSQL`, `MySQL`) through `SQLx`, with optional `SeaORM` integration.
6//! It emphasizes typed connection options over DSN string manipulation and
7//! implements strict security controls (e.g., `SQLite` PRAGMA whitelist).
8//!
9//! # Features
10//! - `pg`, `mysql`, `sqlite`: enable `SQLx` backends
11//! - `sea-orm`: add `SeaORM` integration for type-safe operations
12//! - `preview-outbox`: enable the transactional outbox pipeline (experimental — API may change)
13//!
14//! # New Architecture
15//! The crate now supports:
16//! - Typed `DbConnectOptions` using sqlx `ConnectOptions` (no DSN string building)
17//! - Per-module database factories with configuration merging
18//! - `SQLite` PRAGMA whitelist for security
19//! - Environment variable expansion in passwords and DSNs
20//!
21//! # Example (`DbManager` API)
22//! ```rust,no_run
23//! use modkit_db::{DbManager, GlobalDatabaseConfig, DbConnConfig};
24//! use figment::{Figment, providers::Serialized};
25//! use std::path::PathBuf;
26//! use std::sync::Arc;
27//!
28//! // Create configuration using Figment
29//! let figment = Figment::new()
30//!     .merge(Serialized::defaults(serde_json::json!({
31//!         "db": {
32//!             "servers": {
33//!                 "main": {
34//!                     "host": "localhost",
35//!                     "port": 5432,
36//!                     "user": "app",
37//!                     "password": "${DB_PASSWORD}",
38//!                     "dbname": "app_db"
39//!                 }
40//!             }
41//!         },
42//!         "test_module": {
43//!             "database": {
44//!                 "server": "main",
45//!                 "dbname": "module_db"
46//!             }
47//!         }
48//!     })));
49//!
50//! // Create DbManager
51//! let home_dir = PathBuf::from("/app/data");
52//! let db_manager = Arc::new(DbManager::from_figment(figment, home_dir).unwrap());
53//!
54//! // Use in runtime with DbOptions::Manager(db_manager)
55//! // Modules can then use: ctx.db_required_async().await?
56//! ```
57
58#![cfg_attr(
59    not(any(feature = "pg", feature = "mysql", feature = "sqlite")),
60    allow(
61        unused_imports,
62        unused_variables,
63        dead_code,
64        unreachable_code,
65        unused_lifetimes,
66        clippy::unused_async,
67    )
68)]
69
70// Re-export key types for public API
71pub use advisory_locks::{DbLockGuard, LockConfig};
72
73// Re-export sea_orm_migration for modules that implement DatabaseCapability
74pub use sea_orm_migration;
75
76// Core modules
77pub mod advisory_locks;
78pub mod config;
79pub mod manager;
80pub mod migration_runner;
81pub mod odata;
82pub mod options;
83
84#[cfg(feature = "preview-outbox")]
85pub mod outbox;
86pub mod secure;
87
88mod db_provider;
89
90// Internal modules
91mod pool_opts;
92#[cfg(feature = "sqlite")]
93mod sqlite;
94
95// Re-export important types from new modules
96pub use config::{DbConnConfig, GlobalDatabaseConfig, PoolCfg};
97pub use manager::DbManager;
98pub use options::redact_credentials_in_dsn;
99
100// Re-export secure database types for convenience
101pub use secure::{Db, DbConn, DbTx};
102
103// Re-export service-friendly provider
104pub use db_provider::DBProvider;
105
106/// Connect and return a secure `Db` (no `DbHandle` exposure).
107///
108/// This is the public constructor intended for module code and tests.
109///
110/// # Errors
111///
112/// Returns `DbError` if the connection fails or the DSN/options are invalid.
113pub async fn connect_db(dsn: &str, opts: ConnectOpts) -> Result<Db> {
114    let handle = DbHandle::connect(dsn, opts).await?;
115    Ok(Db::new(handle))
116}
117
118/// Build a secure `Db` from config (no `DbHandle` exposure).
119///
120/// # Errors
121///
122/// Returns `DbError` if configuration is invalid or connection fails.
123pub async fn build_db(cfg: DbConnConfig, global: Option<&GlobalDatabaseConfig>) -> Result<Db> {
124    let handle = options::build_db_handle(cfg, global).await?;
125    Ok(Db::new(handle))
126}
127
128use std::time::Duration;
129
130// Internal imports
131#[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
132use pool_opts::ApplyPoolOpts;
133#[cfg(feature = "sqlite")]
134use sqlite::{Pragmas, extract_sqlite_pragmas, is_memory_dsn, prepare_sqlite_path};
135
136// Used for parsing SQLite DSN query parameters
137
138#[cfg(feature = "mysql")]
139use sqlx::mysql::MySqlPoolOptions;
140#[cfg(feature = "pg")]
141use sqlx::postgres::PgPoolOptions;
142#[cfg(feature = "sqlite")]
143use sqlx::sqlite::SqlitePoolOptions;
144#[cfg(feature = "sqlite")]
145use std::str::FromStr;
146
147use sea_orm::DatabaseConnection;
148#[cfg(feature = "mysql")]
149use sea_orm::SqlxMySqlConnector;
150#[cfg(feature = "pg")]
151use sea_orm::SqlxPostgresConnector;
152#[cfg(feature = "sqlite")]
153use sea_orm::SqlxSqliteConnector;
154
155use thiserror::Error;
156
157/// Library-local result type.
158pub type Result<T> = std::result::Result<T, DbError>;
159
160/// Typed error for the DB handle and helpers.
161#[derive(Debug, Error)]
162pub enum DbError {
163    #[error("Unknown DSN: {0}")]
164    UnknownDsn(String),
165
166    #[error("Feature not enabled: {0}")]
167    FeatureDisabled(&'static str),
168
169    #[error("Invalid configuration: {0}")]
170    InvalidConfig(String),
171
172    #[error("Configuration conflict: {0}")]
173    ConfigConflict(String),
174
175    #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
176    InvalidSqlitePragma { key: String, message: String },
177
178    #[error("Unknown SQLite PRAGMA parameter: {0}")]
179    UnknownSqlitePragma(String),
180
181    #[error("Invalid connection parameter: {0}")]
182    InvalidParameter(String),
183
184    #[error("SQLite pragma error: {0}")]
185    SqlitePragma(String),
186
187    #[error("Environment variable error: {0}")]
188    EnvVar(#[from] std::env::VarError),
189
190    #[error("URL parsing error: {0}")]
191    UrlParse(#[from] url::ParseError),
192
193    #[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
194    #[error(transparent)]
195    Sqlx(#[from] sqlx::Error),
196
197    #[error(transparent)]
198    Sea(#[from] sea_orm::DbErr),
199
200    #[error(transparent)]
201    Io(#[from] std::io::Error),
202
203    // make advisory_locks errors flow into DbError via `?`
204    #[error(transparent)]
205    Lock(#[from] advisory_locks::DbLockError),
206
207    #[error(transparent)]
208    Other(#[from] anyhow::Error),
209
210    /// Attempted to create a non-transactional connection inside an active transaction.
211    ///
212    /// This error occurs when `Db::conn()` is called from within a transaction closure.
213    /// The transaction guard prevents this to avoid accidental data bypass where writes
214    /// would persist outside the transaction scope.
215    ///
216    /// # Resolution
217    ///
218    /// Use the transaction runner (`tx`) provided to the closure instead of creating
219    /// a new connection:
220    ///
221    /// ```ignore
222    /// // Wrong - fails with ConnRequestedInsideTx
223    /// db.transaction(|_tx| {
224    ///     let conn = some_db.conn()?;  // Error!
225    ///     ...
226    /// });
227    ///
228    /// // Correct - use the transaction runner
229    /// db.transaction(|tx| {
230    ///     Entity::find().secure().scope_with(&scope).one(tx).await?;
231    ///     ...
232    /// });
233    /// ```
234    #[error("Cannot create non-transactional connection inside an active transaction")]
235    ConnRequestedInsideTx,
236}
237
238impl From<crate::secure::ScopeError> for DbError {
239    fn from(value: crate::secure::ScopeError) -> Self {
240        // Scope errors are not infra connection errors, but they still originate from the DB
241        // access layer. We keep the wrapper thin and preserve the message for callers.
242        DbError::Other(anyhow::Error::new(value))
243    }
244}
245
246/// Supported engines.
247#[derive(Clone, Copy, Debug, PartialEq, Eq)]
248pub enum DbEngine {
249    Postgres,
250    MySql,
251    Sqlite,
252}
253
254/// Connection options.
255/// Extended to cover common sqlx pool knobs; each driver applies the subset it supports.
256#[derive(Clone, Debug)]
257pub struct ConnectOpts {
258    /// Maximum number of connections in the pool.
259    pub max_conns: Option<u32>,
260    /// Minimum number of connections in the pool.
261    pub min_conns: Option<u32>,
262    /// Timeout to acquire a connection from the pool.
263    pub acquire_timeout: Option<Duration>,
264    /// Idle timeout before a connection is closed.
265    pub idle_timeout: Option<Duration>,
266    /// Maximum lifetime for a connection.
267    pub max_lifetime: Option<Duration>,
268    /// Test connection health before acquire.
269    pub test_before_acquire: bool,
270    /// For `SQLite` file DSNs, create parent directories if missing.
271    pub create_sqlite_dirs: bool,
272}
273impl Default for ConnectOpts {
274    fn default() -> Self {
275        Self {
276            max_conns: Some(10),
277            min_conns: None,
278            acquire_timeout: Some(Duration::from_secs(30)),
279            idle_timeout: None,
280            max_lifetime: None,
281            test_before_acquire: false,
282
283            create_sqlite_dirs: true,
284        }
285    }
286}
287
288/// Main handle.
289#[derive(Debug, Clone)]
290pub(crate) struct DbHandle {
291    engine: DbEngine,
292    dsn: String,
293    sea: DatabaseConnection,
294}
295
296#[cfg(feature = "sqlite")]
297const DEFAULT_SQLITE_BUSY_TIMEOUT: i32 = 5000;
298
299impl DbHandle {
300    /// Detect engine by DSN.
301    ///
302    /// Note: we only check scheme prefixes and don't mutate the tail (credentials etc.).
303    ///
304    /// # Errors
305    /// Returns `DbError::UnknownDsn` if the DSN scheme is not recognized.
306    pub(crate) fn detect(dsn: &str) -> Result<DbEngine> {
307        // Trim only leading spaces/newlines to be forgiving with env files.
308        let s = dsn.trim_start();
309
310        // Explicit, case-sensitive checks for common schemes.
311        // Add more variants as needed (e.g., postgres+unix://).
312        if s.starts_with("postgres://") || s.starts_with("postgresql://") {
313            Ok(DbEngine::Postgres)
314        } else if s.starts_with("mysql://") {
315            Ok(DbEngine::MySql)
316        } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
317            Ok(DbEngine::Sqlite)
318        } else {
319            Err(DbError::UnknownDsn(dsn.to_owned()))
320        }
321    }
322
323    /// Connect and build handle.
324    ///
325    /// # Errors
326    /// Returns an error if the connection fails or the DSN is invalid.
327    pub(crate) async fn connect(dsn: &str, opts: ConnectOpts) -> Result<Self> {
328        let engine = Self::detect(dsn)?;
329        match engine {
330            #[cfg(feature = "pg")]
331            DbEngine::Postgres => {
332                let o = PgPoolOptions::new().apply(&opts);
333                let pool = o.connect(dsn).await?;
334                let sea = SqlxPostgresConnector::from_sqlx_postgres_pool(pool);
335                Ok(Self {
336                    engine,
337                    dsn: dsn.to_owned(),
338                    sea,
339                })
340            }
341            #[cfg(not(feature = "pg"))]
342            DbEngine::Postgres => Err(DbError::FeatureDisabled("PostgreSQL feature not enabled")),
343            #[cfg(feature = "mysql")]
344            DbEngine::MySql => {
345                let o = MySqlPoolOptions::new().apply(&opts);
346                let pool = o.connect(dsn).await?;
347                let sea = SqlxMySqlConnector::from_sqlx_mysql_pool(pool);
348                Ok(Self {
349                    engine,
350                    dsn: dsn.to_owned(),
351                    sea,
352                })
353            }
354            #[cfg(not(feature = "mysql"))]
355            DbEngine::MySql => Err(DbError::FeatureDisabled("MySQL feature not enabled")),
356            #[cfg(feature = "sqlite")]
357            DbEngine::Sqlite => {
358                let dsn = prepare_sqlite_path(dsn, opts.create_sqlite_dirs)?;
359
360                // Extract SQLite PRAGMA parameters from DSN
361                let (clean_dsn, pairs) = extract_sqlite_pragmas(&dsn);
362                let pragmas = Pragmas::from_pairs(&pairs);
363
364                // Build pool options with shared trait
365                let o = SqlitePoolOptions::new().apply(&opts);
366
367                // Apply SQLite pragmas using typed `sqlx` connect options (no raw SQL).
368                let is_memory = is_memory_dsn(&clean_dsn);
369                let mut conn_opts = sqlx::sqlite::SqliteConnectOptions::from_str(&clean_dsn)?;
370
371                let journal_mode = if let Some(mode) = &pragmas.journal_mode {
372                    match mode {
373                        sqlite::pragmas::JournalMode::Delete => {
374                            sqlx::sqlite::SqliteJournalMode::Delete
375                        }
376                        sqlite::pragmas::JournalMode::Wal => sqlx::sqlite::SqliteJournalMode::Wal,
377                        sqlite::pragmas::JournalMode::Memory => {
378                            sqlx::sqlite::SqliteJournalMode::Memory
379                        }
380                        sqlite::pragmas::JournalMode::Truncate => {
381                            sqlx::sqlite::SqliteJournalMode::Truncate
382                        }
383                        sqlite::pragmas::JournalMode::Persist => {
384                            sqlx::sqlite::SqliteJournalMode::Persist
385                        }
386                        sqlite::pragmas::JournalMode::Off => sqlx::sqlite::SqliteJournalMode::Off,
387                    }
388                } else if let Some(wal_toggle) = pragmas.wal_toggle {
389                    if wal_toggle {
390                        sqlx::sqlite::SqliteJournalMode::Wal
391                    } else {
392                        sqlx::sqlite::SqliteJournalMode::Delete
393                    }
394                } else if is_memory {
395                    sqlx::sqlite::SqliteJournalMode::Delete
396                } else {
397                    sqlx::sqlite::SqliteJournalMode::Wal
398                };
399                conn_opts = conn_opts.journal_mode(journal_mode);
400
401                let sync_mode = pragmas.synchronous.as_ref().map_or(
402                    sqlx::sqlite::SqliteSynchronous::Normal,
403                    |s| match s {
404                        sqlite::pragmas::SyncMode::Off => sqlx::sqlite::SqliteSynchronous::Off,
405                        sqlite::pragmas::SyncMode::Normal => {
406                            sqlx::sqlite::SqliteSynchronous::Normal
407                        }
408                        sqlite::pragmas::SyncMode::Full => sqlx::sqlite::SqliteSynchronous::Full,
409                        sqlite::pragmas::SyncMode::Extra => sqlx::sqlite::SqliteSynchronous::Extra,
410                    },
411                );
412                conn_opts = conn_opts.synchronous(sync_mode);
413
414                if !is_memory {
415                    let busy_timeout_ms_i64 = pragmas
416                        .busy_timeout_ms
417                        .unwrap_or(DEFAULT_SQLITE_BUSY_TIMEOUT.into())
418                        .max(0);
419                    let busy_timeout_ms = u64::try_from(busy_timeout_ms_i64).unwrap_or(0);
420                    conn_opts =
421                        conn_opts.busy_timeout(std::time::Duration::from_millis(busy_timeout_ms));
422                }
423
424                let pool = o.connect_with(conn_opts).await?;
425                let sea = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
426
427                Ok(Self {
428                    engine,
429                    dsn: clean_dsn,
430                    sea,
431                })
432            }
433            #[cfg(not(feature = "sqlite"))]
434            DbEngine::Sqlite => Err(DbError::FeatureDisabled("SQLite feature not enabled")),
435        }
436    }
437
438    /// Get the backend.
439    #[must_use]
440    pub fn engine(&self) -> DbEngine {
441        self.engine
442    }
443
444    /// Get the DSN used for this connection.
445    #[must_use]
446    pub fn dsn(&self) -> &str {
447        &self.dsn
448    }
449
450    // NOTE: We intentionally do not expose raw `SQLx` pools from `DbHandle`.
451    // Use `SecureConn` for all application-level DB access.
452
453    // --- SeaORM accessor ---
454
455    /// Create a secure database wrapper for module code.
456    ///
457    /// This returns a `Db` which provides controlled access to the database
458    /// via `conn()` and `transaction()` methods.
459    ///
460    /// # Security
461    ///
462    /// **INTERNAL**: Get raw `SeaORM` connection for internal runtime operations.
463    ///
464    /// This is `pub(crate)` and should **only** be used by:
465    /// - The migration runner (for executing module migrations)
466    /// - Internal infrastructure code within `modkit-db`
467    ///
468    #[must_use]
469    pub(crate) fn sea_internal(&self) -> DatabaseConnection {
470        self.sea.clone()
471    }
472
473    /// **INTERNAL**: Get a reference to the raw `SeaORM` connection.
474    ///
475    /// This is `pub(crate)` and should **only** be used by:
476    /// - The `Db` wrapper for creating runners
477    /// - Internal infrastructure code within `modkit-db`
478    ///
479    /// **NEVER expose this to modules.**
480    #[must_use]
481    pub(crate) fn sea_internal_ref(&self) -> &DatabaseConnection {
482        &self.sea
483    }
484
485    // --- Advisory locks ---
486
487    /// Acquire an advisory lock with the given key and module namespace.
488    ///
489    /// # Errors
490    /// Returns an error if the lock cannot be acquired.
491    pub async fn lock(&self, module: &str, key: &str) -> Result<DbLockGuard> {
492        let lock_manager = advisory_locks::LockManager::new(self.dsn.clone());
493        let guard = lock_manager.lock(module, key).await?;
494        Ok(guard)
495    }
496
497    /// Try to acquire an advisory lock with configurable retry/backoff policy.
498    ///
499    /// # Errors
500    /// Returns an error if an unrecoverable lock error occurs.
501    pub async fn try_lock(
502        &self,
503        module: &str,
504        key: &str,
505        config: LockConfig,
506    ) -> Result<Option<DbLockGuard>> {
507        let lock_manager = advisory_locks::LockManager::new(self.dsn.clone());
508        let res = lock_manager.try_lock(module, key, config).await?;
509        Ok(res)
510    }
511
512    // NOTE: We intentionally do not expose raw SQL transactions from `DbHandle`.
513    // Use `SecureConn::transaction` for application-level atomic operations.
514}
515
516// ===================== tests =====================
517
518#[cfg(test)]
519#[cfg_attr(coverage_nightly, coverage(off))]
520mod tests {
521    use super::*;
522    #[cfg(feature = "sqlite")]
523    use tokio::time::Duration;
524
525    #[cfg(feature = "sqlite")]
526    #[tokio::test]
527    async fn test_sqlite_connection() -> Result<()> {
528        let dsn = "sqlite::memory:";
529        let opts = ConnectOpts::default();
530        let db = DbHandle::connect(dsn, opts).await?;
531        assert_eq!(db.engine(), DbEngine::Sqlite);
532        Ok(())
533    }
534
535    #[cfg(feature = "sqlite")]
536    #[tokio::test]
537    async fn test_sqlite_connection_with_pragma_parameters() -> Result<()> {
538        // Test that SQLite connections work with PRAGMA parameters in DSN
539        let dsn = "sqlite::memory:?wal=true&synchronous=NORMAL&busy_timeout=5000&journal_mode=WAL";
540        let opts = ConnectOpts::default();
541        let db = DbHandle::connect(dsn, opts).await?;
542        assert_eq!(db.engine(), DbEngine::Sqlite);
543
544        // Verify that the stored DSN has been cleaned (SQLite parameters removed)
545        // Note: For memory databases, the DSN should still be sqlite::memory: after cleaning
546        assert!(db.dsn == "sqlite::memory:" || db.dsn.starts_with("sqlite::memory:"));
547
548        Ok(())
549    }
550
551    #[tokio::test]
552    async fn test_backend_detection() {
553        assert_eq!(
554            DbHandle::detect("sqlite::memory:").unwrap(),
555            DbEngine::Sqlite
556        );
557        assert_eq!(
558            DbHandle::detect("postgres://localhost/test").unwrap(),
559            DbEngine::Postgres
560        );
561        assert_eq!(
562            DbHandle::detect("mysql://localhost/test").unwrap(),
563            DbEngine::MySql
564        );
565        assert!(DbHandle::detect("unknown://test").is_err());
566    }
567
568    #[cfg(feature = "sqlite")]
569    #[tokio::test]
570    async fn test_advisory_lock_sqlite() -> Result<()> {
571        let dsn = "sqlite:file:memdb1?mode=memory&cache=shared";
572        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
573
574        let now = std::time::SystemTime::now()
575            .duration_since(std::time::UNIX_EPOCH)
576            .map_or(0, |d| d.as_nanos());
577        let test_id = format!("test_basic_{now}");
578
579        let guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
580        let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
581        let _guard3 = db
582            .lock("different_module", &format!("{test_id}_key1"))
583            .await?;
584
585        // Deterministic unlock to avoid races with async Drop cleanup
586        guard1.release().await;
587        let _guard4 = db.lock("test_module", &format!("{test_id}_key1")).await?;
588        Ok(())
589    }
590
591    #[cfg(feature = "sqlite")]
592    #[tokio::test]
593    async fn test_advisory_lock_different_keys() -> Result<()> {
594        let dsn = "sqlite:file:memdb_diff_keys?mode=memory&cache=shared";
595        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
596
597        let now = std::time::SystemTime::now()
598            .duration_since(std::time::UNIX_EPOCH)
599            .map_or(0, |d| d.as_nanos());
600        let test_id = format!("test_diff_{now}");
601
602        let _guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
603        let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
604        let _guard3 = db.lock("other_module", &format!("{test_id}_key1")).await?;
605        Ok(())
606    }
607
608    #[cfg(feature = "sqlite")]
609    #[tokio::test]
610    async fn test_try_lock_with_config() -> Result<()> {
611        let dsn = "sqlite:file:memdb2?mode=memory&cache=shared";
612        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
613
614        let now = std::time::SystemTime::now()
615            .duration_since(std::time::UNIX_EPOCH)
616            .map_or(0, |d| d.as_nanos());
617        let test_id = format!("test_config_{now}");
618
619        let _guard1 = db.lock("test_module", &format!("{test_id}_key")).await?;
620
621        let config = LockConfig {
622            max_wait: Some(Duration::from_millis(200)),
623            initial_backoff: Duration::from_millis(50),
624            max_attempts: Some(3),
625            ..Default::default()
626        };
627
628        let result = db
629            .try_lock("test_module", &format!("{test_id}_different_key"), config)
630            .await?;
631        assert!(
632            result.is_some(),
633            "expected lock acquisition for different key"
634        );
635        Ok(())
636    }
637
638    #[cfg(feature = "sqlite")]
639    #[tokio::test]
640    async fn test_sea_internal_access() -> Result<()> {
641        let dsn = "sqlite::memory:";
642        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
643
644        // Internal method for migrations
645        let _raw = db.sea_internal();
646        Ok(())
647    }
648}