1use crate::callable::Callable;
6use crate::kernel::ExecutionId;
7use std::sync::Arc;
8
9#[derive(Debug)]
11pub struct ParallelResult {
12 pub name: String,
14 pub execution_id: ExecutionId,
16 pub output: Result<String, String>,
18}
19
20#[derive(Debug, Clone, Default)]
22pub enum FanOut {
23 #[default]
25 Broadcast,
26 Split { delimiter: String },
28 Custom,
30}
31
32#[derive(Debug, Clone)]
34pub enum FanIn {
35 Concat { separator: String },
37 FirstSuccess,
39 JsonArray,
41 Custom,
43}
44
45impl Default for FanIn {
46 fn default() -> Self {
47 FanIn::Concat {
48 separator: "\n".to_string(),
49 }
50 }
51}
52
53pub struct ParallelFlow<C: Callable> {
55 branches: Vec<Arc<C>>,
57 name: String,
59 fan_out: FanOut,
61 fan_in: FanIn,
63}
64
65impl<C: Callable + 'static> ParallelFlow<C> {
66 pub fn new(name: impl Into<String>) -> Self {
68 Self {
69 branches: Vec::new(),
70 name: name.into(),
71 fan_out: FanOut::Broadcast,
72 fan_in: FanIn::Concat {
73 separator: "\n".to_string(),
74 },
75 }
76 }
77
78 pub fn add_branch(mut self, callable: Arc<C>) -> Self {
80 self.branches.push(callable);
81 self
82 }
83
84 pub fn with_fan_out(mut self, strategy: FanOut) -> Self {
86 self.fan_out = strategy;
87 self
88 }
89
90 pub fn with_fan_in(mut self, strategy: FanIn) -> Self {
92 self.fan_in = strategy;
93 self
94 }
95
96 pub async fn execute(&self, input: &str) -> Vec<ParallelResult> {
98 let input = input.to_string();
99
100 let handles: Vec<_> = self
102 .branches
103 .iter()
104 .enumerate()
105 .map(|(idx, branch)| {
106 let branch = Arc::clone(branch);
107 let branch_input = self.prepare_input(&input, idx);
108 let execution_id = ExecutionId::new();
109 let branch_name = branch.name().to_string();
110
111 tokio::spawn(async move {
112 let result = branch.run(&branch_input).await;
113 ParallelResult {
114 name: branch_name,
115 execution_id,
116 output: result.map_err(|e| e.to_string()),
117 }
118 })
119 })
120 .collect();
121
122 let mut results = Vec::new();
124 for handle in handles {
125 match handle.await {
126 Ok(result) => results.push(result),
127 Err(e) => {
128 results.push(ParallelResult {
129 name: "unknown".to_string(),
130 execution_id: ExecutionId::new(),
131 output: Err(format!("Task panicked: {}", e)),
132 });
133 }
134 }
135 }
136
137 results
138 }
139
140 pub async fn execute_aggregated(&self, input: &str) -> anyhow::Result<String> {
142 let results = self.execute(input).await;
143 self.aggregate_results(results)
144 }
145
146 fn prepare_input(&self, input: &str, index: usize) -> String {
148 match &self.fan_out {
149 FanOut::Broadcast => input.to_string(),
150 FanOut::Split { delimiter } => {
151 let parts: Vec<&str> = input.split(delimiter).collect();
152 parts.get(index).copied().unwrap_or("").to_string()
153 }
154 FanOut::Custom => input.to_string(),
155 }
156 }
157
158 fn aggregate_results(&self, results: Vec<ParallelResult>) -> anyhow::Result<String> {
160 match &self.fan_in {
161 FanIn::Concat { separator } => {
162 let outputs: Vec<String> =
163 results.into_iter().filter_map(|r| r.output.ok()).collect();
164 Ok(outputs.join(separator))
165 }
166 FanIn::FirstSuccess => results
167 .into_iter()
168 .find_map(|r| r.output.ok())
169 .ok_or_else(|| anyhow::anyhow!("All branches failed")),
170 FanIn::JsonArray => {
171 let outputs: Vec<String> =
172 results.into_iter().filter_map(|r| r.output.ok()).collect();
173 Ok(serde_json::to_string(&outputs)?)
174 }
175 FanIn::Custom => {
176 let outputs: Vec<String> =
178 results.into_iter().filter_map(|r| r.output.ok()).collect();
179 Ok(outputs.join("\n"))
180 }
181 }
182 }
183
184 pub fn name(&self) -> &str {
186 &self.name
187 }
188
189 pub fn branch_count(&self) -> usize {
191 self.branches.len()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use async_trait::async_trait;
199 use std::time::Duration;
200
201 struct MockCallable {
203 name: String,
204 response: String,
205 delay_ms: Option<u64>,
206 }
207
208 impl MockCallable {
209 fn new(name: &str, response: &str) -> Self {
210 Self {
211 name: name.to_string(),
212 response: response.to_string(),
213 delay_ms: None,
214 }
215 }
216
217 fn with_delay(name: &str, response: &str, delay_ms: u64) -> Self {
218 Self {
219 name: name.to_string(),
220 response: response.to_string(),
221 delay_ms: Some(delay_ms),
222 }
223 }
224 }
225
226 #[async_trait]
227 impl Callable for MockCallable {
228 fn name(&self) -> &str {
229 &self.name
230 }
231
232 async fn run(&self, input: &str) -> anyhow::Result<String> {
233 if let Some(delay) = self.delay_ms {
234 tokio::time::sleep(Duration::from_millis(delay)).await;
235 }
236 Ok(format!("{}:{}", self.response, input))
237 }
238 }
239
240 #[tokio::test]
241 async fn test_parallel_single_branch() {
242 let flow =
243 ParallelFlow::new("single").add_branch(Arc::new(MockCallable::new("b1", "result1")));
244
245 let results = flow.execute("input").await;
246 assert_eq!(results.len(), 1);
247 assert_eq!(results[0].name, "b1");
248 assert!(results[0].output.as_ref().unwrap().contains("result1"));
249 }
250
251 #[tokio::test]
252 async fn test_parallel_multiple_branches() {
253 let flow = ParallelFlow::new("multi")
254 .add_branch(Arc::new(MockCallable::new("b1", "r1")))
255 .add_branch(Arc::new(MockCallable::new("b2", "r2")))
256 .add_branch(Arc::new(MockCallable::new("b3", "r3")));
257
258 assert_eq!(flow.branch_count(), 3);
259 assert_eq!(flow.name(), "multi");
260
261 let results = flow.execute("test").await;
262 assert_eq!(results.len(), 3);
263
264 for result in &results {
266 assert!(result.output.is_ok());
267 }
268 }
269
270 #[tokio::test]
271 async fn test_parallel_executes_concurrently() {
272 use std::time::Instant;
273
274 let flow = ParallelFlow::new("concurrent")
276 .add_branch(Arc::new(MockCallable::with_delay("b1", "r1", 50)))
277 .add_branch(Arc::new(MockCallable::with_delay("b2", "r2", 50)))
278 .add_branch(Arc::new(MockCallable::with_delay("b3", "r3", 50)));
279
280 let start = Instant::now();
281 let results = flow.execute("test").await;
282 let elapsed = start.elapsed();
283
284 assert!(
286 elapsed.as_millis() < 120,
287 "Expected <120ms but took {}ms",
288 elapsed.as_millis()
289 );
290 assert_eq!(results.len(), 3);
291 }
292
293 #[tokio::test]
294 async fn test_parallel_aggregated_concat() {
295 let flow = ParallelFlow::new("concat")
296 .add_branch(Arc::new(MockCallable::new("a", "A")))
297 .add_branch(Arc::new(MockCallable::new("b", "B")))
298 .with_fan_in(FanIn::Concat {
299 separator: "|".to_string(),
300 });
301
302 let result = flow.execute_aggregated("x").await.unwrap();
303 assert!(result.contains("A:x"));
305 assert!(result.contains("B:x"));
306 assert!(result.contains("|"));
307 }
308
309 #[tokio::test]
310 async fn test_parallel_aggregated_json_array() {
311 let flow = ParallelFlow::new("json")
312 .add_branch(Arc::new(MockCallable::new("a", "result_a")))
313 .add_branch(Arc::new(MockCallable::new("b", "result_b")))
314 .with_fan_in(FanIn::JsonArray);
315
316 let result = flow.execute_aggregated("input").await.unwrap();
317 let parsed: Vec<String> = serde_json::from_str(&result).unwrap();
318 assert_eq!(parsed.len(), 2);
319 }
320
321 #[tokio::test]
322 async fn test_fan_out_split_distributes_by_index() {
323 let flow = ParallelFlow::new("split")
324 .with_fan_out(FanOut::Split {
325 delimiter: ",".to_string(),
326 })
327 .add_branch(Arc::new(MockCallable::new("a", "first")))
328 .add_branch(Arc::new(MockCallable::new("b", "second")))
329 .add_branch(Arc::new(MockCallable::new("c", "third")));
330
331 let results = flow.execute("one,two,three").await;
332 let outputs: Vec<String> = results.into_iter().map(|r| r.output.unwrap()).collect();
333
334 assert_eq!(outputs[0], "first:one");
335 assert_eq!(outputs[1], "second:two");
336 assert_eq!(outputs[2], "third:three");
337 }
338
339 #[tokio::test]
340 async fn test_parallel_first_success() {
341 struct MaybeFailCallable {
342 name: &'static str,
343 should_fail: bool,
344 }
345
346 #[async_trait]
347 impl Callable for MaybeFailCallable {
348 fn name(&self) -> &str {
349 self.name
350 }
351 async fn run(&self, _input: &str) -> anyhow::Result<String> {
352 if self.should_fail {
353 anyhow::bail!("Intentional failure")
354 }
355 Ok("success_result".to_string())
356 }
357 }
358
359 let flow = ParallelFlow::new("first_success")
360 .add_branch(Arc::new(MaybeFailCallable {
361 name: "fail",
362 should_fail: true,
363 }))
364 .add_branch(Arc::new(MaybeFailCallable {
365 name: "success",
366 should_fail: false,
367 }))
368 .with_fan_in(FanIn::FirstSuccess);
369
370 let result = flow.execute_aggregated("test").await.unwrap();
371 assert_eq!(result, "success_result");
372 }
373
374 #[tokio::test]
375 async fn test_parallel_all_fail_first_success() {
376 struct FailCallable(&'static str);
377
378 #[async_trait]
379 impl Callable for FailCallable {
380 fn name(&self) -> &str {
381 self.0
382 }
383 async fn run(&self, _input: &str) -> anyhow::Result<String> {
384 anyhow::bail!("Failed: {}", self.0)
385 }
386 }
387
388 let flow = ParallelFlow::new("all_fail")
389 .add_branch(Arc::new(FailCallable("f1")))
390 .add_branch(Arc::new(FailCallable("f2")))
391 .with_fan_in(FanIn::FirstSuccess);
392
393 let result = flow.execute_aggregated("test").await;
394 assert!(result.is_err());
395 assert!(result
396 .unwrap_err()
397 .to_string()
398 .contains("All branches failed"));
399 }
400
401 #[tokio::test]
402 async fn test_fan_out_broadcast() {
403 let flow = ParallelFlow::new("broadcast")
404 .add_branch(Arc::new(MockCallable::new("a", "A")))
405 .add_branch(Arc::new(MockCallable::new("b", "B")))
406 .with_fan_out(FanOut::Broadcast);
407
408 let results = flow.execute("same_input").await;
409 for result in results {
411 assert!(result.output.as_ref().unwrap().contains("same_input"));
412 }
413 }
414
415 #[tokio::test]
416 async fn test_parallel_result_contains_execution_id() {
417 let flow =
418 ParallelFlow::new("with_ids").add_branch(Arc::new(MockCallable::new("b1", "r1")));
419
420 let results = flow.execute("test").await;
421 assert_eq!(results.len(), 1);
422 assert!(!results[0].execution_id.as_str().is_empty());
424 }
425
426 #[test]
427 fn test_fan_out_default() {
428 let fan_out = FanOut::default();
429 matches!(fan_out, FanOut::Broadcast);
430 }
431
432 #[test]
433 fn test_fan_in_default() {
434 let fan_in = FanIn::default();
435 matches!(fan_in, FanIn::Concat { .. });
436 }
437}