use std::{
cell::UnsafeCell,
ops::{Deref, DerefMut},
sync::atomic::{AtomicU32, Ordering},
};
use crate::backoff::Backoff;
use crate::hint::unlikely;
const WRITE_NUMBER: u32 = 1_u32 << 30;
pub struct RwSpinlock<T> {
flag: AtomicU32,
value: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for RwSpinlock<T> {}
unsafe impl<T: Send> Sync for RwSpinlock<T> {}
impl<T> RwSpinlock<T> {
#[inline(always)]
pub fn new(value: T) -> Self {
Self {
flag: AtomicU32::new(0),
value: UnsafeCell::new(value),
}
}
pub fn try_write(&self) -> Option<RwSpinlockGuard<T>> {
if self.flag.compare_exchange_weak(
0,
WRITE_NUMBER,
Ordering::Acquire,
Ordering::Relaxed,
).is_ok() {
return Some(RwSpinlockGuard { parent: self });
}
None
}
pub fn try_read(&self) -> Option<RwSpinlockGuard<T>> {
let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
if pre_value < WRITE_NUMBER {
return Some(RwSpinlockGuard { parent: self });
}
None
}
pub fn write(&self) -> RwSpinlockGuard<T> {
let backoff = Backoff::default();
loop {
if self.flag.compare_exchange_weak(
0,
WRITE_NUMBER,
Ordering::Acquire,
Ordering::Relaxed,
).is_ok() {
break;
}
while self.flag.load(Ordering::Relaxed) != 0 {
backoff.spin();
}
}
RwSpinlockGuard { parent: self }
}
pub fn read(&self) -> RwSpinlockGuard<T> {
let backoff = Backoff::default();
loop {
let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
if pre_value < WRITE_NUMBER {
break;
}
while self.flag.load(Ordering::Relaxed) != 0 {
backoff.spin();
}
}
RwSpinlockGuard { parent: self }
}
}
pub struct RwSpinlockGuard<'a, T> {
parent: &'a RwSpinlock<T>,
}
impl<T> Drop for RwSpinlockGuard<'_, T> {
#[inline(always)]
fn drop(&mut self) {
if unlikely(self.parent.flag.load(Ordering::Relaxed) >= WRITE_NUMBER) {
self.parent.flag.store(0, Ordering::Release);
} else {
self.parent.flag.fetch_sub(1, Ordering::Relaxed);
}
}
}
impl<T> Deref for RwSpinlockGuard<'_, T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
impl<T> DerefMut for RwSpinlockGuard<'_, T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.parent.value.get() }
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[allow(unused_variables)]
#[test]
fn test_read_unlock() {
let m = RwSpinlock::<i32>::new(0);
{
let r1 = m.read();
{
let r2 = m.read();
let r2 = m.read();
assert!(m.try_write().is_none());
}
assert!(m.try_write().is_none());
}
assert!(m.try_write().is_some());
}
#[allow(unused_variables)]
#[test]
fn test_write_unlock() {
let m = RwSpinlock::<i32>::new(0);
{
let w1 = m.write();
assert!(m.try_read().is_none());
}
assert!(m.try_read().is_some());
}
#[test]
fn test_rw_arc() {
let arc = Arc::new(RwSpinlock::new(0));
let arc2 = arc.clone();
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let mut lock = arc2.write();
for _ in 0..10 {
let tmp = *lock;
*lock = -1;
std::thread::yield_now();
*lock = tmp + 1;
}
tx.send(()).unwrap();
});
let mut children = Vec::new();
for _ in 0..5 {
let arc3 = arc.clone();
children.push(std::thread::spawn(move || {
let lock = arc3.read();
assert!(*lock >= 0);
}));
}
for r in children {
assert!(r.join().is_ok());
}
rx.recv().unwrap();
let lock = arc.read();
assert_eq!(*lock, 10);
}
#[test]
fn test_rw_access_in_unwind() {
let arc = Arc::new(RwSpinlock::new(1));
let arc2 = arc.clone();
let _ = std::thread::spawn(move || {
struct Unwinder {
i: Arc<RwSpinlock<isize>>,
}
impl Drop for Unwinder {
fn drop(&mut self) {
let mut lock = self.i.write();
*lock += 1;
}
}
let _u = Unwinder { i: arc2 };
panic!();
})
.join();
let lock = arc.read();
assert_eq!(*lock, 2);
}
#[test]
fn test_rwlock_unsized() {
let rw: &RwSpinlock<[i32;3]> = &RwSpinlock::new([1, 2, 3]);
{
let b = &mut *rw.write();
b[0] = 4;
b[2] = 5;
}
let comp: &[i32] = &[4, 2, 5];
assert_eq!(&*rw.read(), comp);
}
#[allow(clippy::assertions_on_constants)]
#[test]
fn test_rwlock_try_write() {
let lock = RwSpinlock::new(0isize);
let read_guard = lock.read();
let write_result = lock.try_write();
match write_result {
None => (),
Some(_) => assert!(
false,
"try_write should not succeed while read_guard is in scope"
),
}
drop(read_guard);
}
#[test]
fn test_rw_try_read() {
let m = RwSpinlock::new(0);
std::mem::forget(m.write());
assert!(m.try_read().is_none());
}
}