1use crate::events::AgentEvent;
30use crate::llm;
31use crate::types::{ToolResult, ToolTier};
32use anyhow::Result;
33use async_trait::async_trait;
34use serde_json::Value;
35use std::collections::HashMap;
36use std::sync::Arc;
37use tokio::sync::mpsc;
38
39pub struct ToolContext<Ctx> {
41 pub app: Ctx,
43 pub metadata: HashMap<String, Value>,
45 event_tx: Option<mpsc::Sender<AgentEvent>>,
47}
48
49impl<Ctx> ToolContext<Ctx> {
50 #[must_use]
51 pub fn new(app: Ctx) -> Self {
52 Self {
53 app,
54 metadata: HashMap::new(),
55 event_tx: None,
56 }
57 }
58
59 #[must_use]
60 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
61 self.metadata.insert(key.into(), value);
62 self
63 }
64
65 #[must_use]
67 pub fn with_event_tx(mut self, tx: mpsc::Sender<AgentEvent>) -> Self {
68 self.event_tx = Some(tx);
69 self
70 }
71
72 pub fn emit_event(&self, event: AgentEvent) {
77 if let Some(tx) = &self.event_tx {
78 let _ = tx.try_send(event);
79 }
80 }
81
82 #[must_use]
87 pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEvent>> {
88 self.event_tx.clone()
89 }
90}
91
92#[async_trait]
94pub trait Tool<Ctx>: Send + Sync {
95 fn name(&self) -> &str;
97
98 fn description(&self) -> &str;
100
101 fn input_schema(&self) -> Value;
103
104 fn tier(&self) -> ToolTier {
106 ToolTier::Observe
107 }
108
109 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
114}
115
116pub struct ToolRegistry<Ctx> {
118 tools: HashMap<String, Arc<dyn Tool<Ctx>>>,
119}
120
121impl<Ctx> Clone for ToolRegistry<Ctx> {
122 fn clone(&self) -> Self {
123 Self {
124 tools: self.tools.clone(),
125 }
126 }
127}
128
129impl<Ctx> Default for ToolRegistry<Ctx> {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl<Ctx> ToolRegistry<Ctx> {
136 #[must_use]
137 pub fn new() -> Self {
138 Self {
139 tools: HashMap::new(),
140 }
141 }
142
143 pub fn register<T: Tool<Ctx> + 'static>(&mut self, tool: T) -> &mut Self {
145 self.tools.insert(tool.name().to_string(), Arc::new(tool));
146 self
147 }
148
149 pub fn register_boxed(&mut self, tool: Arc<dyn Tool<Ctx>>) -> &mut Self {
151 self.tools.insert(tool.name().to_string(), tool);
152 self
153 }
154
155 #[must_use]
157 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool<Ctx>>> {
158 self.tools.get(name)
159 }
160
161 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn Tool<Ctx>>> {
163 self.tools.values()
164 }
165
166 #[must_use]
168 pub fn len(&self) -> usize {
169 self.tools.len()
170 }
171
172 #[must_use]
174 pub fn is_empty(&self) -> bool {
175 self.tools.is_empty()
176 }
177
178 pub fn filter<F>(&mut self, predicate: F)
189 where
190 F: Fn(&str) -> bool,
191 {
192 self.tools.retain(|name, _| predicate(name));
193 }
194
195 #[must_use]
197 pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
198 self.tools
199 .values()
200 .map(|tool| llm::Tool {
201 name: tool.name().to_string(),
202 description: tool.description().to_string(),
203 input_schema: tool.input_schema(),
204 })
205 .collect()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 struct MockTool;
214
215 #[async_trait]
216 impl Tool<()> for MockTool {
217 fn name(&self) -> &'static str {
218 "mock_tool"
219 }
220
221 fn description(&self) -> &'static str {
222 "A mock tool for testing"
223 }
224
225 fn input_schema(&self) -> Value {
226 serde_json::json!({
227 "type": "object",
228 "properties": {
229 "message": { "type": "string" }
230 }
231 })
232 }
233
234 async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
235 let message = input
236 .get("message")
237 .and_then(|v| v.as_str())
238 .unwrap_or("no message");
239 Ok(ToolResult::success(format!("Received: {message}")))
240 }
241 }
242
243 #[test]
244 fn test_tool_registry() {
245 let mut registry = ToolRegistry::new();
246 registry.register(MockTool);
247
248 assert_eq!(registry.len(), 1);
249 assert!(registry.get("mock_tool").is_some());
250 assert!(registry.get("nonexistent").is_none());
251 }
252
253 #[test]
254 fn test_to_llm_tools() {
255 let mut registry = ToolRegistry::new();
256 registry.register(MockTool);
257
258 let llm_tools = registry.to_llm_tools();
259 assert_eq!(llm_tools.len(), 1);
260 assert_eq!(llm_tools[0].name, "mock_tool");
261 }
262
263 struct AnotherTool;
264
265 #[async_trait]
266 impl Tool<()> for AnotherTool {
267 fn name(&self) -> &'static str {
268 "another_tool"
269 }
270
271 fn description(&self) -> &'static str {
272 "Another tool for testing"
273 }
274
275 fn input_schema(&self) -> Value {
276 serde_json::json!({ "type": "object" })
277 }
278
279 async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
280 Ok(ToolResult::success("Done"))
281 }
282 }
283
284 #[test]
285 fn test_filter_tools() {
286 let mut registry = ToolRegistry::new();
287 registry.register(MockTool);
288 registry.register(AnotherTool);
289
290 assert_eq!(registry.len(), 2);
291
292 registry.filter(|name| name != "mock_tool");
294
295 assert_eq!(registry.len(), 1);
296 assert!(registry.get("mock_tool").is_none());
297 assert!(registry.get("another_tool").is_some());
298 }
299
300 #[test]
301 fn test_filter_tools_keep_all() {
302 let mut registry = ToolRegistry::new();
303 registry.register(MockTool);
304 registry.register(AnotherTool);
305
306 registry.filter(|_| true);
307
308 assert_eq!(registry.len(), 2);
309 }
310
311 #[test]
312 fn test_filter_tools_remove_all() {
313 let mut registry = ToolRegistry::new();
314 registry.register(MockTool);
315 registry.register(AnotherTool);
316
317 registry.filter(|_| false);
318
319 assert!(registry.is_empty());
320 }
321}