Skip to main content

modo/db/
pool.rs

1use std::collections::HashMap;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::Path;
4use std::sync::{Arc, RwLock};
5
6use crate::error::{Error, Result};
7
8use super::config::Config;
9use super::connect::connect;
10use super::database::Database;
11
12struct ShardedMap {
13    shards: Vec<RwLock<HashMap<String, Database>>>,
14}
15
16impl ShardedMap {
17    fn new(num_shards: usize) -> Self {
18        let mut shards = Vec::with_capacity(num_shards);
19        for _ in 0..num_shards {
20            shards.push(RwLock::new(HashMap::new()));
21        }
22        Self { shards }
23    }
24
25    fn shard_index(&self, key: &str) -> usize {
26        let mut hasher = DefaultHasher::new();
27        key.hash(&mut hasher);
28        hasher.finish() as usize % self.shards.len()
29    }
30
31    /// Look up a cached `Database` by key. Returns a clone (cheap Arc bump)
32    /// or `None` if the key is not present.
33    fn get(&self, key: &str) -> Option<Database> {
34        let idx = self.shard_index(key);
35        let shard = &self.shards[idx];
36        let read = shard.read().expect("pool shard lock poisoned");
37        read.get(key).cloned()
38    }
39
40    /// Insert a `Database` under `key`. If the key already exists the old
41    /// value is replaced (last writer wins).
42    fn insert(&self, key: String, db: Database) {
43        let idx = self.shard_index(&key);
44        let shard = &self.shards[idx];
45        let mut write = shard.write().expect("pool shard lock poisoned");
46        write.insert(key, db);
47    }
48}
49
50/// Multi-database connection pool with lazy shard opening.
51///
52/// Wraps a default [`Database`] (the main database) plus a sharded cache of
53/// lazily-opened shard databases. All shards share the same PRAGMAs and
54/// migrations from the parent [`Config`].
55///
56/// Cloning is cheap (reference count increment via `Arc`).
57///
58/// # Examples
59///
60/// ```rust,no_run
61/// use modo::db::{self, ConnExt, DatabasePool};
62///
63/// # async fn example() -> modo::Result<()> {
64/// let config = db::Config {
65///     pool: Some(db::PoolConfig::default()),
66///     ..Default::default()
67/// };
68/// let pool = DatabasePool::new(&config).await?;
69///
70/// // Default database (no shard).
71/// let default_db = pool.conn(None).await?;
72/// default_db.conn().execute_raw("SELECT 1", ()).await?;
73///
74/// // Tenant shard (lazy open + cache).
75/// let tenant_db = pool.conn(Some("tenant_abc")).await?;
76/// tenant_db.conn().execute_raw("SELECT 1", ()).await?;
77/// # Ok(()) }
78/// ```
79#[derive(Clone)]
80pub struct DatabasePool {
81    inner: Arc<Inner>,
82}
83
84struct Inner {
85    default: Database,
86    config: Config,
87    shards: ShardedMap,
88}
89
90impl DatabasePool {
91    /// Create a new pool from the given config.
92    ///
93    /// Opens the default database immediately. Shard databases are opened
94    /// lazily on first [`conn`](Self::conn) call.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if `config.pool` is `None` or the default database
99    /// fails to open.
100    pub async fn new(config: &Config) -> Result<Self> {
101        let pool_config = config
102            .pool
103            .as_ref()
104            .ok_or_else(|| Error::internal("database pool config is required"))?;
105
106        if pool_config.lock_shards == 0 {
107            return Err(Error::internal("pool lock_shards must be greater than 0"));
108        }
109
110        let default = connect(config).await?;
111        let shards = ShardedMap::new(pool_config.lock_shards);
112
113        Ok(Self {
114            inner: Arc::new(Inner {
115                default,
116                config: config.clone(),
117                shards,
118            }),
119        })
120    }
121
122    /// Get a database connection by shard name.
123    ///
124    /// - `None` — returns the default database (instant, no lock).
125    /// - `Some("name")` — returns the cached shard database, opening it on
126    ///   first access at `{base_path}/{name}.db`.
127    ///
128    /// Concurrent first-access to the same shard may open duplicate
129    /// connections; the last writer wins and the extra connection is dropped.
130    /// This is benign because `connect` is idempotent (PRAGMAs are
131    /// re-applied, migrations use checksum tracking).
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if the shard name is invalid (empty, starts with `.`,
136    /// or contains path separators or a null byte) or if the shard database
137    /// fails to open.
138    pub async fn conn(&self, shard: Option<&str>) -> Result<Database> {
139        let Some(name) = shard else {
140            return Ok(self.inner.default.clone());
141        };
142
143        if name.is_empty()
144            || name.starts_with('.')
145            || name.contains('/')
146            || name.contains('\\')
147            || name.contains('\0')
148        {
149            return Err(Error::bad_request(format!("invalid shard name: {name:?}")));
150        }
151
152        if let Some(db) = self.inner.shards.get(name) {
153            return Ok(db);
154        }
155
156        // Safety: pool config is validated as Some in new()
157        let pool_config = self.inner.config.pool.as_ref().unwrap();
158        let shard_path = if pool_config.base_path == ":memory:" {
159            ":memory:".to_string()
160        } else {
161            Path::new(&pool_config.base_path)
162                .join(format!("{name}.db"))
163                .to_string_lossy()
164                .into_owned()
165        };
166        let shard_config = Config {
167            path: shard_path,
168            pool: None,
169            ..self.inner.config.clone()
170        };
171
172        let db = connect(&shard_config).await.map_err(|e| {
173            Error::internal(format!("failed to open shard database: {name}")).chain(e)
174        })?;
175
176        self.inner.shards.insert(name.to_string(), db.clone());
177        Ok(db)
178    }
179}
180
181/// Wrapper for graceful shutdown integration with [`crate::run!`].
182///
183/// Wraps a [`DatabasePool`] so it can be registered as a [`Task`](crate::runtime::Task)
184/// with the modo runtime. On shutdown all database handles (default and shards)
185/// are dropped.
186///
187/// Created by [`managed_pool`].
188pub struct ManagedDatabasePool(DatabasePool);
189
190impl crate::runtime::Task for ManagedDatabasePool {
191    async fn shutdown(self) -> Result<()> {
192        drop(self.0);
193        Ok(())
194    }
195}
196
197/// Wrap a [`DatabasePool`] for use with [`crate::run!`].
198///
199/// # Examples
200///
201/// ```rust,no_run
202/// use modo::db;
203///
204/// # async fn example() -> modo::Result<()> {
205/// let config = db::Config::default();
206/// let pool = db::DatabasePool::new(&config).await?;
207/// let task = db::managed_pool(pool.clone());
208/// // Register `task` with modo::run!() for graceful shutdown
209/// # Ok(())
210/// # }
211/// ```
212pub fn managed_pool(pool: DatabasePool) -> ManagedDatabasePool {
213    ManagedDatabasePool(pool)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    async fn make_test_db() -> Database {
221        let config = super::super::config::Config {
222            path: ":memory:".to_string(),
223            ..Default::default()
224        };
225        super::super::connect::connect(&config).await.unwrap()
226    }
227
228    #[test]
229    fn sharded_map_get_returns_none_for_missing_key() {
230        let map = ShardedMap::new(4);
231        assert!(map.get("missing").is_none());
232    }
233
234    #[tokio::test]
235    async fn sharded_map_insert_and_get() {
236        let map = ShardedMap::new(4);
237        let db = make_test_db().await;
238        map.insert("tenant_a".to_string(), db);
239        assert!(map.get("tenant_a").is_some());
240    }
241
242    #[tokio::test]
243    async fn sharded_map_different_keys_independent() {
244        let map = ShardedMap::new(4);
245        let db = make_test_db().await;
246        map.insert("tenant_a".to_string(), db);
247        assert!(map.get("tenant_a").is_some());
248        assert!(map.get("tenant_b").is_none());
249    }
250
251    #[tokio::test]
252    async fn sharded_map_insert_idempotent() {
253        let map = ShardedMap::new(4);
254        let db1 = make_test_db().await;
255        let db2 = make_test_db().await;
256        map.insert("key".to_string(), db1);
257        map.insert("key".to_string(), db2);
258        assert!(map.get("key").is_some());
259    }
260}