use parking_lot::{ArcMutexGuard, Mutex, RawMutex};
use std::{collections::HashMap, hash::Hash, sync::Arc};
pub struct Guard<'k, K: Eq + Hash + Clone> {
key: K,
_guard: ArcMutexGuard<RawMutex, ()>,
keyed_lock: &'k KeyedLock<K>,
}
impl<'k, K: Eq + Hash + Clone> Drop for Guard<'k, K> {
fn drop(&mut self) {
let mut registry = self.keyed_lock.0.lock();
if let Some(arc_mutex) = registry.get(&self.key) {
if Arc::strong_count(arc_mutex) == 2 {
registry.remove(&self.key);
}
}
}
}
pub struct OwnedGuard<K: Eq + Hash + Clone> {
key: K,
_guard: ArcMutexGuard<RawMutex, ()>,
keyed_lock: Arc<KeyedLock<K>>,
}
impl<K: Eq + Hash + Clone> Drop for OwnedGuard<K> {
fn drop(&mut self) {
let mut registry = self.keyed_lock.0.lock();
if let Some(arc_mutex) = registry.get(&self.key) {
if Arc::strong_count(arc_mutex) == 2 {
registry.remove(&self.key);
}
}
}
}
pub struct KeyedLock<K: Eq + Hash + Clone>(Mutex<HashMap<K, Arc<Mutex<()>>>>);
impl<K: Eq + Hash + Clone> KeyedLock<K> {
#[must_use]
pub fn new() -> Self {
Self(Mutex::new(HashMap::new()))
}
pub fn lock(&self, key: K) -> Guard<'_, K> {
let _guard = self.lock_inner(&key);
Guard {
key,
_guard,
keyed_lock: self,
}
}
pub fn lock_owned(self: &Arc<Self>, key: K) -> OwnedGuard<K> {
let _guard = self.lock_inner(&key);
OwnedGuard {
key,
_guard,
keyed_lock: self.clone(),
}
}
fn lock_inner(&self, key: &K) -> ArcMutexGuard<RawMutex, ()> {
let key_lock = {
let mut registry = self.0.lock();
if let Some(notifies) = registry.get_mut(key) {
Arc::clone(notifies)
} else {
let new = Arc::new(Mutex::new(()));
registry.insert(key.clone(), new.clone());
new
}
};
key_lock.lock_arc()
}
#[cfg(test)]
fn registry_len(&self) -> usize {
self.0.lock().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
#[test]
fn test_basic_lock() {
let keyed_lock = KeyedLock::new();
let _guard = keyed_lock.lock(1);
}
#[test]
fn test_concurrent_access() {
let keyed_lock = Arc::new(KeyedLock::new());
let mut handles = vec![];
for _ in 0..10 {
let keyed_lock_clone = Arc::clone(&keyed_lock);
let handle = thread::spawn(move || {
let _guard = keyed_lock_clone.lock(1);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_lock_is_released() {
let keyed_lock = KeyedLock::new();
let guard = keyed_lock.lock(1);
drop(guard);
let _guard2 = keyed_lock.lock(1);
}
#[test]
fn test_lock_reuse() {
let keyed_lock = KeyedLock::new();
let guard1 = keyed_lock.lock(1);
drop(guard1);
let guard2 = keyed_lock.lock(1);
drop(guard2);
}
#[test]
fn test_locks_different_keys() {
let keyed_lock = KeyedLock::new();
let _guard1 = keyed_lock.lock(1);
let _guard2 = keyed_lock.lock(2);
}
#[test]
fn test_multiple_keys_concurrently() {
let keyed_lock = Arc::new(KeyedLock::new());
let mut handles = vec![];
for i in 0..10 {
let keyed_lock_clone = Arc::clone(&keyed_lock);
let handle = thread::spawn(move || {
let _guard = keyed_lock_clone.lock(i);
thread::sleep(Duration::from_millis(10));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_non_reentrant_lock() {
let keyed_lock = Arc::new(KeyedLock::new());
let keyed_lock_clone = Arc::clone(&keyed_lock);
let _guard = keyed_lock.lock(1);
let handle = thread::spawn(move || {
let now = Instant::now();
let _guard = keyed_lock_clone.lock(1);
assert!(now.elapsed() >= Duration::from_secs(3));
});
std::thread::sleep(Duration::from_secs(4));
drop(_guard);
handle.join().unwrap();
}
#[test]
fn test_registry_cleanup() {
let keyed_lock = KeyedLock::new();
assert_eq!(keyed_lock.registry_len(), 0);
let guard = keyed_lock.lock(1);
assert_eq!(keyed_lock.registry_len(), 1);
drop(guard);
assert_eq!(keyed_lock.registry_len(), 0);
}
#[test]
fn test_registry_cleanup_concurrent() {
let keyed_lock = Arc::new(KeyedLock::new());
assert_eq!(keyed_lock.registry_len(), 0);
let guard1 = keyed_lock.lock(1);
assert_eq!(keyed_lock.registry_len(), 1);
let keyed_lock_clone = Arc::clone(&keyed_lock);
let handle = thread::spawn(move || {
let guard2 = keyed_lock_clone.lock(1);
assert_eq!(keyed_lock_clone.registry_len(), 1);
drop(guard2);
});
assert_eq!(keyed_lock.registry_len(), 1);
drop(guard1);
handle.join().unwrap();
assert_eq!(keyed_lock.registry_len(), 0);
}
#[test]
fn test_registry_cleanup_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
assert_eq!(keyed_lock.registry_len(), 0);
let guard = keyed_lock.lock_owned(1);
assert_eq!(keyed_lock.registry_len(), 1);
drop(guard);
assert_eq!(keyed_lock.registry_len(), 0);
}
#[test]
fn test_lock_arc_concurrently() {
let keyed_lock = Arc::new(KeyedLock::new());
let mut handles = vec![];
for i in 0..10 {
let keyed_lock_clone = Arc::clone(&keyed_lock);
let handle = thread::spawn(move || {
let _guard = keyed_lock_clone.lock_owned(i);
thread::sleep(Duration::from_millis(10));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[cfg(feature = "send_guard")]
#[test]
fn test_non_reentrant_lock_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
let _guard = keyed_lock.lock_owned(1);
let handle = thread::spawn(move || {
std::thread::sleep(Duration::from_secs(4));
drop(_guard);
});
let now = Instant::now();
let _guard = keyed_lock.lock(1);
assert!(now.elapsed() >= Duration::from_secs(4));
handle.join().unwrap();
}
#[test]
fn test_basic_lock_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
let _guard = keyed_lock.lock_owned(1);
}
#[test]
fn test_lock_is_released_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
let guard = keyed_lock.lock_owned(1);
drop(guard);
let _guard2 = keyed_lock.lock_owned(1);
}
#[test]
fn test_lock_reuse_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
let guard1 = keyed_lock.lock_owned(1);
drop(guard1);
let guard2 = keyed_lock.lock_owned(1);
drop(guard2);
}
#[test]
fn test_locks_different_keys_arc() {
let keyed_lock = Arc::new(KeyedLock::new());
let _guard1 = keyed_lock.lock_owned(1);
let _guard2 = keyed_lock.lock_owned(2);
}
}