use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, MutexGuard, Condvar};
use std::thread::ThreadId;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::{Deref, DerefMut};
use std::fmt;
#[derive(Clone, Debug)]
pub struct DeclarativeLock<E: Eq + Hash + Clone + Debug> {
core: Arc<Core<E>>,
}
#[derive(Debug)]
struct Core<E: Eq + Hash + Clone + Debug> {
declared: Mutex<HashMap<ThreadId, HashSet<E>>>,
condvar: Condvar,
counter: Mutex<HashMap<E, isize>>,
}
impl<E: Eq + Hash + Clone + Debug> Core<E> {
fn is_declared(&self, resource: &E) -> bool {
let tid = std::thread::current().id();
self.declared
.lock()
.unwrap()
.get(&tid)
.map_or(false, |set| set.contains(resource))
}
}
impl<E: Eq + Hash + Clone + Debug> DeclarativeLock<E> {
pub fn new() -> Self {
Self {
core: Arc::new(Core {
declared: Mutex::new(HashMap::new()),
condvar: Condvar::new(),
counter: Mutex::new(HashMap::new()),
}),
}
}
pub fn declare(&self, resources: &[E]) -> Result<DeclarationGuard<E>, DeclareError> {
let tid = std::thread::current().id();
let mut declared = self.core.declared.lock().unwrap();
if !declared.get(&tid).is_none() {
return Err(DeclareError::AlreadyDeclared)
}
let is_declared_by_other_context = |declared: &MutexGuard<'_, HashMap<ThreadId, HashSet<E>>>| -> bool {
resources.iter().any(|r| {
declared.iter().any(|(thread_id, set)| {
*thread_id != tid && set.contains(&r)
})
})
};
while is_declared_by_other_context(&declared) {
declared = self.core.condvar.wait(declared).unwrap();
}
let entry = declared.entry(tid).or_insert_with(HashSet::new);
for r in resources {
entry.insert(r.clone());
}
Ok(DeclarationGuard {
core: self.core.clone(),
_not_send_sync: std::marker::PhantomData,
})
}
pub fn is_declared(&self, resource: &E) -> bool {
self.core.is_declared(resource)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeclareError {
AlreadyDeclared,
}
impl fmt::Display for DeclareError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeclareError::AlreadyDeclared => write!(f, "Resources have already been declared for this thread"),
}
}
}
impl std::error::Error for DeclareError {}
#[must_use]
#[derive(Debug)]
pub struct DeclarationGuard<E: Eq + Hash + Clone + Debug> {
core: Arc<Core<E>>,
_not_send_sync: std::marker::PhantomData<std::rc::Rc<()>>,
}
impl<E: Eq + Hash + Clone + Debug> DeclarationGuard<E> {
pub fn is_declared(&self, resource: &E) -> bool {
self.core.is_declared(resource)
}
}
impl<E: Eq + Hash + Clone + Debug> Drop for DeclarationGuard<E> {
fn drop(&mut self) {
let tid = std::thread::current().id();
let mut declared = self.core.declared.lock().unwrap();
declared.remove(&tid);
self.core.condvar.notify_all();
}
}
#[derive(Clone, Debug)]
pub struct DeclarativeLocker<E: Eq + Hash + Clone + Debug, R> {
core: Arc<Core<E>>,
resource_type: E,
resource: Arc<Mutex<R>>,
}
impl<E: Eq + Hash + Clone + Debug, R> DeclarativeLocker<E, R> {
pub fn new(
locker: &DeclarativeLock<E>,
resource_type: E,
resource: R,
) -> Self {
Self {
core: locker.core.clone(),
resource_type,
resource: Arc::new(Mutex::new(resource)),
}
}
pub fn lock(&self) -> Result<LockGuard<'_, E, R>, LockError> {
if !self.core.is_declared(&self.resource_type) {
return Err(LockError::NotDeclared)
}
use std::collections::hash_map::Entry;
match self.core.counter.lock().unwrap().entry(self.resource_type.clone()) {
Entry::Occupied(e) if *e.get() != 0 => return Err(LockError::AlreadyLocked),
Entry::Occupied(mut e) => { e.insert(1); },
Entry::Vacant(e) => { e.insert(1); },
}
let guard = self.resource.lock().map_err(|_| LockError::LockFailed)?;
Ok(LockGuard {
guard,
resource_type: self.resource_type.clone(),
core: self.core.clone(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LockError {
NotDeclared,
AlreadyLocked,
LockFailed,
}
impl fmt::Display for LockError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LockError::NotDeclared => write!(f, "Resource has not been declared for this thread"),
LockError::AlreadyLocked => write!(f, "Resource is already locked"),
LockError::LockFailed => write!(f, "Failed to acquire lock on resource"),
}
}
}
impl std::error::Error for LockError {}
#[must_use]
#[derive(Debug)]
pub struct LockGuard<'a, E: Eq + Hash + Clone + Debug, R> {
guard: MutexGuard<'a, R>,
resource_type: E,
core: Arc<Core<E>>,
}
impl<'a, E: Eq + Hash + Clone + Debug, R> Deref for LockGuard<'a, E, R> {
type Target = R;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<'a, E: Eq + Hash + Clone + Debug, R> DerefMut for LockGuard<'a, E, R> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
impl<'a, E: Eq + Hash + Clone + Debug, R> Drop for LockGuard<'a, E, R> {
fn drop(&mut self) {
self.core.counter.lock().unwrap()
.insert(self.resource_type.clone(), 0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum ResourceType {
Foo,
Bar,
}
#[test]
fn test_declaration_and_locking() {
let lock = DeclarativeLock::<ResourceType>::new();
let guard = lock.declare(&[ResourceType::Foo]).expect("Declaration failed");
assert!(guard.is_declared(&ResourceType::Foo));
assert!(lock.is_declared(&ResourceType::Foo));
let locker = DeclarativeLocker::new(&lock, ResourceType::Foo, 42);
let lock_guard = locker.lock().expect("Lock failed");
assert_eq!(*lock_guard, 42);
drop(lock_guard);
let mut lock_guard = locker.lock().expect("Lock failed");
*lock_guard = 100;
assert_eq!(*lock_guard, 100);
drop(guard);
assert!(!lock.is_declared(&ResourceType::Foo));
}
#[test]
fn test_double_declare_fails() {
let lock = DeclarativeLock::<ResourceType>::new();
let guard1 = lock.declare(&[ResourceType::Foo]).expect("First declaration should succeed");
let result = lock.declare(&[ResourceType::Foo]);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), DeclareError::AlreadyDeclared);
drop(guard1);
}
#[test]
fn test_declare_different_resources_in_threads() {
let start = Instant::now();
let lock = DeclarativeLock::<ResourceType>::new();
let lock1 = lock.clone();
let handle1 = thread::spawn(move || {
let _d1 = lock1.declare(&[ResourceType::Foo]).expect("Thread 1 should declare Foo");
thread::sleep(Duration::from_millis(100));
});
let lock2 = lock.clone();
let handle2 = thread::spawn(move || {
let _d2 = lock2.declare(&[ResourceType::Bar]).expect("Thread 2 should declare Bar");
thread::sleep(Duration::from_millis(100));
});
handle1.join().expect("Thread 1 panicked");
handle2.join().expect("Thread 2 panicked");
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(100), "Elapsed: {:?}", elapsed);
assert!(elapsed <= Duration::from_millis(110), "Elapsed: {:?}", elapsed);
}
#[test]
fn test_declare_same_resources_in_threads() {
let start = Instant::now();
let lock = DeclarativeLock::<ResourceType>::new();
let lock1 = lock.clone();
let handle1 = thread::spawn(move || {
let _d1 = lock1.declare(&[ResourceType::Foo]).expect("Thread 1 should declare Foo");
thread::sleep(Duration::from_millis(100));
});
let lock2 = lock.clone();
let handle2 = thread::spawn(move || {
let _d2 = lock2.declare(&[ResourceType::Foo]).expect("Thread 2 should declare Bar");
thread::sleep(Duration::from_millis(100));
});
handle1.join().expect("Thread 1 panicked");
handle2.join().expect("Thread 2 panicked");
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(200), "Elapsed: {:?}", elapsed);
assert!(elapsed <= Duration::from_millis(220), "Elapsed: {:?}", elapsed);
}
#[test]
fn test_lock_without_declaration_fails() {
let lock = DeclarativeLock::<ResourceType>::new();
let locker = DeclarativeLocker::new(&lock, ResourceType::Bar, 100);
let result = locker.lock();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), LockError::NotDeclared);
}
#[test]
fn test_double_lock_fails() {
let lock = DeclarativeLock::<ResourceType>::new();
let _guard = lock.declare(&[ResourceType::Foo]).unwrap();
let locker = DeclarativeLocker::new(&lock, ResourceType::Foo, 1);
let g1 = locker.lock().unwrap();
let g2 = locker.lock();
assert!(g2.is_err());
assert_eq!(g2.unwrap_err(), LockError::AlreadyLocked);
drop(g1);
let g3 = locker.lock();
assert!(g3.is_ok());
}
#[test]
fn test_multithreaded_declaration_and_locking() {
let lock = DeclarativeLock::<ResourceType>::new();
let locker = DeclarativeLocker::new(&lock, ResourceType::Bar, 999);
let handle = {
let lock = lock.clone();
let locker = locker.clone();
thread::spawn(move || {
let guard = lock.declare(&[ResourceType::Bar]).unwrap();
let g = locker.lock().unwrap();
assert_eq!(*g, 999);
drop(g);
drop(guard);
})
};
handle.join().unwrap();
assert!(!lock.is_declared(&ResourceType::Bar));
}
#[test]
fn test_declare_conflict_waits() {
use std::sync::{Arc, Barrier};
let lock = DeclarativeLock::<ResourceType>::new();
let barrier = Arc::new(Barrier::new(2));
let lock1 = lock.clone();
let barrier1 = Arc::clone(&barrier);
let t1 = thread::spawn(move || {
let _guard = lock1.declare(&[ResourceType::Foo]).unwrap();
barrier1.wait(); thread::sleep(Duration::from_millis(200));
});
let lock2 = lock.clone();
let barrier2 = Arc::clone(&barrier);
let t2 = thread::spawn(move || {
barrier2.wait();
let guard = lock2.declare(&[ResourceType::Foo]).unwrap();
assert!(guard.is_declared(&ResourceType::Foo));
});
t1.join().unwrap();
t2.join().unwrap();
}
}