Skip to main content

ironflow_core/
parallel.rs

1//! Parallel step execution utilities.
2//!
3//! Run independent workflow steps concurrently to reduce total wall-clock time.
4//! Two patterns are supported:
5//!
6//! # Static parallelism (known number of steps)
7//!
8//! Use [`tokio::try_join!`] when you know at compile time how many steps
9//! to run in parallel:
10//!
11//! ```no_run
12//! use ironflow_core::prelude::*;
13//!
14//! # async fn example() -> Result<(), OperationError> {
15//! let (files, status) = tokio::try_join!(
16//!     Shell::new("ls -la"),
17//!     Shell::new("git status"),
18//! )?;
19//!
20//! println!("files:\n{}", files.stdout());
21//! println!("status:\n{}", status.stdout());
22//! # Ok(())
23//! # }
24//! ```
25//!
26//! # Dynamic parallelism (runtime-determined number of steps)
27//!
28//! Use [`try_join_all`] when the number of steps is determined at runtime:
29//!
30//! ```no_run
31//! use ironflow_core::prelude::*;
32//!
33//! # async fn example() -> Result<(), OperationError> {
34//! let commands = vec!["ls -la", "git status", "df -h"];
35//! let results = try_join_all(
36//!     commands.iter().map(|cmd| Shell::new(cmd).run())
37//! ).await?;
38//!
39//! for (cmd, output) in commands.iter().zip(&results) {
40//!     println!("{cmd}: {}", output.stdout());
41//! }
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! # Concurrency-limited parallelism
47//!
48//! Use [`try_join_all_limited`] to cap the number of steps running
49//! simultaneously (useful when launching many agent calls):
50//!
51//! ```no_run
52//! use ironflow_core::prelude::*;
53//!
54//! # async fn example() -> Result<(), OperationError> {
55//! let provider = ClaudeCodeProvider::new();
56//! let prompts = vec!["Summarize file A", "Summarize file B", "Summarize file C"];
57//!
58//! let results = try_join_all_limited(
59//!     prompts.iter().map(|p| {
60//!         Agent::new()
61//!             .prompt(p)
62//!             .model(Model::HAIKU)
63//!             .max_budget_usd(0.10)
64//!             .run(&provider)
65//!     }),
66//!     2, // at most 2 agent calls at a time
67//! ).await?;
68//! # Ok(())
69//! # }
70//! ```
71
72use std::future::Future;
73use std::sync::Arc;
74
75use futures_util::future;
76use tokio::sync::Semaphore;
77
78use crate::error::OperationError;
79
80/// Run a collection of futures concurrently and collect their results.
81///
82/// All futures start executing immediately. Returns a [`Vec<T>`] in the same
83/// order as the input iterator, or the first [`OperationError`] encountered
84/// (remaining futures are dropped on error).
85///
86/// # Examples
87///
88/// ```no_run
89/// use ironflow_core::prelude::*;
90///
91/// # async fn example() -> Result<(), OperationError> {
92/// let outputs = try_join_all(vec![
93///     Shell::new("echo one").run(),
94///     Shell::new("echo two").run(),
95///     Shell::new("echo three").run(),
96/// ]).await?;
97///
98/// assert_eq!(outputs.len(), 3);
99/// # Ok(())
100/// # }
101/// ```
102///
103/// # Errors
104///
105/// Returns the first [`OperationError`] produced by any future. When an error
106/// occurs, all other in-flight futures are cancelled.
107pub async fn try_join_all<I, F, T>(futures: I) -> Result<Vec<T>, OperationError>
108where
109    I: IntoIterator<Item = F>,
110    F: Future<Output = Result<T, OperationError>>,
111{
112    future::try_join_all(futures).await
113}
114
115/// Run a collection of futures with a concurrency limit.
116///
117/// At most `limit` futures execute simultaneously. Results are returned in
118/// the same order as the input iterator. Useful when running many agent
119/// calls to avoid overwhelming the system or exceeding rate limits.
120///
121/// # Examples
122///
123/// ```no_run
124/// use ironflow_core::prelude::*;
125///
126/// # async fn example() -> Result<(), OperationError> {
127/// let commands: Vec<&str> = (0..20)
128///     .map(|_| "echo hello")
129///     .collect();
130///
131/// let outputs = try_join_all_limited(
132///     commands.iter().map(|cmd| Shell::new(cmd).run()),
133///     5, // run at most 5 in parallel
134/// ).await?;
135///
136/// assert_eq!(outputs.len(), 20);
137/// # Ok(())
138/// # }
139/// ```
140///
141/// # Errors
142///
143/// Returns the first [`OperationError`] produced by any future. When an error
144/// occurs, all other in-flight futures are cancelled.
145///
146/// # Panics
147///
148/// Panics if `limit` is `0`.
149pub async fn try_join_all_limited<I, F, T>(
150    futures: I,
151    limit: usize,
152) -> Result<Vec<T>, OperationError>
153where
154    I: IntoIterator<Item = F>,
155    F: Future<Output = Result<T, OperationError>>,
156{
157    assert!(limit > 0, "concurrency limit must be greater than 0");
158
159    let sem = Arc::new(Semaphore::new(limit));
160    let guarded = futures.into_iter().map(|f| {
161        let sem = sem.clone();
162        async move {
163            let _permit = sem.acquire().await.expect("semaphore closed unexpectedly");
164            f.await
165        }
166    });
167
168    future::try_join_all(guarded).await
169}
170
171#[cfg(test)]
172mod tests {
173    use std::pin::Pin;
174    use std::sync::atomic::{AtomicUsize, Ordering};
175
176    use super::*;
177    use crate::operations::shell::Shell;
178
179    #[tokio::test]
180    async fn try_join_all_empty_returns_empty_vec() {
181        let result: Result<Vec<()>, OperationError> = try_join_all(Vec::<
182            Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>,
183        >::new())
184        .await;
185        assert!(result.unwrap().is_empty());
186    }
187
188    #[tokio::test]
189    async fn try_join_all_single_future() {
190        let results = try_join_all(vec![Shell::new("echo hello").dry_run(false).run()])
191            .await
192            .unwrap();
193        assert_eq!(results.len(), 1);
194        assert_eq!(results[0].stdout().trim(), "hello");
195    }
196
197    #[tokio::test]
198    async fn try_join_all_multiple_futures_preserves_order() {
199        let results = try_join_all(vec![
200            Shell::new("echo one").dry_run(false).run(),
201            Shell::new("echo two").dry_run(false).run(),
202            Shell::new("echo three").dry_run(false).run(),
203        ])
204        .await
205        .unwrap();
206
207        assert_eq!(results.len(), 3);
208        assert_eq!(results[0].stdout().trim(), "one");
209        assert_eq!(results[1].stdout().trim(), "two");
210        assert_eq!(results[2].stdout().trim(), "three");
211    }
212
213    #[tokio::test]
214    async fn try_join_all_runs_concurrently() {
215        let concurrent = Arc::new(AtomicUsize::new(0));
216        let max_concurrent = Arc::new(AtomicUsize::new(0));
217
218        let futs: Vec<_> = (0..3)
219            .map(|i| {
220                let concurrent = concurrent.clone();
221                let max_concurrent = max_concurrent.clone();
222                async move {
223                    let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
224                    max_concurrent.fetch_max(current, Ordering::SeqCst);
225                    let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
226                        .dry_run(false)
227                        .run()
228                        .await;
229                    concurrent.fetch_sub(1, Ordering::SeqCst);
230                    result
231                }
232            })
233            .collect();
234
235        let results = try_join_all(futs).await.unwrap();
236        assert_eq!(results.len(), 3);
237        // All 3 should run concurrently (no limit)
238        assert!(
239            max_concurrent.load(Ordering::SeqCst) >= 2,
240            "expected concurrent execution, max concurrency was {}",
241            max_concurrent.load(Ordering::SeqCst)
242        );
243    }
244
245    #[tokio::test]
246    async fn try_join_all_returns_first_error() {
247        let result = try_join_all(vec![
248            Shell::new("echo ok").dry_run(false).run(),
249            Shell::new("exit 1").dry_run(false).run(),
250            Shell::new("echo also ok").dry_run(false).run(),
251        ])
252        .await;
253
254        assert!(result.is_err());
255        let err = result.unwrap_err();
256        assert!(matches!(err, OperationError::Shell { exit_code: 1, .. }));
257    }
258
259    #[tokio::test]
260    async fn try_join_all_from_iterator() {
261        let commands = ["echo alpha", "echo beta"];
262        let results = try_join_all(commands.iter().map(|c| Shell::new(c).dry_run(false).run()))
263            .await
264            .unwrap();
265
266        assert_eq!(results[0].stdout().trim(), "alpha");
267        assert_eq!(results[1].stdout().trim(), "beta");
268    }
269
270    // --- try_join_all_limited ---
271
272    #[tokio::test]
273    async fn limited_empty_returns_empty_vec() {
274        let result: Result<Vec<()>, OperationError> = try_join_all_limited(
275            Vec::<Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(),
276            3,
277        )
278        .await;
279        assert!(result.unwrap().is_empty());
280    }
281
282    #[tokio::test]
283    async fn limited_preserves_order() {
284        let results = try_join_all_limited(
285            vec![
286                Shell::new("echo one").dry_run(false).run(),
287                Shell::new("echo two").dry_run(false).run(),
288                Shell::new("echo three").dry_run(false).run(),
289            ],
290            2,
291        )
292        .await
293        .unwrap();
294
295        assert_eq!(results[0].stdout().trim(), "one");
296        assert_eq!(results[1].stdout().trim(), "two");
297        assert_eq!(results[2].stdout().trim(), "three");
298    }
299
300    #[tokio::test]
301    async fn limited_respects_concurrency_limit() {
302        let concurrent = Arc::new(AtomicUsize::new(0));
303        let max_concurrent = Arc::new(AtomicUsize::new(0));
304
305        let futs: Vec<_> = (0..6)
306            .map(|i| {
307                let concurrent = concurrent.clone();
308                let max_concurrent = max_concurrent.clone();
309                async move {
310                    let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
311                    max_concurrent.fetch_max(current, Ordering::SeqCst);
312                    let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
313                        .dry_run(false)
314                        .run()
315                        .await;
316                    concurrent.fetch_sub(1, Ordering::SeqCst);
317                    result
318                }
319            })
320            .collect();
321
322        let results = try_join_all_limited(futs, 2).await.unwrap();
323        assert_eq!(results.len(), 6);
324        assert!(
325            max_concurrent.load(Ordering::SeqCst) <= 2,
326            "max concurrency was {}, expected <= 2",
327            max_concurrent.load(Ordering::SeqCst)
328        );
329    }
330
331    #[tokio::test]
332    async fn limited_returns_first_error() {
333        let result = try_join_all_limited(
334            vec![
335                Shell::new("echo ok").dry_run(false).run(),
336                Shell::new("exit 42").dry_run(false).run(),
337                Shell::new("echo also ok").dry_run(false).run(),
338            ],
339            2,
340        )
341        .await;
342
343        assert!(result.is_err());
344    }
345
346    #[tokio::test]
347    #[should_panic(expected = "concurrency limit must be greater than 0")]
348    async fn limited_zero_limit_panics() {
349        let _: Result<Vec<()>, _> = try_join_all_limited(
350            Vec::<Pin<Box<dyn Future<Output = Result<(), OperationError>> + Send>>>::new(),
351            0,
352        )
353        .await;
354    }
355
356    #[tokio::test]
357    async fn limited_with_limit_one_runs_sequentially() {
358        let concurrent = Arc::new(AtomicUsize::new(0));
359        let max_concurrent = Arc::new(AtomicUsize::new(0));
360
361        let futs: Vec<_> = (0..3)
362            .map(|i| {
363                let concurrent = concurrent.clone();
364                let max_concurrent = max_concurrent.clone();
365                async move {
366                    let current = concurrent.fetch_add(1, Ordering::SeqCst) + 1;
367                    max_concurrent.fetch_max(current, Ordering::SeqCst);
368                    let result = Shell::new(&format!("sleep 0.05 && echo {i}"))
369                        .dry_run(false)
370                        .run()
371                        .await;
372                    concurrent.fetch_sub(1, Ordering::SeqCst);
373                    result
374                }
375            })
376            .collect();
377
378        let results = try_join_all_limited(futs, 1).await.unwrap();
379        assert_eq!(results.len(), 3);
380        // With limit=1, only 1 should run at a time
381        assert_eq!(
382            max_concurrent.load(Ordering::SeqCst),
383            1,
384            "expected max concurrency of 1, got {}",
385            max_concurrent.load(Ordering::SeqCst)
386        );
387    }
388
389    #[tokio::test]
390    async fn limited_with_limit_greater_than_count() {
391        let results = try_join_all_limited(
392            vec![
393                Shell::new("echo x").dry_run(false).run(),
394                Shell::new("echo y").dry_run(false).run(),
395            ],
396            100,
397        )
398        .await
399        .unwrap();
400
401        assert_eq!(results.len(), 2);
402        assert_eq!(results[0].stdout().trim(), "x");
403        assert_eq!(results[1].stdout().trim(), "y");
404    }
405}