Skip to main content

minion_engine/steps/
parallel.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use tokio::task::JoinSet;
5
6use crate::config::StepConfig;
7use crate::engine::context::Context;
8use crate::error::StepError;
9use crate::workflow::schema::{ScopeDef, StepDef, StepType};
10
11use super::{
12    agent::AgentExecutor, cmd::CmdExecutor, chat::ChatExecutor, gate::GateExecutor,
13    StepExecutor, StepOutput,
14};
15
16pub struct ParallelExecutor {
17    scopes: HashMap<String, ScopeDef>,
18}
19
20impl ParallelExecutor {
21    pub fn new(scopes: &HashMap<String, ScopeDef>) -> Self {
22        Self {
23            scopes: scopes.clone(),
24        }
25    }
26}
27
28#[async_trait]
29impl StepExecutor for ParallelExecutor {
30    async fn execute(
31        &self,
32        step: &StepDef,
33        _config: &StepConfig,
34        ctx: &Context,
35    ) -> Result<StepOutput, StepError> {
36        let nested_steps = step
37            .steps
38            .as_ref()
39            .ok_or_else(|| StepError::Fail("parallel step missing 'steps' field".into()))?;
40
41        let mut set: JoinSet<(String, Result<StepOutput, StepError>)> = JoinSet::new();
42
43        for sub_step in nested_steps.iter() {
44            let sub = sub_step.clone();
45            let scopes = self.scopes.clone();
46            let child_ctx = make_child_ctx(ctx);
47
48            set.spawn(async move {
49                let result = dispatch_step(&sub, &StepConfig::default(), &child_ctx, &scopes).await;
50                (sub.name.clone(), result)
51            });
52        }
53
54        let mut outputs: HashMap<String, StepOutput> = HashMap::new();
55        let mut error: Option<StepError> = None;
56
57        while let Some(res) = set.join_next().await {
58            match res {
59                Ok((name, Ok(output))) => {
60                    outputs.insert(name, output);
61                }
62                Ok((name, Err(StepError::ControlFlow(crate::control_flow::ControlFlow::Skip { .. })))) => {
63                    outputs.insert(name, StepOutput::Empty);
64                }
65                Ok((_, Err(e))) => {
66                    set.abort_all();
67                    error = Some(e);
68                }
69                Err(e) => {
70                    set.abort_all();
71                    if error.is_none() {
72                        error = Some(StepError::Fail(format!("Parallel task panicked: {e}")));
73                    }
74                }
75            }
76        }
77
78        if let Some(e) = error {
79            return Err(e);
80        }
81
82        // Return combined output — for now return last output or Empty
83        let last_output = nested_steps
84            .last()
85            .and_then(|s| outputs.get(&s.name))
86            .cloned()
87            .unwrap_or(StepOutput::Empty);
88
89        Ok(last_output)
90    }
91}
92
93fn make_child_ctx(parent: &Context) -> Context {
94    let target = parent
95        .get_var("target")
96        .and_then(|v| v.as_str())
97        .unwrap_or("")
98        .to_string();
99    Context::new(target, HashMap::new())
100}
101
102async fn dispatch_step(
103    step: &StepDef,
104    _config: &StepConfig,
105    ctx: &Context,
106    _scopes: &HashMap<String, ScopeDef>,
107) -> Result<StepOutput, StepError> {
108    // Build config from step's inline config (convert yaml -> json)
109    let values: HashMap<String, serde_json::Value> = step
110        .config
111        .iter()
112        .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or(serde_json::Value::Null)))
113        .collect();
114    let step_config = StepConfig { values };
115
116    match step.step_type {
117        StepType::Cmd => CmdExecutor.execute(step, &step_config, ctx).await,
118        StepType::Agent => AgentExecutor.execute(step, &step_config, ctx).await,
119        StepType::Gate => GateExecutor.execute(step, &step_config, ctx).await,
120        StepType::Chat => ChatExecutor.execute(step, &step_config, ctx).await,
121        _ => Err(StepError::Fail(format!(
122            "Step type '{}' not supported in parallel",
123            step.step_type
124        ))),
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use std::collections::HashMap;
132    use crate::workflow::schema::StepType;
133
134    fn cmd_step(name: &str, run: &str) -> StepDef {
135        StepDef {
136            name: name.to_string(),
137            step_type: StepType::Cmd,
138            run: Some(run.to_string()),
139            prompt: None,
140            condition: None,
141            on_pass: None,
142            on_fail: None,
143            message: None,
144            scope: None,
145            max_iterations: None,
146            initial_value: None,
147            items: None,
148            parallel: None,
149            steps: None,
150            config: HashMap::new(),
151            outputs: None,
152            output_type: None,
153            async_exec: None,
154        }
155    }
156
157    fn parallel_step(name: &str, sub_steps: Vec<StepDef>) -> StepDef {
158        StepDef {
159            name: name.to_string(),
160            step_type: StepType::Parallel,
161            run: None,
162            prompt: None,
163            condition: None,
164            on_pass: None,
165            on_fail: None,
166            message: None,
167            scope: None,
168            max_iterations: None,
169            initial_value: None,
170            items: None,
171            parallel: None,
172            steps: Some(sub_steps),
173            config: HashMap::new(),
174            outputs: None,
175            output_type: None,
176            async_exec: None,
177        }
178    }
179
180    #[tokio::test]
181    async fn parallel_two_cmd_steps() {
182        let scopes = HashMap::new();
183        let step = parallel_step(
184            "parallel_test",
185            vec![
186                cmd_step("step_a", "echo alpha"),
187                cmd_step("step_b", "echo beta"),
188            ],
189        );
190        let executor = ParallelExecutor::new(&scopes);
191        let config = StepConfig::default();
192        let ctx = Context::new(String::new(), HashMap::new());
193
194        let result = executor.execute(&step, &config, &ctx).await;
195        assert!(result.is_ok(), "Expected success: {:?}", result.err());
196    }
197
198    #[tokio::test]
199    async fn parallel_one_failure_returns_error() {
200        let scopes = HashMap::new();
201        let step = parallel_step(
202            "parallel_fail",
203            vec![
204                cmd_step("ok_step", "echo ok"),
205                {
206                    // Use an unsupported step type to force dispatch_step to return Err
207                    let mut s = cmd_step("fail_step", "echo fake");
208                    s.step_type = StepType::Template;
209                    s
210                },
211            ],
212        );
213        let executor = ParallelExecutor::new(&scopes);
214        let config = StepConfig::default();
215        let ctx = Context::new(String::new(), HashMap::new());
216
217        let result = executor.execute(&step, &config, &ctx).await;
218        assert!(result.is_err(), "Expected error due to failing step");
219    }
220}