chasm_cli/agency/
tools.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Tool System
4//!
5//! Define and register tools that agents can use.
6
7#![allow(dead_code)]
8
9use crate::agency::error::AgencyResult;
10use crate::agency::models::ToolResult;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18
19/// Tool parameter definition
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ToolParameter {
22    /// Parameter name
23    pub name: String,
24    /// Parameter type (string, number, boolean, array, object)
25    #[serde(rename = "type")]
26    pub param_type: String,
27    /// Parameter description
28    pub description: String,
29    /// Whether the parameter is required
30    #[serde(default)]
31    pub required: bool,
32    /// Enum values (if applicable)
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub enum_values: Option<Vec<String>>,
35    /// Default value
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub default: Option<Value>,
38}
39
40/// Tool definition
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Tool {
43    /// Tool name (must be unique)
44    pub name: String,
45    /// Tool description
46    pub description: String,
47    /// Tool parameters
48    #[serde(default)]
49    pub parameters: Vec<ToolParameter>,
50    /// Tool category
51    #[serde(default)]
52    pub category: ToolCategory,
53    /// Whether the tool requires confirmation before execution
54    #[serde(default)]
55    pub requires_confirmation: bool,
56    /// Custom metadata
57    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
58    pub metadata: HashMap<String, Value>,
59}
60
61impl Tool {
62    /// Create a new tool
63    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
64        Self {
65            name: name.into(),
66            description: description.into(),
67            parameters: Vec::new(),
68            category: ToolCategory::Custom,
69            requires_confirmation: false,
70            metadata: HashMap::new(),
71        }
72    }
73
74    /// Convert to function definition for model API
75    pub fn to_function_definition(&self) -> Value {
76        let mut properties = serde_json::Map::new();
77        let mut required = Vec::new();
78
79        for param in &self.parameters {
80            let mut prop = serde_json::Map::new();
81            prop.insert("type".to_string(), Value::String(param.param_type.clone()));
82            prop.insert(
83                "description".to_string(),
84                Value::String(param.description.clone()),
85            );
86
87            if let Some(enum_vals) = &param.enum_values {
88                prop.insert(
89                    "enum".to_string(),
90                    Value::Array(enum_vals.iter().map(|v| Value::String(v.clone())).collect()),
91                );
92            }
93
94            properties.insert(param.name.clone(), Value::Object(prop));
95
96            if param.required {
97                required.push(Value::String(param.name.clone()));
98            }
99        }
100
101        serde_json::json!({
102            "type": "function",
103            "function": {
104                "name": self.name,
105                "description": self.description,
106                "parameters": {
107                    "type": "object",
108                    "properties": properties,
109                    "required": required
110                }
111            }
112        })
113    }
114}
115
116/// Tool category for organization
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
118#[serde(rename_all = "snake_case")]
119pub enum ToolCategory {
120    #[default]
121    Custom,
122    Search,
123    Code,
124    File,
125    Data,
126    Communication,
127    System,
128    Builtin,
129}
130
131/// Fluent builder for tools
132pub struct ToolBuilder {
133    tool: Tool,
134}
135
136impl ToolBuilder {
137    /// Create a new tool builder
138    pub fn new(name: impl Into<String>) -> Self {
139        Self {
140            tool: Tool {
141                name: name.into(),
142                description: String::new(),
143                parameters: Vec::new(),
144                category: ToolCategory::Custom,
145                requires_confirmation: false,
146                metadata: HashMap::new(),
147            },
148        }
149    }
150
151    /// Set description
152    pub fn description(mut self, desc: impl Into<String>) -> Self {
153        self.tool.description = desc.into();
154        self
155    }
156
157    /// Add a parameter
158    pub fn parameter(
159        mut self,
160        name: impl Into<String>,
161        param_type: impl Into<String>,
162        description: impl Into<String>,
163        required: bool,
164    ) -> Self {
165        self.tool.parameters.push(ToolParameter {
166            name: name.into(),
167            param_type: param_type.into(),
168            description: description.into(),
169            required,
170            enum_values: None,
171            default: None,
172        });
173        self
174    }
175
176    /// Add a string parameter
177    pub fn string_param(
178        self,
179        name: impl Into<String>,
180        description: impl Into<String>,
181        required: bool,
182    ) -> Self {
183        self.parameter(name, "string", description, required)
184    }
185
186    /// Add a number parameter
187    pub fn number_param(
188        self,
189        name: impl Into<String>,
190        description: impl Into<String>,
191        required: bool,
192    ) -> Self {
193        self.parameter(name, "number", description, required)
194    }
195
196    /// Add a boolean parameter
197    pub fn bool_param(
198        self,
199        name: impl Into<String>,
200        description: impl Into<String>,
201        required: bool,
202    ) -> Self {
203        self.parameter(name, "boolean", description, required)
204    }
205
206    /// Set category
207    pub fn category(mut self, category: ToolCategory) -> Self {
208        self.tool.category = category;
209        self
210    }
211
212    /// Set requires confirmation
213    pub fn requires_confirmation(mut self, requires: bool) -> Self {
214        self.tool.requires_confirmation = requires;
215        self
216    }
217
218    /// Build the tool
219    pub fn build(self) -> Tool {
220        self.tool
221    }
222}
223
224/// Trait for executable tools
225#[async_trait]
226pub trait ToolExecutor: Send + Sync {
227    /// Get the tool definition
228    fn definition(&self) -> &Tool;
229
230    /// Execute the tool with the given arguments
231    async fn execute(&self, args: Value) -> AgencyResult<ToolResult>;
232}
233
234/// Type alias for tool execution function
235pub type ToolFn = Box<
236    dyn Fn(Value) -> Pin<Box<dyn Future<Output = AgencyResult<ToolResult>> + Send>> + Send + Sync,
237>;
238
239/// Tool registry for managing available tools
240#[derive(Default)]
241pub struct ToolRegistry {
242    tools: HashMap<String, Arc<Tool>>,
243    executors: HashMap<String, Arc<dyn ToolExecutor>>,
244}
245
246impl ToolRegistry {
247    /// Create a new empty registry
248    pub fn new() -> Self {
249        Self::default()
250    }
251
252    /// Create a registry with builtin tools
253    pub fn with_builtins() -> Self {
254        let mut registry = Self::new();
255        registry.register_builtins();
256        registry
257    }
258
259    /// Register a tool
260    pub fn register(&mut self, tool: Tool) {
261        self.tools.insert(tool.name.clone(), Arc::new(tool));
262    }
263
264    /// Register a tool with its executor
265    pub fn register_with_executor(&mut self, executor: impl ToolExecutor + 'static) {
266        let tool = executor.definition().clone();
267        let name = tool.name.clone();
268        self.tools.insert(name.clone(), Arc::new(tool));
269        self.executors.insert(name, Arc::new(executor));
270    }
271
272    /// Get a tool by name
273    pub fn get(&self, name: &str) -> Option<&Arc<Tool>> {
274        self.tools.get(name)
275    }
276
277    /// Get an executor by name
278    pub fn get_executor(&self, name: &str) -> Option<&Arc<dyn ToolExecutor>> {
279        self.executors.get(name)
280    }
281
282    /// List all tools
283    pub fn list(&self) -> Vec<&Tool> {
284        self.tools.values().map(|t| t.as_ref()).collect()
285    }
286
287    /// Get tool definitions for model API
288    pub fn to_definitions(&self) -> Vec<Value> {
289        self.tools
290            .values()
291            .map(|t| t.to_function_definition())
292            .collect()
293    }
294
295    /// Register builtin tools
296    fn register_builtins(&mut self) {
297        for tool in BuiltinTools::all() {
298            self.register(tool);
299        }
300    }
301}
302
303/// Builtin tools provided by the Agency
304pub struct BuiltinTools;
305
306impl BuiltinTools {
307    /// Get all builtin tools
308    pub fn all() -> Vec<Tool> {
309        vec![
310            Self::web_search(),
311            Self::code_execution(),
312            Self::read_file(),
313            Self::write_file(),
314            Self::list_directory(),
315            Self::http_request(),
316            Self::calculator(),
317        ]
318    }
319
320    /// Web search tool
321    pub fn web_search() -> Tool {
322        ToolBuilder::new("web_search")
323            .description("Search the web for information. Returns relevant snippets and URLs.")
324            .string_param("query", "The search query", true)
325            .number_param(
326                "max_results",
327                "Maximum number of results (default: 5)",
328                false,
329            )
330            .category(ToolCategory::Search)
331            .build()
332    }
333
334    /// Code execution tool
335    pub fn code_execution() -> Tool {
336        ToolBuilder::new("code_execution")
337            .description("Execute code in a sandboxed environment. Supports Python, JavaScript, and shell scripts.")
338            .string_param("code", "The code to execute", true)
339            .string_param("language", "Programming language (python, javascript, shell)", true)
340            .number_param("timeout", "Execution timeout in seconds (default: 30)", false)
341            .category(ToolCategory::Code)
342            .requires_confirmation(true)
343            .build()
344    }
345
346    /// Read file tool
347    pub fn read_file() -> Tool {
348        ToolBuilder::new("read_file")
349            .description("Read the contents of a file from the filesystem.")
350            .string_param("path", "Path to the file to read", true)
351            .string_param("encoding", "File encoding (default: utf-8)", false)
352            .category(ToolCategory::File)
353            .build()
354    }
355
356    /// Write file tool
357    pub fn write_file() -> Tool {
358        ToolBuilder::new("write_file")
359            .description("Write content to a file. Creates the file if it doesn't exist.")
360            .string_param("path", "Path to the file to write", true)
361            .string_param("content", "Content to write to the file", true)
362            .bool_param("append", "Append to file instead of overwriting", false)
363            .category(ToolCategory::File)
364            .requires_confirmation(true)
365            .build()
366    }
367
368    /// List directory tool
369    pub fn list_directory() -> Tool {
370        ToolBuilder::new("list_directory")
371            .description("List the contents of a directory.")
372            .string_param("path", "Path to the directory", true)
373            .bool_param("recursive", "Include subdirectories", false)
374            .bool_param("include_hidden", "Include hidden files", false)
375            .category(ToolCategory::File)
376            .build()
377    }
378
379    /// HTTP request tool
380    pub fn http_request() -> Tool {
381        ToolBuilder::new("http_request")
382            .description("Make an HTTP request to a URL.")
383            .string_param("url", "The URL to request", true)
384            .string_param("method", "HTTP method (GET, POST, PUT, DELETE)", false)
385            .string_param("body", "Request body (for POST/PUT)", false)
386            .string_param("headers", "JSON object of headers", false)
387            .category(ToolCategory::Communication)
388            .build()
389    }
390
391    /// Calculator tool
392    pub fn calculator() -> Tool {
393        ToolBuilder::new("calculator")
394            .description("Evaluate mathematical expressions. Supports basic arithmetic, functions, and constants.")
395            .string_param("expression", "The mathematical expression to evaluate", true)
396            .category(ToolCategory::Data)
397            .build()
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_tool_builder() {
407        let tool = ToolBuilder::new("test_tool")
408            .description("A test tool")
409            .string_param("input", "Input parameter", true)
410            .number_param("count", "Count parameter", false)
411            .category(ToolCategory::Custom)
412            .build();
413
414        assert_eq!(tool.name, "test_tool");
415        assert_eq!(tool.description, "A test tool");
416        assert_eq!(tool.parameters.len(), 2);
417        assert!(tool.parameters[0].required);
418        assert!(!tool.parameters[1].required);
419    }
420
421    #[test]
422    fn test_function_definition() {
423        let tool = BuiltinTools::web_search();
424        let def = tool.to_function_definition();
425
426        assert_eq!(def["type"], "function");
427        assert_eq!(def["function"]["name"], "web_search");
428    }
429
430    #[test]
431    fn test_registry() {
432        let registry = ToolRegistry::with_builtins();
433        assert!(registry.get("web_search").is_some());
434        assert!(registry.get("code_execution").is_some());
435        assert!(registry.get("nonexistent").is_none());
436    }
437}