use alloc::collections::VecDeque;
use alloc::fmt;
use core::{
cell::Cell,
future::Future,
hint::unreachable_unchecked,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll, Waker},
};
use fairly_unsafe_cell::FairlyUnsafeCell;
use crate::{extend_lifetime, extend_lifetime_mut};
pub struct RwLock<T> {
value: FairlyUnsafeCell<T>,
readers: Cell<Option<usize>>, parked_reads: FairlyUnsafeCell<VecDeque<Waker>>, parked_writes: FairlyUnsafeCell<VecDeque<Waker>>, }
impl<T> RwLock<T> {
pub fn new(value: T) -> Self {
RwLock {
value: FairlyUnsafeCell::new(value),
readers: Cell::new(Some(0)),
parked_reads: FairlyUnsafeCell::new(VecDeque::new()),
parked_writes: FairlyUnsafeCell::new(VecDeque::new()),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
pub async fn read(&self) -> ReadGuard<T> {
ReadFuture(self).await
}
pub fn try_read(&self) -> Option<ReadGuard<T>> {
let reader_count = self.readers.get()?;
self.readers.set(Some(reader_count + 1));
Some(ReadGuard { lock: self })
}
pub async fn write(&self) -> WriteGuard<T> {
WriteFuture(self).await
}
pub fn try_write(&self) -> Option<WriteGuard<T>> {
match self.readers.get() {
Some(0) => {
self.readers.set(None);
Some(WriteGuard { lock: self })
}
_ => None,
}
}
pub async fn set(&self, to: T) {
let mut guard = self.write().await;
*guard = to;
}
pub async fn replace(&self, mut to: T) -> T {
let mut guard = self.write().await;
core::mem::swap(guard.deref_mut(), &mut to);
to
}
pub async fn update(&self, with: impl FnOnce(&T) -> T) {
let mut guard = self.write().await;
*guard = with(&guard);
}
pub async fn fallible_update<E>(&self, with: impl FnOnce(&T) -> Result<T, E>) -> Result<(), E> {
let mut guard = self.write().await;
*guard = with(&guard)?;
Ok(())
}
pub async fn mutate(&self, with: impl FnOnce(&mut T)) {
let mut guard = self.write().await;
with(&mut guard)
}
pub async fn fallible_mutate<E>(
&self,
with: impl FnOnce(&mut T) -> Result<(), E>,
) -> Result<(), E> {
let mut guard = self.write().await;
with(&mut guard)
}
fn wake_next(&self) {
debug_assert_eq!(self.readers.get(), Some(0));
let there_are_no_pending_reads = unsafe { self.parked_reads.borrow().deref().is_empty() };
if there_are_no_pending_reads {
if let Some(next_write) =
unsafe { self.parked_writes.borrow_mut().deref_mut().pop_front() }
{
next_write.wake();
}
} else {
for parked_read in unsafe { self.parked_reads.borrow_mut().deref_mut().drain(..) } {
parked_read.wake();
}
}
}
fn park_read(&self, cx: &mut Context<'_>) {
let mut parked_reads = unsafe { self.parked_reads.borrow_mut() };
parked_reads.deref_mut().push_back(cx.waker().clone());
}
fn park_write(&self, cx: &mut Context<'_>) {
let mut parked_writes = unsafe { self.parked_writes.borrow_mut() };
parked_writes.deref_mut().push_back(cx.waker().clone());
}
}
impl<T> AsMut<T> for RwLock<T> {
fn as_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}
impl<T: fmt::Debug> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_tuple("RwLock");
match self.try_read() {
Some(guard) => {
d.field(&&*guard);
}
None => {
d.field(&format_args!("<locked>"));
}
}
d.finish()
}
}
#[derive(Debug)]
pub struct ReadGuard<'lock, T> {
lock: &'lock RwLock<T>,
}
impl<T> Drop for ReadGuard<'_, T> {
fn drop(&mut self) {
match self.lock.readers.get() {
None => unsafe { unreachable_unchecked() }, Some(reader_count) => {
self.lock.readers.set(Some(reader_count - 1)); if reader_count == 1 {
self.lock.wake_next();
}
}
}
}
}
impl<T> Deref for ReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
let borrowed = unsafe { self.lock.value.borrow() }; unsafe { extend_lifetime(borrowed.deref()) }
}
}
#[derive(Debug)]
pub struct WriteGuard<'lock, T> {
lock: &'lock RwLock<T>,
}
impl<T> Drop for WriteGuard<'_, T> {
fn drop(&mut self) {
self.lock.readers.set(Some(0));
self.lock.wake_next();
}
}
impl<T> Deref for WriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
let borrowed = unsafe { self.lock.value.borrow() }; unsafe { extend_lifetime(borrowed.deref()) }
}
}
impl<T> DerefMut for WriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
let mut borrowed = unsafe { self.lock.value.borrow_mut() }; unsafe { extend_lifetime_mut(borrowed.deref_mut()) }
}
}
struct ReadFuture<'lock, T>(&'lock RwLock<T>);
impl<'lock, T> Future for ReadFuture<'lock, T> {
type Output = ReadGuard<'lock, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.try_read() {
Some(guard) => Poll::Ready(guard),
None => {
self.0.park_read(cx);
Poll::Pending
}
}
}
}
struct WriteFuture<'lock, T>(&'lock RwLock<T>);
impl<'lock, T> Future for WriteFuture<'lock, T> {
type Output = WriteGuard<'lock, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.try_write() {
Some(guard) => Poll::Ready(guard),
None => {
self.0.park_write(cx);
Poll::Pending
}
}
}
}