1use anyhow::{Result, anyhow};
6use rhai::{AST, Dynamic, Engine, Map, Scope};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScriptSecurityConfig {
21 pub max_execution_time_ms: u64,
23 pub max_call_stack_depth: usize,
25 pub max_operations: u64,
27 pub max_array_size: usize,
29 pub max_string_size: usize,
31 pub allow_loops: bool,
33 pub allow_file_operations: bool,
35 pub allow_network_operations: bool,
37}
38
39impl Default for ScriptSecurityConfig {
40 fn default() -> Self {
41 Self {
42 max_execution_time_ms: 5000,
43 max_call_stack_depth: 64,
44 max_operations: 100_000,
45 max_array_size: 10_000,
46 max_string_size: 1_000_000,
47 allow_loops: true,
48 allow_file_operations: false,
49 allow_network_operations: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct ScriptEngineConfig {
57 pub security: ScriptSecurityConfig,
59 pub script_dirs: Vec<String>,
61 pub debug_mode: bool,
63 pub strict_mode: bool,
65 pub preload_modules: Vec<String>,
67}
68
69#[derive(Debug, Clone, Default)]
75pub struct ScriptContext {
76 pub variables: HashMap<String, serde_json::Value>,
78 pub agent_id: Option<String>,
80 pub workflow_id: Option<String>,
82 pub node_id: Option<String>,
84 pub execution_id: Option<String>,
86 pub metadata: HashMap<String, String>,
88}
89
90impl ScriptContext {
91 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn with_agent(mut self, agent_id: &str) -> Self {
96 self.agent_id = Some(agent_id.to_string());
97 self
98 }
99
100 pub fn with_workflow(mut self, workflow_id: &str) -> Self {
101 self.workflow_id = Some(workflow_id.to_string());
102 self
103 }
104
105 pub fn with_node(mut self, node_id: &str) -> Self {
106 self.node_id = Some(node_id.to_string());
107 self
108 }
109
110 pub fn with_variable<T: Serialize>(mut self, key: &str, value: T) -> Result<Self> {
111 let json_value = serde_json::to_value(value)?;
112 self.variables.insert(key.to_string(), json_value);
113 Ok(self)
114 }
115
116 pub fn set_variable<T: Serialize>(&mut self, key: &str, value: T) -> Result<()> {
117 let json_value = serde_json::to_value(value)?;
118 self.variables.insert(key.to_string(), json_value);
119 Ok(())
120 }
121
122 pub fn get_variable<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
123 self.variables
124 .get(key)
125 .and_then(|v| serde_json::from_value(v.clone()).ok())
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ScriptResult {
136 pub success: bool,
138 pub value: serde_json::Value,
140 pub error: Option<String>,
142 pub execution_time_ms: u64,
144 pub operations_count: u64,
146 pub logs: Vec<String>,
148}
149
150impl ScriptResult {
151 pub fn success(value: serde_json::Value, execution_time_ms: u64) -> Self {
152 Self {
153 success: true,
154 value,
155 error: None,
156 execution_time_ms,
157 operations_count: 0,
158 logs: Vec::new(),
159 }
160 }
161
162 pub fn failure(error: String) -> Self {
163 Self {
164 success: false,
165 value: serde_json::Value::Null,
166 error: Some(error),
167 execution_time_ms: 0,
168 operations_count: 0,
169 logs: Vec::new(),
170 }
171 }
172
173 pub fn into_typed<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
175 if !self.success {
176 return Err(anyhow!(
177 self.error.unwrap_or_else(|| "Unknown error".into())
178 ));
179 }
180 serde_json::from_value(self.value).map_err(|e| anyhow!("Failed to deserialize: {}", e))
181 }
182
183 pub fn as_bool(&self) -> Option<bool> {
185 self.value.as_bool()
186 }
187
188 pub fn as_str(&self) -> Option<&str> {
190 self.value.as_str()
191 }
192
193 pub fn as_i64(&self) -> Option<i64> {
195 self.value.as_i64()
196 }
197
198 pub fn as_f64(&self) -> Option<f64> {
200 self.value.as_f64()
201 }
202}
203
204pub struct CompiledScript {
210 pub id: String,
212 pub name: String,
214 ast: AST,
216 source: String,
218 pub compiled_at: u64,
220}
221
222impl CompiledScript {
223 pub fn new(id: &str, name: &str, ast: AST, source: String) -> Self {
224 Self {
225 id: id.to_string(),
226 name: name.to_string(),
227 ast,
228 source,
229 compiled_at: std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .unwrap_or_default()
232 .as_secs(),
233 }
234 }
235
236 pub fn source(&self) -> &str {
237 &self.source
238 }
239}
240
241pub struct RhaiScriptEngine {
247 engine: Engine,
249 #[allow(dead_code)]
251 config: ScriptEngineConfig,
252 script_cache: Arc<RwLock<HashMap<String, CompiledScript>>>,
254 global_scope: Scope<'static>,
256 logs: Arc<RwLock<Vec<String>>>,
258}
259
260impl RhaiScriptEngine {
261 pub fn new(config: ScriptEngineConfig) -> Result<Self> {
263 let mut engine = Engine::new();
264
265 Self::apply_security_limits(&mut engine, &config.security);
267
268 let logs = Arc::new(RwLock::new(Vec::new()));
270 Self::register_builtin_functions(&mut engine, logs.clone());
271
272 let global_scope = Scope::new();
274
275 Ok(Self {
276 engine,
277 config,
278 script_cache: Arc::new(RwLock::new(HashMap::new())),
279 global_scope,
280 logs,
281 })
282 }
283
284 fn apply_security_limits(engine: &mut Engine, security: &ScriptSecurityConfig) {
286 engine.set_max_call_levels(security.max_call_stack_depth);
287 engine.set_max_operations(security.max_operations);
288 engine.set_max_array_size(security.max_array_size);
289 engine.set_max_string_size(security.max_string_size);
290
291 if !security.allow_loops {
292 engine.set_allow_looping(false);
293 }
294
295 engine.set_strict_variables(false);
297 }
298
299 fn register_builtin_functions(engine: &mut Engine, logs: Arc<RwLock<Vec<String>>>) {
301 let logs_clone = logs.clone();
303 engine.register_fn("log", move |msg: &str| {
304 if let Ok(mut l) = logs_clone.try_write() {
305 l.push(format!("[LOG] {}", msg));
306 }
307 });
308
309 let logs_clone = logs.clone();
310 engine.register_fn("debug", move |msg: &str| {
311 if let Ok(mut l) = logs_clone.try_write() {
312 l.push(format!("[DEBUG] {}", msg));
313 }
314 debug!("Script debug: {}", msg);
315 });
316
317 let logs_clone = logs.clone();
319 engine.register_fn("print", move |msg: &str| {
320 if let Ok(mut l) = logs_clone.try_write() {
321 l.push(format!("[PRINT] {}", msg));
322 }
323 debug!("Script print: {}", msg);
324 });
325
326 let logs_clone = logs.clone();
327 engine.register_fn("warn", move |msg: &str| {
328 if let Ok(mut l) = logs_clone.try_write() {
329 l.push(format!("[WARN] {}", msg));
330 }
331 warn!("Script warn: {}", msg);
332 });
333
334 let logs_clone = logs.clone();
335 engine.register_fn("error", move |msg: &str| {
336 if let Ok(mut l) = logs_clone.try_write() {
337 l.push(format!("[ERROR] {}", msg));
338 }
339 error!("Script error: {}", msg);
340 });
341
342 engine.register_fn("to_json", |value: Dynamic| -> String {
344 serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string())
345 });
346
347 engine.register_fn("from_json", |json: &str| -> Dynamic {
348 serde_json::from_str::<serde_json::Value>(json)
349 .map(|v| json_to_dynamic(&v))
350 .unwrap_or(Dynamic::UNIT)
351 });
352
353 engine.register_fn("trim", |s: &str| -> String { s.trim().to_string() });
355
356 engine.register_fn("upper", |s: &str| -> String { s.to_uppercase() });
357
358 engine.register_fn("lower", |s: &str| -> String { s.to_lowercase() });
359
360 engine.register_fn("contains", |s: &str, pattern: &str| -> bool {
361 s.contains(pattern)
362 });
363
364 engine.register_fn("starts_with", |s: &str, pattern: &str| -> bool {
365 s.starts_with(pattern)
366 });
367
368 engine.register_fn("ends_with", |s: &str, pattern: &str| -> bool {
369 s.ends_with(pattern)
370 });
371
372 engine.register_fn("replace", |s: &str, from: &str, to: &str| -> String {
373 s.replace(from, to)
374 });
375
376 engine.register_fn("split", |s: &str, delimiter: &str| -> Vec<Dynamic> {
377 s.split(delimiter)
378 .map(|part| Dynamic::from(part.to_string()))
379 .collect()
380 });
381
382 engine.register_fn("abs", |x: i64| -> i64 { x.abs() });
384 engine.register_fn("abs_f", |x: f64| -> f64 { x.abs() });
385 engine.register_fn("min", |a: i64, b: i64| -> i64 { a.min(b) });
386 engine.register_fn("max", |a: i64, b: i64| -> i64 { a.max(b) });
387 engine.register_fn("clamp", |value: i64, min: i64, max: i64| -> i64 {
388 value.clamp(min, max)
389 });
390
391 engine.register_fn("now", || -> i64 {
393 std::time::SystemTime::now()
394 .duration_since(std::time::UNIX_EPOCH)
395 .unwrap_or_default()
396 .as_secs() as i64
397 });
398
399 engine.register_fn("now_ms", || -> i64 {
400 std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .unwrap_or_default()
403 .as_millis() as i64
404 });
405
406 engine.register_fn("uuid", || -> String { uuid::Uuid::now_v7().to_string() });
408
409 engine.register_fn("is_null", |v: Dynamic| -> bool { v.is_unit() });
411 engine.register_fn("is_string", |v: Dynamic| -> bool { v.is_string() });
412 engine.register_fn("is_int", |v: Dynamic| -> bool { v.is_int() });
413 engine.register_fn("is_float", |v: Dynamic| -> bool { v.is_float() });
414 engine.register_fn("is_bool", |v: Dynamic| -> bool { v.is_bool() });
415 engine.register_fn("is_array", |v: Dynamic| -> bool { v.is_array() });
416 engine.register_fn("is_map", |v: Dynamic| -> bool { v.is_map() });
417
418 engine.register_fn("to_string", |v: i64| -> String { v.to_string() });
420 engine.register_fn("to_string", |v: f64| -> String { v.to_string() });
421 engine.register_fn("to_string", |v: bool| -> String { v.to_string() });
422 engine.register_fn("to_string", |v: &str| -> String { v.to_string() });
423 }
424
425 pub fn compile(&self, id: &str, name: &str, source: &str) -> Result<CompiledScript> {
427 let ast = self
428 .engine
429 .compile(source)
430 .map_err(|e| anyhow!("Compile error: {}", e))?;
431
432 Ok(CompiledScript::new(id, name, ast, source.to_string()))
433 }
434
435 pub async fn compile_and_cache(&self, id: &str, name: &str, source: &str) -> Result<()> {
437 let compiled = self.compile(id, name, source)?;
438 let mut cache = self.script_cache.write().await;
439 cache.insert(id.to_string(), compiled);
440 info!("Script compiled and cached: {} ({})", name, id);
441 Ok(())
442 }
443
444 pub async fn load_from_file(&self, path: &Path) -> Result<String> {
446 let source = tokio::fs::read_to_string(path).await?;
447 let id = path
448 .file_stem()
449 .and_then(|s| s.to_str())
450 .unwrap_or("unnamed");
451 let name = path
452 .file_name()
453 .and_then(|s| s.to_str())
454 .unwrap_or("unnamed");
455
456 self.compile_and_cache(id, name, &source).await?;
457 Ok(id.to_string())
458 }
459
460 pub async fn execute(&self, source: &str, context: &ScriptContext) -> Result<ScriptResult> {
462 let start_time = std::time::Instant::now();
463
464 {
466 let mut logs = self.logs.write().await;
467 logs.clear();
468 }
469
470 let mut scope = self.global_scope.clone();
472 self.prepare_scope(&mut scope, context);
473
474 let result = self.engine.eval_with_scope::<Dynamic>(&mut scope, source);
476
477 let execution_time_ms = start_time.elapsed().as_millis() as u64;
478 let logs = self.logs.read().await.clone();
479
480 match result {
481 Ok(value) => {
482 let json_value = dynamic_to_json(&value);
483 Ok(ScriptResult {
484 success: true,
485 value: json_value,
486 error: None,
487 execution_time_ms,
488 operations_count: 0,
489 logs,
490 })
491 }
492 Err(e) => Ok(ScriptResult {
493 success: false,
494 value: serde_json::Value::Null,
495 error: Some(format!("{}", e)),
496 execution_time_ms,
497 operations_count: 0,
498 logs,
499 }),
500 }
501 }
502
503 pub async fn execute_compiled(
505 &self,
506 script_id: &str,
507 context: &ScriptContext,
508 ) -> Result<ScriptResult> {
509 let cache = self.script_cache.read().await;
510 let compiled = cache
511 .get(script_id)
512 .ok_or_else(|| anyhow!("Script not found: {}", script_id))?;
513
514 let start_time = std::time::Instant::now();
515
516 {
518 let mut logs = self.logs.write().await;
519 logs.clear();
520 }
521
522 let mut scope = self.global_scope.clone();
524 self.prepare_scope(&mut scope, context);
525
526 let result = self
528 .engine
529 .eval_ast_with_scope::<Dynamic>(&mut scope, &compiled.ast);
530
531 let execution_time_ms = start_time.elapsed().as_millis() as u64;
532 let logs = self.logs.read().await.clone();
533
534 match result {
535 Ok(value) => {
536 let json_value = dynamic_to_json(&value);
537 Ok(ScriptResult {
538 success: true,
539 value: json_value,
540 error: None,
541 execution_time_ms,
542 operations_count: 0,
543 logs,
544 })
545 }
546 Err(e) => Ok(ScriptResult {
547 success: false,
548 value: serde_json::Value::Null,
549 error: Some(format!("{}", e)),
550 execution_time_ms,
551 operations_count: 0,
552 logs,
553 }),
554 }
555 }
556
557 pub async fn call_function<T: for<'de> Deserialize<'de>>(
559 &self,
560 script_id: &str,
561 function_name: &str,
562 args: Vec<serde_json::Value>,
563 context: &ScriptContext,
564 ) -> Result<T> {
565 let cache = self.script_cache.read().await;
566 let compiled = cache
567 .get(script_id)
568 .ok_or_else(|| anyhow!("Script not found: {}", script_id))?;
569
570 let mut scope = self.global_scope.clone();
572 self.prepare_scope(&mut scope, context);
573
574 let dynamic_args: Vec<Dynamic> = args.iter().map(json_to_dynamic).collect();
576
577 let result: Dynamic = self
579 .engine
580 .call_fn(&mut scope, &compiled.ast, function_name, dynamic_args)
581 .map_err(|e| anyhow!("Function call error: {}", e))?;
582
583 let json_value = dynamic_to_json(&result);
585 serde_json::from_value(json_value).map_err(|e| anyhow!("Result conversion error: {}", e))
586 }
587
588 fn prepare_scope(&self, scope: &mut Scope, context: &ScriptContext) {
590 if let Some(ref agent_id) = context.agent_id {
592 scope.push_constant("AGENT_ID", agent_id.clone());
593 }
594 if let Some(ref workflow_id) = context.workflow_id {
595 scope.push_constant("WORKFLOW_ID", workflow_id.clone());
596 }
597 if let Some(ref node_id) = context.node_id {
598 scope.push_constant("NODE_ID", node_id.clone());
599 }
600 if let Some(ref execution_id) = context.execution_id {
601 scope.push_constant("EXECUTION_ID", execution_id.clone());
602 }
603
604 for (key, value) in &context.variables {
606 let dynamic_value = json_to_dynamic(value);
607 scope.push(key.clone(), dynamic_value);
608 }
609
610 let mut metadata_map = Map::new();
612 for (k, v) in &context.metadata {
613 metadata_map.insert(k.clone().into(), Dynamic::from(v.clone()));
614 }
615 scope.push_constant("metadata", metadata_map);
616 }
617
618 pub fn validate(&self, source: &str) -> Result<Vec<String>> {
620 match self.engine.compile(source) {
621 Ok(_) => Ok(Vec::new()),
622 Err(e) => {
623 let errors = vec![format!("{}", e)];
624 Ok(errors)
625 }
626 }
627 }
628
629 pub async fn cached_scripts(&self) -> Vec<String> {
631 let cache = self.script_cache.read().await;
632 cache.keys().cloned().collect()
633 }
634
635 pub async fn remove_cached(&self, script_id: &str) -> bool {
637 let mut cache = self.script_cache.write().await;
638 cache.remove(script_id).is_some()
639 }
640
641 pub async fn clear_cache(&self) {
643 let mut cache = self.script_cache.write().await;
644 cache.clear();
645 }
646
647 pub fn engine(&self) -> &Engine {
649 &self.engine
650 }
651
652 pub fn engine_mut(&mut self) -> &mut Engine {
654 &mut self.engine
655 }
656}
657
658pub fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
664 match value {
665 serde_json::Value::Null => Dynamic::UNIT,
666 serde_json::Value::Bool(b) => Dynamic::from(*b),
667 serde_json::Value::Number(n) => {
668 if let Some(i) = n.as_i64() {
669 Dynamic::from(i)
670 } else if let Some(f) = n.as_f64() {
671 Dynamic::from(f)
672 } else {
673 Dynamic::UNIT
674 }
675 }
676 serde_json::Value::String(s) => Dynamic::from(s.clone()),
677 serde_json::Value::Array(arr) => {
678 let vec: Vec<Dynamic> = arr.iter().map(json_to_dynamic).collect();
679 Dynamic::from(vec)
680 }
681 serde_json::Value::Object(obj) => {
682 let mut map = Map::new();
683 for (k, v) in obj {
684 map.insert(k.clone().into(), json_to_dynamic(v));
685 }
686 Dynamic::from(map)
687 }
688 }
689}
690
691pub fn dynamic_to_json(value: &Dynamic) -> serde_json::Value {
693 if value.is_unit() {
694 serde_json::Value::Null
695 } else if let Some(b) = value.clone().try_cast::<bool>() {
696 serde_json::Value::Bool(b)
697 } else if let Some(i) = value.clone().try_cast::<i64>() {
698 serde_json::json!(i)
699 } else if let Some(f) = value.clone().try_cast::<f64>() {
700 serde_json::json!(f)
701 } else if let Some(s) = value.clone().try_cast::<String>() {
702 serde_json::Value::String(s)
703 } else if value.is_array() {
704 let arr = value.clone().cast::<rhai::Array>();
705 let json_arr: Vec<serde_json::Value> = arr.iter().map(dynamic_to_json).collect();
706 serde_json::Value::Array(json_arr)
707 } else if value.is_map() {
708 let map = value.clone().cast::<Map>();
709 let mut json_obj = serde_json::Map::new();
710 for (k, v) in map.iter() {
711 json_obj.insert(k.to_string(), dynamic_to_json(v));
712 }
713 serde_json::Value::Object(json_obj)
714 } else {
715 serde_json::Value::String(value.to_string())
717 }
718}
719
720#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[tokio::test]
729 async fn test_basic_script_execution() {
730 let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
731 let context = ScriptContext::new();
732
733 let result = engine.execute("1 + 2", &context).await.unwrap();
734
735 assert!(result.success);
736 assert_eq!(result.value, serde_json::json!(3));
737 }
738
739 #[tokio::test]
740 async fn test_script_with_variables() {
741 let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
742 let context = ScriptContext::new()
743 .with_variable("x", 10)
744 .unwrap()
745 .with_variable("y", 20)
746 .unwrap();
747
748 let result = engine.execute("x + y", &context).await.unwrap();
749
750 assert!(result.success);
751 assert_eq!(result.value, serde_json::json!(30));
752 }
753
754 #[tokio::test]
755 async fn test_script_with_function() {
756 let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
757 let context = ScriptContext::new();
758
759 let script = r#"
760 fn double(n) {
761 n * 2
762 }
763 double(21)
764 "#;
765
766 let result = engine.execute(script, &context).await.unwrap();
767
768 assert!(result.success);
769 assert_eq!(result.value, serde_json::json!(42));
770 }
771
772 #[tokio::test]
773 async fn test_compiled_script() {
774 let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
775
776 engine
777 .compile_and_cache(
778 "test_script",
779 "Test Script",
780 r#"
781 fn process(input) {
782 let result = #{};
783 result.doubled = input.value * 2;
784 result.message = "processed: " + input.name;
785 result
786 }
787 process(input)
788 "#,
789 )
790 .await
791 .unwrap();
792
793 let context = ScriptContext::new()
794 .with_variable(
795 "input",
796 serde_json::json!({
797 "name": "test",
798 "value": 21
799 }),
800 )
801 .unwrap();
802
803 let result = engine
804 .execute_compiled("test_script", &context)
805 .await
806 .unwrap();
807
808 assert!(result.success);
809 assert_eq!(result.value["doubled"], 42);
810 assert_eq!(result.value["message"], "processed: test");
811 }
812
813 #[tokio::test]
814 async fn test_builtin_functions() {
815 let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
816 let context = ScriptContext::new();
817
818 let result = engine.execute(r#"upper("hello")"#, &context).await.unwrap();
820 assert_eq!(result.value, "HELLO");
821
822 let result = engine
824 .execute(r#"to_json(#{name: "test", value: 42})"#, &context)
825 .await
826 .unwrap();
827 assert!(result.value.as_str().is_some());
828
829 let result = engine.execute("now()", &context).await.unwrap();
831 assert!(result.value.as_i64().is_some());
832 }
833
834 #[test]
835 fn test_json_conversion() {
836 let json = serde_json::json!({
837 "name": "test",
838 "values": [1, 2, 3],
839 "nested": {
840 "flag": true
841 }
842 });
843
844 let dynamic = json_to_dynamic(&json);
845 let back = dynamic_to_json(&dynamic);
846
847 assert_eq!(json, back);
848 }
849}