#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![warn(rustdoc::missing_crate_level_docs)]
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::ops::Deref;
use std::ops::DerefMut;
#[doc(hidden)]
pub use assoc_static::*;
#[cfg(debug_assertions)]
#[doc(hidden)]
pub use assoc_threadlocal::*;
#[doc(hidden)]
pub use parking_lot;
use parking_lot::lock_api::RawMutex as RawMutexTrait;
use parking_lot::RawMutex;
#[macro_export]
#[cfg(debug_assertions)]
macro_rules! sharded_mutex {
($TAG:ty : $T:ty) => {
$crate::assoc_static!(
$TAG: $T,
$crate::MutexPool = $crate::MutexPool([$crate::MUTEXRC_INIT; $crate::POOL_SIZE])
);
$crate::assoc_threadlocal!($TAG: $T, LockCount = LockCount(0));
};
($T:ty) => {
$crate::assoc_static!(
$T,
$crate::MutexPool = $crate::MutexPool([$crate::MUTEXRC_INIT; $crate::POOL_SIZE])
);
$crate::assoc_threadlocal!($T, LockCount = LockCount(0));
};
}
#[allow(missing_docs)]
#[macro_export]
#[cfg(not(debug_assertions))]
macro_rules! sharded_mutex {
($TAG:ty : $T:ty) => {
$crate::assoc_static!(
$TAG: $T,
$crate::MutexPool = $crate::MutexPool([$crate::MUTEXRC_INIT; $crate::POOL_SIZE])
);
};
($T:ty) => {
$crate::assoc_static!(
$T,
$crate::MutexPool = $crate::MutexPool([$crate::MUTEXRC_INIT; $crate::POOL_SIZE])
);
};
}
#[derive(Debug)]
#[repr(transparent)]
pub struct ShardedMutex<T, TAG = ()>(UnsafeCell<T>, PhantomData<TAG>)
where
T: AssocStatic<MutexPool, TAG>;
unsafe impl<T, TAG> Sync for ShardedMutex<T, TAG> where T: Send + AssocStatic<MutexPool, TAG> {}
unsafe impl<T, TAG> Send for ShardedMutex<T, TAG> where T: Send + AssocStatic<MutexPool, TAG> {}
#[cfg(debug_assertions)]
#[doc(hidden)]
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct LockCount(pub usize);
#[doc(hidden)]
#[cfg(debug_assertions)]
pub trait AssocObjects<TAG>:
AssocStatic<MutexPool, TAG> + AssocThreadLocal<LockCount, TAG>
{
}
#[cfg(debug_assertions)]
impl<T, TAG> AssocObjects<TAG> for T where
T: AssocStatic<MutexPool, TAG> + AssocThreadLocal<LockCount, TAG>
{
}
#[doc(hidden)]
#[cfg(not(debug_assertions))]
pub trait AssocObjects<TAG>: AssocStatic<MutexPool, TAG> {}
#[cfg(not(debug_assertions))]
impl<T, TAG> AssocObjects<TAG> for T where T: AssocStatic<MutexPool, TAG> {}
#[doc(hidden)]
pub const POOL_SIZE: usize = 127;
#[doc(hidden)]
pub struct RawMutexRc(RawMutex, UnsafeCell<u8>);
#[doc(hidden)]
#[allow(clippy::declare_interior_mutable_const)] pub const MUTEXRC_INIT: RawMutexRc = RawMutexRc(RawMutex::INIT, UnsafeCell::new(0));
unsafe impl Sync for RawMutexRc {}
unsafe impl Send for RawMutexRc {}
impl RawMutexRc {
#[inline]
fn lock(&self) {
self.0.lock();
}
#[inline]
fn try_lock(&self) -> bool {
self.0.try_lock()
}
#[inline]
unsafe fn again(&self) {
*self.1.get() += 1;
}
#[inline]
unsafe fn unlock(&self) {
if *self.1.get() == 0 {
self.0.unlock()
} else {
*self.1.get() -= 1;
}
}
}
#[doc(hidden)]
#[repr(align(128))] pub struct MutexPool(pub [RawMutexRc; POOL_SIZE]);
impl<T, TAG> ShardedMutex<T, TAG>
where
T: AssocObjects<TAG>,
{
fn get_mutex(&self) -> &'static RawMutexRc {
unsafe {
<T as AssocStatic<MutexPool, TAG>>::get_static()
.0
.get_unchecked(self as *const Self as usize % POOL_SIZE)
}
}
pub fn new(value: T) -> Self {
ShardedMutex(UnsafeCell::new(value), PhantomData)
}
#[cfg(debug_assertions)]
fn deadlock_check_before_locking() {
assert_eq!(
<T as AssocThreadLocal<LockCount, TAG>>::get_threadlocal(),
LockCount(0),
"already locked from the same thread"
);
}
pub fn lock(&self) -> ShardedMutexGuard<T, TAG> {
#[cfg(debug_assertions)]
Self::deadlock_check_before_locking();
self.get_mutex().lock();
ShardedMutexGuard::new(self)
}
pub fn try_lock(&self) -> Option<ShardedMutexGuard<T, TAG>> {
self.get_mutex()
.try_lock()
.then(|| ShardedMutexGuard::new(self))
}
pub fn multi_lock<const N: usize>(objects: [&Self; N]) -> [ShardedMutexGuard<T, TAG>; N] {
assert!(N <= u8::MAX as usize);
#[cfg(debug_assertions)]
Self::deadlock_check_before_locking();
let mut locks = objects.map(|o| o.get_mutex());
locks.sort_by(|a, b| {
(*a as *const RawMutexRc as usize).cmp(&(*b as *const RawMutexRc as usize))
});
for i in 0..locks.len() {
unsafe {
if i == 0
|| *locks.get_unchecked(i - 1) as *const RawMutexRc
!= *locks.get_unchecked(i) as *const RawMutexRc
{
locks.get_unchecked(i).lock();
} else {
locks.get_unchecked(i).again();
}
}
}
objects.map(|o| ShardedMutexGuard::new(o))
}
pub fn try_multi_lock<const N: usize>(
objects: [&Self; N],
) -> Option<[ShardedMutexGuard<T, TAG>; N]> {
assert!(N <= u8::MAX as usize);
let mut locks = objects.map(|o| o.get_mutex());
locks.sort_by(|a, b| {
(*a as *const RawMutexRc as usize).cmp(&(*b as *const RawMutexRc as usize))
});
for i in 0..locks.len() {
unsafe {
if i == 0
|| *locks.get_unchecked(i - 1) as *const RawMutexRc
!= *locks.get_unchecked(i) as *const RawMutexRc
{
if !locks.get_unchecked(i).try_lock() {
for j in 0..i {
locks.get_unchecked(j).unlock();
}
return None;
}
} else {
locks.get_unchecked(i).again();
}
}
}
Some(objects.map(|o| ShardedMutexGuard::new(o)))
}
pub fn get_mut(&mut self) -> &mut T {
&mut *self.0.get_mut()
}
pub fn into_inner(self) -> T {
self.0.into_inner()
}
}
pub trait PseudoAtomicOps<T, TAG> {
fn load(&self) -> T;
fn store(&self, value: &T);
fn swap(&self, value: &mut T);
fn compare_and_set(&self, current: &T, new: &T) -> bool;
}
impl<T, TAG> PseudoAtomicOps<T, TAG> for ShardedMutex<T, TAG>
where
T: AssocObjects<TAG> + Copy + std::cmp::PartialEq,
{
fn load(&self) -> T {
*self.lock()
}
fn store(&self, value: &T) {
*self.lock() = *value
}
fn swap(&self, value: &mut T) {
std::mem::swap(&mut *self.lock(), value)
}
fn compare_and_set(&self, current: &T, new: &T) -> bool {
let mut guard = self.lock();
if *guard == *current {
*guard = *new;
true
} else {
false
}
}
}
pub struct ShardedMutexGuard<'a, T, TAG>(&'a ShardedMutex<T, TAG>)
where
T: AssocObjects<TAG>;
impl<'a, T, TAG> ShardedMutexGuard<'a, T, TAG>
where
T: AssocObjects<TAG>,
{
fn new(mutex: &'a ShardedMutex<T, TAG>) -> ShardedMutexGuard<'a, T, TAG> {
#[cfg(debug_assertions)]
Self::deadlock_increment_lock_count();
ShardedMutexGuard(mutex)
}
#[cfg(debug_assertions)]
fn deadlock_increment_lock_count() {
let LockCount(n) = <T as AssocThreadLocal<LockCount, TAG>>::get_threadlocal();
<T as AssocThreadLocal<LockCount, TAG>>::set_threadlocal(LockCount(n + 1));
}
#[cfg(debug_assertions)]
fn deadlock_decrement_lock_count() {
let LockCount(n) = <T as AssocThreadLocal<LockCount, TAG>>::get_threadlocal();
<T as AssocThreadLocal<LockCount, TAG>>::set_threadlocal(LockCount(n - 1));
}
}
impl<T, TAG> Deref for ShardedMutexGuard<'_, T, TAG>
where
T: AssocObjects<TAG>,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe {
&*self.0.0.get()
}
}
}
impl<T, TAG> DerefMut for ShardedMutexGuard<'_, T, TAG>
where
T: AssocObjects<TAG>,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
&mut *self.0.0.get()
}
}
}
impl<T, TAG> AsRef<T> for ShardedMutexGuard<'_, T, TAG>
where
T: AssocObjects<TAG>,
{
fn as_ref(&self) -> &T {
unsafe {
&*self.0.0.get()
}
}
}
impl<T, TAG> AsMut<T> for ShardedMutexGuard<'_, T, TAG>
where
T: AssocObjects<TAG>,
{
fn as_mut(&mut self) -> &mut T {
unsafe {
&mut *self.0.0.get()
}
}
}
impl<T, TAG> Drop for ShardedMutexGuard<'_, T, TAG>
where
T: AssocObjects<TAG>,
{
fn drop(&mut self) {
#[cfg(debug_assertions)]
Self::deadlock_decrement_lock_count();
unsafe {
self.0.get_mutex().unlock();
}
}
}
sharded_mutex!(bool);
sharded_mutex!(i8);
sharded_mutex!(u8);
sharded_mutex!(i16);
sharded_mutex!(u16);
sharded_mutex!(i32);
sharded_mutex!(u32);
sharded_mutex!(i64);
sharded_mutex!(u64);
sharded_mutex!(i128);
sharded_mutex!(u128);
sharded_mutex!(isize);
sharded_mutex!(usize);
sharded_mutex!(char);
sharded_mutex!(f32);
sharded_mutex!(f64);
sharded_mutex!(String);
#[cfg(test)]
mod tests {
use crate::ShardedMutex;
#[test]
fn smoke() {
let x = ShardedMutex::new(123);
assert_eq!(*x.lock(), 123);
}
#[test]
fn simple_lock() {
let x = ShardedMutex::new(123);
assert_eq!(*x.lock(), 123);
let mut guard = x.lock();
*guard = 234;
drop(guard);
assert_eq!(*x.lock(), 234);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn simple_deadlock() {
let x = ShardedMutex::new(123);
let _guard = x.lock();
let _guard_deadlock = x.lock();
}
#[test]
fn multi_lock() {
let x = ShardedMutex::new(123);
let y = ShardedMutex::new(234);
let z = ShardedMutex::new(345);
let mut guards = ShardedMutex::multi_lock([&x, &z, &y]);
assert_eq!(*guards[0], 123);
assert_eq!(*guards[1], 345);
assert_eq!(*guards[2], 234);
*guards[1] = 456;
drop(guards);
assert_eq!(*z.lock(), 456);
let guards = ShardedMutex::multi_lock([&z, &y, &x]);
assert_eq!(*guards[0], 456);
assert_eq!(*guards[1], 234);
assert_eq!(*guards[2], 123);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic]
fn multi_deadlock() {
let x = ShardedMutex::new(123);
let y = ShardedMutex::new(234);
let z = ShardedMutex::new(345);
let _guards = ShardedMutex::multi_lock([&x, &z, &y]);
let _guards_deadlock = ShardedMutex::multi_lock([&x, &z, &y]);
}
#[test]
fn try_multi_lock() {
let x = ShardedMutex::new(123);
let y = ShardedMutex::new(234);
let z = ShardedMutex::new(345);
let guards = ShardedMutex::multi_lock([&x, &z, &y]);
assert!(ShardedMutex::try_multi_lock([&x, &z, &y]).is_none());
drop(guards);
let guards = ShardedMutex::try_multi_lock([&z, &y, &x]);
assert!(guards.is_some());
assert_eq!(*guards.as_ref().unwrap()[0], 345);
assert_eq!(*guards.as_ref().unwrap()[1], 234);
assert_eq!(*guards.as_ref().unwrap()[2], 123);
}
#[test]
fn pseudo_atomic_ops() {
use crate::PseudoAtomicOps;
let x = ShardedMutex::new(123);
let loaded = x.load();
assert_eq!(loaded, 123);
x.store(&234);
assert_eq!(x.load(), 234);
let mut swapping = 345;
x.swap(&mut swapping);
assert_eq!(swapping, 234);
assert_eq!(x.load(), 345);
assert!(!x.compare_and_set(&123, &456));
assert!(x.compare_and_set(&345, &456));
assert_eq!(x.load(), 456);
}
#[test]
fn get_mut() {
let mut x = ShardedMutex::new(123);
*x.get_mut() = 234;
assert_eq!(*x.get_mut(), 234);
}
#[test]
fn into_inner() {
let x = ShardedMutex::new(123);
let v = x.into_inner();
assert_eq!(v, 123);
}
}