1#![allow(dead_code)]
8
9use crate::agency::error::AgencyResult;
10use crate::agency::models::ToolResult;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ToolParameter {
22 pub name: String,
24 #[serde(rename = "type")]
26 pub param_type: String,
27 pub description: String,
29 #[serde(default)]
31 pub required: bool,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
34 pub enum_values: Option<Vec<String>>,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub default: Option<Value>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Tool {
43 pub name: String,
45 pub description: String,
47 #[serde(default)]
49 pub parameters: Vec<ToolParameter>,
50 #[serde(default)]
52 pub category: ToolCategory,
53 #[serde(default)]
55 pub requires_confirmation: bool,
56 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
58 pub metadata: HashMap<String, Value>,
59}
60
61impl Tool {
62 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
64 Self {
65 name: name.into(),
66 description: description.into(),
67 parameters: Vec::new(),
68 category: ToolCategory::Custom,
69 requires_confirmation: false,
70 metadata: HashMap::new(),
71 }
72 }
73
74 pub fn to_function_definition(&self) -> Value {
76 let mut properties = serde_json::Map::new();
77 let mut required = Vec::new();
78
79 for param in &self.parameters {
80 let mut prop = serde_json::Map::new();
81 prop.insert("type".to_string(), Value::String(param.param_type.clone()));
82 prop.insert(
83 "description".to_string(),
84 Value::String(param.description.clone()),
85 );
86
87 if let Some(enum_vals) = ¶m.enum_values {
88 prop.insert(
89 "enum".to_string(),
90 Value::Array(enum_vals.iter().map(|v| Value::String(v.clone())).collect()),
91 );
92 }
93
94 properties.insert(param.name.clone(), Value::Object(prop));
95
96 if param.required {
97 required.push(Value::String(param.name.clone()));
98 }
99 }
100
101 serde_json::json!({
102 "type": "function",
103 "function": {
104 "name": self.name,
105 "description": self.description,
106 "parameters": {
107 "type": "object",
108 "properties": properties,
109 "required": required
110 }
111 }
112 })
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
118#[serde(rename_all = "snake_case")]
119pub enum ToolCategory {
120 #[default]
121 Custom,
122 Search,
123 Code,
124 File,
125 Data,
126 Communication,
127 System,
128 Builtin,
129}
130
131pub struct ToolBuilder {
133 tool: Tool,
134}
135
136impl ToolBuilder {
137 pub fn new(name: impl Into<String>) -> Self {
139 Self {
140 tool: Tool {
141 name: name.into(),
142 description: String::new(),
143 parameters: Vec::new(),
144 category: ToolCategory::Custom,
145 requires_confirmation: false,
146 metadata: HashMap::new(),
147 },
148 }
149 }
150
151 pub fn description(mut self, desc: impl Into<String>) -> Self {
153 self.tool.description = desc.into();
154 self
155 }
156
157 pub fn parameter(
159 mut self,
160 name: impl Into<String>,
161 param_type: impl Into<String>,
162 description: impl Into<String>,
163 required: bool,
164 ) -> Self {
165 self.tool.parameters.push(ToolParameter {
166 name: name.into(),
167 param_type: param_type.into(),
168 description: description.into(),
169 required,
170 enum_values: None,
171 default: None,
172 });
173 self
174 }
175
176 pub fn string_param(
178 self,
179 name: impl Into<String>,
180 description: impl Into<String>,
181 required: bool,
182 ) -> Self {
183 self.parameter(name, "string", description, required)
184 }
185
186 pub fn number_param(
188 self,
189 name: impl Into<String>,
190 description: impl Into<String>,
191 required: bool,
192 ) -> Self {
193 self.parameter(name, "number", description, required)
194 }
195
196 pub fn bool_param(
198 self,
199 name: impl Into<String>,
200 description: impl Into<String>,
201 required: bool,
202 ) -> Self {
203 self.parameter(name, "boolean", description, required)
204 }
205
206 pub fn category(mut self, category: ToolCategory) -> Self {
208 self.tool.category = category;
209 self
210 }
211
212 pub fn requires_confirmation(mut self, requires: bool) -> Self {
214 self.tool.requires_confirmation = requires;
215 self
216 }
217
218 pub fn build(self) -> Tool {
220 self.tool
221 }
222}
223
224#[async_trait]
226pub trait ToolExecutor: Send + Sync {
227 fn definition(&self) -> &Tool;
229
230 async fn execute(&self, args: Value) -> AgencyResult<ToolResult>;
232}
233
234pub type ToolFn = Box<
236 dyn Fn(Value) -> Pin<Box<dyn Future<Output = AgencyResult<ToolResult>> + Send>> + Send + Sync,
237>;
238
239#[derive(Default)]
241pub struct ToolRegistry {
242 tools: HashMap<String, Arc<Tool>>,
243 executors: HashMap<String, Arc<dyn ToolExecutor>>,
244}
245
246impl ToolRegistry {
247 pub fn new() -> Self {
249 Self::default()
250 }
251
252 pub fn with_builtins() -> Self {
254 let mut registry = Self::new();
255 registry.register_builtins();
256 registry
257 }
258
259 pub fn register(&mut self, tool: Tool) {
261 self.tools.insert(tool.name.clone(), Arc::new(tool));
262 }
263
264 pub fn register_with_executor(&mut self, executor: impl ToolExecutor + 'static) {
266 let tool = executor.definition().clone();
267 let name = tool.name.clone();
268 self.tools.insert(name.clone(), Arc::new(tool));
269 self.executors.insert(name, Arc::new(executor));
270 }
271
272 pub fn get(&self, name: &str) -> Option<&Arc<Tool>> {
274 self.tools.get(name)
275 }
276
277 pub fn get_executor(&self, name: &str) -> Option<&Arc<dyn ToolExecutor>> {
279 self.executors.get(name)
280 }
281
282 pub fn list(&self) -> Vec<&Tool> {
284 self.tools.values().map(|t| t.as_ref()).collect()
285 }
286
287 pub fn to_definitions(&self) -> Vec<Value> {
289 self.tools
290 .values()
291 .map(|t| t.to_function_definition())
292 .collect()
293 }
294
295 fn register_builtins(&mut self) {
297 for tool in BuiltinTools::all() {
298 self.register(tool);
299 }
300 }
301}
302
303pub struct BuiltinTools;
305
306impl BuiltinTools {
307 pub fn all() -> Vec<Tool> {
309 vec![
310 Self::web_search(),
311 Self::code_execution(),
312 Self::read_file(),
313 Self::write_file(),
314 Self::list_directory(),
315 Self::http_request(),
316 Self::calculator(),
317 ]
318 }
319
320 pub fn web_search() -> Tool {
322 ToolBuilder::new("web_search")
323 .description("Search the web for information. Returns relevant snippets and URLs.")
324 .string_param("query", "The search query", true)
325 .number_param(
326 "max_results",
327 "Maximum number of results (default: 5)",
328 false,
329 )
330 .category(ToolCategory::Search)
331 .build()
332 }
333
334 pub fn code_execution() -> Tool {
336 ToolBuilder::new("code_execution")
337 .description("Execute code in a sandboxed environment. Supports Python, JavaScript, and shell scripts.")
338 .string_param("code", "The code to execute", true)
339 .string_param("language", "Programming language (python, javascript, shell)", true)
340 .number_param("timeout", "Execution timeout in seconds (default: 30)", false)
341 .category(ToolCategory::Code)
342 .requires_confirmation(true)
343 .build()
344 }
345
346 pub fn read_file() -> Tool {
348 ToolBuilder::new("read_file")
349 .description("Read the contents of a file from the filesystem.")
350 .string_param("path", "Path to the file to read", true)
351 .string_param("encoding", "File encoding (default: utf-8)", false)
352 .category(ToolCategory::File)
353 .build()
354 }
355
356 pub fn write_file() -> Tool {
358 ToolBuilder::new("write_file")
359 .description("Write content to a file. Creates the file if it doesn't exist.")
360 .string_param("path", "Path to the file to write", true)
361 .string_param("content", "Content to write to the file", true)
362 .bool_param("append", "Append to file instead of overwriting", false)
363 .category(ToolCategory::File)
364 .requires_confirmation(true)
365 .build()
366 }
367
368 pub fn list_directory() -> Tool {
370 ToolBuilder::new("list_directory")
371 .description("List the contents of a directory.")
372 .string_param("path", "Path to the directory", true)
373 .bool_param("recursive", "Include subdirectories", false)
374 .bool_param("include_hidden", "Include hidden files", false)
375 .category(ToolCategory::File)
376 .build()
377 }
378
379 pub fn http_request() -> Tool {
381 ToolBuilder::new("http_request")
382 .description("Make an HTTP request to a URL.")
383 .string_param("url", "The URL to request", true)
384 .string_param("method", "HTTP method (GET, POST, PUT, DELETE)", false)
385 .string_param("body", "Request body (for POST/PUT)", false)
386 .string_param("headers", "JSON object of headers", false)
387 .category(ToolCategory::Communication)
388 .build()
389 }
390
391 pub fn calculator() -> Tool {
393 ToolBuilder::new("calculator")
394 .description("Evaluate mathematical expressions. Supports basic arithmetic, functions, and constants.")
395 .string_param("expression", "The mathematical expression to evaluate", true)
396 .category(ToolCategory::Data)
397 .build()
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_tool_builder() {
407 let tool = ToolBuilder::new("test_tool")
408 .description("A test tool")
409 .string_param("input", "Input parameter", true)
410 .number_param("count", "Count parameter", false)
411 .category(ToolCategory::Custom)
412 .build();
413
414 assert_eq!(tool.name, "test_tool");
415 assert_eq!(tool.description, "A test tool");
416 assert_eq!(tool.parameters.len(), 2);
417 assert!(tool.parameters[0].required);
418 assert!(!tool.parameters[1].required);
419 }
420
421 #[test]
422 fn test_function_definition() {
423 let tool = BuiltinTools::web_search();
424 let def = tool.to_function_definition();
425
426 assert_eq!(def["type"], "function");
427 assert_eq!(def["function"]["name"], "web_search");
428 }
429
430 #[test]
431 fn test_registry() {
432 let registry = ToolRegistry::with_builtins();
433 assert!(registry.get("web_search").is_some());
434 assert!(registry.get("code_execution").is_some());
435 assert!(registry.get("nonexistent").is_none());
436 }
437}