#[cfg(loom)]
use loom::sync::atomic::AtomicU32;
#[cfg(not(loom))]
use std::sync::atomic::AtomicU32;
use std::{cell::UnsafeCell, sync::atomic::Ordering};
use std::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
#[derive(Debug, Default)]
pub struct IdLock<T> {
inner: RawIdLock,
data: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for IdLock<T> {}
unsafe impl<T: Send + Sync> Sync for IdLock<T> {}
impl<T> IdLock<T> {
pub fn new(data: T) -> Self {
Self {
inner: RawIdLock::new(),
data: UnsafeCell::new(data),
}
}
pub fn try_read(&self) -> Result<IdLockReadGuard<'_, T>, IdLockReadErr> {
self.inner.try_read().map(|_| IdLockReadGuard::new(self))
}
pub fn try_read_upgradable(&self) -> Result<IdLockReadGuard<'_, T, Upgradable>, IdLockReadErr> {
self.inner
.try_read()
.map(|_| IdLockReadGuard::new_upgradable(self))
}
pub fn try_write(&self, worker_id: u32) -> Result<IdLockWriteGuard<'_, T>, IdLockWriteErr> {
self.inner
.try_write(worker_id)
.map(|_| IdLockWriteGuard::new(self))
}
}
#[derive(Debug, Default)]
struct RawIdLock {
state: AtomicU32,
}
type PhantomUnsend = std::marker::PhantomData<*const ()>;
#[derive(Debug)]
pub struct Upgradable;
#[derive(Debug)]
pub struct NotUpgradable;
#[must_use = "if unused the IdLock will immediately unlock"]
#[clippy::has_significant_drop]
#[derive(Debug)]
pub struct IdLockReadGuard<'a, T: 'a, U = NotUpgradable> {
lock: &'a IdLock<T>,
_no_send: PhantomUnsend,
_upgradable: PhantomData<U>,
}
impl<'a, T> IdLockReadGuard<'a, T> {
fn new(lock: &'a IdLock<T>) -> Self {
Self {
lock,
_no_send: Default::default(),
_upgradable: Default::default(),
}
}
}
impl<'a, T> IdLockReadGuard<'a, T, Upgradable> {
fn new_upgradable(lock: &'a IdLock<T>) -> Self {
Self {
lock,
_no_send: Default::default(),
_upgradable: Default::default(),
}
}
pub fn deny_upgrade(self) -> IdLockReadGuard<'a, T, NotUpgradable> {
let lock = self.lock;
std::mem::forget(self);
IdLockReadGuard {
lock,
_no_send: Default::default(),
_upgradable: Default::default(),
}
}
pub fn try_upgrade(self, worker_id: u32) -> Result<IdLockWriteGuard<'a, T>, Self> {
let r = self.lock.inner.try_upgrade(worker_id);
match r {
Ok(_) => {
let lock = self.lock;
std::mem::forget(self); Ok(IdLockWriteGuard::new(lock))
}
Err(_) => Err(self),
}
}
}
impl<'a, T> IdLockWriteGuard<'a, T> {
fn new(lock: &'a IdLock<T>) -> Self {
Self {
lock,
_no_send: Default::default(),
}
}
pub fn downgrade(self) -> IdLockReadGuard<'a, T, Upgradable> {
unsafe { self.lock.inner.downgrade() };
let lock = self.lock;
std::mem::forget(self);
IdLockReadGuard::new_upgradable(lock)
}
}
unsafe impl<T: Sync, U> Sync for IdLockReadGuard<'_, T, U> {}
#[must_use = "if unused the IdLock will immediately unlock"]
#[clippy::has_significant_drop]
#[derive(Debug)]
pub struct IdLockWriteGuard<'a, T: 'a> {
lock: &'a IdLock<T>,
_no_send: PhantomUnsend,
}
impl<T, U> Drop for IdLockReadGuard<'_, T, U> {
fn drop(&mut self) {
unsafe {
self.lock.inner.unlock_read();
}
}
}
impl<T> Drop for IdLockWriteGuard<'_, T> {
fn drop(&mut self) {
unsafe { self.lock.inner.unlock_write() }
}
}
unsafe impl<T: Sync> Sync for IdLockWriteGuard<'_, T> {}
impl<T> Deref for IdLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.data.get() }
}
}
impl<T> Deref for IdLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.data.get() }
}
}
impl<T> DerefMut for IdLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.data.get() }
}
}
#[test]
#[cfg(not(loom))]
fn test_idlock_single_thread() {
let counter = IdLock::new(7);
{
let guard1 = counter.try_read().unwrap();
let guard2 = counter.try_read().unwrap();
assert_eq!(*guard1.deref(), 7);
assert_eq!(*guard2.deref(), 7);
assert_eq!(
counter.try_write(42).err(),
Some(IdLockWriteErr::NumberOfReaders(2))
);
}
{
let mut guard = counter.try_write(42).unwrap();
assert_eq!(
counter.try_read().err(),
Some(IdLockReadErr::CurrentWriter(42))
);
assert_eq!(
counter.try_write(77).err(),
Some(IdLockWriteErr::CurrentWriter(42))
);
*guard += 1;
}
assert_eq!(*counter.try_read().unwrap().deref(), 8);
}
#[test]
#[cfg(not(loom))]
fn test_idlock_upgrade_downgrade() {
let counter = IdLock::new(7);
let read_guard = counter.try_read_upgradable().unwrap();
let write_guard = read_guard.try_upgrade(42).unwrap();
assert_eq!(*write_guard, 7);
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum IdLockWriteErr {
CurrentWriter(u32),
NumberOfReaders(u32),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum IdLockReadErr {
CurrentWriter(u32),
LocksExhausted(u32),
}
impl RawIdLock {
const READ_LOCKED: u32 = 1;
const WRITE_LOCK_BIT: u32 = 31;
const WRITE_LOCK_MASK: u32 = (1 << Self::WRITE_LOCK_BIT);
const ID_COUNT_MASK: u32 = !Self::WRITE_LOCK_MASK;
const MAX_WORKER_ID: u32 = (1 << Self::WRITE_LOCK_BIT) - 1;
const MAX_READ_LOCKS: u32 = (1 << Self::WRITE_LOCK_BIT) - 1;
fn new() -> Self {
Self {
state: Default::default(),
}
}
#[inline]
fn is_read_lockable(s: u32) -> bool {
s & Self::WRITE_LOCK_MASK == 0 && s < Self::MAX_READ_LOCKS }
#[inline]
fn is_upgradable(s: u32) -> bool {
Self::num_readers(s) == 1 && Self::worker_id(s).is_none()
}
#[inline]
fn worker_id(s: u32) -> Option<u32> {
let id = s & Self::ID_COUNT_MASK;
let is_write_locked = s & Self::WRITE_LOCK_MASK != 0;
is_write_locked.then_some(id)
}
#[inline]
fn num_readers(s: u32) -> u32 {
let num_readers = s & Self::ID_COUNT_MASK;
let is_write_locked = s & Self::WRITE_LOCK_MASK != 0;
if is_write_locked { 0 } else { num_readers }
}
fn try_read(&self) -> Result<u32, IdLockReadErr> {
self.state
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |s| {
Self::is_read_lockable(s).then_some(s + Self::READ_LOCKED)
})
.map(|prev_value| Self::num_readers(prev_value) + 1) .map_err(|prev_value| match Self::worker_id(prev_value) {
Some(worker_id) => IdLockReadErr::CurrentWriter(worker_id),
None => IdLockReadErr::LocksExhausted(Self::num_readers(prev_value)),
}) }
unsafe fn unlock_read(&self) -> u32 {
let previous = self.state.fetch_sub(Self::READ_LOCKED, Ordering::Release);
debug_assert!(
Self::num_readers(previous) > 0,
"there must have been a pending read lock"
);
debug_assert_eq!(
Self::worker_id(previous),
None,
"cannot have a pending write lock when unlocking a read lock"
);
Self::num_readers(previous) - 1
}
fn try_write(&self, worker_id: u32) -> Result<(), IdLockWriteErr> {
assert!(
worker_id <= Self::MAX_WORKER_ID,
"worker_id out of allowed range"
);
self.state
.compare_exchange(
0, worker_id | Self::WRITE_LOCK_MASK,
Ordering::Acquire,
Ordering::Relaxed,
)
.map(|_| ()) .map_err(|prev_value| {
let num_readers = Self::num_readers(prev_value);
let worker_id = Self::worker_id(prev_value);
match worker_id {
Some(id) => {
debug_assert_eq!(num_readers, 0);
IdLockWriteErr::CurrentWriter(id)
}
None => {
debug_assert_ne!(num_readers, 0);
IdLockWriteErr::NumberOfReaders(num_readers)
}
}
})
}
fn try_upgrade(&self, worker_id: u32) -> Result<(), u32> {
assert!(
worker_id <= Self::MAX_WORKER_ID,
"worker_id out of allowed range"
);
self.state
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |s| {
Self::is_upgradable(s).then_some(worker_id | Self::WRITE_LOCK_MASK)
})
.map(|prev_value| {
debug_assert_eq!(
Self::num_readers(prev_value),
1,
"there must be exactly one reader, otherwise an upgrade cannot make sense"
);
})
.map_err(Self::num_readers)
}
unsafe fn downgrade(&self) {
let previous = self.state.swap(Self::READ_LOCKED, Ordering::Release);
debug_assert_ne!(
Self::worker_id(previous),
None,
"there must have been a pending write lock"
);
debug_assert_eq!(
Self::num_readers(previous),
0,
"it's impossible to have pending read locks here"
);
}
unsafe fn unlock_write(&self) {
let previous = self.state.swap(0, Ordering::Release);
debug_assert_ne!(
Self::worker_id(previous),
None,
"there must have been a pending write lock"
);
debug_assert_eq!(
Self::num_readers(previous),
0,
"it's impossible to have pending read locks here"
);
}
}
#[test]
#[cfg(not(loom))]
fn test_rawidlock_single_thread() {
let lock = RawIdLock::new();
assert_eq!(lock.try_read(), Ok(1));
assert_eq!(lock.try_read(), Ok(2));
assert_eq!(lock.try_read(), Ok(3));
assert_eq!(lock.try_write(0), Err(IdLockWriteErr::NumberOfReaders(3)));
assert_eq!(lock.try_write(0), Err(IdLockWriteErr::NumberOfReaders(3)));
assert_eq!(unsafe { lock.unlock_read() }, 2);
assert_eq!(unsafe { lock.unlock_read() }, 1);
assert_eq!(lock.try_write(0), Err(IdLockWriteErr::NumberOfReaders(1)));
assert_eq!(unsafe { lock.unlock_read() }, 0);
assert_eq!(
lock.try_write(7),
Ok(()),
"should be able to acquire one write lock"
);
assert_eq!(lock.try_write(0), Err(IdLockWriteErr::CurrentWriter(7)));
assert_eq!(
lock.try_read(),
Err(IdLockReadErr::CurrentWriter(7)),
"should fail to acquire a read lock while it is locked for writing"
);
unsafe { lock.unlock_write() };
}
#[test]
#[cfg(not(loom))]
fn test_rawidlock_upgrade() {
let lock = RawIdLock::new();
assert_eq!(lock.try_read(), Ok(1));
assert_eq!(lock.try_upgrade(7), Ok(()));
assert_eq!(lock.try_read(), Err(IdLockReadErr::CurrentWriter(7)));
unsafe {
lock.unlock_write();
}
assert_eq!(
lock.try_read(),
Ok(1),
"should be able to acquire read locks again"
);
}
#[test]
#[cfg(not(loom))]
fn test_rawidlock_upgrade_fail() {
let lock = RawIdLock::new();
assert_eq!(lock.try_read(), Ok(1));
assert_eq!(lock.try_read(), Ok(2));
assert_eq!(lock.try_upgrade(7), Err(2));
unsafe {
lock.unlock_read();
lock.unlock_read();
}
}
#[test]
#[cfg(not(loom))]
fn test_rawidlock_downgrade() {
let lock = RawIdLock::new();
assert_eq!(lock.try_write(7), Ok(()));
unsafe {
lock.downgrade();
}
assert_eq!(lock.try_read(), Ok(2));
unsafe {
lock.unlock_read();
lock.unlock_read();
}
}
#[test]
#[cfg(not(loom))]
fn test_rawidlock_multi_thread() {
use std::thread;
let lock = RawIdLock::new();
let counter = 0usize;
let num_threads = 8;
let num_loops = 100000;
thread::scope(|s| {
let lock = &lock;
let counter = &counter;
for worker_id in 0..num_threads {
s.spawn(move || {
let counter_ptr: *const usize = counter;
let ptr_mut: *mut usize = counter_ptr.cast_mut();
for _ in 0..num_loops {
loop {
if lock.try_write(worker_id as u32).is_ok() {
break;
}
}
{
*unsafe { ptr_mut.as_mut() }.unwrap() += 1;
}
unsafe {
lock.unlock_write();
}
}
});
}
for _ in 0..num_threads {
s.spawn(move || {
let mut last_counter = 0;
for _ in 0..num_loops {
loop {
if lock.try_read().is_ok() {
break;
}
}
{
assert!(*counter >= last_counter);
last_counter = *counter;
}
unsafe {
lock.unlock_read();
}
}
});
}
});
assert_eq!(counter, num_threads * num_loops);
}
#[test]
#[cfg(loom)]
fn test_loom_raw_idlock_dual_write() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(RawIdLock::new());
let lock1 = lock.clone();
let data = Arc::new(AtomicU32::new(0));
let data1 = data.clone();
let data_atomic_reference = Arc::new(AtomicU32::new(0));
let data_atomic_reference1 = data_atomic_reference.clone();
let t1 = thread::spawn(move || {
match lock1.try_write(1) {
Ok(()) => {
data1.store(data1.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
data_atomic_reference1.fetch_add(1, Ordering::Relaxed);
unsafe {
lock1.unlock_write();
}
}
Err(id) => assert_eq!(id, IdLockWriteErr::CurrentWriter(2)),
};
});
match lock.try_write(2) {
Ok(()) => {
data.store(data.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
data_atomic_reference.fetch_add(1, Ordering::Relaxed);
unsafe {
lock.unlock_write();
}
}
Err(id) => assert_eq!(id, IdLockWriteErr::CurrentWriter(1)),
};
t1.join().unwrap();
assert_eq!(
lock.state.load(Ordering::Relaxed),
0,
"lock must be in the base state"
);
assert_eq!(
data.load(Ordering::Relaxed),
data_atomic_reference.load(Ordering::Relaxed)
);
});
}
#[test]
#[cfg(loom)]
fn test_loom_raw_idlock_read_write() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(RawIdLock::new());
let lock1 = lock.clone();
let t1 = thread::spawn(move || {
match lock1.try_write(42) {
Ok(()) => unsafe {
lock1.unlock_write();
},
Err(id) => assert_eq!(id, IdLockWriteErr::NumberOfReaders(1)),
};
});
match lock.try_read() {
Ok(num_readers) => {
assert_eq!(num_readers, 1);
unsafe {
lock.unlock_read();
}
}
Err(id) => assert_eq!(id, IdLockReadErr::CurrentWriter(42)),
};
t1.join().unwrap();
assert_eq!(
lock.state.load(Ordering::Relaxed),
0,
"lock must be in the base state"
);
});
}
#[test]
#[cfg(loom)]
fn test_loom_raw_idlock_read_upgrade_write() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(RawIdLock::new());
let lock1 = lock.clone();
let t1 = thread::spawn(move || {
match lock1.try_write(42) {
Ok(()) => unsafe {
lock1.unlock_write();
},
Err(id) => assert!(
id == IdLockWriteErr::NumberOfReaders(1)
|| id == IdLockWriteErr::CurrentWriter(77)
),
};
});
match lock.try_read() {
Ok(num_readers) => {
assert_eq!(num_readers, 1);
match lock.try_upgrade(77) {
Ok(()) => unsafe {
lock.unlock_write();
},
Err(num_readers) => {
unsafe {
lock.unlock_read();
}
assert_eq!(num_readers, 1);
}
}
}
Err(id) => assert_eq!(id, IdLockReadErr::CurrentWriter(42)),
};
t1.join().unwrap();
assert_eq!(
lock.state.load(Ordering::Relaxed),
0,
"lock must be in the base state"
);
});
}
#[test]
#[cfg(loom)]
fn test_loom_raw_idlock_read_upgrade_write_downgrade() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(RawIdLock::new());
let lock1 = lock.clone();
let nonatomic_counter = Arc::new(AtomicU32::new(0));
let nonatomic_counter1 = nonatomic_counter.clone();
let atomic_counter = Arc::new(AtomicU32::new(0));
let atomic_counter1 = atomic_counter.clone();
let t1 = thread::spawn(move || {
match lock1.try_write(42) {
Ok(()) => {
atomic_counter1.fetch_add(1, Ordering::Relaxed);
nonatomic_counter1.store(
nonatomic_counter1.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
unsafe {
lock1.unlock_write();
}
}
Err(id) => assert!(
id == IdLockWriteErr::NumberOfReaders(1)
|| id == IdLockWriteErr::CurrentWriter(77)
),
};
});
match lock.try_read() {
Ok(num_readers) => {
assert_eq!(num_readers, 1);
match lock.try_upgrade(77) {
Ok(()) => {
atomic_counter.fetch_add(1, Ordering::Relaxed);
nonatomic_counter.store(
nonatomic_counter.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
unsafe {
lock.downgrade();
lock.unlock_read();
}
}
Err(num_readers) => {
unsafe {
lock.unlock_read();
}
assert_eq!(num_readers, 1);
}
}
}
Err(id) => assert_eq!(id, IdLockReadErr::CurrentWriter(42)),
};
t1.join().unwrap();
assert_eq!(
lock.state.load(Ordering::Relaxed),
0,
"lock must be in the base state"
);
assert_eq!(
atomic_counter.load(Ordering::Relaxed),
nonatomic_counter.load(Ordering::Relaxed)
);
});
}
#[test]
#[cfg(loom)]
fn test_loom_idlock_no_write_races() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(IdLock::new(0));
let lock1 = lock.clone();
let counter = Arc::new(AtomicU32::new(0)); let counter1 = counter.clone();
let t1 = thread::spawn(move || {
match lock1.try_write(1) {
Ok(mut data) => {
*data += 1;
counter1.fetch_add(1, Ordering::Relaxed);
}
Err(id) => assert_eq!(id, IdLockWriteErr::CurrentWriter(0)),
};
});
match lock.try_write(0) {
Ok(mut data) => {
*data += 1;
counter.fetch_add(1, Ordering::Relaxed);
}
Err(id) => assert_eq!(id, IdLockWriteErr::CurrentWriter(1)),
};
t1.join().unwrap();
assert_eq!(*lock.try_read().unwrap(), counter.load(Ordering::Relaxed));
})
}
#[test]
#[cfg(loom)]
fn test_loom_idlock_read_and_write() {
loom::model(|| {
use loom::thread;
use std::sync::Arc;
let lock = Arc::new(IdLock::new(0));
let lock1 = lock.clone();
let counter = Arc::new(AtomicU32::new(0)); let counter1 = counter.clone();
let t1 = thread::spawn(move || {
match lock1.try_read() {
Ok(data) => {
assert_eq!(*data, counter1.load(Ordering::Relaxed));
}
Err(id) => assert_eq!(id, IdLockReadErr::CurrentWriter(0)),
};
});
match lock.try_write(0) {
Ok(mut data) => {
*data += 1;
counter.fetch_add(1, Ordering::Relaxed);
}
Err(id) => assert_eq!(id, IdLockWriteErr::NumberOfReaders(1)),
};
t1.join().unwrap();
assert_eq!(*lock.try_read().unwrap(), counter.load(Ordering::Relaxed));
})
}