rs_utcp/plugins/codemode/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5use anyhow::{anyhow, Result};
6use rhai::{Dynamic, Engine, EvalAltResult, Map, Scope};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::runtime::{Builder, RuntimeFlavor};
10
11use crate::tools::{Tool, ToolInputOutputSchema};
12use crate::UtcpClientInterface;
13
14pub struct CodeModeUtcp {
15    client: Arc<dyn UtcpClientInterface>,
16}
17
18impl CodeModeUtcp {
19    pub fn new(client: Arc<dyn UtcpClientInterface>) -> Self {
20        Self { client }
21    }
22
23    pub async fn execute(&self, args: CodeModeArgs) -> Result<CodeModeResult> {
24        // If it's JSON already, return it directly.
25        if let Ok(json) = serde_json::from_str::<Value>(&args.code) {
26            return Ok(CodeModeResult {
27                value: json,
28                stdout: String::new(),
29                stderr: String::new(),
30            });
31        }
32
33        let value = self.eval_rusty_snippet(&args.code, args.timeout).await?;
34        Ok(CodeModeResult {
35            value,
36            stdout: String::new(),
37            stderr: String::new(),
38        })
39    }
40
41    fn tool_schema(&self) -> Tool {
42        Tool {
43            name: "codemode.run_code".to_string(),
44            description: "Execute a Rust-like snippet with access to UTCP tools.".to_string(),
45            inputs: ToolInputOutputSchema {
46                type_: "object".to_string(),
47                properties: Some(HashMap::from([
48                    (
49                        "code".to_string(),
50                        serde_json::json!({"type": "string", "description": "Rust-like snippet"}),
51                    ),
52                    (
53                        "timeout".to_string(),
54                        serde_json::json!({"type": "integer", "description": "Timeout ms"}),
55                    ),
56                ])),
57                required: Some(vec!["code".to_string()]),
58                description: None,
59                title: Some("CodeModeArgs".to_string()),
60                items: None,
61                enum_: None,
62                minimum: None,
63                maximum: None,
64                format: None,
65            },
66            outputs: ToolInputOutputSchema {
67                type_: "object".to_string(),
68                properties: Some(HashMap::from([
69                    ("value".to_string(), serde_json::json!({"type": "string"})),
70                    ("stdout".to_string(), serde_json::json!({"type": "string"})),
71                    ("stderr".to_string(), serde_json::json!({"type": "string"})),
72                ])),
73                required: None,
74                description: None,
75                title: Some("CodeModeResult".to_string()),
76                items: None,
77                enum_: None,
78                minimum: None,
79                maximum: None,
80                format: None,
81            },
82            tags: vec!["codemode".to_string(), "utcp".to_string()],
83            average_response_size: None,
84            provider: None,
85        }
86    }
87
88    fn build_engine(&self) -> Engine {
89        let mut engine = Engine::new();
90        engine.register_fn("sprintf", sprintf);
91
92        let client = self.client.clone();
93        engine.register_fn(
94            "call_tool",
95            move |name: &str, map: Map| -> Result<Dynamic, Box<EvalAltResult>> {
96                let args_val = serde_json::to_value(map).map_err(|e| {
97                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
98                })?;
99                let args = value_to_map(args_val)?;
100
101                let res = block_on_any_runtime(async { client.call_tool(name, args).await })
102                    .map_err(|e| {
103                        EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
104                    })?;
105
106                Ok(rhai::serde::to_dynamic(res).map_err(|e| {
107                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
108                })?)
109            },
110        );
111
112        let client = self.client.clone();
113        engine.register_fn(
114            "call_tool_stream",
115            move |name: &str, map: Map| -> Result<Dynamic, Box<EvalAltResult>> {
116                let args_val = serde_json::to_value(map).map_err(|e| {
117                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
118                })?;
119                let args = value_to_map(args_val)?;
120
121                let mut stream =
122                    block_on_any_runtime(async { client.call_tool_stream(name, args).await })
123                        .map_err(|e| {
124                            EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
125                        })?;
126
127                let mut items = Vec::new();
128                loop {
129                    let next =
130                        block_on_any_runtime(async { stream.next().await }).map_err(|e| {
131                            EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
132                        })?;
133                    match next {
134                        Some(value) => items.push(value),
135                        None => break,
136                    }
137                }
138
139                if let Err(e) = block_on_any_runtime(async { stream.close().await }) {
140                    return Err(EvalAltResult::ErrorRuntime(
141                        e.to_string().into(),
142                        rhai::Position::NONE,
143                    )
144                    .into());
145                }
146
147                Ok(rhai::serde::to_dynamic(items).map_err(|e| {
148                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
149                })?)
150            },
151        );
152
153        let client = self.client.clone();
154        engine.register_fn(
155            "search_tools",
156            move |query: &str, limit: i64| -> Result<Dynamic, Box<EvalAltResult>> {
157                let res = block_on_any_runtime(async {
158                    client.search_tools(query, limit as usize).await
159                })
160                .map_err(|e| {
161                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
162                })?;
163                Ok(rhai::serde::to_dynamic(res).map_err(|e| {
164                    EvalAltResult::ErrorRuntime(e.to_string().into(), rhai::Position::NONE)
165                })?)
166            },
167        );
168
169        engine
170    }
171
172    async fn eval_rusty_snippet(&self, code: &str, _timeout_ms: Option<u64>) -> Result<Value> {
173        let wrapped = format!("let __out = {{ {} }};\n__out", code);
174        let engine = self.build_engine();
175        let mut scope = Scope::new();
176
177        let dyn_result = engine.eval_with_scope::<Dynamic>(&mut scope, &wrapped);
178        let dyn_value = dyn_result.map_err(|e| anyhow!("codemode eval error: {}", e))?;
179        let value: Value = rhai::serde::from_dynamic(&dyn_value)
180            .map_err(|e| anyhow!("Failed to convert result: {}", e))?;
181        Ok(value)
182    }
183
184    pub fn tool(&self) -> Tool {
185        self.tool_schema()
186    }
187
188    /// Convenience helpers mirroring go-utcp codemode helper exports.
189    pub async fn call_tool(&self, name: &str, args: HashMap<String, Value>) -> Result<Value> {
190        self.client.call_tool(name, args).await
191    }
192
193    pub async fn call_tool_stream(
194        &self,
195        name: &str,
196        args: HashMap<String, Value>,
197    ) -> Result<Box<dyn crate::transports::stream::StreamResult>> {
198        self.client.call_tool_stream(name, args).await
199    }
200
201    pub async fn search_tools(&self, query: &str, limit: usize) -> Result<Vec<Tool>> {
202        self.client.search_tools(query, limit).await
203    }
204}
205
206#[async_trait::async_trait]
207pub trait LlmModel: Send + Sync {
208    async fn complete(&self, prompt: &str) -> Result<Value>;
209}
210
211/// High-level orchestrator that mirrors go-utcp's CodeMode flow:
212/// 1) Decide if tools are needed
213/// 2) Select tools by name
214/// 3) Ask the model to emit a Rhai snippet using call_tool helpers
215/// 4) Execute the snippet via CodeMode
216pub struct CodemodeOrchestrator {
217    codemode: Arc<CodeModeUtcp>,
218    model: Arc<dyn LlmModel>,
219    tool_specs_cache: RwLock<Option<String>>,
220}
221
222impl CodemodeOrchestrator {
223    pub fn new(codemode: Arc<CodeModeUtcp>, model: Arc<dyn LlmModel>) -> Self {
224        Self {
225            codemode,
226            model,
227            tool_specs_cache: RwLock::new(None),
228        }
229    }
230
231    /// Run the full orchestration flow. Returns Ok(None) if the model says no tools are needed
232    /// or fails to pick any tools. Otherwise returns the codemode execution result.
233    pub async fn call_prompt(&self, prompt: &str) -> Result<Option<Value>> {
234        let specs = self.render_tool_specs().await?;
235
236        if !self.decide_if_tools_needed(prompt, &specs).await? {
237            return Ok(None);
238        }
239
240        let selected = self.select_tools(prompt, &specs).await?;
241        if selected.is_empty() {
242            return Ok(None);
243        }
244
245        let snippet = self.generate_snippet(prompt, &selected, &specs).await?;
246        let raw = self
247            .codemode
248            .execute(CodeModeArgs {
249                code: snippet,
250                timeout: Some(20_000),
251            })
252            .await?;
253
254        Ok(Some(raw.value))
255    }
256
257    async fn render_tool_specs(&self) -> Result<String> {
258        {
259            let cache = self.tool_specs_cache.read().await;
260            if let Some(specs) = &*cache {
261                return Ok(specs.clone());
262            }
263        }
264
265        let tools = self
266            .codemode
267            .search_tools("", 200)
268            .await
269            .unwrap_or_default();
270        let mut rendered =
271            String::from("UTCP TOOL REFERENCE (use exact field names and required keys):\n");
272        for tool in tools {
273            rendered.push_str(&format!("TOOL: {} - {}\n", tool.name, tool.description));
274
275            rendered.push_str("INPUTS:\n");
276            match tool.inputs.properties.as_ref() {
277                Some(props) if !props.is_empty() => {
278                    for (key, schema) in props {
279                        rendered.push_str(&format!("  - {}: {}\n", key, schema_type_hint(schema)));
280                    }
281                }
282                _ => rendered.push_str("  - none\n"),
283            }
284
285            if let Some(required) = tool.inputs.required.as_ref() {
286                if !required.is_empty() {
287                    rendered.push_str("  REQUIRED:\n");
288                    for field in required {
289                        rendered.push_str(&format!("  - {}\n", field));
290                    }
291                }
292            }
293
294            rendered.push_str("OUTPUTS:\n");
295            match tool.outputs.properties.as_ref() {
296                Some(props) if !props.is_empty() => {
297                    for (key, schema) in props {
298                        rendered.push_str(&format!("  - {}: {}\n", key, schema_type_hint(schema)));
299                    }
300                }
301                _ => {
302                    if !tool.outputs.type_.is_empty() {
303                        rendered.push_str(&format!("  - type: {}\n", tool.outputs.type_));
304                    } else {
305                        rendered.push_str("  - (shape unspecified)\n");
306                    }
307                }
308            }
309
310            rendered.push('\n');
311        }
312
313        let mut cache = self.tool_specs_cache.write().await;
314        *cache = Some(rendered.clone());
315        Ok(rendered)
316    }
317
318    async fn decide_if_tools_needed(&self, prompt: &str, specs: &str) -> Result<bool> {
319        let request = format!(
320            "You can call tools described below. Respond with only 'yes' or 'no'.\n\nTOOLS:\n{}\n\nUSER:\n{}",
321            specs, prompt
322        );
323        let resp_val = self.model.complete(&request).await?;
324        Ok(resp_val
325            .as_str()
326            .unwrap_or_default()
327            .trim_start()
328            .to_ascii_lowercase()
329            .starts_with('y'))
330    }
331
332    async fn select_tools(&self, prompt: &str, specs: &str) -> Result<Vec<String>> {
333        let request = format!(
334            "Choose relevant tool names from the list. Respond with a comma-separated list of names only.\n\nTOOLS:\n{}\n\nUSER:\n{}",
335            specs, prompt
336        );
337        let resp_val = self.model.complete(&request).await?;
338        let resp = resp_val.as_str().unwrap_or_default();
339        let mut out = Vec::new();
340        for name in resp.split(',') {
341            let n = name.trim();
342            if !n.is_empty() {
343                out.push(n.to_string());
344            }
345        }
346        Ok(out)
347    }
348
349    async fn generate_snippet(
350        &self,
351        prompt: &str,
352        tools: &[String],
353        specs: &str,
354    ) -> Result<String> {
355        let tool_list = tools.join(", ");
356        let request = format!(
357            "Generate a Rhai snippet that chains UTCP tool calls to satisfy the user request.\n\
358Use ONLY these tools: {tool_list}.\n\
359Helpers available: call_tool(name, map), call_tool_stream(name, map) -> array of streamed chunks, search_tools(query, limit), sprintf(fmt, list).\n\
360Use Rhai map syntax #{{\"field\": value}} with exact input field names; include required fields and never invent new keys.\n\
361You may call multiple tools, store results in variables, and pass them into subsequent tools.\n\
362When using call_tool_stream, treat the returned array as the streamed items and chain it into later calls or the final output.\n\
363Return the final value as the last expression (map/list/scalar). No markdown or commentary, code only.\n\
364\nUSER:\n{prompt}\n\nTOOLS (use exact field names):\n{specs}"
365        );
366        let resp_val = self.model.complete(&request).await?;
367        Ok(resp_val.as_str().unwrap_or_default().trim().to_string())
368    }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
372pub struct CodeModeArgs {
373    pub code: String,
374    #[serde(default)]
375    pub timeout: Option<u64>,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
379pub struct CodeModeResult {
380    pub value: Value,
381    #[serde(default)]
382    pub stdout: String,
383    #[serde(default)]
384    pub stderr: String,
385}
386
387fn schema_type_hint(value: &Value) -> String {
388    if let Some(t) = value.get("type").and_then(|v| v.as_str()) {
389        t.to_string()
390    } else if let Some(s) = value.as_str() {
391        s.to_string()
392    } else if value.is_array() {
393        "array".to_string()
394    } else if value.is_object() {
395        "object".to_string()
396    } else {
397        "any".to_string()
398    }
399}
400
401fn value_to_map(value: Value) -> Result<HashMap<String, Value>, Box<EvalAltResult>> {
402    match value {
403        Value::Object(obj) => Ok(obj.into_iter().collect()),
404        _ => Err(EvalAltResult::ErrorRuntime(
405            "call_tool expects object args".into(),
406            rhai::Position::NONE,
407        )
408        .into()),
409    }
410}
411
412pub fn sprintf(fmt: &str, args: &[Dynamic]) -> String {
413    let mut out = fmt.to_string();
414    for rendered in args.iter().map(|v| v.to_string()) {
415        out = out.replacen("{}", &rendered, 1);
416    }
417    out
418}
419
420fn block_on_any_runtime<F, T>(fut: F) -> Result<T, anyhow::Error>
421where
422    F: std::future::Future<Output = Result<T, anyhow::Error>>,
423    T: Send + 'static,
424{
425    match tokio::runtime::Handle::try_current() {
426        Ok(handle) => match handle.runtime_flavor() {
427            RuntimeFlavor::MultiThread => tokio::task::block_in_place(|| handle.block_on(fut)),
428            RuntimeFlavor::CurrentThread => {
429                let rt = Builder::new_current_thread().enable_all().build()?;
430                rt.block_on(fut)
431            }
432            _ => {
433                let rt = Builder::new_current_thread().enable_all().build()?;
434                rt.block_on(fut)
435            }
436        },
437        Err(_) => {
438            let rt = Builder::new_current_thread().enable_all().build()?;
439            rt.block_on(fut)
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::tools::Tool;
448    use crate::transports::stream::boxed_vec_stream;
449    use tokio::sync::Mutex;
450
451    #[derive(Clone)]
452    struct MockClient {
453        called: Arc<Mutex<Vec<String>>>,
454    }
455
456    #[async_trait::async_trait]
457    impl UtcpClientInterface for MockClient {
458        async fn register_tool_provider(
459            &self,
460            _prov: Arc<dyn crate::providers::base::Provider>,
461        ) -> Result<Vec<Tool>> {
462            Ok(vec![])
463        }
464
465        async fn register_tool_provider_with_tools(
466            &self,
467            _prov: Arc<dyn crate::providers::base::Provider>,
468            tools: Vec<Tool>,
469        ) -> Result<Vec<Tool>> {
470            Ok(tools)
471        }
472
473        async fn deregister_tool_provider(&self, _provider_name: &str) -> Result<()> {
474            Ok(())
475        }
476
477        async fn call_tool(&self, tool_name: &str, _args: HashMap<String, Value>) -> Result<Value> {
478            self.called.lock().await.push(tool_name.to_string());
479            Ok(Value::Number(serde_json::Number::from(5)))
480        }
481
482        async fn search_tools(&self, query: &str, _limit: usize) -> Result<Vec<Tool>> {
483            self.called.lock().await.push(format!("search:{query}"));
484            Ok(vec![])
485        }
486
487        fn get_transports(&self) -> HashMap<String, Arc<dyn crate::transports::ClientTransport>> {
488            HashMap::new()
489        }
490
491        async fn call_tool_stream(
492            &self,
493            tool_name: &str,
494            _args: HashMap<String, Value>,
495        ) -> Result<Box<dyn crate::transports::stream::StreamResult>> {
496            self.called.lock().await.push(format!("stream:{tool_name}"));
497            Ok(boxed_vec_stream(vec![Value::String("chunk".into())]))
498        }
499    }
500
501    #[tokio::test(flavor = "multi_thread")]
502    async fn codemode_helpers_forward_to_client() {
503        let client = Arc::new(MockClient {
504            called: Arc::new(Mutex::new(Vec::new())),
505        });
506        let codemode = CodeModeUtcp::new(client.clone());
507
508        codemode
509            .call_tool("demo.tool", HashMap::new())
510            .await
511            .unwrap();
512        codemode.search_tools("demo", 5).await.unwrap();
513        let mut stream = codemode
514            .call_tool_stream("demo.tool", HashMap::new())
515            .await
516            .unwrap();
517        let _ = stream.next().await.unwrap();
518
519        let calls = client.called.lock().await.clone();
520        assert_eq!(calls, vec!["demo.tool", "search:demo", "stream:demo.tool"]);
521    }
522
523    #[tokio::test(flavor = "multi_thread")]
524    async fn execute_runs_rusty_snippet_and_call_tool() {
525        let client = Arc::new(MockClient {
526            called: Arc::new(Mutex::new(Vec::new())),
527        });
528        let codemode = CodeModeUtcp::new(client);
529
530        let code = r#"let x = 2 + 3; let y = call_tool("math.add", #{"a":1}); x + y"#;
531        let args = CodeModeArgs {
532            code: code.into(),
533            timeout: Some(1000),
534        };
535        let res = codemode.execute(args).await.unwrap();
536        assert_eq!(res.value, serde_json::json!(10));
537    }
538
539    #[tokio::test(flavor = "multi_thread")]
540    async fn execute_collects_stream_results() {
541        let client = Arc::new(MockClient {
542            called: Arc::new(Mutex::new(Vec::new())),
543        });
544        let codemode = CodeModeUtcp::new(client.clone());
545
546        let code = r#"let chunks = call_tool_stream("demo.tool", #{}); chunks"#;
547        let args = CodeModeArgs {
548            code: code.into(),
549            timeout: Some(1_000),
550        };
551        let res = codemode.execute(args).await.unwrap();
552        assert_eq!(res.value, serde_json::json!(["chunk"]));
553        let calls = client.called.lock().await.clone();
554        assert_eq!(calls, vec!["stream:demo.tool"]);
555    }
556}