1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use tokio::task::JoinSet;
5
6use crate::config::StepConfig;
7use crate::engine::context::Context;
8use crate::error::StepError;
9use crate::workflow::schema::{ScopeDef, StepDef, StepType};
10
11use super::{
12 agent::AgentExecutor, cmd::CmdExecutor, chat::ChatExecutor, gate::GateExecutor,
13 StepExecutor, StepOutput,
14};
15
16pub struct ParallelExecutor {
17 scopes: HashMap<String, ScopeDef>,
18}
19
20impl ParallelExecutor {
21 pub fn new(scopes: &HashMap<String, ScopeDef>) -> Self {
22 Self {
23 scopes: scopes.clone(),
24 }
25 }
26}
27
28#[async_trait]
29impl StepExecutor for ParallelExecutor {
30 async fn execute(
31 &self,
32 step: &StepDef,
33 _config: &StepConfig,
34 ctx: &Context,
35 ) -> Result<StepOutput, StepError> {
36 let nested_steps = step
37 .steps
38 .as_ref()
39 .ok_or_else(|| StepError::Fail("parallel step missing 'steps' field".into()))?;
40
41 let mut set: JoinSet<(String, Result<StepOutput, StepError>)> = JoinSet::new();
42
43 for sub_step in nested_steps.iter() {
44 let sub = sub_step.clone();
45 let scopes = self.scopes.clone();
46 let child_ctx = make_child_ctx(ctx);
47
48 set.spawn(async move {
49 let result = dispatch_step(&sub, &StepConfig::default(), &child_ctx, &scopes).await;
50 (sub.name.clone(), result)
51 });
52 }
53
54 let mut outputs: HashMap<String, StepOutput> = HashMap::new();
55 let mut error: Option<StepError> = None;
56
57 while let Some(res) = set.join_next().await {
58 match res {
59 Ok((name, Ok(output))) => {
60 outputs.insert(name, output);
61 }
62 Ok((name, Err(StepError::ControlFlow(crate::control_flow::ControlFlow::Skip { .. })))) => {
63 outputs.insert(name, StepOutput::Empty);
64 }
65 Ok((_, Err(e))) => {
66 set.abort_all();
67 error = Some(e);
68 }
69 Err(e) => {
70 set.abort_all();
71 if error.is_none() {
72 error = Some(StepError::Fail(format!("Parallel task panicked: {e}")));
73 }
74 }
75 }
76 }
77
78 if let Some(e) = error {
79 return Err(e);
80 }
81
82 let last_output = nested_steps
84 .last()
85 .and_then(|s| outputs.get(&s.name))
86 .cloned()
87 .unwrap_or(StepOutput::Empty);
88
89 Ok(last_output)
90 }
91}
92
93fn make_child_ctx(parent: &Context) -> Context {
94 let target = parent
95 .get_var("target")
96 .and_then(|v| v.as_str())
97 .unwrap_or("")
98 .to_string();
99 Context::new(target, HashMap::new())
100}
101
102async fn dispatch_step(
103 step: &StepDef,
104 _config: &StepConfig,
105 ctx: &Context,
106 _scopes: &HashMap<String, ScopeDef>,
107) -> Result<StepOutput, StepError> {
108 let values: HashMap<String, serde_json::Value> = step
110 .config
111 .iter()
112 .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or(serde_json::Value::Null)))
113 .collect();
114 let step_config = StepConfig { values };
115
116 match step.step_type {
117 StepType::Cmd => CmdExecutor.execute(step, &step_config, ctx).await,
118 StepType::Agent => AgentExecutor.execute(step, &step_config, ctx).await,
119 StepType::Gate => GateExecutor.execute(step, &step_config, ctx).await,
120 StepType::Chat => ChatExecutor.execute(step, &step_config, ctx).await,
121 _ => Err(StepError::Fail(format!(
122 "Step type '{}' not supported in parallel",
123 step.step_type
124 ))),
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use std::collections::HashMap;
132 use crate::workflow::schema::StepType;
133
134 fn cmd_step(name: &str, run: &str) -> StepDef {
135 StepDef {
136 name: name.to_string(),
137 step_type: StepType::Cmd,
138 run: Some(run.to_string()),
139 prompt: None,
140 condition: None,
141 on_pass: None,
142 on_fail: None,
143 message: None,
144 scope: None,
145 max_iterations: None,
146 initial_value: None,
147 items: None,
148 parallel: None,
149 steps: None,
150 config: HashMap::new(),
151 outputs: None,
152 output_type: None,
153 async_exec: None,
154 }
155 }
156
157 fn parallel_step(name: &str, sub_steps: Vec<StepDef>) -> StepDef {
158 StepDef {
159 name: name.to_string(),
160 step_type: StepType::Parallel,
161 run: None,
162 prompt: None,
163 condition: None,
164 on_pass: None,
165 on_fail: None,
166 message: None,
167 scope: None,
168 max_iterations: None,
169 initial_value: None,
170 items: None,
171 parallel: None,
172 steps: Some(sub_steps),
173 config: HashMap::new(),
174 outputs: None,
175 output_type: None,
176 async_exec: None,
177 }
178 }
179
180 #[tokio::test]
181 async fn parallel_two_cmd_steps() {
182 let scopes = HashMap::new();
183 let step = parallel_step(
184 "parallel_test",
185 vec![
186 cmd_step("step_a", "echo alpha"),
187 cmd_step("step_b", "echo beta"),
188 ],
189 );
190 let executor = ParallelExecutor::new(&scopes);
191 let config = StepConfig::default();
192 let ctx = Context::new(String::new(), HashMap::new());
193
194 let result = executor.execute(&step, &config, &ctx).await;
195 assert!(result.is_ok(), "Expected success: {:?}", result.err());
196 }
197
198 #[tokio::test]
199 async fn parallel_one_failure_returns_error() {
200 let scopes = HashMap::new();
201 let step = parallel_step(
202 "parallel_fail",
203 vec![
204 cmd_step("ok_step", "echo ok"),
205 {
206 let mut s = cmd_step("fail_step", "echo fake");
208 s.step_type = StepType::Template;
209 s
210 },
211 ],
212 );
213 let executor = ParallelExecutor::new(&scopes);
214 let config = StepConfig::default();
215 let ctx = Context::new(String::new(), HashMap::new());
216
217 let result = executor.execute(&step, &config, &ctx).await;
218 assert!(result.is_err(), "Expected error due to failing step");
219 }
220}