use std::task::{Context, Poll};
use futures_util::FutureExt;
use tokio::{
sync::mpsc::Receiver as MpscReceiver,
time::{self, Duration},
};
pub use tokio::sync::mpsc::{OwnedPermit, Permit, Sender, WeakSender, error};
use crate::errors::RecvError;
#[derive(Debug)]
#[repr(transparent)]
pub struct Receiver<T>(MpscReceiver<T>);
impl<T> Receiver<T> {
pub fn recv(&mut self) -> impl Future<Output = Result<T, RecvError>> {
self.0.recv().map(|v| v.ok_or(RecvError::Closed))
}
pub fn recv_many(&mut self, buffer: &mut Vec<T>, limit: usize) -> impl Future<Output = usize> {
self.0.recv_many(buffer, limit)
}
pub fn try_recv(&mut self) -> Result<T, RecvError> {
self.0.try_recv().map_err(Into::into)
}
pub fn blocking_recv(&mut self) -> Result<T, RecvError> {
self.0.blocking_recv().ok_or(RecvError::Closed)
}
pub fn blocking_recv_many(&mut self, buffer: &mut Vec<T>, limit: usize) -> usize {
self.0.blocking_recv_many(buffer, limit)
}
pub fn close(&mut self) {
self.0.close();
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn capacity(&self) -> usize {
self.0.capacity()
}
pub fn max_capacity(&self) -> usize {
self.0.max_capacity()
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
self.0.poll_recv(cx).map(|r| r.ok_or(RecvError::Closed))
}
pub fn poll_recv_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<T>,
limit: usize,
) -> Poll<usize> {
self.0.poll_recv_many(cx, buffer, limit)
}
pub fn sender_strong_count(&self) -> usize {
self.0.sender_strong_count()
}
pub fn sender_weak_count(&self) -> usize {
self.0.sender_weak_count()
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<T, RecvError> {
match time::timeout(timeout, self.0.recv()).await {
Ok(Some(v)) => Ok(v),
Ok(None) => Err(RecvError::Closed),
Err(_) => Err(RecvError::Timeout),
}
}
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
(tx, Receiver(rx))
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_recv() {
let (tx, mut rx) = channel::<u32>(4);
tx.send(1).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 1);
assert!(matches!(rx.try_recv(), Err(RecvError::Empty)));
tx.send(2).await.unwrap();
assert_eq!(rx.try_recv().unwrap(), 2);
drop(tx);
assert!(matches!(rx.recv().await, Err(RecvError::Closed)));
assert!(matches!(rx.try_recv(), Err(RecvError::Closed)));
let (tx, mut rx) = channel::<u32>(4);
let rx_handle = tokio::task::spawn_blocking(move || {
let v = rx.blocking_recv().unwrap();
let closed = rx.blocking_recv();
(v, closed)
});
tx.send(42).await.unwrap();
drop(tx);
let (v, closed) = rx_handle.await.unwrap();
assert_eq!(v, 42);
assert!(matches!(closed, Err(RecvError::Closed)));
}
#[tokio::test]
async fn recv_timeout() {
let (tx, mut rx) = channel::<u32>(1);
assert!(matches!(
rx.recv_timeout(Duration::from_millis(10)).await,
Err(RecvError::Timeout),
));
tx.send(7).await.unwrap();
assert_eq!(rx.recv_timeout(Duration::from_secs(1)).await.unwrap(), 7);
drop(tx);
assert!(matches!(
rx.recv_timeout(Duration::from_secs(1)).await,
Err(RecvError::Closed),
));
}
#[tokio::test]
async fn recv_many_behavior() {
let (tx, mut rx) = channel::<u32>(8);
for i in 0..5 {
tx.send(i).await.unwrap();
}
let mut buf = Vec::new();
let n = rx.recv_many(&mut buf, 3).await;
assert_eq!(n, 3);
assert_eq!(buf, vec![0, 1, 2]);
let rx_handle = tokio::task::spawn_blocking(move || {
let mut buf = Vec::new();
let n = rx.blocking_recv_many(&mut buf, 10);
(n, buf)
});
drop(tx);
let (n, buf) = rx_handle.await.unwrap();
assert_eq!(n, 2);
assert_eq!(buf, vec![3, 4]);
}
#[tokio::test]
async fn channel_state_and_close() {
let (tx, mut rx) = channel::<u32>(4);
assert!(rx.is_empty());
assert!(!rx.is_closed());
assert_eq!(rx.len(), 0);
assert_eq!(rx.capacity(), 4);
assert_eq!(rx.max_capacity(), 4);
tx.send(1).await.unwrap();
tx.send(2).await.unwrap();
assert_eq!(rx.len(), 2);
assert!(!rx.is_empty());
assert_eq!(rx.capacity(), 2);
assert_eq!(rx.max_capacity(), 4);
let tx2 = tx.clone();
let weak = tx.downgrade();
assert_eq!(rx.sender_strong_count(), 2);
assert_eq!(rx.sender_weak_count(), 1);
drop(tx2);
drop(weak);
assert_eq!(rx.sender_strong_count(), 1);
assert_eq!(rx.sender_weak_count(), 0);
rx.close();
assert!(rx.is_closed());
assert!(tx.send(3).await.is_err());
assert_eq!(rx.recv().await.unwrap(), 1);
assert_eq!(rx.recv().await.unwrap(), 2);
assert!(matches!(rx.recv().await, Err(RecvError::Closed)));
}
}