agent_sdk/
tools.rs

1//! Tool definition and registry.
2//!
3//! Tools allow the LLM to perform actions in the real world. This module provides:
4//!
5//! - [`Tool`] trait - Define custom tools the LLM can call
6//! - [`ToolRegistry`] - Collection of available tools
7//! - [`ToolContext`] - Context passed to tool execution
8//!
9//! # Implementing a Tool
10//!
11//! ```ignore
12//! use agent_sdk::{Tool, ToolContext, ToolResult, ToolTier};
13//!
14//! struct MyTool;
15//!
16//! #[async_trait]
17//! impl Tool<MyContext> for MyTool {
18//!     fn name(&self) -> &str { "my_tool" }
19//!     fn description(&self) -> &str { "Does something useful" }
20//!     fn input_schema(&self) -> Value { json!({ "type": "object" }) }
21//!     fn tier(&self) -> ToolTier { ToolTier::Observe }
22//!
23//!     async fn execute(&self, ctx: &ToolContext<MyContext>, input: Value) -> Result<ToolResult> {
24//!         Ok(ToolResult::success("Done!"))
25//!     }
26//! }
27//! ```
28
29use crate::events::AgentEvent;
30use crate::llm;
31use crate::types::{ToolResult, ToolTier};
32use anyhow::Result;
33use async_trait::async_trait;
34use serde_json::Value;
35use std::collections::HashMap;
36use std::sync::Arc;
37use tokio::sync::mpsc;
38
39/// Context passed to tool execution
40pub struct ToolContext<Ctx> {
41    /// Application-specific context (e.g., `user_id`, db connection)
42    pub app: Ctx,
43    /// Tool-specific metadata
44    pub metadata: HashMap<String, Value>,
45    /// Optional channel for tools to emit events (e.g., subagent progress)
46    event_tx: Option<mpsc::Sender<AgentEvent>>,
47}
48
49impl<Ctx> ToolContext<Ctx> {
50    #[must_use]
51    pub fn new(app: Ctx) -> Self {
52        Self {
53            app,
54            metadata: HashMap::new(),
55            event_tx: None,
56        }
57    }
58
59    #[must_use]
60    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
61        self.metadata.insert(key.into(), value);
62        self
63    }
64
65    /// Set the event channel for tools that need to emit events during execution.
66    #[must_use]
67    pub fn with_event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
68        self.event_tx = Some(tx);
69        self
70    }
71
72    /// Emit an event through the event channel (if set).
73    ///
74    /// This uses `try_send` to avoid blocking and to ensure the future is `Send`.
75    /// The event is silently dropped if the channel is full.
76    pub fn emit_event(&self, event: AgentEvent) {
77        if let Some(tx) = &self.event_tx {
78            let _ = tx.try_send(event);
79        }
80    }
81
82    /// Get a clone of the event channel sender (if set).
83    ///
84    /// This is useful for tools that spawn subprocesses (like subagents)
85    /// and need to forward events to the parent's event stream.
86    #[must_use]
87    pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEvent>> {
88        self.event_tx.clone()
89    }
90}
91
92/// Definition of a tool that can be called by the agent
93#[async_trait]
94pub trait Tool<Ctx>: Send + Sync {
95    /// Unique name for the tool (used in LLM tool calls)
96    fn name(&self) -> &str;
97
98    /// Human-readable description of what the tool does
99    fn description(&self) -> &str;
100
101    /// JSON schema for the tool's input parameters
102    fn input_schema(&self) -> Value;
103
104    /// Permission tier for this tool
105    fn tier(&self) -> ToolTier {
106        ToolTier::Observe
107    }
108
109    /// Execute the tool with the given input
110    ///
111    /// # Errors
112    /// Returns an error if tool execution fails.
113    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
114}
115
116/// Registry of available tools
117pub struct ToolRegistry<Ctx> {
118    tools: HashMap<String, Arc<dyn Tool<Ctx>>>,
119}
120
121impl<Ctx> Clone for ToolRegistry<Ctx> {
122    fn clone(&self) -> Self {
123        Self {
124            tools: self.tools.clone(),
125        }
126    }
127}
128
129impl<Ctx> Default for ToolRegistry<Ctx> {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl<Ctx> ToolRegistry<Ctx> {
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            tools: HashMap::new(),
140        }
141    }
142
143    /// Register a tool in the registry
144    pub fn register<T: Tool<Ctx> + 'static>(&mut self, tool: T) -> &mut Self {
145        self.tools.insert(tool.name().to_string(), Arc::new(tool));
146        self
147    }
148
149    /// Register a boxed tool
150    pub fn register_boxed(&mut self, tool: Arc<dyn Tool<Ctx>>) -> &mut Self {
151        self.tools.insert(tool.name().to_string(), tool);
152        self
153    }
154
155    /// Get a tool by name
156    #[must_use]
157    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool<Ctx>>> {
158        self.tools.get(name)
159    }
160
161    /// Get all registered tools
162    pub fn all(&self) -> impl Iterator<Item = &Arc<dyn Tool<Ctx>>> {
163        self.tools.values()
164    }
165
166    /// Get the number of registered tools
167    #[must_use]
168    pub fn len(&self) -> usize {
169        self.tools.len()
170    }
171
172    /// Check if the registry is empty
173    #[must_use]
174    pub fn is_empty(&self) -> bool {
175        self.tools.is_empty()
176    }
177
178    /// Filter tools by a predicate.
179    ///
180    /// Removes tools for which the predicate returns false.
181    /// The predicate receives the tool name.
182    ///
183    /// # Example
184    ///
185    /// ```ignore
186    /// registry.filter(|name| name != "bash");
187    /// ```
188    pub fn filter<F>(&mut self, predicate: F)
189    where
190        F: Fn(&str) -> bool,
191    {
192        self.tools.retain(|name, _| predicate(name));
193    }
194
195    /// Convert tools to LLM tool definitions
196    #[must_use]
197    pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
198        self.tools
199            .values()
200            .map(|tool| llm::Tool {
201                name: tool.name().to_string(),
202                description: tool.description().to_string(),
203                input_schema: tool.input_schema(),
204            })
205            .collect()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    struct MockTool;
214
215    #[async_trait]
216    impl Tool<()> for MockTool {
217        fn name(&self) -> &'static str {
218            "mock_tool"
219        }
220
221        fn description(&self) -> &'static str {
222            "A mock tool for testing"
223        }
224
225        fn input_schema(&self) -> Value {
226            serde_json::json!({
227                "type": "object",
228                "properties": {
229                    "message": { "type": "string" }
230                }
231            })
232        }
233
234        async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
235            let message = input
236                .get("message")
237                .and_then(|v| v.as_str())
238                .unwrap_or("no message");
239            Ok(ToolResult::success(format!("Received: {message}")))
240        }
241    }
242
243    #[test]
244    fn test_tool_registry() {
245        let mut registry = ToolRegistry::new();
246        registry.register(MockTool);
247
248        assert_eq!(registry.len(), 1);
249        assert!(registry.get("mock_tool").is_some());
250        assert!(registry.get("nonexistent").is_none());
251    }
252
253    #[test]
254    fn test_to_llm_tools() {
255        let mut registry = ToolRegistry::new();
256        registry.register(MockTool);
257
258        let llm_tools = registry.to_llm_tools();
259        assert_eq!(llm_tools.len(), 1);
260        assert_eq!(llm_tools[0].name, "mock_tool");
261    }
262
263    struct AnotherTool;
264
265    #[async_trait]
266    impl Tool<()> for AnotherTool {
267        fn name(&self) -> &'static str {
268            "another_tool"
269        }
270
271        fn description(&self) -> &'static str {
272            "Another tool for testing"
273        }
274
275        fn input_schema(&self) -> Value {
276            serde_json::json!({ "type": "object" })
277        }
278
279        async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
280            Ok(ToolResult::success("Done"))
281        }
282    }
283
284    #[test]
285    fn test_filter_tools() {
286        let mut registry = ToolRegistry::new();
287        registry.register(MockTool);
288        registry.register(AnotherTool);
289
290        assert_eq!(registry.len(), 2);
291
292        // Filter out mock_tool
293        registry.filter(|name| name != "mock_tool");
294
295        assert_eq!(registry.len(), 1);
296        assert!(registry.get("mock_tool").is_none());
297        assert!(registry.get("another_tool").is_some());
298    }
299
300    #[test]
301    fn test_filter_tools_keep_all() {
302        let mut registry = ToolRegistry::new();
303        registry.register(MockTool);
304        registry.register(AnotherTool);
305
306        registry.filter(|_| true);
307
308        assert_eq!(registry.len(), 2);
309    }
310
311    #[test]
312    fn test_filter_tools_remove_all() {
313        let mut registry = ToolRegistry::new();
314        registry.register(MockTool);
315        registry.register(AnotherTool);
316
317        registry.filter(|_| false);
318
319        assert!(registry.is_empty());
320    }
321}