1use std::future::Future;
73use std::sync::Arc;
74
75use futures_util::future;
76use tokio::sync::Semaphore;
77
78use crate::error::OperationError;
79
80pub async fn try_join_all<I, F, T>(futures: I) -> Result<Vec<T>, OperationError>
108where
109 I: IntoIterator<Item = F>,
110 F: Future<Output = Result<T, OperationError>>,
111{
112 future::try_join_all(futures).await
113}
114
115pub async fn try_join_all_limited<I, F, T>(
150 futures: I,
151 limit: usize,
152) -> Result<Vec<T>, OperationError>
153where
154 I: IntoIterator<Item = F>,
155 F: Future<Output = Result<T, OperationError>>,
156{
157 assert!(limit > 0, "concurrency limit must be greater than 0");
158
159 let sem = Arc::new(Semaphore::new(limit));
160 let guarded = futures.into_iter().map(|f| {
161 let sem = sem.clone();
162 async move {
163 let _permit = sem.acquire().await.expect("semaphore closed unexpectedly");
164 f.await
165 }
166 });
167
168 future::try_join_all(guarded).await
169}
170
171#[cfg(test)]
172mod tests {
173 use std::pin::Pin;
174 use std::sync::atomic::{AtomicUsize, Ordering};
175
176 use super::*;
177 use crate::operations::shell::Shell;
178
179 #[tokio::test]
180 async fn try_join_all_empty_returns_empty_vec() {
181 let result: Result<Vec<()>, OperationError> = try_join_all(Vec::<
182 Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>,
183 >::new())
184 .await;
185 assert!(result.unwrap().is_empty());
186 }
187
188 #[tokio::test]
189 async fn try_join_all_single_future() {
190 let results = try_join_all(vec![Shell::new("echo hello").dry_run(false).run()])
191 .await
192 .unwrap();
193 assert_eq!(results.len(), 1);
194 assert_eq!(results[0].stdout().trim(), "hello");
195 }
196
197 #[tokio::test]
198 async fn try_join_all_multiple_futures_preserves_order() {
199 let results = try_join_all(vec![
200 Shell::new("echo one").dry_run(false).run(),
201 Shell::new("echo two").dry_run(false).run(),
202 Shell::new("echo three").dry_run(false).run(),
203 ])
204 .await
205 .unwrap();
206
207 assert_eq!(results.len(), 3);
208 assert_eq!(results[0].stdout().trim(), "one");
209 assert_eq!(results[1].stdout().trim(), "two");
210 assert_eq!(results[2].stdout().trim(), "three");
211 }
212
213 #[tokio::test]
214 async fn try_join_all_runs_concurrently() {
215 let concurrent = Arc::new(AtomicUsize::new(0));
216 let max_concurrent = Arc::new(AtomicUsize::new(0));
217
218 let futs: Vec<_> = (0..3)
219 .map(|i| {
220 let concurrent = concurrent.clone();
221 let max_concurrent = max_concurrent.clone();
222 async move {
223 let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
224 max_concurrent.fetch_max(current, Ordering::SeqCst);
225 let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
226 .dry_run(false)
227 .run()
228 .await;
229 concurrent.fetch_sub(1, Ordering::SeqCst);
230 result
231 }
232 })
233 .collect();
234
235 let results = try_join_all(futs).await.unwrap();
236 assert_eq!(results.len(), 3);
237 assert!(
239 max_concurrent.load(Ordering::SeqCst) >= 2,
240 "expected concurrent execution, max concurrency was {}",
241 max_concurrent.load(Ordering::SeqCst)
242 );
243 }
244
245 #[tokio::test]
246 async fn try_join_all_returns_first_error() {
247 let result = try_join_all(vec![
248 Shell::new("echo ok").dry_run(false).run(),
249 Shell::new("exit 1").dry_run(false).run(),
250 Shell::new("echo also ok").dry_run(false).run(),
251 ])
252 .await;
253
254 assert!(result.is_err());
255 let err = result.unwrap_err();
256 assert!(matches!(err, OperationError::Shell { exit_code: 1, .. }));
257 }
258
259 #[tokio::test]
260 async fn try_join_all_from_iterator() {
261 let commands = ["echo alpha", "echo beta"];
262 let results = try_join_all(commands.iter().map(|c| Shell::new(c).dry_run(false).run()))
263 .await
264 .unwrap();
265
266 assert_eq!(results[0].stdout().trim(), "alpha");
267 assert_eq!(results[1].stdout().trim(), "beta");
268 }
269
270 #[tokio::test]
273 async fn limited_empty_returns_empty_vec() {
274 let result: Result<Vec<()>, OperationError> = try_join_all_limited(
275 Vec::<Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(),
276 3,
277 )
278 .await;
279 assert!(result.unwrap().is_empty());
280 }
281
282 #[tokio::test]
283 async fn limited_preserves_order() {
284 let results = try_join_all_limited(
285 vec![
286 Shell::new("echo one").dry_run(false).run(),
287 Shell::new("echo two").dry_run(false).run(),
288 Shell::new("echo three").dry_run(false).run(),
289 ],
290 2,
291 )
292 .await
293 .unwrap();
294
295 assert_eq!(results[0].stdout().trim(), "one");
296 assert_eq!(results[1].stdout().trim(), "two");
297 assert_eq!(results[2].stdout().trim(), "three");
298 }
299
300 #[tokio::test]
301 async fn limited_respects_concurrency_limit() {
302 let concurrent = Arc::new(AtomicUsize::new(0));
303 let max_concurrent = Arc::new(AtomicUsize::new(0));
304
305 let futs: Vec<_> = (0..6)
306 .map(|i| {
307 let concurrent = concurrent.clone();
308 let max_concurrent = max_concurrent.clone();
309 async move {
310 let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
311 max_concurrent.fetch_max(current, Ordering::SeqCst);
312 let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
313 .dry_run(false)
314 .run()
315 .await;
316 concurrent.fetch_sub(1, Ordering::SeqCst);
317 result
318 }
319 })
320 .collect();
321
322 let results = try_join_all_limited(futs, 2).await.unwrap();
323 assert_eq!(results.len(), 6);
324 assert!(
325 max_concurrent.load(Ordering::SeqCst) <= 2,
326 "max concurrency was {}, expected <= 2",
327 max_concurrent.load(Ordering::SeqCst)
328 );
329 }
330
331 #[tokio::test]
332 async fn limited_returns_first_error() {
333 let result = try_join_all_limited(
334 vec![
335 Shell::new("echo ok").dry_run(false).run(),
336 Shell::new("exit 42").dry_run(false).run(),
337 Shell::new("echo also ok").dry_run(false).run(),
338 ],
339 2,
340 )
341 .await;
342
343 assert!(result.is_err());
344 }
345
346 #[tokio::test]
347 #[should_panic(expected = "concurrency limit must be greater than 0")]
348 async fn limited_zero_limit_panics() {
349 let _: Result<Vec<()>, _> = try_join_all_limited(
350 Vec::<Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(),
351 0,
352 )
353 .await;
354 }
355
356 #[tokio::test]
357 async fn limited_with_limit_one_runs_sequentially() {
358 let concurrent = Arc::new(AtomicUsize::new(0));
359 let max_concurrent = Arc::new(AtomicUsize::new(0));
360
361 let futs: Vec<_> = (0..3)
362 .map(|i| {
363 let concurrent = concurrent.clone();
364 let max_concurrent = max_concurrent.clone();
365 async move {
366 let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
367 max_concurrent.fetch_max(current, Ordering::SeqCst);
368 let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
369 .dry_run(false)
370 .run()
371 .await;
372 concurrent.fetch_sub(1, Ordering::SeqCst);
373 result
374 }
375 })
376 .collect();
377
378 let results = try_join_all_limited(futs, 1).await.unwrap();
379 assert_eq!(results.len(), 3);
380 assert_eq!(
382 max_concurrent.load(Ordering::SeqCst),
383 1,
384 "expected max concurrency of 1, got {}",
385 max_concurrent.load(Ordering::SeqCst)
386 );
387 }
388
389 #[tokio::test]
390 async fn limited_with_limit_greater_than_count() {
391 let results = try_join_all_limited(
392 vec![
393 Shell::new("echo x").dry_run(false).run(),
394 Shell::new("echo y").dry_run(false).run(),
395 ],
396 100,
397 )
398 .await
399 .unwrap();
400
401 assert_eq!(results.len(), 2);
402 assert_eq!(results[0].stdout().trim(), "x");
403 assert_eq!(results[1].stdout().trim(), "y");
404 }
405}