use crate::sync::Mutex;
use core::num::NonZeroUsize;
use futures::{stream::FusedStream, Sink, Stream};
use std::{
collections::VecDeque,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use thiserror::Error;
#[derive(Debug, Error)]
#[error("channel closed")]
pub struct ChannelClosed;
#[derive(Debug)]
struct Shared<T: Send + Sync> {
buffer: VecDeque<T>,
capacity: usize,
receiver_waker: Option<Waker>,
sender_count: usize,
receiver_dropped: bool,
}
pub struct Sender<T: Send + Sync> {
shared: Arc<Mutex<Shared<T>>>,
}
impl<T: Send + Sync> Sender<T> {
pub fn is_closed(&self) -> bool {
let shared = self.shared.lock();
shared.receiver_dropped
}
}
impl<T: Send + Sync> Clone for Sender<T> {
fn clone(&self) -> Self {
let mut shared = self.shared.lock();
shared.sender_count += 1;
drop(shared);
Self {
shared: self.shared.clone(),
}
}
}
impl<T: Send + Sync> Drop for Sender<T> {
fn drop(&mut self) {
let mut shared = self.shared.lock();
shared.sender_count -= 1;
let waker = if shared.sender_count == 0 {
shared.receiver_waker.take()
} else {
None
};
drop(shared);
if let Some(w) = waker {
w.wake();
}
}
}
impl<T: Send + Sync> Sink<T> for Sender<T> {
type Error = ChannelClosed;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let shared = self.shared.lock();
if shared.receiver_dropped {
return Poll::Ready(Err(ChannelClosed));
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let mut shared = self.shared.lock();
if shared.receiver_dropped {
return Err(ChannelClosed);
}
let old_item = if shared.buffer.len() >= shared.capacity {
shared.buffer.pop_front()
} else {
None
};
shared.buffer.push_back(item);
let waker = shared.receiver_waker.take();
drop(shared);
drop(old_item);
if let Some(w) = waker {
w.wake();
}
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct Receiver<T: Send + Sync> {
shared: Arc<Mutex<Shared<T>>>,
}
impl<T: Send + Sync> Stream for Receiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut shared = self.shared.lock();
if let Some(item) = shared.buffer.pop_front() {
return Poll::Ready(Some(item));
}
if shared.sender_count == 0 {
return Poll::Ready(None);
}
if !shared
.receiver_waker
.as_ref()
.is_some_and(|w| w.will_wake(cx.waker()))
{
shared.receiver_waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
impl<T: Send + Sync> FusedStream for Receiver<T> {
fn is_terminated(&self) -> bool {
let shared = self.shared.lock();
shared.sender_count == 0 && shared.buffer.is_empty()
}
}
impl<T: Send + Sync> Drop for Receiver<T> {
fn drop(&mut self) {
let mut shared = self.shared.lock();
shared.receiver_dropped = true;
}
}
pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Mutex::new(Shared {
buffer: VecDeque::with_capacity(capacity.get()),
capacity: capacity.get(),
receiver_waker: None,
sender_count: 1,
receiver_dropped: false,
}));
let sender = Sender {
shared: shared.clone(),
};
let receiver = Receiver { shared };
(sender, receiver)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::NZUsize;
use futures::{executor::block_on, SinkExt, StreamExt};
#[test]
fn test_basic_send_recv() {
block_on(async {
let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
sender.send(3).await.unwrap();
assert_eq!(receiver.next().await, Some(1));
assert_eq!(receiver.next().await, Some(2));
assert_eq!(receiver.next().await, Some(3));
});
}
#[test]
fn test_overflow_drops_oldest() {
block_on(async {
let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
sender.send(3).await.unwrap(); sender.send(4).await.unwrap();
assert_eq!(receiver.next().await, Some(3));
assert_eq!(receiver.next().await, Some(4));
});
}
#[test]
fn test_send_after_receiver_dropped() {
block_on(async {
let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
drop(receiver);
let err = sender.send(1).await.unwrap_err();
assert!(matches!(err, ChannelClosed));
});
}
#[test]
fn test_recv_after_sender_dropped() {
block_on(async {
let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
drop(sender);
assert_eq!(receiver.next().await, Some(1));
assert_eq!(receiver.next().await, Some(2));
assert_eq!(receiver.next().await, None);
});
}
#[test]
fn test_stream_collect() {
block_on(async {
let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
sender.send(3).await.unwrap();
drop(sender);
let items: Vec<_> = receiver.collect().await;
assert_eq!(items, vec![1, 2, 3]);
});
}
#[test]
fn test_clone_sender() {
block_on(async {
let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
let mut sender2 = sender1.clone();
sender1.send(1).await.unwrap();
sender2.send(2).await.unwrap();
assert_eq!(receiver.next().await, Some(1));
assert_eq!(receiver.next().await, Some(2));
});
}
#[test]
fn test_sender_drop_with_clones() {
block_on(async {
let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
let mut sender2 = sender1.clone();
drop(sender1);
sender2.send(1).await.unwrap();
assert_eq!(receiver.next().await, Some(1));
drop(sender2);
assert_eq!(receiver.next().await, None);
});
}
#[test]
fn test_capacity_one() {
block_on(async {
let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
assert_eq!(receiver.next().await, Some(2));
sender.send(1).await.unwrap();
sender.send(2).await.unwrap(); sender.send(3).await.unwrap();
assert_eq!(receiver.next().await, Some(3));
});
}
#[test]
fn test_send_all() {
block_on(async {
let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
let items = futures::stream::iter(vec![1, 2, 3]);
sender.send_all(&mut items.map(Ok)).await.unwrap();
drop(sender);
let received: Vec<_> = receiver.collect().await;
assert_eq!(received, vec![1, 2, 3]);
});
}
#[test]
fn test_fused_stream() {
use futures::stream::FusedStream;
block_on(async {
let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
assert!(!receiver.is_terminated());
sender.send(1).await.unwrap();
assert!(!receiver.is_terminated());
drop(sender);
assert!(!receiver.is_terminated());
assert_eq!(receiver.next().await, Some(1));
assert!(receiver.is_terminated());
assert_eq!(receiver.next().await, None);
assert!(receiver.is_terminated());
});
}
#[test]
fn test_is_closed() {
block_on(async {
let (sender, receiver) = channel::<i32>(NZUsize!(10));
assert!(!sender.is_closed());
drop(receiver);
assert!(sender.is_closed());
});
}
}