gent/runtime/tools/
user_tool.rs1use super::Tool;
4use crate::interpreter::block_eval::evaluate_block;
5use crate::interpreter::{Environment, UserToolValue, Value};
6use crate::parser::ast::TypeName;
7use async_trait::async_trait;
8use serde_json::{json, Value as JsonValue};
9use std::sync::Arc;
10
11pub struct UserToolWrapper {
13 tool: UserToolValue,
14 env: Arc<Environment>,
15}
16
17impl UserToolWrapper {
18 pub fn new(tool: UserToolValue, env: Arc<Environment>) -> Self {
20 Self { tool, env }
21 }
22}
23
24#[async_trait]
25impl Tool for UserToolWrapper {
26 fn name(&self) -> &str {
27 &self.tool.name
28 }
29
30 fn description(&self) -> &str {
31 "User-defined tool"
32 }
33
34 fn parameters_schema(&self) -> JsonValue {
35 let mut properties = serde_json::Map::new();
36 let mut required = Vec::new();
37
38 for param in &self.tool.params {
39 required.push(param.name.clone());
40
41 let type_str = match param.type_name {
42 TypeName::String => "string",
43 TypeName::Number => "number",
44 TypeName::Boolean => "boolean",
45 TypeName::Array => "array",
46 TypeName::Object => "object",
47 TypeName::Any => "string", };
49
50 properties.insert(
51 param.name.clone(),
52 json!({
53 "type": type_str,
54 "description": format!("Parameter {}", param.name)
55 }),
56 );
57 }
58
59 json!({
60 "type": "object",
61 "properties": properties,
62 "required": required
63 })
64 }
65
66 async fn execute(&self, args: JsonValue) -> Result<String, String> {
67 let tool_body = self.tool.body.clone();
69 let params = self.tool.params.clone();
70 let base_env = self.env.clone();
71
72 tokio::task::spawn_blocking(move || {
74 let rt = tokio::runtime::Runtime::new().unwrap();
76 rt.block_on(async move {
77 let mut exec_env = (*base_env).clone();
79
80 for param in ¶ms {
82 let arg_value = args
83 .get(¶m.name)
84 .ok_or_else(|| format!("Missing required parameter: {}", param.name))?;
85
86 let gent_value = json_to_value(arg_value);
88
89 exec_env.define(¶m.name, gent_value);
91 }
92
93 let tools = super::ToolRegistry::new();
96
97 let result = evaluate_block(&tool_body, &mut exec_env, &tools)
99 .await
100 .map_err(|e| format!("Tool execution failed: {}", e))?;
101
102 Ok::<String, String>(result.to_string())
104 })
105 })
106 .await
107 .map_err(|e| format!("Task panicked: {}", e))?
108 }
109}
110
111fn json_to_value(json: &JsonValue) -> Value {
113 match json {
114 JsonValue::Null => Value::Null,
115 JsonValue::Bool(b) => Value::Boolean(*b),
116 JsonValue::Number(n) => {
117 if let Some(f) = n.as_f64() {
118 Value::Number(f)
119 } else {
120 Value::Null
121 }
122 }
123 JsonValue::String(s) => Value::String(s.clone()),
124 JsonValue::Array(arr) => {
125 let items = arr.iter().map(json_to_value).collect();
126 Value::Array(items)
127 }
128 JsonValue::Object(obj) => {
129 let mut map = std::collections::HashMap::new();
130 for (k, v) in obj {
131 map.insert(k.clone(), json_to_value(v));
132 }
133 Value::Object(map)
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_json_to_value_primitives() {
144 assert_eq!(json_to_value(&json!(null)), Value::Null);
145 assert_eq!(json_to_value(&json!(true)), Value::Boolean(true));
146 assert_eq!(json_to_value(&json!(42)), Value::Number(42.0));
147 assert_eq!(
148 json_to_value(&json!("hello")),
149 Value::String("hello".to_string())
150 );
151 }
152
153 #[test]
154 fn test_json_to_value_array() {
155 let json_arr = json!([1, 2, 3]);
156 let result = json_to_value(&json_arr);
157
158 if let Value::Array(items) = result {
159 assert_eq!(items.len(), 3);
160 assert_eq!(items[0], Value::Number(1.0));
161 assert_eq!(items[1], Value::Number(2.0));
162 assert_eq!(items[2], Value::Number(3.0));
163 } else {
164 panic!("Expected Array value");
165 }
166 }
167
168 #[test]
169 fn test_json_to_value_object() {
170 let json_obj = json!({"key": "value", "num": 42});
171 let result = json_to_value(&json_obj);
172
173 if let Value::Object(map) = result {
174 assert_eq!(map.len(), 2);
175 assert_eq!(map.get("key"), Some(&Value::String("value".to_string())));
176 assert_eq!(map.get("num"), Some(&Value::Number(42.0)));
177 } else {
178 panic!("Expected Object value");
179 }
180 }
181}