use std::future::{Future, IntoFuture};
use std::ops::Drop;
use std::pin::Pin;
use std::sync::{Arc, Mutex, OnceLock, Weak};
use std::task::{Context, Poll, Waker, ready};
use slotmap_careful::DenseSlotMap;
slotmap_careful::new_key_type! { struct WakerKey; }
#[derive(Debug)]
pub(crate) struct Sender<T> {
shared: Weak<Shared<T>>,
}
#[derive(Clone, Debug)]
pub(crate) struct Receiver<T> {
shared: Arc<Shared<T>>,
}
#[derive(Debug)]
struct Shared<T> {
msg: OnceLock<Result<T, SenderDropped>>,
wakers: Mutex<Result<DenseSlotMap<WakerKey, Waker>, WakersAlreadyWoken>>,
}
#[derive(Debug)]
pub(crate) struct BorrowedReceiverFuture<'a, T> {
shared: &'a Shared<T>,
waker_key: Option<WakerKey>,
}
#[derive(Debug)]
pub(crate) struct ReceiverFuture<T> {
shared: Arc<Shared<T>>,
waker_key: Option<WakerKey>,
}
#[derive(Copy, Clone, Debug)]
struct WakersAlreadyWoken;
#[derive(Copy, Clone, Debug, thiserror::Error)]
#[error("the message was already set")]
struct MessageAlreadySet;
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[error("the sender was dropped")]
pub(crate) struct SenderDropped;
pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
msg: OnceLock::new(),
wakers: Mutex::new(Ok(DenseSlotMap::with_key())),
});
let sender = Sender {
shared: Arc::downgrade(&shared),
};
let receiver = Receiver { shared };
(sender, receiver)
}
impl<T> Sender<T> {
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn send(self, msg: T) {
Self::send_and_wake(&self.shared, Ok(msg))
.expect("could not set the message");
}
fn send_and_wake(
shared: &Weak<Shared<T>>,
msg: Result<T, SenderDropped>,
) -> Result<(), MessageAlreadySet> {
let Some(shared) = shared.upgrade() else {
return Ok(());
};
shared.msg.set(msg).or(Err(MessageAlreadySet))?;
let mut wakers = {
let mut wakers = shared.wakers.lock().expect("poisoned");
std::mem::replace(&mut *wakers, Err(WakersAlreadyWoken))
.expect("wakers were taken more than once")
};
for (_key, waker) in wakers.drain() {
waker.wake();
}
Ok(())
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn is_cancelled(&self) -> bool {
self.shared.strong_count() == 0
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let _ = Self::send_and_wake(&self.shared, Err(SenderDropped));
}
}
impl<T> Receiver<T> {
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn borrowed(&self) -> BorrowedReceiverFuture<'_, T> {
BorrowedReceiverFuture {
shared: &self.shared,
waker_key: None,
}
}
pub(crate) fn is_ready(&self) -> bool {
self.shared.msg.get().is_some()
}
}
impl<T: Clone> IntoFuture for Receiver<T> {
type Output = Result<T, SenderDropped>;
type IntoFuture = ReceiverFuture<T>;
fn into_future(self) -> Self::IntoFuture {
ReceiverFuture {
shared: self.shared,
waker_key: None,
}
}
}
impl<'a, T> Future for BorrowedReceiverFuture<'a, T> {
type Output = Result<&'a T, SenderDropped>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let self_ = self.get_mut();
receiver_fut_poll(self_.shared, &mut self_.waker_key, cx.waker())
}
}
impl<T> Drop for BorrowedReceiverFuture<'_, T> {
fn drop(&mut self) {
receiver_fut_drop(self.shared, &mut self.waker_key);
}
}
impl<T: Clone> Future for ReceiverFuture<T> {
type Output = Result<T, SenderDropped>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let self_ = self.get_mut();
let poll = receiver_fut_poll(&self_.shared, &mut self_.waker_key, cx.waker());
Poll::Ready(ready!(poll)).map_ok(Clone::clone)
}
}
impl<T> Drop for ReceiverFuture<T> {
fn drop(&mut self) {
receiver_fut_drop(&self.shared, &mut self.waker_key);
}
}
fn receiver_fut_poll<'a, T>(
shared: &'a Shared<T>,
waker_key: &mut Option<WakerKey>,
new_waker: &Waker,
) -> Poll<Result<&'a T, SenderDropped>> {
if let Some(msg) = shared.msg.get() {
return Poll::Ready(msg.as_ref().or(Err(SenderDropped)));
}
let mut wakers = shared.wakers.lock().expect("poisoned");
if let Some(msg) = shared.msg.get() {
return Poll::Ready(msg.as_ref().or(Err(SenderDropped)));
}
let wakers = wakers.as_mut().expect("wakers were already woken");
match waker_key {
Some(waker_key) => {
let waker = wakers
.get_mut(*waker_key)
.expect("waker key is missing from map");
waker.clone_from(new_waker);
}
None => {
let new_key = wakers.insert(new_waker.clone());
*waker_key = Some(new_key);
}
}
Poll::Pending
}
fn receiver_fut_drop<T>(shared: &Shared<T>, waker_key: &mut Option<WakerKey>) {
if let Some(waker_key) = waker_key.take() {
let mut wakers = shared.wakers.lock().expect("poisoned");
if let Ok(wakers) = wakers.as_mut() {
let waker = wakers.remove(waker_key);
debug_assert!(waker.is_some(), "the waker key was not found");
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use futures::future::FutureExt;
use tor_rtcompat::SpawnExt;
impl<T> Shared<T> {
fn count_wakers(&self) -> usize {
self.wakers
.lock()
.expect("poisoned")
.as_ref()
.map(|x| x.len())
.unwrap_or(0)
}
}
#[test]
fn standard_usage() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx) = channel();
tx.send(0_u8);
assert_eq!(rx.borrowed().await, Ok(&0));
let (tx, rx) = channel();
tx.send(0_u8);
assert_eq!(rx.await, Ok(0));
});
}
#[test]
fn immediate_drop() {
let _ = channel::<()>();
let (tx, rx) = channel::<()>();
drop(tx);
drop(rx);
let (tx, rx) = channel::<()>();
drop(rx);
drop(tx);
}
#[test]
fn drop_sender() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx_1) = channel::<u8>();
let rx_2 = rx_1.clone();
drop(tx);
let rx_3 = rx_1.clone();
assert_eq!(rx_1.borrowed().await, Err(SenderDropped));
assert_eq!(rx_2.borrowed().await, Err(SenderDropped));
assert_eq!(rx_3.borrowed().await, Err(SenderDropped));
});
}
#[test]
fn clone_before_send() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx_1) = channel();
let rx_2 = rx_1.clone();
tx.send(0_u8);
assert_eq!(rx_1.borrowed().await, Ok(&0));
assert_eq!(rx_2.borrowed().await, Ok(&0));
});
}
#[test]
fn clone_after_send() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx_1) = channel();
tx.send(0_u8);
let rx_2 = rx_1.clone();
assert_eq!(rx_1.borrowed().await, Ok(&0));
assert_eq!(rx_2.borrowed().await, Ok(&0));
});
}
#[test]
fn clone_after_borrowed() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx_1) = channel();
tx.send(0_u8);
assert_eq!(rx_1.borrowed().await, Ok(&0));
let rx_2 = rx_1.clone();
assert_eq!(rx_2.borrowed().await, Ok(&0));
});
}
#[test]
fn drop_one_receiver() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx_1) = channel();
let rx_2 = rx_1.clone();
drop(rx_1);
tx.send(0_u8);
assert_eq!(rx_2.borrowed().await, Ok(&0));
});
}
#[test]
fn drop_all_receivers() {
let (tx, rx_1) = channel();
let rx_2 = rx_1.clone();
drop(rx_1);
drop(rx_2);
tx.send(0_u8);
}
#[test]
fn drop_fut() {
let (_tx, rx) = channel::<u8>();
let fut = rx.borrowed();
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (tx, rx) = channel();
tx.send(0_u8);
let fut = rx.borrowed();
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (_tx, rx) = channel::<u8>();
let mut fut = Box::pin(rx.borrowed());
assert_eq!(rx.shared.count_wakers(), 0);
assert_eq!(fut.as_mut().now_or_never(), None);
assert_eq!(rx.shared.count_wakers(), 1);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (tx, rx) = channel();
let mut fut = Box::pin(rx.borrowed());
assert_eq!(rx.shared.count_wakers(), 0);
assert_eq!(fut.as_mut().now_or_never(), None);
assert_eq!(rx.shared.count_wakers(), 1);
tx.send(0_u8);
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
}
#[test]
fn drop_owned_fut() {
let (_tx, rx) = channel::<u8>();
let fut = rx.clone().into_future();
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (tx, rx) = channel();
tx.send(0_u8);
let fut = rx.clone().into_future();
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (_tx, rx) = channel::<u8>();
let mut fut = Box::pin(rx.clone().into_future());
assert_eq!(rx.shared.count_wakers(), 0);
assert_eq!(fut.as_mut().now_or_never(), None);
assert_eq!(rx.shared.count_wakers(), 1);
drop(fut);
assert_eq!(rx.shared.count_wakers(), 0);
let (tx, rx) = channel();
let mut fut = Box::pin(rx.clone().into_future());
assert_eq!(rx.shared.count_wakers(), 0);
assert_eq!(fut.as_mut().now_or_never(), None);
assert_eq!(rx.shared.count_wakers(), 1);
tx.send(0_u8);
assert_eq!(rx.shared.count_wakers(), 0);
drop(fut);
}
#[test]
fn is_ready_after_send() {
let (tx, rx_1) = channel();
assert!(!rx_1.is_ready());
let rx_2 = rx_1.clone();
assert!(!rx_2.is_ready());
tx.send(0_u8);
assert!(rx_1.is_ready());
assert!(rx_2.is_ready());
let rx_3 = rx_1.clone();
assert!(rx_3.is_ready());
}
#[test]
fn is_ready_after_drop() {
let (tx, rx_1) = channel::<u8>();
assert!(!rx_1.is_ready());
let rx_2 = rx_1.clone();
assert!(!rx_2.is_ready());
drop(tx);
assert!(rx_1.is_ready());
assert!(rx_2.is_ready());
let rx_3 = rx_1.clone();
assert!(rx_3.is_ready());
}
#[test]
fn is_cancelled() {
let (tx, rx) = channel::<u8>();
assert!(!tx.is_cancelled());
drop(rx);
assert!(tx.is_cancelled());
let (tx, rx_1) = channel::<u8>();
assert!(!tx.is_cancelled());
let rx_2 = rx_1.clone();
drop(rx_1);
assert!(!tx.is_cancelled());
drop(rx_2);
assert!(tx.is_cancelled());
}
#[test]
fn recv_in_task() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let (tx, rx) = channel();
let join = rt
.spawn_with_handle(async move {
assert_eq!(rx.borrowed().await, Ok(&0));
assert_eq!(rx.await, Ok(0));
})
.unwrap();
tx.send(0_u8);
join.await;
});
}
#[test]
fn recv_multiple_in_task() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let (tx, rx) = channel();
let rx_1 = rx.clone();
let rx_2 = rx.clone();
let join_1 = rt
.spawn_with_handle(async move {
assert_eq!(rx_1.borrowed().await, Ok(&0));
})
.unwrap();
let join_2 = rt
.spawn_with_handle(async move {
assert_eq!(rx_2.await, Ok(0));
})
.unwrap();
tx.send(0_u8);
join_1.await;
join_2.await;
assert_eq!(rx.borrowed().await, Ok(&0));
});
}
#[test]
fn recv_multiple_times() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx) = channel();
tx.send(0_u8);
assert_eq!(rx.borrowed().await, Ok(&0));
assert_eq!(rx.borrowed().await, Ok(&0));
assert_eq!(rx.clone().await, Ok(0));
assert_eq!(rx.await, Ok(0));
});
}
#[test]
fn stress() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let (tx, rx) = channel();
rt.spawn(async move {
for _ in 0..20 {
tor_rtcompat::task::yield_now().await;
}
tx.send(0_u8);
})
.unwrap();
let mut joins = vec![];
for _ in 0..100 {
let rx_clone = rx.clone();
let join = rt
.spawn_with_handle(async move { rx_clone.borrowed().await.cloned() })
.unwrap();
joins.push(join);
tor_rtcompat::task::yield_now().await;
}
for join in joins {
assert!(matches!(join.await, Ok(0)));
}
});
}
}