Skip to main content

musefs_core/
db_pool.rs

1//! Hands a read connection to whichever thread needs one.
2//!
3//! - File-backed DB → the pool owns one read-only connection per thread in an
4//!   internal map. Each thread lazily opens its own connection, and dropping
5//!   the pool drops the map and closes every connection it owns.
6//! - In-memory DB (tests) cannot be reopened by path, so a single connection is
7//!   shared behind a mutex.
8//!
9//! Dropping the pool closes every connection it owns, from whatever thread
10//! drops it (#127). A thread that dies while the pool lives leaves its
11//! connection in the map until the pool is dropped; that bound is the pool's
12//! lifetime, not the thread's. Each pool has its own map, so multiple mounts
13//! (or test DBs) on the same thread don't collide.
14
15use std::collections::{HashMap, hash_map::Entry};
16use std::path::PathBuf;
17use std::sync::Arc;
18use std::thread::ThreadId;
19
20use parking_lot::{Mutex, ReentrantMutex};
21
22use musefs_db::{Db, ReadOnly};
23
24use crate::error::{CoreError, Result};
25
26/// `with` and `with_poll` may nest freely, on any variant: `PerThread` reads
27/// hand out cloned `Arc`s from the pool-owned map, and the connection locks
28/// are reentrant.
29///
30/// The `poll`/`conns` asymmetry is deliberate. `poll` is uniquely owned (the
31/// `Box` only keeps the variant small) because `with_poll` locks it in place
32/// and takes no other lock. `conns` values are `Arc`-wrapped so `with` can
33/// clone a handle and release the map guard *before* running the caller's
34/// closure — holding the (non-reentrant) map guard across it would deadlock
35/// nested `with`. The inner `ReentrantMutex` is never contended (only its
36/// owning thread locks it) but is load-bearing for the type system: `Db` is
37/// `Send + !Sync`, so the mutex wrapper is what keeps the map field, and
38/// therefore `DbPool`, `Send + Sync`.
39pub enum DbPool {
40    PerThread {
41        path: PathBuf,
42        poll: Box<ReentrantMutex<Db<ReadOnly>>>,
43        conns: Mutex<HashMap<ThreadId, Arc<ReentrantMutex<Db<ReadOnly>>>>>,
44    },
45    Shared(Arc<ReentrantMutex<Db<ReadOnly>>>),
46}
47
48impl DbPool {
49    /// Build a pool from the DB used to construct the mount. File-backed DBs
50    /// become per-thread pools (the passed connection becomes the poll
51    /// connection — workers open their own); in-memory DBs are wrapped in a
52    /// shared mutex.
53    pub fn new(db: Db) -> Result<DbPool> {
54        let db = db.into_read_only();
55        match db.path() {
56            Some(p) => Ok(DbPool::PerThread {
57                path: p.to_path_buf(),
58                poll: Box::new(ReentrantMutex::new(db)),
59                conns: Mutex::new(HashMap::new()),
60            }),
61            None => Ok(DbPool::Shared(Arc::new(ReentrantMutex::new(db)))),
62        }
63    }
64
65    /// Run `f` with the persistent poll connection.
66    ///
67    /// For `PerThread` pools, `PRAGMA data_version` is connection-relative: a fresh
68    /// thread-local connection starts at 0, so it can't detect changes that happened
69    /// before it opened. The poll connection is the original writer Db, kept alive
70    /// precisely so it can observe incremental changes from other connections.
71    /// For `Shared` pools (in-memory), the single shared connection serves both roles.
72    pub fn with_poll<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
73        match self {
74            DbPool::PerThread { poll, .. } => f(&poll.lock()),
75            DbPool::Shared(m) => f(&m.lock()),
76        }
77    }
78
79    /// Run `f` with a read connection.
80    pub fn with<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
81        match self {
82            DbPool::PerThread { path, conns, .. } => {
83                let tid = std::thread::current().id();
84                let db = {
85                    let mut map = conns.lock();
86                    match map.entry(tid) {
87                        Entry::Occupied(entry) => Arc::clone(entry.get()),
88                        Entry::Vacant(entry) => {
89                            let db =
90                                Db::open_readonly(path).map_err(|source| CoreError::DbOpen {
91                                    path: path.clone(),
92                                    source,
93                                })?;
94                            Arc::clone(entry.insert(Arc::new(ReentrantMutex::new(db))))
95                        }
96                    }
97                };
98                let guard = db.lock();
99                f(&guard)
100            }
101            DbPool::Shared(m) => {
102                let db = m.lock();
103                f(&db)
104            }
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use musefs_db::Db;
113
114    /// Count this process's open fds whose target path starts with `db_path`.
115    /// Prefix match deliberately: a WAL reader holds up to three fds
116    /// (`db`, `db-wal`, `db-shm`). Linux-only: it reads `/proc/self/fd`, which
117    /// FreeBSD has no equivalent for by default — so the two fd-leak tests that
118    /// use it are gated to Linux as well.
119    #[cfg(target_os = "linux")]
120    fn db_fd_count(db_path: &std::path::Path) -> usize {
121        let prefix = db_path.to_str().unwrap();
122        std::fs::read_dir("/proc/self/fd")
123            .unwrap()
124            .filter_map(|e| std::fs::read_link(e.unwrap().path()).ok())
125            .filter(|target| target.to_string_lossy().starts_with(prefix))
126            .count()
127    }
128
129    // Linux-only: asserts fd closure via `db_fd_count` (/proc/self/fd).
130    #[cfg(target_os = "linux")]
131    #[test]
132    fn drop_closes_connections_opened_by_live_threads() {
133        let dir = tempfile::tempdir().unwrap();
134        let path = dir.path().join("d.db");
135        Db::open(&path).unwrap(); // create + migrate (writer, sets WAL)
136        let baseline = db_fd_count(&path);
137
138        let pool = Arc::new(DbPool::new(Db::open(&path).unwrap()).unwrap());
139        // 2 workers + the main thread; workers park here until main has asserted.
140        let barrier = Arc::new(std::sync::Barrier::new(3));
141        let (done_tx, done_rx) = std::sync::mpsc::channel();
142        let mut handles = Vec::new();
143        for _ in 0..2 {
144            let pool = Arc::clone(&pool);
145            let barrier = Arc::clone(&barrier);
146            let done = done_tx.clone();
147            handles.push(std::thread::spawn(move || {
148                pool.with(|db| Ok(db.data_version()?)).unwrap();
149                drop(pool); // this thread's Arc clone; main's is then the last
150                done.send(()).unwrap();
151                barrier.wait();
152            }));
153        }
154        // Count exactly two done-signals; don't drain-until-Err — the workers
155        // still hold their sender clones while parked at the barrier.
156        drop(done_tx);
157        for _ in 0..2 {
158            done_rx.recv().unwrap();
159        }
160
161        // Both workers are done using the pool but still alive (parked at, or
162        // headed to, the barrier — they cannot pass it until main waits too).
163        drop(pool);
164        assert_eq!(
165            db_fd_count(&path),
166            baseline,
167            "pool drop must close all threads' connections while those threads are alive"
168        );
169
170        barrier.wait();
171        for h in handles {
172            h.join().unwrap();
173        }
174    }
175
176    // Linux-only: asserts fd closure via `db_fd_count` (/proc/self/fd).
177    #[cfg(target_os = "linux")]
178    #[test]
179    fn drop_on_foreign_thread_closes_all_connections() {
180        let dir = tempfile::tempdir().unwrap();
181        let path = dir.path().join("x.db");
182        Db::open(&path).unwrap();
183        let baseline = db_fd_count(&path);
184
185        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
186        pool.with(|db| Ok(db.data_version()?)).unwrap(); // opens this thread's connection
187
188        // DbPool is Send: drop it on a thread that never opened a connection.
189        std::thread::spawn(move || drop(pool)).join().unwrap();
190
191        assert_eq!(
192            db_fd_count(&path),
193            baseline,
194            "drop on a foreign thread must still close every connection"
195        );
196    }
197
198    #[test]
199    fn shared_pool_for_in_memory_db() {
200        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
201        // NOTE: db.data_version() returns the DB crate's error type, so wrap with
202        // Ok(...?) to convert it into the core Result the closure must return.
203        let v = pool.with(|db| Ok(db.data_version()?)).unwrap();
204        let v2 = pool.with(|db| Ok(db.data_version()?)).unwrap();
205        assert_eq!(v, v2);
206    }
207
208    #[test]
209    fn same_thread_two_pools_keyed_by_path() {
210        let dir = tempfile::tempdir().unwrap();
211        let path_a = dir.path().join("a.db");
212        let path_b = dir.path().join("b.db");
213        Db::open(&path_a).unwrap();
214        Db::open(&path_b).unwrap();
215
216        let pool_a = DbPool::new(Db::open(&path_a).unwrap()).unwrap();
217        let pool_b = DbPool::new(Db::open(&path_b).unwrap()).unwrap();
218
219        pool_a
220            .with(|db| {
221                assert_eq!(db.path().unwrap(), path_a);
222                Ok(())
223            })
224            .unwrap();
225        pool_b
226            .with(|db| {
227                assert_eq!(db.path().unwrap(), path_b);
228                Ok(())
229            })
230            .unwrap();
231    }
232
233    #[test]
234    fn per_thread_pool_for_file_db() {
235        let dir = tempfile::tempdir().unwrap();
236        let path = dir.path().join("m.db");
237        Db::open(&path).unwrap(); // create + migrate (writer, sets WAL)
238        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
239        // Used from a different thread: that thread opens its own read connection.
240        let r = std::thread::scope(|s| {
241            s.spawn(|| pool.with(|db| Ok(db.data_version()?)).unwrap())
242                .join()
243                .unwrap()
244        });
245        assert!(r >= 0);
246    }
247
248    #[test]
249    fn reentrant_with_does_not_panic() {
250        let dir = tempfile::tempdir().unwrap();
251        let path = dir.path().join("re.db");
252        Db::open(&path).unwrap();
253        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
254        let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
255        assert!(r.is_ok(), "re-entrant with() must not panic or error");
256    }
257
258    #[test]
259    fn with_open_failure_includes_path_in_error() {
260        let bad = std::path::PathBuf::from("/nonexistent-musefs-dir/does-not-exist.db");
261        let pool = DbPool::PerThread {
262            path: bad.clone(),
263            poll: Box::new(ReentrantMutex::new(
264                Db::open_in_memory().unwrap().into_read_only(),
265            )),
266            conns: Mutex::new(HashMap::new()),
267        };
268        let msg = pool.with(|_db| Ok(())).unwrap_err().to_string();
269        assert!(
270            msg.contains("/nonexistent-musefs-dir/does-not-exist.db"),
271            "open error must name the failing path, got: {msg}"
272        );
273    }
274
275    #[test]
276    fn nested_with_on_shared_pool() {
277        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
278        let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
279        assert!(r.is_ok(), "nested with on Shared must not deadlock");
280    }
281
282    #[test]
283    fn with_poll_inside_with_on_shared_pool() {
284        let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
285        let r: Result<i64> = pool.with(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
286        assert!(
287            r.is_ok(),
288            "with_poll inside with on Shared must not deadlock"
289        );
290    }
291
292    #[test]
293    fn nested_with_poll_on_per_thread_pool() {
294        let dir = tempfile::tempdir().unwrap();
295        let path = dir.path().join("np.db");
296        Db::open(&path).unwrap();
297        let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
298        let r: Result<i64> = pool.with_poll(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
299        assert!(r.is_ok(), "nested with_poll on PerThread must not deadlock");
300    }
301}