use std::fs::OpenOptions;
use std::path::Path;
use fs2::FileExt;
use crate::error::VaultdbError;
pub(crate) const META_DIR: &str = ".vaultdb";
const LOCK_FILE: &str = "lock";
pub fn with_vault_lock<F, R, E>(vault_root: &Path, op: F) -> Result<R, E>
where
F: FnOnce() -> Result<R, E>,
E: From<std::io::Error>,
{
let lock_dir = vault_root.join(META_DIR);
std::fs::create_dir_all(&lock_dir).map_err(E::from)?;
let lock_path = lock_dir.join(LOCK_FILE);
let lock_file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&lock_path)
.map_err(E::from)?;
lock_file.lock_exclusive().map_err(E::from)?;
let result = op();
let _ = FileExt::unlock(&lock_file);
result
}
pub(crate) fn with_lock<F, R>(vault_root: &Path, op: F) -> Result<R, VaultdbError>
where
F: FnOnce() -> Result<R, VaultdbError>,
{
with_vault_lock(vault_root, op)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
use tempfile::TempDir;
fn make_vault() -> TempDir {
let dir = TempDir::new().unwrap();
std::fs::create_dir(dir.path().join(".obsidian")).unwrap();
dir
}
#[test]
fn lock_file_and_meta_dir_are_created() {
let dir = make_vault();
let result: Result<(), VaultdbError> = with_lock(dir.path(), || Ok(()));
result.unwrap();
assert!(dir.path().join(".vaultdb").is_dir());
assert!(dir.path().join(".vaultdb").join("lock").is_file());
}
#[test]
fn lock_serializes_concurrent_callers() {
let dir = make_vault();
let vault_path = dir.path().to_path_buf();
let counter = Arc::new(AtomicUsize::new(0));
let collisions = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..8 {
let vault_path = vault_path.clone();
let counter = Arc::clone(&counter);
let collisions = Arc::clone(&collisions);
handles.push(thread::spawn(move || {
let result: Result<(), VaultdbError> = with_lock(&vault_path, || {
let before = counter.load(Ordering::SeqCst);
thread::sleep(Duration::from_millis(20));
let after = counter.load(Ordering::SeqCst);
if before != after {
collisions.fetch_add(1, Ordering::SeqCst);
}
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
});
result.unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(
counter.load(Ordering::SeqCst),
8,
"every thread should have incremented exactly once"
);
assert_eq!(
collisions.load(Ordering::SeqCst),
0,
"lock failed to serialize: counter changed during a critical section"
);
}
#[test]
fn op_error_propagates_and_lock_releases() {
let dir = make_vault();
let result: Result<(), VaultdbError> = with_lock(dir.path(), || {
Err(VaultdbError::SchemaError("intentional".into()))
});
assert!(matches!(
result,
Err(VaultdbError::SchemaError(ref m)) if m == "intentional"
));
let result: Result<(), VaultdbError> = with_lock(dir.path(), || Ok(()));
result.unwrap();
}
}