use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
use std::time::Duration;
#[must_use]
pub fn channel<T: Send>(capacity: usize) -> (MpscSender<T>, MpscReceiver<T>) {
MpscChannel::bounded(capacity)
}
pub struct MpscChannel;
impl MpscChannel {
#[must_use]
pub fn bounded<T: Send>(capacity: usize) -> (MpscSender<T>, MpscReceiver<T>) {
let (sender, receiver) = bounded(capacity);
(
MpscSender { inner: sender },
MpscReceiver { inner: receiver },
)
}
}
#[derive(Clone)]
pub struct MpscSender<T> {
inner: Sender<T>,
}
impl<T> MpscSender<T> {
#[inline]
pub fn try_send(&self, item: T) -> Result<(), TrySendError<T>> {
self.inner.try_send(item)
}
pub fn send(&self, item: T) -> Result<(), T> {
self.inner.send(item).map_err(|e| e.0)
}
pub fn send_timeout(&self, item: T, timeout: Duration) -> Result<(), T> {
self.inner.send_timeout(item, timeout).map_err(|e| match e {
crossbeam_channel::SendTimeoutError::Timeout(v) => v,
crossbeam_channel::SendTimeoutError::Disconnected(v) => v,
})
}
#[must_use]
pub fn is_connected(&self) -> bool {
!self.inner.is_empty() || !self.inner.is_full()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_full(&self) -> bool {
self.inner.is_full()
}
#[must_use]
pub fn capacity(&self) -> Option<usize> {
self.inner.capacity()
}
}
pub struct MpscReceiver<T> {
inner: Receiver<T>,
}
impl<T> MpscReceiver<T> {
#[inline]
pub fn try_recv(&self) -> Option<T> {
self.inner.try_recv().ok()
}
pub fn recv(&self) -> Option<T> {
self.inner.recv().ok()
}
pub fn recv_timeout(&self, timeout: Duration) -> Option<T> {
self.inner.recv_timeout(timeout).ok()
}
#[must_use]
pub fn as_select(&self) -> &Receiver<T> {
&self.inner
}
pub fn drain(&self) -> impl Iterator<Item = T> + '_ {
std::iter::from_fn(|| self.inner.try_recv().ok())
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn is_disconnected(&self) -> bool {
self.inner.is_empty() && self.inner.try_recv().is_err()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_basic_send_recv() {
let (tx, rx) = channel::<u64>(16);
assert!(tx.try_send(42).is_ok());
assert_eq!(rx.try_recv(), Some(42));
assert_eq!(rx.try_recv(), None);
}
#[test]
fn test_multiple_senders() {
let (tx, rx) = channel::<u64>(16);
let tx2 = tx.clone();
tx.send(1).unwrap();
tx2.send(2).unwrap();
let mut received = vec![rx.recv().unwrap(), rx.recv().unwrap()];
received.sort();
assert_eq!(received, vec![1, 2]);
}
#[test]
fn test_threaded_send() {
let (tx, rx) = channel::<u64>(100);
let handles: Vec<_> = (0..4)
.map(|i| {
let tx = tx.clone();
thread::spawn(move || {
for j in 0..10 {
tx.send(i * 10 + j).unwrap();
}
})
})
.collect();
drop(tx);
for h in handles {
h.join().unwrap();
}
let received: Vec<_> = rx.drain().collect();
assert_eq!(received.len(), 40);
}
#[test]
fn test_send_timeout() {
let (tx, _rx) = channel::<u64>(1);
tx.send(1).unwrap();
let result = tx.send_timeout(2, Duration::from_millis(10));
assert!(result.is_err());
}
#[test]
fn test_recv_timeout() {
let (_tx, rx) = channel::<u64>(16);
let result = rx.recv_timeout(Duration::from_millis(10));
assert!(result.is_none());
}
#[test]
fn test_drain() {
let (tx, rx) = channel::<u64>(16);
for i in 0..5 {
tx.send(i).unwrap();
}
let items: Vec<_> = rx.drain().collect();
assert_eq!(items, vec![0, 1, 2, 3, 4]);
assert!(rx.is_empty());
}
#[test]
fn test_receiver_len() {
let (tx, rx) = channel::<u64>(16);
assert_eq!(rx.len(), 0);
assert!(rx.is_empty());
tx.send(1).unwrap();
tx.send(2).unwrap();
assert_eq!(rx.len(), 2);
assert!(!rx.is_empty());
}
#[test]
fn test_as_select() {
let (tx, rx) = channel::<u64>(16);
tx.send(42).unwrap();
let inner = rx.as_select();
assert_eq!(inner.try_recv().ok(), Some(42));
}
#[test]
fn test_sender_capacity() {
let (tx, _rx) = channel::<u64>(16);
assert_eq!(tx.capacity(), Some(16));
}
}