#![allow(dead_code)]
use std::time::Duration;
#[cfg(feature = "crossbeam")]
use crossbeam_channel::{
bounded as bounded_ext, select, unbounded as unbounded_ext, Receiver as ReceiverExt,
RecvError as ExtRecvError, RecvTimeoutError as ExtRecvTimeoutError, SendError as SendErrorExt,
Sender as SenderExt, TryRecvError as ExtTryRecvError,
};
#[cfg(feature = "flume")]
use flume::{
bounded as bounded_ext, unbounded as unbounded_ext, Receiver as ReceiverExt,
RecvError as ExtRecvError, RecvTimeoutError as ExtRecvTimeoutError, SendError as SendErrorExt,
Sender as SenderExt, TryRecvError as ExtTryRecvError,
};
pub trait ChannelItem: Send + 'static {}
impl<T: Send + 'static> ChannelItem for T {}
pub type SendError<T> = SendErrorExt<T>;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum RecvError {
Disconnected,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TryRecvError {
Empty,
Disconnected,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum RecvTimeoutError {
Timeout,
Disconnected,
}
impl From<ExtRecvError> for RecvError {
fn from(_: ExtRecvError) -> Self {
Self::Disconnected
}
}
impl From<ExtTryRecvError> for TryRecvError {
fn from(e: ExtTryRecvError) -> Self {
match e {
ExtTryRecvError::Empty => Self::Empty,
ExtTryRecvError::Disconnected => Self::Disconnected,
}
}
}
impl From<ExtRecvTimeoutError> for RecvTimeoutError {
fn from(e: ExtRecvTimeoutError) -> Self {
match e {
ExtRecvTimeoutError::Timeout => Self::Timeout,
ExtRecvTimeoutError::Disconnected => Self::Disconnected,
}
}
}
pub fn bounded<T: Send + 'static>(size: usize) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = bounded_ext(size);
(Sender(tx), Receiver(rx))
}
pub fn unbounded<T: Send + 'static>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let (tx, rx) = unbounded_ext();
(UnboundedSender(tx), UnboundedReceiver(rx))
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum SelectResult<In1, In2> {
A(Result<In1, RecvError>),
B(Result<In2, RecvError>),
}
#[cfg(feature = "crossbeam")]
#[macro_use]
mod select_impl {
macro_rules! select_impl {
($self:expr, $other:expr) => {
select! {
recv($self.0) -> el => SelectResult::A(el.map_err(RecvError::from)),
recv($other.0) -> el => SelectResult::B(el.map_err(RecvError::from)),
}
};
}
macro_rules! select_timeout_impl {
($self:expr, $other:expr, $timeout:expr) => {
select! {
recv($self.0) -> el => Ok(SelectResult::A(el.map_err(RecvError::from))),
recv($other.0) -> el => Ok(SelectResult::B(el.map_err(RecvError::from))),
default($timeout) => Err(RecvTimeoutError::Timeout),
}
};
}
}
#[cfg(feature = "flume")]
#[macro_use]
mod select_impl {
macro_rules! select_impl {
($self:expr, $other:expr) => {{
flume::Selector::new()
.recv(&$self.0, |el| SelectResult::A(el.map_err(RecvError::from)))
.recv(&$other.0, |el| SelectResult::B(el.map_err(RecvError::from)))
.wait()
}};
}
macro_rules! select_timeout_impl {
($self:expr, $other:expr, $timeout:expr) => {
flume::Selector::new()
.recv(&$self.0, |el| SelectResult::A(el.map_err(RecvError::from)))
.recv(&$other.0, |el| SelectResult::B(el.map_err(RecvError::from)))
.wait_timeout($timeout)
.map_err(|_| RecvTimeoutError::Timeout)
};
}
}
#[derive(Debug, Clone)]
pub struct Sender<T: ChannelItem>(SenderExt<T>);
#[derive(Debug)]
pub struct Receiver<T: ChannelItem>(ReceiverExt<T>);
impl<T: ChannelItem> Sender<T> {
#[inline]
pub fn send(&self, item: T) -> Result<(), SendError<T>> {
self.0.send(item)
}
}
impl<T: ChannelItem> Receiver<T> {
#[inline]
pub fn recv(&self) -> Result<T, RecvError> {
self.0.recv().map_err(RecvError::from)
}
#[inline]
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.0.try_recv().map_err(TryRecvError::from)
}
#[inline]
#[cfg(feature = "flume")]
pub async fn recv_async(&self) -> Result<T, RecvError> {
self.0.recv_async().await.map_err(RecvError::from)
}
#[inline]
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvTimeoutError> {
self.0.recv_timeout(timeout).map_err(RecvTimeoutError::from)
}
#[inline]
pub fn select<T2: ChannelItem>(&self, other: &Receiver<T2>) -> SelectResult<T, T2> {
select_impl!(self, other)
}
#[inline]
pub fn select_timeout<T2: ChannelItem>(
&self,
other: &Receiver<T2>,
timeout: Duration,
) -> Result<SelectResult<T, T2>, RecvTimeoutError> {
select_timeout_impl!(self, other, timeout)
}
}
#[derive(Debug, Clone)]
pub struct UnboundedSender<T: ChannelItem>(SenderExt<T>);
#[derive(Debug)]
pub struct UnboundedReceiver<T: ChannelItem>(pub(crate) ReceiverExt<T>);
impl<T: ChannelItem> UnboundedSender<T> {
#[inline]
pub fn send(&self, item: T) -> Result<(), SendError<T>> {
self.0.send(item)
}
}
impl<T: ChannelItem> UnboundedReceiver<T> {
#[inline]
pub fn recv(&self) -> Result<T, RecvError> {
self.0.recv().map_err(RecvError::from)
}
#[inline]
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvTimeoutError> {
self.0.recv_timeout(timeout).map_err(RecvTimeoutError::from)
}
#[inline]
pub fn select<T2: ChannelItem>(&self, other: &UnboundedReceiver<T2>) -> SelectResult<T, T2> {
select_impl!(self, other)
}
#[inline]
pub fn select_timeout<T2: ChannelItem>(
&self,
other: &UnboundedReceiver<T2>,
timeout: Duration,
) -> Result<SelectResult<T, T2>, RecvTimeoutError> {
select_timeout_impl!(self, other, timeout)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use itertools::Itertools;
use crate::channel::{bounded, SelectResult};
const TEST_CAPACITY: usize = 10;
#[test]
fn test_recv_local() {
let (sender, receiver) = bounded(TEST_CAPACITY);
sender.send(123).unwrap();
sender.send(456).unwrap();
drop(sender);
assert_eq!(receiver.recv().unwrap(), 123);
assert_eq!(receiver.recv().unwrap(), 456);
assert!(receiver.recv().is_err());
}
#[test]
fn test_recv_timeout_local() {
let (sender, receiver) = bounded(TEST_CAPACITY);
sender.send(123).unwrap();
assert_eq!(
receiver.recv_timeout(Duration::from_millis(1)).unwrap(),
123
);
assert!(receiver.recv_timeout(Duration::from_millis(50)).is_err());
sender.send(456).unwrap();
assert_eq!(
receiver.recv_timeout(Duration::from_millis(1)).unwrap(),
456
);
}
#[test]
fn test_select_local() {
let (sender1, receiver1) = bounded(TEST_CAPACITY);
let (sender2, receiver2) = bounded(TEST_CAPACITY);
sender1.send(123).unwrap();
let elem1 = receiver1.select(&receiver2);
assert_eq!(elem1, SelectResult::A(Ok(123)));
sender2.send("test".to_string()).unwrap();
let elem2 = receiver1.select(&receiver2);
assert_eq!(elem2, SelectResult::B(Ok("test".to_string())));
}
#[test]
fn test_select_timeout_local() {
let (sender1, receiver1) = bounded(TEST_CAPACITY);
let (sender2, receiver2) = bounded(TEST_CAPACITY);
sender1.send(123).unwrap();
let elem1 = receiver1
.select_timeout(&receiver2, Duration::from_millis(1))
.unwrap();
assert_eq!(elem1, SelectResult::A(Ok(123)));
let timeout = receiver1.select_timeout(&receiver2, Duration::from_millis(50));
assert!(timeout.is_err());
sender2.send("test".to_string()).unwrap();
let elem2 = receiver1
.select_timeout(&receiver2, Duration::from_millis(1))
.unwrap();
assert_eq!(elem2, SelectResult::B(Ok("test".to_string())));
}
#[test]
fn test_select_fairness() {
let tries = 100;
let mut failures = 0;
for _ in 0..100 {
let (sender1, receiver1) = bounded(TEST_CAPACITY);
let (sender2, receiver2) = bounded(TEST_CAPACITY);
for _ in 0..TEST_CAPACITY {
sender1.send(1).unwrap();
sender2.send(2).unwrap();
}
let mut order = Vec::new();
for _ in 0..2 * TEST_CAPACITY {
let elem = receiver1.select(&receiver2);
match elem {
SelectResult::A(Ok(_)) => order.push(1),
SelectResult::B(Ok(_)) => order.push(2),
_ => {}
}
}
let in_order1 = (0..TEST_CAPACITY)
.map(|_| 1)
.chain((0..TEST_CAPACITY).map(|_| 2))
.collect_vec();
let in_order2 = (0..TEST_CAPACITY)
.map(|_| 2)
.chain((0..TEST_CAPACITY).map(|_| 1))
.collect_vec();
if order == in_order1 || order == in_order2 {
failures += 1;
}
}
assert!(100 * failures < 5 * tries);
}
}