use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::sync::Arc;
pub struct LockHandle {
lock: Arc<RwLock<()>>,
}
impl Default for LockHandle {
fn default() -> Self {
Self::new()
}
}
impl LockHandle {
pub fn new() -> Self {
LockHandle {
lock: Arc::new(RwLock::new(())),
}
}
pub fn read(&self) -> RwLockReadGuard<'_, ()> {
self.lock.read()
}
pub fn write(&self) -> RwLockWriteGuard<'_, ()> {
self.lock.write()
}
}
#[derive(Clone)]
pub struct LockRegistry {
locks: Arc<RwLock<HashMap<String, Arc<RwLock<()>>>>>,
}
impl LockRegistry {
pub fn new() -> Self {
LockRegistry {
locks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_lock(&self, name: &str) -> LockHandle {
let lock = {
let mut locks = self.locks.write();
locks
.entry(name.to_string())
.or_insert_with(|| Arc::new(RwLock::new(())))
.clone()
};
LockHandle { lock }
}
pub fn remove_lock(&self, name: &str) -> bool {
let mut locks = self.locks.write();
locks.remove(name).is_some()
}
pub fn lock_count(&self) -> usize {
let locks = self.locks.read();
locks.len()
}
}
impl Default for LockRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc as StdArc;
use std::thread;
#[test]
fn test_new_lock_registry() {
let lock_registry = LockRegistry::new();
assert_eq!(lock_registry.lock_count(), 0);
}
#[test]
fn test_get_lock() {
let lock_registry = LockRegistry::new();
let _handle = lock_registry.get_lock("resource1");
let _read_guard = _handle.read();
assert_eq!(lock_registry.lock_count(), 1);
}
#[test]
fn test_get_lock_write() {
let lock_registry = LockRegistry::new();
let _handle = lock_registry.get_lock("resource1");
let _write_guard = _handle.write();
assert_eq!(lock_registry.lock_count(), 1);
}
#[test]
fn test_multiple_read_locks_same_name() {
let lock_registry = StdArc::new(LockRegistry::new());
let counter = StdArc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _i in 0..3 {
let registry = lock_registry.clone();
let cnt = counter.clone();
let handle = thread::spawn(move || {
let lock_handle = registry.get_lock("resource1");
let _read_guard = lock_handle.read();
cnt.fetch_add(1, Ordering::SeqCst);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 3);
assert_eq!(lock_registry.lock_count(), 1);
}
#[test]
fn test_remove_lock() {
let lock_registry = LockRegistry::new();
let _handle = lock_registry.get_lock("resource1");
let _read_guard = _handle.read();
assert_eq!(lock_registry.lock_count(), 1);
let removed = lock_registry.remove_lock("resource1");
assert!(removed);
assert_eq!(lock_registry.lock_count(), 0);
}
#[test]
fn test_remove_nonexistent_lock() {
let lock_registry = LockRegistry::new();
let removed = lock_registry.remove_lock("nonexistent");
assert!(!removed);
}
#[test]
fn test_default() {
let lock_registry = LockRegistry::default();
assert_eq!(lock_registry.lock_count(), 0);
}
}