gent/runtime/tools/
user_tool.rs

1//! User-defined tool wrapper for the tool registry
2
3use super::Tool;
4use crate::interpreter::block_eval::evaluate_block;
5use crate::interpreter::{Environment, UserToolValue, Value};
6use crate::parser::ast::TypeName;
7use async_trait::async_trait;
8use serde_json::{json, Value as JsonValue};
9use std::sync::Arc;
10
11/// Wrapper that makes UserToolValue implement the Tool trait
12pub struct UserToolWrapper {
13    tool: UserToolValue,
14    env: Arc<Environment>,
15}
16
17impl UserToolWrapper {
18    /// Create a new UserToolWrapper
19    pub fn new(tool: UserToolValue, env: Arc<Environment>) -> Self {
20        Self { tool, env }
21    }
22}
23
24#[async_trait]
25impl Tool for UserToolWrapper {
26    fn name(&self) -> &str {
27        &self.tool.name
28    }
29
30    fn description(&self) -> &str {
31        "User-defined tool"
32    }
33
34    fn parameters_schema(&self) -> JsonValue {
35        let mut properties = serde_json::Map::new();
36        let mut required = Vec::new();
37
38        for param in &self.tool.params {
39            required.push(param.name.clone());
40
41            let type_str = match param.type_name {
42                TypeName::String => "string",
43                TypeName::Number => "number",
44                TypeName::Boolean => "boolean",
45                TypeName::Array => "array",
46                TypeName::Object => "object",
47                TypeName::Any => "string", // Default to string for Any
48            };
49
50            properties.insert(
51                param.name.clone(),
52                json!({
53                    "type": type_str,
54                    "description": format!("Parameter {}", param.name)
55                }),
56            );
57        }
58
59        json!({
60            "type": "object",
61            "properties": properties,
62            "required": required
63        })
64    }
65
66    async fn execute(&self, args: JsonValue) -> Result<String, String> {
67        // Clone all the data we need to own it in the async block
68        let tool_body = self.tool.body.clone();
69        let params = self.tool.params.clone();
70        let base_env = self.env.clone();
71
72        // Use spawn_blocking to run the non-Send future in a blocking context
73        tokio::task::spawn_blocking(move || {
74            // Create a new runtime for the blocking task
75            let rt = tokio::runtime::Runtime::new().unwrap();
76            rt.block_on(async move {
77                // Clone the environment to create an isolated execution context
78                let mut exec_env = (*base_env).clone();
79
80                // Bind parameters from JSON args to the environment
81                for param in &params {
82                    let arg_value = args
83                        .get(&param.name)
84                        .ok_or_else(|| format!("Missing required parameter: {}", param.name))?;
85
86                    // Convert JSON value to GENT Value
87                    let gent_value = json_to_value(arg_value);
88
89                    // Define the parameter in the environment
90                    exec_env.define(&param.name, gent_value);
91                }
92
93                // Create an empty tool registry for executing the tool body
94                // User tools cannot call other tools during their execution
95                let tools = super::ToolRegistry::new();
96
97                // Execute the tool body
98                let result = evaluate_block(&tool_body, &mut exec_env, &tools)
99                    .await
100                    .map_err(|e| format!("Tool execution failed: {}", e))?;
101
102                // Convert the result to a string
103                Ok::<String, String>(result.to_string())
104            })
105        })
106        .await
107        .map_err(|e| format!("Task panicked: {}", e))?
108    }
109}
110
111/// Convert a JSON value to a GENT Value
112fn json_to_value(json: &JsonValue) -> Value {
113    match json {
114        JsonValue::Null => Value::Null,
115        JsonValue::Bool(b) => Value::Boolean(*b),
116        JsonValue::Number(n) => {
117            if let Some(f) = n.as_f64() {
118                Value::Number(f)
119            } else {
120                Value::Null
121            }
122        }
123        JsonValue::String(s) => Value::String(s.clone()),
124        JsonValue::Array(arr) => {
125            let items = arr.iter().map(json_to_value).collect();
126            Value::Array(items)
127        }
128        JsonValue::Object(obj) => {
129            let mut map = std::collections::HashMap::new();
130            for (k, v) in obj {
131                map.insert(k.clone(), json_to_value(v));
132            }
133            Value::Object(map)
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_json_to_value_primitives() {
144        assert_eq!(json_to_value(&json!(null)), Value::Null);
145        assert_eq!(json_to_value(&json!(true)), Value::Boolean(true));
146        assert_eq!(json_to_value(&json!(42)), Value::Number(42.0));
147        assert_eq!(
148            json_to_value(&json!("hello")),
149            Value::String("hello".to_string())
150        );
151    }
152
153    #[test]
154    fn test_json_to_value_array() {
155        let json_arr = json!([1, 2, 3]);
156        let result = json_to_value(&json_arr);
157
158        if let Value::Array(items) = result {
159            assert_eq!(items.len(), 3);
160            assert_eq!(items[0], Value::Number(1.0));
161            assert_eq!(items[1], Value::Number(2.0));
162            assert_eq!(items[2], Value::Number(3.0));
163        } else {
164            panic!("Expected Array value");
165        }
166    }
167
168    #[test]
169    fn test_json_to_value_object() {
170        let json_obj = json!({"key": "value", "num": 42});
171        let result = json_to_value(&json_obj);
172
173        if let Value::Object(map) = result {
174            assert_eq!(map.len(), 2);
175            assert_eq!(map.get("key"), Some(&Value::String("value".to_string())));
176            assert_eq!(map.get("num"), Some(&Value::Number(42.0)));
177        } else {
178            panic!("Expected Object value");
179        }
180    }
181}