Skip to main content

modkit_db/
lib.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2//! `ModKit` Database abstraction crate.
3//!
4//! This crate provides a unified interface for working with different databases
5//! (`SQLite`, `PostgreSQL`, `MySQL`) through `SQLx`, with optional `SeaORM` integration.
6//! It emphasizes typed connection options over DSN string manipulation and
7//! implements strict security controls (e.g., `SQLite` PRAGMA whitelist).
8//!
9//! # Features
10//! - `pg`, `mysql`, `sqlite`: enable `SQLx` backends
11//! - `sea-orm`: add `SeaORM` integration for type-safe operations
12//!
13//! # New Architecture
14//! The crate now supports:
15//! - Typed `DbConnectOptions` using sqlx `ConnectOptions` (no DSN string building)
16//! - Per-module database factories with configuration merging
17//! - `SQLite` PRAGMA whitelist for security
18//! - Environment variable expansion in passwords and DSNs
19//!
20//! # Example (`DbManager` API)
21//! ```rust,no_run
22//! use modkit_db::{DbManager, GlobalDatabaseConfig, DbConnConfig};
23//! use figment::{Figment, providers::Serialized};
24//! use std::path::PathBuf;
25//! use std::sync::Arc;
26//!
27//! // Create configuration using Figment
28//! let figment = Figment::new()
29//!     .merge(Serialized::defaults(serde_json::json!({
30//!         "db": {
31//!             "servers": {
32//!                 "main": {
33//!                     "host": "localhost",
34//!                     "port": 5432,
35//!                     "user": "app",
36//!                     "password": "${DB_PASSWORD}",
37//!                     "dbname": "app_db"
38//!                 }
39//!             }
40//!         },
41//!         "test_module": {
42//!             "database": {
43//!                 "server": "main",
44//!                 "dbname": "module_db"
45//!             }
46//!         }
47//!     })));
48//!
49//! // Create DbManager
50//! let home_dir = PathBuf::from("/app/data");
51//! let db_manager = Arc::new(DbManager::from_figment(figment, home_dir).unwrap());
52//!
53//! // Use in runtime with DbOptions::Manager(db_manager)
54//! // Modules can then use: ctx.db_required_async().await?
55//! ```
56
57#![cfg_attr(
58    not(any(feature = "pg", feature = "mysql", feature = "sqlite")),
59    allow(
60        unused_imports,
61        unused_variables,
62        dead_code,
63        unreachable_code,
64        unused_lifetimes,
65        clippy::unused_async,
66    )
67)]
68
69// Re-export key types for public API
70pub use advisory_locks::{DbLockGuard, LockConfig};
71
72// Re-export sea_orm_migration for modules that implement DatabaseCapability
73pub use sea_orm_migration;
74
75// Core modules
76pub mod advisory_locks;
77pub mod config;
78pub mod manager;
79pub mod migration_runner;
80pub mod odata;
81pub mod options;
82
83pub mod secure;
84
85mod db_provider;
86
87// Internal modules
88mod pool_opts;
89#[cfg(feature = "sqlite")]
90mod sqlite;
91
92// Re-export important types from new modules
93pub use config::{DbConnConfig, GlobalDatabaseConfig, PoolCfg};
94pub use manager::DbManager;
95pub use options::redact_credentials_in_dsn;
96
97// Re-export secure database types for convenience
98pub use secure::{Db, DbConn, DbTx};
99
100// Re-export service-friendly provider
101pub use db_provider::DBProvider;
102
103/// Connect and return a secure `Db` (no `DbHandle` exposure).
104///
105/// This is the public constructor intended for module code and tests.
106///
107/// # Errors
108///
109/// Returns `DbError` if the connection fails or the DSN/options are invalid.
110pub async fn connect_db(dsn: &str, opts: ConnectOpts) -> Result<Db> {
111    let handle = DbHandle::connect(dsn, opts).await?;
112    Ok(Db::new(handle))
113}
114
115/// Build a secure `Db` from config (no `DbHandle` exposure).
116///
117/// # Errors
118///
119/// Returns `DbError` if configuration is invalid or connection fails.
120pub async fn build_db(cfg: DbConnConfig, global: Option<&GlobalDatabaseConfig>) -> Result<Db> {
121    let handle = options::build_db_handle(cfg, global).await?;
122    Ok(Db::new(handle))
123}
124
125use std::time::Duration;
126
127// Internal imports
128#[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
129use pool_opts::ApplyPoolOpts;
130#[cfg(feature = "sqlite")]
131use sqlite::{Pragmas, extract_sqlite_pragmas, is_memory_dsn, prepare_sqlite_path};
132
133// Used for parsing SQLite DSN query parameters
134
135#[cfg(feature = "mysql")]
136use sqlx::mysql::MySqlPoolOptions;
137#[cfg(feature = "pg")]
138use sqlx::postgres::PgPoolOptions;
139#[cfg(feature = "sqlite")]
140use sqlx::sqlite::SqlitePoolOptions;
141#[cfg(feature = "sqlite")]
142use std::str::FromStr;
143
144use sea_orm::DatabaseConnection;
145#[cfg(feature = "mysql")]
146use sea_orm::SqlxMySqlConnector;
147#[cfg(feature = "pg")]
148use sea_orm::SqlxPostgresConnector;
149#[cfg(feature = "sqlite")]
150use sea_orm::SqlxSqliteConnector;
151
152use thiserror::Error;
153
154/// Library-local result type.
155pub type Result<T> = std::result::Result<T, DbError>;
156
157/// Typed error for the DB handle and helpers.
158#[derive(Debug, Error)]
159pub enum DbError {
160    #[error("Unknown DSN: {0}")]
161    UnknownDsn(String),
162
163    #[error("Feature not enabled: {0}")]
164    FeatureDisabled(&'static str),
165
166    #[error("Invalid configuration: {0}")]
167    InvalidConfig(String),
168
169    #[error("Configuration conflict: {0}")]
170    ConfigConflict(String),
171
172    #[error("Invalid SQLite PRAGMA parameter '{key}': {message}")]
173    InvalidSqlitePragma { key: String, message: String },
174
175    #[error("Unknown SQLite PRAGMA parameter: {0}")]
176    UnknownSqlitePragma(String),
177
178    #[error("Invalid connection parameter: {0}")]
179    InvalidParameter(String),
180
181    #[error("SQLite pragma error: {0}")]
182    SqlitePragma(String),
183
184    #[error("Environment variable error: {0}")]
185    EnvVar(#[from] std::env::VarError),
186
187    #[error("URL parsing error: {0}")]
188    UrlParse(#[from] url::ParseError),
189
190    #[cfg(any(feature = "pg", feature = "mysql", feature = "sqlite"))]
191    #[error(transparent)]
192    Sqlx(#[from] sqlx::Error),
193
194    #[error(transparent)]
195    Sea(#[from] sea_orm::DbErr),
196
197    #[error(transparent)]
198    Io(#[from] std::io::Error),
199
200    // make advisory_locks errors flow into DbError via `?`
201    #[error(transparent)]
202    Lock(#[from] advisory_locks::DbLockError),
203
204    #[error(transparent)]
205    Other(#[from] anyhow::Error),
206
207    /// Attempted to create a non-transactional connection inside an active transaction.
208    ///
209    /// This error occurs when `Db::conn()` is called from within a transaction closure.
210    /// The transaction guard prevents this to avoid accidental data bypass where writes
211    /// would persist outside the transaction scope.
212    ///
213    /// # Resolution
214    ///
215    /// Use the transaction runner (`tx`) provided to the closure instead of creating
216    /// a new connection:
217    ///
218    /// ```ignore
219    /// // Wrong - fails with ConnRequestedInsideTx
220    /// db.transaction(|_tx| {
221    ///     let conn = some_db.conn()?;  // Error!
222    ///     ...
223    /// });
224    ///
225    /// // Correct - use the transaction runner
226    /// db.transaction(|tx| {
227    ///     Entity::find().secure().scope_with(&scope).one(tx).await?;
228    ///     ...
229    /// });
230    /// ```
231    #[error("Cannot create non-transactional connection inside an active transaction")]
232    ConnRequestedInsideTx,
233}
234
235impl From<crate::secure::ScopeError> for DbError {
236    fn from(value: crate::secure::ScopeError) -> Self {
237        // Scope errors are not infra connection errors, but they still originate from the DB
238        // access layer. We keep the wrapper thin and preserve the message for callers.
239        DbError::Other(anyhow::Error::new(value))
240    }
241}
242
243/// Supported engines.
244#[derive(Clone, Copy, Debug, PartialEq, Eq)]
245pub enum DbEngine {
246    Postgres,
247    MySql,
248    Sqlite,
249}
250
251/// Connection options.
252/// Extended to cover common sqlx pool knobs; each driver applies the subset it supports.
253#[derive(Clone, Debug)]
254pub struct ConnectOpts {
255    /// Maximum number of connections in the pool.
256    pub max_conns: Option<u32>,
257    /// Minimum number of connections in the pool.
258    pub min_conns: Option<u32>,
259    /// Timeout to acquire a connection from the pool.
260    pub acquire_timeout: Option<Duration>,
261    /// Idle timeout before a connection is closed.
262    pub idle_timeout: Option<Duration>,
263    /// Maximum lifetime for a connection.
264    pub max_lifetime: Option<Duration>,
265    /// Test connection health before acquire.
266    pub test_before_acquire: bool,
267    /// For `SQLite` file DSNs, create parent directories if missing.
268    pub create_sqlite_dirs: bool,
269}
270impl Default for ConnectOpts {
271    fn default() -> Self {
272        Self {
273            max_conns: Some(10),
274            min_conns: None,
275            acquire_timeout: Some(Duration::from_secs(30)),
276            idle_timeout: None,
277            max_lifetime: None,
278            test_before_acquire: false,
279
280            create_sqlite_dirs: true,
281        }
282    }
283}
284
285/// Main handle.
286#[derive(Debug, Clone)]
287pub(crate) struct DbHandle {
288    engine: DbEngine,
289    dsn: String,
290    sea: DatabaseConnection,
291}
292
293#[cfg(feature = "sqlite")]
294const DEFAULT_SQLITE_BUSY_TIMEOUT: i32 = 5000;
295
296impl DbHandle {
297    /// Detect engine by DSN.
298    ///
299    /// Note: we only check scheme prefixes and don't mutate the tail (credentials etc.).
300    ///
301    /// # Errors
302    /// Returns `DbError::UnknownDsn` if the DSN scheme is not recognized.
303    pub(crate) fn detect(dsn: &str) -> Result<DbEngine> {
304        // Trim only leading spaces/newlines to be forgiving with env files.
305        let s = dsn.trim_start();
306
307        // Explicit, case-sensitive checks for common schemes.
308        // Add more variants as needed (e.g., postgres+unix://).
309        if s.starts_with("postgres://") || s.starts_with("postgresql://") {
310            Ok(DbEngine::Postgres)
311        } else if s.starts_with("mysql://") {
312            Ok(DbEngine::MySql)
313        } else if s.starts_with("sqlite:") || s.starts_with("sqlite://") {
314            Ok(DbEngine::Sqlite)
315        } else {
316            Err(DbError::UnknownDsn(dsn.to_owned()))
317        }
318    }
319
320    /// Connect and build handle.
321    ///
322    /// # Errors
323    /// Returns an error if the connection fails or the DSN is invalid.
324    pub(crate) async fn connect(dsn: &str, opts: ConnectOpts) -> Result<Self> {
325        let engine = Self::detect(dsn)?;
326        match engine {
327            #[cfg(feature = "pg")]
328            DbEngine::Postgres => {
329                let o = PgPoolOptions::new().apply(&opts);
330                let pool = o.connect(dsn).await?;
331                let sea = SqlxPostgresConnector::from_sqlx_postgres_pool(pool);
332                Ok(Self {
333                    engine,
334                    dsn: dsn.to_owned(),
335                    sea,
336                })
337            }
338            #[cfg(not(feature = "pg"))]
339            DbEngine::Postgres => Err(DbError::FeatureDisabled("PostgreSQL feature not enabled")),
340            #[cfg(feature = "mysql")]
341            DbEngine::MySql => {
342                let o = MySqlPoolOptions::new().apply(&opts);
343                let pool = o.connect(dsn).await?;
344                let sea = SqlxMySqlConnector::from_sqlx_mysql_pool(pool);
345                Ok(Self {
346                    engine,
347                    dsn: dsn.to_owned(),
348                    sea,
349                })
350            }
351            #[cfg(not(feature = "mysql"))]
352            DbEngine::MySql => Err(DbError::FeatureDisabled("MySQL feature not enabled")),
353            #[cfg(feature = "sqlite")]
354            DbEngine::Sqlite => {
355                let dsn = prepare_sqlite_path(dsn, opts.create_sqlite_dirs)?;
356
357                // Extract SQLite PRAGMA parameters from DSN
358                let (clean_dsn, pairs) = extract_sqlite_pragmas(&dsn);
359                let pragmas = Pragmas::from_pairs(&pairs);
360
361                // Build pool options with shared trait
362                let o = SqlitePoolOptions::new().apply(&opts);
363
364                // Apply SQLite pragmas using typed `sqlx` connect options (no raw SQL).
365                let is_memory = is_memory_dsn(&clean_dsn);
366                let mut conn_opts = sqlx::sqlite::SqliteConnectOptions::from_str(&clean_dsn)?;
367
368                let journal_mode = if let Some(mode) = &pragmas.journal_mode {
369                    match mode {
370                        sqlite::pragmas::JournalMode::Delete => {
371                            sqlx::sqlite::SqliteJournalMode::Delete
372                        }
373                        sqlite::pragmas::JournalMode::Wal => sqlx::sqlite::SqliteJournalMode::Wal,
374                        sqlite::pragmas::JournalMode::Memory => {
375                            sqlx::sqlite::SqliteJournalMode::Memory
376                        }
377                        sqlite::pragmas::JournalMode::Truncate => {
378                            sqlx::sqlite::SqliteJournalMode::Truncate
379                        }
380                        sqlite::pragmas::JournalMode::Persist => {
381                            sqlx::sqlite::SqliteJournalMode::Persist
382                        }
383                        sqlite::pragmas::JournalMode::Off => sqlx::sqlite::SqliteJournalMode::Off,
384                    }
385                } else if let Some(wal_toggle) = pragmas.wal_toggle {
386                    if wal_toggle {
387                        sqlx::sqlite::SqliteJournalMode::Wal
388                    } else {
389                        sqlx::sqlite::SqliteJournalMode::Delete
390                    }
391                } else if is_memory {
392                    sqlx::sqlite::SqliteJournalMode::Delete
393                } else {
394                    sqlx::sqlite::SqliteJournalMode::Wal
395                };
396                conn_opts = conn_opts.journal_mode(journal_mode);
397
398                let sync_mode = pragmas.synchronous.as_ref().map_or(
399                    sqlx::sqlite::SqliteSynchronous::Normal,
400                    |s| match s {
401                        sqlite::pragmas::SyncMode::Off => sqlx::sqlite::SqliteSynchronous::Off,
402                        sqlite::pragmas::SyncMode::Normal => {
403                            sqlx::sqlite::SqliteSynchronous::Normal
404                        }
405                        sqlite::pragmas::SyncMode::Full => sqlx::sqlite::SqliteSynchronous::Full,
406                        sqlite::pragmas::SyncMode::Extra => sqlx::sqlite::SqliteSynchronous::Extra,
407                    },
408                );
409                conn_opts = conn_opts.synchronous(sync_mode);
410
411                if !is_memory {
412                    let busy_timeout_ms_i64 = pragmas
413                        .busy_timeout_ms
414                        .unwrap_or(DEFAULT_SQLITE_BUSY_TIMEOUT.into())
415                        .max(0);
416                    let busy_timeout_ms = u64::try_from(busy_timeout_ms_i64).unwrap_or(0);
417                    conn_opts =
418                        conn_opts.busy_timeout(std::time::Duration::from_millis(busy_timeout_ms));
419                }
420
421                let pool = o.connect_with(conn_opts).await?;
422                let sea = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
423
424                Ok(Self {
425                    engine,
426                    dsn: clean_dsn,
427                    sea,
428                })
429            }
430            #[cfg(not(feature = "sqlite"))]
431            DbEngine::Sqlite => Err(DbError::FeatureDisabled("SQLite feature not enabled")),
432        }
433    }
434
435    /// Get the backend.
436    #[must_use]
437    pub fn engine(&self) -> DbEngine {
438        self.engine
439    }
440
441    /// Get the DSN used for this connection.
442    #[must_use]
443    pub fn dsn(&self) -> &str {
444        &self.dsn
445    }
446
447    // NOTE: We intentionally do not expose raw `SQLx` pools from `DbHandle`.
448    // Use `SecureConn` for all application-level DB access.
449
450    // --- SeaORM accessor ---
451
452    /// Create a secure database wrapper for module code.
453    ///
454    /// This returns a `Db` which provides controlled access to the database
455    /// via `conn()` and `transaction()` methods.
456    ///
457    /// # Security
458    ///
459    /// **INTERNAL**: Get raw `SeaORM` connection for internal runtime operations.
460    ///
461    /// This is `pub(crate)` and should **only** be used by:
462    /// - The migration runner (for executing module migrations)
463    /// - Internal infrastructure code within `modkit-db`
464    ///
465    #[must_use]
466    pub(crate) fn sea_internal(&self) -> DatabaseConnection {
467        self.sea.clone()
468    }
469
470    /// **INTERNAL**: Get a reference to the raw `SeaORM` connection.
471    ///
472    /// This is `pub(crate)` and should **only** be used by:
473    /// - The `Db` wrapper for creating runners
474    /// - Internal infrastructure code within `modkit-db`
475    ///
476    /// **NEVER expose this to modules.**
477    #[must_use]
478    pub(crate) fn sea_internal_ref(&self) -> &DatabaseConnection {
479        &self.sea
480    }
481
482    // --- Advisory locks ---
483
484    /// Acquire an advisory lock with the given key and module namespace.
485    ///
486    /// # Errors
487    /// Returns an error if the lock cannot be acquired.
488    pub async fn lock(&self, module: &str, key: &str) -> Result<DbLockGuard> {
489        let lock_manager = advisory_locks::LockManager::new(self.dsn.clone());
490        let guard = lock_manager.lock(module, key).await?;
491        Ok(guard)
492    }
493
494    /// Try to acquire an advisory lock with configurable retry/backoff policy.
495    ///
496    /// # Errors
497    /// Returns an error if an unrecoverable lock error occurs.
498    pub async fn try_lock(
499        &self,
500        module: &str,
501        key: &str,
502        config: LockConfig,
503    ) -> Result<Option<DbLockGuard>> {
504        let lock_manager = advisory_locks::LockManager::new(self.dsn.clone());
505        let res = lock_manager.try_lock(module, key, config).await?;
506        Ok(res)
507    }
508
509    // NOTE: We intentionally do not expose raw SQL transactions from `DbHandle`.
510    // Use `SecureConn::transaction` for application-level atomic operations.
511}
512
513// ===================== tests =====================
514
515#[cfg(test)]
516#[cfg_attr(coverage_nightly, coverage(off))]
517mod tests {
518    use super::*;
519    #[cfg(feature = "sqlite")]
520    use tokio::time::Duration;
521
522    #[cfg(feature = "sqlite")]
523    #[tokio::test]
524    async fn test_sqlite_connection() -> Result<()> {
525        let dsn = "sqlite::memory:";
526        let opts = ConnectOpts::default();
527        let db = DbHandle::connect(dsn, opts).await?;
528        assert_eq!(db.engine(), DbEngine::Sqlite);
529        Ok(())
530    }
531
532    #[cfg(feature = "sqlite")]
533    #[tokio::test]
534    async fn test_sqlite_connection_with_pragma_parameters() -> Result<()> {
535        // Test that SQLite connections work with PRAGMA parameters in DSN
536        let dsn = "sqlite::memory:?wal=true&synchronous=NORMAL&busy_timeout=5000&journal_mode=WAL";
537        let opts = ConnectOpts::default();
538        let db = DbHandle::connect(dsn, opts).await?;
539        assert_eq!(db.engine(), DbEngine::Sqlite);
540
541        // Verify that the stored DSN has been cleaned (SQLite parameters removed)
542        // Note: For memory databases, the DSN should still be sqlite::memory: after cleaning
543        assert!(db.dsn == "sqlite::memory:" || db.dsn.starts_with("sqlite::memory:"));
544
545        Ok(())
546    }
547
548    #[tokio::test]
549    async fn test_backend_detection() {
550        assert_eq!(
551            DbHandle::detect("sqlite::memory:").unwrap(),
552            DbEngine::Sqlite
553        );
554        assert_eq!(
555            DbHandle::detect("postgres://localhost/test").unwrap(),
556            DbEngine::Postgres
557        );
558        assert_eq!(
559            DbHandle::detect("mysql://localhost/test").unwrap(),
560            DbEngine::MySql
561        );
562        assert!(DbHandle::detect("unknown://test").is_err());
563    }
564
565    #[cfg(feature = "sqlite")]
566    #[tokio::test]
567    async fn test_advisory_lock_sqlite() -> Result<()> {
568        let dsn = "sqlite:file:memdb1?mode=memory&cache=shared";
569        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
570
571        let now = std::time::SystemTime::now()
572            .duration_since(std::time::UNIX_EPOCH)
573            .map_or(0, |d| d.as_nanos());
574        let test_id = format!("test_basic_{now}");
575
576        let guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
577        let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
578        let _guard3 = db
579            .lock("different_module", &format!("{test_id}_key1"))
580            .await?;
581
582        // Deterministic unlock to avoid races with async Drop cleanup
583        guard1.release().await;
584        let _guard4 = db.lock("test_module", &format!("{test_id}_key1")).await?;
585        Ok(())
586    }
587
588    #[cfg(feature = "sqlite")]
589    #[tokio::test]
590    async fn test_advisory_lock_different_keys() -> Result<()> {
591        let dsn = "sqlite:file:memdb_diff_keys?mode=memory&cache=shared";
592        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
593
594        let now = std::time::SystemTime::now()
595            .duration_since(std::time::UNIX_EPOCH)
596            .map_or(0, |d| d.as_nanos());
597        let test_id = format!("test_diff_{now}");
598
599        let _guard1 = db.lock("test_module", &format!("{test_id}_key1")).await?;
600        let _guard2 = db.lock("test_module", &format!("{test_id}_key2")).await?;
601        let _guard3 = db.lock("other_module", &format!("{test_id}_key1")).await?;
602        Ok(())
603    }
604
605    #[cfg(feature = "sqlite")]
606    #[tokio::test]
607    async fn test_try_lock_with_config() -> Result<()> {
608        let dsn = "sqlite:file:memdb2?mode=memory&cache=shared";
609        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
610
611        let now = std::time::SystemTime::now()
612            .duration_since(std::time::UNIX_EPOCH)
613            .map_or(0, |d| d.as_nanos());
614        let test_id = format!("test_config_{now}");
615
616        let _guard1 = db.lock("test_module", &format!("{test_id}_key")).await?;
617
618        let config = LockConfig {
619            max_wait: Some(Duration::from_millis(200)),
620            initial_backoff: Duration::from_millis(50),
621            max_attempts: Some(3),
622            ..Default::default()
623        };
624
625        let result = db
626            .try_lock("test_module", &format!("{test_id}_different_key"), config)
627            .await?;
628        assert!(
629            result.is_some(),
630            "expected lock acquisition for different key"
631        );
632        Ok(())
633    }
634
635    #[cfg(feature = "sqlite")]
636    #[tokio::test]
637    async fn test_sea_internal_access() -> Result<()> {
638        let dsn = "sqlite::memory:";
639        let db = DbHandle::connect(dsn, ConnectOpts::default()).await?;
640
641        // Internal method for migrations
642        let _raw = db.sea_internal();
643        Ok(())
644    }
645}