use futures::FutureExt;
use std::any::Any;
use std::fmt;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use tokio::task::JoinSet;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum OrderedParallelError {
Panicked(String),
JoinFailed(String),
}
impl fmt::Display for OrderedParallelError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Panicked(message) => write!(f, "parallel branch panicked: {message}"),
Self::JoinFailed(message) => write!(f, "parallel branch join failed: {message}"),
}
}
}
impl std::error::Error for OrderedParallelError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OrderedParallelResult<T> {
pub index: usize,
pub output: Result<T, OrderedParallelError>,
}
#[cfg(test)]
pub(crate) async fn run_ordered_parallel<I, F, Fut, T>(
items: Vec<I>,
run: F,
) -> Vec<OrderedParallelResult<T>>
where
I: Send + 'static,
F: Fn(usize, I) -> Fut + Clone + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let item_count = items.len();
run_ordered_parallel_with_limit(items, item_count.max(1), run).await
}
pub(crate) async fn run_ordered_parallel_with_limit<I, F, Fut, T>(
items: Vec<I>,
max_concurrency: usize,
run: F,
) -> Vec<OrderedParallelResult<T>>
where
I: Send + 'static,
F: Fn(usize, I) -> Fut + Clone + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let mut join_set = JoinSet::new();
let mut pending = items.into_iter().enumerate();
let max_concurrency = max_concurrency.max(1);
let mut active = 0usize;
while active < max_concurrency {
let Some((index, item)) = pending.next() else {
break;
};
let run = run.clone();
join_set.spawn(async move {
let output = AssertUnwindSafe(run(index, item))
.catch_unwind()
.await
.map_err(|payload| {
OrderedParallelError::Panicked(panic_payload_to_string(payload))
});
OrderedParallelResult { index, output }
});
active += 1;
}
let mut results = Vec::new();
while active > 0 {
if let Some(result) = join_set.join_next().await {
active -= 1;
match result {
Ok(result) => results.push(result),
Err(error) => results.push(OrderedParallelResult {
index: usize::MAX,
output: Err(OrderedParallelError::JoinFailed(error.to_string())),
}),
}
}
while active < max_concurrency {
let Some((index, item)) = pending.next() else {
break;
};
let run = run.clone();
join_set.spawn(async move {
let output = AssertUnwindSafe(run(index, item))
.catch_unwind()
.await
.map_err(|payload| {
OrderedParallelError::Panicked(panic_payload_to_string(payload))
});
OrderedParallelResult { index, output }
});
active += 1;
}
}
results.sort_by_key(|result| result.index);
results
}
fn panic_payload_to_string(payload: Box<dyn Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<&str>() {
return (*message).to_string();
}
if let Some(message) = payload.downcast_ref::<String>() {
return message.clone();
}
"unknown panic payload".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::Barrier;
#[tokio::test]
async fn preserves_input_order_when_completion_order_differs() {
let barrier = std::sync::Arc::new(Barrier::new(2));
let results = run_ordered_parallel(vec![120_u64, 10], {
let barrier = std::sync::Arc::clone(&barrier);
move |_index, delay_ms| {
let barrier = std::sync::Arc::clone(&barrier);
async move {
barrier.wait().await;
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
delay_ms
}
}
})
.await;
let values = results
.into_iter()
.map(|result| result.output.unwrap())
.collect::<Vec<_>>();
assert_eq!(values, vec![120, 10]);
}
#[tokio::test]
async fn catches_branch_panics_without_dropping_other_results() {
let results = run_ordered_parallel(vec![0_u8, 1], |_index, value| async move {
assert!(value != 0, "boom");
value
})
.await;
assert_eq!(results.len(), 2);
assert!(matches!(
results[0].output,
Err(OrderedParallelError::Panicked(_))
));
assert_eq!(results[1].output, Ok(1));
}
#[tokio::test]
async fn respects_concurrency_limit() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let active = Arc::new(AtomicUsize::new(0));
let max_active = Arc::new(AtomicUsize::new(0));
let results = run_ordered_parallel_with_limit(vec![1_u8, 2, 3, 4], 2, {
let active = Arc::clone(&active);
let max_active = Arc::clone(&max_active);
move |_index, value| {
let active = Arc::clone(&active);
let max_active = Arc::clone(&max_active);
async move {
let now = active.fetch_add(1, Ordering::SeqCst) + 1;
max_active.fetch_max(now, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
active.fetch_sub(1, Ordering::SeqCst);
value
}
}
})
.await;
assert_eq!(max_active.load(Ordering::SeqCst), 2);
assert_eq!(
results
.into_iter()
.map(|result| result.output.unwrap())
.collect::<Vec<_>>(),
vec![1, 2, 3, 4]
);
}
}