use std::future::Future;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::error::OperationError;
pub async fn try_join_all<I, F, T>(futures: I) -> Result<Vec<T>, OperationError>
where
I: IntoIterator<Item = F>,
F: Future<Output = Result<T, OperationError>>,
{
futures_util::future::try_join_all(futures).await
}
pub async fn try_join_all_limited<I, F, T>(
futures: I,
limit: usize,
) -> Result<Vec<T>, OperationError>
where
I: IntoIterator<Item = F>,
F: Future<Output = Result<T, OperationError>>,
{
assert!(limit > 0, "concurrency limit must be greater than 0");
let sem = Arc::new(Semaphore::new(limit));
let guarded = futures.into_iter().map(|f| {
let sem = sem.clone();
async move {
let _permit = sem.acquire().await.expect("semaphore closed unexpectedly");
f.await
}
});
futures_util::future::try_join_all(guarded).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::operations::shell::Shell;
#[tokio::test]
async fn try_join_all_empty_returns_empty_vec() {
let result: Result<Vec<()>, OperationError> = try_join_all(Vec::<
std::pin::Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>,
>::new())
.await;
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn try_join_all_single_future() {
let results = try_join_all(vec![Shell::new("echo hello").run()])
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].stdout().trim(), "hello");
}
#[tokio::test]
async fn try_join_all_multiple_futures_preserves_order() {
let results = try_join_all(vec![
Shell::new("echo one").run(),
Shell::new("echo two").run(),
Shell::new("echo three").run(),
])
.await
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].stdout().trim(), "one");
assert_eq!(results[1].stdout().trim(), "two");
assert_eq!(results[2].stdout().trim(), "three");
}
#[tokio::test]
async fn try_join_all_runs_concurrently() {
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let futs: Vec<_> = (0..3)
.map(|i| {
let concurrent = concurrent.clone();
let max_concurrent = max_concurrent.clone();
async move {
let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(current, Ordering::SeqCst);
let result = Shell::new(&format!("sleep 0.05 && echo {i}")).run().await;
concurrent.fetch_sub(1, Ordering::SeqCst);
result
}
})
.collect();
let results = try_join_all(futs).await.unwrap();
assert_eq!(results.len(), 3);
assert!(
max_concurrent.load(Ordering::SeqCst) >= 2,
"expected concurrent execution, max concurrency was {}",
max_concurrent.load(Ordering::SeqCst)
);
}
#[tokio::test]
async fn try_join_all_returns_first_error() {
let result = try_join_all(vec![
Shell::new("echo ok").run(),
Shell::new("exit 1").run(),
Shell::new("echo also ok").run(),
])
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, OperationError::Shell { exit_code: 1, .. }));
}
#[tokio::test]
async fn try_join_all_from_iterator() {
let commands = ["echo alpha", "echo beta"];
let results = try_join_all(commands.iter().map(|c| Shell::new(c).run()))
.await
.unwrap();
assert_eq!(results[0].stdout().trim(), "alpha");
assert_eq!(results[1].stdout().trim(), "beta");
}
#[tokio::test]
async fn limited_empty_returns_empty_vec() {
let result: Result<Vec<()>, OperationError> = try_join_all_limited(
Vec::<std::pin::Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(
),
3,
)
.await;
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn limited_preserves_order() {
let results = try_join_all_limited(
vec![
Shell::new("echo one").run(),
Shell::new("echo two").run(),
Shell::new("echo three").run(),
],
2,
)
.await
.unwrap();
assert_eq!(results[0].stdout().trim(), "one");
assert_eq!(results[1].stdout().trim(), "two");
assert_eq!(results[2].stdout().trim(), "three");
}
#[tokio::test]
async fn limited_respects_concurrency_limit() {
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let futs: Vec<_> = (0..6)
.map(|i| {
let concurrent = concurrent.clone();
let max_concurrent = max_concurrent.clone();
async move {
let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(current, Ordering::SeqCst);
let result = Shell::new(&format!("sleep 0.05 && echo {i}")).run().await;
concurrent.fetch_sub(1, Ordering::SeqCst);
result
}
})
.collect();
let results = try_join_all_limited(futs, 2).await.unwrap();
assert_eq!(results.len(), 6);
assert!(
max_concurrent.load(Ordering::SeqCst) <= 2,
"max concurrency was {}, expected <= 2",
max_concurrent.load(Ordering::SeqCst)
);
}
#[tokio::test]
async fn limited_returns_first_error() {
let result = try_join_all_limited(
vec![
Shell::new("echo ok").run(),
Shell::new("exit 42").run(),
Shell::new("echo also ok").run(),
],
2,
)
.await;
assert!(result.is_err());
}
#[test]
#[should_panic(expected = "concurrency limit must be greater than 0")]
fn limited_zero_limit_panics() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _: Result<Vec<()>, _> = try_join_all_limited(
Vec::<std::pin::Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(),
0,
)
.await;
});
}
#[tokio::test]
async fn limited_with_limit_one_runs_sequentially() {
let concurrent = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let futs: Vec<_> = (0..3)
.map(|i| {
let concurrent = concurrent.clone();
let max_concurrent = max_concurrent.clone();
async move {
let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(current, Ordering::SeqCst);
let result = Shell::new(&format!("sleep 0.05 && echo {i}")).run().await;
concurrent.fetch_sub(1, Ordering::SeqCst);
result
}
})
.collect();
let results = try_join_all_limited(futs, 1).await.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(
max_concurrent.load(Ordering::SeqCst),
1,
"expected max concurrency of 1, got {}",
max_concurrent.load(Ordering::SeqCst)
);
}
#[tokio::test]
async fn limited_with_limit_greater_than_count() {
let results = try_join_all_limited(
vec![Shell::new("echo x").run(), Shell::new("echo y").run()],
100,
)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].stdout().trim(), "x");
assert_eq!(results[1].stdout().trim(), "y");
}
}