1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::sync::Semaphore;
6use tokio::task::JoinSet;
7
8use crate::config::StepConfig;
9use crate::control_flow::ControlFlow;
10use crate::engine::context::Context;
11use crate::error::StepError;
12use crate::workflow::schema::{ScopeDef, StepDef};
13
14use super::{
15 call::dispatch_scope_step, CmdOutput, IterationOutput, ScopeOutput, StepExecutor, StepOutput,
16};
17
18fn apply_reduce(
20 scope: &ScopeOutput,
21 reducer: &str,
22 condition_template: Option<&str>,
23) -> Result<StepOutput, crate::error::StepError> {
24 let iterations = &scope.iterations;
25
26 match reducer {
27 "concat" => {
28 let joined = iterations
29 .iter()
30 .map(|it| it.output.text().to_string())
31 .collect::<Vec<_>>()
32 .join("\n");
33 Ok(StepOutput::Cmd(CmdOutput {
34 stdout: joined,
35 stderr: String::new(),
36 exit_code: 0,
37 duration: std::time::Duration::ZERO,
38 }))
39 }
40 "sum" => {
41 let sum: f64 = iterations
42 .iter()
43 .map(|it| it.output.text().trim().parse::<f64>().unwrap_or(0.0))
44 .sum();
45 let result = if sum.fract() == 0.0 {
47 format!("{}", sum as i64)
48 } else {
49 format!("{}", sum)
50 };
51 Ok(StepOutput::Cmd(CmdOutput {
52 stdout: result,
53 stderr: String::new(),
54 exit_code: 0,
55 duration: std::time::Duration::ZERO,
56 }))
57 }
58 "count" => {
59 Ok(StepOutput::Cmd(CmdOutput {
60 stdout: iterations.len().to_string(),
61 stderr: String::new(),
62 exit_code: 0,
63 duration: std::time::Duration::ZERO,
64 }))
65 }
66 "min" => {
67 let min_val = iterations
68 .iter()
69 .filter_map(|it| it.output.text().trim().parse::<f64>().ok())
70 .fold(f64::INFINITY, f64::min);
71 let result = if min_val.fract() == 0.0 {
72 format!("{}", min_val as i64)
73 } else {
74 format!("{}", min_val)
75 };
76 Ok(StepOutput::Cmd(CmdOutput {
77 stdout: result,
78 stderr: String::new(),
79 exit_code: 0,
80 duration: std::time::Duration::ZERO,
81 }))
82 }
83 "max" => {
84 let max_val = iterations
85 .iter()
86 .filter_map(|it| it.output.text().trim().parse::<f64>().ok())
87 .fold(f64::NEG_INFINITY, f64::max);
88 let result = if max_val.fract() == 0.0 {
89 format!("{}", max_val as i64)
90 } else {
91 format!("{}", max_val)
92 };
93 Ok(StepOutput::Cmd(CmdOutput {
94 stdout: result,
95 stderr: String::new(),
96 exit_code: 0,
97 duration: std::time::Duration::ZERO,
98 }))
99 }
100 "filter" => {
101 let tmpl = condition_template.ok_or_else(|| {
102 crate::error::StepError::Fail(
103 "reduce: 'filter' requires 'reduce_condition' to be set".to_string(),
104 )
105 })?;
106
107 let mut kept = Vec::new();
108 for it in iterations {
109 let mut vars = std::collections::HashMap::new();
111 vars.insert(
112 "item_output".to_string(),
113 serde_json::Value::String(it.output.text().to_string()),
114 );
115 let simplified_tmpl = tmpl
118 .replace("{{item.output}}", "{{ item_output }}")
119 .replace("{{ item.output }}", "{{ item_output }}");
120 let child_ctx =
121 crate::engine::context::Context::new(String::new(), vars);
122 let rendered = child_ctx
123 .render_template(&simplified_tmpl)
124 .unwrap_or_default();
125 let passes = !rendered.trim().is_empty()
126 && rendered.trim() != "false"
127 && rendered.trim() != "0";
128 if passes {
129 kept.push(it.output.text().to_string());
130 }
131 }
132
133 let joined = kept.join("\n");
134 Ok(StepOutput::Cmd(CmdOutput {
135 stdout: joined,
136 stderr: String::new(),
137 exit_code: 0,
138 duration: std::time::Duration::ZERO,
139 }))
140 }
141 other => Err(crate::error::StepError::Fail(format!(
142 "unknown reduce operation '{}'; expected concat, sum, count, filter, min, max",
143 other
144 ))),
145 }
146}
147
148fn apply_collect(scope: ScopeOutput, mode: &str) -> Result<StepOutput, crate::error::StepError> {
150 match mode {
151 "text" => {
152 let joined = scope
153 .iterations
154 .iter()
155 .map(|it| it.output.text().to_string())
156 .collect::<Vec<_>>()
157 .join("\n");
158 Ok(StepOutput::Cmd(CmdOutput {
159 stdout: joined,
160 stderr: String::new(),
161 exit_code: 0,
162 duration: std::time::Duration::ZERO,
163 }))
164 }
165 "all" | "json" => {
166 let arr: Vec<serde_json::Value> = scope
167 .iterations
168 .iter()
169 .map(|it| serde_json::Value::String(it.output.text().to_string()))
170 .collect();
171 let json = serde_json::to_string(&arr)
172 .map_err(|e| crate::error::StepError::Fail(format!("collect serialize error: {e}")))?;
173 Ok(StepOutput::Cmd(CmdOutput {
174 stdout: json,
175 stderr: String::new(),
176 exit_code: 0,
177 duration: std::time::Duration::ZERO,
178 }))
179 }
180 other => Err(crate::error::StepError::Fail(format!(
181 "unknown collect mode '{}'; expected all, text, or json",
182 other
183 ))),
184 }
185}
186
187pub struct MapExecutor {
188 scopes: HashMap<String, ScopeDef>,
189}
190
191impl MapExecutor {
192 pub fn new(scopes: &HashMap<String, ScopeDef>) -> Self {
193 Self {
194 scopes: scopes.clone(),
195 }
196 }
197}
198
199#[async_trait]
200impl StepExecutor for MapExecutor {
201 async fn execute(
202 &self,
203 step: &StepDef,
204 _config: &StepConfig,
205 ctx: &Context,
206 ) -> Result<StepOutput, StepError> {
207 let items_template = step
208 .items
209 .as_ref()
210 .ok_or_else(|| StepError::Fail("map step missing 'items' field".into()))?;
211
212 let scope_name = step
213 .scope
214 .as_ref()
215 .ok_or_else(|| StepError::Fail("map step missing 'scope' field".into()))?;
216
217 let scope = self
218 .scopes
219 .get(scope_name)
220 .ok_or_else(|| StepError::Fail(format!("scope '{}' not found", scope_name)))?
221 .clone();
222
223 let rendered_items = ctx.render_template(items_template)?;
224
225 let items: Vec<String> = if rendered_items.trim().starts_with('[') {
227 serde_json::from_str::<Vec<serde_json::Value>>(&rendered_items)
228 .map(|arr| {
229 arr.into_iter()
230 .map(|v| match v {
231 serde_json::Value::String(s) => s,
232 other => other.to_string(),
233 })
234 .collect()
235 })
236 .unwrap_or_else(|_| {
237 rendered_items
238 .lines()
239 .filter(|l| !l.trim().is_empty())
240 .map(|l| l.to_string())
241 .collect()
242 })
243 } else {
244 rendered_items
245 .lines()
246 .filter(|l| !l.trim().is_empty())
247 .map(|l| l.to_string())
248 .collect()
249 };
250
251 let parallel_count = step.parallel.unwrap_or(0);
252
253 let scope_output = if parallel_count == 0 {
254 serial_execute(items, &scope, ctx, &self.scopes).await?
256 } else {
257 parallel_execute(items, &scope, ctx, &self.scopes, parallel_count).await?
259 };
260
261 let reduce_mode = _config.get_str("reduce").map(|s| s.to_string());
263 if let Some(ref reducer) = reduce_mode {
264 if let StepOutput::Scope(ref s) = scope_output {
265 let condition = _config.get_str("reduce_condition");
266 return apply_reduce(s, reducer, condition);
267 }
268 }
269
270 let collect_mode = _config.get_str("collect").map(|s| s.to_string());
272 match (scope_output, collect_mode) {
273 (StepOutput::Scope(s), Some(mode)) => apply_collect(s, &mode),
274 (output, _) => Ok(output),
275 }
276 }
277}
278
279async fn serial_execute(
280 items: Vec<String>,
281 scope: &ScopeDef,
282 ctx: &Context,
283 scopes: &HashMap<String, ScopeDef>,
284) -> Result<StepOutput, StepError> {
285 let mut iterations = Vec::new();
286
287 for (i, item) in items.iter().enumerate() {
288 let mut child_ctx = make_child_ctx(ctx, Some(serde_json::Value::String(item.clone())), i);
289
290 let iter_output = execute_scope_steps(scope, &mut child_ctx, scopes).await?;
291
292 iterations.push(IterationOutput {
293 index: i,
294 output: iter_output,
295 });
296 }
297
298 let final_value = iterations.last().map(|i| Box::new(i.output.clone()));
299 Ok(StepOutput::Scope(ScopeOutput {
300 iterations,
301 final_value,
302 }))
303}
304
305async fn parallel_execute(
306 items: Vec<String>,
307 scope: &ScopeDef,
308 ctx: &Context,
309 scopes: &HashMap<String, ScopeDef>,
310 parallel_count: usize,
311) -> Result<StepOutput, StepError> {
312 let sem = Arc::new(Semaphore::new(parallel_count));
313 let mut set: JoinSet<(usize, Result<StepOutput, StepError>)> = JoinSet::new();
314
315 for (i, item) in items.iter().enumerate() {
316 let sem = Arc::clone(&sem);
317 let item_val = serde_json::Value::String(item.clone());
318 let child_ctx = make_child_ctx(ctx, Some(item_val), i);
319 let scope_clone = scope.clone();
320 let scopes_clone = scopes.clone();
321
322 set.spawn(async move {
323 let _permit = sem.acquire().await.expect("semaphore closed");
324 let result = execute_scope_steps_owned(scope_clone, child_ctx, scopes_clone).await;
325 (i, result)
326 });
327 }
328
329 let mut results: Vec<Option<StepOutput>> = vec![None; items.len()];
330
331 while let Some(res) = set.join_next().await {
332 match res {
333 Ok((i, Ok(output))) => {
334 results[i] = Some(output);
335 }
336 Ok((_, Err(e))) => {
337 set.abort_all();
338 return Err(e);
339 }
340 Err(e) => {
341 set.abort_all();
342 return Err(StepError::Fail(format!("Task panicked: {e}")));
343 }
344 }
345 }
346
347 let iterations: Vec<IterationOutput> = results
348 .into_iter()
349 .enumerate()
350 .map(|(i, opt)| IterationOutput {
351 index: i,
352 output: opt.unwrap_or(StepOutput::Empty),
353 })
354 .collect();
355
356 let final_value = iterations.last().map(|i| Box::new(i.output.clone()));
357 Ok(StepOutput::Scope(ScopeOutput {
358 iterations,
359 final_value,
360 }))
361}
362
363fn make_child_ctx(
364 parent: &Context,
365 scope_value: Option<serde_json::Value>,
366 index: usize,
367) -> Context {
368 let target = parent
369 .get_var("target")
370 .and_then(|v| v.as_str())
371 .unwrap_or("")
372 .to_string();
373 let mut ctx = Context::new(target, parent.all_variables());
374 ctx.scope_value = scope_value;
375 ctx.scope_index = index;
376 ctx.stack_info = parent.get_stack_info().cloned();
377 ctx.prompts_dir = parent.prompts_dir.clone();
378 ctx
379}
380
381async fn execute_scope_steps(
382 scope: &ScopeDef,
383 child_ctx: &mut Context,
384 scopes: &HashMap<String, ScopeDef>,
385) -> Result<StepOutput, StepError> {
386 let mut last_output = StepOutput::Empty;
387
388 for scope_step in &scope.steps {
389 let config = StepConfig::default();
390 let result = dispatch_scope_step(scope_step, &config, child_ctx, scopes).await;
391
392 match result {
393 Ok(output) => {
394 child_ctx.store(&scope_step.name, output.clone());
395 last_output = output;
396 }
397 Err(StepError::ControlFlow(ControlFlow::Break { value, .. })) => {
398 if let Some(v) = value {
399 last_output = v;
400 }
401 break;
402 }
403 Err(StepError::ControlFlow(ControlFlow::Skip { .. })) => {
404 child_ctx.store(&scope_step.name, StepOutput::Empty);
405 }
406 Err(StepError::ControlFlow(ControlFlow::Next { .. })) => {
407 break;
408 }
409 Err(e) => return Err(e),
410 }
411 }
412
413 if let Some(outputs_template) = &scope.outputs {
415 match child_ctx.render_template(outputs_template) {
416 Ok(rendered) => {
417 return Ok(StepOutput::Cmd(CmdOutput {
418 stdout: rendered,
419 stderr: String::new(),
420 exit_code: 0,
421 duration: std::time::Duration::ZERO,
422 }));
423 }
424 Err(_) => {}
425 }
426 }
427
428 Ok(last_output)
429}
430
431async fn execute_scope_steps_owned(
432 scope: ScopeDef,
433 mut child_ctx: Context,
434 scopes: HashMap<String, ScopeDef>,
435) -> Result<StepOutput, StepError> {
436 execute_scope_steps(&scope, &mut child_ctx, &scopes).await
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use std::collections::HashMap;
443 use crate::workflow::schema::{ScopeDef, StepType};
444
445 fn cmd_step(name: &str, run: &str) -> StepDef {
446 StepDef {
447 name: name.to_string(),
448 step_type: StepType::Cmd,
449 run: Some(run.to_string()),
450 prompt: None,
451 condition: None,
452 on_pass: None,
453 on_fail: None,
454 message: None,
455 scope: None,
456 max_iterations: None,
457 initial_value: None,
458 items: None,
459 parallel: None,
460 steps: None,
461 config: HashMap::new(),
462 outputs: None,
463 output_type: None,
464 async_exec: None,
465 }
466 }
467
468 fn map_step(name: &str, items: &str, scope: &str, parallel: Option<usize>) -> StepDef {
469 StepDef {
470 name: name.to_string(),
471 step_type: StepType::Map,
472 run: None,
473 prompt: None,
474 condition: None,
475 on_pass: None,
476 on_fail: None,
477 message: None,
478 scope: Some(scope.to_string()),
479 max_iterations: None,
480 initial_value: None,
481 items: Some(items.to_string()),
482 parallel,
483 steps: None,
484 config: HashMap::new(),
485 outputs: None,
486 output_type: None,
487 async_exec: None,
488 }
489 }
490
491 fn echo_scope() -> ScopeDef {
492 ScopeDef {
493 steps: vec![cmd_step("echo", "echo {{ scope.value }}")],
494 outputs: None,
495 }
496 }
497
498 #[tokio::test]
499 async fn map_three_items_serial() {
500 let mut scopes = HashMap::new();
501 scopes.insert("echo_scope".to_string(), echo_scope());
502
503 let step = map_step("map_test", "alpha\nbeta\ngamma", "echo_scope", None);
504 let executor = MapExecutor::new(&scopes);
505 let config = StepConfig::default();
506 let ctx = Context::new(String::new(), HashMap::new());
507
508 let result = executor.execute(&step, &config, &ctx).await.unwrap();
509 if let StepOutput::Scope(scope_out) = &result {
510 assert_eq!(scope_out.iterations.len(), 3);
511 assert!(scope_out.iterations[0].output.text().contains("alpha"));
512 assert!(scope_out.iterations[1].output.text().contains("beta"));
513 assert!(scope_out.iterations[2].output.text().contains("gamma"));
514 } else {
515 panic!("Expected Scope output");
516 }
517 }
518
519 #[tokio::test]
520 async fn map_three_items_parallel() {
521 let mut scopes = HashMap::new();
522 scopes.insert("echo_scope".to_string(), echo_scope());
523
524 let step = map_step("map_parallel", "a\nb\nc", "echo_scope", Some(3));
525 let executor = MapExecutor::new(&scopes);
526 let config = StepConfig::default();
527 let ctx = Context::new(String::new(), HashMap::new());
528
529 let result = executor.execute(&step, &config, &ctx).await.unwrap();
530 if let StepOutput::Scope(scope_out) = &result {
531 assert_eq!(scope_out.iterations.len(), 3);
532 } else {
533 panic!("Expected Scope output");
534 }
535 }
536
537 fn map_step_with_config(
538 name: &str,
539 items: &str,
540 scope: &str,
541 config_values: HashMap<String, serde_yaml::Value>,
542 ) -> StepDef {
543 StepDef {
544 name: name.to_string(),
545 step_type: StepType::Map,
546 run: None,
547 prompt: None,
548 condition: None,
549 on_pass: None,
550 on_fail: None,
551 message: None,
552 scope: Some(scope.to_string()),
553 max_iterations: None,
554 initial_value: None,
555 items: Some(items.to_string()),
556 parallel: None,
557 steps: None,
558 config: config_values,
559 outputs: None,
560 output_type: None,
561 async_exec: None,
562 }
563 }
564
565 #[tokio::test]
566 async fn map_collect_text_joins_with_newlines() {
567 let mut scopes = HashMap::new();
568 scopes.insert("echo_scope".to_string(), echo_scope());
569
570 let mut cfg = HashMap::new();
571 cfg.insert(
572 "collect".to_string(),
573 serde_yaml::Value::String("text".to_string()),
574 );
575 let step = map_step_with_config("map_collect_text", "alpha\nbeta\ngamma", "echo_scope", cfg);
576 let executor = MapExecutor::new(&scopes);
577
578 let mut config_values = HashMap::new();
580 config_values.insert(
581 "collect".to_string(),
582 serde_json::Value::String("text".to_string()),
583 );
584 let config = crate::config::StepConfig { values: config_values };
585 let ctx = Context::new(String::new(), HashMap::new());
586
587 let result = executor.execute(&step, &config, &ctx).await.unwrap();
588 assert!(matches!(result, StepOutput::Cmd(_)));
590 let text = result.text();
591 assert!(text.contains("alpha"), "Missing alpha in: {}", text);
592 assert!(text.contains("beta"), "Missing beta in: {}", text);
593 assert!(text.contains("gamma"), "Missing gamma in: {}", text);
594 }
595
596 #[tokio::test]
597 async fn map_collect_all_produces_json_array() {
598 let mut scopes = HashMap::new();
599 scopes.insert("echo_scope".to_string(), echo_scope());
600
601 let step = map_step_with_config(
602 "map_collect_all",
603 "x\ny\nz",
604 "echo_scope",
605 HashMap::new(),
606 );
607 let executor = MapExecutor::new(&scopes);
608
609 let mut config_values = HashMap::new();
610 config_values.insert(
611 "collect".to_string(),
612 serde_json::Value::String("all".to_string()),
613 );
614 let config = crate::config::StepConfig { values: config_values };
615 let ctx = Context::new(String::new(), HashMap::new());
616
617 let result = executor.execute(&step, &config, &ctx).await.unwrap();
618 assert!(matches!(result, StepOutput::Cmd(_)));
619 let text = result.text();
620 let arr: Vec<serde_json::Value> = serde_json::from_str(text).expect("Expected JSON array");
622 assert_eq!(arr.len(), 3);
623 }
624
625 #[tokio::test]
626 async fn map_no_collect_returns_scope_output() {
627 let mut scopes = HashMap::new();
628 scopes.insert("echo_scope".to_string(), echo_scope());
629
630 let step = map_step("map_no_collect", "a\nb", "echo_scope", None);
631 let executor = MapExecutor::new(&scopes);
632 let config = StepConfig::default();
633 let ctx = Context::new(String::new(), HashMap::new());
634
635 let result = executor.execute(&step, &config, &ctx).await.unwrap();
636 assert!(matches!(result, StepOutput::Scope(_)));
638 }
639
640 #[tokio::test]
641 async fn map_reduce_concat_joins_outputs() {
642 let mut scopes = HashMap::new();
643 scopes.insert("echo_scope".to_string(), echo_scope());
644
645 let step = map_step("map_reduce_concat", "hello\nworld", "echo_scope", None);
646 let executor = MapExecutor::new(&scopes);
647
648 let mut config_values = HashMap::new();
649 config_values.insert(
650 "reduce".to_string(),
651 serde_json::Value::String("concat".to_string()),
652 );
653 let config = crate::config::StepConfig { values: config_values };
654 let ctx = Context::new(String::new(), HashMap::new());
655
656 let result = executor.execute(&step, &config, &ctx).await.unwrap();
657 assert!(matches!(result, StepOutput::Cmd(_)));
658 let text = result.text();
659 assert!(text.contains("hello"), "Missing hello: {}", text);
660 assert!(text.contains("world"), "Missing world: {}", text);
661 }
662
663 #[tokio::test]
664 async fn map_reduce_sum_adds_numbers() {
665 let mut scopes = HashMap::new();
666 scopes.insert(
668 "echo_scope".to_string(),
669 ScopeDef {
670 steps: vec![cmd_step("echo_val", "echo {{ scope.value }}")],
671 outputs: None,
672 },
673 );
674
675 let step = map_step("map_reduce_sum", "10\n20\n30", "echo_scope", None);
676 let executor = MapExecutor::new(&scopes);
677
678 let mut config_values = HashMap::new();
679 config_values.insert(
680 "reduce".to_string(),
681 serde_json::Value::String("sum".to_string()),
682 );
683 let config = crate::config::StepConfig { values: config_values };
684 let ctx = Context::new(String::new(), HashMap::new());
685
686 let result = executor.execute(&step, &config, &ctx).await.unwrap();
687 assert!(matches!(result, StepOutput::Cmd(_)));
688 let text = result.text().trim().to_string();
689 assert_eq!(text, "60", "Expected 60, got: {}", text);
690 }
691
692 #[tokio::test]
693 async fn map_reduce_filter_removes_empty() {
694 let mut scopes = HashMap::new();
695 scopes.insert(
697 "echo_scope".to_string(),
698 ScopeDef {
699 steps: vec![cmd_step("echo_val", "echo {{ scope.value }}")],
700 outputs: None,
701 },
702 );
703
704 let step = map_step("map_reduce_filter", "hello\n\nworld", "echo_scope", None);
705 let executor = MapExecutor::new(&scopes);
706
707 let mut config_values = HashMap::new();
708 config_values.insert(
709 "reduce".to_string(),
710 serde_json::Value::String("filter".to_string()),
711 );
712 config_values.insert(
713 "reduce_condition".to_string(),
714 serde_json::Value::String("{{ item.output }}".to_string()),
715 );
716 let config = crate::config::StepConfig { values: config_values };
717 let ctx = Context::new(String::new(), HashMap::new());
718
719 let result = executor.execute(&step, &config, &ctx).await.unwrap();
720 assert!(matches!(result, StepOutput::Cmd(_)));
721 let text = result.text();
722 let lines: Vec<&str> = text.lines().filter(|l| !l.trim().is_empty()).collect();
724 assert!(lines.len() <= 3, "Should have at most 3 lines: {:?}", lines);
725 }
726
727 #[tokio::test]
728 async fn map_order_preserved_parallel() {
729 let mut scopes = HashMap::new();
730 scopes.insert("echo_scope".to_string(), echo_scope());
731
732 let step = map_step("map_order", "first\nsecond\nthird", "echo_scope", Some(3));
733 let executor = MapExecutor::new(&scopes);
734 let config = StepConfig::default();
735 let ctx = Context::new(String::new(), HashMap::new());
736
737 let result = executor.execute(&step, &config, &ctx).await.unwrap();
738 if let StepOutput::Scope(scope_out) = &result {
739 assert_eq!(scope_out.iterations[0].index, 0);
740 assert_eq!(scope_out.iterations[1].index, 1);
741 assert_eq!(scope_out.iterations[2].index, 2);
742 assert!(scope_out.iterations[0].output.text().contains("first"));
743 assert!(scope_out.iterations[1].output.text().contains("second"));
744 assert!(scope_out.iterations[2].output.text().contains("third"));
745 } else {
746 panic!("Expected Scope output");
747 }
748 }
749}