Skip to main content

rs_utcp/plugins/codemode/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5
6use anyhow::{anyhow, Result};
7use rhai::{Dynamic, Engine, EvalAltResult, Map, Scope};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tokio::runtime::{Builder, RuntimeFlavor};
11
12use crate::security;
13use crate::tools::{Tool, ToolInputOutputSchema};
14use crate::UtcpClientInterface;
15
16// Security configuration constants
17/// Maximum code snippet size (100KB) to prevent DoS attacks
18const MAX_CODE_SIZE: usize = 100_000;
19
20/// Maximum timeout for code execution (30 seconds)
21const MAX_TIMEOUT_MS: u64 = 45_000;
22
23/// Default timeout if none specified (5 seconds)
24const DEFAULT_TIMEOUT_MS: u64 = 5_000;
25
26/// Maximum size for script output (10MB) to prevent memory exhaustion
27const MAX_OUTPUT_SIZE: usize = 10_000_000;
28
29/// Maximum operations per script execution
30const MAX_OPERATIONS: u64 = 100_000;
31
32/// Maximum expression depth to prevent stack overflow
33const MAX_EXPR_DEPTH: (usize, usize) = (64, 32);
34
35/// Maximum string size (1MB) within scripts
36const MAX_STRING_SIZE: usize = 1_000_000;
37
38/// Maximum array/map sizes to prevent memory exhaustion
39const MAX_ARRAY_SIZE: usize = 10_000;
40const MAX_MAP_SIZE: usize = 10_000;
41
42/// Maximum number of modules
43const MAX_MODULES: usize = 16;
44
45/// Dangerous code patterns that are prohibited
46const DANGEROUS_PATTERNS: &[&str] = &[
47    "eval(",
48    "import ",
49    "fn ",        // Function definitions could be abused
50    "while true", // Infinite loops
51    "loop {",     // Infinite loops
52];
53
54/// Minimal facade exposing UTCP calls to Rhai scripts executed by CodeMode.
55pub struct CodeModeUtcp {
56    client: Arc<dyn UtcpClientInterface>,
57}
58
59impl CodeModeUtcp {
60    /// Wrap an `UtcpClientInterface` so codemode scripts can invoke tools.
61    pub fn new(client: Arc<dyn UtcpClientInterface>) -> Self {
62        Self { client }
63    }
64
65    /// Validates code for security issues before execution.
66    fn validate_code(&self, code: &str) -> Result<()> {
67        // Check code size
68        if code.len() > MAX_CODE_SIZE {
69            return Err(anyhow!(
70                "Code size {} bytes exceeds maximum allowed {} bytes",
71                code.len(),
72                MAX_CODE_SIZE
73            ));
74        }
75
76        // Check for dangerous patterns
77        for pattern in DANGEROUS_PATTERNS {
78            if code.contains(pattern) {
79                return Err(anyhow!("Code contains prohibited pattern: '{}'", pattern));
80            }
81        }
82
83        Ok(())
84    }
85
86    /// Execute a snippet or JSON payload, returning the resulting value and captured output.
87    pub async fn execute(&self, args: CodeModeArgs) -> Result<CodeModeResult> {
88        // Validate code before execution
89        self.validate_code(&args.code)?;
90
91        // Determine and validate timeout
92        let timeout_ms = args.timeout.unwrap_or(DEFAULT_TIMEOUT_MS);
93        security::validate_timeout(timeout_ms, MAX_TIMEOUT_MS)?;
94
95        // If it's JSON already, return it directly (no execution needed)
96        if let Ok(json) = serde_json::from_str::<Value>(&args.code) {
97            return Ok(CodeModeResult {
98                value: json,
99                stdout: String::new(),
100                stderr: String::new(),
101            });
102        }
103
104        // Execute with timeout
105        let result = tokio::time::timeout(
106            Duration::from_millis(timeout_ms),
107            self.eval_rusty_snippet(&args.code, Some(timeout_ms)),
108        )
109        .await;
110
111        let value = match result {
112            Ok(Ok(v)) => v,
113            Ok(Err(e)) => return Err(e),
114            Err(_) => {
115                return Err(anyhow!("Code execution timed out after {}ms", timeout_ms));
116            }
117        };
118
119        // Validate output size to prevent memory exhaustion
120        let serialized = serde_json::to_vec(&value)?;
121        if serialized.len() > MAX_OUTPUT_SIZE {
122            return Err(anyhow!(
123                "Output size {} bytes exceeds maximum allowed {} bytes",
124                serialized.len(),
125                MAX_OUTPUT_SIZE
126            ));
127        }
128
129        Ok(CodeModeResult {
130            value,
131            stdout: String::new(),
132            stderr: String::new(),
133        })
134    }
135
136    fn tool_schema(&self) -> Tool {
137        Tool {
138            name: "codemode.run_code".to_string(),
139            description: "Execute a Rust-like snippet with access to UTCP tools.".to_string(),
140            inputs: ToolInputOutputSchema {
141                type_: "object".to_string(),
142                properties: Some(HashMap::from([
143                    (
144                        "code".to_string(),
145                        serde_json::json!({"type": "string", "description": "Rust-like snippet"}),
146                    ),
147                    (
148                        "timeout".to_string(),
149                        serde_json::json!({"type": "integer", "description": "Timeout ms"}),
150                    ),
151                ])),
152                required: Some(vec!["code".to_string()]),
153                description: None,
154                title: Some("CodeModeArgs".to_string()),
155                items: None,
156                enum_: None,
157                minimum: None,
158                maximum: None,
159                format: None,
160            },
161            outputs: ToolInputOutputSchema {
162                type_: "object".to_string(),
163                properties: Some(HashMap::from([
164                    ("value".to_string(), serde_json::json!({"type": "string"})),
165                    ("stdout".to_string(), serde_json::json!({"type": "string"})),
166                    ("stderr".to_string(), serde_json::json!({"type": "string"})),
167                ])),
168                required: None,
169                description: None,
170                title: Some("CodeModeResult".to_string()),
171                items: None,
172                enum_: None,
173                minimum: None,
174                maximum: None,
175                format: None,
176            },
177            tags: vec!["codemode".to_string(), "utcp".to_string()],
178            average_response_size: None,
179            provider: None,
180        }
181    }
182
183    fn build_engine(&self) -> Engine {
184        let mut engine = Engine::new();
185
186        // Security: Comprehensive sandboxing using centralized constants
187        engine.set_max_expr_depths(MAX_EXPR_DEPTH.0, MAX_EXPR_DEPTH.1);
188        engine.set_max_operations(MAX_OPERATIONS);
189        engine.set_max_modules(MAX_MODULES);
190        engine.set_max_string_size(MAX_STRING_SIZE);
191        engine.set_max_array_size(MAX_ARRAY_SIZE);
192        engine.set_max_map_size(MAX_MAP_SIZE);
193
194        // Note: File I/O and other dangerous operations are disabled by default in Rhai
195        // when not explicitly importing the std modules
196
197        engine.register_fn("sprintf", sprintf);
198
199        let client = self.client.clone();
200        engine.register_fn(
201            "call_tool",
202            move |name: &str, map: Map| -> Result<Dynamic, Box<EvalAltResult>> {
203                // Security: Validate tool name format
204                if name.is_empty() || name.len() > 200 {
205                    return Err(EvalAltResult::ErrorRuntime(
206                        "Invalid tool name length".into(),
207                        rhai::Position::NONE,
208                    )
209                    .into());
210                }
211
212                let args_val = serde_json::to_value(map).map_err(|e| {
213                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
214                })?;
215                let args = value_to_map(args_val)?;
216
217                let res = block_on_any_runtime(async { client.call_tool(name, args).await })
218                    .map_err(|e| {
219                        EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
220                    })?;
221
222                Ok(rhai::serde::to_dynamic(res).map_err(|e| {
223                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
224                })?)
225            },
226        );
227
228        let client = self.client.clone();
229        engine.register_fn(
230            "call_tool_stream",
231            move |name: &str, map: Map| -> Result<Dynamic, Box<EvalAltResult>> {
232                // Security: Validate tool name format
233                if name.is_empty() || name.len() > 200 {
234                    return Err(EvalAltResult::ErrorRuntime(
235                        "Invalid tool name length".into(),
236                        rhai::Position::NONE,
237                    )
238                    .into());
239                }
240
241                let args_val = serde_json::to_value(map).map_err(|e| {
242                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
243                })?;
244                let args = value_to_map(args_val)?;
245
246                let mut stream =
247                    block_on_any_runtime(async { client.call_tool_stream(name, args).await })
248                        .map_err(|e| {
249                            EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
250                        })?;
251
252                let mut items = Vec::new();
253                // Security: Limit maximum number of stream items to prevent memory exhaustion
254                const MAX_STREAM_ITEMS: usize = 10_000;
255
256                loop {
257                    if items.len() >= MAX_STREAM_ITEMS {
258                        return Err(EvalAltResult::ErrorRuntime(
259                            format!("Stream exceeded maximum {} items", MAX_STREAM_ITEMS).into(),
260                            rhai::Position::NONE,
261                        )
262                        .into());
263                    }
264
265                    let next =
266                        block_on_any_runtime(async { stream.next().await }).map_err(|e| {
267                            EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
268                        })?;
269                    match next {
270                        Some(value) => items.push(value),
271                        None => break,
272                    }
273                }
274
275                if let Err(e) = block_on_any_runtime(async { stream.close().await }) {
276                    return Err(EvalAltResult::ErrorRuntime(
277                        e.to_string().into(),
278                        rhai::Position::NONE,
279                    )
280                    .into());
281                }
282
283                Ok(rhai::serde::to_dynamic(items).map_err(|e| {
284                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
285                })?)
286            },
287        );
288
289        let client = self.client.clone();
290        engine.register_fn(
291            "search_tools",
292            move |query: &str, limit: i64| -> Result<Dynamic, Box<EvalAltResult>> {
293                // Security: Validate query length
294                if query.len() > 1000 {
295                    return Err(EvalAltResult::ErrorRuntime(
296                        "Search query too long (max 1000 chars)".into(),
297                        rhai::Position::NONE,
298                    )
299                    .into());
300                }
301
302                // Security: Enforce reasonable search limit
303                const MAX_SEARCH_LIMIT: i64 = 500;
304                let safe_limit = if limit <= 0 || limit > MAX_SEARCH_LIMIT {
305                    MAX_SEARCH_LIMIT
306                } else {
307                    limit
308                };
309
310                let res = block_on_any_runtime(async {
311                    client.search_tools(query, safe_limit as usize).await
312                })
313                .map_err(|e| {
314                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
315                })?;
316                Ok(rhai::serde::to_dynamic(res).map_err(|e| {
317                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
318                })?)
319            },
320        );
321
322        engine
323    }
324
325    async fn eval_rusty_snippet(&self, code: &str, _timeout_ms: Option<u64>) -> Result<Value> {
326        let wrapped = format!("let __out = {{ {} }};\n__out", code);
327        let engine = self.build_engine();
328        let mut scope = Scope::new();
329
330        let dyn_result = engine.eval_with_scope::<Dynamic>(&mut scope, &wrapped);
331        let dyn_value = dyn_result.map_err(|e| anyhow!("codemode eval error: {}", e))?;
332        let value: Value = rhai::serde::from_dynamic(&dyn_value)
333            .map_err(|e| anyhow!("Failed to convert result: {}", e))?;
334        Ok(value)
335    }
336
337    /// Expose the codemode tool definition for registration.
338    pub fn tool(&self) -> Tool {
339        self.tool_schema()
340    }
341
342    /// Convenience helpers mirroring go-utcp codemode helper exports.
343    pub async fn call_tool(&self, name: &str, args: HashMap<String, Value>) -> Result<Value> {
344        self.client.call_tool(name, args).await
345    }
346
347    pub async fn call_tool_stream(
348        &self,
349        name: &str,
350        args: HashMap<String, Value>,
351    ) -> Result<Box<dyn crate::transports::stream::StreamResult>> {
352        self.client.call_tool_stream(name, args).await
353    }
354
355    pub async fn search_tools(&self, query: &str, limit: usize) -> Result<Vec<Tool>> {
356        self.client.search_tools(query, limit).await
357    }
358}
359
360#[async_trait::async_trait]
361pub trait LlmModel: Send + Sync {
362    /// Produce a completion for the provided prompt.
363    async fn complete(&self, prompt: &str) -> Result<Value>;
364}
365
366/// High-level orchestrator that mirrors go-utcp's CodeMode flow:
367/// 1) Decide if tools are needed
368/// 2) Select tools by name
369/// 3) Ask the model to emit a Rhai snippet using call_tool helpers
370/// 4) Execute the snippet via CodeMode
371pub struct CodemodeOrchestrator {
372    codemode: Arc<CodeModeUtcp>,
373    model: Arc<dyn LlmModel>,
374    tool_specs_cache: RwLock<Option<String>>,
375}
376
377impl CodemodeOrchestrator {
378    /// Create a new orchestrator backed by a CodeMode UTCP shim and an LLM model.
379    pub fn new(codemode: Arc<CodeModeUtcp>, model: Arc<dyn LlmModel>) -> Self {
380        Self {
381            codemode,
382            model,
383            tool_specs_cache: RwLock::new(None),
384        }
385    }
386
387    /// Run the full orchestration flow. Returns Ok(None) if the model says no tools are needed
388    /// or fails to pick any tools. Otherwise returns the codemode execution result.
389    pub async fn call_prompt(&self, prompt: &str) -> Result<Option<Value>> {
390        let specs = self.render_tool_specs().await?;
391
392        if !self.decide_if_tools_needed(prompt, &specs).await? {
393            return Ok(None);
394        }
395
396        let selected = self.select_tools(prompt, &specs).await?;
397        if selected.is_empty() {
398            return Ok(None);
399        }
400
401        let snippet = self.generate_snippet(prompt, &selected, &specs).await?;
402        let raw = self
403            .codemode
404            .execute(CodeModeArgs {
405                code: snippet,
406                timeout: Some(20_000),
407            })
408            .await?;
409
410        Ok(Some(raw.value))
411    }
412
413    async fn render_tool_specs(&self) -> Result<String> {
414        {
415            let cache = self.tool_specs_cache.read().await;
416            if let Some(specs) = &*cache {
417                return Ok(specs.clone());
418            }
419        }
420
421        let tools = self
422            .codemode
423            .search_tools("", 200)
424            .await
425            .unwrap_or_default();
426        let mut rendered =
427            String::from("UTCP TOOL REFERENCE (use exact field names and required keys):\n");
428        for tool in tools {
429            rendered.push_str(&format!("TOOL: {} - {}\n", tool.name, tool.description));
430
431            rendered.push_str("INPUTS:\n");
432            match tool.inputs.properties.as_ref() {
433                Some(props) if !props.is_empty() => {
434                    for (key, schema) in props {
435                        rendered.push_str(&format!("  - {}: {}\n", key, schema_type_hint(schema)));
436                    }
437                }
438                _ => rendered.push_str("  - none\n"),
439            }
440
441            if let Some(required) = tool.inputs.required.as_ref() {
442                if !required.is_empty() {
443                    rendered.push_str("  REQUIRED:\n");
444                    for field in required {
445                        rendered.push_str(&format!("  - {}\n", field));
446                    }
447                }
448            }
449
450            rendered.push_str("OUTPUTS:\n");
451            match tool.outputs.properties.as_ref() {
452                Some(props) if !props.is_empty() => {
453                    for (key, schema) in props {
454                        rendered.push_str(&format!("  - {}: {}\n", key, schema_type_hint(schema)));
455                    }
456                }
457                _ => {
458                    if !tool.outputs.type_.is_empty() {
459                        rendered.push_str(&format!("  - type: {}\n", tool.outputs.type_));
460                    } else {
461                        rendered.push_str("  - (shape unspecified)\n");
462                    }
463                }
464            }
465
466            rendered.push('\n');
467        }
468
469        let mut cache = self.tool_specs_cache.write().await;
470        *cache = Some(rendered.clone());
471        Ok(rendered)
472    }
473
474    async fn decide_if_tools_needed(&self, prompt: &str, specs: &str) -> Result<bool> {
475        let request = format!(
476            "You can call tools described below. Respond with only 'yes' or 'no'.\n\nTOOLS:\n{}\n\nUSER:\n{}",
477            specs, prompt
478        );
479        let resp_val = self.model.complete(&request).await?;
480        Ok(resp_val
481            .as_str()
482            .unwrap_or_default()
483            .trim_start()
484            .to_ascii_lowercase()
485            .starts_with('y'))
486    }
487
488    async fn select_tools(&self, prompt: &str, specs: &str) -> Result<Vec<String>> {
489        let request = format!(
490            "Choose relevant tool names from the list. Respond with a comma-separated list of names only.\n\nTOOLS:\n{}\n\nUSER:\n{}",
491            specs, prompt
492        );
493        let resp_val = self.model.complete(&request).await?;
494        let resp = resp_val.as_str().unwrap_or_default();
495        let mut out = Vec::new();
496        for name in resp.split(',') {
497            let n = name.trim();
498            if !n.is_empty() {
499                out.push(n.to_string());
500            }
501        }
502        Ok(out)
503    }
504
505    async fn generate_snippet(
506        &self,
507        prompt: &str,
508        tools: &[String],
509        specs: &str,
510    ) -> Result<String> {
511        let tool_list = tools.join(", ");
512        let request = format!(
513            "Generate a Rhai snippet that chains UTCP tool calls to satisfy the user request.\n\
514Use ONLY these tools: {tool_list}.\n\
515Helpers available: call_tool(name, map), call_tool_stream(name, map) -> array of streamed chunks, search_tools(query, limit), sprintf(fmt, list).\n\
516Use Rhai map syntax #{{\"field\": value}} with exact input field names; include required fields and never invent new keys.\n\
517You may call multiple tools, store results in variables, and pass them into subsequent tools.\n\
518When using call_tool_stream, treat the returned array as the streamed items and chain it into later calls or the final output.\n\
519Return the final value as the last expression (map/list/scalar). No markdown or commentary, code only.\n\
520\nUSER:\n{prompt}\n\nTOOLS (use exact field names):\n{specs}"
521        );
522        let resp_val = self.model.complete(&request).await?;
523        Ok(resp_val.as_str().unwrap_or_default().trim().to_string())
524    }
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
528/// Arguments accepted by the codemode tool.
529pub struct CodeModeArgs {
530    pub code: String,
531    #[serde(default)]
532    pub timeout: Option<u64>,
533}
534
535#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
536/// Result payload returned from codemode execution.
537pub struct CodeModeResult {
538    pub value: Value,
539    #[serde(default)]
540    pub stdout: String,
541    #[serde(default)]
542    pub stderr: String,
543}
544
545fn schema_type_hint(value: &Value) -> String {
546    if let Some(t) = value.get("type").and_then(|v| v.as_str()) {
547        t.to_string()
548    } else if let Some(s) = value.as_str() {
549        s.to_string()
550    } else if value.is_array() {
551        "array".to_string()
552    } else if value.is_object() {
553        "object".to_string()
554    } else {
555        "any".to_string()
556    }
557}
558
559fn value_to_map(value: Value) -> Result<HashMap<String, Value>, Box<EvalAltResult>> {
560    match value {
561        Value::Object(obj) => Ok(obj.into_iter().collect()),
562        _ => Err(EvalAltResult::ErrorRuntime(
563            "call_tool expects object args".into(),
564            rhai::Position::NONE,
565        )
566        .into()),
567    }
568}
569
570/// Minimal string formatter exposed to Rhai snippets.
571/// Security: Limited to prevent DoS attacks.
572pub fn sprintf(fmt: &str, args: &[Dynamic]) -> String {
573    // Security: Limit format string size
574    const MAX_FMT_SIZE: usize = 10_000;
575    const MAX_ARGS: usize = 100;
576
577    if fmt.len() > MAX_FMT_SIZE {
578        return "[ERROR: Format string too long]".to_string();
579    }
580
581    if args.len() > MAX_ARGS {
582        return "[ERROR: Too many arguments]".to_string();
583    }
584
585    let mut out = fmt.to_string();
586    for rendered in args.iter().map(|v| v.to_string()) {
587        // Security: Limit argument string length
588        let safe_rendered = if rendered.len() > 1000 {
589            format!("{}...[truncated]", &rendered[..1000])
590        } else {
591            rendered
592        };
593        out = out.replacen("{}", &safe_rendered, 1);
594    }
595
596    // Security: Limit total output size
597    if out.len() > MAX_FMT_SIZE * 2 {
598        out.truncate(MAX_FMT_SIZE * 2);
599        out.push_str("...[truncated]");
600    }
601
602    out
603}
604
605fn block_on_any_runtime<F, T>(fut: F) -> Result<T, anyhow::Error>
606where
607    F: std::future::Future<Output = Result<T, anyhow::Error>>,
608    T: Send + 'static,
609{
610    match tokio::runtime::Handle::try_current() {
611        Ok(handle) => match handle.runtime_flavor() {
612            RuntimeFlavor::MultiThread => tokio::task::block_in_place(|| handle.block_on(fut)),
613            RuntimeFlavor::CurrentThread => {
614                let rt = Builder::new_current_thread().enable_all().build()?;
615                rt.block_on(fut)
616            }
617            _ => {
618                let rt = Builder::new_current_thread().enable_all().build()?;
619                rt.block_on(fut)
620            }
621        },
622        Err(_) => {
623            let rt = Builder::new_current_thread().enable_all().build()?;
624            rt.block_on(fut)
625        }
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use crate::tools::Tool;
633    use crate::transports::stream::boxed_vec_stream;
634    use tokio::sync::Mutex;
635
636    #[derive(Clone)]
637    struct MockClient {
638        called: Arc<Mutex<Vec<String>>>,
639    }
640
641    #[async_trait::async_trait]
642    impl UtcpClientInterface for MockClient {
643        async fn register_tool_provider(
644            &self,
645            _prov: Arc<dyn crate::providers::base::Provider>,
646        ) -> Result<Vec<Tool>> {
647            Ok(vec![])
648        }
649
650        async fn register_tool_provider_with_tools(
651            &self,
652            _prov: Arc<dyn crate::providers::base::Provider>,
653            tools: Vec<Tool>,
654        ) -> Result<Vec<Tool>> {
655            Ok(tools)
656        }
657
658        async fn deregister_tool_provider(&self, _provider_name: &str) -> Result<()> {
659            Ok(())
660        }
661
662        async fn call_tool(&self, tool_name: &str, _args: HashMap<String, Value>) -> Result<Value> {
663            self.called.lock().await.push(tool_name.to_string());
664            Ok(Value::Number(serde_json::Number::from(5)))
665        }
666
667        async fn search_tools(&self, query: &str, _limit: usize) -> Result<Vec<Tool>> {
668            self.called.lock().await.push(format!("search:{query}"));
669            Ok(vec![])
670        }
671
672        fn get_transports(&self) -> HashMap<String, Arc<dyn crate::transports::ClientTransport>> {
673            HashMap::new()
674        }
675
676        async fn call_tool_stream(
677            &self,
678            tool_name: &str,
679            _args: HashMap<String, Value>,
680        ) -> Result<Box<dyn crate::transports::stream::StreamResult>> {
681            self.called.lock().await.push(format!("stream:{tool_name}"));
682            Ok(boxed_vec_stream(vec![Value::String("chunk".into())]))
683        }
684    }
685
686    #[tokio::test(flavor = "multi_thread")]
687    async fn codemode_helpers_forward_to_client() {
688        let client = Arc::new(MockClient {
689            called: Arc::new(Mutex::new(Vec::new())),
690        });
691        let codemode = CodeModeUtcp::new(client.clone());
692
693        codemode
694            .call_tool("demo.tool", HashMap::new())
695            .await
696            .unwrap();
697        codemode.search_tools("demo", 5).await.unwrap();
698        let mut stream = codemode
699            .call_tool_stream("demo.tool", HashMap::new())
700            .await
701            .unwrap();
702        let _ = stream.next().await.unwrap();
703
704        let calls = client.called.lock().await.clone();
705        assert_eq!(calls, vec!["demo.tool", "search:demo", "stream:demo.tool"]);
706    }
707
708    #[tokio::test(flavor = "multi_thread")]
709    async fn execute_runs_rusty_snippet_and_call_tool() {
710        let client = Arc::new(MockClient {
711            called: Arc::new(Mutex::new(Vec::new())),
712        });
713        let codemode = CodeModeUtcp::new(client);
714
715        let code = r#"let x = 2 + 3; let y = call_tool("math.add", #{"a":1}); x + y"#;
716        let args = CodeModeArgs {
717            code: code.into(),
718            timeout: Some(1000),
719        };
720        let res = codemode.execute(args).await.unwrap();
721        assert_eq!(res.value, serde_json::json!(10));
722    }
723
724    #[tokio::test(flavor = "multi_thread")]
725    async fn execute_collects_stream_results() {
726        let client = Arc::new(MockClient {
727            called: Arc::new(Mutex::new(Vec::new())),
728        });
729        let codemode = CodeModeUtcp::new(client.clone());
730
731        let code = r#"let chunks = call_tool_stream("demo.tool", #{}); chunks"#;
732        let args = CodeModeArgs {
733            code: code.into(),
734            timeout: Some(1_000),
735        };
736        let res = codemode.execute(args).await.unwrap();
737        assert_eq!(res.value, serde_json::json!(["chunk"]));
738        let calls = client.called.lock().await.clone();
739        assert_eq!(calls, vec!["stream:demo.tool"]);
740    }
741
742    // Security Tests
743
744    #[tokio::test(flavor = "multi_thread")]
745    async fn security_rejects_oversized_code() {
746        let client = Arc::new(MockClient {
747            called: Arc::new(Mutex::new(Vec::new())),
748        });
749        let codemode = CodeModeUtcp::new(client);
750
751        // Create code larger than MAX_CODE_SIZE (100KB)
752        let large_code = "x".repeat(150_000);
753        let args = CodeModeArgs {
754            code: large_code,
755            timeout: Some(1000),
756        };
757
758        let result = codemode.execute(args).await;
759        assert!(result.is_err());
760        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
761    }
762
763    #[tokio::test(flavor = "multi_thread")]
764    async fn security_rejects_dangerous_patterns() {
765        let client = Arc::new(MockClient {
766            called: Arc::new(Mutex::new(Vec::new())),
767        });
768        let codemode = CodeModeUtcp::new(client);
769
770        // Test each dangerous pattern
771        let dangerous_codes = vec![
772            "eval(some_code)",
773            "import some_module",
774            "fn evil() { }",
775            "while true { }",
776            "loop { break; }",
777        ];
778
779        for code in dangerous_codes {
780            let args = CodeModeArgs {
781                code: code.to_string(),
782                timeout: Some(1000),
783            };
784
785            let result = codemode.execute(args).await;
786            assert!(result.is_err(), "Should reject: {}", code);
787            assert!(result
788                .unwrap_err()
789                .to_string()
790                .contains("prohibited pattern"));
791        }
792    }
793
794    #[tokio::test(flavor = "multi_thread")]
795    async fn security_enforces_timeout() {
796        let client = Arc::new(MockClient {
797            called: Arc::new(Mutex::new(Vec::new())),
798        });
799        let codemode = CodeModeUtcp::new(client);
800
801        // Code that takes a while (but not infinite due to operation limits)
802        let code = r#"let sum = 0; for i in 0..100000 { sum = sum + i; } sum"#;
803        let args = CodeModeArgs {
804            code: code.to_string(),
805            timeout: Some(1), // Very short timeout - 1ms
806        };
807
808        let result = codemode.execute(args).await;
809        // This should timeout or complete very fast
810        // Either way, we're testing that timeout mechanism works
811        if result.is_err() {
812            let err = result.unwrap_err().to_string();
813            // It might timeout or hit operation limit
814            assert!(
815                err.contains("timeout") || err.contains("operations"),
816                "Unexpected error: {}",
817                err
818            );
819        }
820    }
821
822    #[tokio::test(flavor = "multi_thread")]
823    async fn security_rejects_excessive_timeout() {
824        let client = Arc::new(MockClient {
825            called: Arc::new(Mutex::new(Vec::new())),
826        });
827        let codemode = CodeModeUtcp::new(client);
828
829        let args = CodeModeArgs {
830            code: "42".to_string(),
831            timeout: Some(60_000), // 60 seconds - over MAX_TIMEOUT_MS
832        };
833
834        let result = codemode.execute(args).await;
835        assert!(result.is_err());
836        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
837    }
838
839    #[tokio::test(flavor = "multi_thread")]
840    async fn security_limits_output_size() {
841        let client = Arc::new(MockClient {
842            called: Arc::new(Mutex::new(Vec::new())),
843        });
844        let codemode = CodeModeUtcp::new(client);
845
846        // Create code that would produce large output through array limits
847        // This will hit the array size limit (10,000 items)
848        let code = r#"let arr = []; for i in 0..15000 { arr.push(i); } arr"#;
849        let args = CodeModeArgs {
850            code: code.to_string(),
851            timeout: Some(10_000),
852        };
853
854        let result = codemode.execute(args).await;
855        // Should fail due to array size limit or operations limit
856        assert!(result.is_err(), "Should fail due to limits");
857        let err = result.unwrap_err().to_string();
858        assert!(
859            err.contains("array") || err.contains("operations") || err.contains("eval error"),
860            "Unexpected error: {}",
861            err
862        );
863    }
864
865    #[test]
866    fn security_sprintf_limits_format_size() {
867        let fmt = "x".repeat(20_000); // Over MAX_FMT_SIZE
868        let result = sprintf(&fmt, &[]);
869        assert_eq!(result, "[ERROR: Format string too long]");
870    }
871
872    #[test]
873    fn security_sprintf_limits_args_count() {
874        let args: Vec<Dynamic> = (0..200).map(|i| Dynamic::from(i)).collect();
875        let result = sprintf("{}", &args);
876        assert_eq!(result, "[ERROR: Too many arguments]");
877    }
878
879    #[test]
880    fn security_sprintf_truncates_long_args() {
881        let long_arg = Dynamic::from("x".repeat(2000));
882        let result = sprintf("Value: {}", &[long_arg]);
883        assert!(result.contains("...[truncated]"));
884    }
885
886    #[test]
887    fn security_sprintf_limits_output_size() {
888        let fmt = "{}".repeat(10_000);
889        let args: Vec<Dynamic> = (0..10_000)
890            .map(|i| Dynamic::from(format!("arg{}", i)))
891            .collect();
892        let result = sprintf(&fmt, &args[..100]); // Use fewer args to stay under MAX_ARGS
893                                                  // Output should be truncated if it gets too large
894        if result.len() > 20_000 {
895            assert!(result.contains("...[truncated]"));
896        }
897    }
898}