Skip to main content

agent_runtime/
tool.rs

1use crate::types::{ToolError, ToolExecutionResult, ToolResult};
2use async_trait::async_trait;
3use futures::future::BoxFuture;
4use serde_json::Value as JsonValue;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8/// Tool trait that all tools must implement
9#[async_trait]
10pub trait Tool: Send + Sync {
11    /// Unique name for this tool
12    fn name(&self) -> &str;
13
14    /// Human-readable description for LLM
15    fn description(&self) -> &str;
16
17    /// JSON schema for input parameters
18    fn input_schema(&self) -> JsonValue;
19
20    /// Execute the tool with given parameters
21    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult;
22}
23
24type ToolExecutor = Arc<
25    dyn Fn(HashMap<String, JsonValue>) -> BoxFuture<'static, ToolExecutionResult> + Send + Sync,
26>;
27
28/// A native (in-memory) tool implemented as a Rust async function
29///
30/// Native tools execute directly in the runtime process with no IPC overhead.
31/// They are defined as async closures that accept parameters and return results.
32pub struct NativeTool {
33    name: String,
34    description: String,
35    input_schema: JsonValue,
36    executor: ToolExecutor,
37}
38
39impl NativeTool {
40    /// Create a new native tool
41    ///
42    /// # Arguments
43    /// * `name` - Unique identifier for the tool
44    /// * `description` - Human-readable description
45    /// * `input_schema` - JSON Schema describing input parameters
46    /// * `executor` - Async function that executes the tool
47    pub fn new<F, Fut>(
48        name: impl Into<String>,
49        description: impl Into<String>,
50        input_schema: JsonValue,
51        executor: F,
52    ) -> Self
53    where
54        F: Fn(HashMap<String, JsonValue>) -> Fut + Send + Sync + 'static,
55        Fut: std::future::Future<Output = ToolExecutionResult> + Send + 'static,
56    {
57        Self {
58            name: name.into(),
59            description: description.into(),
60            input_schema,
61            executor: Arc::new(move |params| Box::pin(executor(params))),
62        }
63    }
64}
65
66#[async_trait]
67impl Tool for NativeTool {
68    fn name(&self) -> &str {
69        &self.name
70    }
71
72    fn description(&self) -> &str {
73        &self.description
74    }
75
76    fn input_schema(&self) -> JsonValue {
77        self.input_schema.clone()
78    }
79
80    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
81        (self.executor)(params).await
82    }
83}
84
85impl std::fmt::Debug for NativeTool {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("NativeTool")
88            .field("name", &self.name)
89            .field("description", &self.description)
90            .field("input_schema", &self.input_schema)
91            .finish()
92    }
93}
94
95/// Registry for managing tools
96///
97/// The registry stores all available tools and provides methods to
98/// list, query, and execute them.
99pub struct ToolRegistry {
100    tools: HashMap<String, Arc<dyn Tool>>,
101}
102
103impl ToolRegistry {
104    /// Create a new empty tool registry
105    pub fn new() -> Self {
106        Self {
107            tools: HashMap::new(),
108        }
109    }
110
111    /// Register a tool
112    ///
113    /// # Arguments
114    /// * `tool` - The tool to register (must implement `Tool` trait)
115    ///
116    /// # Returns
117    /// * `&mut Self` - For method chaining
118    pub fn register(&mut self, tool: impl Tool + 'static) -> &mut Self {
119        let name = tool.name().to_string();
120        self.tools.insert(name, Arc::new(tool));
121        self
122    }
123
124    /// Get a tool by name
125    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
126        self.tools.get(name)
127    }
128
129    /// List all tool names
130    pub fn list_names(&self) -> Vec<String> {
131        self.tools.keys().cloned().collect()
132    }
133
134    /// List all tools with their schemas (for LLM function calling)
135    pub fn list_tools(&self) -> Vec<JsonValue> {
136        self.tools
137            .values()
138            .map(|tool| {
139                serde_json::json!({
140                    "type": "function",
141                    "function": {
142                        "name": tool.name(),
143                        "description": tool.description(),
144                        "parameters": tool.input_schema(),
145                    }
146                })
147            })
148            .collect()
149    }
150
151    /// Call a tool by name with the given parameters
152    pub async fn call_tool(
153        &self,
154        name: &str,
155        params: HashMap<String, JsonValue>,
156    ) -> ToolExecutionResult {
157        match self.tools.get(name) {
158            Some(tool) => tool.execute(params).await,
159            None => Err(ToolError::InvalidParameters(format!(
160                "Tool not found: {}",
161                name
162            ))),
163        }
164    }
165
166    /// Check if a tool exists
167    pub fn has_tool(&self, name: &str) -> bool {
168        self.tools.contains_key(name)
169    }
170
171    /// Get the number of registered tools
172    pub fn len(&self) -> usize {
173        self.tools.len()
174    }
175
176    /// Check if the registry is empty
177    pub fn is_empty(&self) -> bool {
178        self.tools.is_empty()
179    }
180}
181
182impl Default for ToolRegistry {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl std::fmt::Debug for ToolRegistry {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("ToolRegistry")
191            .field("tool_count", &self.tools.len())
192            .field("tools", &self.tools.keys().collect::<Vec<_>>())
193            .finish()
194    }
195}
196
197/// Example: Echo tool that returns its input
198pub struct EchoTool;
199
200#[async_trait]
201impl Tool for EchoTool {
202    fn name(&self) -> &str {
203        "echo"
204    }
205
206    fn description(&self) -> &str {
207        "Echoes back the input message"
208    }
209
210    fn input_schema(&self) -> JsonValue {
211        serde_json::json!({
212            "type": "object",
213            "properties": {
214                "message": {
215                    "type": "string",
216                    "description": "The message to echo"
217                }
218            },
219            "required": ["message"]
220        })
221    }
222
223    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
224        let start = std::time::Instant::now();
225
226        let message = params
227            .get("message")
228            .and_then(|v| v.as_str())
229            .ok_or_else(|| ToolError::InvalidParameters("missing 'message' parameter".into()))?;
230
231        let output = serde_json::json!({
232            "echoed": message
233        });
234
235        Ok(ToolResult::success(
236            output,
237            start.elapsed().as_secs_f64() * 1000.0,
238        ))
239    }
240}
241
242/// Example: Calculator tool for simple math
243pub struct CalculatorTool;
244
245#[async_trait]
246impl Tool for CalculatorTool {
247    fn name(&self) -> &str {
248        "calculator"
249    }
250
251    fn description(&self) -> &str {
252        "Performs basic arithmetic operations (add, subtract, multiply, divide)"
253    }
254
255    fn input_schema(&self) -> JsonValue {
256        serde_json::json!({
257            "type": "object",
258            "properties": {
259                "operation": {
260                    "type": "string",
261                    "enum": ["add", "subtract", "multiply", "divide"]
262                },
263                "a": { "type": "number" },
264                "b": { "type": "number" }
265            },
266            "required": ["operation", "a", "b"]
267        })
268    }
269
270    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
271        let start = std::time::Instant::now();
272
273        let operation = params
274            .get("operation")
275            .and_then(|v| v.as_str())
276            .ok_or_else(|| ToolError::InvalidParameters("missing 'operation'".into()))?;
277
278        let a = params
279            .get("a")
280            .and_then(|v| v.as_f64())
281            .ok_or_else(|| ToolError::InvalidParameters("missing 'a'".into()))?;
282
283        let b = params
284            .get("b")
285            .and_then(|v| v.as_f64())
286            .ok_or_else(|| ToolError::InvalidParameters("missing 'b'".into()))?;
287
288        let result = match operation {
289            "add" => a + b,
290            "subtract" => a - b,
291            "multiply" => a * b,
292            "divide" => {
293                if b == 0.0 {
294                    return Err(ToolError::ExecutionFailed("division by zero".into()));
295                }
296                a / b
297            }
298            _ => {
299                return Err(ToolError::InvalidParameters(format!(
300                    "unknown operation: {}",
301                    operation
302                )))
303            }
304        };
305
306        Ok(ToolResult::success(
307            serde_json::json!({ "result": result }),
308            start.elapsed().as_secs_f64() * 1000.0,
309        ))
310    }
311}