use std::{
cell::RefCell,
future::poll_fn,
rc::Rc,
task::{Poll, Waker},
};
use derive_more::{Display, Error};
struct Inner<T> {
value: Option<T>,
waker: Option<Waker>,
}
#[derive(derive_more::Debug)]
pub struct Sender<T> {
#[debug("{:p}", Rc::as_ptr(inner))]
inner: Rc<RefCell<Inner<T>>>,
}
#[derive(derive_more::Debug)]
#[must_use = "futures do nothing unless awaited"]
pub struct Receiver<T> {
#[debug("{:p}", Rc::as_ptr(inner))]
inner: Rc<RefCell<Inner<T>>>,
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Rc::new(RefCell::new(Inner {
value: None,
waker: None,
}));
(
Sender {
inner: inner.clone(),
},
Receiver { inner },
)
}
impl<T> Sender<T> {
pub fn send(self, val: T) -> Result<(), T> {
if Rc::strong_count(&self.inner) == 1 {
return Err(val);
}
let mut borrow = self.inner.borrow_mut();
borrow.value = Some(val);
if let Some(waker) = borrow.waker.take() {
waker.wake();
}
Ok(())
}
}
#[derive(Debug, Display, Error, PartialEq, Eq)]
#[display("sender has been dropped")]
pub struct RecvError;
#[derive(Debug, Display, Error, PartialEq, Eq)]
pub enum TryRecvError {
#[display("channel is empty")]
Empty,
#[display("sender has been dropped")]
Closed,
}
impl<T> Receiver<T> {
pub(crate) fn poll_recv(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<T, RecvError>> {
let mut borrow = self.inner.borrow_mut();
if let Some(val) = borrow.value.take() {
return Poll::Ready(Ok(val));
}
if Rc::strong_count(&self.inner) == 1 {
return Poll::Ready(Err(RecvError));
}
borrow.waker = Some(cx.waker().clone());
Poll::Pending
}
pub async fn recv(mut self) -> Result<T, RecvError> {
poll_fn(|cx| self.poll_recv(cx)).await
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let mut borrow = self.inner.borrow_mut();
if let Some(val) = borrow.value.take() {
return Ok(val);
}
if Rc::strong_count(&self.inner) == 1 {
Err(TryRecvError::Closed)
} else {
Err(TryRecvError::Empty)
}
}
}
#[cfg(test)]
mod tests {
use tempest_io::VirtualIo;
use crate::{block_on, spawn};
use super::*;
#[test]
fn test_oneshot_send_recv() {
block_on(VirtualIo::default(), async {
let (tx, rx) = channel();
tx.send(5).unwrap();
assert_eq!(rx.recv().await.unwrap(), 5);
})
}
#[test]
fn test_oneshot_sender_dropped() {
block_on(VirtualIo::default(), async {
let (tx, rx) = channel::<i32>();
drop(tx);
assert_eq!(rx.recv().await, Err(RecvError));
});
}
#[test]
fn test_oneshot_receiver_dropped() {
block_on(VirtualIo::default(), async {
let (tx, rx) = channel::<i32>();
drop(rx);
assert_eq!(tx.send(99), Err(99));
});
}
#[test]
fn test_oneshot_from_task() {
block_on(VirtualIo::default(), async {
let (tx, rx) = channel();
let handle = spawn(async {
tx.send(42).unwrap();
});
handle.await.unwrap();
let result = rx.recv().await.unwrap();
assert_eq!(result, 42);
});
}
#[test]
fn test_oneshot_try_recv_empty() {
block_on(VirtualIo::default(), async {
let (_tx, mut rx) = channel::<i32>();
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
});
}
#[test]
fn test_oneshot_try_recv_closed() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = channel::<i32>();
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
});
}
#[test]
fn test_oneshot_try_recv_value() {
block_on(VirtualIo::default(), async {
let (tx, mut rx) = channel();
tx.send(42).unwrap();
assert_eq!(rx.try_recv(), Ok(42));
});
}
}