#![allow(unreachable_pub)]
use std::{
collections::{HashMap, hash_map},
hash::Hash,
pin::Pin,
sync::Arc,
task::Poll,
};
use futures::future::FutureExt;
use futures::{
Future,
channel::mpsc::{UnboundedReceiver, UnboundedSender},
};
use pin_project::pin_project;
struct KeyedWaker<K> {
key: K,
sender: UnboundedSender<K>,
}
impl<K> std::task::Wake for KeyedWaker<K>
where
K: Clone,
{
fn wake(self: Arc<Self>) {
self.sender
.unbounded_send(self.key.clone())
.unwrap_or_else(|e| {
if e.is_disconnected() {
return;
}
tracing::error!("Bug: Unexpected send error: {e:?}");
});
}
}
#[derive(Debug)]
#[pin_project]
pub struct KeyedFuturesUnordered<K, F>
where
F: Future,
{
#[pin]
notification_receiver: UnboundedReceiver<K>,
notification_sender: UnboundedSender<K>,
futures: HashMap<K, F>,
}
impl<K, F> KeyedFuturesUnordered<K, F>
where
F: Future,
K: Eq + Hash + Clone,
{
pub fn new() -> Self {
let (send, recv) = futures::channel::mpsc::unbounded();
Self {
notification_sender: send,
notification_receiver: recv,
futures: Default::default(),
}
}
pub fn try_insert(&mut self, key: K, fut: F) -> Result<(), KeyAlreadyInsertedError<K, F>> {
let hash_map::Entry::Vacant(v) = self.futures.entry(key.clone()) else {
return Err(KeyAlreadyInsertedError { key, fut });
};
v.insert(fut);
self.notification_sender
.unbounded_send(key)
.expect("Unbounded send unexpectedly failed");
Ok(())
}
pub fn remove(&mut self, key: &K) -> Option<(K, F)> {
self.futures.remove_entry(key)
}
#[allow(dead_code)]
pub fn get<'a>(&'a self, key: &K) -> Option<&'a F> {
self.futures.get(key)
}
#[allow(dead_code)]
pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut F> {
self.futures.get_mut(key)
}
}
impl<K, F> futures::Stream for KeyedFuturesUnordered<K, F>
where
F: Future + Unpin,
K: Clone + Hash + Eq + Send + Sync + 'static,
{
type Item = (K, F::Output);
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.futures.is_empty() {
return Poll::Ready(None);
}
let mut self_ = self.project();
loop {
let key = match self_.notification_receiver.as_mut().poll_next(cx) {
Poll::Ready(key) => key.expect("Unexpected end of stream"),
Poll::Pending => {
return Poll::Pending;
}
};
let Some(fut) = self_.futures.get_mut(&key) else {
continue;
};
let waker = std::task::Waker::from(Arc::new(KeyedWaker {
key: key.clone(),
sender: self_.notification_sender.clone(),
}));
match fut.poll_unpin(&mut std::task::Context::from_waker(&waker)) {
Poll::Ready(o) => {
self_.futures.remove(&key);
return Poll::Ready(Some((key, o)));
}
Poll::Pending => {
}
}
}
}
}
#[derive(Debug, thiserror::Error)]
#[allow(clippy::exhaustive_structs)]
pub struct KeyAlreadyInsertedError<K, F> {
#[allow(dead_code)]
pub key: K,
#[allow(dead_code)]
pub fut: F,
}
#[cfg(test)]
mod tests {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use std::task::Waker;
use futures::{StreamExt as _, executor::block_on, future::poll_fn};
use oneshot_fused_workaround as oneshot;
use tor_rtmock::MockRuntime;
use super::*;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct Key(u64);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
struct Value(u64);
#[derive(Debug, Clone)]
struct ValueFut<V> {
value: Option<V>,
ready: bool,
waker: Option<Waker>,
}
impl<V> std::cmp::PartialEq for ValueFut<V>
where
V: std::cmp::PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.value == other.value && self.ready == other.ready
}
}
impl<V> std::cmp::Eq for ValueFut<V> where V: std::cmp::Eq {}
impl<V> ValueFut<V> {
fn ready(value: V) -> Self {
Self {
value: Some(value),
ready: true,
waker: None,
}
}
fn pending(value: V) -> Self {
Self {
value: Some(value),
ready: false,
waker: None,
}
}
fn make_ready(&mut self) {
self.ready = true;
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
impl<V> Future for ValueFut<V>
where
V: Unpin,
{
type Output = V;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
if !self.ready {
self.waker.replace(cx.waker().clone());
Poll::Pending
} else {
Poll::Ready(self.value.take().expect("Polled future after it was ready"))
}
}
}
#[test]
fn test_empty() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::<Key, ValueFut<Value>>::new();
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
assert_eq!(kfu.get(&Key(0)), None);
assert_eq!(kfu.get_mut(&Key(0)), None);
Poll::Ready(())
}));
}
#[test]
fn test_one_pending_future() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::pending(Value(0))));
assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::pending(Value(0))));
Poll::Ready(())
}));
}
#[test]
fn test_one_ready_future() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::ready(Value(1))));
assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::ready(Value(1))));
assert_eq!(
kfu.poll_next_unpin(cx),
Poll::Ready(Some((Key(0), Value(1))))
);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
assert_eq!(kfu.get(&Key(0)), None);
assert_eq!(kfu.get_mut(&Key(0)), None);
Poll::Ready(())
}));
}
#[test]
fn test_one_pending_then_ready_future() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
let (send, recv) = oneshot::channel::<Value>();
kfu.try_insert(Key(0), recv).unwrap();
assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
assert!(kfu.get(&Key(0)).is_some());
assert!(kfu.get_mut(&Key(0)).is_some());
send.send(Value(1)).unwrap();
assert_eq!(
kfu.poll_next_unpin(cx),
Poll::Ready(Some((Key(0), Ok(Value(1)))))
);
assert!(kfu.get(&Key(0)).is_none());
assert!(kfu.get_mut(&Key(0)).is_none());
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
Poll::Ready(())
}));
}
#[test]
fn test_remove_pending() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
assert_eq!(
kfu.remove(&Key(0)),
Some((Key(0), ValueFut::pending(Value(0))))
);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
Poll::Ready(())
}));
}
#[test]
fn test_remove_ready() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
assert_eq!(
kfu.remove(&Key(0)),
Some((Key(0), ValueFut::ready(Value(1))))
);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
Poll::Ready(())
}));
}
#[test]
fn test_remove_and_reuse_ready() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
assert_eq!(
kfu.remove(&Key(0)),
Some((Key(0), ValueFut::ready(Value(1))))
);
kfu.try_insert(Key(0), ValueFut::ready(Value(2))).unwrap();
assert_eq!(
kfu.poll_next_unpin(cx),
Poll::Ready(Some((Key(0), Value(2))))
);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
Poll::Ready(())
}));
}
#[test]
fn test_remove_and_reuse_pending_then_ready() {
block_on(poll_fn(|cx| {
let mut kfu = KeyedFuturesUnordered::new();
kfu.try_insert(Key(0), ValueFut::pending(Value(1))).unwrap();
let (_key, mut removed_value) = kfu.remove(&Key(0)).unwrap();
kfu.try_insert(Key(0), ValueFut::pending(Value(2))).unwrap();
removed_value.make_ready();
assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
kfu.get_mut(&Key(0)).unwrap().make_ready();
assert_eq!(
kfu.poll_next_unpin(cx),
Poll::Ready(Some((Key(0), Value(2))))
);
assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
Poll::Ready(())
}));
}
#[test]
fn test_async() {
MockRuntime::test_with_various(|rt| async move {
let mut kfu = KeyedFuturesUnordered::new();
for i in 0..10 {
let (send, recv) = oneshot::channel();
kfu.try_insert(Key(i), recv).unwrap();
rt.spawn_identified(format!("sender-{i}"), async move {
send.send(Value(i)).unwrap();
});
}
let values = kfu.collect::<Vec<_>>().await;
let mut values = values
.into_iter()
.map(|(k, v)| (k, v.unwrap()))
.collect::<Vec<_>>();
values.sort();
let expected_values = (0..10).map(|i| (Key(i), Value(i))).collect::<Vec<_>>();
assert_eq!(values, expected_values);
});
}
}