use anyhow::{Result, anyhow};
use rhai::{AST, Dynamic, Engine, Map, Scope};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptSecurityConfig {
pub max_execution_time_ms: u64,
pub max_call_stack_depth: usize,
pub max_operations: u64,
pub max_array_size: usize,
pub max_string_size: usize,
pub allow_loops: bool,
pub allow_file_operations: bool,
pub allow_network_operations: bool,
}
impl Default for ScriptSecurityConfig {
fn default() -> Self {
Self {
max_execution_time_ms: 5000,
max_call_stack_depth: 64,
max_operations: 100_000,
max_array_size: 10_000,
max_string_size: 1_000_000,
allow_loops: true,
allow_file_operations: false,
allow_network_operations: false,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScriptEngineConfig {
pub security: ScriptSecurityConfig,
pub script_dirs: Vec<String>,
pub debug_mode: bool,
pub strict_mode: bool,
pub preload_modules: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ScriptContext {
pub variables: HashMap<String, serde_json::Value>,
pub agent_id: Option<String>,
pub workflow_id: Option<String>,
pub node_id: Option<String>,
pub execution_id: Option<String>,
pub metadata: HashMap<String, String>,
}
impl ScriptContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_agent(mut self, agent_id: &str) -> Self {
self.agent_id = Some(agent_id.to_string());
self
}
pub fn with_workflow(mut self, workflow_id: &str) -> Self {
self.workflow_id = Some(workflow_id.to_string());
self
}
pub fn with_node(mut self, node_id: &str) -> Self {
self.node_id = Some(node_id.to_string());
self
}
pub fn with_variable<T: Serialize>(mut self, key: &str, value: T) -> Result<Self> {
let json_value = serde_json::to_value(value)?;
self.variables.insert(key.to_string(), json_value);
Ok(self)
}
pub fn set_variable<T: Serialize>(&mut self, key: &str, value: T) -> Result<()> {
let json_value = serde_json::to_value(value)?;
self.variables.insert(key.to_string(), json_value);
Ok(())
}
pub fn get_variable<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
self.variables
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScriptResult {
pub success: bool,
pub value: serde_json::Value,
pub error: Option<String>,
pub execution_time_ms: u64,
pub operations_count: u64,
pub logs: Vec<String>,
}
impl ScriptResult {
pub fn success(value: serde_json::Value, execution_time_ms: u64) -> Self {
Self {
success: true,
value,
error: None,
execution_time_ms,
operations_count: 0,
logs: Vec::new(),
}
}
pub fn failure(error: String) -> Self {
Self {
success: false,
value: serde_json::Value::Null,
error: Some(error),
execution_time_ms: 0,
operations_count: 0,
logs: Vec::new(),
}
}
pub fn into_typed<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
if !self.success {
return Err(anyhow!(
self.error.unwrap_or_else(|| "Unknown error".into())
));
}
serde_json::from_value(self.value).map_err(|e| anyhow!("Failed to deserialize: {}", e))
}
pub fn as_bool(&self) -> Option<bool> {
self.value.as_bool()
}
pub fn as_str(&self) -> Option<&str> {
self.value.as_str()
}
pub fn as_i64(&self) -> Option<i64> {
self.value.as_i64()
}
pub fn as_f64(&self) -> Option<f64> {
self.value.as_f64()
}
}
pub struct CompiledScript {
pub id: String,
pub name: String,
ast: AST,
source: String,
pub compiled_at: u64,
}
impl CompiledScript {
pub fn new(id: &str, name: &str, ast: AST, source: String) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
ast,
source,
compiled_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn source(&self) -> &str {
&self.source
}
}
pub struct RhaiScriptEngine {
engine: Engine,
#[allow(dead_code)]
config: ScriptEngineConfig,
script_cache: Arc<RwLock<HashMap<String, CompiledScript>>>,
global_scope: Scope<'static>,
logs: Arc<RwLock<Vec<String>>>,
}
impl RhaiScriptEngine {
pub fn new(config: ScriptEngineConfig) -> Result<Self> {
let mut engine = Engine::new();
Self::apply_security_limits(&mut engine, &config.security);
let logs = Arc::new(RwLock::new(Vec::new()));
Self::register_builtin_functions(&mut engine, logs.clone());
let global_scope = Scope::new();
Ok(Self {
engine,
config,
script_cache: Arc::new(RwLock::new(HashMap::new())),
global_scope,
logs,
})
}
fn apply_security_limits(engine: &mut Engine, security: &ScriptSecurityConfig) {
engine.set_max_call_levels(security.max_call_stack_depth);
engine.set_max_operations(security.max_operations);
engine.set_max_array_size(security.max_array_size);
engine.set_max_string_size(security.max_string_size);
if !security.allow_loops {
engine.set_allow_looping(false);
}
engine.set_strict_variables(false);
}
fn register_builtin_functions(engine: &mut Engine, logs: Arc<RwLock<Vec<String>>>) {
let logs_clone = logs.clone();
engine.register_fn("log", move |msg: &str| {
if let Ok(mut l) = logs_clone.try_write() {
l.push(format!("[LOG] {}", msg));
}
});
let logs_clone = logs.clone();
engine.register_fn("debug", move |msg: &str| {
if let Ok(mut l) = logs_clone.try_write() {
l.push(format!("[DEBUG] {}", msg));
}
debug!("Script debug: {}", msg);
});
let logs_clone = logs.clone();
engine.register_fn("print", move |msg: &str| {
if let Ok(mut l) = logs_clone.try_write() {
l.push(format!("[PRINT] {}", msg));
}
debug!("Script print: {}", msg);
});
let logs_clone = logs.clone();
engine.register_fn("warn", move |msg: &str| {
if let Ok(mut l) = logs_clone.try_write() {
l.push(format!("[WARN] {}", msg));
}
warn!("Script warn: {}", msg);
});
let logs_clone = logs.clone();
engine.register_fn("error", move |msg: &str| {
if let Ok(mut l) = logs_clone.try_write() {
l.push(format!("[ERROR] {}", msg));
}
error!("Script error: {}", msg);
});
engine.register_fn("to_json", |value: Dynamic| -> String {
serde_json::to_string(&value).unwrap_or_else(|_| "null".to_string())
});
engine.register_fn("from_json", |json: &str| -> Dynamic {
serde_json::from_str::<serde_json::Value>(json)
.map(|v| json_to_dynamic(&v))
.unwrap_or(Dynamic::UNIT)
});
engine.register_fn("trim", |s: &str| -> String { s.trim().to_string() });
engine.register_fn("upper", |s: &str| -> String { s.to_uppercase() });
engine.register_fn("lower", |s: &str| -> String { s.to_lowercase() });
engine.register_fn("contains", |s: &str, pattern: &str| -> bool {
s.contains(pattern)
});
engine.register_fn("starts_with", |s: &str, pattern: &str| -> bool {
s.starts_with(pattern)
});
engine.register_fn("ends_with", |s: &str, pattern: &str| -> bool {
s.ends_with(pattern)
});
engine.register_fn("replace", |s: &str, from: &str, to: &str| -> String {
s.replace(from, to)
});
engine.register_fn("split", |s: &str, delimiter: &str| -> Vec<Dynamic> {
s.split(delimiter)
.map(|part| Dynamic::from(part.to_string()))
.collect()
});
engine.register_fn("abs", |x: i64| -> i64 { x.abs() });
engine.register_fn("abs_f", |x: f64| -> f64 { x.abs() });
engine.register_fn("min", |a: i64, b: i64| -> i64 { a.min(b) });
engine.register_fn("max", |a: i64, b: i64| -> i64 { a.max(b) });
engine.register_fn("clamp", |value: i64, min: i64, max: i64| -> i64 {
value.clamp(min, max)
});
engine.register_fn("now", || -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
});
engine.register_fn("now_ms", || -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
});
engine.register_fn("uuid", || -> String { uuid::Uuid::now_v7().to_string() });
engine.register_fn("is_null", |v: Dynamic| -> bool { v.is_unit() });
engine.register_fn("is_string", |v: Dynamic| -> bool { v.is_string() });
engine.register_fn("is_int", |v: Dynamic| -> bool { v.is_int() });
engine.register_fn("is_float", |v: Dynamic| -> bool { v.is_float() });
engine.register_fn("is_bool", |v: Dynamic| -> bool { v.is_bool() });
engine.register_fn("is_array", |v: Dynamic| -> bool { v.is_array() });
engine.register_fn("is_map", |v: Dynamic| -> bool { v.is_map() });
engine.register_fn("to_string", |v: i64| -> String { v.to_string() });
engine.register_fn("to_string", |v: f64| -> String { v.to_string() });
engine.register_fn("to_string", |v: bool| -> String { v.to_string() });
engine.register_fn("to_string", |v: &str| -> String { v.to_string() });
}
pub fn compile(&self, id: &str, name: &str, source: &str) -> Result<CompiledScript> {
let ast = self
.engine
.compile(source)
.map_err(|e| anyhow!("Compile error: {}", e))?;
Ok(CompiledScript::new(id, name, ast, source.to_string()))
}
pub async fn compile_and_cache(&self, id: &str, name: &str, source: &str) -> Result<()> {
let compiled = self.compile(id, name, source)?;
let mut cache = self.script_cache.write().await;
cache.insert(id.to_string(), compiled);
info!("Script compiled and cached: {} ({})", name, id);
Ok(())
}
pub async fn load_from_file(&self, path: &Path) -> Result<String> {
let source = tokio::fs::read_to_string(path).await?;
let id = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unnamed");
let name = path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unnamed");
self.compile_and_cache(id, name, &source).await?;
Ok(id.to_string())
}
pub async fn execute(&self, source: &str, context: &ScriptContext) -> Result<ScriptResult> {
let start_time = std::time::Instant::now();
{
let mut logs = self.logs.write().await;
logs.clear();
}
let mut scope = self.global_scope.clone();
self.prepare_scope(&mut scope, context);
let result = self.engine.eval_with_scope::<Dynamic>(&mut scope, source);
let execution_time_ms = start_time.elapsed().as_millis() as u64;
let logs = self.logs.read().await.clone();
match result {
Ok(value) => {
let json_value = dynamic_to_json(&value);
Ok(ScriptResult {
success: true,
value: json_value,
error: None,
execution_time_ms,
operations_count: 0,
logs,
})
}
Err(e) => Ok(ScriptResult {
success: false,
value: serde_json::Value::Null,
error: Some(format!("{}", e)),
execution_time_ms,
operations_count: 0,
logs,
}),
}
}
pub async fn execute_compiled(
&self,
script_id: &str,
context: &ScriptContext,
) -> Result<ScriptResult> {
let cache = self.script_cache.read().await;
let compiled = cache
.get(script_id)
.ok_or_else(|| anyhow!("Script not found: {}", script_id))?;
let start_time = std::time::Instant::now();
{
let mut logs = self.logs.write().await;
logs.clear();
}
let mut scope = self.global_scope.clone();
self.prepare_scope(&mut scope, context);
let result = self
.engine
.eval_ast_with_scope::<Dynamic>(&mut scope, &compiled.ast);
let execution_time_ms = start_time.elapsed().as_millis() as u64;
let logs = self.logs.read().await.clone();
match result {
Ok(value) => {
let json_value = dynamic_to_json(&value);
Ok(ScriptResult {
success: true,
value: json_value,
error: None,
execution_time_ms,
operations_count: 0,
logs,
})
}
Err(e) => Ok(ScriptResult {
success: false,
value: serde_json::Value::Null,
error: Some(format!("{}", e)),
execution_time_ms,
operations_count: 0,
logs,
}),
}
}
pub async fn call_function<T: for<'de> Deserialize<'de>>(
&self,
script_id: &str,
function_name: &str,
args: Vec<serde_json::Value>,
context: &ScriptContext,
) -> Result<T> {
let cache = self.script_cache.read().await;
let compiled = cache
.get(script_id)
.ok_or_else(|| anyhow!("Script not found: {}", script_id))?;
let mut scope = self.global_scope.clone();
self.prepare_scope(&mut scope, context);
let dynamic_args: Vec<Dynamic> = args.iter().map(json_to_dynamic).collect();
let result: Dynamic = self
.engine
.call_fn(&mut scope, &compiled.ast, function_name, dynamic_args)
.map_err(|e| anyhow!("Function call error: {}", e))?;
let json_value = dynamic_to_json(&result);
serde_json::from_value(json_value).map_err(|e| anyhow!("Result conversion error: {}", e))
}
fn prepare_scope(&self, scope: &mut Scope, context: &ScriptContext) {
if let Some(ref agent_id) = context.agent_id {
scope.push_constant("AGENT_ID", agent_id.clone());
}
if let Some(ref workflow_id) = context.workflow_id {
scope.push_constant("WORKFLOW_ID", workflow_id.clone());
}
if let Some(ref node_id) = context.node_id {
scope.push_constant("NODE_ID", node_id.clone());
}
if let Some(ref execution_id) = context.execution_id {
scope.push_constant("EXECUTION_ID", execution_id.clone());
}
for (key, value) in &context.variables {
let dynamic_value = json_to_dynamic(value);
scope.push(key.clone(), dynamic_value);
}
let mut metadata_map = Map::new();
for (k, v) in &context.metadata {
metadata_map.insert(k.clone().into(), Dynamic::from(v.clone()));
}
scope.push_constant("metadata", metadata_map);
}
pub fn validate(&self, source: &str) -> Result<Vec<String>> {
match self.engine.compile(source) {
Ok(_) => Ok(Vec::new()),
Err(e) => {
let errors = vec![format!("{}", e)];
Ok(errors)
}
}
}
pub async fn cached_scripts(&self) -> Vec<String> {
let cache = self.script_cache.read().await;
cache.keys().cloned().collect()
}
pub async fn remove_cached(&self, script_id: &str) -> bool {
let mut cache = self.script_cache.write().await;
cache.remove(script_id).is_some()
}
pub async fn clear_cache(&self) {
let mut cache = self.script_cache.write().await;
cache.clear();
}
pub fn engine(&self) -> &Engine {
&self.engine
}
pub fn engine_mut(&mut self) -> &mut Engine {
&mut self.engine
}
}
pub fn json_to_dynamic(value: &serde_json::Value) -> Dynamic {
match value {
serde_json::Value::Null => Dynamic::UNIT,
serde_json::Value::Bool(b) => Dynamic::from(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Dynamic::from(i)
} else if let Some(f) = n.as_f64() {
Dynamic::from(f)
} else {
Dynamic::UNIT
}
}
serde_json::Value::String(s) => Dynamic::from(s.clone()),
serde_json::Value::Array(arr) => {
let vec: Vec<Dynamic> = arr.iter().map(json_to_dynamic).collect();
Dynamic::from(vec)
}
serde_json::Value::Object(obj) => {
let mut map = Map::new();
for (k, v) in obj {
map.insert(k.clone().into(), json_to_dynamic(v));
}
Dynamic::from(map)
}
}
}
pub fn dynamic_to_json(value: &Dynamic) -> serde_json::Value {
if value.is_unit() {
serde_json::Value::Null
} else if let Some(b) = value.clone().try_cast::<bool>() {
serde_json::Value::Bool(b)
} else if let Some(i) = value.clone().try_cast::<i64>() {
serde_json::json!(i)
} else if let Some(f) = value.clone().try_cast::<f64>() {
serde_json::json!(f)
} else if let Some(s) = value.clone().try_cast::<String>() {
serde_json::Value::String(s)
} else if value.is_array() {
let arr = value.clone().cast::<rhai::Array>();
let json_arr: Vec<serde_json::Value> = arr.iter().map(dynamic_to_json).collect();
serde_json::Value::Array(json_arr)
} else if value.is_map() {
let map = value.clone().cast::<Map>();
let mut json_obj = serde_json::Map::new();
for (k, v) in map.iter() {
json_obj.insert(k.to_string(), dynamic_to_json(v));
}
serde_json::Value::Object(json_obj)
} else {
serde_json::Value::String(value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_script_execution() {
let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
let context = ScriptContext::new();
let result = engine.execute("1 + 2", &context).await.unwrap();
assert!(result.success);
assert_eq!(result.value, serde_json::json!(3));
}
#[tokio::test]
async fn test_script_with_variables() {
let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
let context = ScriptContext::new()
.with_variable("x", 10)
.unwrap()
.with_variable("y", 20)
.unwrap();
let result = engine.execute("x + y", &context).await.unwrap();
assert!(result.success);
assert_eq!(result.value, serde_json::json!(30));
}
#[tokio::test]
async fn test_script_with_function() {
let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
let context = ScriptContext::new();
let script = r#"
fn double(n) {
n * 2
}
double(21)
"#;
let result = engine.execute(script, &context).await.unwrap();
assert!(result.success);
assert_eq!(result.value, serde_json::json!(42));
}
#[tokio::test]
async fn test_compiled_script() {
let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
engine
.compile_and_cache(
"test_script",
"Test Script",
r#"
fn process(input) {
let result = #{};
result.doubled = input.value * 2;
result.message = "processed: " + input.name;
result
}
process(input)
"#,
)
.await
.unwrap();
let context = ScriptContext::new()
.with_variable(
"input",
serde_json::json!({
"name": "test",
"value": 21
}),
)
.unwrap();
let result = engine
.execute_compiled("test_script", &context)
.await
.unwrap();
assert!(result.success);
assert_eq!(result.value["doubled"], 42);
assert_eq!(result.value["message"], "processed: test");
}
#[tokio::test]
async fn test_builtin_functions() {
let engine = RhaiScriptEngine::new(ScriptEngineConfig::default()).unwrap();
let context = ScriptContext::new();
let result = engine.execute(r#"upper("hello")"#, &context).await.unwrap();
assert_eq!(result.value, "HELLO");
let result = engine
.execute(r#"to_json(#{name: "test", value: 42})"#, &context)
.await
.unwrap();
assert!(result.value.as_str().is_some());
let result = engine.execute("now()", &context).await.unwrap();
assert!(result.value.as_i64().is_some());
}
#[test]
fn test_json_conversion() {
let json = serde_json::json!({
"name": "test",
"values": [1, 2, 3],
"nested": {
"flag": true
}
});
let dynamic = json_to_dynamic(&json);
let back = dynamic_to_json(&dynamic);
assert_eq!(json, back);
}
}