use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub struct LimitedJoin<Fut>
where
Fut: Future,
{
inner: Pin<Box<[MaybeCompleted<Fut>]>>,
concurrency: usize,
}
pub fn join<Fut>(futures: impl IntoIterator<Item = Fut>, concurrency: usize) -> LimitedJoin<Fut>
where
Fut: Future,
{
let futures = futures
.into_iter()
.map(MaybeCompleted::InProgress)
.collect::<Vec<_>>()
.into_boxed_slice();
LimitedJoin {
inner: futures.into(),
concurrency,
}
}
impl<Fut> Future for LimitedJoin<Fut>
where
Fut: Future,
{
type Output = Vec<Fut::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { Pin::get_unchecked_mut(self) };
let states = unsafe { Pin::get_unchecked_mut(this.inner.as_mut()) };
let mut remaining = states.iter().filter(|state| state.is_in_progress()).count();
let mut to_poll = this.concurrency.min(remaining);
let mut polled = 0;
let mut index = 0;
while polled < to_poll && index < states.len() {
let state = &mut states[index];
if !state.is_in_progress() {
index += 1;
continue;
}
let res = unsafe { Pin::new_unchecked(state).poll(cx) };
if let Poll::Ready(output) = res {
states[index] = MaybeCompleted::Completed(output);
remaining -= 1;
to_poll += 1;
}
polled += 1;
index += 1;
}
if remaining == 0 {
Poll::Ready(states.iter_mut().map(|state| state.take()).collect())
} else {
Poll::Pending
}
}
}
enum MaybeCompleted<Fut: Future> {
InProgress(Fut),
Completed(Fut::Output),
Drained,
}
impl<Fut: Future> MaybeCompleted<Fut> {
fn is_in_progress(&self) -> bool {
matches!(self, Self::InProgress { .. })
}
fn take(&mut self) -> Fut::Output {
match std::mem::replace(self, MaybeCompleted::Drained) {
Self::Completed(output) => output,
Self::InProgress(_) => panic!("attempt to get output of incomplete future"),
Self::Drained => panic!("attempt to get output of drained future"),
}
}
unsafe fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Fut::Output> {
let this = self.as_mut();
let this = this.get_unchecked_mut();
match this {
Self::InProgress(future) => Pin::new_unchecked(future).poll(cx),
_ => unreachable!("attempted to poll a complete or drained future"),
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::time::sleep;
use super::*;
#[tokio::test]
async fn test_not_above_limit() {
let joined = join(
[
sleep(Duration::from_millis(10)),
sleep(Duration::from_millis(20)),
],
10,
);
let timeout = tokio::time::timeout(Duration::from_millis(30), joined);
timeout.await.expect("future timed out before completion");
}
#[tokio::test]
async fn test_above_limit_no_concurrency() {
let completed = Arc::new(AtomicBool::new(false));
let run = |expected: bool| {
let completed = completed.clone();
async move {
let loaded = completed.load(Ordering::SeqCst);
assert_eq!(loaded, expected);
sleep(Duration::from_millis(10)).await;
completed.store(true, Ordering::SeqCst);
}
};
join([run(false), run(true)], 1).await;
}
#[tokio::test]
async fn test_above_limit() {
let (tx, rx) = std::sync::mpsc::channel();
let record = |id: usize, millis: u64| {
let tx = tx.clone();
async move {
tx.send(format!("s{id}")).unwrap();
sleep(Duration::from_millis(millis)).await;
tx.send(format!("e{id}")).unwrap();
}
};
join(
[record(0, 10), record(1, 25), record(2, 50), record(3, 50)],
2,
)
.await;
let mut order = rx.into_iter();
assert_eq!("s0", order.next().unwrap());
assert_eq!("s1", order.next().unwrap());
assert_eq!("e0", order.next().unwrap());
assert_eq!("s2", order.next().unwrap());
assert_eq!("e1", order.next().unwrap());
assert_eq!("s3", order.next().unwrap());
assert_eq!("e2", order.next().unwrap());
assert_eq!("e3", order.next().unwrap());
}
}