1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use async_trait::async_trait;
5use rhai::{Dynamic, Engine as RhaiEngine, EvalAltResult, Scope};
6
7use crate::config::StepConfig;
8use crate::engine::context::Context;
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{CmdOutput, StepExecutor, StepOutput};
13
14const MAX_OPERATIONS: u64 = 1_000_000;
16
17pub struct ScriptExecutor;
18
19#[async_trait]
20impl StepExecutor for ScriptExecutor {
21 async fn execute(
22 &self,
23 step: &StepDef,
24 _config: &StepConfig,
25 ctx: &Context,
26 ) -> Result<StepOutput, StepError> {
27 let script = step
28 .run
29 .as_ref()
30 .ok_or_else(|| StepError::Fail("script step missing 'run' field".into()))?
31 .clone();
32
33 let ctx_snapshot = build_ctx_snapshot(ctx);
35
36 let mut engine = RhaiEngine::new();
38 engine.set_max_operations(MAX_OPERATIONS);
39
40 let snapshot = ctx_snapshot.clone();
42 engine.register_fn("ctx_get", move |key: &str| -> Dynamic {
43 match snapshot.get(key) {
44 Some(v) => json_to_dynamic(v),
45 None => Dynamic::UNIT,
46 }
47 });
48
49 thread_local! {
51 static CTX_WRITES: RefCell<HashMap<String, serde_json::Value>> =
52 RefCell::new(HashMap::new());
53 }
54 CTX_WRITES.with(|w| w.borrow_mut().clear());
55
56 engine.register_fn("ctx_set", |key: &str, value: Dynamic| {
57 let json_val = dynamic_to_json(&value);
58 CTX_WRITES.with(|w| w.borrow_mut().insert(key.to_string(), json_val));
59 });
60
61 let mut scope = Scope::new();
63 let result = engine.eval_with_scope::<Dynamic>(&mut scope, &script);
64
65 let output_text = match result {
66 Ok(val) => dynamic_to_string(&val),
67 Err(e) => {
68 return Err(StepError::Fail(format_rhai_error(&e)));
69 }
70 };
71
72 Ok(StepOutput::Cmd(CmdOutput {
73 stdout: output_text,
74 stderr: String::new(),
75 exit_code: 0,
76 duration: std::time::Duration::ZERO,
77 }))
78 }
79}
80
81fn build_ctx_snapshot(ctx: &Context) -> HashMap<String, serde_json::Value> {
84 let tera_ctx = ctx.to_tera_context();
85 let mut flat: HashMap<String, serde_json::Value> = HashMap::new();
86
87 if let Some(steps_val) = tera_ctx.get("steps") {
89 if let serde_json::Value::Object(steps_map) = steps_val {
90 for (step_name, step_val) in steps_map {
91 if let serde_json::Value::Object(fields) = step_val {
92 for (field, field_val) in fields {
93 flat.insert(format!("{}.{}", step_name, field), field_val.clone());
94 }
95 }
96 }
97 }
98 }
99
100 if let Some(target) = tera_ctx.get("target") {
102 flat.insert("target".to_string(), target.clone());
103 }
104
105 flat
106}
107
108fn json_to_dynamic(val: &serde_json::Value) -> Dynamic {
110 match val {
111 serde_json::Value::Null => Dynamic::UNIT,
112 serde_json::Value::Bool(b) => Dynamic::from(*b),
113 serde_json::Value::Number(n) => {
114 if let Some(i) = n.as_i64() {
115 Dynamic::from(i)
116 } else if let Some(f) = n.as_f64() {
117 Dynamic::from(f)
118 } else {
119 Dynamic::UNIT
120 }
121 }
122 serde_json::Value::String(s) => Dynamic::from(s.clone()),
123 serde_json::Value::Array(arr) => {
124 let v: rhai::Array = arr.iter().map(json_to_dynamic).collect();
125 Dynamic::from(v)
126 }
127 serde_json::Value::Object(obj) => {
128 let mut map = rhai::Map::new();
129 for (k, v) in obj {
130 map.insert(k.clone().into(), json_to_dynamic(v));
131 }
132 Dynamic::from(map)
133 }
134 }
135}
136
137fn dynamic_to_json(val: &Dynamic) -> serde_json::Value {
139 if val.is_unit() {
140 serde_json::Value::Null
141 } else if let Some(b) = val.clone().try_cast::<bool>() {
142 serde_json::Value::Bool(b)
143 } else if let Some(i) = val.clone().try_cast::<i64>() {
144 serde_json::json!(i)
145 } else if let Some(f) = val.clone().try_cast::<f64>() {
146 serde_json::json!(f)
147 } else if let Some(s) = val.clone().try_cast::<String>() {
148 serde_json::Value::String(s)
149 } else {
150 serde_json::Value::String(val.to_string())
151 }
152}
153
154fn dynamic_to_string(val: &Dynamic) -> String {
156 if val.is_unit() {
157 String::new()
158 } else if let Some(s) = val.clone().try_cast::<String>() {
159 s
160 } else {
161 val.to_string()
162 }
163}
164
165fn format_rhai_error(e: &EvalAltResult) -> String {
167 format!("Script error: {e}")
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::config::StepConfig;
174 use crate::workflow::schema::StepType;
175
176 fn script_step(name: &str, run: &str) -> StepDef {
177 StepDef {
178 name: name.to_string(),
179 step_type: StepType::Script,
180 run: Some(run.to_string()),
181 prompt: None,
182 condition: None,
183 on_pass: None,
184 on_fail: None,
185 message: None,
186 scope: None,
187 max_iterations: None,
188 initial_value: None,
189 items: None,
190 parallel: None,
191 steps: None,
192 config: HashMap::new(),
193 outputs: None,
194 output_type: None,
195 async_exec: None,
196 }
197 }
198
199 #[tokio::test]
200 async fn script_returns_integer_expression() {
201 let step = script_step("s", "40 + 2");
202 let config = StepConfig::default();
203 let ctx = Context::new(String::new(), HashMap::new());
204
205 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
206 assert_eq!(result.text().trim(), "42");
207 }
208
209 #[tokio::test]
210 async fn script_returns_string_value() {
211 let step = script_step("s", r#""hello from rhai""#);
212 let config = StepConfig::default();
213 let ctx = Context::new(String::new(), HashMap::new());
214
215 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
216 assert_eq!(result.text(), "hello from rhai");
217 }
218
219 #[tokio::test]
220 async fn script_runtime_error_returns_step_error() {
221 let step = script_step("s", "throw \"oops\";");
222 let config = StepConfig::default();
223 let ctx = Context::new(String::new(), HashMap::new());
224
225 let result = ScriptExecutor.execute(&step, &config, &ctx).await;
226 assert!(result.is_err());
227 let err = result.unwrap_err().to_string();
228 assert!(
229 err.contains("Script error") || err.contains("oops"),
230 "Got: {err}"
231 );
232 }
233
234 #[tokio::test]
235 async fn script_ctx_get_reads_step_output() {
236 use crate::steps::{CmdOutput, StepOutput};
237 use std::time::Duration;
238
239 let mut ctx = Context::new(String::new(), HashMap::new());
240 ctx.store(
241 "prev",
242 StepOutput::Cmd(CmdOutput {
243 stdout: "hello_world".to_string(),
244 stderr: String::new(),
245 exit_code: 0,
246 duration: Duration::ZERO,
247 }),
248 );
249
250 let step = script_step("s", r#"let v = ctx_get("prev.stdout"); v"#);
251 let config = StepConfig::default();
252
253 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
254 assert_eq!(result.text(), "hello_world");
255 }
256
257 #[tokio::test]
258 async fn script_missing_run_field_returns_error() {
259 let step = StepDef {
260 name: "s".to_string(),
261 step_type: StepType::Script,
262 run: None,
263 prompt: None,
264 condition: None,
265 on_pass: None,
266 on_fail: None,
267 message: None,
268 scope: None,
269 max_iterations: None,
270 initial_value: None,
271 items: None,
272 parallel: None,
273 steps: None,
274 config: HashMap::new(),
275 outputs: None,
276 output_type: None,
277 async_exec: None,
278 };
279 let config = StepConfig::default();
280 let ctx = Context::new(String::new(), HashMap::new());
281 let result = ScriptExecutor.execute(&step, &config, &ctx).await;
282 assert!(result.is_err());
283 assert!(result.unwrap_err().to_string().contains("missing 'run'"));
284 }
285}