Skip to main content

rustant_tools/
registry.rs

1//! Tool Registry — manages dynamic tool registration, validation, and execution.
2//!
3//! Tools are registered at startup and can be added/removed at runtime.
4//! The registry provides tool definitions for the LLM and executes tool calls
5//! with proper validation and timeout handling.
6
7use async_trait::async_trait;
8use rustant_core::error::ToolError;
9use rustant_core::types::{RiskLevel, ToolDefinition, ToolOutput};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, info};
14
15/// Trait that all tools must implement.
16#[async_trait]
17pub trait Tool: Send + Sync {
18    /// The unique name of this tool.
19    fn name(&self) -> &str;
20
21    /// Human-readable description of what this tool does.
22    fn description(&self) -> &str;
23
24    /// JSON Schema for the tool's parameters.
25    fn parameters_schema(&self) -> serde_json::Value;
26
27    /// Execute the tool with the given arguments.
28    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError>;
29
30    /// The risk level of this tool.
31    fn risk_level(&self) -> RiskLevel;
32
33    /// Maximum execution time before timeout.
34    fn timeout(&self) -> Duration {
35        Duration::from_secs(30)
36    }
37}
38
39/// The tool registry holds all registered tools and handles execution.
40#[derive(Clone)]
41pub struct ToolRegistry {
42    tools: HashMap<String, Arc<dyn Tool>>,
43}
44
45impl ToolRegistry {
46    pub fn new() -> Self {
47        Self {
48            tools: HashMap::new(),
49        }
50    }
51
52    /// Register a tool. Returns error if a tool with the same name is already registered.
53    pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolError> {
54        let name = tool.name().to_string();
55        if self.tools.contains_key(&name) {
56            return Err(ToolError::AlreadyRegistered { name });
57        }
58        debug!(tool = %name, "Registering tool");
59        self.tools.insert(name, tool);
60        Ok(())
61    }
62
63    /// Unregister a tool by name.
64    pub fn unregister(&mut self, name: &str) -> Result<(), ToolError> {
65        if self.tools.remove(name).is_none() {
66            return Err(ToolError::NotFound {
67                name: name.to_string(),
68            });
69        }
70        debug!(tool = %name, "Unregistered tool");
71        Ok(())
72    }
73
74    /// Get a tool by name.
75    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
76        self.tools.get(name).cloned()
77    }
78
79    /// List all registered tool definitions (for sending to LLM).
80    pub fn list_definitions(&self) -> Vec<ToolDefinition> {
81        self.tools
82            .values()
83            .map(|tool| ToolDefinition {
84                name: tool.name().to_string(),
85                description: tool.description().to_string(),
86                parameters: tool.parameters_schema(),
87            })
88            .collect()
89    }
90
91    /// List all registered tool names.
92    pub fn list_names(&self) -> Vec<String> {
93        self.tools.keys().cloned().collect()
94    }
95
96    /// Get the risk level of a tool by name.
97    pub fn get_risk_level(&self, name: &str) -> Option<RiskLevel> {
98        self.tools.get(name).map(|t| t.risk_level())
99    }
100
101    /// Get the parameters schema for a tool by name.
102    pub fn get_parameters_schema(&self, name: &str) -> Option<serde_json::Value> {
103        self.tools.get(name).map(|t| t.parameters_schema())
104    }
105
106    /// Get the number of registered tools.
107    pub fn len(&self) -> usize {
108        self.tools.len()
109    }
110
111    /// Check if the registry is empty.
112    pub fn is_empty(&self) -> bool {
113        self.tools.is_empty()
114    }
115
116    /// Execute a tool by name with the given arguments, applying timeout.
117    pub async fn execute(
118        &self,
119        name: &str,
120        args: serde_json::Value,
121    ) -> Result<ToolOutput, ToolError> {
122        let tool = self.tools.get(name).ok_or_else(|| ToolError::NotFound {
123            name: name.to_string(),
124        })?;
125
126        let timeout = tool.timeout();
127        info!(tool = %name, timeout_secs = timeout.as_secs(), "Executing tool");
128
129        match tokio::time::timeout(timeout, tool.execute(args)).await {
130            Ok(result) => result,
131            Err(_) => Err(ToolError::Timeout {
132                name: name.to_string(),
133                timeout_secs: timeout.as_secs(),
134            }),
135        }
136    }
137}
138
139impl Default for ToolRegistry {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    /// A simple echo tool for testing.
150    struct EchoTool;
151
152    #[async_trait]
153    impl Tool for EchoTool {
154        fn name(&self) -> &str {
155            "echo"
156        }
157
158        fn description(&self) -> &str {
159            "Echoes the input text back"
160        }
161
162        fn parameters_schema(&self) -> serde_json::Value {
163            serde_json::json!({
164                "type": "object",
165                "properties": {
166                    "text": { "type": "string", "description": "Text to echo" }
167                },
168                "required": ["text"]
169            })
170        }
171
172        async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
173            let text = args["text"]
174                .as_str()
175                .ok_or_else(|| ToolError::InvalidArguments {
176                    name: "echo".to_string(),
177                    reason: "missing 'text' parameter".to_string(),
178                })?;
179            Ok(ToolOutput::text(format!("Echo: {}", text)))
180        }
181
182        fn risk_level(&self) -> RiskLevel {
183            RiskLevel::ReadOnly
184        }
185    }
186
187    /// A slow tool for timeout testing.
188    struct SlowTool;
189
190    #[async_trait]
191    impl Tool for SlowTool {
192        fn name(&self) -> &str {
193            "slow"
194        }
195
196        fn description(&self) -> &str {
197            "A tool that takes forever"
198        }
199
200        fn parameters_schema(&self) -> serde_json::Value {
201            serde_json::json!({"type": "object"})
202        }
203
204        async fn execute(&self, _args: serde_json::Value) -> Result<ToolOutput, ToolError> {
205            tokio::time::sleep(Duration::from_secs(60)).await;
206            Ok(ToolOutput::text("done"))
207        }
208
209        fn risk_level(&self) -> RiskLevel {
210            RiskLevel::ReadOnly
211        }
212
213        fn timeout(&self) -> Duration {
214            Duration::from_millis(100) // Very short timeout for testing
215        }
216    }
217
218    #[test]
219    fn test_registry_new() {
220        let registry = ToolRegistry::new();
221        assert!(registry.is_empty());
222        assert_eq!(registry.len(), 0);
223    }
224
225    #[test]
226    fn test_register_tool() {
227        let mut registry = ToolRegistry::new();
228        let tool: Arc<dyn Tool> = Arc::new(EchoTool);
229        registry.register(tool).unwrap();
230
231        assert_eq!(registry.len(), 1);
232        assert!(!registry.is_empty());
233        assert!(registry.get("echo").is_some());
234    }
235
236    #[test]
237    fn test_register_duplicate() {
238        let mut registry = ToolRegistry::new();
239        registry.register(Arc::new(EchoTool)).unwrap();
240
241        let result = registry.register(Arc::new(EchoTool));
242        assert!(result.is_err());
243        match result.unwrap_err() {
244            ToolError::AlreadyRegistered { name } => assert_eq!(name, "echo"),
245            _ => panic!("Expected AlreadyRegistered error"),
246        }
247    }
248
249    #[test]
250    fn test_unregister_tool() {
251        let mut registry = ToolRegistry::new();
252        registry.register(Arc::new(EchoTool)).unwrap();
253        assert_eq!(registry.len(), 1);
254
255        registry.unregister("echo").unwrap();
256        assert_eq!(registry.len(), 0);
257        assert!(registry.get("echo").is_none());
258    }
259
260    #[test]
261    fn test_unregister_nonexistent() {
262        let mut registry = ToolRegistry::new();
263        let result = registry.unregister("nonexistent");
264        assert!(result.is_err());
265    }
266
267    #[test]
268    fn test_list_definitions() {
269        let mut registry = ToolRegistry::new();
270        registry.register(Arc::new(EchoTool)).unwrap();
271
272        let defs = registry.list_definitions();
273        assert_eq!(defs.len(), 1);
274        assert_eq!(defs[0].name, "echo");
275        assert_eq!(defs[0].description, "Echoes the input text back");
276    }
277
278    #[test]
279    fn test_list_names() {
280        let mut registry = ToolRegistry::new();
281        registry.register(Arc::new(EchoTool)).unwrap();
282
283        let names = registry.list_names();
284        assert_eq!(names, vec!["echo"]);
285    }
286
287    #[tokio::test]
288    async fn test_execute_tool() {
289        let mut registry = ToolRegistry::new();
290        registry.register(Arc::new(EchoTool)).unwrap();
291
292        let result = registry
293            .execute("echo", serde_json::json!({"text": "hello"}))
294            .await
295            .unwrap();
296        assert_eq!(result.content, "Echo: hello");
297    }
298
299    #[tokio::test]
300    async fn test_execute_nonexistent_tool() {
301        let registry = ToolRegistry::new();
302        let result = registry.execute("missing", serde_json::json!({})).await;
303        assert!(result.is_err());
304        match result.unwrap_err() {
305            ToolError::NotFound { name } => assert_eq!(name, "missing"),
306            _ => panic!("Expected NotFound error"),
307        }
308    }
309
310    #[tokio::test]
311    async fn test_execute_invalid_args() {
312        let mut registry = ToolRegistry::new();
313        registry.register(Arc::new(EchoTool)).unwrap();
314
315        // Missing required 'text' parameter
316        let result = registry.execute("echo", serde_json::json!({})).await;
317        assert!(result.is_err());
318    }
319
320    #[tokio::test]
321    async fn test_execute_timeout() {
322        let mut registry = ToolRegistry::new();
323        registry.register(Arc::new(SlowTool)).unwrap();
324
325        let result = registry.execute("slow", serde_json::json!({})).await;
326        assert!(result.is_err());
327        match result.unwrap_err() {
328            ToolError::Timeout { name, .. } => assert_eq!(name, "slow"),
329            e => panic!("Expected Timeout error, got: {:?}", e),
330        }
331    }
332
333    #[test]
334    fn test_get_nonexistent() {
335        let registry = ToolRegistry::new();
336        assert!(registry.get("missing").is_none());
337    }
338}