use std::pin::Pin;
use futures::Stream;
#[cfg(async_channel_impl = "tokio")]
mod inner {
pub use tokio::sync::mpsc::error::{SendError, TryRecvError};
use tokio::sync::mpsc::{Receiver as InnerReceiver, Sender as InnerSender};
#[derive(Debug, PartialEq, Eq)]
pub struct RecvError;
impl std::fmt::Display for RecvError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, stringify!(RecvError))
}
}
impl std::error::Error for RecvError {}
pub struct Sender<T>(pub(super) InnerSender<T>);
pub struct Receiver<T>(pub(super) InnerReceiver<T>);
pub struct BoundedStream<T>(pub(super) tokio_stream::wrappers::ReceiverStream<T>);
pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
match e {
TryRecvError::Empty => None,
TryRecvError::Disconnected => Some(RecvError),
}
}
#[must_use]
pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
let (sender, receiver) = tokio::sync::mpsc::channel(len);
(Sender(sender), Receiver(receiver))
}
}
#[cfg(async_channel_impl = "flume")]
mod inner {
pub use flume::{RecvError, SendError, TryRecvError};
use flume::{r#async::RecvStream, Receiver as InnerReceiver, Sender as InnerSender};
pub struct Sender<T>(pub(super) InnerSender<T>);
pub struct Receiver<T>(pub(super) InnerReceiver<T>);
pub struct BoundedStream<T: 'static>(pub(super) RecvStream<'static, T>);
pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
match e {
TryRecvError::Empty => None,
TryRecvError::Disconnected => Some(RecvError::Disconnected),
}
}
#[must_use]
pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
let (sender, receiver) = flume::bounded(len);
(Sender(sender), Receiver(receiver))
}
}
#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
mod inner {
pub use async_std::channel::{RecvError, SendError, TryRecvError};
use async_std::channel::{Receiver as InnerReceiver, Sender as InnerSender};
pub struct Sender<T>(pub(super) InnerSender<T>);
pub struct Receiver<T>(pub(super) InnerReceiver<T>);
pub struct BoundedStream<T>(pub(super) InnerReceiver<T>);
pub(super) fn try_recv_error_to_recv_error(e: TryRecvError) -> Option<RecvError> {
match e {
TryRecvError::Empty => None,
TryRecvError::Closed => Some(RecvError),
}
}
#[must_use]
pub fn bounded<T>(len: usize) -> (Sender<T>, Receiver<T>) {
let (sender, receiver) = async_std::channel::bounded(len);
(Sender(sender), Receiver(receiver))
}
}
pub use inner::*;
impl<T> Sender<T> {
pub async fn send(&self, msg: T) -> Result<(), SendError<T>> {
#[cfg(async_channel_impl = "flume")]
let result = self.0.send_async(msg).await;
#[cfg(not(all(async_channel_impl = "flume")))]
let result = self.0.send(msg).await;
result
}
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Result<T, RecvError> {
#[cfg(async_channel_impl = "flume")]
let result = self.0.recv_async().await;
#[cfg(async_channel_impl = "tokio")]
let result = self.0.recv().await.ok_or(RecvError);
#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
let result = self.0.recv().await;
result
}
pub fn into_stream(self) -> BoundedStream<T>
where
T: 'static,
{
#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
let result = self.0;
#[cfg(async_channel_impl = "tokio")]
let result = tokio_stream::wrappers::ReceiverStream::new(self.0);
#[cfg(async_channel_impl = "flume")]
let result = self.0.into_stream();
BoundedStream(result)
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.0.try_recv()
}
pub async fn drain_at_least_one(&mut self) -> Result<Vec<T>, RecvError> {
let first = self.recv().await?;
let mut ret = vec![first];
loop {
match self.try_recv() {
Ok(x) => ret.push(x),
Err(e) => {
if let Some(e) = try_recv_error_to_recv_error(e) {
tracing::error!(
"Tried to empty {:?} queue but it disconnected while we were emptying it ({} items are being dropped)",
std::any::type_name::<Self>(),
ret.len()
);
return Err(e);
}
break;
}
}
}
Ok(ret)
}
pub fn drain(&mut self) -> Result<Vec<T>, RecvError> {
let mut result = Vec::new();
loop {
match self.try_recv() {
Ok(t) => result.push(t),
Err(e) => {
if let Some(e) = try_recv_error_to_recv_error(e) {
return Err(e);
}
break;
}
}
}
Ok(result)
}
}
impl<T> Stream for BoundedStream<T> {
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
#[cfg(async_channel_impl = "flume")]
return <flume::r#async::RecvStream<T>>::poll_next(Pin::new(&mut self.0), cx);
#[cfg(async_channel_impl = "tokio")]
return <tokio_stream::wrappers::ReceiverStream<T> as Stream>::poll_next(
Pin::new(&mut self.0),
cx,
);
#[cfg(not(any(async_channel_impl = "flume", async_channel_impl = "tokio")))]
return <async_std::channel::Receiver<T> as Stream>::poll_next(Pin::new(&mut self.0), cx);
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T> std::fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender").finish()
}
}
impl<T> std::fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Receiver").finish()
}
}