Skip to main content

modo/db/
pool.rs

1use std::collections::HashMap;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::Path;
4use std::sync::{Arc, RwLock};
5
6use crate::error::{Error, Result};
7
8use super::config::Config;
9use super::connect::connect;
10use super::database::Database;
11
12struct ShardedMap {
13    shards: Vec<RwLock<HashMap<String, Database>>>,
14}
15
16impl ShardedMap {
17    fn new(num_shards: usize) -> Self {
18        let mut shards = Vec::with_capacity(num_shards);
19        for _ in 0..num_shards {
20            shards.push(RwLock::new(HashMap::new()));
21        }
22        Self { shards }
23    }
24
25    fn shard_index(&self, key: &str) -> usize {
26        let mut hasher = DefaultHasher::new();
27        key.hash(&mut hasher);
28        hasher.finish() as usize % self.shards.len()
29    }
30
31    /// Look up a cached `Database` by key. Returns a clone (cheap Arc bump)
32    /// or `None` if the key is not present.
33    fn get(&self, key: &str) -> Option<Database> {
34        let idx = self.shard_index(key);
35        let shard = &self.shards[idx];
36        let read = shard.read().expect("pool shard lock poisoned");
37        read.get(key).cloned()
38    }
39
40    /// Insert a `Database` under `key`. If the key already exists the old
41    /// value is replaced (last writer wins).
42    fn insert(&self, key: String, db: Database) {
43        let idx = self.shard_index(&key);
44        let shard = &self.shards[idx];
45        let mut write = shard.write().expect("pool shard lock poisoned");
46        write.insert(key, db);
47    }
48}
49
50/// Multi-database connection pool with lazy shard opening.
51///
52/// Wraps a default [`Database`] (the main database) plus a sharded cache of
53/// lazily-opened shard databases. All shards share the same PRAGMAs and
54/// migrations from the parent [`Config`].
55///
56/// Cloning is cheap (reference count increment via `Arc`).
57///
58/// # Examples
59///
60/// ```rust,ignore
61/// use modo::db::{self, ConnExt, ConnQueryExt, DatabasePool};
62///
63/// let pool = DatabasePool::new(&config).await?;
64///
65/// // Default database:
66/// let user: User = pool.conn(None).await?
67///     .conn()
68///     .query_one("SELECT id, name FROM users WHERE id = ?1", libsql::params!["u1"])
69///     .await?;
70///
71/// // Tenant shard (lazy open + cache):
72/// let user: User = pool.conn(tenant.db_shard.as_deref()).await?
73///     .conn()
74///     .query_one("SELECT id, name FROM users WHERE id = ?1", libsql::params!["u1"])
75///     .await?;
76/// ```
77#[derive(Clone)]
78pub struct DatabasePool {
79    inner: Arc<Inner>,
80}
81
82struct Inner {
83    default: Database,
84    config: Config,
85    shards: ShardedMap,
86}
87
88impl DatabasePool {
89    /// Create a new pool from the given config.
90    ///
91    /// Opens the default database immediately. Shard databases are opened
92    /// lazily on first [`conn`](Self::conn) call.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if `config.pool` is `None` or the default database
97    /// fails to open.
98    pub async fn new(config: &Config) -> Result<Self> {
99        let pool_config = config
100            .pool
101            .as_ref()
102            .ok_or_else(|| Error::internal("database pool config is required"))?;
103
104        if pool_config.lock_shards == 0 {
105            return Err(Error::internal("pool lock_shards must be greater than 0"));
106        }
107
108        let default = connect(config).await?;
109        let shards = ShardedMap::new(pool_config.lock_shards);
110
111        Ok(Self {
112            inner: Arc::new(Inner {
113                default,
114                config: config.clone(),
115                shards,
116            }),
117        })
118    }
119
120    /// Get a database connection by shard name.
121    ///
122    /// - `None` — returns the default database (instant, no lock).
123    /// - `Some("name")` — returns the cached shard database, opening it on
124    ///   first access at `{base_path}/{name}.db`.
125    ///
126    /// Concurrent first-access to the same shard may open duplicate
127    /// connections; the last writer wins and the extra connection is dropped.
128    /// This is benign because `connect` is idempotent (PRAGMAs are
129    /// re-applied, migrations use checksum tracking).
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the shard name is invalid (empty, starts with `.`,
134    /// or contains path separators or a null byte) or if the shard database
135    /// fails to open.
136    pub async fn conn(&self, shard: Option<&str>) -> Result<Database> {
137        let Some(name) = shard else {
138            return Ok(self.inner.default.clone());
139        };
140
141        if name.is_empty()
142            || name.starts_with('.')
143            || name.contains('/')
144            || name.contains('\\')
145            || name.contains('\0')
146        {
147            return Err(Error::bad_request(format!("invalid shard name: {name:?}")));
148        }
149
150        if let Some(db) = self.inner.shards.get(name) {
151            return Ok(db);
152        }
153
154        // Safety: pool config is validated as Some in new()
155        let pool_config = self.inner.config.pool.as_ref().unwrap();
156        let shard_path = if pool_config.base_path == ":memory:" {
157            ":memory:".to_string()
158        } else {
159            Path::new(&pool_config.base_path)
160                .join(format!("{name}.db"))
161                .to_string_lossy()
162                .into_owned()
163        };
164        let shard_config = Config {
165            path: shard_path,
166            pool: None,
167            ..self.inner.config.clone()
168        };
169
170        let db = connect(&shard_config).await.map_err(|e| {
171            Error::internal(format!("failed to open shard database: {name}")).chain(e)
172        })?;
173
174        self.inner.shards.insert(name.to_string(), db.clone());
175        Ok(db)
176    }
177}
178
179/// Wrapper for graceful shutdown integration with [`crate::run!`].
180///
181/// Wraps a [`DatabasePool`] so it can be registered as a [`Task`](crate::runtime::Task)
182/// with the modo runtime. On shutdown all database handles (default and shards)
183/// are dropped.
184///
185/// Created by [`managed_pool`].
186pub struct ManagedDatabasePool(DatabasePool);
187
188impl crate::runtime::Task for ManagedDatabasePool {
189    async fn shutdown(self) -> Result<()> {
190        drop(self.0);
191        Ok(())
192    }
193}
194
195/// Wrap a [`DatabasePool`] for use with [`crate::run!`].
196///
197/// # Examples
198///
199/// ```rust,no_run
200/// use modo::db;
201///
202/// # async fn example() -> modo::Result<()> {
203/// let config = db::Config::default();
204/// let pool = db::DatabasePool::new(&config).await?;
205/// let task = db::managed_pool(pool.clone());
206/// // Register `task` with modo::run!() for graceful shutdown
207/// # Ok(())
208/// # }
209/// ```
210pub fn managed_pool(pool: DatabasePool) -> ManagedDatabasePool {
211    ManagedDatabasePool(pool)
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    async fn make_test_db() -> Database {
219        let config = super::super::config::Config {
220            path: ":memory:".to_string(),
221            ..Default::default()
222        };
223        super::super::connect::connect(&config).await.unwrap()
224    }
225
226    #[test]
227    fn sharded_map_get_returns_none_for_missing_key() {
228        let map = ShardedMap::new(4);
229        assert!(map.get("missing").is_none());
230    }
231
232    #[tokio::test]
233    async fn sharded_map_insert_and_get() {
234        let map = ShardedMap::new(4);
235        let db = make_test_db().await;
236        map.insert("tenant_a".to_string(), db);
237        assert!(map.get("tenant_a").is_some());
238    }
239
240    #[tokio::test]
241    async fn sharded_map_different_keys_independent() {
242        let map = ShardedMap::new(4);
243        let db = make_test_db().await;
244        map.insert("tenant_a".to_string(), db);
245        assert!(map.get("tenant_a").is_some());
246        assert!(map.get("tenant_b").is_none());
247    }
248
249    #[tokio::test]
250    async fn sharded_map_insert_idempotent() {
251        let map = ShardedMap::new(4);
252        let db1 = make_test_db().await;
253        let db2 = make_test_db().await;
254        map.insert("key".to_string(), db1);
255        map.insert("key".to_string(), db2);
256        assert!(map.get("key").is_some());
257    }
258}