1use std::collections::HashMap;
5use crate::error::AgentError;
6use crate::types::ToolResult;
7
8pub type ToolFn = Box<dyn Fn(&str) -> ToolResult>;
13
14#[derive(Debug, Clone)]
16pub struct ToolSpec {
17 pub name: String,
19 pub description: String,
21 pub input_schema: String,
23}
24
25impl ToolSpec {
26 pub fn new(name: impl Into<String>, description: impl Into<String>, schema: impl Into<String>) -> Self {
33 Self { name: name.into(), description: description.into(), input_schema: schema.into() }
34 }
35}
36
37pub struct ToolRegistry {
39 specs: HashMap<String, ToolSpec>,
40 handlers: HashMap<String, ToolFn>,
41}
42
43impl ToolRegistry {
44 pub fn new() -> Self {
46 Self { specs: HashMap::new(), handlers: HashMap::new() }
47 }
48
49 pub fn register(&mut self, spec: ToolSpec, handler: ToolFn) -> Result<(), AgentError> {
54 if spec.name.is_empty() {
55 return Err(AgentError::InvalidToolSignature("tool name cannot be empty".into()));
56 }
57 self.specs.insert(spec.name.clone(), spec.clone());
58 self.handlers.insert(spec.name, handler);
59 Ok(())
60 }
61
62 pub fn dispatch(&self, tool_name: &str, input: &str) -> Result<ToolResult, AgentError> {
67 let handler = self.handlers.get(tool_name)
68 .ok_or_else(|| AgentError::ToolNotFound { name: tool_name.to_string() })?;
69 Ok(handler(input))
70 }
71
72 pub fn spec(&self, name: &str) -> Option<&ToolSpec> { self.specs.get(name) }
74
75 pub fn tool_count(&self) -> usize { self.specs.len() }
77
78 pub fn tool_names(&self) -> Vec<&str> { self.specs.keys().map(|s| s.as_str()).collect() }
80
81 pub fn tools_prompt(&self) -> String {
83 let mut lines = vec!["Available tools:".to_string()];
84 for spec in self.specs.values() {
85 lines.push(format!("- {}: {}", spec.name, spec.description));
86 }
87 lines.join("\n")
88 }
89}
90
91impl Default for ToolRegistry {
92 fn default() -> Self { Self::new() }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 fn echo_tool() -> (ToolSpec, ToolFn) {
100 let spec = ToolSpec::new("echo", "Echoes input back", r#"{"type":"string"}"#);
101 let handler: ToolFn = Box::new(|input: &str| ToolResult {
102 tool_name: "echo".into(),
103 output: format!("echo: {input}"),
104 success: true,
105 });
106 (spec, handler)
107 }
108
109 #[test]
110 fn test_registry_register_and_dispatch_ok() {
111 let mut reg = ToolRegistry::new();
112 let (spec, handler) = echo_tool();
113 reg.register(spec, handler).unwrap();
114 let result = reg.dispatch("echo", "hello").unwrap();
115 assert!(result.success);
116 assert_eq!(result.output, "echo: hello");
117 }
118
119 #[test]
120 fn test_registry_dispatch_unknown_tool_returns_error() {
121 let reg = ToolRegistry::new();
122 let err = reg.dispatch("nonexistent", "").unwrap_err();
123 assert!(matches!(err, AgentError::ToolNotFound { .. }));
124 }
125
126 #[test]
127 fn test_registry_register_empty_name_returns_error() {
128 let mut reg = ToolRegistry::new();
129 let spec = ToolSpec::new("", "bad", "{}");
130 let err = reg.register(spec, Box::new(|_| ToolResult {
131 tool_name: "".into(), output: "".into(), success: false,
132 })).unwrap_err();
133 assert!(matches!(err, AgentError::InvalidToolSignature(_)));
134 }
135
136 #[test]
137 fn test_registry_tool_count_increments() {
138 let mut reg = ToolRegistry::new();
139 assert_eq!(reg.tool_count(), 0);
140 let (spec, handler) = echo_tool();
141 reg.register(spec, handler).unwrap();
142 assert_eq!(reg.tool_count(), 1);
143 }
144
145 #[test]
146 fn test_registry_tools_prompt_contains_tool_name() {
147 let mut reg = ToolRegistry::new();
148 let (spec, handler) = echo_tool();
149 reg.register(spec, handler).unwrap();
150 assert!(reg.tools_prompt().contains("echo"));
151 }
152
153 #[test]
154 fn test_registry_spec_retrieval_present_and_absent() {
155 let mut reg = ToolRegistry::new();
156 let (spec, handler) = echo_tool();
157 reg.register(spec, handler).unwrap();
158 assert!(reg.spec("echo").is_some());
159 assert!(reg.spec("missing").is_none());
160 }
161
162 #[test]
163 fn test_registry_tool_names_lists_all() {
164 let mut reg = ToolRegistry::new();
165 let (spec, handler) = echo_tool();
166 reg.register(spec, handler).unwrap();
167 let names = reg.tool_names();
168 assert!(names.contains(&"echo"));
169 }
170
171 #[test]
172 fn test_registry_default_is_empty() {
173 let reg = ToolRegistry::default();
174 assert_eq!(reg.tool_count(), 0);
175 }
176}