Skip to main content

minion_engine/steps/
script.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use async_trait::async_trait;
5use rhai::{Dynamic, Engine as RhaiEngine, EvalAltResult, Scope};
6
7use crate::config::StepConfig;
8use crate::engine::context::Context;
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{CmdOutput, StepExecutor, StepOutput};
13
14/// Maximum number of Rhai operations before timeout (prevents infinite loops)
15const MAX_OPERATIONS: u64 = 1_000_000;
16
17pub struct ScriptExecutor;
18
19#[async_trait]
20impl StepExecutor for ScriptExecutor {
21    async fn execute(
22        &self,
23        step: &StepDef,
24        _config: &StepConfig,
25        ctx: &Context,
26    ) -> Result<StepOutput, StepError> {
27        let script = step
28            .run
29            .as_ref()
30            .ok_or_else(|| StepError::Fail("script step missing 'run' field".into()))?
31            .clone();
32
33        // Build a flat snapshot of context values for ctx_get access
34        let ctx_snapshot = build_ctx_snapshot(ctx);
35
36        // Build Rhai engine with operation limit
37        let mut engine = RhaiEngine::new();
38        engine.set_max_operations(MAX_OPERATIONS);
39
40        // Register ctx_get(key) — reads from the context snapshot
41        let snapshot = ctx_snapshot.clone();
42        engine.register_fn("ctx_get", move |key: &str| -> Dynamic {
43            match snapshot.get(key) {
44                Some(v) => json_to_dynamic(v),
45                None => Dynamic::UNIT,
46            }
47        });
48
49        // Register ctx_set(key, value) — writes to thread_local storage
50        thread_local! {
51            static CTX_WRITES: RefCell<HashMap<String, serde_json::Value>> =
52                RefCell::new(HashMap::new());
53        }
54        CTX_WRITES.with(|w| w.borrow_mut().clear());
55
56        engine.register_fn("ctx_set", |key: &str, value: Dynamic| {
57            let json_val = dynamic_to_json(&value);
58            CTX_WRITES.with(|w| w.borrow_mut().insert(key.to_string(), json_val));
59        });
60
61        // Evaluate the script synchronously
62        let mut scope = Scope::new();
63        let result = engine.eval_with_scope::<Dynamic>(&mut scope, &script);
64
65        let output_text = match result {
66            Ok(val) => dynamic_to_string(&val),
67            Err(e) => {
68                return Err(StepError::Fail(format_rhai_error(&e)));
69            }
70        };
71
72        Ok(StepOutput::Cmd(CmdOutput {
73            stdout: output_text,
74            stderr: String::new(),
75            exit_code: 0,
76            duration: std::time::Duration::ZERO,
77        }))
78    }
79}
80
81/// Build a flat key-value snapshot from the context for ctx_get access.
82/// Uses the tera context's `get()` API to extract step outputs.
83fn build_ctx_snapshot(ctx: &Context) -> HashMap<String, serde_json::Value> {
84    let tera_ctx = ctx.to_tera_context();
85    let mut flat: HashMap<String, serde_json::Value> = HashMap::new();
86
87    // Extract steps map: "step_name.field" => value
88    if let Some(serde_json::Value::Object(steps_map)) = tera_ctx.get("steps") {
89        for (step_name, step_val) in steps_map {
90            if let serde_json::Value::Object(fields) = step_val {
91                for (field, field_val) in fields {
92                    flat.insert(format!("{}.{}", step_name, field), field_val.clone());
93                }
94            }
95        }
96    }
97
98    // Extract top-level variables
99    if let Some(target) = tera_ctx.get("target") {
100        flat.insert("target".to_string(), target.clone());
101    }
102
103    flat
104}
105
106/// Convert a serde_json::Value to a Rhai Dynamic
107fn json_to_dynamic(val: &serde_json::Value) -> Dynamic {
108    match val {
109        serde_json::Value::Null => Dynamic::UNIT,
110        serde_json::Value::Bool(b) => Dynamic::from(*b),
111        serde_json::Value::Number(n) => {
112            if let Some(i) = n.as_i64() {
113                Dynamic::from(i)
114            } else if let Some(f) = n.as_f64() {
115                Dynamic::from(f)
116            } else {
117                Dynamic::UNIT
118            }
119        }
120        serde_json::Value::String(s) => Dynamic::from(s.clone()),
121        serde_json::Value::Array(arr) => {
122            let v: rhai::Array = arr.iter().map(json_to_dynamic).collect();
123            Dynamic::from(v)
124        }
125        serde_json::Value::Object(obj) => {
126            let mut map = rhai::Map::new();
127            for (k, v) in obj {
128                map.insert(k.clone().into(), json_to_dynamic(v));
129            }
130            Dynamic::from(map)
131        }
132    }
133}
134
135/// Convert a Rhai Dynamic to a serde_json::Value
136fn dynamic_to_json(val: &Dynamic) -> serde_json::Value {
137    if val.is_unit() {
138        serde_json::Value::Null
139    } else if let Some(b) = val.clone().try_cast::<bool>() {
140        serde_json::Value::Bool(b)
141    } else if let Some(i) = val.clone().try_cast::<i64>() {
142        serde_json::json!(i)
143    } else if let Some(f) = val.clone().try_cast::<f64>() {
144        serde_json::json!(f)
145    } else if let Some(s) = val.clone().try_cast::<String>() {
146        serde_json::Value::String(s)
147    } else {
148        serde_json::Value::String(val.to_string())
149    }
150}
151
152/// Convert a Rhai Dynamic to a display string (script return value → step output text)
153fn dynamic_to_string(val: &Dynamic) -> String {
154    if val.is_unit() {
155        String::new()
156    } else if let Some(s) = val.clone().try_cast::<String>() {
157        s
158    } else {
159        val.to_string()
160    }
161}
162
163/// Format a Rhai evaluation error with line number info when available
164fn format_rhai_error(e: &EvalAltResult) -> String {
165    format!("Script error: {e}")
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::config::StepConfig;
172    use crate::workflow::schema::StepType;
173
174    fn script_step(name: &str, run: &str) -> StepDef {
175        StepDef {
176            name: name.to_string(),
177            step_type: StepType::Script,
178            run: Some(run.to_string()),
179            prompt: None,
180            condition: None,
181            on_pass: None,
182            on_fail: None,
183            message: None,
184            scope: None,
185            max_iterations: None,
186            initial_value: None,
187            items: None,
188            parallel: None,
189            steps: None,
190            config: HashMap::new(),
191            outputs: None,
192            output_type: None,
193            async_exec: None,
194        }
195    }
196
197    #[tokio::test]
198    async fn script_returns_integer_expression() {
199        let step = script_step("s", "40 + 2");
200        let config = StepConfig::default();
201        let ctx = Context::new(String::new(), HashMap::new());
202
203        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
204        assert_eq!(result.text().trim(), "42");
205    }
206
207    #[tokio::test]
208    async fn script_returns_string_value() {
209        let step = script_step("s", r#""hello from rhai""#);
210        let config = StepConfig::default();
211        let ctx = Context::new(String::new(), HashMap::new());
212
213        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
214        assert_eq!(result.text(), "hello from rhai");
215    }
216
217    #[tokio::test]
218    async fn script_runtime_error_returns_step_error() {
219        let step = script_step("s", "throw \"oops\";");
220        let config = StepConfig::default();
221        let ctx = Context::new(String::new(), HashMap::new());
222
223        let result = ScriptExecutor.execute(&step, &config, &ctx).await;
224        assert!(result.is_err());
225        let err = result.unwrap_err().to_string();
226        assert!(
227            err.contains("Script error") || err.contains("oops"),
228            "Got: {err}"
229        );
230    }
231
232    #[tokio::test]
233    async fn script_ctx_get_reads_step_output() {
234        use crate::steps::{CmdOutput, StepOutput};
235        use std::time::Duration;
236
237        let mut ctx = Context::new(String::new(), HashMap::new());
238        ctx.store(
239            "prev",
240            StepOutput::Cmd(CmdOutput {
241                stdout: "hello_world".to_string(),
242                stderr: String::new(),
243                exit_code: 0,
244                duration: Duration::ZERO,
245            }),
246        );
247
248        let step = script_step("s", r#"let v = ctx_get("prev.stdout"); v"#);
249        let config = StepConfig::default();
250
251        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
252        assert_eq!(result.text(), "hello_world");
253    }
254
255    #[tokio::test]
256    async fn script_missing_run_field_returns_error() {
257        let step = StepDef {
258            name: "s".to_string(),
259            step_type: StepType::Script,
260            run: None,
261            prompt: None,
262            condition: None,
263            on_pass: None,
264            on_fail: None,
265            message: None,
266            scope: None,
267            max_iterations: None,
268            initial_value: None,
269            items: None,
270            parallel: None,
271            steps: None,
272            config: HashMap::new(),
273            outputs: None,
274            output_type: None,
275            async_exec: None,
276        };
277        let config = StepConfig::default();
278        let ctx = Context::new(String::new(), HashMap::new());
279        let result = ScriptExecutor.execute(&step, &config, &ctx).await;
280        assert!(result.is_err());
281        assert!(result.unwrap_err().to_string().contains("missing 'run'"));
282    }
283}