Skip to main content

minion_engine/steps/
parallel.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio::task::JoinSet;
6
7use crate::config::StepConfig;
8use crate::config::manager::ConfigManager;
9use crate::engine::context::Context;
10use crate::error::StepError;
11use crate::workflow::schema::{ScopeDef, StepDef, StepType};
12
13use super::{
14    agent::AgentExecutor, cmd::CmdExecutor, chat::ChatExecutor, gate::GateExecutor,
15    call::resolve_scope_step_config,
16    SandboxAwareExecutor, SharedSandbox, StepExecutor, StepOutput,
17};
18
19pub struct ParallelExecutor {
20    scopes: HashMap<String, ScopeDef>,
21    sandbox: SharedSandbox,
22    config_manager: Option<Arc<ConfigManager>>,
23}
24
25impl ParallelExecutor {
26    pub fn new(scopes: &HashMap<String, ScopeDef>, sandbox: SharedSandbox) -> Self {
27        Self {
28            scopes: scopes.clone(),
29            sandbox,
30            config_manager: None,
31        }
32    }
33
34    pub fn with_config_manager(mut self, cm: Option<Arc<ConfigManager>>) -> Self {
35        self.config_manager = cm;
36        self
37    }
38}
39
40#[async_trait]
41impl StepExecutor for ParallelExecutor {
42    async fn execute(
43        &self,
44        step: &StepDef,
45        _config: &StepConfig,
46        ctx: &Context,
47    ) -> Result<StepOutput, StepError> {
48        let nested_steps = step
49            .steps
50            .as_ref()
51            .ok_or_else(|| StepError::Fail("parallel step missing 'steps' field".into()))?;
52
53        let mut set: JoinSet<(String, Result<StepOutput, StepError>)> = JoinSet::new();
54
55        for sub_step in nested_steps.iter() {
56            let sub = sub_step.clone();
57            let scopes = self.scopes.clone();
58            let child_ctx = make_child_ctx(ctx);
59            let sandbox_clone = self.sandbox.clone();
60            let cm_clone = self.config_manager.clone();
61
62            set.spawn(async move {
63                let step_config = resolve_scope_step_config(&cm_clone, &sub);
64                let result = dispatch_step(&sub, &step_config, &child_ctx, &scopes, &sandbox_clone).await;
65                (sub.name.clone(), result)
66            });
67        }
68
69        let mut outputs: HashMap<String, StepOutput> = HashMap::new();
70        let mut error: Option<StepError> = None;
71
72        while let Some(res) = set.join_next().await {
73            match res {
74                Ok((name, Ok(output))) => {
75                    outputs.insert(name, output);
76                }
77                Ok((name, Err(StepError::ControlFlow(crate::control_flow::ControlFlow::Skip { .. })))) => {
78                    outputs.insert(name, StepOutput::Empty);
79                }
80                Ok((_, Err(e))) => {
81                    set.abort_all();
82                    error = Some(e);
83                }
84                Err(e) => {
85                    set.abort_all();
86                    if error.is_none() {
87                        error = Some(StepError::Fail(format!("Parallel task panicked: {e}")));
88                    }
89                }
90            }
91        }
92
93        if let Some(e) = error {
94            return Err(e);
95        }
96
97        // Return combined output — for now return last output or Empty
98        let last_output = nested_steps
99            .last()
100            .and_then(|s| outputs.get(&s.name))
101            .cloned()
102            .unwrap_or(StepOutput::Empty);
103
104        Ok(last_output)
105    }
106}
107
108fn make_child_ctx(parent: &Context) -> Context {
109    let target = parent
110        .get_var("target")
111        .and_then(|v| v.as_str())
112        .unwrap_or("")
113        .to_string();
114    Context::new(target, HashMap::new())
115}
116
117async fn dispatch_step(
118    step: &StepDef,
119    _config: &StepConfig,
120    ctx: &Context,
121    _scopes: &HashMap<String, ScopeDef>,
122    sandbox: &SharedSandbox,
123) -> Result<StepOutput, StepError> {
124    // Build config from step's inline config (convert yaml -> json)
125    let values: HashMap<String, serde_json::Value> = step
126        .config
127        .iter()
128        .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or(serde_json::Value::Null)))
129        .collect();
130    let step_config = StepConfig { values };
131
132    match step.step_type {
133        StepType::Cmd => CmdExecutor.execute_sandboxed(step, &step_config, ctx, sandbox).await,
134        StepType::Agent => AgentExecutor.execute_sandboxed(step, &step_config, ctx, sandbox).await,
135        StepType::Gate => GateExecutor.execute(step, &step_config, ctx).await,
136        StepType::Chat => ChatExecutor.execute(step, &step_config, ctx).await,
137        _ => Err(StepError::Fail(format!(
138            "Step type '{}' not supported in parallel",
139            step.step_type
140        ))),
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use std::collections::HashMap;
148    use crate::workflow::schema::StepType;
149
150    fn cmd_step(name: &str, run: &str) -> StepDef {
151        StepDef {
152            name: name.to_string(),
153            step_type: StepType::Cmd,
154            run: Some(run.to_string()),
155            prompt: None,
156            condition: None,
157            on_pass: None,
158            on_fail: None,
159            message: None,
160            scope: None,
161            max_iterations: None,
162            initial_value: None,
163            items: None,
164            parallel: None,
165            steps: None,
166            config: HashMap::new(),
167            outputs: None,
168            output_type: None,
169            async_exec: None,
170        }
171    }
172
173    fn parallel_step(name: &str, sub_steps: Vec<StepDef>) -> StepDef {
174        StepDef {
175            name: name.to_string(),
176            step_type: StepType::Parallel,
177            run: None,
178            prompt: None,
179            condition: None,
180            on_pass: None,
181            on_fail: None,
182            message: None,
183            scope: None,
184            max_iterations: None,
185            initial_value: None,
186            items: None,
187            parallel: None,
188            steps: Some(sub_steps),
189            config: HashMap::new(),
190            outputs: None,
191            output_type: None,
192            async_exec: None,
193        }
194    }
195
196    #[tokio::test]
197    async fn parallel_two_cmd_steps() {
198        let scopes = HashMap::new();
199        let step = parallel_step(
200            "parallel_test",
201            vec![
202                cmd_step("step_a", "echo alpha"),
203                cmd_step("step_b", "echo beta"),
204            ],
205        );
206        let executor = ParallelExecutor::new(&scopes, None);
207        let config = StepConfig::default();
208        let ctx = Context::new(String::new(), HashMap::new());
209
210        let result = executor.execute(&step, &config, &ctx).await;
211        assert!(result.is_ok(), "Expected success: {:?}", result.err());
212    }
213
214    #[tokio::test]
215    async fn parallel_one_failure_returns_error() {
216        let scopes = HashMap::new();
217        let step = parallel_step(
218            "parallel_fail",
219            vec![
220                cmd_step("ok_step", "echo ok"),
221                {
222                    // Use an unsupported step type to force dispatch_step to return Err
223                    let mut s = cmd_step("fail_step", "echo fake");
224                    s.step_type = StepType::Template;
225                    s
226                },
227            ],
228        );
229        let executor = ParallelExecutor::new(&scopes, None);
230        let config = StepConfig::default();
231        let ctx = Context::new(String::new(), HashMap::new());
232
233        let result = executor.execute(&step, &config, &ctx).await;
234        assert!(result.is_err(), "Expected error due to failing step");
235    }
236}