use core::{
cell::{Cell, UnsafeCell},
ops::{Deref, DerefMut},
};
use crate::waiter_queue::WaiterQueue;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State {
Available,
Reading,
Writing,
}
pub struct RwLock<T> {
value: UnsafeCell<T>,
state: Cell<State>,
active_reader_count: Cell<usize>,
waiting_reader_count: Cell<usize>,
readers_waiting_on_writers: Cell<bool>,
reader_leader_cancelled: Cell<bool>,
read_waiters: WaiterQueue,
write_waiters: WaiterQueue,
}
impl<T> RwLock<T> {
pub fn new(initial: T) -> Self {
Self {
value: UnsafeCell::new(initial),
state: Cell::new(State::Available),
active_reader_count: Cell::new(0),
waiting_reader_count: Cell::new(0),
readers_waiting_on_writers: Cell::new(false),
reader_leader_cancelled: Cell::new(false),
read_waiters: WaiterQueue::new(),
write_waiters: WaiterQueue::new(),
}
}
#[inline]
pub fn try_read(&self) -> Option<ReadGuard<'_, T>> {
if self.readers_waiting_on_writers.get() {
return None;
}
if self.state.get() != State::Reading && self.state.get() != State::Available {
return None;
}
self.state.set(State::Reading);
self.active_reader_count.set(self.active_reader_count.get() + 1);
Some(ReadGuard { lock: self })
}
#[inline]
pub async fn read(&self) -> ReadGuard<'_, T> {
if self.state.get() == State::Writing || self.write_waiters.waiter_count() > 0 || self.readers_waiting_on_writers.get() {
self.waiting_reader_count.set(self.waiting_reader_count.get() + 1);
loop {
if !self.readers_waiting_on_writers.get() {
struct DropGuard<'a, T>(&'a RwLock<T>);
impl<T> Drop for DropGuard<'_, T> {
fn drop(&mut self) {
self.0.waiting_reader_count.set(self.0.waiting_reader_count.get() - 1);
self.0.reader_leader_cancelled.set(true);
self.0.read_waiters.notify(1);
}
}
let drop_guard = DropGuard(self);
self.readers_waiting_on_writers.set(true);
self.write_waiters
.wait_until(|| self.state.get() == State::Available)
.await;
core::mem::forget(drop_guard);
self.readers_waiting_on_writers.set(false);
self.state.set(State::Reading);
self.active_reader_count.set(self.waiting_reader_count.replace(0));
self.read_waiters.notify_all();
return ReadGuard { lock: self };
} else {
struct DropGuard<'a, T>(&'a RwLock<T>);
impl<T> Drop for DropGuard<'_, T> {
fn drop(&mut self) {
self.0.waiting_reader_count.set(self.0.waiting_reader_count.get() - 1);
}
}
let drop_guard = DropGuard(self);
self.read_waiters
.wait_until(|| {
!self.readers_waiting_on_writers.get()
|| self.reader_leader_cancelled.get()
})
.await;
core::mem::forget(drop_guard);
if self.reader_leader_cancelled.get() {
self.readers_waiting_on_writers.set(false);
self.reader_leader_cancelled.set(false);
} else {
debug_assert!(self.state.get() == State::Reading);
return ReadGuard { lock: self };
}
}
}
} else {
debug_assert!(self.state.get() == State::Available || self.state.get() == State::Reading);
self.state.set(State::Reading);
self.active_reader_count.set(self.active_reader_count.get() + 1);
return ReadGuard { lock: self };
}
}
#[inline]
pub fn try_write(&self) -> Option<WriteGuard<'_, T>> {
if self.state.get() == State::Available {
self.state.set(State::Writing);
Some(WriteGuard { lock: self })
} else {
None
}
}
#[inline]
pub async fn write(&self) -> WriteGuard<'_, T> {
self.write_waiters.wait_for(|| self.try_write()).await
}
}
pub struct ReadGuard<'a, T> {
lock: &'a RwLock<T>,
}
impl<T> Deref for ReadGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.value.get() }
}
}
impl<T> Drop for ReadGuard<'_, T> {
#[inline]
fn drop(&mut self) {
debug_assert!(self.lock.state.get() == State::Reading);
self.lock
.active_reader_count
.set(self.lock.active_reader_count.get() - 1);
if self.lock.active_reader_count.get() == 0 {
self.lock.state.set(State::Available);
self.lock.write_waiters.notify(1);
}
}
}
pub struct WriteGuard<'a, T> {
lock: &'a RwLock<T>,
}
impl<T> Deref for WriteGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.value.get() }
}
}
impl<T> DerefMut for WriteGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.value.get() }
}
}
impl<T> Drop for WriteGuard<'_, T> {
#[inline]
fn drop(&mut self) {
debug_assert!(self.lock.state.get() == State::Writing);
self.lock.state.set(State::Available);
self.lock.write_waiters.notify(1);
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_rw_lock_immediate() {
let rw_lock: RwLock<u32> = RwLock::new(42);
assert_eq!(*pollster::block_on(rw_lock.read()), 42);
*pollster::block_on(rw_lock.write()) += 2;
assert_eq!(*pollster::block_on(rw_lock.read()), 44);
}
}