1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::sync::Semaphore;
6use tokio::task::JoinSet;
7
8use crate::config::StepConfig;
9use crate::control_flow::ControlFlow;
10use crate::engine::context::Context;
11use crate::error::StepError;
12use crate::workflow::schema::{ScopeDef, StepDef};
13
14use super::{
15 call::dispatch_scope_step, CmdOutput, IterationOutput, ScopeOutput, StepExecutor, StepOutput,
16};
17
18fn apply_reduce(
20 scope: &ScopeOutput,
21 reducer: &str,
22 condition_template: Option<&str>,
23) -> Result<StepOutput, crate::error::StepError> {
24 let iterations = &scope.iterations;
25
26 match reducer {
27 "concat" => {
28 let joined = iterations
29 .iter()
30 .map(|it| it.output.text().to_string())
31 .collect::<Vec<_>>()
32 .join("\n");
33 Ok(StepOutput::Cmd(CmdOutput {
34 stdout: joined,
35 stderr: String::new(),
36 exit_code: 0,
37 duration: std::time::Duration::ZERO,
38 }))
39 }
40 "sum" => {
41 let sum: f64 = iterations
42 .iter()
43 .map(|it| it.output.text().trim().parse::<f64>().unwrap_or(0.0))
44 .sum();
45 let result = if sum.fract() == 0.0 {
47 format!("{}", sum as i64)
48 } else {
49 format!("{}", sum)
50 };
51 Ok(StepOutput::Cmd(CmdOutput {
52 stdout: result,
53 stderr: String::new(),
54 exit_code: 0,
55 duration: std::time::Duration::ZERO,
56 }))
57 }
58 "count" => {
59 Ok(StepOutput::Cmd(CmdOutput {
60 stdout: iterations.len().to_string(),
61 stderr: String::new(),
62 exit_code: 0,
63 duration: std::time::Duration::ZERO,
64 }))
65 }
66 "min" => {
67 let min_val = iterations
68 .iter()
69 .filter_map(|it| it.output.text().trim().parse::<f64>().ok())
70 .fold(f64::INFINITY, f64::min);
71 let result = if min_val.fract() == 0.0 {
72 format!("{}", min_val as i64)
73 } else {
74 format!("{}", min_val)
75 };
76 Ok(StepOutput::Cmd(CmdOutput {
77 stdout: result,
78 stderr: String::new(),
79 exit_code: 0,
80 duration: std::time::Duration::ZERO,
81 }))
82 }
83 "max" => {
84 let max_val = iterations
85 .iter()
86 .filter_map(|it| it.output.text().trim().parse::<f64>().ok())
87 .fold(f64::NEG_INFINITY, f64::max);
88 let result = if max_val.fract() == 0.0 {
89 format!("{}", max_val as i64)
90 } else {
91 format!("{}", max_val)
92 };
93 Ok(StepOutput::Cmd(CmdOutput {
94 stdout: result,
95 stderr: String::new(),
96 exit_code: 0,
97 duration: std::time::Duration::ZERO,
98 }))
99 }
100 "filter" => {
101 let tmpl = condition_template.ok_or_else(|| {
102 crate::error::StepError::Fail(
103 "reduce: 'filter' requires 'reduce_condition' to be set".to_string(),
104 )
105 })?;
106
107 let mut kept = Vec::new();
108 for it in iterations {
109 let mut vars = std::collections::HashMap::new();
111 vars.insert(
112 "item_output".to_string(),
113 serde_json::Value::String(it.output.text().to_string()),
114 );
115 let simplified_tmpl = tmpl
118 .replace("{{item.output}}", "{{ item_output }}")
119 .replace("{{ item.output }}", "{{ item_output }}");
120 let child_ctx =
121 crate::engine::context::Context::new(String::new(), vars);
122 let rendered = child_ctx
123 .render_template(&simplified_tmpl)
124 .unwrap_or_default();
125 let passes = !rendered.trim().is_empty()
126 && rendered.trim() != "false"
127 && rendered.trim() != "0";
128 if passes {
129 kept.push(it.output.text().to_string());
130 }
131 }
132
133 let joined = kept.join("\n");
134 Ok(StepOutput::Cmd(CmdOutput {
135 stdout: joined,
136 stderr: String::new(),
137 exit_code: 0,
138 duration: std::time::Duration::ZERO,
139 }))
140 }
141 other => Err(crate::error::StepError::Fail(format!(
142 "unknown reduce operation '{}'; expected concat, sum, count, filter, min, max",
143 other
144 ))),
145 }
146}
147
148fn apply_collect(scope: ScopeOutput, mode: &str) -> Result<StepOutput, crate::error::StepError> {
150 match mode {
151 "text" => {
152 let joined = scope
153 .iterations
154 .iter()
155 .map(|it| it.output.text().to_string())
156 .collect::<Vec<_>>()
157 .join("\n");
158 Ok(StepOutput::Cmd(CmdOutput {
159 stdout: joined,
160 stderr: String::new(),
161 exit_code: 0,
162 duration: std::time::Duration::ZERO,
163 }))
164 }
165 "all" | "json" => {
166 let arr: Vec<serde_json::Value> = scope
167 .iterations
168 .iter()
169 .map(|it| serde_json::Value::String(it.output.text().to_string()))
170 .collect();
171 let json = serde_json::to_string(&arr)
172 .map_err(|e| crate::error::StepError::Fail(format!("collect serialize error: {e}")))?;
173 Ok(StepOutput::Cmd(CmdOutput {
174 stdout: json,
175 stderr: String::new(),
176 exit_code: 0,
177 duration: std::time::Duration::ZERO,
178 }))
179 }
180 other => Err(crate::error::StepError::Fail(format!(
181 "unknown collect mode '{}'; expected all, text, or json",
182 other
183 ))),
184 }
185}
186
187pub struct MapExecutor {
188 scopes: HashMap<String, ScopeDef>,
189}
190
191impl MapExecutor {
192 pub fn new(scopes: &HashMap<String, ScopeDef>) -> Self {
193 Self {
194 scopes: scopes.clone(),
195 }
196 }
197}
198
199#[async_trait]
200impl StepExecutor for MapExecutor {
201 async fn execute(
202 &self,
203 step: &StepDef,
204 _config: &StepConfig,
205 ctx: &Context,
206 ) -> Result<StepOutput, StepError> {
207 let items_template = step
208 .items
209 .as_ref()
210 .ok_or_else(|| StepError::Fail("map step missing 'items' field".into()))?;
211
212 let scope_name = step
213 .scope
214 .as_ref()
215 .ok_or_else(|| StepError::Fail("map step missing 'scope' field".into()))?;
216
217 let scope = self
218 .scopes
219 .get(scope_name)
220 .ok_or_else(|| StepError::Fail(format!("scope '{}' not found", scope_name)))?
221 .clone();
222
223 let rendered_items = ctx.render_template(items_template)?;
224
225 let items: Vec<String> = if rendered_items.trim().starts_with('[') {
227 serde_json::from_str::<Vec<serde_json::Value>>(&rendered_items)
228 .map(|arr| {
229 arr.into_iter()
230 .map(|v| match v {
231 serde_json::Value::String(s) => s,
232 other => other.to_string(),
233 })
234 .collect()
235 })
236 .unwrap_or_else(|_| {
237 rendered_items
238 .lines()
239 .filter(|l| !l.trim().is_empty())
240 .map(|l| l.to_string())
241 .collect()
242 })
243 } else {
244 rendered_items
245 .lines()
246 .filter(|l| !l.trim().is_empty())
247 .map(|l| l.to_string())
248 .collect()
249 };
250
251 let parallel_count = step.parallel.unwrap_or(0);
252
253 let scope_output = if parallel_count == 0 {
254 serial_execute(items, &scope, ctx, &self.scopes).await?
256 } else {
257 parallel_execute(items, &scope, ctx, &self.scopes, parallel_count).await?
259 };
260
261 let reduce_mode = _config.get_str("reduce").map(|s| s.to_string());
263 if let Some(ref reducer) = reduce_mode {
264 if let StepOutput::Scope(ref s) = scope_output {
265 let condition = _config.get_str("reduce_condition");
266 return apply_reduce(s, reducer, condition);
267 }
268 }
269
270 let collect_mode = _config.get_str("collect").map(|s| s.to_string());
272 match (scope_output, collect_mode) {
273 (StepOutput::Scope(s), Some(mode)) => apply_collect(s, &mode),
274 (output, _) => Ok(output),
275 }
276 }
277}
278
279async fn serial_execute(
280 items: Vec<String>,
281 scope: &ScopeDef,
282 ctx: &Context,
283 scopes: &HashMap<String, ScopeDef>,
284) -> Result<StepOutput, StepError> {
285 let mut iterations = Vec::new();
286
287 for (i, item) in items.iter().enumerate() {
288 let mut child_ctx = make_child_ctx(ctx, Some(serde_json::Value::String(item.clone())), i);
289
290 let iter_output = execute_scope_steps(scope, &mut child_ctx, scopes).await?;
291
292 iterations.push(IterationOutput {
293 index: i,
294 output: iter_output,
295 });
296 }
297
298 let final_value = iterations.last().map(|i| Box::new(i.output.clone()));
299 Ok(StepOutput::Scope(ScopeOutput {
300 iterations,
301 final_value,
302 }))
303}
304
305async fn parallel_execute(
306 items: Vec<String>,
307 scope: &ScopeDef,
308 ctx: &Context,
309 scopes: &HashMap<String, ScopeDef>,
310 parallel_count: usize,
311) -> Result<StepOutput, StepError> {
312 let sem = Arc::new(Semaphore::new(parallel_count));
313 let mut set: JoinSet<(usize, Result<StepOutput, StepError>)> = JoinSet::new();
314
315 for (i, item) in items.iter().enumerate() {
316 let sem = Arc::clone(&sem);
317 let item_val = serde_json::Value::String(item.clone());
318 let child_ctx = make_child_ctx(ctx, Some(item_val), i);
319 let scope_clone = scope.clone();
320 let scopes_clone = scopes.clone();
321
322 set.spawn(async move {
323 let _permit = sem.acquire().await.expect("semaphore closed");
324 let result = execute_scope_steps_owned(scope_clone, child_ctx, scopes_clone).await;
325 (i, result)
326 });
327 }
328
329 let mut results: Vec<Option<StepOutput>> = vec![None; items.len()];
330
331 while let Some(res) = set.join_next().await {
332 match res {
333 Ok((i, Ok(output))) => {
334 results[i] = Some(output);
335 }
336 Ok((_, Err(e))) => {
337 set.abort_all();
338 return Err(e);
339 }
340 Err(e) => {
341 set.abort_all();
342 return Err(StepError::Fail(format!("Task panicked: {e}")));
343 }
344 }
345 }
346
347 let iterations: Vec<IterationOutput> = results
348 .into_iter()
349 .enumerate()
350 .map(|(i, opt)| IterationOutput {
351 index: i,
352 output: opt.unwrap_or(StepOutput::Empty),
353 })
354 .collect();
355
356 let final_value = iterations.last().map(|i| Box::new(i.output.clone()));
357 Ok(StepOutput::Scope(ScopeOutput {
358 iterations,
359 final_value,
360 }))
361}
362
363fn make_child_ctx(
364 parent: &Context,
365 scope_value: Option<serde_json::Value>,
366 index: usize,
367) -> Context {
368 let target = parent
369 .get_var("target")
370 .and_then(|v| v.as_str())
371 .unwrap_or("")
372 .to_string();
373 let mut ctx = Context::new(target, HashMap::new());
374 ctx.scope_value = scope_value;
375 ctx.scope_index = index;
376 ctx
377}
378
379async fn execute_scope_steps(
380 scope: &ScopeDef,
381 child_ctx: &mut Context,
382 scopes: &HashMap<String, ScopeDef>,
383) -> Result<StepOutput, StepError> {
384 let mut last_output = StepOutput::Empty;
385
386 for scope_step in &scope.steps {
387 let config = StepConfig::default();
388 let result = dispatch_scope_step(scope_step, &config, child_ctx, scopes).await;
389
390 match result {
391 Ok(output) => {
392 child_ctx.store(&scope_step.name, output.clone());
393 last_output = output;
394 }
395 Err(StepError::ControlFlow(ControlFlow::Break { value, .. })) => {
396 if let Some(v) = value {
397 last_output = v;
398 }
399 break;
400 }
401 Err(StepError::ControlFlow(ControlFlow::Skip { .. })) => {
402 child_ctx.store(&scope_step.name, StepOutput::Empty);
403 }
404 Err(StepError::ControlFlow(ControlFlow::Next { .. })) => {
405 break;
406 }
407 Err(e) => return Err(e),
408 }
409 }
410
411 if let Some(outputs_template) = &scope.outputs {
413 match child_ctx.render_template(outputs_template) {
414 Ok(rendered) => {
415 return Ok(StepOutput::Cmd(CmdOutput {
416 stdout: rendered,
417 stderr: String::new(),
418 exit_code: 0,
419 duration: std::time::Duration::ZERO,
420 }));
421 }
422 Err(_) => {}
423 }
424 }
425
426 Ok(last_output)
427}
428
429async fn execute_scope_steps_owned(
430 scope: ScopeDef,
431 mut child_ctx: Context,
432 scopes: HashMap<String, ScopeDef>,
433) -> Result<StepOutput, StepError> {
434 execute_scope_steps(&scope, &mut child_ctx, &scopes).await
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use std::collections::HashMap;
441 use crate::workflow::schema::{ScopeDef, StepType};
442
443 fn cmd_step(name: &str, run: &str) -> StepDef {
444 StepDef {
445 name: name.to_string(),
446 step_type: StepType::Cmd,
447 run: Some(run.to_string()),
448 prompt: None,
449 condition: None,
450 on_pass: None,
451 on_fail: None,
452 message: None,
453 scope: None,
454 max_iterations: None,
455 initial_value: None,
456 items: None,
457 parallel: None,
458 steps: None,
459 config: HashMap::new(),
460 outputs: None,
461 output_type: None,
462 async_exec: None,
463 }
464 }
465
466 fn map_step(name: &str, items: &str, scope: &str, parallel: Option<usize>) -> StepDef {
467 StepDef {
468 name: name.to_string(),
469 step_type: StepType::Map,
470 run: None,
471 prompt: None,
472 condition: None,
473 on_pass: None,
474 on_fail: None,
475 message: None,
476 scope: Some(scope.to_string()),
477 max_iterations: None,
478 initial_value: None,
479 items: Some(items.to_string()),
480 parallel,
481 steps: None,
482 config: HashMap::new(),
483 outputs: None,
484 output_type: None,
485 async_exec: None,
486 }
487 }
488
489 fn echo_scope() -> ScopeDef {
490 ScopeDef {
491 steps: vec![cmd_step("echo", "echo {{ scope.value }}")],
492 outputs: None,
493 }
494 }
495
496 #[tokio::test]
497 async fn map_three_items_serial() {
498 let mut scopes = HashMap::new();
499 scopes.insert("echo_scope".to_string(), echo_scope());
500
501 let step = map_step("map_test", "alpha\nbeta\ngamma", "echo_scope", None);
502 let executor = MapExecutor::new(&scopes);
503 let config = StepConfig::default();
504 let ctx = Context::new(String::new(), HashMap::new());
505
506 let result = executor.execute(&step, &config, &ctx).await.unwrap();
507 if let StepOutput::Scope(scope_out) = &result {
508 assert_eq!(scope_out.iterations.len(), 3);
509 assert!(scope_out.iterations[0].output.text().contains("alpha"));
510 assert!(scope_out.iterations[1].output.text().contains("beta"));
511 assert!(scope_out.iterations[2].output.text().contains("gamma"));
512 } else {
513 panic!("Expected Scope output");
514 }
515 }
516
517 #[tokio::test]
518 async fn map_three_items_parallel() {
519 let mut scopes = HashMap::new();
520 scopes.insert("echo_scope".to_string(), echo_scope());
521
522 let step = map_step("map_parallel", "a\nb\nc", "echo_scope", Some(3));
523 let executor = MapExecutor::new(&scopes);
524 let config = StepConfig::default();
525 let ctx = Context::new(String::new(), HashMap::new());
526
527 let result = executor.execute(&step, &config, &ctx).await.unwrap();
528 if let StepOutput::Scope(scope_out) = &result {
529 assert_eq!(scope_out.iterations.len(), 3);
530 } else {
531 panic!("Expected Scope output");
532 }
533 }
534
535 fn map_step_with_config(
536 name: &str,
537 items: &str,
538 scope: &str,
539 config_values: HashMap<String, serde_yaml::Value>,
540 ) -> StepDef {
541 StepDef {
542 name: name.to_string(),
543 step_type: StepType::Map,
544 run: None,
545 prompt: None,
546 condition: None,
547 on_pass: None,
548 on_fail: None,
549 message: None,
550 scope: Some(scope.to_string()),
551 max_iterations: None,
552 initial_value: None,
553 items: Some(items.to_string()),
554 parallel: None,
555 steps: None,
556 config: config_values,
557 outputs: None,
558 output_type: None,
559 async_exec: None,
560 }
561 }
562
563 #[tokio::test]
564 async fn map_collect_text_joins_with_newlines() {
565 let mut scopes = HashMap::new();
566 scopes.insert("echo_scope".to_string(), echo_scope());
567
568 let mut cfg = HashMap::new();
569 cfg.insert(
570 "collect".to_string(),
571 serde_yaml::Value::String("text".to_string()),
572 );
573 let step = map_step_with_config("map_collect_text", "alpha\nbeta\ngamma", "echo_scope", cfg);
574 let executor = MapExecutor::new(&scopes);
575
576 let mut config_values = HashMap::new();
578 config_values.insert(
579 "collect".to_string(),
580 serde_json::Value::String("text".to_string()),
581 );
582 let config = crate::config::StepConfig { values: config_values };
583 let ctx = Context::new(String::new(), HashMap::new());
584
585 let result = executor.execute(&step, &config, &ctx).await.unwrap();
586 assert!(matches!(result, StepOutput::Cmd(_)));
588 let text = result.text();
589 assert!(text.contains("alpha"), "Missing alpha in: {}", text);
590 assert!(text.contains("beta"), "Missing beta in: {}", text);
591 assert!(text.contains("gamma"), "Missing gamma in: {}", text);
592 }
593
594 #[tokio::test]
595 async fn map_collect_all_produces_json_array() {
596 let mut scopes = HashMap::new();
597 scopes.insert("echo_scope".to_string(), echo_scope());
598
599 let step = map_step_with_config(
600 "map_collect_all",
601 "x\ny\nz",
602 "echo_scope",
603 HashMap::new(),
604 );
605 let executor = MapExecutor::new(&scopes);
606
607 let mut config_values = HashMap::new();
608 config_values.insert(
609 "collect".to_string(),
610 serde_json::Value::String("all".to_string()),
611 );
612 let config = crate::config::StepConfig { values: config_values };
613 let ctx = Context::new(String::new(), HashMap::new());
614
615 let result = executor.execute(&step, &config, &ctx).await.unwrap();
616 assert!(matches!(result, StepOutput::Cmd(_)));
617 let text = result.text();
618 let arr: Vec<serde_json::Value> = serde_json::from_str(text).expect("Expected JSON array");
620 assert_eq!(arr.len(), 3);
621 }
622
623 #[tokio::test]
624 async fn map_no_collect_returns_scope_output() {
625 let mut scopes = HashMap::new();
626 scopes.insert("echo_scope".to_string(), echo_scope());
627
628 let step = map_step("map_no_collect", "a\nb", "echo_scope", None);
629 let executor = MapExecutor::new(&scopes);
630 let config = StepConfig::default();
631 let ctx = Context::new(String::new(), HashMap::new());
632
633 let result = executor.execute(&step, &config, &ctx).await.unwrap();
634 assert!(matches!(result, StepOutput::Scope(_)));
636 }
637
638 #[tokio::test]
639 async fn map_reduce_concat_joins_outputs() {
640 let mut scopes = HashMap::new();
641 scopes.insert("echo_scope".to_string(), echo_scope());
642
643 let step = map_step("map_reduce_concat", "hello\nworld", "echo_scope", None);
644 let executor = MapExecutor::new(&scopes);
645
646 let mut config_values = HashMap::new();
647 config_values.insert(
648 "reduce".to_string(),
649 serde_json::Value::String("concat".to_string()),
650 );
651 let config = crate::config::StepConfig { values: config_values };
652 let ctx = Context::new(String::new(), HashMap::new());
653
654 let result = executor.execute(&step, &config, &ctx).await.unwrap();
655 assert!(matches!(result, StepOutput::Cmd(_)));
656 let text = result.text();
657 assert!(text.contains("hello"), "Missing hello: {}", text);
658 assert!(text.contains("world"), "Missing world: {}", text);
659 }
660
661 #[tokio::test]
662 async fn map_reduce_sum_adds_numbers() {
663 let mut scopes = HashMap::new();
664 scopes.insert(
666 "echo_scope".to_string(),
667 ScopeDef {
668 steps: vec![cmd_step("echo_val", "echo {{ scope.value }}")],
669 outputs: None,
670 },
671 );
672
673 let step = map_step("map_reduce_sum", "10\n20\n30", "echo_scope", None);
674 let executor = MapExecutor::new(&scopes);
675
676 let mut config_values = HashMap::new();
677 config_values.insert(
678 "reduce".to_string(),
679 serde_json::Value::String("sum".to_string()),
680 );
681 let config = crate::config::StepConfig { values: config_values };
682 let ctx = Context::new(String::new(), HashMap::new());
683
684 let result = executor.execute(&step, &config, &ctx).await.unwrap();
685 assert!(matches!(result, StepOutput::Cmd(_)));
686 let text = result.text().trim().to_string();
687 assert_eq!(text, "60", "Expected 60, got: {}", text);
688 }
689
690 #[tokio::test]
691 async fn map_reduce_filter_removes_empty() {
692 let mut scopes = HashMap::new();
693 scopes.insert(
695 "echo_scope".to_string(),
696 ScopeDef {
697 steps: vec![cmd_step("echo_val", "echo {{ scope.value }}")],
698 outputs: None,
699 },
700 );
701
702 let step = map_step("map_reduce_filter", "hello\n\nworld", "echo_scope", None);
703 let executor = MapExecutor::new(&scopes);
704
705 let mut config_values = HashMap::new();
706 config_values.insert(
707 "reduce".to_string(),
708 serde_json::Value::String("filter".to_string()),
709 );
710 config_values.insert(
711 "reduce_condition".to_string(),
712 serde_json::Value::String("{{ item.output }}".to_string()),
713 );
714 let config = crate::config::StepConfig { values: config_values };
715 let ctx = Context::new(String::new(), HashMap::new());
716
717 let result = executor.execute(&step, &config, &ctx).await.unwrap();
718 assert!(matches!(result, StepOutput::Cmd(_)));
719 let text = result.text();
720 let lines: Vec<&str> = text.lines().filter(|l| !l.trim().is_empty()).collect();
722 assert!(lines.len() <= 3, "Should have at most 3 lines: {:?}", lines);
723 }
724
725 #[tokio::test]
726 async fn map_order_preserved_parallel() {
727 let mut scopes = HashMap::new();
728 scopes.insert("echo_scope".to_string(), echo_scope());
729
730 let step = map_step("map_order", "first\nsecond\nthird", "echo_scope", Some(3));
731 let executor = MapExecutor::new(&scopes);
732 let config = StepConfig::default();
733 let ctx = Context::new(String::new(), HashMap::new());
734
735 let result = executor.execute(&step, &config, &ctx).await.unwrap();
736 if let StepOutput::Scope(scope_out) = &result {
737 assert_eq!(scope_out.iterations[0].index, 0);
738 assert_eq!(scope_out.iterations[1].index, 1);
739 assert_eq!(scope_out.iterations[2].index, 2);
740 assert!(scope_out.iterations[0].output.text().contains("first"));
741 assert!(scope_out.iterations[1].output.text().contains("second"));
742 assert!(scope_out.iterations[2].output.text().contains("third"));
743 } else {
744 panic!("Expected Scope output");
745 }
746 }
747}