localq 0.0.2

No-std async primitives for `!Send` tasks.
Documentation
use core::{
    cell::{Cell, UnsafeCell},
    ops::{Deref, DerefMut},
};

use crate::waiter_queue::WaiterQueue;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State {
    Available,
    Reading,
    Writing,
}

/// A `!Sync` async mutex.
pub struct RwLock<T> {
    value: UnsafeCell<T>,
    state: Cell<State>,
    active_reader_count: Cell<usize>,
    waiting_reader_count: Cell<usize>,
    readers_waiting_on_writers: Cell<bool>,
    reader_leader_cancelled: Cell<bool>,
    read_waiters: WaiterQueue,
    write_waiters: WaiterQueue,
}

impl<T> RwLock<T> {
    pub fn new(initial: T) -> Self {
        Self {
            value: UnsafeCell::new(initial),
            state: Cell::new(State::Available),
            active_reader_count: Cell::new(0),
            waiting_reader_count: Cell::new(0),
            readers_waiting_on_writers: Cell::new(false),
            reader_leader_cancelled: Cell::new(false),
            read_waiters: WaiterQueue::new(),
            write_waiters: WaiterQueue::new(),
        }
    }

    /// Attempt to immediately gain read access to the inner value.
    ///
    /// Returns the read guard if successful, otherwise returns None.
    #[inline]
    pub fn try_read(&self) -> Option<ReadGuard<'_, T>> {
        if self.readers_waiting_on_writers.get() {
            return None;
        }
        if self.state.get() != State::Reading && self.state.get() != State::Available {
            return None;
        }
        self.state.set(State::Reading);
        self.active_reader_count.set(self.active_reader_count.get() + 1);
        Some(ReadGuard { lock: self })
    }

    /// Acquire a read guard, waiting to do so if necessary.
    #[inline]
    pub async fn read(&self) -> ReadGuard<'_, T> {
        if self.state.get() == State::Writing || self.write_waiters.waiter_count() > 0 || self.readers_waiting_on_writers.get() {
            // Increment active readers immediately to prevent a writer from starting a write if the
            // reader leader gets dropped before these waiting reads get polled.
            self.waiting_reader_count.set(self.waiting_reader_count.get() + 1);

            loop {
                if !self.readers_waiting_on_writers.get() {
                    // This is the first reader to wait for existing writers. Wait in the writer
                    // queue - any writers that had already started waiting will get to go
                    // first, and future writers will wait to go after this read.

                    struct DropGuard<'a, T>(&'a RwLock<T>);
                    impl<T> Drop for DropGuard<'_, T> {
                        fn drop(&mut self) {
                            self.0.waiting_reader_count.set(self.0.waiting_reader_count.get() - 1);

                            // This waiting read got cancelled - one of the other waiting reads
                            // needs to be upgraded to the write_waiters queue.
                            self.0.reader_leader_cancelled.set(true);
                            self.0.read_waiters.notify(1);
                        }
                    }
                    let drop_guard = DropGuard(self);

                    self.readers_waiting_on_writers.set(true);
                    self.write_waiters
                        .wait_until(|| self.state.get() == State::Available)
                        .await;
                    core::mem::forget(drop_guard);
                    self.readers_waiting_on_writers.set(false);

                    self.state.set(State::Reading);
                    self.active_reader_count.set(self.waiting_reader_count.replace(0));
                    self.read_waiters.notify_all();
                    return ReadGuard { lock: self };
                } else {
                    struct DropGuard<'a, T>(&'a RwLock<T>);
                    impl<T> Drop for DropGuard<'_, T> {
                        fn drop(&mut self) {
                            self.0.waiting_reader_count.set(self.0.waiting_reader_count.get() - 1);
                        }
                    }
                    let drop_guard = DropGuard(self);

                    // Yield to already waiting writers.
                    self.read_waiters
                        .wait_until(|| {
                            !self.readers_waiting_on_writers.get()
                                || self.reader_leader_cancelled.get()
                        })
                        .await;
                    core::mem::forget(drop_guard);

                    // So long as the reader leader didn't get cancelled, we are done.
                    // Otherwise this reader needs to be upgraded to be the leader.
                    if self.reader_leader_cancelled.get() {
                        // NOTE: This is a bit sketchy, but is ok because there are no .await
                        // points until after `readers_waiters_on_writers` is flipped back to true
                        // in the next loop cycle.
                        self.readers_waiting_on_writers.set(false);
                        self.reader_leader_cancelled.set(false);
                    } else {
                        debug_assert!(self.state.get() == State::Reading);
                        return ReadGuard { lock: self };
                    }
                }
            }
        } else {
            debug_assert!(self.state.get() == State::Available || self.state.get() == State::Reading);
            self.state.set(State::Reading);
            self.active_reader_count.set(self.active_reader_count.get() + 1);
            return ReadGuard { lock: self };
        }
    }

    /// Attempt to immediately gain write access to the inner value.
    ///
    /// Returns the write guard if successful, otherwise returns None.
    #[inline]
    pub fn try_write(&self) -> Option<WriteGuard<'_, T>> {
        if self.state.get() == State::Available {
            self.state.set(State::Writing);
            Some(WriteGuard { lock: self })
        } else {
            None
        }
    }

    /// Acquire a write guard, waiting to do so if necessary.
    #[inline]
    pub async fn write(&self) -> WriteGuard<'_, T> {
        self.write_waiters.wait_for(|| self.try_write()).await
    }
}

pub struct ReadGuard<'a, T> {
    lock: &'a RwLock<T>,
}

impl<T> Deref for ReadGuard<'_, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.lock.value.get() }
    }
}

impl<T> Drop for ReadGuard<'_, T> {
    #[inline]
    fn drop(&mut self) {
        debug_assert!(self.lock.state.get() == State::Reading);

        self.lock
            .active_reader_count
            .set(self.lock.active_reader_count.get() - 1);

        if self.lock.active_reader_count.get() == 0 {
            self.lock.state.set(State::Available);
            // Notify a writer (if any)
            self.lock.write_waiters.notify(1);
        }
    }
}

pub struct WriteGuard<'a, T> {
    lock: &'a RwLock<T>,
}

impl<T> Deref for WriteGuard<'_, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.lock.value.get() }
    }
}

impl<T> DerefMut for WriteGuard<'_, T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.lock.value.get() }
    }
}

impl<T> Drop for WriteGuard<'_, T> {
    #[inline]
    fn drop(&mut self) {
        debug_assert!(self.lock.state.get() == State::Writing);

        self.lock.state.set(State::Available);

        // Note this will also notify waiting readers, because the first reader to wait while
        // there are writers will register itself in the writers' queue.
        self.lock.write_waiters.notify(1);
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_rw_lock_immediate() {
        let rw_lock: RwLock<u32> = RwLock::new(42);
        assert_eq!(*pollster::block_on(rw_lock.read()), 42);

        *pollster::block_on(rw_lock.write()) += 2;
        assert_eq!(*pollster::block_on(rw_lock.read()), 44);
    }
}