use std::task::{Context, Poll};
use futures_util::FutureExt;
use tokio::{
sync::mpsc::Receiver as MpscReceiver,
time::{self, Duration},
};
pub use tokio::sync::mpsc::{OwnedPermit, Permit, Sender, WeakSender, error};
use crate::errors::RecvError;
#[derive(Debug)]
#[repr(transparent)]
pub struct Receiver<T>(MpscReceiver<T>);
impl<T> Receiver<T> {
pub fn recv(&mut self) -> impl Future<Output = Result<T, RecvError>> {
self.0.recv().map(|v| v.ok_or(RecvError::Closed))
}
pub fn recv_many(&mut self, buffer: &mut Vec<T>, limit: usize) -> impl Future<Output = usize> {
self.0.recv_many(buffer, limit)
}
pub fn try_recv(&mut self) -> Result<T, RecvError> {
self.0.try_recv().map_err(Into::into)
}
pub fn blocking_recv(&mut self) -> Result<T, RecvError> {
self.0.blocking_recv().ok_or(RecvError::Closed)
}
pub fn blocking_recv_many(&mut self, buffer: &mut Vec<T>, limit: usize) -> usize {
self.0.blocking_recv_many(buffer, limit)
}
pub fn close(&mut self) {
self.0.close();
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn capacity(&self) -> usize {
self.0.capacity()
}
pub fn max_capacity(&self) -> usize {
self.0.max_capacity()
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
self.0.poll_recv(cx).map(|r| r.ok_or(RecvError::Closed))
}
pub fn poll_recv_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<T>,
limit: usize,
) -> Poll<usize> {
self.0.poll_recv_many(cx, buffer, limit)
}
pub fn sender_strong_count(&self) -> usize {
self.0.sender_strong_count()
}
pub fn sender_weak_count(&self) -> usize {
self.0.sender_weak_count()
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<T, RecvError> {
match time::timeout(timeout, self.0.recv()).await {
Ok(Some(v)) => Ok(v),
Ok(None) => Err(RecvError::Closed),
Err(_) => Err(RecvError::Timeout),
}
}
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
(tx, Receiver(rx))
}