use std::path::PathBuf;
use std::sync::Arc;
use std::thread::ThreadId;
use dashmap::DashMap;
use parking_lot::ReentrantMutex;
use musefs_db::{Db, ReadOnly};
use crate::error::{CoreError, Result};
pub enum DbPool {
PerThread {
path: PathBuf,
poll: Box<ReentrantMutex<Db<ReadOnly>>>,
conns: DashMap<ThreadId, Arc<ReentrantMutex<Db<ReadOnly>>>>,
},
Shared(Arc<ReentrantMutex<Db<ReadOnly>>>),
}
impl DbPool {
pub fn new(db: Db) -> Result<DbPool> {
let db = db.into_read_only();
match db.path() {
Some(p) => Ok(DbPool::PerThread {
path: p.to_path_buf(),
poll: Box::new(ReentrantMutex::new(db)),
conns: DashMap::new(),
}),
None => Ok(DbPool::Shared(Arc::new(ReentrantMutex::new(db)))),
}
}
pub fn with_poll<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
match self {
DbPool::PerThread { poll, .. } => f(&poll.lock()),
DbPool::Shared(m) => f(&m.lock()),
}
}
pub fn with<R>(&self, f: impl FnOnce(&Db<ReadOnly>) -> Result<R>) -> Result<R> {
match self {
DbPool::PerThread { path, conns, .. } => {
let tid = std::thread::current().id();
let db = if let Some(existing) = conns.get(&tid) {
Arc::clone(existing.value())
} else {
match conns.entry(tid) {
dashmap::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::Entry::Vacant(entry) => {
let db =
Db::open_readonly(path).map_err(|source| CoreError::DbOpen {
path: path.clone(),
source,
})?;
Arc::clone(&entry.insert(Arc::new(ReentrantMutex::new(db))))
}
}
};
let guard = db.lock();
f(&guard)
}
DbPool::Shared(m) => {
let db = m.lock();
f(&db)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use musefs_db::Db;
#[cfg(target_os = "linux")]
fn db_fd_count(db_path: &std::path::Path) -> usize {
let prefix = db_path.to_str().unwrap();
std::fs::read_dir("/proc/self/fd")
.unwrap()
.filter_map(|e| std::fs::read_link(e.unwrap().path()).ok())
.filter(|target| target.to_string_lossy().starts_with(prefix))
.count()
}
#[cfg(target_os = "linux")]
#[test]
fn drop_closes_connections_opened_by_live_threads() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("d.db");
Db::open(&path).unwrap(); let baseline = db_fd_count(&path);
let pool = Arc::new(DbPool::new(Db::open(&path).unwrap()).unwrap());
let barrier = Arc::new(std::sync::Barrier::new(3));
let (done_tx, done_rx) = std::sync::mpsc::channel();
let mut handles = Vec::new();
for _ in 0..2 {
let pool = Arc::clone(&pool);
let barrier = Arc::clone(&barrier);
let done = done_tx.clone();
handles.push(std::thread::spawn(move || {
pool.with(|db| Ok(db.data_version()?)).unwrap();
drop(pool); done.send(()).unwrap();
barrier.wait();
}));
}
drop(done_tx);
for _ in 0..2 {
done_rx.recv().unwrap();
}
drop(pool);
assert_eq!(
db_fd_count(&path),
baseline,
"pool drop must close all threads' connections while those threads are alive"
);
barrier.wait();
for h in handles {
h.join().unwrap();
}
}
#[cfg(target_os = "linux")]
#[test]
fn drop_on_foreign_thread_closes_all_connections() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("x.db");
Db::open(&path).unwrap();
let baseline = db_fd_count(&path);
let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
pool.with(|db| Ok(db.data_version()?)).unwrap();
std::thread::spawn(move || drop(pool)).join().unwrap();
assert_eq!(
db_fd_count(&path),
baseline,
"drop on a foreign thread must still close every connection"
);
}
#[test]
fn shared_pool_for_in_memory_db() {
let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
let v = pool.with(|db| Ok(db.data_version()?)).unwrap();
let v2 = pool.with(|db| Ok(db.data_version()?)).unwrap();
assert_eq!(v, v2);
}
#[test]
fn same_thread_two_pools_keyed_by_path() {
let dir = tempfile::tempdir().unwrap();
let path_a = dir.path().join("a.db");
let path_b = dir.path().join("b.db");
Db::open(&path_a).unwrap();
Db::open(&path_b).unwrap();
let pool_a = DbPool::new(Db::open(&path_a).unwrap()).unwrap();
let pool_b = DbPool::new(Db::open(&path_b).unwrap()).unwrap();
pool_a
.with(|db| {
assert_eq!(db.path().unwrap(), path_a);
Ok(())
})
.unwrap();
pool_b
.with(|db| {
assert_eq!(db.path().unwrap(), path_b);
Ok(())
})
.unwrap();
}
#[test]
fn per_thread_pool_for_file_db() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("m.db");
Db::open(&path).unwrap(); let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
let r = std::thread::scope(|s| {
s.spawn(|| pool.with(|db| Ok(db.data_version()?)).unwrap())
.join()
.unwrap()
});
assert!(r >= 0);
}
#[test]
fn reentrant_with_does_not_panic() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("re.db");
Db::open(&path).unwrap();
let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
assert!(r.is_ok(), "re-entrant with() must not panic or error");
}
#[test]
fn with_open_failure_includes_path_in_error() {
let bad = std::path::PathBuf::from("/nonexistent-musefs-dir/does-not-exist.db");
let pool = DbPool::PerThread {
path: bad.clone(),
poll: Box::new(ReentrantMutex::new(
Db::open_in_memory().unwrap().into_read_only(),
)),
conns: DashMap::new(),
};
let msg = pool.with(|_db| Ok(())).unwrap_err().to_string();
assert!(
msg.contains("/nonexistent-musefs-dir/does-not-exist.db"),
"open error must name the failing path, got: {msg}"
);
}
#[test]
fn nested_with_on_shared_pool() {
let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
let r: Result<i64> = pool.with(|_outer| pool.with(|db| Ok(db.data_version()?)));
assert!(r.is_ok(), "nested with on Shared must not deadlock");
}
#[test]
fn with_poll_inside_with_on_shared_pool() {
let pool = DbPool::new(Db::open_in_memory().unwrap()).unwrap();
let r: Result<i64> = pool.with(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
assert!(
r.is_ok(),
"with_poll inside with on Shared must not deadlock"
);
}
#[test]
fn nested_with_poll_on_per_thread_pool() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("np.db");
Db::open(&path).unwrap();
let pool = DbPool::new(Db::open(&path).unwrap()).unwrap();
let r: Result<i64> = pool.with_poll(|_outer| pool.with_poll(|db| Ok(db.data_version()?)));
assert!(r.is_ok(), "nested with_poll on PerThread must not deadlock");
}
}