use std::sync::Arc;
#[cfg(test)]
use std::sync::atomic::AtomicU64;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Notify;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SendError<T>(pub T);
impl<T> std::fmt::Display for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "channel disconnected")
}
}
impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecvError;
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "channel disconnected")
}
}
impl std::error::Error for RecvError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrySendError<T> {
Full(T),
Disconnected(T),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Disconnected,
}
impl<T> std::fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrySendError::Full(_) => write!(f, "channel full"),
TrySendError::Disconnected(_) => write!(f, "channel disconnected"),
}
}
}
impl<T: std::fmt::Debug> std::error::Error for TrySendError<T> {}
pub struct FastSender<T> {
inner: crossbeam::channel::Sender<T>,
send_notify: Arc<Notify>,
recv_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
}
impl<T> Clone for FastSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
send_notify: self.send_notify.clone(),
recv_notify: self.recv_notify.clone(),
closed: self.closed.clone(),
}
}
}
impl<T> FastSender<T> {
#[cfg(test)]
#[inline]
pub fn send(&self, msg: T) -> Result<(), SendError<T>> {
match self.inner.send(msg) {
Ok(()) => {
self.recv_notify.notify_one();
Ok(())
}
Err(crossbeam::channel::SendError(msg)) => Err(SendError(msg)),
}
}
pub async fn send_async(&self, mut msg: T) -> Result<(), SendError<T>> {
loop {
match self.inner.try_send(msg) {
Ok(()) => {
self.recv_notify.notify_one();
return Ok(());
}
Err(crossbeam::channel::TrySendError::Full(returned)) => {
msg = returned;
self.send_notify.notified().await;
}
Err(crossbeam::channel::TrySendError::Disconnected(returned)) => {
return Err(SendError(returned));
}
}
}
}
#[inline]
pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
match self.inner.try_send(msg) {
Ok(()) => {
self.recv_notify.notify_one();
Ok(())
}
Err(crossbeam::channel::TrySendError::Full(msg)) => Err(TrySendError::Full(msg)),
Err(crossbeam::channel::TrySendError::Disconnected(msg)) => {
Err(TrySendError::Disconnected(msg))
}
}
}
#[cfg(test)]
#[inline]
pub fn is_full(&self) -> bool {
self.inner.is_full()
}
#[inline]
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::Acquire)
}
}
pub struct FastReceiver<T> {
inner: crossbeam::channel::Receiver<T>,
send_notify: Arc<Notify>,
recv_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
#[cfg(test)]
poll_count: Arc<AtomicU64>,
}
impl<T> Drop for FastReceiver<T> {
fn drop(&mut self) {
self.closed.store(true, Ordering::Release);
self.send_notify.notify_waiters();
}
}
impl<T> Drop for FastSender<T> {
fn drop(&mut self) {
self.recv_notify.notify_one();
}
}
impl<T> FastReceiver<T> {
pub async fn recv_async(&self) -> Result<T, RecvError> {
loop {
#[cfg(test)]
self.poll_count.fetch_add(1, Ordering::Relaxed);
match self.inner.try_recv() {
Ok(msg) => {
self.send_notify.notify_one();
return Ok(msg);
}
Err(crossbeam::channel::TryRecvError::Empty) => {
self.recv_notify.notified().await;
}
Err(crossbeam::channel::TryRecvError::Disconnected) => {
return Err(RecvError);
}
}
}
}
#[cfg(test)]
pub fn poll_count(&self) -> u64 {
self.poll_count.load(Ordering::Relaxed)
}
#[inline]
pub fn recv(&self) -> Result<T, RecvError> {
match self.inner.recv() {
Ok(msg) => {
self.send_notify.notify_one();
Ok(msg)
}
Err(crossbeam::channel::RecvError) => Err(RecvError),
}
}
#[inline]
pub fn try_recv(&self) -> Result<T, TryRecvError> {
match self.inner.try_recv() {
Ok(msg) => {
self.send_notify.notify_one();
Ok(msg)
}
Err(crossbeam::channel::TryRecvError::Empty) => Err(TryRecvError::Empty),
Err(crossbeam::channel::TryRecvError::Disconnected) => Err(TryRecvError::Disconnected),
}
}
}
pub fn bounded<T>(capacity: usize) -> (FastSender<T>, FastReceiver<T>) {
let (tx, rx) = crossbeam::channel::bounded(capacity);
let send_notify = Arc::new(Notify::new());
let recv_notify = Arc::new(Notify::new());
let closed = Arc::new(AtomicBool::new(false));
let sender = FastSender {
inner: tx,
send_notify: send_notify.clone(),
recv_notify: recv_notify.clone(),
closed: closed.clone(),
};
let receiver = FastReceiver {
inner: rx,
send_notify,
recv_notify,
closed,
#[cfg(test)]
poll_count: Arc::new(AtomicU64::new(0)),
};
(sender, receiver)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
#[tokio::test]
async fn test_send_recv() {
let (tx, rx) = bounded::<i32>(10);
tx.send(42).unwrap();
tx.send(43).unwrap();
assert_eq!(rx.recv_async().await.unwrap(), 42);
assert_eq!(rx.recv_async().await.unwrap(), 43);
}
#[tokio::test]
async fn test_async_send() {
let (tx, rx) = bounded::<i32>(10);
tx.send_async(42).await.unwrap();
assert_eq!(rx.recv_async().await.unwrap(), 42);
}
#[tokio::test]
async fn test_backpressure() {
let (tx, rx) = bounded::<i32>(2);
tx.send(1).unwrap();
tx.send(2).unwrap();
assert!(tx.is_full());
assert_eq!(rx.recv_async().await.unwrap(), 1);
assert!(!tx.is_full());
tx.send(3).unwrap();
}
#[tokio::test]
async fn test_sender_clone() {
let (tx, rx) = bounded::<i32>(10);
let tx2 = tx.clone();
tx.send(1).unwrap();
tx2.send(2).unwrap();
assert_eq!(rx.recv_async().await.unwrap(), 1);
assert_eq!(rx.recv_async().await.unwrap(), 2);
}
#[tokio::test]
async fn test_disconnection() {
let (tx, rx) = bounded::<i32>(10);
tx.send(42).unwrap();
drop(tx);
assert_eq!(rx.recv_async().await.unwrap(), 42);
assert!(rx.recv_async().await.is_err());
}
#[tokio::test]
async fn test_is_closed() {
let (tx, rx) = bounded::<i32>(10);
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
}
#[tokio::test]
async fn test_concurrent_send_recv() {
let (tx, rx) = bounded::<i32>(100);
let count = 1000;
let sender = GlobalExecutor::spawn(async move {
for i in 0..count {
tx.send_async(i).await.unwrap();
}
});
let receiver = GlobalExecutor::spawn(async move {
let mut received = 0;
while rx.recv_async().await.is_ok() {
received += 1;
if received == count {
break;
}
}
received
});
sender.await.unwrap();
let received = receiver.await.unwrap();
assert_eq!(received, count);
}
#[tokio::test]
async fn test_recv_async_no_busy_loop() {
let (tx, rx) = bounded::<i32>(10);
let sender = GlobalExecutor::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
tx.send_async(42).await.unwrap();
});
let result = rx.recv_async().await.unwrap();
assert_eq!(result, 42);
sender.await.unwrap();
let polls = rx.poll_count();
assert!(
polls < 10,
"Too many poll iterations ({polls}), notification wakeup may not be working"
);
assert!(polls > 0, "Should have polled at least once");
}
#[tokio::test]
async fn test_burst_send_before_recv_drains_all() {
let (tx, rx) = bounded::<i32>(100);
for i in 0..50 {
tx.send(i).unwrap();
}
for i in 0..50 {
assert_eq!(rx.recv_async().await.unwrap(), i);
}
}
#[tokio::test]
async fn test_sender_drop_wakes_blocked_receiver() {
let (tx, rx) = bounded::<i32>(10);
let receiver = GlobalExecutor::spawn(async move {
rx.recv_async().await
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
drop(tx);
let result = receiver.await.unwrap();
assert!(
result.is_err(),
"Should return RecvError after all senders drop"
);
}
#[tokio::test]
async fn test_multiple_concurrent_senders() {
let (tx, rx) = bounded::<i32>(100);
let n_senders = 5;
let msgs_per_sender = 200;
let handles: Vec<_> = (0..n_senders)
.map(|_| {
let tx = tx.clone();
GlobalExecutor::spawn(async move {
for i in 0..msgs_per_sender {
tx.send_async(i).await.unwrap();
}
})
})
.collect();
drop(tx);
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0;
while rx.recv_async().await.is_ok() {
count += 1;
}
count
});
for h in handles {
h.await.unwrap();
}
let count = receiver.await.unwrap();
assert_eq!(count, n_senders * msgs_per_sender);
}
#[tokio::test]
async fn stress_high_concurrency_small_channel() {
let (tx, rx) = bounded::<u64>(5);
let n_senders = 50;
let msgs_per_sender = 1000;
let handles: Vec<_> = (0..n_senders)
.map(|sender_id| {
let tx = tx.clone();
GlobalExecutor::spawn(async move {
for i in 0..msgs_per_sender {
tx.send_async(sender_id * msgs_per_sender + i)
.await
.unwrap();
}
})
})
.collect();
drop(tx);
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0u64;
while rx.recv_async().await.is_ok() {
count += 1;
}
count
});
for h in handles {
h.await.unwrap();
}
let count = receiver.await.unwrap();
assert_eq!(count, n_senders * msgs_per_sender);
}
#[tokio::test]
async fn stress_rapid_sender_disconnect() {
let (tx, rx) = bounded::<u64>(100);
let total_expected = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
let handles: Vec<_> = (0..100)
.map(|wave| {
let tx = tx.clone();
let total = total_expected.clone();
GlobalExecutor::spawn(async move {
for i in 0..10 {
tx.send_async(wave * 10 + i).await.unwrap();
total.fetch_add(1, Ordering::Relaxed);
}
})
})
.collect();
drop(tx);
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0u64;
while rx.recv_async().await.is_ok() {
count += 1;
}
count
});
for h in handles {
h.await.unwrap();
}
let count = receiver.await.unwrap();
let expected = total_expected.load(Ordering::Relaxed);
assert_eq!(count, expected);
}
#[tokio::test]
async fn stress_many_channels_parallel() {
let n_channels = 200;
let msgs_per_channel = 500;
let handles: Vec<_> = (0..n_channels)
.map(|_| {
GlobalExecutor::spawn(async move {
let (tx, rx) = bounded::<u64>(50);
let sender = GlobalExecutor::spawn(async move {
for i in 0..msgs_per_channel {
tx.send_async(i).await.unwrap();
}
});
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0u64;
while count < msgs_per_channel {
rx.recv_async().await.unwrap();
count += 1;
}
count
});
sender.await.unwrap();
let count = receiver.await.unwrap();
assert_eq!(count, msgs_per_channel);
})
})
.collect();
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn stress_backpressure_with_slow_receiver() {
let (tx, rx) = bounded::<u64>(1);
let n_msgs = 500;
let sender = GlobalExecutor::spawn(async move {
for i in 0..n_msgs {
tx.send_async(i).await.unwrap();
}
});
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0u64;
for _ in 0..n_msgs {
let msg = rx.recv_async().await.unwrap();
assert_eq!(msg, count);
count += 1;
if count % 10 == 0 {
tokio::task::yield_now().await;
}
}
count
});
sender.await.unwrap();
let count = receiver.await.unwrap();
assert_eq!(count, n_msgs);
}
#[tokio::test]
async fn stress_sender_drop_during_backpressure() {
let (tx, rx) = bounded::<u64>(2);
tx.send(1).unwrap();
tx.send(2).unwrap();
let handles: Vec<_> = (0..10)
.map(|i| {
let tx = tx.clone();
GlobalExecutor::spawn(async move {
tx.send_async(100 + i).await
})
})
.collect();
drop(tx);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
drop(rx);
for h in handles {
let result = h.await.unwrap();
assert!(
result.is_err(),
"Sender should get error after receiver drops"
);
}
}
mod proptest_channel {
use super::*;
use crate::config::GlobalExecutor;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn no_silent_drops(
n_senders in 1usize..20,
msgs_per_sender in 1usize..200,
capacity in 1usize..100,
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let total = n_senders * msgs_per_sender;
let received = rt.block_on(async {
let (tx, rx) = bounded::<u64>(capacity);
let handles: Vec<_> = (0..n_senders)
.map(|sender_id| {
let tx = tx.clone();
GlobalExecutor::spawn(async move {
for i in 0..msgs_per_sender {
tx.send_async(
(sender_id * msgs_per_sender + i) as u64
)
.await
.unwrap();
}
})
})
.collect();
drop(tx);
let receiver = GlobalExecutor::spawn(async move {
let mut count = 0u64;
while rx.recv_async().await.is_ok() {
count += 1;
}
count
});
for h in handles {
h.await.unwrap();
}
receiver.await.unwrap()
});
prop_assert_eq!(
received, total as u64,
"Expected {} messages, received {}",
total, received
);
}
#[test]
fn backpressure_signals_full(
capacity in 1usize..50,
n_extra in 1usize..100,
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let (tx, rx) = bounded::<u64>(capacity);
for i in 0..capacity {
tx.try_send(i as u64).unwrap();
}
let mut full_count = 0usize;
for i in 0..n_extra {
match tx.try_send((capacity + i) as u64) {
Err(TrySendError::Full(_)) => full_count += 1,
Ok(()) => {
panic!(
"try_send succeeded beyond capacity without draining"
);
}
Err(TrySendError::Disconnected(_)) => {
panic!("channel disconnected unexpectedly");
}
}
}
prop_assert_eq!(
full_count, n_extra,
"All extra sends should return Full"
);
let mut received = 0usize;
while rx.try_recv().is_ok() {
received += 1;
}
prop_assert_eq!(
received, capacity,
"Should receive exactly capacity messages"
);
Ok(())
})?;
}
#[test]
fn sender_disconnect_detected(
n_msgs in 0usize..100,
capacity in 1usize..50,
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let received = rt.block_on(async {
let (tx, rx) = bounded::<u64>(capacity);
let to_send = n_msgs.min(capacity); for i in 0..to_send {
tx.send(i as u64).unwrap();
}
drop(tx);
let mut count = 0u64;
while rx.recv_async().await.is_ok() {
count += 1;
}
count
});
prop_assert_eq!(
received, n_msgs.min(capacity) as u64,
"Should receive all buffered messages before disconnect"
);
}
#[test]
fn receiver_disconnect_detected(
capacity in 1usize..50,
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let (tx, rx) = bounded::<u64>(capacity);
prop_assert!(!tx.is_closed());
drop(rx);
prop_assert!(tx.is_closed());
let result = tx.send_async(42).await;
prop_assert!(result.is_err(), "send_async should fail after receiver drops");
Ok(())
})?;
}
}
}
}