use futures::stream::{FuturesUnordered, Stream};
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[pin_project(project = ThenConcurrentProj)]
pub struct ThenConcurrent<St, Fut: Future, F> {
#[pin]
stream: St,
#[pin]
futures: FuturesUnordered<Fut>,
fun: F,
limit: Option<usize>,
}
impl<St, Fut, F, T> Stream for ThenConcurrent<St, Fut, F>
where
St: Stream,
Fut: Future<Output = T>,
F: FnMut(St::Item) -> Fut,
{
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let ThenConcurrentProj {
mut stream,
mut futures,
fun,
limit,
} = self.project();
if limit.as_ref().is_none_or(|&l| futures.len() < l) {
loop {
match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(n)) => {
futures.push(fun(n));
if limit.as_ref().is_some_and(|&l| futures.len() >= l) {
break;
}
}
Poll::Ready(None) => {
if futures.is_empty() {
return Poll::Ready(None);
}
break;
}
Poll::Pending => {
if futures.is_empty() {
return Poll::Pending;
}
break;
}
}
}
}
futures.as_mut().poll_next(cx)
}
}
pub trait StreamThenConcurrentExt: Stream {
fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
where
Self: Sized,
Fut: Future,
F: FnMut(Self::Item) -> Fut,
L: Into<Option<usize>>;
}
impl<S: Stream> StreamThenConcurrentExt for S {
fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
where
Self: Sized,
Fut: Future,
F: FnMut(Self::Item) -> Fut,
L: Into<Option<usize>>,
{
ThenConcurrent {
stream: self,
futures: FuturesUnordered::new(),
fun: f,
limit: limit.into().filter(|&l| l > 0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{channel::mpsc::unbounded, StreamExt};
#[tokio::test]
async fn no_items() {
let stream = futures::stream::iter::<Vec<u64>>(vec![]).then_concurrent(|_| async move {
panic!("must not be called");
}, None);
assert_eq!(stream.collect::<Vec<_>>().await, vec![]);
}
#[tokio::test]
async fn paused_stream() {
let (mut tx, rx) = unbounded::<u64>();
let mut stream = rx.then_concurrent(|x| async move {
if x == 0 {
x
} else {
tokio::time::sleep(std::time::Duration::from_millis(x)).await;
x
}
}, None);
let first_item = stream.next();
tx.start_send(0).unwrap();
assert_eq!(first_item.await, Some(0));
let second_item = stream.next();
tx.start_send(5).unwrap();
assert_eq!(second_item.await, Some(5));
}
#[tokio::test]
async fn fast_items() {
let item_1 = 0u64;
let item_2 = 0u64;
let item_3 = 7u64;
let stream =
futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
if x == 0 {
x
} else {
tokio::time::sleep(std::time::Duration::from_millis(x)).await;
x
}
}, None);
let actual_packets = stream.collect::<Vec<u64>>().await;
assert_eq!(actual_packets, vec![0, 0, 7]);
}
#[tokio::test]
async fn reorder_items() {
let item_1 = 10u64; let item_2 = 5u64; let item_3 = 7u64;
let stream =
futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
tokio::time::sleep(std::time::Duration::from_millis(x)).await;
x
}, None);
let actual_packets = stream.collect::<Vec<u64>>().await;
assert_eq!(actual_packets, vec![5, 7, 10]);
}
}