use crate::channels::ChannelCapacity;
use futures_lite::future;
use futures_lite::stream::Stream;
use futures_lite::FutureExt;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, Waker};
#[derive(Debug)]
pub struct ChannelError<T> {
kind: io::ErrorKind,
pub item: T,
}
impl<T> ChannelError<T> {
fn new(kind: io::ErrorKind, item: T) -> ChannelError<T> {
ChannelError { kind, item }
}
}
impl<T: std::fmt::Debug> From<ChannelError<T>> for io::Error {
fn from(error: ChannelError<T>) -> Self {
io::Error::new(error.kind, format!("item: {:?}", error.item))
}
}
#[derive(Debug)]
pub struct LocalSender<T> {
channel: LocalChannel<T>,
}
#[derive(Debug)]
pub struct LocalReceiver<T> {
channel: LocalChannel<T>,
}
#[derive(Debug)]
struct State<T> {
capacity: ChannelCapacity,
channel: VecDeque<T>,
recv_waiters: Option<VecDeque<Waker>>,
send_waiters: Option<VecDeque<Waker>>,
}
#[derive(Debug, Clone)]
struct LocalChannel<T> {
state: Rc<RefCell<State<T>>>,
}
impl<T> LocalChannel<T> {
#[allow(clippy::new_ret_no_self)]
fn new(capacity: ChannelCapacity) -> (LocalSender<T>, LocalReceiver<T>) {
let channel = match capacity {
ChannelCapacity::Unbounded => VecDeque::new(),
ChannelCapacity::Bounded(x) => VecDeque::with_capacity(x),
};
let state = Rc::new(RefCell::new(State {
capacity,
channel,
send_waiters: Some(VecDeque::new()),
recv_waiters: Some(VecDeque::new()),
}));
(
LocalSender {
channel: LocalChannel {
state: state.clone(),
},
},
LocalReceiver {
channel: LocalChannel { state },
},
)
}
fn push(&self, item: T) -> Result<(), ChannelError<T>> {
let mut state = self.state.borrow_mut();
if state.recv_waiters.is_none() {
return Err(ChannelError::new(io::ErrorKind::BrokenPipe, item));
}
if let ChannelCapacity::Bounded(x) = state.capacity {
if state.channel.len() >= x {
return Err(ChannelError::new(io::ErrorKind::WouldBlock, item));
}
}
state.channel.push_back(item);
if let Some(w) = state.recv_waiters.as_mut().and_then(|x| x.pop_front()) {
drop(state);
w.wake();
}
Ok(())
}
fn is_full(&self) -> bool {
let state = self.state.borrow();
match state.capacity {
ChannelCapacity::Unbounded => false,
ChannelCapacity::Bounded(x) => state.channel.len() >= x,
}
}
}
pub fn new_unbounded<T>() -> (LocalSender<T>, LocalReceiver<T>) {
LocalChannel::new(ChannelCapacity::Unbounded)
}
pub fn new_bounded<T>(size: usize) -> (LocalSender<T>, LocalReceiver<T>) {
LocalChannel::new(ChannelCapacity::Bounded(size))
}
#[allow(clippy::len_without_is_empty)]
impl<T> LocalSender<T> {
pub fn try_send(&self, item: T) -> Result<(), ChannelError<T>> {
self.channel.push(item)
}
pub async fn send(&self, item: T) -> Result<(), ChannelError<T>> {
if !self.is_full() {
self.try_send(item)
} else {
let waiter = future::poll_fn(|cx| self.wait_for_room(cx));
waiter.await;
self.try_send(item)
}
}
pub fn is_full(&self) -> bool {
self.channel.is_full()
}
pub fn len(&self) -> usize {
let state = self.channel.state.borrow();
state.channel.len()
}
fn wait_for_room(&self, cx: &mut Context<'_>) -> Poll<()> {
if !self.channel.is_full() {
Poll::Ready(())
} else {
let mut state = self.channel.state.borrow_mut();
state
.send_waiters
.as_mut()
.unwrap()
.push_back(cx.waker().clone());
Poll::Pending
}
}
}
impl<T> Drop for LocalSender<T> {
fn drop(&mut self) {
let mut state = self.channel.state.borrow_mut();
state.send_waiters.take();
if state.recv_waiters.is_none() {
return;
}
let waiters = state.recv_waiters.replace(VecDeque::new()).unwrap();
drop(state);
for w in waiters {
w.wake();
}
}
}
impl<T> Drop for LocalReceiver<T> {
fn drop(&mut self) {
let mut state = self.channel.state.borrow_mut();
state.recv_waiters.take();
if state.send_waiters.is_none() {
return;
}
let waiters = state.send_waiters.replace(VecDeque::new()).unwrap();
drop(state);
for w in waiters {
w.wake();
}
}
}
impl<T> Stream for LocalReceiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut waiter = future::poll_fn(|cx| self.recv_one(cx));
waiter.poll(cx)
}
}
impl<T> LocalReceiver<T> {
pub async fn recv(&self) -> Option<T> {
let waiter = future::poll_fn(|cx| self.recv_one(cx));
waiter.await
}
fn recv_one(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
let mut state = self.channel.state.borrow_mut();
match state.channel.pop_front() {
Some(item) => {
if let Some(w) = state.send_waiters.as_mut().and_then(|x| x.pop_front()) {
drop(state);
w.wake();
}
Poll::Ready(Some(item))
}
None => match state.send_waiters.is_some() {
true => {
state
.recv_waiters
.as_mut()
.unwrap()
.push_back(cx.waker().clone());
Poll::Pending
}
false => Poll::Ready(None),
},
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{enclose, Local};
use futures_lite::stream::{self, StreamExt};
#[test]
fn producer_consumer() {
test_executor!(async move {
let (sender, receiver) = new_unbounded();
let handle = Local::local(async move {
let sum = receiver.fold(0, |acc, x| acc + x).await;
assert_eq!(sum, 10);
})
.detach();
for _ in 0..10 {
sender.try_send(1).unwrap();
}
drop(sender);
handle.await;
});
}
#[test]
fn producer_parallel_consumer() {
use futures::future;
use futures::stream::StreamExt;
test_executor!(async move {
let (sender, receiver) = new_unbounded();
let handle = Local::local(async move {
let mut sum = 0;
receiver
.for_each_concurrent(1000, |x| {
sum += x;
future::ready(())
})
.await;
assert_eq!(sum, 10);
})
.detach();
for _ in 0..10 {
sender.try_send(1).unwrap();
}
drop(sender);
handle.await;
});
}
#[test]
fn producer_receiver_serialized() {
test_executor!(async move {
let (sender, receiver) = new_unbounded();
for _ in 0..10 {
sender.try_send(1).unwrap();
}
drop(sender);
let handle = Local::local(async move {
let sum = receiver.fold(0, |acc, x| acc + x).await;
assert_eq!(sum, 10);
})
.detach();
handle.await;
});
}
#[test]
fn producer_early_drop_receiver() {
test_executor!(async move {
let (sender, receiver) = new_unbounded();
let handle = Local::local(async move {
let sum = receiver.take(3).fold(0, |acc, x| acc + x).await;
assert_eq!(sum, 3);
})
.detach();
loop {
match sender.try_send(1) {
Ok(_) => Local::later().await,
Err(x) => {
assert_eq!(x.item, 1);
assert_eq!(x.kind, io::ErrorKind::BrokenPipe);
break;
}
}
}
handle.await;
});
}
#[test]
fn producer_bounded_early_error() {
test_executor!(async move {
let (sender, receiver) = new_bounded(1);
sender.try_send(0).unwrap();
sender.try_send(0).unwrap_err();
drop(receiver);
});
}
#[test]
fn producer_bounded_has_capacity() {
test_executor!(async move {
let (sender, receiver) = new_bounded(1);
assert_eq!(sender.is_full(), false);
sender.try_send(0).unwrap();
assert_eq!(sender.is_full(), true);
drop(receiver);
});
}
#[test]
fn producer_bounded_ping_pong() {
test_executor!(async move {
let (sender, receiver) = new_bounded(1);
let handle = Local::local(async move {
let sum = receiver.fold(0, |acc, x| acc + x).await;
assert_eq!(sum, 10);
})
.detach();
for _ in 0..10 {
sender.send(1).await.unwrap();
}
drop(sender);
handle.await;
});
}
#[test]
fn producer_bounded_previously_blocked_still_errors_out() {
test_executor!(async move {
let (sender, mut receiver) = new_bounded(1);
let s = Local::local(async move {
sender.try_send(0).unwrap();
sender.send(0).await.unwrap_err();
})
.detach();
let r = Local::local(async move {
receiver.next().await.unwrap();
drop(receiver);
})
.detach();
r.await;
s.await;
});
}
#[test]
fn non_stream_receiver() {
test_executor!(async move {
let (sender, receiver) = new_bounded(1);
let s = Local::local(async move {
sender.try_send(0).unwrap();
sender.send(0).await.unwrap_err();
})
.detach();
let r = Local::local(async move {
receiver.recv().await.unwrap();
drop(receiver);
})
.detach();
r.await;
s.await;
});
}
#[test]
fn multiple_task_receivers() {
test_executor!(async move {
let (sender, receiver) = new_bounded(1);
let receiver = Rc::new(receiver);
let mut ret = Vec::new();
for _ in 0..10 {
ret.push(Local::local(enclose! {(receiver) async move {
receiver.recv().await.unwrap();
}}));
}
for _ in 0..10 {
sender.send(0).await.unwrap();
}
let recvd = stream::iter(ret).then(|f| f).count().await;
assert_eq!(recvd, 10);
});
}
}