1use serde_json::Value;
7use std::collections::HashMap;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum ToolError {
13 #[error("tool not found: {0}")]
14 NotFound(String),
15
16 #[error("invalid input: {0}")]
17 InvalidInput(String),
18
19 #[error("execution failed: {0}")]
20 ExecutionFailed(String),
21
22 #[error(transparent)]
23 Io(#[from] std::io::Error),
24
25 #[error("serialization error: {0}")]
26 Serialization(String),
27}
28
29impl From<serde_json::Error> for ToolError {
30 fn from(e: serde_json::Error) -> Self {
31 ToolError::Serialization(e.to_string())
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ToolCategory {
38 Search,
40 Edit,
42 Analysis,
44 Utility,
46}
47
48impl ToolCategory {
49 pub const fn as_str(&self) -> &'static str {
50 match self {
51 ToolCategory::Search => "search",
52 ToolCategory::Edit => "edit",
53 ToolCategory::Analysis => "analysis",
54 ToolCategory::Utility => "utility",
55 }
56 }
57}
58
59pub type ToolExecutor = fn(Value) -> Result<ToolOutput, ToolError>;
61
62#[derive(Clone)]
64pub struct ToolInfo {
65 pub name: &'static str,
67 pub description: &'static str,
69 pub category: ToolCategory,
71 pub example: &'static str,
73 pub execute: ToolExecutor,
75}
76
77#[derive(Debug, Clone)]
79pub struct ToolOutput {
80 pub success: bool,
82 pub result: Value,
84 pub summary: String,
86 pub duration_ms: u64,
88}
89
90impl ToolOutput {
91 pub fn success(result: Value, summary: impl Into<String>) -> Self {
93 ToolOutput {
94 success: true,
95 result,
96 summary: summary.into(),
97 duration_ms: 0,
98 }
99 }
100
101 pub fn failure(error: impl Into<String>) -> Self {
103 ToolOutput {
104 success: false,
105 result: Value::Null,
106 summary: error.into(),
107 duration_ms: 0,
108 }
109 }
110
111 pub const fn with_duration(mut self, duration_ms: u64) -> Self {
113 self.duration_ms = duration_ms;
114 self
115 }
116}
117
118pub struct ToolRegistry {
120 tools: HashMap<String, ToolInfo>,
121}
122
123impl ToolRegistry {
124 pub fn new() -> Self {
126 ToolRegistry {
127 tools: HashMap::new(),
128 }
129 }
130
131 pub fn register(&mut self, info: ToolInfo) {
133 self.tools.insert(info.name.to_string(), info);
134 }
135
136 pub fn list_tools(&self) -> Vec<&str> {
138 let mut tools: Vec<&str> = self.tools.keys().map(|s| s.as_str()).collect();
139 tools.sort();
140 tools
141 }
142
143 pub fn list_by_category(&self, category: ToolCategory) -> Vec<&str> {
145 let mut tools: Vec<&str> = self
146 .tools
147 .values()
148 .filter(|info| info.category == category)
149 .map(|info| info.name)
150 .collect();
151 tools.sort();
152 tools
153 }
154
155 pub fn get_info(&self, name: &str) -> Option<&ToolInfo> {
157 self.tools.get(name)
158 }
159
160 pub fn execute(&self, name: &str, input: Value) -> Result<ToolOutput, ToolError> {
162 let start = std::time::Instant::now();
163
164 let info = self
165 .tools
166 .get(name)
167 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
168
169 let mut output = (info.execute)(input)?;
170 output.duration_ms = start.elapsed().as_millis() as u64;
171
172 Ok(output)
173 }
174
175 pub fn get_manifest(&self) -> Value {
177 let tools: Vec<Value> = self
178 .tools
179 .values()
180 .map(|info| {
181 serde_json::json!({
182 "name": info.name,
183 "description": info.description,
184 "category": info.category.as_str(),
185 "example": info.example,
186 })
187 })
188 .collect();
189
190 serde_json::json!({
191 "version": "1.0",
192 "tools": tools,
193 "categories": ["search", "edit", "analysis", "utility"],
194 })
195 }
196}
197
198impl Default for ToolRegistry {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 fn dummy_tool(_input: Value) -> Result<ToolOutput, ToolError> {
209 Ok(ToolOutput::success(
210 serde_json::json!({"test": "result"}),
211 "Test completed",
212 ))
213 }
214
215 #[test]
216 fn test_registry_basic() {
217 let mut registry = ToolRegistry::new();
218
219 registry.register(ToolInfo {
220 name: "test",
221 description: "Test tool",
222 category: ToolCategory::Utility,
223 example: "{}",
224 execute: dummy_tool,
225 });
226
227 assert_eq!(registry.list_tools(), vec!["test"]);
228 assert!(registry.get_info("test").is_some());
229 assert!(registry.get_info("nonexistent").is_none());
230 }
231
232 #[test]
233 fn test_registry_execute() {
234 let mut registry = ToolRegistry::new();
235
236 registry.register(ToolInfo {
237 name: "test",
238 description: "Test tool",
239 category: ToolCategory::Utility,
240 example: "{}",
241 execute: dummy_tool,
242 });
243
244 let result = registry.execute("test", serde_json::json!({})).unwrap();
245 assert!(result.success);
246 assert_eq!(result.summary, "Test completed");
247 }
248
249 #[test]
250 fn test_registry_categories() {
251 let mut registry = ToolRegistry::new();
252
253 registry.register(ToolInfo {
254 name: "search",
255 description: "Search tool",
256 category: ToolCategory::Search,
257 example: "{}",
258 execute: dummy_tool,
259 });
260
261 registry.register(ToolInfo {
262 name: "edit",
263 description: "Edit tool",
264 category: ToolCategory::Edit,
265 example: "{}",
266 execute: dummy_tool,
267 });
268
269 assert_eq!(
270 registry.list_by_category(ToolCategory::Search),
271 vec!["search"]
272 );
273 assert_eq!(registry.list_by_category(ToolCategory::Edit), vec!["edit"]);
274 assert!(registry.list_by_category(ToolCategory::Analysis).is_empty());
275 }
276
277 #[test]
278 fn test_tool_not_found() {
279 let registry = ToolRegistry::new();
280 let result = registry.execute("nonexistent", serde_json::json!({}));
281
282 assert!(matches!(result, Err(ToolError::NotFound(_))));
283 }
284
285 #[test]
286 fn test_manifest() {
287 let mut registry = ToolRegistry::new();
288
289 registry.register(ToolInfo {
290 name: "test",
291 description: "Test tool",
292 category: ToolCategory::Utility,
293 example: r#"{"input": "test"}"#,
294 execute: dummy_tool,
295 });
296
297 let manifest = registry.get_manifest();
298 assert!(manifest["tools"].is_array());
299 assert_eq!(manifest["tools"][0]["name"], "test");
300 }
301}