1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use async_trait::async_trait;
5use rhai::{Dynamic, Engine as RhaiEngine, EvalAltResult, Scope};
6
7use crate::config::StepConfig;
8use crate::engine::context::Context;
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{CmdOutput, StepExecutor, StepOutput};
13
14const MAX_OPERATIONS: u64 = 1_000_000;
16
17pub struct ScriptExecutor;
18
19#[async_trait]
20impl StepExecutor for ScriptExecutor {
21 async fn execute(
22 &self,
23 step: &StepDef,
24 _config: &StepConfig,
25 ctx: &Context,
26 ) -> Result<StepOutput, StepError> {
27 let script = step
28 .run
29 .as_ref()
30 .ok_or_else(|| StepError::Fail("script step missing 'run' field".into()))?
31 .clone();
32
33 let ctx_snapshot = build_ctx_snapshot(ctx);
35
36 let mut engine = RhaiEngine::new();
38 engine.set_max_operations(MAX_OPERATIONS);
39
40 let snapshot = ctx_snapshot.clone();
42 engine.register_fn("ctx_get", move |key: &str| -> Dynamic {
43 match snapshot.get(key) {
44 Some(v) => json_to_dynamic(v),
45 None => Dynamic::UNIT,
46 }
47 });
48
49 thread_local! {
51 static CTX_WRITES: RefCell<HashMap<String, serde_json::Value>> =
52 RefCell::new(HashMap::new());
53 }
54 CTX_WRITES.with(|w| w.borrow_mut().clear());
55
56 engine.register_fn("ctx_set", |key: &str, value: Dynamic| {
57 let json_val = dynamic_to_json(&value);
58 CTX_WRITES.with(|w| w.borrow_mut().insert(key.to_string(), json_val));
59 });
60
61 let mut scope = Scope::new();
63 let result = engine.eval_with_scope::<Dynamic>(&mut scope, &script);
64
65 let output_text = match result {
66 Ok(val) => dynamic_to_string(&val),
67 Err(e) => {
68 return Err(StepError::Fail(format_rhai_error(&e)));
69 }
70 };
71
72 Ok(StepOutput::Cmd(CmdOutput {
73 stdout: output_text,
74 stderr: String::new(),
75 exit_code: 0,
76 duration: std::time::Duration::ZERO,
77 }))
78 }
79}
80
81fn build_ctx_snapshot(ctx: &Context) -> HashMap<String, serde_json::Value> {
84 let tera_ctx = ctx.to_tera_context();
85 let mut flat: HashMap<String, serde_json::Value> = HashMap::new();
86
87 if let Some(serde_json::Value::Object(steps_map)) = tera_ctx.get("steps") {
89 for (step_name, step_val) in steps_map {
90 if let serde_json::Value::Object(fields) = step_val {
91 for (field, field_val) in fields {
92 flat.insert(format!("{}.{}", step_name, field), field_val.clone());
93 }
94 }
95 }
96 }
97
98 if let Some(target) = tera_ctx.get("target") {
100 flat.insert("target".to_string(), target.clone());
101 }
102
103 flat
104}
105
106fn json_to_dynamic(val: &serde_json::Value) -> Dynamic {
108 match val {
109 serde_json::Value::Null => Dynamic::UNIT,
110 serde_json::Value::Bool(b) => Dynamic::from(*b),
111 serde_json::Value::Number(n) => {
112 if let Some(i) = n.as_i64() {
113 Dynamic::from(i)
114 } else if let Some(f) = n.as_f64() {
115 Dynamic::from(f)
116 } else {
117 Dynamic::UNIT
118 }
119 }
120 serde_json::Value::String(s) => Dynamic::from(s.clone()),
121 serde_json::Value::Array(arr) => {
122 let v: rhai::Array = arr.iter().map(json_to_dynamic).collect();
123 Dynamic::from(v)
124 }
125 serde_json::Value::Object(obj) => {
126 let mut map = rhai::Map::new();
127 for (k, v) in obj {
128 map.insert(k.clone().into(), json_to_dynamic(v));
129 }
130 Dynamic::from(map)
131 }
132 }
133}
134
135fn dynamic_to_json(val: &Dynamic) -> serde_json::Value {
137 if val.is_unit() {
138 serde_json::Value::Null
139 } else if let Some(b) = val.clone().try_cast::<bool>() {
140 serde_json::Value::Bool(b)
141 } else if let Some(i) = val.clone().try_cast::<i64>() {
142 serde_json::json!(i)
143 } else if let Some(f) = val.clone().try_cast::<f64>() {
144 serde_json::json!(f)
145 } else if let Some(s) = val.clone().try_cast::<String>() {
146 serde_json::Value::String(s)
147 } else {
148 serde_json::Value::String(val.to_string())
149 }
150}
151
152fn dynamic_to_string(val: &Dynamic) -> String {
154 if val.is_unit() {
155 String::new()
156 } else if let Some(s) = val.clone().try_cast::<String>() {
157 s
158 } else {
159 val.to_string()
160 }
161}
162
163fn format_rhai_error(e: &EvalAltResult) -> String {
165 format!("Script error: {e}")
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::config::StepConfig;
172 use crate::workflow::schema::StepType;
173
174 fn script_step(name: &str, run: &str) -> StepDef {
175 StepDef {
176 name: name.to_string(),
177 step_type: StepType::Script,
178 run: Some(run.to_string()),
179 prompt: None,
180 condition: None,
181 on_pass: None,
182 on_fail: None,
183 message: None,
184 scope: None,
185 max_iterations: None,
186 initial_value: None,
187 items: None,
188 parallel: None,
189 steps: None,
190 config: HashMap::new(),
191 outputs: None,
192 output_type: None,
193 async_exec: None,
194 }
195 }
196
197 #[tokio::test]
198 async fn script_returns_integer_expression() {
199 let step = script_step("s", "40 + 2");
200 let config = StepConfig::default();
201 let ctx = Context::new(String::new(), HashMap::new());
202
203 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
204 assert_eq!(result.text().trim(), "42");
205 }
206
207 #[tokio::test]
208 async fn script_returns_string_value() {
209 let step = script_step("s", r#""hello from rhai""#);
210 let config = StepConfig::default();
211 let ctx = Context::new(String::new(), HashMap::new());
212
213 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
214 assert_eq!(result.text(), "hello from rhai");
215 }
216
217 #[tokio::test]
218 async fn script_runtime_error_returns_step_error() {
219 let step = script_step("s", "throw \"oops\";");
220 let config = StepConfig::default();
221 let ctx = Context::new(String::new(), HashMap::new());
222
223 let result = ScriptExecutor.execute(&step, &config, &ctx).await;
224 assert!(result.is_err());
225 let err = result.unwrap_err().to_string();
226 assert!(
227 err.contains("Script error") || err.contains("oops"),
228 "Got: {err}"
229 );
230 }
231
232 #[tokio::test]
233 async fn script_ctx_get_reads_step_output() {
234 use crate::steps::{CmdOutput, StepOutput};
235 use std::time::Duration;
236
237 let mut ctx = Context::new(String::new(), HashMap::new());
238 ctx.store(
239 "prev",
240 StepOutput::Cmd(CmdOutput {
241 stdout: "hello_world".to_string(),
242 stderr: String::new(),
243 exit_code: 0,
244 duration: Duration::ZERO,
245 }),
246 );
247
248 let step = script_step("s", r#"let v = ctx_get("prev.stdout"); v"#);
249 let config = StepConfig::default();
250
251 let result = ScriptExecutor.execute(&step, &config, &ctx).await.unwrap();
252 assert_eq!(result.text(), "hello_world");
253 }
254
255 #[tokio::test]
256 async fn script_missing_run_field_returns_error() {
257 let step = StepDef {
258 name: "s".to_string(),
259 step_type: StepType::Script,
260 run: None,
261 prompt: None,
262 condition: None,
263 on_pass: None,
264 on_fail: None,
265 message: None,
266 scope: None,
267 max_iterations: None,
268 initial_value: None,
269 items: None,
270 parallel: None,
271 steps: None,
272 config: HashMap::new(),
273 outputs: None,
274 output_type: None,
275 async_exec: None,
276 };
277 let config = StepConfig::default();
278 let ctx = Context::new(String::new(), HashMap::new());
279 let result = ScriptExecutor.execute(&step, &config, &ctx).await;
280 assert!(result.is_err());
281 assert!(result.unwrap_err().to_string().contains("missing 'run'"));
282 }
283}