use std::task::{Context, Poll};
use tokio::sync::mpsc;
#[derive(Debug)]
pub struct Receiver<T> {
inner: mpsc::UnboundedReceiver<T>,
}
#[derive(Debug)]
pub struct Sender<T> {
inner: mpsc::UnboundedSender<T>,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RecvError {
#[error("Disconnected")]
Disconnected,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum TryRecvError {
#[error("Empty")]
Empty,
#[error("Disconnected")]
Disconnected,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RecvTimeoutError {
#[error("Timeout")]
Timeout,
#[error("Disconnected")]
Disconnected,
}
impl From<mpsc::error::TryRecvError> for TryRecvError {
fn from(err: mpsc::error::TryRecvError) -> Self {
match err {
mpsc::error::TryRecvError::Empty => Self::Empty,
mpsc::error::TryRecvError::Disconnected => Self::Disconnected,
}
}
}
impl From<RecvError> for TryRecvError {
fn from(_: RecvError) -> Self {
Self::Disconnected
}
}
impl From<RecvTimeoutError> for TryRecvError {
fn from(_: RecvTimeoutError) -> Self {
Self::Disconnected
}
}
impl From<RecvTimeoutError> for RecvError {
fn from(_: RecvTimeoutError) -> Self {
Self::Disconnected
}
}
impl From<RecvError> for RecvTimeoutError {
fn from(_: RecvError) -> Self {
Self::Disconnected
}
}
impl<T> Receiver<T> {
pub fn recv(&mut self) -> Result<T, RecvError> {
self.inner.blocking_recv().ok_or(RecvError::Disconnected)
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
Ok(self.inner.try_recv()?)
}
pub fn recv_timeout(&mut self, timeout: std::time::Duration) -> Result<T, RecvTimeoutError> {
crate::runtime::Handle::current().block_on(self.recv_timeout_async(timeout))
}
pub async fn recv_timeout_async(
&mut self,
timeout: std::time::Duration,
) -> Result<T, RecvTimeoutError> {
crate::select! {
result = self.recv_async() => {
Ok(result?)
}
() = crate::time::sleep(timeout) => {
Err(RecvTimeoutError::Timeout)
}
}
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.inner.poll_recv(cx)
}
pub async fn recv_async(&mut self) -> Result<T, RecvError> {
self.inner.recv().await.ok_or(RecvError::Disconnected)
}
}
#[derive(thiserror::Error)]
pub enum SendError<T> {
#[error("Disconnected")]
Disconnected(T),
}
impl<T> std::fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected(_t) => f
.debug_tuple("SendError::Disconnected")
.finish_non_exhaustive(),
}
}
}
impl<T> From<mpsc::error::SendError<T>> for SendError<T> {
fn from(e: mpsc::error::SendError<T>) -> Self {
Self::Disconnected(e.0)
}
}
#[derive(thiserror::Error)]
pub enum TrySendError<T> {
#[error("Full")]
Full(T),
#[error("Disconnected")]
Disconnected(T),
}
impl<T> std::fmt::Debug for TrySendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Full(_t) => f.debug_tuple("TrySendError::Full").finish_non_exhaustive(),
Self::Disconnected(_t) => f
.debug_tuple("TrySendError::Disconnected")
.finish_non_exhaustive(),
}
}
}
impl<T> From<mpsc::error::TrySendError<T>> for TrySendError<T> {
fn from(err: mpsc::error::TrySendError<T>) -> Self {
match err {
mpsc::error::TrySendError::Full(t) => Self::Full(t),
mpsc::error::TrySendError::Closed(t) => Self::Disconnected(t),
}
}
}
impl<T> From<SendError<T>> for TrySendError<T> {
fn from(e: SendError<T>) -> Self {
match e {
SendError::Disconnected(t) => Self::Disconnected(t),
}
}
}
impl<T> From<mpsc::error::SendError<T>> for TrySendError<T> {
fn from(e: mpsc::error::SendError<T>) -> Self {
match e {
mpsc::error::SendError(t) => Self::Disconnected(t),
}
}
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
Ok(self.inner.send(value)?)
}
#[allow(clippy::unused_async)]
pub async fn send_async(&self, value: T) -> Result<(), SendError<T>> {
Ok(self.inner.send(value)?)
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
Ok(self.inner.send(value)?)
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[must_use]
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
let (tx, rx) = mpsc::unbounded_channel();
(Sender { inner: tx }, Receiver { inner: rx })
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test]
fn test_unbounded_channel_send_and_try_recv() {
let (tx, mut rx) = unbounded::<i32>();
tx.send(42).unwrap();
let value = rx.try_recv().unwrap();
assert_eq!(value, 42);
}
#[test_log::test]
fn test_unbounded_channel_try_recv_empty() {
let (_tx, mut rx) = unbounded::<i32>();
let result = rx.try_recv();
assert!(matches!(result, Err(TryRecvError::Empty)));
}
#[test_log::test]
fn test_unbounded_channel_try_recv_disconnected() {
let (tx, mut rx) = unbounded::<i32>();
drop(tx);
let result = rx.try_recv();
assert!(matches!(result, Err(TryRecvError::Disconnected)));
}
#[test_log::test]
fn test_sender_send_after_receiver_dropped() {
let (tx, rx) = unbounded::<i32>();
drop(rx);
let result = tx.send(42);
assert!(matches!(result, Err(SendError::Disconnected(42))));
}
#[test_log::test]
fn test_sender_try_send() {
let (tx, mut rx) = unbounded::<i32>();
tx.try_send(100).unwrap();
let value = rx.try_recv().unwrap();
assert_eq!(value, 100);
}
#[test_log::test]
fn test_sender_clone() {
let (tx1, mut rx) = unbounded::<i32>();
let tx2 = tx1.clone();
tx1.send(1).unwrap();
tx2.send(2).unwrap();
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
}
#[test_log::test]
fn test_multiple_messages() {
let (tx, mut rx) = unbounded::<String>();
tx.send("first".to_string()).unwrap();
tx.send("second".to_string()).unwrap();
tx.send("third".to_string()).unwrap();
assert_eq!(rx.try_recv().unwrap(), "first");
assert_eq!(rx.try_recv().unwrap(), "second");
assert_eq!(rx.try_recv().unwrap(), "third");
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_recv_async_success() {
let (tx, mut rx) = unbounded::<i32>();
tx.send(42).unwrap();
let result = rx.recv_async().await;
assert_eq!(result.unwrap(), 42);
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_recv_async_disconnected() {
let (tx, mut rx) = unbounded::<i32>();
drop(tx);
let result = rx.recv_async().await;
assert!(matches!(result, Err(RecvError::Disconnected)));
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_send_async() {
let (tx, mut rx) = unbounded::<i32>();
tx.send_async(99).await.unwrap();
let value = rx.try_recv().unwrap();
assert_eq!(value, 99);
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_recv_timeout_async_success() {
let (tx, mut rx) = unbounded::<i32>();
tx.send(123).unwrap();
let result = rx
.recv_timeout_async(std::time::Duration::from_millis(100))
.await;
assert_eq!(result.unwrap(), 123);
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_recv_timeout_async_timeout() {
let (_tx, mut rx) = unbounded::<i32>();
let result = rx
.recv_timeout_async(std::time::Duration::from_millis(10))
.await;
assert!(matches!(result, Err(RecvTimeoutError::Timeout)));
}
#[test_log::test(crate::internal_test(real_time))]
async fn test_recv_timeout_async_disconnected() {
let (tx, mut rx) = unbounded::<i32>();
drop(tx);
let result = rx
.recv_timeout_async(std::time::Duration::from_millis(100))
.await;
assert!(matches!(result, Err(RecvTimeoutError::Disconnected)));
}
#[test_log::test]
fn test_try_send_disconnected() {
let (tx, rx) = unbounded::<i32>();
drop(rx);
let result = tx.try_send(42);
assert!(matches!(result, Err(TrySendError::Disconnected(42))));
}
}