use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::{
sync::oneshot::{
Receiver as OneshotReceiver, Sender as OneshotSender, channel as oneshot_channel,
},
time::{self, Duration},
};
use crate::error::{BoxError, RecvError, SendError};
#[derive(Debug)]
#[repr(transparent)]
pub struct Sender<T>(OneshotSender<Result<T, BoxError>>);
impl<T> Sender<T> {
pub fn send(self, t: T) -> Result<(), SendError<T>> {
self.0
.send(Ok(t))
.map_err(|t| SendError::Closed(t.unwrap()))
}
pub async fn closed(&mut self) {
self.0.closed().await
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.0.poll_closed(cx)
}
pub fn send_err<E>(self, err: E) -> Result<(), SendError<BoxError>>
where
E: Into<BoxError>,
{
self.0
.send(Err(err.into()))
.map_err(|err| SendError::Closed(err.err().unwrap()))
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct Receiver<T>(OneshotReceiver<Result<T, BoxError>>);
impl<T> Receiver<T> {
pub fn close(&mut self) {
self.0.close();
}
pub fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn try_recv(&mut self) -> Result<T, RecvError> {
Ok(self.0.try_recv()??)
}
pub fn blocking_recv(self) -> Result<T, RecvError> {
Ok(self.0.blocking_recv()??)
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<T, RecvError> {
match time::timeout(timeout, self).await {
Ok(r) => r,
Err(_) => Err(RecvError::Timeout),
}
}
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx).map(|r| Ok(r??))
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let (tx, rx) = oneshot_channel();
(Sender(tx), Receiver(rx))
}
#[cfg(test)]
mod tests {
use std::future::poll_fn;
use anyhow::Result;
use pretty_assertions::assert_eq;
use super::*;
use crate::error::SendError;
#[tokio::test]
async fn test_send_recv() -> Result<()> {
let (tx, rx) = channel::<u32>();
tx.send(42)?;
assert!(!rx.is_empty());
assert_eq!(rx.await?, 42);
let (tx, rx) = channel::<u32>();
tx.send_err("boom")?;
let result = rx.await;
assert!(matches!(result, Err(RecvError::Other(e)) if e.to_string() == "boom"));
let (tx, rx) = channel::<u32>();
drop(tx);
assert!(matches!(rx.await, Err(RecvError::Closed)));
let (tx, rx) = channel::<u32>();
drop(rx);
assert!(matches!(tx.send(7), Err(SendError::Closed(v)) if v == 7));
let (tx, rx) = channel::<u32>();
drop(rx);
assert!(
matches!(tx.send_err("late"), Err(SendError::Closed(e)) if e.to_string() == "late")
);
Ok(())
}
#[tokio::test]
async fn test_try_send_recv() -> Result<()> {
let (tx, mut rx) = channel::<u32>();
assert!(matches!(rx.try_recv(), Err(RecvError::Empty)));
tx.send(11)?;
assert_eq!(rx.try_recv()?, 11);
let (tx, mut rx) = channel::<u32>();
drop(tx);
assert!(matches!(rx.try_recv(), Err(RecvError::Closed)));
Ok(())
}
#[tokio::test]
async fn test_recv_timeout() -> Result<()> {
let (tx, mut rx) = channel::<u32>();
assert!(matches!(
rx.recv_timeout(Duration::from_millis(100)).await,
Err(RecvError::Timeout)
));
tx.send(99)?;
assert_eq!(rx.recv_timeout(Duration::from_millis(100)).await?, 99);
Ok(())
}
#[tokio::test]
async fn test_close() -> Result<()> {
let (mut tx, rx) = channel::<u32>();
assert!(!tx.is_closed());
drop(rx);
tx.closed().await;
assert!(tx.is_closed());
let (tx, mut rx) = channel::<u32>();
assert!(!rx.is_terminated());
drop(tx);
assert!(!rx.is_terminated());
assert!(matches!(rx.try_recv(), Err(RecvError::Closed)));
assert!(rx.is_terminated());
let (mut tx, mut rx) = channel::<u32>();
rx.close();
poll_fn(|cx| tx.poll_closed(cx)).await;
assert!(tx.is_closed());
Ok(())
}
}