#![warn(missing_docs)]
pub mod lock_in_order;
pub use lock_in_order::lock;
mod sync;
use sync::MutexGuard as StdMutexGuard;
use sync::*;
use std::cell::RefCell;
use std::sync::atomic::AtomicUsize;
use std::sync::{PoisonError, TryLockError};
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
sync::thread_local!(
static THIS_SCOPE: RefCell<LockScope> = RefCell::new(LockScope::new(
THREAD_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
));
);
#[derive(Default)]
pub struct CoopMutex<T> {
native: Mutex<T>,
held_waiter: Mutex<HeldWaiter>,
waiters: Condvar,
primary_waiter: Condvar,
}
type HeldWaiter = (Option<usize>, Option<usize>);
impl<T> CoopMutex<T> {
pub fn new(item: T) -> Self {
CoopMutex {
native: Mutex::new(item),
held_waiter: Mutex::new((None, None)),
waiters: Condvar::new(),
primary_waiter: Condvar::new(),
}
}
pub fn lock(&self) -> Result<LockResult<MutexGuard<T>>, Retry> {
THIS_SCOPE.with(|scope| scope.borrow().lock(self))
}
#[cfg(not(feature = "loom-tests"))]
pub fn get_mut(&mut self) -> LockResult<&mut T> {
self.native.get_mut()
}
}
impl<T: core::fmt::Debug> core::fmt::Debug for CoopMutex<T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
self.native.fmt(f)
}
}
impl<T> From<T> for CoopMutex<T> {
fn from(t: T) -> Self {
CoopMutex::new(t)
}
}
struct LockScope {
id: usize,
lock_count: Arc<()>,
}
impl LockScope {
fn lock<'m, T>(
&self,
mutex: &'m CoopMutex<T>,
) -> Result<LockResult<MutexGuard<'m, T>>, Retry<'m>> {
self.lock_native(mutex).map(|result| match result {
Ok(g) => Ok(self.guard(g, mutex)),
Err(p) => Err(PoisonError::new(self.guard(p.into_inner(), mutex))),
})
}
fn lock_native<'m, T>(
&self,
mutex: &'m CoopMutex<T>,
) -> Result<LockResult<StdMutexGuard<'m, T>>, Retry<'m>> {
loop {
match mutex.native.try_lock() {
Ok(g) => return Ok(Ok(g)),
Err(TryLockError::Poisoned(p)) => return Ok(Err(p)),
Err(TryLockError::WouldBlock) => {}
}
let mut lock = mutex.held_waiter.lock().unwrap();
loop {
lock = match &mut *lock {
(None, _) => break,
(Some(holder), _) if self.id == *holder => {
panic!("Attempted to lock a CoopMutex already held by this thread")
}
(Some(holder), Some(waiter)) if holder == waiter => {
unreachable!("Held and waited by same thread")
}
_ if self.active_locks() == 0 => mutex.waiters.wait(lock).unwrap(),
(Some(holder), _) if self.id > *holder => return Err(self.retry(mutex)),
(_, Some(waiter)) if self.id > *waiter => mutex.waiters.wait(lock).unwrap(),
_ => {
lock.1 = Some(self.id);
mutex.primary_waiter.notify_one();
mutex.primary_waiter.wait(lock).unwrap()
}
}
}
}
}
fn retry<'m, T>(&self, mutex: &'m CoopMutex<T>) -> Retry<'m> {
Retry {
waiters: &mutex.waiters,
mutex: &mutex.held_waiter,
}
}
fn guard<'m, T>(
&self,
native: StdMutexGuard<'m, T>,
mutex: &'m CoopMutex<T>,
) -> MutexGuard<'m, T> {
let mut held_waiter = mutex.held_waiter.lock().unwrap();
held_waiter.0 = Some(self.id);
held_waiter.1 = None;
MutexGuard {
native,
mutex,
_lock_count: Arc::clone(&self.lock_count),
}
}
fn new(id: usize) -> LockScope {
LockScope {
id,
lock_count: Arc::new(()),
}
}
fn active_locks(&self) -> usize {
Arc::strong_count(&self.lock_count) - 1
}
fn update_id_for_fairness(&mut self) {
if self.active_locks() == 0 {
self.id = THREAD_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
}
pub struct MutexGuard<'m, T> {
native: StdMutexGuard<'m, T>,
mutex: &'m CoopMutex<T>,
_lock_count: Arc<()>,
}
impl<'m, T> Drop for MutexGuard<'m, T> {
fn drop(&mut self) {
let mut held_waiter = self.mutex.held_waiter.lock().unwrap();
held_waiter.0 = None;
self.mutex.primary_waiter.notify_one();
self.mutex.waiters.notify_all();
drop(held_waiter);
}
}
impl<'m, T> core::ops::Deref for MutexGuard<'m, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.native
}
}
impl<'m, T> core::ops::DerefMut for MutexGuard<'m, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.native
}
}
#[derive(Debug)]
pub struct Retry<'m> {
mutex: &'m Mutex<HeldWaiter>,
waiters: &'m Condvar,
}
impl<'m> Retry<'m> {
fn wait(self) {
let mut lock = self.mutex.lock().unwrap();
while let Some(_) = lock.0 {
lock = self.waiters.wait(lock).unwrap();
}
}
}
pub fn retry_loop<'m, T, F: FnMut() -> Result<T, Retry<'m>>>(mut f: F) -> T {
loop {
match f() {
Ok(t) => {
THIS_SCOPE.with(|s| s.borrow_mut().update_id_for_fairness());
return t;
}
Err(retry) => retry.wait(),
}
}
}
#[cfg(all(test, not(feature = "loom-tests")))]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn second_thread_retries() {
let a = CoopMutex::new(42);
let b = CoopMutex::new(43);
let s1 = LockScope::new(0);
let s2 = LockScope::new(1);
crossbeam::thread::scope(|s| {
let x1 = s1.lock(&a).unwrap();
let x2 = s2.lock(&b).unwrap();
s.spawn(|_| {
let _ = s1.lock(&b).unwrap();
});
assert!(s2.lock(&a).is_err());
drop((x1, x2));
})
.unwrap();
}
#[test]
fn first_thread_blocks() {
let mutex = CoopMutex::new(42);
let s1 = LockScope::new(0);
let s2 = LockScope::new(1);
crossbeam::thread::scope(|s| {
let lock = s2.lock(&mutex).unwrap();
s.spawn(|_| {
assert_eq!(*s1.lock(&mutex).unwrap().unwrap(), 42);
});
std::thread::sleep(Duration::from_millis(100));
drop(lock);
})
.unwrap();
}
#[test]
fn second_waits_if_not_holding_other_locks() {
let mutex = CoopMutex::new(42);
let s1 = LockScope::new(0);
let s2 = LockScope::new(1);
crossbeam::thread::scope(|s| {
s.spawn(|_| {
let lock = s1.lock(&mutex);
std::thread::sleep(Duration::from_millis(100));
drop(lock);
});
std::thread::sleep(Duration::from_millis(10));
assert_eq!(*s2.lock(&mutex).unwrap().unwrap(), 42);
})
.unwrap();
}
}
#[cfg(all(test, feature = "loom-tests"))]
mod loom_tests {
use super::*;
use loom::{self, sync::Arc};
#[test]
#[ignore]
fn loom_deadlock() {
loom::model(|| {
let a = Arc::new(CoopMutex::new(42));
let b = Arc::new(CoopMutex::new(43));
let t1 = {
let a = a.clone();
let b = b.clone();
loom::thread::spawn(move || {
retry_loop(|| {
let a = a.lock()?.unwrap();
let mut b = b.lock()?.unwrap();
*b += *a;
Ok(())
});
})
};
let t2 = {
let a = a.clone();
let b = b.clone();
loom::thread::spawn(move || {
retry_loop(|| {
let b = b.lock()?.unwrap();
let mut a = a.lock()?.unwrap();
*a += *b;
Ok(())
});
})
};
t1.join().unwrap();
t2.join().unwrap();
});
}
}