use std::{
cell::RefCell,
collections::VecDeque,
future::poll_fn,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
};
use derive_more::{Display, Error};
use futures::Stream;
struct Inner<T> {
queue: VecDeque<T>,
rx_alive: bool,
rx_waker: Option<Waker>,
tx_wakers: Vec<Waker>,
}
pub struct BoundedSender<T> {
inner: Rc<RefCell<Inner<T>>>,
}
impl<T> Clone for BoundedSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
pub struct Receiver<T> {
inner: Rc<RefCell<Inner<T>>>,
}
pub fn bounded<T>(cap: usize) -> (BoundedSender<T>, Receiver<T>) {
assert_ne!(cap, 0, "a bounded channel with capacity 0 does not work");
let mut queue = VecDeque::new();
queue.reserve_exact(cap);
let inner = Rc::new(RefCell::new(Inner {
queue,
rx_alive: true,
rx_waker: None,
tx_wakers: Vec::new(),
}));
let tx = BoundedSender {
inner: inner.clone(),
};
let rx = Receiver { inner };
(tx, rx)
}
#[derive(Debug, Display, Error)]
#[display("receiver has been dropped")]
pub struct SendError<T>(pub T);
#[derive(Debug, Display, Error)]
pub enum TrySendError<T> {
#[display("channel is full")]
Full(T),
#[display("receiver has been dropped")]
Closed(T),
}
impl<T> BoundedSender<T> {
pub async fn send(&mut self, val: T) -> Result<(), SendError<T>> {
let mut val = Some(val);
poll_fn(|cx| {
if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
waker.wake();
}
match self.try_send(val.take().unwrap()) {
Ok(()) => Poll::Ready(Ok(())),
Err(TrySendError::Full(v)) => {
val = Some(v);
self.inner.borrow_mut().tx_wakers.push(cx.waker().clone());
Poll::Pending
}
Err(TrySendError::Closed(v)) => Poll::Ready(Err(SendError(v))),
}
})
.await
}
pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
let mut inner = self.inner.borrow_mut();
if !inner.rx_alive {
Err(TrySendError::Closed(val))
} else if inner.queue.len() == inner.queue.capacity() {
Err(TrySendError::Full(val))
} else {
inner.queue.push_back(val);
Ok(())
}
}
}
#[derive(Debug, Display, Error, PartialEq, Eq)]
#[display("all senders have been dropped")]
pub struct RecvError;
#[derive(Debug, Display, Error, PartialEq, Eq)]
pub enum TryRecvError {
#[display("channel is empty")]
Empty,
#[display("all senders have been dropped")]
Closed,
}
impl<T> Receiver<T> {
pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
match self.try_recv() {
Ok(val) => {
for waker in self.inner.borrow_mut().tx_wakers.drain(..) {
waker.wake();
}
Poll::Ready(Ok(val))
}
Err(TryRecvError::Empty) => {
self.inner.borrow_mut().rx_waker = Some(cx.waker().clone());
Poll::Pending
}
Err(TryRecvError::Closed) => Poll::Ready(Err(RecvError)),
}
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
poll_fn(|cx| self.poll_recv(cx)).await
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if let Some(val) = self.inner.borrow_mut().queue.pop_front() {
return Ok(val);
}
if Rc::strong_count(&self.inner) == 1 {
Err(TryRecvError::Closed)
} else {
Err(TryRecvError::Empty)
}
}
}
impl<T> Stream for Receiver<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_recv(cx).map(|r| r.ok())
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.inner.borrow().queue.len(), None)
}
}
impl<T> Drop for BoundedSender<T> {
fn drop(&mut self) {
if Rc::strong_count(&self.inner) == 2 {
if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
waker.wake();
}
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.inner.borrow_mut().rx_alive = false;
}
}
#[cfg(test)]
mod tests {
use std::{
pin::{Pin, pin},
task::Poll,
};
use futures::StreamExt;
use tempest_io::VirtualIo;
use crate::{block_on, spawn};
use super::*;
#[test]
fn test_try_send_one() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(1);
assert!(tx.try_send(42).is_ok());
});
}
#[test]
fn test_try_send_exactly_full() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(2);
assert!(tx.try_send(1).is_ok());
assert!(tx.try_send(2).is_ok());
});
}
#[test]
fn test_try_send_over_full() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(1);
tx.try_send(1).unwrap();
match tx.try_send(99) {
Err(TrySendError::Full(v)) => assert_eq!(v, 99),
_ => panic!("expected Full"),
}
});
}
#[test]
fn test_try_send_closed() {
block_on(VirtualIo::default(), async {
let (mut tx, rx) = bounded::<i32>(1);
drop(rx);
match tx.try_send(99) {
Err(TrySendError::Closed(v)) => assert_eq!(v, 99),
_ => panic!("expected Closed"),
}
});
}
#[test]
fn test_send_one() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(1);
tx.send(42).await.unwrap();
});
}
#[test]
fn test_send_exactly_full() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(2);
tx.send(1).await.unwrap();
tx.send(2).await.unwrap();
});
}
#[test]
fn test_send_pending_when_full() {
block_on(VirtualIo::default(), async {
let (mut tx, _rx) = bounded(1);
tx.try_send(1).unwrap();
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let mut fut = pin!(tx.send(2));
assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
});
}
#[test]
fn test_send_when_full_eventually_resolves() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(1);
tx.try_send(1).unwrap();
let mut handle = spawn(async move { tx.send(2).await.unwrap() });
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
assert!(matches!(Pin::new(&mut handle).poll(&mut cx), Poll::Pending));
assert_eq!(rx.recv().await, Ok(1));
assert_eq!(rx.recv().await, Ok(2));
assert!(rx.try_recv().is_err())
});
}
#[test]
fn test_recv_when_empty_eventually_resolves() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(1);
spawn(async move {
assert_eq!(rx.recv().await, Ok(1));
assert_eq!(rx.recv().await, Ok(2));
});
tx.send(1).await.unwrap();
assert!(matches!(tx.try_send(2), Err(TrySendError::Full(2))));
tx.send(2).await.unwrap();
});
}
#[test]
fn test_send_closed() {
block_on(VirtualIo::default(), async {
let (mut tx, rx) = bounded::<i32>(1);
drop(rx);
match tx.send(99).await {
Err(SendError(v)) => assert_eq!(v, 99),
Ok(()) => panic!("expected Err"),
}
});
}
#[test]
fn test_try_recv_one() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(1);
tx.try_send(42).unwrap();
assert_eq!(rx.try_recv().unwrap(), 42);
});
}
#[test]
fn test_try_recv_in_order() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(3);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
assert_eq!(rx.try_recv().unwrap(), 3);
});
}
#[test]
fn test_try_recv_empty() {
block_on(VirtualIo::default(), async {
let (_tx, mut rx) = bounded::<i32>(1);
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
});
}
#[test]
fn test_try_recv_closed() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = bounded::<i32>(1);
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
});
}
#[test]
fn test_recv_one() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(1);
tx.send(42).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 42);
});
}
#[test]
fn test_recv_in_order() {
block_on(VirtualIo::default(), async {
let (mut tx, mut rx) = bounded(3);
tx.send(1).await.unwrap();
tx.send(2).await.unwrap();
tx.send(3).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 1);
assert_eq!(rx.recv().await.unwrap(), 2);
assert_eq!(rx.recv().await.unwrap(), 3);
});
}
#[test]
fn test_recv_pending_when_empty() {
block_on(VirtualIo::default(), async {
let (_tx, mut rx) = bounded::<i32>(1);
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let mut fut = pin!(rx.recv());
assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
});
}
#[test]
fn test_recv_closed() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = bounded::<i32>(1);
drop(tx);
assert_eq!(rx.recv().await, Err(RecvError));
});
}
#[test]
fn test_recv_woken_when_last_sender_dropped() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = bounded::<i32>(1);
spawn(async move {
drop(tx);
});
assert_eq!(rx.recv().await, Err(RecvError));
});
}
#[test]
fn test_recv_woken_when_last_of_multiple_senders_dropped() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = bounded::<i32>(1);
let tx2 = tx.clone();
spawn(async move {
drop(tx);
drop(tx2);
});
assert_eq!(rx.recv().await, Err(RecvError));
});
}
#[test]
fn test_stream_recv() {
const ITEMS: &[i32; 3] = &[1, 2, 3];
block_on(VirtualIo::default(), async {
let (mut tx, rx) = bounded::<i32>(1);
spawn(async move {
for &item in ITEMS {
tx.send(item).await.unwrap();
}
});
let result: Vec<_> = rx.collect().await;
assert_eq!(result, ITEMS);
})
}
}