use lru::LruCache;
use parking_lot::RwLock;
use rocksdb::DB;
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock};
use crate::core::db;
use crate::core::db::data_frames::DataFrameError;
use crate::error::OxenError;
const CHANGES_DB_CACHE_SIZE: NonZeroUsize = NonZeroUsize::new(100).unwrap();
static CHANGES_DB_INSTANCES: LazyLock<RwLock<LruCache<PathBuf, Arc<DB>>>> =
LazyLock::new(|| RwLock::new(LruCache::new(CHANGES_DB_CACHE_SIZE)));
pub fn remove_from_cache(db_path: impl AsRef<Path>) -> Result<(), OxenError> {
let mut instances = CHANGES_DB_INSTANCES.write();
let _ = instances.pop(&db_path.as_ref().to_path_buf());
Ok(())
}
pub fn remove_from_cache_with_children(prefix: impl AsRef<Path>) -> Result<(), OxenError> {
let prefix = prefix.as_ref();
let mut instances = CHANGES_DB_INSTANCES.write();
let to_remove: Vec<PathBuf> = instances
.iter()
.map(|(key, _)| key.clone())
.filter(|key| key.starts_with(prefix))
.collect();
for key in to_remove {
let _ = instances.pop(&key);
}
Ok(())
}
pub fn get_changes_db(db_path: &Path) -> Result<Arc<DB>, DataFrameError> {
if let Some(db) = lookup_cached(db_path) {
return Ok(db);
}
open_and_cache(db_path)
}
pub fn try_get_changes_db(db_path: &Path) -> Result<Option<Arc<DB>>, DataFrameError> {
if let Some(db) = lookup_cached(db_path) {
return Ok(Some(db));
}
if !db_path.exists() {
return Ok(None);
}
open_and_cache(db_path).map(Some)
}
fn lookup_cached(db_path: &Path) -> Option<Arc<DB>> {
let cache_r = CHANGES_DB_INSTANCES.read();
cache_r.peek(&db_path.to_path_buf()).cloned()
}
fn open_and_cache(db_path: &Path) -> Result<Arc<DB>, DataFrameError> {
let key = db_path.to_path_buf();
let mut cache_w = CHANGES_DB_INSTANCES.write();
if let Some(db) = cache_w.get(&key) {
return Ok(db.clone());
}
if !db_path.exists() {
std::fs::create_dir_all(db_path).map_err(DataFrameError::FailCreateDfDbDir)?;
}
let opts = db::key_val::opts::default();
let db = Arc::new(DB::open(&opts, dunce::simplified(db_path))?);
cache_w.put(key, db.clone());
Ok(db)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
use std::thread;
#[test]
fn test_get_changes_db_shares_one_handle_per_path() -> Result<(), OxenError> {
test::run_empty_dir_test(|data_dir| {
let db_path = data_dir.join("row_changes");
let a = get_changes_db(&db_path)?;
let b = get_changes_db(&db_path)?;
assert!(
Arc::ptr_eq(&a, &b),
"repeated opens of the same path must reuse one cached handle"
);
remove_from_cache(&db_path)?;
Ok(())
})
}
#[test]
fn test_concurrent_opens_do_not_hit_the_lock_error() -> Result<(), OxenError> {
const NUM_THREADS: usize = CHANGES_DB_CACHE_SIZE.get() + 8;
test::run_empty_dir_test(|data_dir| {
let db_path = data_dir.join("row_changes");
let _held = get_changes_db(&db_path)?;
thread::scope(|scope| {
let handles: Vec<_> = (0..NUM_THREADS)
.map(|i| {
let db_path = db_path.clone();
scope.spawn(move || -> Result<(), OxenError> {
let db = get_changes_db(&db_path)?;
db.put(format!("key-{i}"), format!("val-{i}"))?;
Ok(())
})
})
.collect();
for handle in handles {
handle
.join()
.expect("worker thread panicked")
.expect("concurrent open/write must not hit the RocksDB lock error");
}
});
let db = get_changes_db(&db_path)?;
for i in 0..NUM_THREADS {
let value = db.get(format!("key-{i}"))?;
assert_eq!(value.as_deref(), Some(format!("val-{i}").as_bytes()));
}
remove_from_cache(&db_path)?;
Ok(())
})
}
#[test]
fn test_try_get_changes_db_does_not_create_missing_db() -> Result<(), OxenError> {
test::run_empty_dir_test(|data_dir| {
let db_path = data_dir.join("row_changes");
assert!(try_get_changes_db(&db_path)?.is_none());
assert!(
!db_path.exists(),
"try_get_changes_db must not create the db directory on a read miss"
);
get_changes_db(&db_path)?.put("k", "v")?;
let db = try_get_changes_db(&db_path)?.expect("db exists after write");
assert_eq!(db.get("k")?.as_deref(), Some(b"v".as_ref()));
remove_from_cache(&db_path)?;
Ok(())
})
}
}