1use crate::functions;
4use anyhow::{anyhow, Result};
5use rhai::{Dynamic, Engine, Map, Scope, AST};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use tracing::debug;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ScriptContext {
14 pub profile: ProfileContext,
16 pub provider: ProviderContext,
18 pub agent: AgentContext,
20 pub prefs: PrefsContext,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ProfileContext {
27 pub alias: String,
28 pub home: PathBuf,
29 pub model: String,
30 pub endpoint: String,
31 pub hooks: Vec<String>,
32 pub mcp_servers: Vec<String>,
33 pub hooks_config: Option<serde_json::Value>,
35 pub proxy_url: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProviderContext {
42 pub id: String,
43 pub name: String,
44 pub provider_type: String,
45 pub auth_env_key: String,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AgentContext {
51 pub id: String,
52 pub name: String,
53 pub binary: String,
54}
55
56#[derive(Debug, Clone, Default, Serialize, Deserialize)]
58pub struct PrefsContext {
59 #[serde(flatten)]
61 pub custom: HashMap<String, String>,
62}
63
64#[derive(Debug, Clone, Default)]
66pub struct ScriptOutput {
67 pub files: HashMap<String, String>,
69 pub env: HashMap<String, String>,
71 pub args: Vec<String>,
73 pub hooks: Option<serde_json::Value>,
75 pub mcp_servers: Option<serde_json::Value>,
77}
78
79pub struct ScriptEngine {
81 engine: Engine,
82}
83
84impl ScriptEngine {
85 pub fn new() -> Self {
87 let mut engine = Engine::new();
88
89 engine.set_max_operations(100_000);
91 engine.set_max_string_size(1024 * 1024); engine.set_max_array_size(10_000);
93 engine.set_max_map_size(10_000);
94 engine.set_max_call_levels(64);
95
96 functions::register_all(&mut engine);
98
99 Self { engine }
100 }
101
102 pub fn compile(&self, script: &str) -> Result<AST> {
104 self.engine
105 .compile(script)
106 .map_err(|e| anyhow!("Failed to compile script: {}", e))
107 }
108
109 pub fn run(&self, script: &str, context: &ScriptContext) -> Result<ScriptOutput> {
111 let ast = self.compile(script)?;
112 self.run_ast(&ast, context)
113 }
114
115 pub fn run_ast(&self, ast: &AST, context: &ScriptContext) -> Result<ScriptOutput> {
117 let mut scope = Scope::new();
118
119 let context_dynamic = context_to_dynamic(context)?;
121 scope.push_dynamic("ctx", context_dynamic);
122
123 debug!("Running script with context: {:?}", context);
124
125 let result: Dynamic = self
127 .engine
128 .eval_ast_with_scope(&mut scope, ast)
129 .map_err(|e| anyhow!("Script execution failed: {}", e))?;
130
131 dynamic_to_output(result)
133 }
134}
135
136impl Default for ScriptEngine {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142fn context_to_dynamic(context: &ScriptContext) -> Result<Dynamic> {
144 let mut map = Map::new();
145
146 let mut profile = Map::new();
148 profile.insert("alias".into(), context.profile.alias.clone().into());
149 profile.insert(
150 "home".into(),
151 context.profile.home.to_string_lossy().to_string().into(),
152 );
153 profile.insert("model".into(), context.profile.model.clone().into());
154 profile.insert("endpoint".into(), context.profile.endpoint.clone().into());
155 profile.insert(
156 "hooks".into(),
157 context
158 .profile
159 .hooks
160 .iter()
161 .map(|s| Dynamic::from(s.clone()))
162 .collect::<Vec<_>>()
163 .into(),
164 );
165 profile.insert(
166 "mcp_servers".into(),
167 context
168 .profile
169 .mcp_servers
170 .iter()
171 .map(|s| Dynamic::from(s.clone()))
172 .collect::<Vec<_>>()
173 .into(),
174 );
175 if let Some(ref hooks_json) = context.profile.hooks_config {
177 let hooks_dynamic = json_to_dynamic(hooks_json.clone())?;
178 profile.insert("hooks_config".into(), hooks_dynamic);
179 } else {
180 profile.insert("hooks_config".into(), Dynamic::UNIT);
181 }
182 if let Some(ref proxy_url) = context.profile.proxy_url {
184 profile.insert("proxy_url".into(), proxy_url.clone().into());
185 } else {
186 profile.insert("proxy_url".into(), Dynamic::UNIT);
187 }
188 map.insert("profile".into(), profile.into());
189
190 let mut provider = Map::new();
192 provider.insert("id".into(), context.provider.id.clone().into());
193 provider.insert("name".into(), context.provider.name.clone().into());
194 provider.insert("type".into(), context.provider.provider_type.clone().into());
195 provider.insert(
196 "auth_env_key".into(),
197 context.provider.auth_env_key.clone().into(),
198 );
199 map.insert("provider".into(), provider.into());
200
201 let mut agent = Map::new();
203 agent.insert("id".into(), context.agent.id.clone().into());
204 agent.insert("name".into(), context.agent.name.clone().into());
205 agent.insert("binary".into(), context.agent.binary.clone().into());
206 map.insert("agent".into(), agent.into());
207
208 let mut prefs = Map::new();
210 for (k, v) in &context.prefs.custom {
211 prefs.insert(k.clone().into(), v.clone().into());
212 }
213 map.insert("prefs".into(), prefs.into());
214
215 Ok(map.into())
216}
217
218fn dynamic_to_output(result: Dynamic) -> Result<ScriptOutput> {
220 let mut output = ScriptOutput::default();
221
222 let map = result
223 .try_cast::<Map>()
224 .ok_or_else(|| anyhow!("Script must return an object"))?;
225
226 if let Some(files_dynamic) = map.get("files") {
228 if let Some(files_map) = files_dynamic.clone().try_cast::<Map>() {
229 for (key, value) in files_map {
230 if let Some(content) = value.clone().try_cast::<String>() {
231 output.files.insert(key.to_string(), content);
232 }
233 }
234 }
235 }
236
237 if let Some(env_dynamic) = map.get("env") {
239 if let Some(env_map) = env_dynamic.clone().try_cast::<Map>() {
240 for (key, value) in env_map {
241 if let Some(val) = value.clone().try_cast::<String>() {
242 output.env.insert(key.to_string(), val);
243 }
244 }
245 }
246 }
247
248 if let Some(args_dynamic) = map.get("args") {
250 if let Some(args_arr) = args_dynamic.clone().try_cast::<rhai::Array>() {
251 for arg in args_arr {
252 if let Some(arg_str) = arg.clone().try_cast::<String>() {
253 output.args.push(arg_str);
254 }
255 }
256 }
257 }
258
259 if let Some(hooks_dynamic) = map.get("hooks") {
261 output.hooks = Some(dynamic_to_json(hooks_dynamic.clone())?);
262 }
263
264 if let Some(mcp_dynamic) = map.get("mcp_servers") {
266 output.mcp_servers = Some(dynamic_to_json(mcp_dynamic.clone())?);
267 }
268
269 Ok(output)
270}
271
272fn dynamic_to_json(value: Dynamic) -> Result<serde_json::Value> {
274 if value.is::<()>() {
275 Ok(serde_json::Value::Null)
276 } else if value.is::<bool>() {
277 Ok(serde_json::Value::Bool(value.cast::<bool>()))
278 } else if value.is::<i64>() {
279 Ok(serde_json::Value::Number(value.cast::<i64>().into()))
280 } else if value.is::<f64>() {
281 let f = value.cast::<f64>();
282 Ok(serde_json::Number::from_f64(f)
283 .map(serde_json::Value::Number)
284 .unwrap_or(serde_json::Value::Null))
285 } else if value.is::<String>() {
286 Ok(serde_json::Value::String(value.cast::<String>()))
287 } else if value.is::<rhai::Array>() {
288 let arr = value.cast::<rhai::Array>();
289 let mut json_arr = Vec::with_capacity(arr.len());
290 for item in arr {
291 json_arr.push(dynamic_to_json(item)?);
292 }
293 Ok(serde_json::Value::Array(json_arr))
294 } else if value.is::<Map>() {
295 let map = value.cast::<Map>();
296 let mut json_obj = serde_json::Map::new();
297 for (k, v) in map {
298 json_obj.insert(k.to_string(), dynamic_to_json(v)?);
299 }
300 Ok(serde_json::Value::Object(json_obj))
301 } else {
302 Ok(serde_json::Value::String(value.to_string()))
304 }
305}
306
307fn json_to_dynamic(value: serde_json::Value) -> Result<Dynamic> {
309 match value {
310 serde_json::Value::Null => Ok(Dynamic::UNIT),
311 serde_json::Value::Bool(b) => Ok(Dynamic::from(b)),
312 serde_json::Value::Number(n) => {
313 if let Some(i) = n.as_i64() {
314 Ok(Dynamic::from(i))
315 } else if let Some(f) = n.as_f64() {
316 Ok(Dynamic::from(f))
317 } else {
318 Ok(Dynamic::UNIT)
319 }
320 }
321 serde_json::Value::String(s) => Ok(Dynamic::from(s)),
322 serde_json::Value::Array(arr) => {
323 let mut rhai_arr = rhai::Array::new();
324 for item in arr {
325 rhai_arr.push(json_to_dynamic(item)?);
326 }
327 Ok(Dynamic::from(rhai_arr))
328 }
329 serde_json::Value::Object(obj) => {
330 let mut map = Map::new();
331 for (k, v) in obj {
332 map.insert(k.into(), json_to_dynamic(v)?);
333 }
334 Ok(Dynamic::from(map))
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_simple_script() {
345 let engine = ScriptEngine::new();
346
347 let script = r#"
348 #{
349 files: #{
350 "test.txt": "Hello, " + ctx.profile.alias
351 },
352 env: #{
353 "TEST_VAR": "test_value"
354 }
355 }
356 "#;
357
358 let context = ScriptContext {
359 profile: ProfileContext {
360 alias: "myprofile".to_string(),
361 home: PathBuf::from("/home/test"),
362 model: "test-model".to_string(),
363 endpoint: "https://api.test.com".to_string(),
364 hooks: vec![],
365 mcp_servers: vec![],
366 hooks_config: None,
367 proxy_url: None,
368 },
369 provider: ProviderContext {
370 id: "test".to_string(),
371 name: "Test Provider".to_string(),
372 provider_type: "anthropic".to_string(),
373 auth_env_key: "TEST_API_KEY".to_string(),
374 },
375 agent: AgentContext {
376 id: "test".to_string(),
377 name: "Test Agent".to_string(),
378 binary: "test".to_string(),
379 },
380 prefs: PrefsContext::default(),
381 };
382
383 let output = engine.run(script, &context).unwrap();
384
385 assert_eq!(output.files.get("test.txt"), Some(&"Hello, myprofile".to_string()));
386 assert_eq!(output.env.get("TEST_VAR"), Some(&"test_value".to_string()));
387 }
388
389 #[test]
390 fn test_json_encode() {
391 let engine = ScriptEngine::new();
392
393 let script = r#"
394 let obj = #{ name: "test", value: 42 };
395 #{
396 files: #{
397 "config.json": json::encode(obj)
398 },
399 env: #{}
400 }
401 "#;
402
403 let context = ScriptContext {
404 profile: ProfileContext {
405 alias: "test".to_string(),
406 home: PathBuf::from("/home/test"),
407 model: "test".to_string(),
408 endpoint: "https://test.com".to_string(),
409 hooks: vec![],
410 mcp_servers: vec![],
411 hooks_config: None,
412 proxy_url: None,
413 },
414 provider: ProviderContext {
415 id: "test".to_string(),
416 name: "Test".to_string(),
417 provider_type: "anthropic".to_string(),
418 auth_env_key: "KEY".to_string(),
419 },
420 agent: AgentContext {
421 id: "test".to_string(),
422 name: "Test".to_string(),
423 binary: "test".to_string(),
424 },
425 prefs: PrefsContext::default(),
426 };
427
428 let output = engine.run(script, &context).unwrap();
429 let json_content = output.files.get("config.json").unwrap();
430 assert!(json_content.contains("\"name\""));
431 assert!(json_content.contains("\"test\""));
432 }
433}