1use std::path::PathBuf;
16use std::sync::Arc;
17use std::thread::ThreadId;
18
19use dashmap::DashMap;
20use parking_lot::ReentrantMutex;
21
22use musefs_db::{Db, ReadOnly};
23
24use crate::error::{CoreError, Result};
25
26pub enum DbPool {
42 PerThread {
43 path: PathBuf,
44 poll: Box<ReentrantMutex<Db<ReadOnly>>>,
45 conns: DashMap<ThreadId, Arc<ReentrantMutex<Db<ReadOnly>>>>,
46 },
47 Shared(Arc<ReentrantMutex<Db<ReadOnly>>>),
48}
49
50impl DbPool {
51 pub fn new(db: Db) -> Result<DbPool> {
56 let db = db.into_read_only();
57 match db.path() {
58 Some(p) => Ok(DbPool::PerThread {
59 path: p.to_path_buf(),
60 poll: Box::new(ReentrantMutex::new(db)),
61 conns: DashMap::new(),
62 }),
63 None => Ok(DbPool::Shared(Arc::new(ReentrantMutex::new(db)))),
64 }
65 }
66
67 pub fn with_poll<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
75 match self {
76 DbPool::PerThread { poll, .. } => f(&poll.lock()),
77 DbPool::Shared(m) => f(&m.lock()),
78 }
79 }
80
81 pub fn with<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
83 match self {
84 DbPool::PerThread { path, conns, .. } => {
85 let tid = std::thread::current().id();
86 let db = if let Some(existing) = conns.get(&tid) {
92 Arc::clone(existing.value())
93 } else {
94 match conns.entry(tid) {
95 dashmap::Entry::Occupied(entry) => Arc::clone(entry.get()),
96 dashmap::Entry::Vacant(entry) => {
97 let db =
98 Db::open_readonly(path).map_err(|source| CoreError::DbOpen {
99 path: path.clone(),
100 source,
101 })?;
102 Arc::clone(&entry.insert(Arc::new(ReentrantMutex::new(db))))
103 }
104 }
105 };
106 let guard = db.lock();
107 f(&guard)
108 }
109 DbPool::Shared(m) => {
110 let db = m.lock();
111 f(&db)
112 }
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use musefs_db::Db;
121
122 #[cfg(target_os = "linux")]
128 fn db_fd_count(db_path: &std::path::Path) -> usize {
129 let prefix = db_path.to_str().unwrap();
130 std::fs::read_dir("/proc/self/fd")
131 .unwrap()
132 .filter_map(|e| std::fs::read_link(e.unwrap().path()).ok())
133 .filter(|target| target.to_string_lossy().starts_with(prefix))
134 .count()
135 }
136
137 #[cfg(target_os = "linux")]
139 #[test]
140 fn drop_closes_connections_opened_by_live_threads() {
141 let dir = tempfile::tempdir().unwrap();
142 let path = dir.path().join("d.db");
143 Db::open(&path).unwrap(); let baseline = db_fd_count(&path);
145
146 let pool = Arc::new(DbPool::new(Db::open(&path).unwrap()).unwrap());
147 let barrier = Arc::new(std::sync::Barrier::new(3));
149 let (done_tx, done_rx) = std::sync::mpsc::channel();
150 let mut handles = Vec::new();
151 for _ in 0..2 {
152 let pool = Arc::clone(&pool);
153 let barrier = Arc::clone(&barrier);
154 let done = done_tx.clone();
155 handles.push(std::thread::spawn(move || {
156 pool.with(|db| Ok(db.data_version()?)).unwrap();
157 drop(pool); done.send(()).unwrap();
159 barrier.wait();
160 }));
161 }
162 drop(done_tx);
165 for _ in 0..2 {
166 done_rx.recv().unwrap();
167 }
168
169 drop(pool);
172 assert_eq!(
173 db_fd_count(&path),
174 baseline,
175 "pool drop must close all threads' connections while those threads are alive"
176 );
177
178 barrier.wait();
179 for h in handles {
180 h.join().unwrap();
181 }
182 }
183
184 #[cfg(target_os = "linux")]
186 #[test]
187 fn drop_on_foreign_thread_closes_all_connections() {
188 let dir = tempfile::tempdir().unwrap();
189 let path = dir.path().join("x.db");
190 Db::open(&path).unwrap();
191 let baseline = db_fd_count(&path);
192
193 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
194 pool.with(|db| Ok(db.data_version()?)).unwrap(); std::thread::spawn(move || drop(pool)).join().unwrap();
198
199 assert_eq!(
200 db_fd_count(&path),
201 baseline,
202 "drop on a foreign thread must still close every connection"
203 );
204 }
205
206 #[test]
207 fn shared_pool_for_in_memory_db() {
208 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
209 let v = pool.with(|db| Ok(db.data_version()?)).unwrap();
212 let v2 = pool.with(|db| Ok(db.data_version()?)).unwrap();
213 assert_eq!(v, v2);
214 }
215
216 #[test]
217 fn same_thread_two_pools_keyed_by_path() {
218 let dir = tempfile::tempdir().unwrap();
219 let path_a = dir.path().join("a.db");
220 let path_b = dir.path().join("b.db");
221 Db::open(&path_a).unwrap();
222 Db::open(&path_b).unwrap();
223
224 let pool_a = DbPool::new(Db::open(&path_a).unwrap()).unwrap();
225 let pool_b = DbPool::new(Db::open(&path_b).unwrap()).unwrap();
226
227 pool_a
228 .with(|db| {
229 assert_eq!(db.path().unwrap(), path_a);
230 Ok(())
231 })
232 .unwrap();
233 pool_b
234 .with(|db| {
235 assert_eq!(db.path().unwrap(), path_b);
236 Ok(())
237 })
238 .unwrap();
239 }
240
241 #[test]
242 fn per_thread_pool_for_file_db() {
243 let dir = tempfile::tempdir().unwrap();
244 let path = dir.path().join("m.db");
245 Db::open(&path).unwrap(); let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
247 let r = std::thread::scope(|s| {
249 s.spawn(|| pool.with(|db| Ok(db.data_version()?)).unwrap())
250 .join()
251 .unwrap()
252 });
253 assert!(r >= 0);
254 }
255
256 #[test]
257 fn reentrant_with_does_not_panic() {
258 let dir = tempfile::tempdir().unwrap();
259 let path = dir.path().join("re.db");
260 Db::open(&path).unwrap();
261 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
262 let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
263 assert!(r.is_ok(), "re-entrant with() must not panic or error");
264 }
265
266 #[test]
267 fn with_open_failure_includes_path_in_error() {
268 let bad = std::path::PathBuf::from("/nonexistent-musefs-dir/does-not-exist.db");
269 let pool = DbPool::PerThread {
270 path: bad.clone(),
271 poll: Box::new(ReentrantMutex::new(
272 Db::open_in_memory().unwrap().into_read_only(),
273 )),
274 conns: DashMap::new(),
275 };
276 let msg = pool.with(|_db| Ok(())).unwrap_err().to_string();
277 assert!(
278 msg.contains("/nonexistent-musefs-dir/does-not-exist.db"),
279 "open error must name the failing path, got: {msg}"
280 );
281 }
282
283 #[test]
284 fn nested_with_on_shared_pool() {
285 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
286 let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
287 assert!(r.is_ok(), "nested with on Shared must not deadlock");
288 }
289
290 #[test]
291 fn with_poll_inside_with_on_shared_pool() {
292 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
293 let r: Result<i64> = pool.with(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
294 assert!(
295 r.is_ok(),
296 "with_poll inside with on Shared must not deadlock"
297 );
298 }
299
300 #[test]
301 fn nested_with_poll_on_per_thread_pool() {
302 let dir = tempfile::tempdir().unwrap();
303 let path = dir.path().join("np.db");
304 Db::open(&path).unwrap();
305 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
306 let r: Result<i64> = pool.with_poll(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
307 assert!(r.is_ok(), "nested with_poll on PerThread must not deadlock");
308 }
309}