#![cfg_attr(feature = "nightly", feature(const_fn))]
use lock_api::{RawReentrantMutex, RawMutex, GetThreadId};
use std::fmt;
use std::mem::ManuallyDrop;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::cell::{RefMut, RefCell};
pub struct LendableMutex<R, G, T: ?Sized> {
raw: RawReentrantMutex<R, G>,
data: RefCell<T>
}
unsafe impl<R, G, T: ?Sized> Send for LendableMutex<R, G, T> {}
unsafe impl<R, G, T: ?Sized> Sync for LendableMutex<R, G, T> {}
impl<R: RawMutex, G: GetThreadId, T: ?Sized> LendableMutex<R, G, T> {
#[cfg(feature = "nightly")]
pub const fn new(v: T) -> Self where T: Sized {
Self {
raw: RawReentrantMutex::INIT,
data: RefCell::new(v)
}
}
#[cfg(not(feature = "nightly"))]
pub fn new(v: T) -> Self where T: Sized {
Self {
raw: RawReentrantMutex::INIT,
data: RefCell::new(v)
}
}
pub fn into_inner(self) -> T where T: Sized { self.data.into_inner() }
#[track_caller]
#[inline]
fn guard(&self) -> LendableMutexGuard<'_, R, G, T> {
LendableMutexGuard {
mutex: self,
refmut: ManuallyDrop::new(self.data.borrow_mut()),
marker: PhantomData
}
}
pub unsafe fn force_unlock(&self) {
self.raw.unlock();
}
pub unsafe fn raw(&self) -> &RawReentrantMutex<R, G> { &self.raw }
#[track_caller]
pub fn lock<'a>(&'a self) -> LendableMutexGuard<'a, R, G, T> {
self.raw.lock();
self.guard()
}
#[track_caller]
pub fn try_lock<'a>(&'a self) -> Option<LendableMutexGuard<'a, R, G, T>> {
if self.raw.try_lock() {
Some(self.guard())
} else {
None
}
}
}
pub struct LendableMutexGuard<'a, R: RawMutex, G: GetThreadId, T: ?Sized> {
mutex: &'a LendableMutex<R, G, T>,
refmut: ManuallyDrop<RefMut<'a, T>>,
marker: PhantomData<(&'a mut T, R::GuardMarker)>
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Deref for LendableMutexGuard<'a, R, G, T> {
type Target = T;
fn deref(&self) -> &T { &*self.refmut }
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> DerefMut for LendableMutexGuard<'a, R, G, T> {
fn deref_mut(&mut self) -> &mut T { &mut *self.refmut }
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Debug> Debug for LendableMutexGuard<'a, R, G, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(&**self, f)
}
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized + Display> Display for LendableMutexGuard<'a, R, G, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Display::fmt(&**self, f)
}
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> Drop for LendableMutexGuard<'a, R, G, T> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.refmut);
self.mutex.force_unlock();
}
}
}
impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> LendableMutexGuard<'a, R, G, T> {
pub fn mutex(&self) -> &'a LendableMutex<R, G, T> { self.mutex }
pub fn lend(&mut self, f: impl FnOnce()) {
unsafe { ManuallyDrop::drop(&mut self.refmut); }
let _defer = defer::defer(|| {
self.refmut = ManuallyDrop::new(self.mutex.data.borrow_mut());
});
f();
}
}
pub type PlLendableMutex<T> = LendableMutex<parking_lot::RawMutex, parking_lot::RawThreadId, T>;
pub type PlLendableMutexGuard<'a, T> = LendableMutexGuard<'a, parking_lot::RawMutex, parking_lot::RawThreadId, T>;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::thread;
use super::*;
#[test]
fn basic_mutex() {
let m = Arc::new(PlLendableMutex::new(0));
let mut handles = Vec::new();
for _ in 0..100 {
let m2 = m.clone();
handles.push(thread::spawn(move || {
let mut l = m2.lock();
*l += 1;
}));
}
for h in handles { h.join().unwrap(); }
assert_eq!(*m.lock(), 100);
}
#[test]
fn stays_locked() {
let m = Arc::new(PlLendableMutex::new(0));
let mut handles = Vec::new();
for _ in 0..100 {
let m2 = m.clone();
handles.push(thread::spawn(move || {
println!("[{:?}] locking", thread::current().id());
let mut l = m2.lock();
println!("[{:?}] locked", thread::current().id());
let old = *l;
*l += 1;
println!("[{:?}] lending", thread::current().id());
l.lend(|| {
println!("[{:?}] lent", thread::current().id());
#[allow(deprecated)] thread::sleep_ms(100);
let mut l2 = m2.lock();
assert_eq!(*l2, old + 1);
*l2 += 1;
println!("[{:?}] end lend", thread::current().id());
});
assert_eq!(*l, old + 2);
println!("[{:?}] end of thread", thread::current().id());
}));
}
for h in handles { h.join().unwrap(); }
assert_eq!(*m.lock(), 200);
}
}