use std::future::Future;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{AcquireError, Semaphore};
use tokio::task::{JoinError, JoinSet};
#[derive(Debug, Error)]
pub enum ParutilsError<E: Send + Sync + 'static> {
#[error(transparent)]
Join(#[from] JoinError),
#[error(transparent)]
Acquire(#[from] AcquireError),
#[error(transparent)]
Task(E),
#[error("Infallible, this should not be possible: {0}")]
Infallible(String),
}
pub async fn run_constrained_with_semaphore<Fut, T, E>(
futures_it: impl Iterator<Item = Fut>,
max_concurrent: Arc<Semaphore>,
) -> Result<Vec<T>, ParutilsError<E>>
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
let handle = tokio::runtime::Handle::current();
let mut js: JoinSet<Result<(usize, T), ParutilsError<E>>> = JoinSet::new();
for (i, fut) in futures_it.enumerate() {
let semaphore = max_concurrent.clone();
js.spawn_on(
async move {
let _permit = semaphore.acquire().await?;
let res = fut.await.map_err(ParutilsError::Task)?;
Ok((i, res))
},
&handle,
);
}
let mut results: Vec<Option<T>> = Vec::with_capacity(js.len());
(0..js.len()).for_each(|_| results.push(None));
while let Some(result) = js.join_next().await {
let (i, res) = result??;
debug_assert!(results[i].is_none());
results[i] = Some(res);
}
debug_assert!(js.is_empty());
debug_assert!(results.iter().all(|r| r.is_some()));
let Some(result) = results.into_iter().collect() else {
return Err(ParutilsError::Infallible("A task was unaccounted for when collecting result".to_string()));
};
Ok(result)
}
pub async fn run_constrained<Fut, T, E>(
futures_it: impl Iterator<Item = Fut>,
max_concurrent: usize,
) -> Result<Vec<T>, ParutilsError<E>>
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
let semaphore = Arc::new(Semaphore::new(max_concurrent));
run_constrained_with_semaphore(futures_it, semaphore).await
}
#[cfg(test)]
mod parallel_tests {
use std::sync::atomic::{AtomicU32, Ordering};
use super::*;
#[tokio::test(flavor = "multi_thread")]
async fn test_simple_parallel() {
let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
let r = run_constrained(
data.into_iter()
.enumerate()
.map(|(i, s)| async move { Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i)) }),
4,
)
.await
.unwrap();
assert_eq!(data_ref.len(), r.len());
for i in 0..data_ref.len() {
assert_eq!(data_ref[i], r[i]);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_parallel_with_sleeps() {
let data: Vec<String> = (0..400).map(|i| format!("Number = {}", &i)).collect();
let data_ref: Vec<String> = data.iter().enumerate().map(|(i, s)| format!("{}{}{}", &s, ":", &i)).collect();
let r = run_constrained(
data.into_iter().enumerate().map(|(i, s)| async move {
tokio::time::sleep(std::time::Duration::from_millis(401 - i as u64)).await;
Result::<_, ()>::Ok(format!("{}{}{}", &s, ":", &i))
}),
100,
)
.await
.unwrap();
assert_eq!(data_ref.len(), r.len());
for i in 0..data_ref.len() {
assert_eq!(data_ref[i], r[i]);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_max_concurrent_constraint() {
const NUM_TASKS: u64 = 100;
const TASK_DURATION_BASE_MS: u64 = 100;
const MAX_CONCURRENT: usize = 5;
let current_running = Arc::new(AtomicU32::new(0));
let max_concurrent_observed = Arc::new(AtomicU32::new(0));
let futures = (0..NUM_TASKS).map(|i| {
let current_running = current_running.clone();
let max_concurrent_observed = max_concurrent_observed.clone();
async move {
let running = current_running.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent_observed.fetch_max(running, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(TASK_DURATION_BASE_MS - i)).await;
current_running.fetch_sub(1, Ordering::SeqCst);
Result::<_, ()>::Ok(i)
}
});
let results = run_constrained(futures, MAX_CONCURRENT).await.unwrap();
assert_eq!(results.len(), NUM_TASKS as usize);
for i in 0..NUM_TASKS {
assert_eq!(results[i as usize], i);
}
let max_observed = max_concurrent_observed.load(Ordering::SeqCst);
assert!(
max_observed <= MAX_CONCURRENT as u32,
"Max concurrent tasks observed: {}, but limit was: {}",
max_observed,
MAX_CONCURRENT
);
assert_eq!(
max_observed, MAX_CONCURRENT as u32,
"Expected to see exactly {} concurrent tasks, but saw {}",
MAX_CONCURRENT, max_observed
);
assert_eq!(current_running.load(Ordering::SeqCst), 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_returns_error() {
let futures = (0..10).map(|i| async move {
if i == 5 {
Result::<_, i32>::Err(5)
} else {
Result::<_, i32>::Ok(i)
}
});
let result = run_constrained(futures, 2).await;
assert!(matches!(result, Err(ParutilsError::Task(5))));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_returns_join_error_on_panic() {
let futures = (0..10).map(|i| async move { if i == 5 { panic!("5") } else { Result::<_, i32>::Ok(i) } });
let result = run_constrained(futures, 2).await;
if let Err(ParutilsError::Join(e)) = result {
assert!(e.is_panic());
} else {
assert!(false, "Expected to panic, but got {:?}", result);
}
}
}