Skip to main content

musefs_core/
db_pool.rs

1//! Hands a read connection to whichever thread needs one.
2//!
3//! - File-backed DB → the pool owns one read-only connection per thread in an
4//!   internal map. Each thread lazily opens its own connection, and dropping
5//!   the pool drops the map and closes every connection it owns.
6//! - In-memory DB (tests) cannot be reopened by path, so a single connection is
7//!   shared behind a mutex.
8//!
9//! Dropping the pool closes every connection it owns, from whatever thread
10//! drops it (#127). A thread that dies while the pool lives leaves its
11//! connection in the map until the pool is dropped; that bound is the pool's
12//! lifetime, not the thread's. Each pool has its own map, so multiple mounts
13//! (or test DBs) on the same thread don't collide.
14
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::thread::ThreadId;
18
19use dashmap::DashMap;
20use parking_lot::ReentrantMutex;
21
22use musefs_db::{Db, ReadOnly};
23
24use crate::error::{CoreError, Result};
25
26/// `with` and `with_poll` may nest freely, on any variant: `PerThread` reads
27/// hand out cloned `Arc`s from the pool-owned map, and the connection locks
28/// are reentrant.
29///
30/// The `poll`/`conns` asymmetry is deliberate. `poll` is uniquely owned (the
31/// `Box` only keeps the variant small) because `with_poll` locks it in place
32/// and takes no other lock. `conns` is a `DashMap`, so `with` never
33/// serializes concurrent reads on a single map lock — a steady-state hit
34/// takes only a shard read lock. Values are `Arc`-wrapped so `with` can clone
35/// a handle and release the shard guard *before* running the caller's closure
36/// — holding a (non-reentrant) shard guard across it would deadlock a nested
37/// `with` whose thread hashes to the same shard. The inner `ReentrantMutex` is
38/// never contended (only its owning thread locks it) but is load-bearing for
39/// the type system: `Db` is `Send + !Sync`, so the mutex wrapper is what keeps
40/// the map values, and therefore `DbPool`, `Send + Sync`.
41pub enum DbPool {
42    PerThread {
43        path: PathBuf,
44        poll: Box<ReentrantMutex<Db<ReadOnly>>>,
45        conns: DashMap<ThreadId, Arc<ReentrantMutex<Db<ReadOnly>>>>,
46    },
47    Shared(Arc<ReentrantMutex<Db<ReadOnly>>>),
48}
49
50impl DbPool {
51    /// Build a pool from the DB used to construct the mount. File-backed DBs
52    /// become per-thread pools (the passed connection becomes the poll
53    /// connection — workers open their own); in-memory DBs are wrapped in a
54    /// shared mutex.
55    pub fn new(db: Db) -> Result<DbPool> {
56        let db = db.into_read_only();
57        match db.path() {
58            Some(p) => Ok(DbPool::PerThread {
59                path: p.to_path_buf(),
60                poll: Box::new(ReentrantMutex::new(db)),
61                conns: DashMap::new(),
62            }),
63            None => Ok(DbPool::Shared(Arc::new(ReentrantMutex::new(db)))),
64        }
65    }
66
67    /// Run `f` with the persistent poll connection.
68    ///
69    /// For `PerThread` pools, `PRAGMA data_version` is connection-relative: a fresh
70    /// thread-local connection starts at 0, so it can't detect changes that happened
71    /// before it opened. The poll connection is the original writer Db, kept alive
72    /// precisely so it can observe incremental changes from other connections.
73    /// For `Shared` pools (in-memory), the single shared connection serves both roles.
74    pub fn with_poll<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
75        match self {
76            DbPool::PerThread { poll, .. } => f(&poll.lock()),
77            DbPool::Shared(m) => f(&m.lock()),
78        }
79    }
80
81    /// Run `f` with a read connection.
82    pub fn with<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
83        match self {
84            DbPool::PerThread { path, conns, .. } => {
85                let tid = std::thread::current().id();
86                // Clone the handle and release every DashMap shard guard before
87                // running `f`: holding a shard guard across the closure would
88                // deadlock a nested `with` whose thread hashes to the same shard.
89                // The steady-state hit takes only a shard *read* lock (`get`), so
90                // concurrent reads never serialize on a single global lock.
91                let db = if let Some(existing) = conns.get(&tid) {
92                    Arc::clone(existing.value())
93                } else {
94                    match conns.entry(tid) {
95                        dashmap::Entry::Occupied(entry) => Arc::clone(entry.get()),
96                        dashmap::Entry::Vacant(entry) => {
97                            let db =
98                                Db::open_readonly(path).map_err(|source| CoreError::DbOpen {
99                                    path: path.clone(),
100                                    source,
101                                })?;
102                            Arc::clone(&entry.insert(Arc::new(ReentrantMutex::new(db))))
103                        }
104                    }
105                };
106                let guard = db.lock();
107                f(&guard)
108            }
109            DbPool::Shared(m) => {
110                let db = m.lock();
111                f(&db)
112            }
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use musefs_db::Db;
121
122    /// Count this process's open fds whose target path starts with `db_path`.
123    /// Prefix match deliberately: a WAL reader holds up to three fds
124    /// (`db`, `db-wal`, `db-shm`). Linux-only: it reads `/proc/self/fd`, which
125    /// FreeBSD has no equivalent for by default — so the two fd-leak tests that
126    /// use it are gated to Linux as well.
127    #[cfg(target_os = "linux")]
128    fn db_fd_count(db_path: &std::path::Path) -> usize {
129        let prefix = db_path.to_str().unwrap();
130        std::fs::read_dir("/proc/self/fd")
131            .unwrap()
132            .filter_map(|e| std::fs::read_link(e.unwrap().path()).ok())
133            .filter(|target| target.to_string_lossy().starts_with(prefix))
134            .count()
135    }
136
137    // Linux-only: asserts fd closure via `db_fd_count` (/proc/self/fd).
138    #[cfg(target_os = "linux")]
139    #[test]
140    fn drop_closes_connections_opened_by_live_threads() {
141        let dir = tempfile::tempdir().unwrap();
142        let path = dir.path().join("d.db");
143        Db::open(&path).unwrap(); // create + migrate (writer, sets WAL)
144        let baseline = db_fd_count(&path);
145
146        let pool = Arc::new(DbPool::new(Db::open(&path).unwrap()).unwrap());
147        // 2 workers + the main thread; workers park here until main has asserted.
148        let barrier = Arc::new(std::sync::Barrier::new(3));
149        let (done_tx, done_rx) = std::sync::mpsc::channel();
150        let mut handles = Vec::new();
151        for _ in 0..2 {
152            let pool = Arc::clone(&pool);
153            let barrier = Arc::clone(&barrier);
154            let done = done_tx.clone();
155            handles.push(std::thread::spawn(move || {
156                pool.with(|db| Ok(db.data_version()?)).unwrap();
157                drop(pool); // this thread's Arc clone; main's is then the last
158                done.send(()).unwrap();
159                barrier.wait();
160            }));
161        }
162        // Count exactly two done-signals; don't drain-until-Err — the workers
163        // still hold their sender clones while parked at the barrier.
164        drop(done_tx);
165        for _ in 0..2 {
166            done_rx.recv().unwrap();
167        }
168
169        // Both workers are done using the pool but still alive (parked at, or
170        // headed to, the barrier — they cannot pass it until main waits too).
171        drop(pool);
172        assert_eq!(
173            db_fd_count(&path),
174            baseline,
175            "pool drop must close all threads' connections while those threads are alive"
176        );
177
178        barrier.wait();
179        for h in handles {
180            h.join().unwrap();
181        }
182    }
183
184    // Linux-only: asserts fd closure via `db_fd_count` (/proc/self/fd).
185    #[cfg(target_os = "linux")]
186    #[test]
187    fn drop_on_foreign_thread_closes_all_connections() {
188        let dir = tempfile::tempdir().unwrap();
189        let path = dir.path().join("x.db");
190        Db::open(&path).unwrap();
191        let baseline = db_fd_count(&path);
192
193        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
194        pool.with(|db| Ok(db.data_version()?)).unwrap(); // opens this thread's connection
195
196        // DbPool is Send: drop it on a thread that never opened a connection.
197        std::thread::spawn(move || drop(pool)).join().unwrap();
198
199        assert_eq!(
200            db_fd_count(&path),
201            baseline,
202            "drop on a foreign thread must still close every connection"
203        );
204    }
205
206    #[test]
207    fn shared_pool_for_in_memory_db() {
208        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
209        // NOTE: db.data_version() returns the DB crate's error type, so wrap with
210        // Ok(...?) to convert it into the core Result the closure must return.
211        let v = pool.with(|db| Ok(db.data_version()?)).unwrap();
212        let v2 = pool.with(|db| Ok(db.data_version()?)).unwrap();
213        assert_eq!(v, v2);
214    }
215
216    #[test]
217    fn same_thread_two_pools_keyed_by_path() {
218        let dir = tempfile::tempdir().unwrap();
219        let path_a = dir.path().join("a.db");
220        let path_b = dir.path().join("b.db");
221        Db::open(&path_a).unwrap();
222        Db::open(&path_b).unwrap();
223
224        let pool_a = DbPool::new(Db::open(&path_a).unwrap()).unwrap();
225        let pool_b = DbPool::new(Db::open(&path_b).unwrap()).unwrap();
226
227        pool_a
228            .with(|db| {
229                assert_eq!(db.path().unwrap(), path_a);
230                Ok(())
231            })
232            .unwrap();
233        pool_b
234            .with(|db| {
235                assert_eq!(db.path().unwrap(), path_b);
236                Ok(())
237            })
238            .unwrap();
239    }
240
241    #[test]
242    fn per_thread_pool_for_file_db() {
243        let dir = tempfile::tempdir().unwrap();
244        let path = dir.path().join("m.db");
245        Db::open(&path).unwrap(); // create + migrate (writer, sets WAL)
246        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
247        // Used from a different thread: that thread opens its own read connection.
248        let r = std::thread::scope(|s| {
249            s.spawn(|| pool.with(|db| Ok(db.data_version()?)).unwrap())
250                .join()
251                .unwrap()
252        });
253        assert!(r >= 0);
254    }
255
256    #[test]
257    fn reentrant_with_does_not_panic() {
258        let dir = tempfile::tempdir().unwrap();
259        let path = dir.path().join("re.db");
260        Db::open(&path).unwrap();
261        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
262        let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
263        assert!(r.is_ok(), "re-entrant with() must not panic or error");
264    }
265
266    #[test]
267    fn with_open_failure_includes_path_in_error() {
268        let bad = std::path::PathBuf::from("/nonexistent-musefs-dir/does-not-exist.db");
269        let pool = DbPool::PerThread {
270            path: bad.clone(),
271            poll: Box::new(ReentrantMutex::new(
272                Db::open_in_memory().unwrap().into_read_only(),
273            )),
274            conns: DashMap::new(),
275        };
276        let msg = pool.with(|_db| Ok(())).unwrap_err().to_string();
277        assert!(
278            msg.contains("/nonexistent-musefs-dir/does-not-exist.db"),
279            "open error must name the failing path, got: {msg}"
280        );
281    }
282
283    #[test]
284    fn nested_with_on_shared_pool() {
285        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
286        let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
287        assert!(r.is_ok(), "nested with on Shared must not deadlock");
288    }
289
290    #[test]
291    fn with_poll_inside_with_on_shared_pool() {
292        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
293        let r: Result<i64> = pool.with(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
294        assert!(
295            r.is_ok(),
296            "with_poll inside with on Shared must not deadlock"
297        );
298    }
299
300    #[test]
301    fn nested_with_poll_on_per_thread_pool() {
302        let dir = tempfile::tempdir().unwrap();
303        let path = dir.path().join("np.db");
304        Db::open(&path).unwrap();
305        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
306        let r: Result<i64> = pool.with_poll(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
307        assert!(r.is_ok(), "nested with_poll on PerThread must not deadlock");
308    }
309}