use std::collections::HashMap;
use std::sync::{Arc, Mutex};
const MAX_ENTRIES: usize = 500;
pub fn per_file_lock(path: &str) -> Arc<Mutex<()>> {
static FILE_LOCKS: std::sync::OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> =
std::sync::OnceLock::new();
let map = FILE_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = map.lock().unwrap_or_else(|poisoned| {
tracing::warn!("path_locks registry poisoned; recovering");
poisoned.into_inner()
});
if map.len() > MAX_ENTRIES {
map.retain(|_, v| Arc::strong_count(v) > 1);
}
map.entry(path.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Barrier;
#[test]
fn same_path_returns_same_mutex() {
let a1 = per_file_lock("/tmp/path_locks_same.txt");
let a2 = per_file_lock("/tmp/path_locks_same.txt");
assert!(Arc::ptr_eq(&a1, &a2));
}
#[test]
fn different_paths_return_different_mutexes() {
let a = per_file_lock("/tmp/path_locks_a.txt");
let b = per_file_lock("/tmp/path_locks_b.txt");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn serializes_concurrent_access_to_same_path() {
let counter = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(8));
let path = "/tmp/path_locks_serialize.txt";
let mut handles = Vec::new();
for _ in 0..8 {
let counter = Arc::clone(&counter);
let max_concurrent = Arc::clone(&max_concurrent);
let barrier = Arc::clone(&barrier);
handles.push(std::thread::spawn(move || {
barrier.wait();
let lock = per_file_lock(path);
let _guard = lock.lock().unwrap();
let active = counter.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(active, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(5));
counter.fetch_sub(1, Ordering::SeqCst);
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(
max_concurrent.load(Ordering::SeqCst),
1,
"per-file lock must serialize same-path access"
);
}
#[test]
fn allows_parallel_access_to_different_paths() {
let counter = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for i in 0..8 {
let counter = Arc::clone(&counter);
let max_concurrent = Arc::clone(&max_concurrent);
let barrier = Arc::clone(&barrier);
handles.push(std::thread::spawn(move || {
let path = format!("/tmp/path_locks_parallel_{i}.txt");
barrier.wait();
let lock = per_file_lock(&path);
let _guard = lock.lock().unwrap();
let active = counter.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(active, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(5));
counter.fetch_sub(1, Ordering::SeqCst);
}));
}
for h in handles {
h.join().unwrap();
}
assert!(
max_concurrent.load(Ordering::SeqCst) > 1,
"different paths must be allowed to run in parallel"
);
}
}