use core::ops::{Deref, DerefMut};
use crate::cfg::sync::Arc;
pub trait AsDeref {
type Target: ?Sized;
type Deref<'a>: Deref<Target = Self::Target>
where
Self: 'a,
Self::Target: 'a;
fn as_deref(&self) -> Self::Deref<'_>;
}
pub trait AsDerefMut: AsDeref {
type DerefMut<'a>: DerefMut<Target = Self::Target>
where
Self: 'a,
Self::Target: 'a;
fn as_deref_mut(&mut self) -> Self::DerefMut<'_>;
}
pub trait LockNew {
type Target: ?Sized;
fn new(value: Self::Target) -> Self
where
Self::Target: Sized;
}
pub trait LockWithThen: LockNew {
type Node: Default;
type Guard<'a>: AsDerefMut<Target = Self::Target>
where
Self: 'a,
Self::Target: 'a;
fn lock_with_then<F, Ret>(&self, node: &mut Self::Node, f: F) -> Ret
where
F: FnOnce(Self::Guard<'_>) -> Ret;
}
pub trait LockThen: LockWithThen {
fn lock_then<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce(Self::Guard<'_>) -> Ret,
{
self.lock_with_then(&mut Self::Node::default(), f)
}
}
pub trait TryLockWithThen: LockWithThen {
fn try_lock_with_then<F, Ret>(&self, node: &mut Self::Node, f: F) -> Ret
where
F: FnOnce(Option<Self::Guard<'_>>) -> Ret;
#[cfg_attr(all(loom, test), allow(dead_code))]
fn is_locked(&self) -> bool;
}
pub trait TryLockThen: TryLockWithThen + LockThen {
#[cfg_attr(all(loom, test), allow(dead_code))]
fn try_lock_then<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce(Option<Self::Guard<'_>>) -> Ret,
{
self.try_lock_with_then(&mut Self::Node::default(), f)
}
}
#[cfg(not(loom))]
pub trait LockData: LockNew {
fn into_inner(self) -> Self::Target
where
Self::Target: Sized;
fn get_mut(&mut self) -> &mut Self::Target;
}
impl<T: Deref> AsDeref for T {
type Target = <Self as Deref>::Target;
type Deref<'a>
= &'a <Self as Deref>::Target
where
Self: 'a,
Self::Target: 'a;
fn as_deref(&self) -> Self::Deref<'_> {
self
}
}
impl<T: DerefMut> AsDerefMut for T {
type DerefMut<'a>
= &'a mut <Self as Deref>::Target
where
Self: 'a,
Self::Target: 'a;
fn as_deref_mut(&mut self) -> Self::DerefMut<'_> {
self
}
}
pub type Int = u32;
pub fn get<L>(mutex: &Arc<L>) -> L::Target
where
L: LockThen<Target: Sized + Copy>,
{
mutex.lock_then(|data| *data.as_deref())
}
pub fn inc<L>(mutex: &Arc<L>)
where
L: LockThen<Target = Int>,
{
mutex.lock_then(inc_inner::<L>);
}
#[cfg(not(all(loom, test)))]
pub fn inc_with<L>(mutex: &Arc<L>, node: &mut L::Node)
where
L: LockWithThen<Target = Int>,
{
mutex.lock_with_then(node, inc_inner::<L>);
}
#[cfg(not(all(loom, test)))]
pub fn try_inc_with<L>(mutex: &Arc<L>, node: &mut L::Node)
where
L: TryLockWithThen<Target = Int>,
{
mutex.try_lock_with_then(node, try_inc_inner::<L>);
}
#[cfg(all(loom, test))]
pub fn try_inc<L>(mutex: &Arc<L>)
where
L: TryLockThen<Target = Int>,
{
mutex.try_lock_then(try_inc_inner::<L>);
}
fn inc_inner<L>(mut guard: L::Guard<'_>)
where
L: LockWithThen<Target = Int>,
{
*guard.as_deref_mut() += 1;
}
fn try_inc_inner<L>(guard: Option<L::Guard<'_>>)
where
L: TryLockWithThen<Target = Int>,
{
guard.map(inc_inner::<L>);
}
#[cfg(all(not(loom), test))]
pub mod tests {
use core::ops::RangeInclusive;
use std::fmt::Debug;
use std::format;
use std::string::ToString;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::thread;
#[cfg(feature = "barging")]
use std::fmt::Display;
use super::{get, inc, inc_with, try_inc_with, Int};
use super::{AsDeref, AsDerefMut};
use super::{LockData, LockThen, LockWithThen, TryLockThen, TryLockWithThen};
#[derive(Eq, PartialEq, Debug)]
pub struct NonCopy(u32);
pub struct Foo(Arc<AtomicUsize>);
impl Drop for Foo {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
const ITERS: Int = 1000;
const THREADS: Int = 4;
const EXPECTED_VALUE: Int = ITERS * THREADS;
const EXPECTED_RANGE: RangeInclusive<Int> = 1..=EXPECTED_VALUE;
fn inc_for<L, const END: Int>(mutex: &Arc<L>)
where
L: LockWithThen<Target = Int>,
{
let mut node = L::Node::default();
for _ in 0..END {
inc_with::<L>(mutex, &mut node);
}
}
fn try_inc_for<L, const END: Int>(mutex: &Arc<L>)
where
L: TryLockWithThen<Target = Int>,
{
let mut node = L::Node::default();
for _ in 0..END {
try_inc_with::<L>(mutex, &mut node);
}
}
fn mixed_inc_for<L, const END: Int>(mutex: &Arc<L>)
where
L: TryLockWithThen<Target = Int>,
{
let mut node = L::Node::default();
for r in 0..END {
let f = if r % 2 == 0 { inc_with } else { try_inc_with };
f(mutex, &mut node);
}
}
fn lots_and_lots<L, const THREADS: Int>(f: fn(&Arc<L>)) -> Int
where
L: LockThen<Target = Int> + Send + Sync + 'static,
{
let mutex = Arc::new(L::new(0));
let (tx, rx) = channel();
for _ in 0..THREADS {
let c_mutex = Arc::clone(&mutex);
let c_tx = tx.clone();
thread::spawn(move || {
f(&c_mutex);
c_tx.send(()).unwrap();
});
}
drop(tx);
for _ in 0..THREADS {
rx.recv().unwrap();
}
get(&mutex)
}
pub fn node_waiter_drop_does_not_matter<W>() {
use crate::inner::raw::{MutexNode, MutexNodeInit};
assert!(!core::mem::needs_drop::<W>());
assert!(!core::mem::needs_drop::<MutexNode<W>>());
assert!(!core::mem::needs_drop::<MutexNodeInit<W>>());
}
pub fn lots_and_lots_lock<L>()
where
L: LockThen<Target = Int> + Send + Sync + 'static,
{
let value = lots_and_lots::<L, THREADS>(inc_for::<L, ITERS>);
assert_eq!(value, EXPECTED_VALUE);
}
pub fn lots_and_lots_try_lock<L>()
where
L: TryLockThen<Target = Int> + Send + Sync + 'static,
{
let value = lots_and_lots::<L, THREADS>(try_inc_for::<L, ITERS>);
assert!(EXPECTED_RANGE.contains(&value));
}
pub fn lots_and_lots_mixed_lock<L>()
where
L: TryLockThen<Target = Int> + Send + Sync + 'static,
{
let value = lots_and_lots::<L, THREADS>(mixed_inc_for::<L, ITERS>);
assert!(EXPECTED_RANGE.contains(&value));
}
pub fn smoke<L>()
where
L: LockWithThen<Target = Int>,
{
let mutex = L::new(1);
let mut node = L::Node::default();
mutex.lock_with_then(&mut node, |data| drop(data));
mutex.lock_with_then(&mut node, |data| drop(data));
}
#[cfg(feature = "barging")]
pub fn test_guard_debug_display<L>()
where
L: LockThen<Target = Int>,
for<'a> <L as LockWithThen>::Guard<'a>: Debug + Display,
{
let value = 42;
let mutex = L::new(value);
mutex.lock_then(|data| {
assert_eq!(format!("{value:?}"), format!("{data:?}"));
assert_eq!(format!("{value}"), format!("{data}"));
});
}
pub fn test_mutex_debug<L>()
where
L: LockThen<Target = Int> + Debug + Send + Sync + 'static,
{
let value = 42;
let mutex = Arc::new(L::new(value));
let msg = format!("Mutex {{ data: {value:?} }}");
assert_eq!(msg, format!("{mutex:?}"));
let c_mutex = Arc::clone(&mutex);
let msg = "Mutex { data: <locked> }".to_string();
mutex.lock_then(|_data| {
assert_eq!(msg, format!("{:?}", *c_mutex));
});
}
pub fn test_mutex_default<L>()
where
L: LockData<Target = Int> + Default,
{
let mutex: L = Default::default();
assert_eq!(u32::default(), mutex.into_inner());
}
pub fn test_mutex_from<L>()
where
L: LockData<Target = Int> + From<Int>,
{
let value = 42;
let mutex = L::from(value);
assert_eq!(value, mutex.into_inner());
}
pub fn test_try_lock<L>()
where
L: TryLockThen<Target = ()>,
{
use std::rc::Rc;
let mutex = Rc::new(L::new(()));
let c_mutex = Rc::clone(&mutex);
mutex.try_lock_then(|data| {
assert!(c_mutex.is_locked());
*data.unwrap().as_deref_mut() = ();
});
assert!(!mutex.is_locked());
}
pub fn test_into_inner<M>()
where
M: LockData<Target = NonCopy>,
{
let mutex = M::new(NonCopy(10));
assert_eq!(mutex.into_inner(), NonCopy(10));
}
pub fn test_into_inner_drop<M>()
where
M: LockData<Target = Foo>,
{
let num_drops = Arc::new(AtomicUsize::new(0));
let mutex = M::new(Foo(num_drops.clone()));
assert_eq!(num_drops.load(Ordering::SeqCst), 0);
{
let _inner = mutex.into_inner();
assert_eq!(num_drops.load(Ordering::SeqCst), 0);
}
assert_eq!(num_drops.load(Ordering::SeqCst), 1);
}
pub fn test_get_mut<M>()
where
M: LockData<Target = NonCopy>,
{
let mut mutex = M::new(NonCopy(10));
*mutex.get_mut() = NonCopy(20);
assert_eq!(mutex.into_inner(), NonCopy(20));
}
pub fn test_lock_arc_nested<L1, L2>()
where
L1: LockThen<Target = Int>,
L2: LockThen<Target = Arc<L1>> + Send + Sync + 'static,
{
let arc = Arc::new(L1::new(1));
let arc2 = Arc::new(L2::new(arc));
let (tx, rx) = channel();
let _t = thread::spawn(move || {
let val = arc2.lock_then(|arc2| {
let arc2 = arc2.as_deref();
get(&arc2)
});
assert_eq!(val, 1);
tx.send(()).unwrap();
});
rx.recv().unwrap();
}
pub fn test_acquire_more_than_one_lock<L>()
where
L: LockThen<Target = Int> + Send + Sync + 'static,
{
let arc = Arc::new(L::new(1));
let (tx, rx) = channel();
for _ in 0..4 {
let tx2 = tx.clone();
let c_arc = Arc::clone(&arc);
let _t = thread::spawn(move || {
c_arc.lock_then(|_d| {
let mutex = L::new(1);
mutex.lock_then(|_d| ());
});
tx2.send(()).unwrap();
});
}
drop(tx);
for _ in 0..4 {
rx.recv().unwrap();
}
}
pub fn test_lock_arc_access_in_unwind<L>()
where
L: LockThen<Target = Int> + Send + Sync + 'static,
{
let arc = Arc::new(L::new(1));
let arc2 = arc.clone();
let _ = thread::spawn(move || {
struct Unwinder<T: LockThen<Target = Int>> {
i: Arc<T>,
}
impl<T: LockThen<Target = Int>> Drop for Unwinder<T> {
fn drop(&mut self) {
inc(&self.i);
}
}
let _u = Unwinder { i: arc2 };
panic!();
})
.join();
let value = get(&arc);
assert_eq!(value, 2);
}
pub fn test_lock_unsized<L>()
where
L: LockThen<Target = [Int; 3]>,
{
let mutex = Arc::new(L::new([1, 2, 3]));
{
mutex.lock_then(|mut d| {
d.as_deref_mut()[0] = 4;
d.as_deref_mut()[2] = 5;
});
}
let comp: &[Int] = &[4, 2, 5];
let data = get(&mutex);
assert_eq!(comp, data);
}
}