#![doc = include_str!("../README.md")]
#![doc = include_str!("example.md")]
use std::cmp::min;
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
#[derive(Debug)]
struct State<T> {
queue: VecDeque<T>,
tx_count: usize,
rx_count: usize,
rx_wakers: Vec<Waker>,
}
fn wake_all<T>(mut state: MutexGuard<State<T>>) {
let wakers = std::mem::take(&mut state.rx_wakers);
drop(state);
for waker in wakers {
waker.wake();
}
}
#[derive(Debug)]
pub struct Sender<T> {
state: Arc<Mutex<State<T>>>,
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.state.lock().unwrap().tx_count += 1;
Sender {
state: self.state.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let mut state = self.state.lock().unwrap();
assert!(state.tx_count >= 1);
state.tx_count -= 1;
if state.tx_count == 0 {
wake_all(state);
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct SendError<T>(pub T);
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "failed to send value on channel")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
let mut state = self.state.lock().unwrap();
if state.rx_count == 0 {
assert!(state.queue.is_empty());
return Err(SendError(value));
}
state.queue.push_back(value);
wake_all(state);
Ok(())
}
pub fn send_iter<I>(&self, values: I) -> Result<(), SendError<I>>
where
I: IntoIterator<Item = T>,
{
let mut state = self.state.lock().unwrap();
if state.rx_count == 0 {
assert!(state.queue.is_empty());
return Err(SendError(values));
}
state.queue.extend(values.into_iter());
wake_all(state);
Ok(())
}
pub fn send_vec(&self, mut values: Vec<T>) -> Result<Vec<T>, SendError<Vec<T>>> {
let mut state = self.state.lock().unwrap();
if state.rx_count == 0 {
assert!(state.queue.is_empty());
return Err(SendError(values));
}
state.queue.extend(values.drain(..));
wake_all(state);
Ok(values)
}
pub fn batch(self, capacity: usize) -> BatchSender<T> {
BatchSender {
sender: self,
capacity,
buffer: Vec::with_capacity(capacity),
}
}
}
#[derive(Debug)]
pub struct BatchSender<T> {
sender: Sender<T>,
capacity: usize,
buffer: Vec<T>,
}
impl<T> Drop for BatchSender<T> {
fn drop(&mut self) {
if self.buffer.is_empty() {
return;
}
_ = self.sender.send_vec(std::mem::take(&mut self.buffer));
}
}
impl<T> BatchSender<T> {
pub fn send(&mut self, value: T) -> Result<(), SendError<()>> {
self.buffer.push(value);
if self.buffer.len() == self.capacity {
self.drain()
} else {
Ok(())
}
}
pub fn send_iter<I: IntoIterator<Item = T>>(&mut self, values: I) -> Result<(), SendError<()>> {
for value in values.into_iter() {
self.send(value)?;
}
Ok(())
}
pub fn drain(&mut self) -> Result<(), SendError<()>> {
match self.sender.send_vec(std::mem::take(&mut self.buffer)) {
Ok(drained_vec) => {
self.buffer = drained_vec;
Ok(())
}
Err(_) => Err(SendError(())),
}
}
}
#[derive(Debug)]
pub struct Receiver<T> {
state: Arc<Mutex<State<T>>>,
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.state.lock().unwrap().rx_count += 1;
Receiver {
state: self.state.clone(),
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut state = self.state.lock().unwrap();
assert!(state.rx_count >= 1);
state.rx_count -= 1;
if state.rx_count == 0 {
state.queue.clear();
}
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Recv<'a, T> {
receiver: &'a Receiver<T>,
}
impl<'a, T> Unpin for Recv<'a, T> {}
impl<'a, T> Future for Recv<'a, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.receiver.state.lock().unwrap();
match state.queue.pop_front() {
Some(value) => Poll::Ready(Some(value)),
None => {
if state.tx_count == 0 {
Poll::Ready(None)
} else {
state.rx_wakers.push(cx.waker().clone());
Poll::Pending
}
}
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
struct RecvBatch<'a, T> {
receiver: &'a Receiver<T>,
element_limit: usize,
}
impl<'a, T> Unpin for RecvBatch<'a, T> {}
impl<'a, T> Future for RecvBatch<'a, T> {
type Output = Vec<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.receiver.state.lock().unwrap();
let q = &mut state.queue;
let q_len = q.len();
if q_len == 0 {
if state.tx_count == 0 {
return Poll::Ready(Vec::new());
} else {
state.rx_wakers.push(cx.waker().clone());
return Poll::Pending;
}
}
let capacity = min(q_len, self.element_limit);
let v = Vec::from_iter(q.drain(..capacity));
Poll::Ready(v)
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
struct RecvVec<'a, T> {
receiver: &'a Receiver<T>,
element_limit: usize,
vec: &'a mut Vec<T>,
}
impl<'a, T> Unpin for RecvVec<'a, T> {}
impl<'a, T> Future for RecvVec<'a, T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.receiver.state.lock().unwrap();
let q = &mut state.queue;
let q_len = q.len();
if q_len == 0 {
if state.tx_count == 0 {
assert!(self.vec.is_empty());
return Poll::Ready(());
} else {
state.rx_wakers.push(cx.waker().clone());
return Poll::Pending;
}
}
let capacity = min(q_len, self.element_limit);
self.vec.extend(q.drain(..capacity));
Poll::Ready(())
}
}
impl<T> Receiver<T> {
pub fn recv(&self) -> impl Future<Output = Option<T>> + '_ {
Recv { receiver: self }
}
pub fn recv_batch(&self, element_limit: usize) -> impl Future<Output = Vec<T>> + '_ {
RecvBatch {
receiver: self,
element_limit,
}
}
pub fn recv_vec<'a>(
&'a self,
element_limit: usize,
vec: &'a mut Vec<T>,
) -> impl Future<Output = ()> + 'a {
vec.clear();
RecvVec {
receiver: self,
element_limit,
vec,
}
}
}
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
let state = Arc::new(Mutex::new(State {
queue: VecDeque::new(),
tx_count: 1,
rx_count: 1,
rx_wakers: Vec::new(),
}));
(
Sender {
state: state.clone(),
},
Receiver { state },
)
}