use std::future::Future;
use futures::stream::{FuturesUnordered, StreamExt};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DagScheduler {
max_in_flight: usize,
}
impl DagScheduler {
pub fn new(max_in_flight: usize) -> Self {
Self {
max_in_flight: max_in_flight.max(1),
}
}
pub fn max_in_flight(self) -> usize {
self.max_in_flight
}
pub async fn run_bounded<I, F, Fut, T, E>(
&self,
inputs: I,
mut task_builder: F,
) -> Result<Vec<T>, E>
where
I: IntoIterator,
F: FnMut(I::Item) -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let mut indexed_inputs = inputs.into_iter().enumerate();
let mut in_flight = FuturesUnordered::new();
let mut results: Vec<Option<T>> = Vec::new();
loop {
while in_flight.len() < self.max_in_flight {
let Some((index, item)) = indexed_inputs.next() else {
break;
};
if results.len() <= index {
results.resize_with(index + 1, || None);
}
let task = task_builder(item);
in_flight.push(async move { (index, task.await) });
}
let Some((index, output)) = in_flight.next().await else {
break;
};
match output {
Ok(value) => {
results[index] = Some(value);
}
Err(error) => return Err(error),
}
}
Ok(results
.into_iter()
.map(|entry| entry.expect("scheduler result slot must be filled"))
.collect())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::time::sleep;
use super::DagScheduler;
#[tokio::test]
async fn respects_max_in_flight_limit() {
let scheduler = DagScheduler::new(2);
let in_flight = Arc::new(Mutex::new(0usize));
let peak = Arc::new(Mutex::new(0usize));
let outputs = scheduler
.run_bounded(0..8usize, {
let in_flight = Arc::clone(&in_flight);
let peak = Arc::clone(&peak);
move |item| {
let in_flight = Arc::clone(&in_flight);
let peak = Arc::clone(&peak);
async move {
{
let mut active = in_flight.lock().await;
*active += 1;
let mut peak_guard = peak.lock().await;
*peak_guard = (*peak_guard).max(*active);
}
sleep(Duration::from_millis(10)).await;
{
let mut active = in_flight.lock().await;
*active = active.saturating_sub(1);
}
Ok::<usize, ()>(item * 2)
}
}
})
.await
.expect("bounded scheduling should succeed");
assert_eq!(outputs, vec![0, 2, 4, 6, 8, 10, 12, 14]);
assert!(*peak.lock().await <= 2);
}
#[tokio::test]
async fn runs_concurrently_when_limit_above_one() {
let serial_scheduler = DagScheduler::new(1);
let serial_started = Instant::now();
let _ = serial_scheduler
.run_bounded(0..4usize, |_| async {
sleep(Duration::from_millis(20)).await;
Ok::<(), ()>(())
})
.await
.expect("scheduler should run all tasks");
let serial_elapsed = serial_started.elapsed();
let parallel_scheduler = DagScheduler::new(4);
let parallel_started = Instant::now();
let _ = parallel_scheduler
.run_bounded(0..4usize, |_| async {
sleep(Duration::from_millis(20)).await;
Ok::<(), ()>(())
})
.await
.expect("scheduler should run all tasks");
let parallel_elapsed = parallel_started.elapsed();
assert!(
parallel_elapsed < serial_elapsed,
"expected parallel scheduler to finish faster (parallel={parallel_elapsed:?}, serial={serial_elapsed:?})"
);
}
}