1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::task::JoinSet;
6
7use crate::config::StepConfig;
8use crate::config::manager::ConfigManager;
9use crate::engine::context::Context;
10use crate::error::StepError;
11use crate::workflow::schema::{ScopeDef, StepDef, StepType};
12
13use super::{
14 agent::AgentExecutor, cmd::CmdExecutor, chat::ChatExecutor, gate::GateExecutor,
15 call::resolve_scope_step_config,
16 SandboxAwareExecutor, SharedSandbox, StepExecutor, StepOutput,
17};
18
19pub struct ParallelExecutor {
20 scopes: HashMap<String, ScopeDef>,
21 sandbox: SharedSandbox,
22 config_manager: Option<Arc<ConfigManager>>,
23}
24
25impl ParallelExecutor {
26 pub fn new(scopes: &HashMap<String, ScopeDef>, sandbox: SharedSandbox) -> Self {
27 Self {
28 scopes: scopes.clone(),
29 sandbox,
30 config_manager: None,
31 }
32 }
33
34 pub fn with_config_manager(mut self, cm: Option<Arc<ConfigManager>>) -> Self {
35 self.config_manager = cm;
36 self
37 }
38}
39
40#[async_trait]
41impl StepExecutor for ParallelExecutor {
42 async fn execute(
43 &self,
44 step: &StepDef,
45 _config: &StepConfig,
46 ctx: &Context,
47 ) -> Result<StepOutput, StepError> {
48 let nested_steps = step
49 .steps
50 .as_ref()
51 .ok_or_else(|| StepError::Fail("parallel step missing 'steps' field".into()))?;
52
53 let mut set: JoinSet<(String, Result<StepOutput, StepError>)> = JoinSet::new();
54
55 for sub_step in nested_steps.iter() {
56 let sub = sub_step.clone();
57 let scopes = self.scopes.clone();
58 let child_ctx = make_child_ctx(ctx);
59 let sandbox_clone = self.sandbox.clone();
60 let cm_clone = self.config_manager.clone();
61
62 set.spawn(async move {
63 let step_config = resolve_scope_step_config(&cm_clone, &sub);
64 let result = dispatch_step(&sub, &step_config, &child_ctx, &scopes, &sandbox_clone).await;
65 (sub.name.clone(), result)
66 });
67 }
68
69 let mut outputs: HashMap<String, StepOutput> = HashMap::new();
70 let mut error: Option<StepError> = None;
71
72 while let Some(res) = set.join_next().await {
73 match res {
74 Ok((name, Ok(output))) => {
75 outputs.insert(name, output);
76 }
77 Ok((name, Err(StepError::ControlFlow(crate::control_flow::ControlFlow::Skip { .. })))) => {
78 outputs.insert(name, StepOutput::Empty);
79 }
80 Ok((_, Err(e))) => {
81 set.abort_all();
82 error = Some(e);
83 }
84 Err(e) => {
85 set.abort_all();
86 if error.is_none() {
87 error = Some(StepError::Fail(format!("Parallel task panicked: {e}")));
88 }
89 }
90 }
91 }
92
93 if let Some(e) = error {
94 return Err(e);
95 }
96
97 let last_output = nested_steps
99 .last()
100 .and_then(|s| outputs.get(&s.name))
101 .cloned()
102 .unwrap_or(StepOutput::Empty);
103
104 Ok(last_output)
105 }
106}
107
108fn make_child_ctx(parent: &Context) -> Context {
109 let target = parent
110 .get_var("target")
111 .and_then(|v| v.as_str())
112 .unwrap_or("")
113 .to_string();
114 Context::new(target, HashMap::new())
115}
116
117async fn dispatch_step(
118 step: &StepDef,
119 _config: &StepConfig,
120 ctx: &Context,
121 _scopes: &HashMap<String, ScopeDef>,
122 sandbox: &SharedSandbox,
123) -> Result<StepOutput, StepError> {
124 let values: HashMap<String, serde_json::Value> = step
126 .config
127 .iter()
128 .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or(serde_json::Value::Null)))
129 .collect();
130 let step_config = StepConfig { values };
131
132 match step.step_type {
133 StepType::Cmd => CmdExecutor.execute_sandboxed(step, &step_config, ctx, sandbox).await,
134 StepType::Agent => AgentExecutor.execute_sandboxed(step, &step_config, ctx, sandbox).await,
135 StepType::Gate => GateExecutor.execute(step, &step_config, ctx).await,
136 StepType::Chat => ChatExecutor.execute(step, &step_config, ctx).await,
137 _ => Err(StepError::Fail(format!(
138 "Step type '{}' not supported in parallel",
139 step.step_type
140 ))),
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::collections::HashMap;
148 use crate::workflow::schema::StepType;
149
150 fn cmd_step(name: &str, run: &str) -> StepDef {
151 StepDef {
152 name: name.to_string(),
153 step_type: StepType::Cmd,
154 run: Some(run.to_string()),
155 prompt: None,
156 condition: None,
157 on_pass: None,
158 on_fail: None,
159 message: None,
160 scope: None,
161 max_iterations: None,
162 initial_value: None,
163 items: None,
164 parallel: None,
165 steps: None,
166 config: HashMap::new(),
167 outputs: None,
168 output_type: None,
169 async_exec: None,
170 }
171 }
172
173 fn parallel_step(name: &str, sub_steps: Vec<StepDef>) -> StepDef {
174 StepDef {
175 name: name.to_string(),
176 step_type: StepType::Parallel,
177 run: None,
178 prompt: None,
179 condition: None,
180 on_pass: None,
181 on_fail: None,
182 message: None,
183 scope: None,
184 max_iterations: None,
185 initial_value: None,
186 items: None,
187 parallel: None,
188 steps: Some(sub_steps),
189 config: HashMap::new(),
190 outputs: None,
191 output_type: None,
192 async_exec: None,
193 }
194 }
195
196 #[tokio::test]
197 async fn parallel_two_cmd_steps() {
198 let scopes = HashMap::new();
199 let step = parallel_step(
200 "parallel_test",
201 vec![
202 cmd_step("step_a", "echo alpha"),
203 cmd_step("step_b", "echo beta"),
204 ],
205 );
206 let executor = ParallelExecutor::new(&scopes, None);
207 let config = StepConfig::default();
208 let ctx = Context::new(String::new(), HashMap::new());
209
210 let result = executor.execute(&step, &config, &ctx).await;
211 assert!(result.is_ok(), "Expected success: {:?}", result.err());
212 }
213
214 #[tokio::test]
215 async fn parallel_one_failure_returns_error() {
216 let scopes = HashMap::new();
217 let step = parallel_step(
218 "parallel_fail",
219 vec![
220 cmd_step("ok_step", "echo ok"),
221 {
222 let mut s = cmd_step("fail_step", "echo fake");
224 s.step_type = StepType::Template;
225 s
226 },
227 ],
228 );
229 let executor = ParallelExecutor::new(&scopes, None);
230 let config = StepConfig::default();
231 let ctx = Context::new(String::new(), HashMap::new());
232
233 let result = executor.execute(&step, &config, &ctx).await;
234 assert!(result.is_err(), "Expected error due to failing step");
235 }
236}