use std::future::{poll_fn, Future};
use std::task::Poll;
use futures_util::FutureExt;
pub struct PollAllPreservingOrder<V, F: Future<Output = V> + Unpin> {
values: Vec<Option<V>>, pending_futures: Vec<(usize, F)>,
current_index: usize,
}
impl<V, F: Future<Output = V> + Unpin> PollAllPreservingOrder<V, F> {
pub fn new(futures: Vec<F>) -> Self {
Self {
values: futures.iter().map(|_| None).collect(),
pending_futures: futures.into_iter().enumerate().collect(),
current_index: 0,
}
}
pub async fn next(&mut self) -> Option<V> {
poll_fn(move |cx| {
if self.current_index == self.values.len() {
return Poll::Ready(None);
}
self.pending_futures
.retain_mut(|(i, f)| match f.poll_unpin(cx) {
Poll::Pending => true,
Poll::Ready(e) => {
self.values[*i] = Some(e);
false
}
});
if let Some(next_value) = self.values[self.current_index].take() {
self.current_index += 1;
Poll::Ready(Some(next_value))
} else {
Poll::Pending
}
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future::ready;
#[tokio::test]
async fn test_poll_all_ready() {
let futures = vec![ready(1), ready(2), ready(3), ready(4), ready(5), ready(6)];
let mut poll_all = PollAllPreservingOrder::new(futures);
assert_eq!(poll_all.next().await, Some(1));
assert_eq!(poll_all.next().await, Some(2));
assert_eq!(poll_all.next().await, Some(3));
assert_eq!(poll_all.next().await, Some(4));
assert_eq!(poll_all.next().await, Some(5));
assert_eq!(poll_all.next().await, Some(6));
assert_eq!(poll_all.next().await, None);
}
#[tokio::test]
async fn test_poll_some_pending() {
async fn yield_until_threshold_exceeded(
mut rx: tokio::sync::broadcast::Receiver<i32>,
threshold: i32,
value: i32,
) -> i32 {
loop {
let msg = rx.recv().await.unwrap();
if msg >= threshold {
return value;
}
tokio::task::yield_now().await;
}
}
let (tx, _) = tokio::sync::broadcast::channel(100);
let receivers = vec![
yield_until_threshold_exceeded(tx.subscribe(), 6, 1),
yield_until_threshold_exceeded(tx.subscribe(), 3, 2),
yield_until_threshold_exceeded(tx.subscribe(), 5, 3),
yield_until_threshold_exceeded(tx.subscribe(), 1, 4),
yield_until_threshold_exceeded(tx.subscribe(), 2, 5),
yield_until_threshold_exceeded(tx.subscribe(), 4, 6),
];
for x in 0..10 {
tx.send(x).unwrap();
}
let mut poll_all =
PollAllPreservingOrder::new(receivers.into_iter().map(Box::pin).collect());
assert_eq!(poll_all.next().await, Some(1));
assert_eq!(poll_all.next().await, Some(2));
assert_eq!(poll_all.next().await, Some(3));
assert_eq!(poll_all.next().await, Some(4));
assert_eq!(poll_all.next().await, Some(5));
assert_eq!(poll_all.next().await, Some(6));
assert_eq!(poll_all.next().await, None);
}
}