use owning_ref_lockable::OwningHandle;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::{Arc, Mutex, MutexGuard};
use crate::error::{PoisonError, TryLockError, UnpoisonError};
use crate::guard::{Guard, GuardImpl};
use crate::mutex::{LockError, MutexImpl};
pub trait LockPool<K>: Default
where
K: Eq + PartialEq + Hash + Clone + Debug,
{
type Guard<'a>: Guard<K>
where
Self: 'a;
type OwnedGuard: Guard<K>;
#[inline]
fn new() -> Self {
Self::default()
}
fn num_locked_or_poisoned(&self) -> usize;
fn lock(&self, key: K) -> Result<Self::Guard<'_>, PoisonError<K, Self::Guard<'_>>>;
fn lock_owned(
self: &Arc<Self>,
key: K,
) -> Result<Self::OwnedGuard, PoisonError<K, Self::OwnedGuard>>;
fn try_lock(&self, key: K) -> Result<Self::Guard<'_>, TryLockError<K, Self::Guard<'_>>>;
fn try_lock_owned(
self: &Arc<Self>,
key: K,
) -> Result<Self::OwnedGuard, TryLockError<K, Self::OwnedGuard>>;
fn unpoison(&self, key: K) -> Result<(), UnpoisonError>;
}
pub struct LockPoolImpl<K, M>
where
K: Eq + PartialEq + Hash + Clone + Debug,
M: MutexImpl,
{
currently_locked: Mutex<HashMap<K, Arc<M>>>,
_p: PhantomData<M>,
}
impl<K, M> Default for LockPoolImpl<K, M>
where
K: Eq + PartialEq + Hash + Clone + Debug,
M: MutexImpl,
{
#[inline]
fn default() -> Self {
Self {
currently_locked: Mutex::new(HashMap::new()),
_p: PhantomData,
}
}
}
impl<K, M> LockPool<K> for LockPoolImpl<K, M>
where
K: Eq + PartialEq + Hash + Clone + Debug + 'static,
M: MutexImpl + 'static,
{
type Guard<'a> = GuardImpl<'a, K, M, &'a Self>;
type OwnedGuard = GuardImpl<'static, K, M, Arc<LockPoolImpl<K, M>>>;
#[inline]
fn num_locked_or_poisoned(&self) -> usize {
self._currently_locked().len()
}
fn lock(&self, key: K) -> Result<Self::Guard<'_>, PoisonError<K, Self::Guard<'_>>> {
Self::_lock(self, key)
}
fn lock_owned(
self: &Arc<Self>,
key: K,
) -> Result<Self::OwnedGuard, PoisonError<K, Self::OwnedGuard>> {
Self::_lock(Arc::clone(self), key)
}
fn try_lock(&self, key: K) -> Result<Self::Guard<'_>, TryLockError<K, Self::Guard<'_>>> {
Self::_try_lock(self, key)
}
fn try_lock_owned(
self: &Arc<Self>,
key: K,
) -> Result<Self::OwnedGuard, TryLockError<K, Self::OwnedGuard>> {
Self::_try_lock(Arc::clone(self), key)
}
fn unpoison(&self, key: K) -> Result<(), UnpoisonError> {
if !M::SUPPORTS_POISONING {
panic!("This lock pool doesn't support poisoning");
}
let mut currently_locked = self._currently_locked();
let mutex: &Arc<M> = currently_locked
.get(&key)
.ok_or(UnpoisonError::NotPoisoned)?;
if Arc::strong_count(mutex) != 1 {
return Err(UnpoisonError::OtherThreadsBlockedOnMutex);
}
let result = match Arc::clone(mutex).lock() {
Ok(_) => Err(UnpoisonError::NotPoisoned),
Err(_) => {
let remove_result = currently_locked.remove(&key);
assert!(
remove_result.is_some(),
"We just got this entry above from the hash map, it cannot have vanished since then"
);
Ok(())
}
};
result
}
}
impl<K, M> LockPoolImpl<K, M>
where
K: Eq + PartialEq + Hash + Clone + Debug,
M: MutexImpl,
{
fn _currently_locked(&self) -> MutexGuard<'_, HashMap<K, Arc<M>>> {
self.currently_locked
.lock()
.expect("The global mutex protecting the lock pool is poisoned. This shouldn't happen since there shouldn't be any user code running while this lock is held so no thread should ever panic with it")
}
pub(super) fn _load_or_insert_mutex_for_key(&self, key: &K) -> Arc<M> {
let mut currently_locked = self._currently_locked();
if let Some(mutex) = currently_locked.get_mut(key).map(|a| Arc::clone(a)) {
mutex
} else {
let insert_result = currently_locked.insert(key.clone(), Arc::new(M::new()));
assert!(
insert_result.is_none(),
"We just checked that the entry doesn't exist, why does it exist now?"
);
currently_locked
.get_mut(key)
.map(|a| Arc::clone(a))
.expect("We just inserted this")
}
}
fn _lock<'a, S: 'a + Deref<Target = Self>>(this: S, key: K) -> LockResult<'a, K, M, S> {
let mutex = this._load_or_insert_mutex_for_key(&key);
let mut poisoned = false;
let guard = OwningHandle::new_with_fn(mutex, |mutex: *const M| {
let mutex: &M = unsafe { &*mutex };
match mutex.lock() {
Ok(guard) => guard,
Err(poison_error) => {
poisoned = true;
poison_error.into_inner()
}
}
});
if poisoned {
let guard = GuardImpl::new(this, key.clone(), guard, true);
Err(PoisonError { key, guard })
} else {
let guard = GuardImpl::new(this, key, guard, false);
Ok(guard)
}
}
fn _try_lock<'a, S: 'a + Deref<Target = Self>>(this: S, key: K) -> TryLockResult<'a, K, M, S> {
let mutex = this._load_or_insert_mutex_for_key(&key);
let mut poisoned = false;
let guard = OwningHandle::try_new(mutex, |mutex: *const M| {
let mutex: &M = unsafe { &*mutex };
match mutex.try_lock() {
Ok(guard) => Ok(guard),
Err(std::sync::TryLockError::Poisoned(poison_error)) => {
poisoned = true;
Ok(poison_error.into_inner())
}
Err(std::sync::TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
}
})?;
if poisoned {
let guard = GuardImpl::new(this, key.clone(), guard, true);
Err(TryLockError::Poisoned(PoisonError { key, guard }))
} else {
let guard = GuardImpl::new(this, key, guard, false);
Ok(guard)
}
}
pub(super) fn _unlock(
&self,
key: &K,
guard: OwningHandle<Arc<M>, M::Guard<'_>>,
poisoned: bool,
) {
if poisoned {
return;
}
let mut currently_locked = self._currently_locked();
let mutex: &Arc<M> = currently_locked
.get(key)
.expect("This entry must exist or the guard passed in as a parameter shouldn't exist");
std::mem::drop(guard);
if Arc::strong_count(mutex) == 1 && !std::thread::panicking() {
let remove_result = currently_locked.remove(key);
assert!(
remove_result.is_some(),
"We just got this entry above from the hash map, it cannot have vanished since then"
);
}
}
}
pub type LockResult<'a, K, M, S> =
Result<GuardImpl<'a, K, M, S>, PoisonError<K, GuardImpl<'a, K, M, S>>>;
pub type TryLockResult<'a, K, M, S> =
Result<GuardImpl<'a, K, M, S>, TryLockError<K, GuardImpl<'a, K, M, S>>>;
#[cfg(test)]
mod tests;
#[cfg(feature = "tokio")]
pub mod pool_async;
pub mod pool_sync;