use alloc::{collections::VecDeque, rc::Rc};
use core::{
cell::{Cell, RefCell},
task::{Poll, Waker},
};
use crate::waiter_queue::WaiterQueue;
pub enum SendError<T> {
Shutdown(T),
}
pub enum TrySendError<T> {
Full(T),
Shutdown(T),
}
impl<T> core::fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Shutdown(_) => write!(f, "SendError::Shutdown"),
}
}
}
impl<T> core::fmt::Debug for TrySendError<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Full(_) => write!(f, "TrySendError::Full"),
Self::Shutdown(_) => write!(f, "TrySendError::Shutdown"),
}
}
}
#[derive(Debug)]
pub enum TryRecvError {
Empty,
Shutdown,
}
#[derive(Debug)]
pub enum RecvError {
Shutdown,
}
pub struct Sender<T> {
inner: Rc<Inner<T>>,
}
pub struct Receiver<T> {
inner: Rc<Inner<T>>,
}
struct Inner<T> {
queue: RefCell<VecDeque<T>>,
waiting_senders: WaiterQueue,
receiver_waker: Cell<Option<Waker>>,
is_shutdown: bool,
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let inner = Rc::new(Inner {
queue: RefCell::new(VecDeque::with_capacity(capacity)),
waiting_senders: WaiterQueue::new(),
receiver_waker: Cell::new(None),
is_shutdown: false,
});
let sender = Sender {
inner: inner.clone(),
};
let receiver = Receiver { inner };
(sender, receiver)
}
impl<T> Receiver<T> {
pub fn len(&self) -> usize {
self.inner.queue.borrow().len()
}
pub fn capacity(&self) -> usize {
self.inner.queue.borrow().capacity()
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
if self.inner.is_shutdown && self.len() == 0 {
return Err(RecvError::Shutdown);
}
let item = core::future::poll_fn(|cx| {
let mut queue = self.inner.queue.borrow_mut();
if let Some(popped) = queue.pop_front() {
Poll::Ready(popped)
} else {
if let Some(waker) = self.inner.receiver_waker.take() {
if waker.will_wake(cx.waker()) {
self.inner.receiver_waker.set(Some(waker));
return Poll::Pending;
}
}
self.inner.receiver_waker.set(Some(cx.waker().clone()));
Poll::Pending
}
})
.await;
self.inner.waiting_senders.notify(1);
Ok(item)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.inner.is_shutdown && self.len() == 0 {
return Err(TryRecvError::Shutdown);
}
let mut queue = self.inner.queue.borrow_mut();
let item = queue.pop_front().ok_or(TryRecvError::Empty)?;
self.inner.waiting_senders.notify(1);
Ok(item)
}
}
impl<T> Sender<T> {
pub fn len(&self) -> usize {
self.inner.queue.borrow().len()
}
pub fn capacity(&self) -> usize {
self.inner.queue.borrow().capacity()
}
pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
if self.inner.is_shutdown {
return Err(TrySendError::Shutdown(item));
}
let mut queue = self.inner.queue.borrow_mut();
if queue.len() == queue.capacity() {
return Err(TrySendError::Full(item));
}
queue.push_back(item);
if let Some(waker) = self.inner.receiver_waker.take() {
waker.wake();
}
Ok(())
}
pub async fn send(&self, item: T) -> Result<(), SendError<T>> {
if self.inner.is_shutdown {
return Err(SendError::Shutdown(item));
}
let queue = self.inner.queue.borrow_mut();
if queue.len() == queue.capacity() {
drop(queue);
self.inner
.waiting_senders
.wait_until(|| {
let queue = self.inner.queue.borrow();
queue.len() < queue.capacity()
})
.await;
} else {
drop(queue);
}
let mut queue = self.inner.queue.borrow_mut();
queue.push_back(item);
if let Some(waker) = self.inner.receiver_waker.take() {
waker.wake();
}
Ok(())
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_local_mpsc() {
use alloc::{rc::Rc, vec};
let waiter_count = 10;
let ex = async_executor::LocalExecutor::new();
pollster::block_on(ex.run(async {
let (sender, mut receiver) = channel(4);
let acquire_starts = Rc::new(async_unsync::semaphore::Semaphore::new(0));
for i in 0..sender.capacity() {
sender.try_send(i).unwrap();
}
for i in 0..waiter_count {
let sender = sender.clone();
let acquire_starts = acquire_starts.clone();
ex.spawn(async move {
acquire_starts.add_permits(1);
sender.send(10 + i).await.unwrap();
})
.detach();
}
for _ in 0..waiter_count {
acquire_starts.acquire().await.unwrap().forget();
}
let mut received = vec![];
for _ in 0..sender.capacity() + waiter_count {
let item = receiver.recv().await.unwrap();
received.push(item);
}
assert_eq!(
received,
&[0, 1, 2, 3, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,]
);
assert!(receiver.try_recv().is_err());
}));
}
}