use std::{
collections::VecDeque,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use parking_lot::Mutex;
pub fn channels<T>(
n: usize,
) -> (Vec<DistributionSender<T>>, Vec<DistributionReceiver<T>>) {
let channels = (0..n)
.map(|id| {
Arc::new(Mutex::new(Channel {
data: VecDeque::default(),
n_senders: 1,
recv_alive: true,
recv_wakers: Vec::default(),
id,
}))
})
.collect::<Vec<_>>();
let gate = Arc::new(Mutex::new(Gate {
empty_channels: n,
send_wakers: Vec::default(),
}));
let senders = channels
.iter()
.map(|channel| DistributionSender {
channel: Arc::clone(channel),
gate: Arc::clone(&gate),
})
.collect();
let receivers = channels
.into_iter()
.map(|channel| DistributionReceiver {
channel,
gate: Arc::clone(&gate),
})
.collect();
(senders, receivers)
}
#[derive(PartialEq, Eq)]
pub struct SendError<T>(pub T);
impl<T> std::fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SendError").finish()
}
}
impl<T> std::fmt::Display for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cannot send data, receiver is gone")
}
}
impl<T> std::error::Error for SendError<T> {}
#[derive(Debug)]
pub struct DistributionSender<T> {
channel: SharedChannel<T>,
gate: SharedGate,
}
impl<T> DistributionSender<T> {
pub fn send(&self, element: T) -> SendFuture<'_, T> {
SendFuture {
channel: &self.channel,
gate: &self.gate,
element: Box::new(Some(element)),
}
}
}
impl<T> Clone for DistributionSender<T> {
fn clone(&self) -> Self {
let mut guard = self.channel.lock();
guard.n_senders += 1;
Self {
channel: Arc::clone(&self.channel),
gate: Arc::clone(&self.gate),
}
}
}
impl<T> Drop for DistributionSender<T> {
fn drop(&mut self) {
let mut guard_channel = self.channel.lock();
guard_channel.n_senders -= 1;
if guard_channel.n_senders == 0 {
if guard_channel.data.is_empty() && guard_channel.recv_alive {
let mut guard_gate = self.gate.lock();
guard_gate.empty_channels -= 1;
}
guard_channel.wake_receivers();
}
}
}
#[derive(Debug)]
pub struct SendFuture<'a, T> {
channel: &'a SharedChannel<T>,
gate: &'a SharedGate,
element: Box<Option<T>>,
}
impl<'a, T> Future for SendFuture<'a, T> {
type Output = Result<(), SendError<T>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
assert!(this.element.is_some(), "polled ready future");
let mut guard_channel = this.channel.lock();
if !guard_channel.recv_alive {
return Poll::Ready(Err(SendError(
this.element.take().expect("just checked"),
)));
}
let mut guard_gate = this.gate.lock();
if guard_gate.empty_channels == 0 {
guard_gate
.send_wakers
.push((cx.waker().clone(), guard_channel.id));
return Poll::Pending;
}
let was_empty = guard_channel.data.is_empty();
guard_channel
.data
.push_back(this.element.take().expect("just checked"));
if was_empty {
guard_gate.empty_channels -= 1;
guard_channel.wake_receivers();
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct DistributionReceiver<T> {
channel: SharedChannel<T>,
gate: SharedGate,
}
impl<T> DistributionReceiver<T> {
pub fn recv(&mut self) -> RecvFuture<'_, T> {
RecvFuture {
channel: &mut self.channel,
gate: &mut self.gate,
rdy: false,
}
}
}
impl<T> Drop for DistributionReceiver<T> {
fn drop(&mut self) {
let mut guard_channel = self.channel.lock();
let mut guard_gate = self.gate.lock();
guard_channel.recv_alive = false;
if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) {
guard_gate.empty_channels -= 1;
}
guard_gate.wake_channel_senders(guard_channel.id);
guard_channel.data.clear();
}
}
pub struct RecvFuture<'a, T> {
channel: &'a mut SharedChannel<T>,
gate: &'a mut SharedGate,
rdy: bool,
}
impl<'a, T> Future for RecvFuture<'a, T> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
assert!(!this.rdy, "polled ready future");
let mut guard_channel = this.channel.lock();
match guard_channel.data.pop_front() {
Some(element) => {
if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) {
let mut guard_gate = this.gate.lock();
let old_counter = guard_gate.empty_channels;
guard_gate.empty_channels += 1;
if old_counter == 0 {
guard_gate.wake_all_senders();
}
drop(guard_gate);
drop(guard_channel);
}
this.rdy = true;
Poll::Ready(Some(element))
}
None if guard_channel.n_senders == 0 => {
this.rdy = true;
Poll::Ready(None)
}
None => {
guard_channel.recv_wakers.push(cx.waker().clone());
Poll::Pending
}
}
}
}
#[derive(Debug)]
struct Channel<T> {
data: VecDeque<T>,
n_senders: usize,
recv_alive: bool,
recv_wakers: Vec<Waker>,
id: usize,
}
impl<T> Channel<T> {
fn wake_receivers(&mut self) {
for waker in self.recv_wakers.drain(..) {
waker.wake();
}
}
}
type SharedChannel<T> = Arc<Mutex<Channel<T>>>;
#[derive(Debug)]
struct Gate {
empty_channels: usize,
send_wakers: Vec<(Waker, usize)>,
}
impl Gate {
fn wake_all_senders(&mut self) {
for (waker, _id) in self.send_wakers.drain(..) {
waker.wake();
}
}
fn wake_channel_senders(&mut self, id: usize) {
let (wake, keep) = self
.send_wakers
.drain(..)
.partition(|(_waker, id2)| id == *id2);
self.send_wakers = keep;
for (waker, _id) in wake {
waker.wake();
}
}
}
type SharedGate = Arc<Mutex<Gate>>;
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use futures::{task::ArcWake, FutureExt};
use super::*;
#[test]
fn test_single_channel_no_gate() {
let (mut txs, mut rxs) = channels(2);
let mut recv_fut = rxs[0].recv();
let waker = poll_pending(&mut recv_fut);
poll_ready(&mut txs[0].send("foo")).unwrap();
assert!(waker.woken());
assert_eq!(poll_ready(&mut recv_fut), Some("foo"),);
poll_ready(&mut txs[0].send("bar")).unwrap();
poll_ready(&mut txs[0].send("baz")).unwrap();
poll_ready(&mut txs[0].send("end")).unwrap();
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("baz"),);
txs.remove(0);
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("end"),);
assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
}
#[test]
fn test_multi_sender() {
let (txs, mut rxs) = channels(2);
let tx_clone = txs[0].clone();
poll_ready(&mut txs[0].send("foo")).unwrap();
poll_ready(&mut tx_clone.send("bar")).unwrap();
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("foo"),);
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
}
#[test]
fn test_gate() {
let (txs, mut rxs) = channels(2);
poll_ready(&mut txs[0].send("0_a")).unwrap();
poll_ready(&mut txs[0].send("0_b")).unwrap();
poll_ready(&mut txs[1].send("1_a")).unwrap();
let mut send_fut = txs[1].send("1_b");
let waker = poll_pending(&mut send_fut);
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_a"),);
poll_pending(&mut send_fut);
assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_b"),);
assert!(waker.woken());
poll_ready(&mut send_fut).unwrap();
}
#[test]
fn test_close_channel_by_dropping_tx() {
let (mut txs, mut rxs) = channels(2);
let tx0 = txs.remove(0);
let tx1 = txs.remove(0);
let tx0_clone = tx0.clone();
let mut recv_fut = rxs[0].recv();
poll_ready(&mut tx1.send("a")).unwrap();
let recv_waker = poll_pending(&mut recv_fut);
drop(tx0);
assert!(!recv_waker.woken());
poll_ready(&mut tx1.send("b")).unwrap();
let recv_waker = poll_pending(&mut recv_fut);
let tx0_clone2 = tx0_clone.clone();
assert!(!recv_waker.woken());
poll_ready(&mut tx1.send("c")).unwrap();
let recv_waker = poll_pending(&mut recv_fut);
drop(tx0_clone);
assert!(!recv_waker.woken());
poll_ready(&mut tx1.send("d")).unwrap();
let recv_waker = poll_pending(&mut recv_fut);
drop(tx0_clone2);
poll_pending(&mut tx1.send("e"));
assert!(recv_waker.woken());
assert_eq!(poll_ready(&mut recv_fut), None,);
}
#[test]
fn test_close_channel_by_dropping_rx_on_open_gate() {
let (txs, mut rxs) = channels(2);
let rx0 = rxs.remove(0);
let _rx1 = rxs.remove(0);
poll_ready(&mut txs[1].send("a")).unwrap();
drop(rx0);
poll_pending(&mut txs[1].send("b"));
assert_eq!(poll_ready(&mut txs[0].send("foo")), Err(SendError("foo")),);
}
#[test]
fn test_close_channel_by_dropping_rx_on_closed_gate() {
let (txs, mut rxs) = channels(2);
let rx0 = rxs.remove(0);
let mut rx1 = rxs.remove(0);
poll_ready(&mut txs[0].send("0_a")).unwrap();
poll_ready(&mut txs[1].send("1_a")).unwrap();
let mut send_fut0 = txs[0].send("0_b");
let mut send_fut1 = txs[1].send("1_b");
let waker0 = poll_pending(&mut send_fut0);
let waker1 = poll_pending(&mut send_fut1);
drop(rx0);
assert!(waker0.woken());
assert!(!waker1.woken());
assert_eq!(poll_ready(&mut send_fut0), Err(SendError("0_b")),);
poll_pending(&mut send_fut1);
assert_eq!(poll_ready(&mut rx1.recv()), Some("1_a"),);
}
#[test]
fn test_drop_rx_three_channels() {
let (mut txs, mut rxs) = channels(3);
let tx0 = txs.remove(0);
let tx1 = txs.remove(0);
let tx2 = txs.remove(0);
let mut rx0 = rxs.remove(0);
let rx1 = rxs.remove(0);
let _rx2 = rxs.remove(0);
poll_ready(&mut tx0.send("0_a")).unwrap();
poll_ready(&mut tx1.send("1_a")).unwrap();
poll_ready(&mut tx2.send("2_a")).unwrap();
drop(rx1);
assert_eq!(poll_ready(&mut rx0.recv()), Some("0_a"),);
poll_ready(&mut tx0.send("0_b")).unwrap();
assert_eq!(poll_ready(&mut tx1.send("1_b")), Err(SendError("1_b")),);
poll_pending(&mut tx2.send("2_b"));
}
#[test]
fn test_close_channel_by_dropping_rx_clears_data() {
let (txs, rxs) = channels(1);
let obj = Arc::new(());
let counter = Arc::downgrade(&obj);
assert_eq!(counter.strong_count(), 1);
poll_ready(&mut txs[0].send(obj)).unwrap();
assert_eq!(counter.strong_count(), 1);
drop(rxs);
assert_eq!(counter.strong_count(), 0);
}
#[test]
#[should_panic(expected = "polled ready future")]
fn test_panic_poll_send_future_after_ready_ok() {
let (txs, _rxs) = channels(1);
let mut fut = txs[0].send("foo");
poll_ready(&mut fut).unwrap();
poll_ready(&mut fut).ok();
}
#[test]
#[should_panic(expected = "polled ready future")]
fn test_panic_poll_send_future_after_ready_err() {
let (txs, rxs) = channels(1);
drop(rxs);
let mut fut = txs[0].send("foo");
poll_ready(&mut fut).unwrap_err();
poll_ready(&mut fut).ok();
}
#[test]
#[should_panic(expected = "polled ready future")]
fn test_panic_poll_recv_future_after_ready_some() {
let (txs, mut rxs) = channels(1);
poll_ready(&mut txs[0].send("foo")).unwrap();
let mut fut = rxs[0].recv();
poll_ready(&mut fut).unwrap();
poll_ready(&mut fut);
}
#[test]
#[should_panic(expected = "polled ready future")]
fn test_panic_poll_recv_future_after_ready_none() {
let (txs, mut rxs) = channels::<u8>(1);
drop(txs);
let mut fut = rxs[0].recv();
assert!(poll_ready(&mut fut).is_none());
poll_ready(&mut fut);
}
#[test]
#[should_panic(expected = "future is pending")]
fn test_meta_poll_ready_wrong_state() {
let mut fut = futures::future::pending::<u8>();
poll_ready(&mut fut);
}
#[test]
#[should_panic(expected = "future is ready")]
fn test_meta_poll_pending_wrong_state() {
let mut fut = futures::future::ready(1);
poll_pending(&mut fut);
}
#[test]
fn test_meta_poll_pending_waker() {
let (tx, mut rx) = futures::channel::oneshot::channel();
let waker = poll_pending(&mut rx);
assert!(!waker.woken());
tx.send(1).unwrap();
assert!(waker.woken());
}
#[track_caller]
fn poll_ready<F>(fut: &mut F) -> F::Output
where
F: Future + Unpin,
{
match poll(fut).0 {
Poll::Ready(x) => x,
Poll::Pending => panic!("future is pending"),
}
}
#[track_caller]
fn poll_pending<F>(fut: &mut F) -> Arc<TestWaker>
where
F: Future + Unpin,
{
let (res, waker) = poll(fut);
match res {
Poll::Ready(_) => panic!("future is ready"),
Poll::Pending => waker,
}
}
fn poll<F>(fut: &mut F) -> (Poll<F::Output>, Arc<TestWaker>)
where
F: Future + Unpin,
{
let test_waker = Arc::new(TestWaker::default());
let waker = futures::task::waker(Arc::clone(&test_waker));
let mut cx = std::task::Context::from_waker(&waker);
let res = fut.poll_unpin(&mut cx);
(res, test_waker)
}
#[derive(Debug, Default)]
struct TestWaker {
woken: AtomicBool,
}
impl TestWaker {
fn woken(&self) -> bool {
self.woken.load(Ordering::SeqCst)
}
}
impl ArcWake for TestWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.woken.store(true, Ordering::SeqCst);
}
}
}