1use crate::agent_tool::Tool;
4use crate::tool::ToolDef;
5use indexmap::IndexMap;
6
7struct ProxyTool {
10 def: ToolDef,
11}
12
13impl ProxyTool {
14 fn from_def(def: ToolDef) -> Self {
15 Self { def }
16 }
17}
18
19#[async_trait::async_trait]
20impl Tool for ProxyTool {
21 fn name(&self) -> &str {
22 &self.def.name
23 }
24 fn description(&self) -> &str {
25 &self.def.description
26 }
27 fn parameters_schema(&self) -> serde_json::Value {
28 self.def.parameters.clone()
29 }
30 async fn execute(
31 &self,
32 _: serde_json::Value,
33 _: &mut crate::context::AgentContext,
34 ) -> Result<crate::agent_tool::ToolOutput, crate::agent_tool::ToolError> {
35 Err(crate::agent_tool::ToolError::Execution(
36 "ProxyTool cannot execute — use the original registry".into(),
37 ))
38 }
39}
40
41#[derive(Debug, Clone)]
43pub enum ResolveError {
44 Deferred(String),
46 NotFound(String),
48}
49
50impl std::fmt::Display for ResolveError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 ResolveError::Deferred(name) => write!(
54 f,
55 "Tool '{}' is deferred. Call tool_search to load its schema first.",
56 name
57 ),
58 ResolveError::NotFound(name) => write!(f, "Tool '{}' not found.", name),
59 }
60 }
61}
62
63impl std::error::Error for ResolveError {}
64
65pub struct ToolRegistry {
71 tools: IndexMap<String, Box<dyn Tool>>,
72 deferred: IndexMap<String, Box<dyn Tool>>,
75}
76
77impl ToolRegistry {
78 pub fn new() -> Self {
79 Self {
80 tools: IndexMap::new(),
81 deferred: IndexMap::new(),
82 }
83 }
84
85 pub fn register(mut self, tool: impl Tool + 'static) -> Self {
87 self.tools.insert(tool.name().to_string(), Box::new(tool));
88 self
89 }
90
91 pub fn add(&mut self, tool: impl Tool + 'static) {
93 self.tools.insert(tool.name().to_string(), Box::new(tool));
94 }
95
96 pub fn register_deferred(mut self, tool: impl Tool + 'static) -> Self {
100 self.deferred
101 .insert(tool.name().to_string(), Box::new(tool));
102 self
103 }
104
105 pub fn add_deferred(&mut self, tool: impl Tool + 'static) {
107 self.deferred
108 .insert(tool.name().to_string(), Box::new(tool));
109 }
110
111 pub fn promote_deferred(&mut self, name: &str) -> bool {
114 let lower = name.to_lowercase();
115 let key = self
116 .deferred
117 .keys()
118 .find(|k| k.to_lowercase() == lower)
119 .cloned();
120 if let Some(key) = key
121 && let Some(tool) = self.deferred.swap_remove(&key)
122 {
123 self.tools.insert(key, tool);
124 return true;
125 }
126 false
127 }
128
129 pub fn deferred_names(&self) -> Vec<&str> {
131 self.deferred.keys().map(|s| s.as_str()).collect()
132 }
133
134 pub fn is_deferred(&self, name: &str) -> bool {
136 let lower = name.to_lowercase();
137 self.deferred.keys().any(|k| k.to_lowercase() == lower)
138 }
139
140 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
142 let lower = name.to_lowercase();
143 self.tools
144 .iter()
145 .find(|(k, _)| k.to_lowercase() == lower)
146 .map(|(_, v)| v.as_ref())
147 }
148
149 pub fn list(&self) -> Vec<&dyn Tool> {
151 self.tools.values().map(|t| t.as_ref()).collect()
152 }
153
154 pub fn system_tools(&self) -> Vec<&dyn Tool> {
156 self.tools
157 .values()
158 .filter(|t| t.is_system())
159 .map(|t| t.as_ref())
160 .collect()
161 }
162
163 pub fn to_defs(&self) -> Vec<ToolDef> {
166 let mut defs: Vec<ToolDef> = self.tools.values().map(|t| t.to_def()).collect();
167 for tool in self.deferred.values() {
169 defs.push(ToolDef {
170 name: tool.name().to_string(),
171 description: tool.description().to_string(),
172 parameters: serde_json::json!({"type": "object", "properties": {}}),
173 });
174 }
175 defs
176 }
177
178 pub fn core_defs(&self) -> Vec<ToolDef> {
180 self.tools.values().map(|t| t.to_def()).collect()
181 }
182
183 pub fn resolve(&self, name: &str) -> Result<&dyn Tool, ResolveError> {
186 if let Some(t) = self.get(name) {
188 return Ok(t);
189 }
190 if self.is_deferred(name) {
192 return Err(ResolveError::Deferred(name.to_string()));
193 }
194 let lower = name.to_lowercase();
196 let mut best: Option<(&str, f64)> = None;
197 for key in self.tools.keys() {
198 let score = strsim::normalized_levenshtein(&lower, &key.to_lowercase());
199 if score > 0.6 && (best.is_none() || score > best.unwrap().1) {
200 best = Some((key.as_str(), score));
201 }
202 }
203 match best.and_then(|(k, _)| self.tools.get(k).map(|t| t.as_ref())) {
204 Some(t) => Ok(t),
205 None => Err(ResolveError::NotFound(name.to_string())),
206 }
207 }
208
209 pub fn len(&self) -> usize {
211 self.tools.len()
212 }
213
214 pub fn is_empty(&self) -> bool {
215 self.tools.is_empty()
216 }
217
218 pub fn filter(&self, names: &[String]) -> ToolRegistry {
221 self.clone_filtered(names)
227 }
228
229 fn clone_filtered(&self, names: &[String]) -> ToolRegistry {
231 let mut new_tools = IndexMap::new();
232 for name in names {
233 let lower = name.to_lowercase();
234 for (k, v) in &self.tools {
235 if k.to_lowercase() == lower {
236 new_tools.insert(k.clone(), ProxyTool::from_def(v.to_def()));
240 }
241 }
242 }
243 let mut reg = ToolRegistry::new();
244 for (_, tool) in new_tools {
245 reg.tools.insert(tool.def.name.clone(), Box::new(tool));
246 }
247 reg
248 }
249}
250
251impl Default for ToolRegistry {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::agent_tool::{ToolError, ToolOutput};
261 use crate::context::AgentContext;
262 use serde_json::Value;
263
264 struct MockTool {
265 tool_name: String,
266 desc: String,
267 system: bool,
268 }
269
270 impl MockTool {
271 fn new(name: &str, desc: &str) -> Self {
272 Self {
273 tool_name: name.into(),
274 desc: desc.into(),
275 system: false,
276 }
277 }
278 fn system(name: &str, desc: &str) -> Self {
279 Self {
280 tool_name: name.into(),
281 desc: desc.into(),
282 system: true,
283 }
284 }
285 }
286
287 #[async_trait::async_trait]
288 impl Tool for MockTool {
289 fn name(&self) -> &str {
290 &self.tool_name
291 }
292 fn description(&self) -> &str {
293 &self.desc
294 }
295 fn is_system(&self) -> bool {
296 self.system
297 }
298 fn parameters_schema(&self) -> Value {
299 serde_json::json!({"type": "object"})
300 }
301 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
302 Ok(ToolOutput::text("ok"))
303 }
304 }
305
306 #[test]
307 fn registry_builder() {
308 let reg = ToolRegistry::new()
309 .register(MockTool::new("read_file", "Read a file"))
310 .register(MockTool::new("write_file", "Write a file"));
311 assert_eq!(reg.len(), 2);
312 }
313
314 #[test]
315 fn registry_get_case_insensitive() {
316 let reg = ToolRegistry::new().register(MockTool::new("ReadFile", "Read"));
317 assert!(reg.get("readfile").is_some());
318 assert!(reg.get("READFILE").is_some());
319 assert!(reg.get("ReadFile").is_some());
320 }
321
322 #[test]
323 fn registry_list_preserves_order() {
324 let reg = ToolRegistry::new()
325 .register(MockTool::new("alpha", "a"))
326 .register(MockTool::new("beta", "b"))
327 .register(MockTool::new("gamma", "c"));
328 let names: Vec<_> = reg.list().iter().map(|t| t.name()).collect();
329 assert_eq!(names, vec!["alpha", "beta", "gamma"]);
330 }
331
332 #[test]
333 fn registry_system_tools() {
334 let reg = ToolRegistry::new()
335 .register(MockTool::new("read_file", "Read"))
336 .register(MockTool::system("finish", "Finish task"));
337 let sys = reg.system_tools();
338 assert_eq!(sys.len(), 1);
339 assert_eq!(sys[0].name(), "finish");
340 }
341
342 #[test]
343 fn registry_to_defs() {
344 let reg = ToolRegistry::new().register(MockTool::new("bash", "Run command"));
345 let defs = reg.to_defs();
346 assert_eq!(defs.len(), 1);
347 assert_eq!(defs[0].name, "bash");
348 }
349
350 #[test]
351 fn registry_fuzzy_resolve() {
352 let reg = ToolRegistry::new()
353 .register(MockTool::new("read_file", "Read"))
354 .register(MockTool::new("write_file", "Write"));
355 assert_eq!(reg.resolve("read_file").unwrap().name(), "read_file");
357 assert_eq!(reg.resolve("reed_file").unwrap().name(), "read_file");
359 assert!(reg.resolve("xyz").is_err());
361 }
362
363 #[test]
364 fn registry_add_mutable() {
365 let mut reg = ToolRegistry::new();
366 reg.add(MockTool::new("tool_a", "A"));
367 reg.add(MockTool::new("tool_b", "B"));
368 assert_eq!(reg.len(), 2);
369 }
370
371 #[test]
372 fn register_deferred_adds_to_deferred_map() {
373 let reg = ToolRegistry::new()
374 .register(MockTool::new("active", "Active tool"))
375 .register_deferred(MockTool::new("lazy", "Lazy tool"));
376 assert_eq!(reg.len(), 1); assert!(reg.is_deferred("lazy"));
378 assert!(!reg.is_deferred("active"));
379 assert_eq!(reg.deferred_names(), vec!["lazy"]);
380 }
381
382 #[test]
383 fn to_defs_returns_deferred_with_empty_params() {
384 let reg = ToolRegistry::new()
385 .register(MockTool::new("active", "Active tool"))
386 .register_deferred(MockTool::new("lazy", "Lazy tool"));
387 let defs = reg.to_defs();
388 assert_eq!(defs.len(), 2);
389 let active_def = defs.iter().find(|d| d.name == "active").unwrap();
391 assert!(active_def.parameters["type"] == "object");
392 let lazy_def = defs.iter().find(|d| d.name == "lazy").unwrap();
394 assert_eq!(lazy_def.description, "Lazy tool");
395 assert_eq!(lazy_def.parameters["properties"], serde_json::json!({}));
396 }
397
398 #[test]
399 fn promote_deferred_moves_to_active() {
400 let mut reg = ToolRegistry::new().register_deferred(MockTool::new("lazy", "Lazy tool"));
401 assert_eq!(reg.len(), 0);
402 assert!(reg.is_deferred("lazy"));
403
404 let promoted = reg.promote_deferred("lazy");
405 assert!(promoted);
406 assert_eq!(reg.len(), 1);
407 assert!(!reg.is_deferred("lazy"));
408 assert!(reg.get("lazy").is_some());
409 }
410
411 #[test]
412 fn promote_deferred_not_found() {
413 let mut reg = ToolRegistry::new();
414 assert!(!reg.promote_deferred("ghost"));
415 }
416
417 #[test]
418 fn resolve_deferred_returns_error() {
419 let reg = ToolRegistry::new()
420 .register(MockTool::new("active", "Active"))
421 .register_deferred(MockTool::new("lazy", "Lazy"));
422 assert!(reg.resolve("active").is_ok());
424 let result = reg.resolve("lazy");
426 assert!(result.is_err());
427 let err = match result {
428 Err(e) => e,
429 Ok(_) => panic!("expected Deferred error"),
430 };
431 assert!(matches!(err, ResolveError::Deferred(_)));
432 assert!(err.to_string().contains("tool_search"));
433 }
434
435 #[test]
436 fn resolve_deferred_after_promote() {
437 let mut reg = ToolRegistry::new().register_deferred(MockTool::new("lazy", "Lazy"));
438 assert!(reg.resolve("lazy").is_err());
439 reg.promote_deferred("lazy");
440 assert!(reg.resolve("lazy").is_ok());
441 }
442}