Skip to main content

limit_agent/
registry.rs

1use crate::error::AgentError;
2use crate::tool::Tool;
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7/// Registry for managing tools
8pub struct ToolRegistry {
9    tools: HashMap<String, Arc<dyn Tool>>,
10}
11
12impl ToolRegistry {
13    /// Create a new empty tool registry
14    pub fn new() -> Self {
15        ToolRegistry {
16            tools: HashMap::new(),
17        }
18    }
19
20    /// Register a tool with the registry
21    pub fn register<T>(&mut self, tool: T) -> Result<(), AgentError>
22    where
23        T: Tool + 'static,
24    {
25        let name = tool.name().to_string();
26        self.tools.insert(name, Arc::new(tool));
27        Ok(())
28    }
29
30    /// Get a tool by name
31    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
32        self.tools.get(name).cloned()
33    }
34
35    /// List all registered tool names
36    pub fn list(&self) -> Vec<String> {
37        let mut names: Vec<String> = self.tools.keys().cloned().collect();
38        names.sort();
39        names
40    }
41
42    /// Execute a tool by name with the given arguments
43    pub async fn execute(&self, name: &str, args: Value) -> Result<Value, AgentError> {
44        let tool = self
45            .get(name)
46            .ok_or_else(|| AgentError::ToolError(format!("Tool '{}' not found", name)))?;
47        tool.execute(args).await
48    }
49}
50
51impl Default for ToolRegistry {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::tool::EchoTool;
61
62    #[tokio::test]
63    async fn test_registry_new() {
64        let registry = ToolRegistry::new();
65        assert_eq!(registry.list().len(), 0);
66    }
67
68    #[tokio::test]
69    async fn test_registry_default() {
70        let registry = ToolRegistry::default();
71        assert_eq!(registry.list().len(), 0);
72    }
73
74    #[tokio::test]
75    async fn test_registry_register() {
76        let mut registry = ToolRegistry::new();
77        registry.register(EchoTool::new()).unwrap();
78
79        assert_eq!(registry.list().len(), 1);
80        assert_eq!(registry.list()[0], "echo");
81    }
82
83    #[tokio::test]
84    async fn test_registry_get() {
85        let mut registry = ToolRegistry::new();
86        registry.register(EchoTool::new()).unwrap();
87
88        let tool = registry.get("echo");
89        assert!(tool.is_some());
90        assert_eq!(tool.unwrap().name(), "echo");
91    }
92
93    #[tokio::test]
94    async fn test_registry_get_nonexistent() {
95        let registry = ToolRegistry::new();
96        let tool = registry.get("nonexistent");
97        assert!(tool.is_none());
98    }
99
100    #[tokio::test]
101    async fn test_registry_list() {
102        let mut registry = ToolRegistry::new();
103        registry.register(EchoTool::new()).unwrap();
104
105        let names = registry.list();
106        assert_eq!(names.len(), 1);
107        assert_eq!(names[0], "echo");
108    }
109
110    #[tokio::test]
111    async fn test_registry_execute() {
112        let mut registry = ToolRegistry::new();
113        registry.register(EchoTool::new()).unwrap();
114
115        let args = serde_json::json!({"test": "value"});
116        let result = registry.execute("echo", args.clone()).await.unwrap();
117        assert_eq!(result, args);
118    }
119
120    #[tokio::test]
121    async fn test_registry_execute_nonexistent() {
122        let registry = ToolRegistry::new();
123
124        let args = serde_json::json!({"test": "value"});
125        let result = registry.execute("nonexistent", args).await;
126
127        assert!(result.is_err());
128        assert!(matches!(result.unwrap_err(), AgentError::ToolError(_)));
129    }
130
131    #[tokio::test]
132    async fn test_registry_multiple_tools() {
133        use async_trait::async_trait;
134
135        struct AnotherTool;
136
137        #[async_trait]
138        impl Tool for AnotherTool {
139            fn name(&self) -> &str {
140                "another"
141            }
142
143            async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
144                Ok(serde_json::json!({"status": "ok"}))
145            }
146        }
147
148        let mut registry = ToolRegistry::new();
149        registry.register(EchoTool::new()).unwrap();
150        registry.register(AnotherTool).unwrap();
151
152        let names = registry.list();
153        assert_eq!(names.len(), 2);
154        assert!(names.contains(&"echo".to_string()));
155        assert!(names.contains(&"another".to_string()));
156    }
157}