Skip to main content

enact_core/flow/
parallel.rs

1//! Parallel Flow - Fan-out, fan-in execution
2//!
3//! Execute multiple callables concurrently and aggregate results.
4
5use crate::callable::Callable;
6use crate::kernel::ExecutionId;
7use std::sync::Arc;
8
9/// Result from a parallel execution branch
10#[derive(Debug)]
11pub struct ParallelResult {
12    /// Name of the callable that produced this result
13    pub name: String,
14    /// Execution ID for this branch
15    pub execution_id: ExecutionId,
16    /// Output or error
17    pub output: Result<String, String>,
18}
19
20/// Fan-out strategy - how to distribute input to parallel branches
21#[derive(Debug, Clone, Default)]
22pub enum FanOut {
23    /// Same input to all branches
24    #[default]
25    Broadcast,
26    /// Split input (e.g., by line, by comma)
27    Split { delimiter: String },
28    /// Custom distribution (branch index → input)
29    Custom,
30}
31
32/// Fan-in strategy - how to aggregate parallel results
33#[derive(Debug, Clone)]
34pub enum FanIn {
35    /// Wait for all, concatenate results
36    Concat { separator: String },
37    /// Wait for first success
38    FirstSuccess,
39    /// Wait for all, return as JSON array
40    JsonArray,
41    /// Custom aggregation
42    Custom,
43}
44
45impl Default for FanIn {
46    fn default() -> Self {
47        FanIn::Concat {
48            separator: "\n".to_string(),
49        }
50    }
51}
52
53/// Parallel execution flow
54pub struct ParallelFlow<C: Callable> {
55    /// Branches to execute in parallel
56    branches: Vec<Arc<C>>,
57    /// Flow name
58    name: String,
59    /// Fan-out strategy
60    fan_out: FanOut,
61    /// Fan-in strategy
62    fan_in: FanIn,
63}
64
65impl<C: Callable + 'static> ParallelFlow<C> {
66    /// Create a new parallel flow
67    pub fn new(name: impl Into<String>) -> Self {
68        Self {
69            branches: Vec::new(),
70            name: name.into(),
71            fan_out: FanOut::Broadcast,
72            fan_in: FanIn::Concat {
73                separator: "\n".to_string(),
74            },
75        }
76    }
77
78    /// Add a branch
79    pub fn add_branch(mut self, callable: Arc<C>) -> Self {
80        self.branches.push(callable);
81        self
82    }
83
84    /// Set fan-out strategy
85    pub fn with_fan_out(mut self, strategy: FanOut) -> Self {
86        self.fan_out = strategy;
87        self
88    }
89
90    /// Set fan-in strategy
91    pub fn with_fan_in(mut self, strategy: FanIn) -> Self {
92        self.fan_in = strategy;
93        self
94    }
95
96    /// Execute all branches in parallel
97    pub async fn execute(&self, input: &str) -> Vec<ParallelResult> {
98        let input = input.to_string();
99
100        // Spawn all branches
101        let handles: Vec<_> = self
102            .branches
103            .iter()
104            .enumerate()
105            .map(|(idx, branch)| {
106                let branch = Arc::clone(branch);
107                let branch_input = self.prepare_input(&input, idx);
108                let execution_id = ExecutionId::new();
109                let branch_name = branch.name().to_string();
110
111                tokio::spawn(async move {
112                    let result = branch.run(&branch_input).await;
113                    ParallelResult {
114                        name: branch_name,
115                        execution_id,
116                        output: result.map_err(|e| e.to_string()),
117                    }
118                })
119            })
120            .collect();
121
122        // Collect results
123        let mut results = Vec::new();
124        for handle in handles {
125            match handle.await {
126                Ok(result) => results.push(result),
127                Err(e) => {
128                    results.push(ParallelResult {
129                        name: "unknown".to_string(),
130                        execution_id: ExecutionId::new(),
131                        output: Err(format!("Task panicked: {}", e)),
132                    });
133                }
134            }
135        }
136
137        results
138    }
139
140    /// Execute and aggregate results
141    pub async fn execute_aggregated(&self, input: &str) -> anyhow::Result<String> {
142        let results = self.execute(input).await;
143        self.aggregate_results(results)
144    }
145
146    /// Prepare input for a branch based on fan-out strategy
147    fn prepare_input(&self, input: &str, index: usize) -> String {
148        match &self.fan_out {
149            FanOut::Broadcast => input.to_string(),
150            FanOut::Split { delimiter } => {
151                let parts: Vec<&str> = input.split(delimiter).collect();
152                parts.get(index).copied().unwrap_or("").to_string()
153            }
154            FanOut::Custom => input.to_string(),
155        }
156    }
157
158    /// Aggregate results based on fan-in strategy
159    fn aggregate_results(&self, results: Vec<ParallelResult>) -> anyhow::Result<String> {
160        match &self.fan_in {
161            FanIn::Concat { separator } => {
162                let outputs: Vec<String> =
163                    results.into_iter().filter_map(|r| r.output.ok()).collect();
164                Ok(outputs.join(separator))
165            }
166            FanIn::FirstSuccess => results
167                .into_iter()
168                .find_map(|r| r.output.ok())
169                .ok_or_else(|| anyhow::anyhow!("All branches failed")),
170            FanIn::JsonArray => {
171                let outputs: Vec<String> =
172                    results.into_iter().filter_map(|r| r.output.ok()).collect();
173                Ok(serde_json::to_string(&outputs)?)
174            }
175            FanIn::Custom => {
176                // Default to concat
177                let outputs: Vec<String> =
178                    results.into_iter().filter_map(|r| r.output.ok()).collect();
179                Ok(outputs.join("\n"))
180            }
181        }
182    }
183
184    /// Get the flow name
185    pub fn name(&self) -> &str {
186        &self.name
187    }
188
189    /// Get branch count
190    pub fn branch_count(&self) -> usize {
191        self.branches.len()
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use async_trait::async_trait;
199    use std::time::Duration;
200
201    /// Mock callable for testing
202    struct MockCallable {
203        name: String,
204        response: String,
205        delay_ms: Option<u64>,
206    }
207
208    impl MockCallable {
209        fn new(name: &str, response: &str) -> Self {
210            Self {
211                name: name.to_string(),
212                response: response.to_string(),
213                delay_ms: None,
214            }
215        }
216
217        fn with_delay(name: &str, response: &str, delay_ms: u64) -> Self {
218            Self {
219                name: name.to_string(),
220                response: response.to_string(),
221                delay_ms: Some(delay_ms),
222            }
223        }
224    }
225
226    #[async_trait]
227    impl Callable for MockCallable {
228        fn name(&self) -> &str {
229            &self.name
230        }
231
232        async fn run(&self, input: &str) -> anyhow::Result<String> {
233            if let Some(delay) = self.delay_ms {
234                tokio::time::sleep(Duration::from_millis(delay)).await;
235            }
236            Ok(format!("{}:{}", self.response, input))
237        }
238    }
239
240    #[tokio::test]
241    async fn test_parallel_single_branch() {
242        let flow =
243            ParallelFlow::new("single").add_branch(Arc::new(MockCallable::new("b1", "result1")));
244
245        let results = flow.execute("input").await;
246        assert_eq!(results.len(), 1);
247        assert_eq!(results[0].name, "b1");
248        assert!(results[0].output.as_ref().unwrap().contains("result1"));
249    }
250
251    #[tokio::test]
252    async fn test_parallel_multiple_branches() {
253        let flow = ParallelFlow::new("multi")
254            .add_branch(Arc::new(MockCallable::new("b1", "r1")))
255            .add_branch(Arc::new(MockCallable::new("b2", "r2")))
256            .add_branch(Arc::new(MockCallable::new("b3", "r3")));
257
258        assert_eq!(flow.branch_count(), 3);
259        assert_eq!(flow.name(), "multi");
260
261        let results = flow.execute("test").await;
262        assert_eq!(results.len(), 3);
263
264        // All should succeed
265        for result in &results {
266            assert!(result.output.is_ok());
267        }
268    }
269
270    #[tokio::test]
271    async fn test_parallel_executes_concurrently() {
272        use std::time::Instant;
273
274        // Each branch takes 50ms, but they run in parallel
275        let flow = ParallelFlow::new("concurrent")
276            .add_branch(Arc::new(MockCallable::with_delay("b1", "r1", 50)))
277            .add_branch(Arc::new(MockCallable::with_delay("b2", "r2", 50)))
278            .add_branch(Arc::new(MockCallable::with_delay("b3", "r3", 50)));
279
280        let start = Instant::now();
281        let results = flow.execute("test").await;
282        let elapsed = start.elapsed();
283
284        // Should take ~50ms, not 150ms (proves concurrency)
285        assert!(
286            elapsed.as_millis() < 120,
287            "Expected <120ms but took {}ms",
288            elapsed.as_millis()
289        );
290        assert_eq!(results.len(), 3);
291    }
292
293    #[tokio::test]
294    async fn test_parallel_aggregated_concat() {
295        let flow = ParallelFlow::new("concat")
296            .add_branch(Arc::new(MockCallable::new("a", "A")))
297            .add_branch(Arc::new(MockCallable::new("b", "B")))
298            .with_fan_in(FanIn::Concat {
299                separator: "|".to_string(),
300            });
301
302        let result = flow.execute_aggregated("x").await.unwrap();
303        // Results may be in any order due to parallelism
304        assert!(result.contains("A:x"));
305        assert!(result.contains("B:x"));
306        assert!(result.contains("|"));
307    }
308
309    #[tokio::test]
310    async fn test_parallel_aggregated_json_array() {
311        let flow = ParallelFlow::new("json")
312            .add_branch(Arc::new(MockCallable::new("a", "result_a")))
313            .add_branch(Arc::new(MockCallable::new("b", "result_b")))
314            .with_fan_in(FanIn::JsonArray);
315
316        let result = flow.execute_aggregated("input").await.unwrap();
317        let parsed: Vec<String> = serde_json::from_str(&result).unwrap();
318        assert_eq!(parsed.len(), 2);
319    }
320
321    #[tokio::test]
322    async fn test_fan_out_split_distributes_by_index() {
323        let flow = ParallelFlow::new("split")
324            .with_fan_out(FanOut::Split {
325                delimiter: ",".to_string(),
326            })
327            .add_branch(Arc::new(MockCallable::new("a", "first")))
328            .add_branch(Arc::new(MockCallable::new("b", "second")))
329            .add_branch(Arc::new(MockCallable::new("c", "third")));
330
331        let results = flow.execute("one,two,three").await;
332        let outputs: Vec<String> = results.into_iter().map(|r| r.output.unwrap()).collect();
333
334        assert_eq!(outputs[0], "first:one");
335        assert_eq!(outputs[1], "second:two");
336        assert_eq!(outputs[2], "third:three");
337    }
338
339    #[tokio::test]
340    async fn test_parallel_first_success() {
341        struct MaybeFailCallable {
342            name: &'static str,
343            should_fail: bool,
344        }
345
346        #[async_trait]
347        impl Callable for MaybeFailCallable {
348            fn name(&self) -> &str {
349                self.name
350            }
351            async fn run(&self, _input: &str) -> anyhow::Result<String> {
352                if self.should_fail {
353                    anyhow::bail!("Intentional failure")
354                }
355                Ok("success_result".to_string())
356            }
357        }
358
359        let flow = ParallelFlow::new("first_success")
360            .add_branch(Arc::new(MaybeFailCallable {
361                name: "fail",
362                should_fail: true,
363            }))
364            .add_branch(Arc::new(MaybeFailCallable {
365                name: "success",
366                should_fail: false,
367            }))
368            .with_fan_in(FanIn::FirstSuccess);
369
370        let result = flow.execute_aggregated("test").await.unwrap();
371        assert_eq!(result, "success_result");
372    }
373
374    #[tokio::test]
375    async fn test_parallel_all_fail_first_success() {
376        struct FailCallable(&'static str);
377
378        #[async_trait]
379        impl Callable for FailCallable {
380            fn name(&self) -> &str {
381                self.0
382            }
383            async fn run(&self, _input: &str) -> anyhow::Result<String> {
384                anyhow::bail!("Failed: {}", self.0)
385            }
386        }
387
388        let flow = ParallelFlow::new("all_fail")
389            .add_branch(Arc::new(FailCallable("f1")))
390            .add_branch(Arc::new(FailCallable("f2")))
391            .with_fan_in(FanIn::FirstSuccess);
392
393        let result = flow.execute_aggregated("test").await;
394        assert!(result.is_err());
395        assert!(result
396            .unwrap_err()
397            .to_string()
398            .contains("All branches failed"));
399    }
400
401    #[tokio::test]
402    async fn test_fan_out_broadcast() {
403        let flow = ParallelFlow::new("broadcast")
404            .add_branch(Arc::new(MockCallable::new("a", "A")))
405            .add_branch(Arc::new(MockCallable::new("b", "B")))
406            .with_fan_out(FanOut::Broadcast);
407
408        let results = flow.execute("same_input").await;
409        // Both branches should receive the same input
410        for result in results {
411            assert!(result.output.as_ref().unwrap().contains("same_input"));
412        }
413    }
414
415    #[tokio::test]
416    async fn test_parallel_result_contains_execution_id() {
417        let flow =
418            ParallelFlow::new("with_ids").add_branch(Arc::new(MockCallable::new("b1", "r1")));
419
420        let results = flow.execute("test").await;
421        assert_eq!(results.len(), 1);
422        // Each result should have a unique execution ID
423        assert!(!results[0].execution_id.as_str().is_empty());
424    }
425
426    #[test]
427    fn test_fan_out_default() {
428        let fan_out = FanOut::default();
429        matches!(fan_out, FanOut::Broadcast);
430    }
431
432    #[test]
433    fn test_fan_in_default() {
434        let fan_in = FanIn::default();
435        matches!(fan_in, FanIn::Concat { .. });
436    }
437}