1use std::collections::{HashMap, hash_map::Entry};
16use std::path::PathBuf;
17use std::sync::Arc;
18use std::thread::ThreadId;
19
20use parking_lot::{Mutex, ReentrantMutex};
21
22use musefs_db::{Db, ReadOnly};
23
24use crate::error::{CoreError, Result};
25
26pub enum DbPool {
40 PerThread {
41 path: PathBuf,
42 poll: Box<ReentrantMutex<Db<ReadOnly>>>,
43 conns: Mutex<HashMap<ThreadId, Arc<ReentrantMutex<Db<ReadOnly>>>>>,
44 },
45 Shared(Arc<ReentrantMutex<Db<ReadOnly>>>),
46}
47
48impl DbPool {
49 pub fn new(db: Db) -> Result<DbPool> {
54 let db = db.into_read_only();
55 match db.path() {
56 Some(p) => Ok(DbPool::PerThread {
57 path: p.to_path_buf(),
58 poll: Box::new(ReentrantMutex::new(db)),
59 conns: Mutex::new(HashMap::new()),
60 }),
61 None => Ok(DbPool::Shared(Arc::new(ReentrantMutex::new(db)))),
62 }
63 }
64
65 pub fn with_poll<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
73 match self {
74 DbPool::PerThread { poll, .. } => f(&poll.lock()),
75 DbPool::Shared(m) => f(&m.lock()),
76 }
77 }
78
79 pub fn with<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
81 match self {
82 DbPool::PerThread { path, conns, .. } => {
83 let tid = std::thread::current().id();
84 let db = {
85 let mut map = conns.lock();
86 match map.entry(tid) {
87 Entry::Occupied(entry) => Arc::clone(entry.get()),
88 Entry::Vacant(entry) => {
89 let db =
90 Db::open_readonly(path).map_err(|source| CoreError::DbOpen {
91 path: path.clone(),
92 source,
93 })?;
94 Arc::clone(entry.insert(Arc::new(ReentrantMutex::new(db))))
95 }
96 }
97 };
98 let guard = db.lock();
99 f(&guard)
100 }
101 DbPool::Shared(m) => {
102 let db = m.lock();
103 f(&db)
104 }
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use musefs_db::Db;
113
114 #[cfg(target_os = "linux")]
120 fn db_fd_count(db_path: &std::path::Path) -> usize {
121 let prefix = db_path.to_str().unwrap();
122 std::fs::read_dir("/proc/self/fd")
123 .unwrap()
124 .filter_map(|e| std::fs::read_link(e.unwrap().path()).ok())
125 .filter(|target| target.to_string_lossy().starts_with(prefix))
126 .count()
127 }
128
129 #[cfg(target_os = "linux")]
131 #[test]
132 fn drop_closes_connections_opened_by_live_threads() {
133 let dir = tempfile::tempdir().unwrap();
134 let path = dir.path().join("d.db");
135 Db::open(&path).unwrap(); let baseline = db_fd_count(&path);
137
138 let pool = Arc::new(DbPool::new(Db::open(&path).unwrap()).unwrap());
139 let barrier = Arc::new(std::sync::Barrier::new(3));
141 let (done_tx, done_rx) = std::sync::mpsc::channel();
142 let mut handles = Vec::new();
143 for _ in 0..2 {
144 let pool = Arc::clone(&pool);
145 let barrier = Arc::clone(&barrier);
146 let done = done_tx.clone();
147 handles.push(std::thread::spawn(move || {
148 pool.with(|db| Ok(db.data_version()?)).unwrap();
149 drop(pool); done.send(()).unwrap();
151 barrier.wait();
152 }));
153 }
154 drop(done_tx);
157 for _ in 0..2 {
158 done_rx.recv().unwrap();
159 }
160
161 drop(pool);
164 assert_eq!(
165 db_fd_count(&path),
166 baseline,
167 "pool drop must close all threads' connections while those threads are alive"
168 );
169
170 barrier.wait();
171 for h in handles {
172 h.join().unwrap();
173 }
174 }
175
176 #[cfg(target_os = "linux")]
178 #[test]
179 fn drop_on_foreign_thread_closes_all_connections() {
180 let dir = tempfile::tempdir().unwrap();
181 let path = dir.path().join("x.db");
182 Db::open(&path).unwrap();
183 let baseline = db_fd_count(&path);
184
185 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
186 pool.with(|db| Ok(db.data_version()?)).unwrap(); std::thread::spawn(move || drop(pool)).join().unwrap();
190
191 assert_eq!(
192 db_fd_count(&path),
193 baseline,
194 "drop on a foreign thread must still close every connection"
195 );
196 }
197
198 #[test]
199 fn shared_pool_for_in_memory_db() {
200 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
201 let v = pool.with(|db| Ok(db.data_version()?)).unwrap();
204 let v2 = pool.with(|db| Ok(db.data_version()?)).unwrap();
205 assert_eq!(v, v2);
206 }
207
208 #[test]
209 fn same_thread_two_pools_keyed_by_path() {
210 let dir = tempfile::tempdir().unwrap();
211 let path_a = dir.path().join("a.db");
212 let path_b = dir.path().join("b.db");
213 Db::open(&path_a).unwrap();
214 Db::open(&path_b).unwrap();
215
216 let pool_a = DbPool::new(Db::open(&path_a).unwrap()).unwrap();
217 let pool_b = DbPool::new(Db::open(&path_b).unwrap()).unwrap();
218
219 pool_a
220 .with(|db| {
221 assert_eq!(db.path().unwrap(), path_a);
222 Ok(())
223 })
224 .unwrap();
225 pool_b
226 .with(|db| {
227 assert_eq!(db.path().unwrap(), path_b);
228 Ok(())
229 })
230 .unwrap();
231 }
232
233 #[test]
234 fn per_thread_pool_for_file_db() {
235 let dir = tempfile::tempdir().unwrap();
236 let path = dir.path().join("m.db");
237 Db::open(&path).unwrap(); let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
239 let r = std::thread::scope(|s| {
241 s.spawn(|| pool.with(|db| Ok(db.data_version()?)).unwrap())
242 .join()
243 .unwrap()
244 });
245 assert!(r >= 0);
246 }
247
248 #[test]
249 fn reentrant_with_does_not_panic() {
250 let dir = tempfile::tempdir().unwrap();
251 let path = dir.path().join("re.db");
252 Db::open(&path).unwrap();
253 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
254 let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
255 assert!(r.is_ok(), "re-entrant with() must not panic or error");
256 }
257
258 #[test]
259 fn with_open_failure_includes_path_in_error() {
260 let bad = std::path::PathBuf::from("/nonexistent-musefs-dir/does-not-exist.db");
261 let pool = DbPool::PerThread {
262 path: bad.clone(),
263 poll: Box::new(ReentrantMutex::new(
264 Db::open_in_memory().unwrap().into_read_only(),
265 )),
266 conns: Mutex::new(HashMap::new()),
267 };
268 let msg = pool.with(|_db| Ok(())).unwrap_err().to_string();
269 assert!(
270 msg.contains("/nonexistent-musefs-dir/does-not-exist.db"),
271 "open error must name the failing path, got: {msg}"
272 );
273 }
274
275 #[test]
276 fn nested_with_on_shared_pool() {
277 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
278 let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
279 assert!(r.is_ok(), "nested with on Shared must not deadlock");
280 }
281
282 #[test]
283 fn with_poll_inside_with_on_shared_pool() {
284 let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
285 let r: Result<i64> = pool.with(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
286 assert!(
287 r.is_ok(),
288 "with_poll inside with on Shared must not deadlock"
289 );
290 }
291
292 #[test]
293 fn nested_with_poll_on_per_thread_pool() {
294 let dir = tempfile::tempdir().unwrap();
295 let path = dir.path().join("np.db");
296 Db::open(&path).unwrap();
297 let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
298 let r: Result<i64> = pool.with_poll(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
299 assert!(r.is_ok(), "nested with_poll on PerThread must not deadlock");
300 }
301}