localq 0.0.2

No-std async primitives for `!Send` tasks.
Documentation
use alloc::{collections::VecDeque, rc::Rc};
use core::{
    cell::{Cell, RefCell},
    task::{Poll, Waker},
};

use crate::waiter_queue::WaiterQueue;

pub enum SendError<T> {
    Shutdown(T),
}

pub enum TrySendError<T> {
    Full(T),
    Shutdown(T),
}

impl<T> core::fmt::Debug for SendError<T> {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::Shutdown(_) => write!(f, "SendError::Shutdown"),
        }
    }
}

impl<T> core::fmt::Debug for TrySendError<T> {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::Full(_) => write!(f, "TrySendError::Full"),
            Self::Shutdown(_) => write!(f, "TrySendError::Shutdown"),
        }
    }
}

#[derive(Debug)]
pub enum TryRecvError {
    Empty,
    Shutdown,
}

#[derive(Debug)]
pub enum RecvError {
    Shutdown,
}

pub struct Sender<T> {
    inner: Rc<Inner<T>>,
}

pub struct Receiver<T> {
    inner: Rc<Inner<T>>,
}

struct Inner<T> {
    queue: RefCell<VecDeque<T>>,
    waiting_senders: WaiterQueue,
    receiver_waker: Cell<Option<Waker>>,
    is_shutdown: bool,
}

pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
    let inner = Rc::new(Inner {
        queue: RefCell::new(VecDeque::with_capacity(capacity)),
        waiting_senders: WaiterQueue::new(),
        receiver_waker: Cell::new(None),
        is_shutdown: false,
    });

    let sender = Sender {
        inner: inner.clone(),
    };
    let receiver = Receiver { inner };

    (sender, receiver)
}

impl<T> Receiver<T> {
    pub fn len(&self) -> usize {
        self.inner.queue.borrow().len()
    }

    pub fn capacity(&self) -> usize {
        self.inner.queue.borrow().capacity()
    }

    pub async fn recv(&mut self) -> Result<T, RecvError> {
        if self.inner.is_shutdown && self.len() == 0 {
            return Err(RecvError::Shutdown);
        }

        let item = core::future::poll_fn(|cx| {
            let mut queue = self.inner.queue.borrow_mut();

            if let Some(popped) = queue.pop_front() {
                Poll::Ready(popped)
            } else {
                if let Some(waker) = self.inner.receiver_waker.take() {
                    if waker.will_wake(cx.waker()) {
                        // Optimization to avoid cloning the waker.
                        self.inner.receiver_waker.set(Some(waker));
                        return Poll::Pending;
                    }
                }

                self.inner.receiver_waker.set(Some(cx.waker().clone()));
                Poll::Pending
            }
        })
        .await;

        self.inner.waiting_senders.notify(1);

        Ok(item)
    }

    pub fn try_recv(&self) -> Result<T, TryRecvError> {
        if self.inner.is_shutdown && self.len() == 0 {
            return Err(TryRecvError::Shutdown);
        }

        let mut queue = self.inner.queue.borrow_mut();
        let item = queue.pop_front().ok_or(TryRecvError::Empty)?;
        self.inner.waiting_senders.notify(1);
        Ok(item)
    }
}

impl<T> Sender<T> {
    pub fn len(&self) -> usize {
        self.inner.queue.borrow().len()
    }

    pub fn capacity(&self) -> usize {
        self.inner.queue.borrow().capacity()
    }

    pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
        if self.inner.is_shutdown {
            return Err(TrySendError::Shutdown(item));
        }

        let mut queue = self.inner.queue.borrow_mut();

        if queue.len() == queue.capacity() {
            return Err(TrySendError::Full(item));
        }

        queue.push_back(item);

        if let Some(waker) = self.inner.receiver_waker.take() {
            waker.wake();
        }

        Ok(())
    }

    pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
        if self.inner.is_shutdown {
            return Err(SendError::Shutdown(item));
        }

        let queue = self.inner.queue.borrow_mut();

        if queue.len() == queue.capacity() {
            drop(queue);
            self.inner
                .waiting_senders
                .wait_until(|| {
                    let queue = self.inner.queue.borrow();
                    queue.len() < queue.capacity()
                })
                .await;
        } else {
            drop(queue);
        }

        let mut queue = self.inner.queue.borrow_mut();
        queue.push_back(item);

        if let Some(waker) = self.inner.receiver_waker.take() {
            waker.wake();
        }

        Ok(())
    }
}

impl<T> Clone for Sender<T> {
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_local_mpsc() {
        use alloc::{rc::Rc, vec};

        let waiter_count = 10;
        let ex = async_executor::LocalExecutor::new();

        pollster::block_on(ex.run(async {
            // Acquire all the buffers
            let (sender, mut receiver) = channel(4);
            let acquire_starts = Rc::new(async_unsync::semaphore::Semaphore::new(0));

            for i in 0..sender.capacity() {
                sender.try_send(i).unwrap();
            }

            // Spawn some acquires to force some sender waiting.
            for i in 0..waiter_count {
                let sender = sender.clone();
                let acquire_starts = acquire_starts.clone();
                ex.spawn(async move {
                    acquire_starts.add_permits(1);
                    sender.send(10 + i).await.unwrap();
                })
                .detach();
            }

            for _ in 0..waiter_count {
                acquire_starts.acquire().await.unwrap().forget();
            }

            // Receive all the items
            let mut received = vec![];
            for _ in 0..sender.capacity() + waiter_count {
                let item = receiver.recv().await.unwrap();
                received.push(item);
            }

            assert_eq!(
                received,
                &[0, 1, 2, 3, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,]
            );
            assert!(receiver.try_recv().is_err());
        }));
    }
}