use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::FutureExt;
use tokio::{
sync::oneshot::{
Receiver as OneshotReceiver, Sender as OneshotSender, channel as oneshot_channel,
},
time::{self, Duration},
};
use crate::errors::{BoxError, RecvError, SendError};
#[derive(Debug)]
#[repr(transparent)]
pub struct Sender<T>(OneshotSender<Result<T, BoxError>>);
impl<T> Sender<T> {
pub fn send(self, t: T) -> Result<(), SendError<T>> {
self.0.send(Ok(t)).map_err(|err| match err {
Ok(t) => SendError::Closed(t),
Err(_) => unreachable!(),
})
}
pub async fn closed(&mut self) {
self.0.closed().await
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.0.poll_closed(cx)
}
pub fn send_err<E>(self, err: E) -> Result<(), SendError<BoxError>>
where
E: Into<BoxError>,
{
self.0.send(Err(err.into())).map_err(|err| match err {
Ok(_) => unreachable!(),
Err(e) => SendError::Closed(e),
})
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct Receiver<T>(OneshotReceiver<Result<T, BoxError>>);
impl<T> Receiver<T> {
pub fn close(&mut self) {
self.0.close();
}
pub fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn try_recv(&mut self) -> Result<T, RecvError> {
Ok(self.0.try_recv()??)
}
pub fn blocking_recv(self) -> Result<T, RecvError> {
Ok(self.0.blocking_recv()??)
}
pub fn recv(self) -> impl Future<Output = Result<T, RecvError>> {
self.0.map(|r| Ok(r??))
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<T, RecvError> {
match time::timeout(timeout, self).await {
Ok(r) => r,
Err(_) => Err(RecvError::Timeout),
}
}
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx).map(|r| Ok(r??))
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let (tx, rx) = oneshot_channel();
(Sender(tx), Receiver(rx))
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
use crate::errors::SendError;
#[tokio::test]
async fn test_send_recv() {
let (tx, rx) = channel::<u32>();
tx.send(42).unwrap();
assert!(!rx.is_empty());
assert!(!rx.is_terminated());
assert_eq!(rx.await.unwrap(), 42);
let (tx, rx) = channel::<u32>();
tx.send_err("boom").unwrap();
match rx.recv().await {
Err(RecvError::Other(e)) => assert_eq!(e.to_string(), "boom"),
other => panic!("expected RecvError::Other, got {:?}", other),
}
let (tx, rx) = channel::<u32>();
drop(tx);
assert!(matches!(rx.await, Err(RecvError::Closed)));
}
#[tokio::test]
async fn test_try_send_recv() {
let (tx, mut rx) = channel::<u32>();
assert!(matches!(rx.try_recv(), Err(RecvError::Empty)));
tx.send(11).unwrap();
assert_eq!(rx.try_recv().unwrap(), 11);
let (tx, mut rx) = channel::<u32>();
drop(tx);
assert!(matches!(rx.try_recv(), Err(RecvError::Closed)));
}
#[tokio::test]
async fn test_send_closed() {
let (tx, rx) = channel::<u32>();
drop(rx);
match tx.send(7) {
Err(SendError::Closed(v)) => assert_eq!(v, 7),
other => panic!("expected SendError::Closed(7), got {:?}", other),
}
let (tx, rx) = channel::<u32>();
drop(rx);
match tx.send_err("late") {
Err(SendError::Closed(e)) => assert_eq!(e.to_string(), "late"),
other => panic!("expected SendError::Closed, got {:?}", other),
}
}
#[tokio::test]
async fn test_close_detection() {
let (mut tx, rx) = channel::<u32>();
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
tx.closed().await;
let (mut tx, mut rx) = channel::<u32>();
rx.close();
tx.closed().await;
assert!(tx.is_closed());
}
#[tokio::test]
async fn test_recv_timeout() {
let (tx, mut rx) = channel::<u32>();
assert!(matches!(
rx.recv_timeout(Duration::from_millis(10)).await,
Err(RecvError::Timeout)
));
tx.send(99).unwrap();
assert_eq!(
rx.recv_timeout(Duration::from_millis(10)).await.unwrap(),
99
);
}
}