use core::cell::UnsafeCell;
use core::default::Default;
use core::hint::spin_loop;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crossbeam_utils::CachePadded;
const MAX_READER_THREADS: usize = 192;
const_assert!(MAX_READER_THREADS > 0);
#[allow(clippy::declare_interior_mutable_const)]
const RLOCK_DEFAULT: CachePadded<AtomicUsize> = CachePadded::new(AtomicUsize::new(0));
pub struct RwLock<T>
where
T: Sized + Sync,
{
wlock: CachePadded<AtomicBool>,
rlock: [CachePadded<AtomicUsize>; MAX_READER_THREADS],
data: UnsafeCell<T>,
}
pub struct ReadGuard<'a, T: Sized + Sync + 'a> {
tid: usize,
lock: &'a RwLock<T>,
}
pub struct WriteGuard<'a, T: Sized + Sync + 'a> {
lock: &'a RwLock<T>,
}
impl<T> Default for RwLock<T>
where
T: Sized + Default + Sync,
{
fn default() -> RwLock<T> {
RwLock {
wlock: CachePadded::new(AtomicBool::new(false)),
rlock: [RLOCK_DEFAULT; MAX_READER_THREADS],
data: UnsafeCell::new(T::default()),
}
}
}
impl<T> RwLock<T>
where
T: Sized + Sync,
{
pub fn new(t: T) -> Self {
Self {
wlock: CachePadded::new(AtomicBool::new(false)),
rlock: [RLOCK_DEFAULT; MAX_READER_THREADS],
data: UnsafeCell::new(t),
}
}
pub fn write(&self, n: usize) -> WriteGuard<T> {
loop {
match self.wlock.compare_exchange_weak(
false,
true,
Ordering::Acquire,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue,
}
}
while !self
.rlock
.iter()
.take(n)
.all(|item| item.load(Ordering::Relaxed) == 0)
{
spin_loop();
}
unsafe { WriteGuard::new(self) }
}
pub fn read(&self, tid: usize) -> ReadGuard<T> {
let ptr = unsafe {
&*(&self.wlock as *const crossbeam_utils::CachePadded<core::sync::atomic::AtomicBool>
as *const bool)
};
loop {
unsafe {
while core::ptr::read_volatile(ptr) {
spin_loop();
}
}
self.rlock[tid].fetch_add(1, Ordering::Acquire);
if !self.wlock.load(Ordering::Relaxed) {
break;
}
self.rlock[tid].fetch_sub(1, Ordering::Release);
}
unsafe { ReadGuard::new(self, tid) }
}
pub(in crate::rwlock) unsafe fn write_unlock(&self) {
match self
.wlock
.compare_exchange_weak(true, false, Ordering::Acquire, Ordering::Acquire)
{
Ok(_) => (),
Err(_) => panic!("write_unlock() called without acquiring the write lock"),
}
}
pub(in crate::rwlock) unsafe fn read_unlock(&self, tid: usize) {
if self.rlock[tid].fetch_sub(1, Ordering::Release) == 0 {
panic!("read_unlock() called without acquiring the read lock");
}
}
}
impl<'rwlock, T: Sized + Sync> ReadGuard<'rwlock, T> {
unsafe fn new(lock: &'rwlock RwLock<T>, tid: usize) -> ReadGuard<'rwlock, T> {
ReadGuard { tid, lock }
}
}
impl<'rwlock, T: Sized + Sync> WriteGuard<'rwlock, T> {
unsafe fn new(lock: &'rwlock RwLock<T>) -> WriteGuard<'rwlock, T> {
WriteGuard { lock }
}
}
unsafe impl<T: Sized + Sync> Sync for RwLock<T> {}
impl<T: Sized + Sync> Deref for ReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: Sized + Sync> Deref for WriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: Sized + Sync> DerefMut for WriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: Sized + Sync> Drop for ReadGuard<'_, T> {
fn drop(&mut self) {
unsafe {
let tid = self.tid;
self.lock.read_unlock(tid);
}
}
}
impl<T: Sized + Sync> Drop for WriteGuard<'_, T> {
fn drop(&mut self) {
unsafe {
self.lock.write_unlock();
}
}
}
#[cfg(test)]
mod tests {
use super::{RwLock, MAX_READER_THREADS};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::vec::Vec;
#[test]
fn test_rwlock_default() {
let lock = RwLock::<usize>::default();
assert_eq!(lock.wlock.load(Ordering::Relaxed), false);
for idx in 0..MAX_READER_THREADS {
assert_eq!(lock.rlock[idx].load(Ordering::Relaxed), 0);
}
assert_eq!(unsafe { *lock.data.get() }, usize::default());
}
#[test]
fn test_writer_lock() {
let lock = RwLock::<usize>::default();
let val = 10;
let mut guard = lock.write(1);
*guard = val;
assert_eq!(lock.wlock.load(Ordering::Relaxed), true);
assert_eq!(lock.rlock[0].load(Ordering::Relaxed), 0);
assert_eq!(unsafe { *lock.data.get() }, val);
}
#[test]
fn test_writer_unlock() {
let lock = RwLock::<usize>::default();
{
let mut _guard = lock.write(1);
assert_eq!(lock.wlock.load(Ordering::Relaxed), true);
}
assert_eq!(lock.wlock.load(Ordering::Relaxed), false);
}
#[test]
fn test_reader_lock() {
let lock = RwLock::<usize>::default();
let val = 10;
unsafe {
*lock.data.get() = val;
}
let guard = lock.read(0);
assert_eq!(lock.wlock.load(Ordering::Relaxed), false);
assert_eq!(lock.rlock[0].load(Ordering::Relaxed), 1);
assert_eq!(*guard, val);
}
#[test]
fn test_reader_unlock() {
let lock = RwLock::<usize>::default();
{
let mut _guard = lock.read(0);
assert_eq!(lock.rlock[0].load(Ordering::Relaxed), 1);
}
assert_eq!(lock.rlock[0].load(Ordering::Relaxed), 0);
}
#[test]
fn test_multiple_readers() {
let lock = RwLock::<usize>::default();
let val = 10;
unsafe {
*lock.data.get() = val;
}
let f = lock.read(0);
let s = lock.read(1);
let t = lock.read(2);
assert_eq!(lock.rlock[0].load(Ordering::Relaxed), 1);
assert_eq!(lock.rlock[1].load(Ordering::Relaxed), 1);
assert_eq!(lock.rlock[2].load(Ordering::Relaxed), 1);
assert_eq!(*f, val);
assert_eq!(*s, val);
assert_eq!(*t, val);
}
#[test]
fn test_lock_combinations() {
let l = RwLock::<usize>::default();
{
let _g = l.write(2);
}
{
let _g = l.write(2);
}
{
let _f = l.read(0);
let _s = l.read(1);
}
{
let _g = l.write(2);
}
}
#[test]
fn test_atomic_writes() {
let lock = Arc::new(RwLock::<usize>::default());
let t = 100;
let mut threads = Vec::new();
for _i in 0..t {
let l = lock.clone();
let child = thread::spawn(move || {
let mut ele = l.write(t);
*ele += 1;
});
threads.push(child);
}
for _i in 0..threads.len() {
let _retval = threads
.pop()
.unwrap()
.join()
.expect("Thread didn't finish successfully.");
}
assert_eq!(unsafe { *lock.data.get() }, t);
}
#[test]
fn test_parallel_readers() {
let lock = Arc::new(RwLock::<usize>::default());
let t = 100;
unsafe {
*lock.data.get() = t;
}
let mut threads = Vec::new();
for i in 0..t {
let l = lock.clone();
let child = thread::spawn(move || {
let ele = l.read(i);
assert_eq!(*ele, t);
});
threads.push(child);
}
for _i in 0..threads.len() {
let _retval = threads
.pop()
.unwrap()
.join()
.expect("Reading didn't finish successfully.");
}
}
#[test]
#[should_panic]
fn test_writer_unlock_without_lock() {
let lock = RwLock::<usize>::default();
unsafe { lock.write_unlock() };
}
#[test]
#[should_panic]
fn test_reader_unlock_without_lock() {
let lock = RwLock::<usize>::default();
unsafe { lock.read_unlock(1) };
}
#[test]
#[should_panic(expected = "This test should always panic")]
fn test_reader_after_writer() {
let lock = RwLock::<usize>::default();
let shared = Arc::new(AtomicUsize::new(0));
let s = shared.clone();
let lock_thread = thread::spawn(move || {
let _w = lock.write(1);
let _r = lock.read(0);
s.store(1, Ordering::SeqCst);
});
thread::sleep(std::time::Duration::from_secs(2));
if shared.load(Ordering::SeqCst) == 0 {
panic!("This test should always panic");
}
lock_thread.join().unwrap();
}
#[test]
#[should_panic(expected = "This test should always panic")]
fn test_writer_after_reader() {
let lock = RwLock::<usize>::default();
let shared = Arc::new(AtomicUsize::new(0));
let s = shared.clone();
let lock_thread = thread::spawn(move || {
let _r = lock.read(0);
let _w = lock.write(1);
s.store(1, Ordering::SeqCst);
});
thread::sleep(std::time::Duration::from_secs(2));
if shared.load(Ordering::SeqCst) == 0 {
panic!("This test should always panic");
}
lock_thread.join().unwrap();
}
#[test]
#[should_panic(expected = "This test should always panic")]
fn test_writer_after_writer() {
let lock = RwLock::<usize>::default();
let shared = Arc::new(AtomicUsize::new(0));
let s = shared.clone();
let lock_thread = thread::spawn(move || {
let _f = lock.write(1);
let _s = lock.write(1);
s.store(1, Ordering::SeqCst);
});
thread::sleep(std::time::Duration::from_secs(2));
if shared.load(Ordering::SeqCst) == 0 {
panic!("This test should always panic");
}
lock_thread.join().unwrap();
}
}