Skip to main content

minion_engine/steps/
script.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use async_trait::async_trait;
5use rhai::{Dynamic, Engine as RhaiEngine, EvalAltResult, Scope};
6
7use crate::config::StepConfig;
8use crate::engine::context::Context;
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{CmdOutput, StepExecutor, StepOutput};
13
14/// Maximum number of Rhai operations before timeout (prevents infinite loops)
15const MAX_OPERATIONS: u64 = 1_000_000;
16
17pub struct ScriptExecutor;
18
19#[async_trait]
20impl StepExecutor for ScriptExecutor {
21    async fn execute(
22        &self,
23        step: &StepDef,
24        _config: &StepConfig,
25        ctx: &Context,
26    ) -> Result<StepOutput, StepError> {
27        let script = step
28            .run
29            .as_ref()
30            .ok_or_else(|| StepError::Fail("script step missing 'run' field".into()))?
31            .clone();
32
33        // Build a flat snapshot of context values for ctx_get access
34        let ctx_snapshot = build_ctx_snapshot(ctx);
35
36        // Build Rhai engine with operation limit
37        let mut engine = RhaiEngine::new();
38        engine.set_max_operations(MAX_OPERATIONS);
39
40        // Register ctx_get(key) — reads from the context snapshot
41        let snapshot = ctx_snapshot.clone();
42        engine.register_fn("ctx_get", move |key: &str| -> Dynamic {
43            match snapshot.get(key) {
44                Some(v) => json_to_dynamic(v),
45                None => Dynamic::UNIT,
46            }
47        });
48
49        // Register ctx_set(key, value) — writes to thread_local storage
50        thread_local! {
51            static CTX_WRITES: RefCell<HashMap<String, serde_json::Value>> =
52                RefCell::new(HashMap::new());
53        }
54        CTX_WRITES.with(|w| w.borrow_mut().clear());
55
56        engine.register_fn("ctx_set", |key: &str, value: Dynamic| {
57            let json_val = dynamic_to_json(&value);
58            CTX_WRITES.with(|w| w.borrow_mut().insert(key.to_string(), json_val));
59        });
60
61        // Evaluate the script synchronously
62        let mut scope = Scope::new();
63        let result = engine.eval_with_scope::<Dynamic>(&mut scope, &script);
64
65        let output_text = match result {
66            Ok(val) => dynamic_to_string(&val),
67            Err(e) => {
68                return Err(StepError::Fail(format_rhai_error(&e)));
69            }
70        };
71
72        Ok(StepOutput::Cmd(CmdOutput {
73            stdout: output_text,
74            stderr: String::new(),
75            exit_code: 0,
76            duration: std::time::Duration::ZERO,
77        }))
78    }
79}
80
81/// Build a flat key-value snapshot from the context for ctx_get access.
82/// Uses the tera context's `get()` API to extract step outputs.
83fn build_ctx_snapshot(ctx: &Context) -> HashMap<String, serde_json::Value> {
84    let tera_ctx = ctx.to_tera_context();
85    let mut flat: HashMap<String, serde_json::Value> = HashMap::new();
86
87    // Extract steps map: "step_name.field" => value
88    if let Some(steps_val) = tera_ctx.get("steps") {
89        if let serde_json::Value::Object(steps_map) = steps_val {
90            for (step_name, step_val) in steps_map {
91                if let serde_json::Value::Object(fields) = step_val {
92                    for (field, field_val) in fields {
93                        flat.insert(format!("{}.{}", step_name, field), field_val.clone());
94                    }
95                }
96            }
97        }
98    }
99
100    // Extract top-level variables
101    if let Some(target) = tera_ctx.get("target") {
102        flat.insert("target".to_string(), target.clone());
103    }
104
105    flat
106}
107
108/// Convert a serde_json::Value to a Rhai Dynamic
109fn json_to_dynamic(val: &serde_json::Value) -> Dynamic {
110    match val {
111        serde_json::Value::Null => Dynamic::UNIT,
112        serde_json::Value::Bool(b) => Dynamic::from(*b),
113        serde_json::Value::Number(n) => {
114            if let Some(i) = n.as_i64() {
115                Dynamic::from(i)
116            } else if let Some(f) = n.as_f64() {
117                Dynamic::from(f)
118            } else {
119                Dynamic::UNIT
120            }
121        }
122        serde_json::Value::String(s) => Dynamic::from(s.clone()),
123        serde_json::Value::Array(arr) => {
124            let v: rhai::Array = arr.iter().map(json_to_dynamic).collect();
125            Dynamic::from(v)
126        }
127        serde_json::Value::Object(obj) => {
128            let mut map = rhai::Map::new();
129            for (k, v) in obj {
130                map.insert(k.clone().into(), json_to_dynamic(v));
131            }
132            Dynamic::from(map)
133        }
134    }
135}
136
137/// Convert a Rhai Dynamic to a serde_json::Value
138fn dynamic_to_json(val: &Dynamic) -> serde_json::Value {
139    if val.is_unit() {
140        serde_json::Value::Null
141    } else if let Some(b) = val.clone().try_cast::<bool>() {
142        serde_json::Value::Bool(b)
143    } else if let Some(i) = val.clone().try_cast::<i64>() {
144        serde_json::json!(i)
145    } else if let Some(f) = val.clone().try_cast::<f64>() {
146        serde_json::json!(f)
147    } else if let Some(s) = val.clone().try_cast::<String>() {
148        serde_json::Value::String(s)
149    } else {
150        serde_json::Value::String(val.to_string())
151    }
152}
153
154/// Convert a Rhai Dynamic to a display string (script return value → step output text)
155fn dynamic_to_string(val: &Dynamic) -> String {
156    if val.is_unit() {
157        String::new()
158    } else if let Some(s) = val.clone().try_cast::<String>() {
159        s
160    } else {
161        val.to_string()
162    }
163}
164
165/// Format a Rhai evaluation error with line number info when available
166fn format_rhai_error(e: &EvalAltResult) -> String {
167    format!("Script error: {e}")
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::config::StepConfig;
174    use crate::workflow::schema::StepType;
175
176    fn script_step(name: &str, run: &str) -> StepDef {
177        StepDef {
178            name: name.to_string(),
179            step_type: StepType::Script,
180            run: Some(run.to_string()),
181            prompt: None,
182            condition: None,
183            on_pass: None,
184            on_fail: None,
185            message: None,
186            scope: None,
187            max_iterations: None,
188            initial_value: None,
189            items: None,
190            parallel: None,
191            steps: None,
192            config: HashMap::new(),
193            outputs: None,
194            output_type: None,
195            async_exec: None,
196        }
197    }
198
199    #[tokio::test]
200    async fn script_returns_integer_expression() {
201        let step = script_step("s", "40 + 2");
202        let config = StepConfig::default();
203        let ctx = Context::new(String::new(), HashMap::new());
204
205        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
206        assert_eq!(result.text().trim(), "42");
207    }
208
209    #[tokio::test]
210    async fn script_returns_string_value() {
211        let step = script_step("s", r#""hello from rhai""#);
212        let config = StepConfig::default();
213        let ctx = Context::new(String::new(), HashMap::new());
214
215        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
216        assert_eq!(result.text(), "hello from rhai");
217    }
218
219    #[tokio::test]
220    async fn script_runtime_error_returns_step_error() {
221        let step = script_step("s", "throw \"oops\";");
222        let config = StepConfig::default();
223        let ctx = Context::new(String::new(), HashMap::new());
224
225        let result = ScriptExecutor.execute(&step, &config, &ctx).await;
226        assert!(result.is_err());
227        let err = result.unwrap_err().to_string();
228        assert!(
229            err.contains("Script error") || err.contains("oops"),
230            "Got: {err}"
231        );
232    }
233
234    #[tokio::test]
235    async fn script_ctx_get_reads_step_output() {
236        use crate::steps::{CmdOutput, StepOutput};
237        use std::time::Duration;
238
239        let mut ctx = Context::new(String::new(), HashMap::new());
240        ctx.store(
241            "prev",
242            StepOutput::Cmd(CmdOutput {
243                stdout: "hello_world".to_string(),
244                stderr: String::new(),
245                exit_code: 0,
246                duration: Duration::ZERO,
247            }),
248        );
249
250        let step = script_step("s", r#"let v = ctx_get("prev.stdout"); v"#);
251        let config = StepConfig::default();
252
253        let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
254        assert_eq!(result.text(), "hello_world");
255    }
256
257    #[tokio::test]
258    async fn script_missing_run_field_returns_error() {
259        let step = StepDef {
260            name: "s".to_string(),
261            step_type: StepType::Script,
262            run: None,
263            prompt: None,
264            condition: None,
265            on_pass: None,
266            on_fail: None,
267            message: None,
268            scope: None,
269            max_iterations: None,
270            initial_value: None,
271            items: None,
272            parallel: None,
273            steps: None,
274            config: HashMap::new(),
275            outputs: None,
276            output_type: None,
277            async_exec: None,
278        };
279        let config = StepConfig::default();
280        let ctx = Context::new(String::new(), HashMap::new());
281        let result = ScriptExecutor.execute(&step, &config, &ctx).await;
282        assert!(result.is_err());
283        assert!(result.unwrap_err().to_string().contains("missing 'run'"));
284    }
285}