Skip to main content

ai_agents_tools/
registry.rs

1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use ai_agents_core::{Tool, ToolInfo};
6
7use super::ToolError;
8use super::provider::{ProviderHealth, ToolProvider, ToolProviderError};
9use super::types::ToolAliases;
10
11#[derive(Clone)]
12enum ToolRef {
13    Builtin(Arc<dyn Tool>),
14    Provider {
15        provider_id: String,
16        tool: Arc<dyn Tool>,
17    },
18}
19
20pub struct ToolRegistry {
21    builtin_tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22
23    providers: RwLock<HashMap<String, Arc<dyn ToolProvider>>>,
24
25    tool_index: RwLock<HashMap<String, ToolRef>>,
26
27    alias_index: RwLock<HashMap<String, String>>,
28
29    builtin_aliases: RwLock<HashMap<String, ToolAliases>>,
30}
31
32impl ToolRegistry {
33    pub fn new() -> Self {
34        Self {
35            builtin_tools: RwLock::new(HashMap::new()),
36            providers: RwLock::new(HashMap::new()),
37            tool_index: RwLock::new(HashMap::new()),
38            alias_index: RwLock::new(HashMap::new()),
39            builtin_aliases: RwLock::new(HashMap::new()),
40        }
41    }
42
43    pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolError> {
44        let id = tool.id().to_string();
45
46        let mut builtin_tools = self.builtin_tools.write();
47        let mut tool_index = self.tool_index.write();
48
49        if builtin_tools.contains_key(&id) || tool_index.contains_key(&id) {
50            return Err(ToolError::Duplicate(id));
51        }
52
53        tool_index.insert(id.clone(), ToolRef::Builtin(tool.clone()));
54        builtin_tools.insert(id, tool);
55        Ok(())
56    }
57
58    pub fn get(&self, id_or_alias: &str) -> Option<Arc<dyn Tool>> {
59        let tool_index = self.tool_index.read();
60
61        // Try exact match first
62        if let Some(tool_ref) = tool_index.get(id_or_alias) {
63            return self.resolve_tool_ref(tool_ref);
64        }
65
66        // Try case-insensitive match on tool IDs and display names
67        // (LLM may return display name like "HTTP Client" but ID is "http")
68        let lower_input = id_or_alias.to_lowercase();
69        for (id, tool_ref) in tool_index.iter() {
70            if id.to_lowercase() == lower_input {
71                return self.resolve_tool_ref(tool_ref);
72            }
73            if let Some(tool) = self.resolve_tool_ref(tool_ref) {
74                if tool.name().to_lowercase() == lower_input {
75                    return Some(tool);
76                }
77            }
78        }
79
80        // Try alias lookup (case-insensitive)
81        let alias_index = self.alias_index.read();
82        for (alias_key, tool_id) in alias_index.iter() {
83            if alias_key.ends_with(&format!(":{}", lower_input)) {
84                if let Some(tool_ref) = tool_index.get(tool_id) {
85                    return self.resolve_tool_ref(tool_ref);
86                }
87            }
88        }
89
90        None
91    }
92
93    fn resolve_tool_ref(&self, tool_ref: &ToolRef) -> Option<Arc<dyn Tool>> {
94        match tool_ref {
95            ToolRef::Builtin(tool) => Some(tool.clone()),
96            ToolRef::Provider { tool, .. } => Some(tool.clone()),
97        }
98    }
99
100    pub fn list_ids(&self) -> Vec<String> {
101        self.tool_index.read().keys().cloned().collect()
102    }
103
104    pub fn list_infos(&self) -> Vec<ToolInfo> {
105        let tool_index = self.tool_index.read();
106        let mut infos = Vec::with_capacity(tool_index.len());
107
108        for tool_ref in tool_index.values() {
109            if let Some(tool) = self.resolve_tool_ref(tool_ref) {
110                infos.push(tool.info());
111            }
112        }
113
114        infos
115    }
116
117    pub fn len(&self) -> usize {
118        self.tool_index.read().len()
119    }
120
121    pub fn is_empty(&self) -> bool {
122        self.tool_index.read().is_empty()
123    }
124
125    pub async fn register_provider(
126        &self,
127        provider: Arc<dyn ToolProvider>,
128    ) -> Result<(), ToolError> {
129        let provider_id = provider.id().to_string();
130
131        {
132            let providers = self.providers.read();
133            if providers.contains_key(&provider_id) {
134                return Err(ToolError::Duplicate(format!("Provider: {}", provider_id)));
135            }
136        }
137
138        let tools = provider.list_tools().await;
139
140        {
141            let mut tool_index = self.tool_index.write();
142            let mut alias_index = self.alias_index.write();
143
144            for descriptor in &tools {
145                if tool_index.contains_key(&descriptor.id) {
146                    return Err(ToolError::Duplicate(descriptor.id.clone()));
147                }
148
149                if let Some(tool) = provider.get_tool(&descriptor.id).await {
150                    tool_index.insert(
151                        descriptor.id.clone(),
152                        ToolRef::Provider {
153                            provider_id: provider_id.clone(),
154                            tool,
155                        },
156                    );
157
158                    if let Some(ref aliases) = descriptor.aliases {
159                        for (lang, name) in &aliases.names {
160                            let key = format!("{}:{}", lang, name.to_lowercase());
161                            alias_index.insert(key, descriptor.id.clone());
162                        }
163                    }
164                }
165            }
166        }
167
168        self.providers.write().insert(provider_id, provider);
169
170        Ok(())
171    }
172
173    pub fn unregister_provider(&self, provider_id: &str) -> bool {
174        let removed = self.providers.write().remove(provider_id);
175
176        if removed.is_some() {
177            let mut tool_index = self.tool_index.write();
178            let mut alias_index = self.alias_index.write();
179
180            let tools_to_remove: Vec<String> = tool_index
181                .iter()
182                .filter_map(|(id, tool_ref)| {
183                    if let ToolRef::Provider {
184                        provider_id: pid, ..
185                    } = tool_ref
186                    {
187                        if pid == provider_id {
188                            return Some(id.clone());
189                        }
190                    }
191                    None
192                })
193                .collect();
194
195            for tool_id in &tools_to_remove {
196                tool_index.remove(tool_id);
197            }
198
199            alias_index.retain(|_, tool_id| !tools_to_remove.contains(tool_id));
200
201            true
202        } else {
203            false
204        }
205    }
206
207    pub fn set_tool_aliases(&self, tool_id: &str, aliases: ToolAliases) {
208        if !self.tool_index.read().contains_key(tool_id) {
209            return;
210        }
211
212        {
213            let mut alias_index = self.alias_index.write();
214            for (lang, name) in &aliases.names {
215                let key = format!("{}:{}", lang, name.to_lowercase());
216                alias_index.insert(key, tool_id.to_string());
217            }
218        }
219
220        self.builtin_aliases
221            .write()
222            .insert(tool_id.to_string(), aliases);
223    }
224
225    pub fn get_by_alias(&self, alias: &str, lang: &str) -> Option<Arc<dyn Tool>> {
226        let key = format!("{}:{}", lang, alias.to_lowercase());
227        let alias_index = self.alias_index.read();
228
229        if let Some(tool_id) = alias_index.get(&key) {
230            return self.get(tool_id);
231        }
232
233        None
234    }
235
236    pub fn list_providers(&self) -> Vec<String> {
237        self.providers.read().keys().cloned().collect()
238    }
239
240    pub async fn provider_health(&self, provider_id: &str) -> Option<ProviderHealth> {
241        let providers = self.providers.read();
242        if let Some(provider) = providers.get(provider_id) {
243            Some(provider.health_check().await)
244        } else {
245            None
246        }
247    }
248
249    pub async fn refresh_provider(&self, provider_id: &str) -> Result<(), ToolProviderError> {
250        let provider = {
251            let providers = self.providers.read();
252            providers.get(provider_id).cloned()
253        };
254
255        if let Some(provider) = provider {
256            if provider.supports_refresh() {
257                provider.refresh().await?;
258
259                let tools = provider.list_tools().await;
260
261                let mut tool_index = self.tool_index.write();
262                let mut alias_index = self.alias_index.write();
263
264                let old_tools: Vec<String> = tool_index
265                    .iter()
266                    .filter_map(|(id, tool_ref)| {
267                        if let ToolRef::Provider {
268                            provider_id: pid, ..
269                        } = tool_ref
270                        {
271                            if pid == provider_id {
272                                return Some(id.clone());
273                            }
274                        }
275                        None
276                    })
277                    .collect();
278
279                for tool_id in &old_tools {
280                    tool_index.remove(tool_id);
281                }
282                alias_index.retain(|_, tool_id| !old_tools.contains(tool_id));
283
284                for descriptor in &tools {
285                    if let Some(tool) = provider.get_tool(&descriptor.id).await {
286                        tool_index.insert(
287                            descriptor.id.clone(),
288                            ToolRef::Provider {
289                                provider_id: provider_id.to_string(),
290                                tool,
291                            },
292                        );
293
294                        if let Some(ref aliases) = descriptor.aliases {
295                            for (lang, name) in &aliases.names {
296                                let key = format!("{}:{}", lang, name.to_lowercase());
297                                alias_index.insert(key, descriptor.id.clone());
298                            }
299                        }
300                    }
301                }
302            }
303            Ok(())
304        } else {
305            Err(ToolProviderError::ToolNotFound(format!(
306                "Provider not found: {}",
307                provider_id
308            )))
309        }
310    }
311
312    pub fn generate_tools_prompt(&self) -> String {
313        self.generate_tools_prompt_with_lang(None, false)
314    }
315
316    pub fn generate_tools_prompt_with_parallel(&self, parallel: bool) -> String {
317        self.generate_tools_prompt_with_lang(None, parallel)
318    }
319
320    pub fn generate_tools_prompt_with_lang(
321        &self,
322        language: Option<&str>,
323        parallel: bool,
324    ) -> String {
325        let tool_index = self.tool_index.read();
326        if tool_index.is_empty() {
327            return String::new();
328        }
329
330        let builtin_aliases = self.builtin_aliases.read();
331        let mut prompt = String::from("Available tools:\n");
332
333        for (id, tool_ref) in tool_index.iter() {
334            if let Some(tool) = self.resolve_tool_ref(tool_ref) {
335                let (name, description) = if let Some(lang) = language {
336                    if let Some(aliases) = builtin_aliases.get(id) {
337                        let name = aliases
338                            .names
339                            .get(lang)
340                            .map(|s| s.as_str())
341                            .unwrap_or_else(|| tool.name());
342                        let desc = aliases
343                            .descriptions
344                            .get(lang)
345                            .map(|s| s.as_str())
346                            .unwrap_or_else(|| tool.description());
347                        (name, desc)
348                    } else {
349                        (tool.name(), tool.description())
350                    }
351                } else {
352                    (tool.name(), tool.description())
353                };
354
355                let schema = tool.input_schema();
356                let args_desc = if let Some(props) = schema.get("properties") {
357                    serde_json::to_string(props).unwrap_or_default()
358                } else {
359                    "{}".to_string()
360                };
361
362                prompt.push_str(&format!(
363                    "- {}: {}. Arguments: {}\n",
364                    name, description, args_desc
365                ));
366            }
367        }
368
369        Self::append_tool_format_instructions(&mut prompt, parallel);
370
371        prompt
372    }
373
374    pub fn generate_filtered_prompt(&self, tool_ids: &[String]) -> String {
375        self.generate_filtered_prompt_with_lang(tool_ids, None, false)
376    }
377
378    pub fn generate_filtered_prompt_with_parallel(
379        &self,
380        tool_ids: &[String],
381        parallel: bool,
382    ) -> String {
383        self.generate_filtered_prompt_with_lang(tool_ids, None, parallel)
384    }
385
386    pub fn generate_filtered_prompt_with_lang(
387        &self,
388        tool_ids: &[String],
389        language: Option<&str>,
390        parallel: bool,
391    ) -> String {
392        if tool_ids.is_empty() {
393            return self.generate_tools_prompt_with_lang(language, parallel);
394        }
395
396        let tool_index = self.tool_index.read();
397        let builtin_aliases = self.builtin_aliases.read();
398        let mut prompt = String::from("Available tools:\n");
399        let mut found_any = false;
400
401        for id in tool_ids {
402            if let Some(tool_ref) = tool_index.get(id) {
403                if let Some(tool) = self.resolve_tool_ref(tool_ref) {
404                    found_any = true;
405
406                    let (name, description) = if let Some(lang) = language {
407                        if let Some(aliases) = builtin_aliases.get(id) {
408                            let name = aliases
409                                .names
410                                .get(lang)
411                                .map(|s| s.as_str())
412                                .unwrap_or_else(|| tool.name());
413                            let desc = aliases
414                                .descriptions
415                                .get(lang)
416                                .map(|s| s.as_str())
417                                .unwrap_or_else(|| tool.description());
418                            (name, desc)
419                        } else {
420                            (tool.name(), tool.description())
421                        }
422                    } else {
423                        (tool.name(), tool.description())
424                    };
425
426                    let schema = tool.input_schema();
427                    let args_desc = if let Some(props) = schema.get("properties") {
428                        serde_json::to_string(props).unwrap_or_default()
429                    } else {
430                        "{}".to_string()
431                    };
432
433                    prompt.push_str(&format!(
434                        "- {}: {}. Arguments: {}\n",
435                        name, description, args_desc
436                    ));
437                }
438            }
439        }
440
441        if !found_any {
442            return String::new();
443        }
444
445        Self::append_tool_format_instructions(&mut prompt, parallel);
446
447        prompt
448    }
449
450    /// Append tool call format instructions to a prompt.
451    /// When `parallel` is true, also instructs the LLM to use a JSON array
452    /// for multiple simultaneous tool calls.
453    fn append_tool_format_instructions(prompt: &mut String, parallel: bool) {
454        prompt.push_str(
455            "\nWhen you need to use a tool, respond ONLY with valid JSON in this exact format:\n",
456        );
457        prompt.push_str("{\"tool\": \"tool_name\", \"arguments\": {...}}\n");
458        prompt.push_str("The \"tool\" value MUST be one of the exact tool names listed above. Do not invent tool names.\n");
459        if parallel {
460            prompt.push_str(
461                "\nWhen you need to call multiple tools at once, respond with a JSON array:\n",
462            );
463            prompt.push_str(
464                "[{\"tool\": \"tool_name1\", \"arguments\": {...}}, {\"tool\": \"tool_name2\", \"arguments\": {...}}]\n",
465            );
466        }
467        prompt.push_str("\nWhen you receive a tool result, summarize it naturally for the user.\n");
468        prompt.push_str("If no tool is needed, respond normally.");
469    }
470}
471
472impl Default for ToolRegistry {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::ToolResult;
482    use async_trait::async_trait;
483    use serde_json::Value;
484
485    struct TestTool {
486        id: String,
487    }
488
489    #[async_trait]
490    impl Tool for TestTool {
491        fn id(&self) -> &str {
492            &self.id
493        }
494        fn name(&self) -> &str {
495            "Test"
496        }
497        fn description(&self) -> &str {
498            "A test tool"
499        }
500        fn input_schema(&self) -> Value {
501            serde_json::json!({"type": "object"})
502        }
503        async fn execute(&self, _args: Value) -> ToolResult {
504            ToolResult::ok("test")
505        }
506    }
507
508    #[test]
509    fn test_register_and_get() {
510        let mut registry = ToolRegistry::new();
511        let tool = Arc::new(TestTool {
512            id: "test".to_string(),
513        });
514
515        registry.register(tool).unwrap();
516        assert!(registry.get("test").is_some());
517        assert_eq!(registry.len(), 1);
518    }
519
520    #[test]
521    fn test_duplicate_registration() {
522        let mut registry = ToolRegistry::new();
523        let tool1 = Arc::new(TestTool {
524            id: "test".to_string(),
525        });
526        let tool2 = Arc::new(TestTool {
527            id: "test".to_string(),
528        });
529
530        registry.register(tool1).unwrap();
531        assert!(registry.register(tool2).is_err());
532    }
533
534    #[test]
535    fn test_list_ids() {
536        let mut registry = ToolRegistry::new();
537        registry
538            .register(Arc::new(TestTool {
539                id: "a".to_string(),
540            }))
541            .unwrap();
542        registry
543            .register(Arc::new(TestTool {
544                id: "b".to_string(),
545            }))
546            .unwrap();
547
548        let ids = registry.list_ids();
549        assert_eq!(ids.len(), 2);
550        assert!(ids.contains(&"a".to_string()));
551        assert!(ids.contains(&"b".to_string()));
552    }
553
554    #[test]
555    fn test_generate_tools_prompt() {
556        let empty_registry = ToolRegistry::new();
557        let empty_prompt = empty_registry.generate_tools_prompt();
558        assert!(empty_prompt.is_empty());
559
560        let mut registry = ToolRegistry::new();
561        registry
562            .register(Arc::new(TestTool {
563                id: "test".to_string(),
564            }))
565            .unwrap();
566
567        let prompt = registry.generate_tools_prompt();
568        assert!(prompt.contains("Available tools:"));
569        assert!(prompt.contains("Test:"));
570        assert!(prompt.contains("A test tool"));
571        assert!(prompt.contains("tool_name"));
572    }
573
574    #[test]
575    fn test_generate_filtered_prompt_with_filter() {
576        let mut registry = ToolRegistry::new();
577        registry
578            .register(Arc::new(TestTool {
579                id: "tool_a".to_string(),
580            }))
581            .unwrap();
582        registry
583            .register(Arc::new(TestTool {
584                id: "tool_b".to_string(),
585            }))
586            .unwrap();
587        registry
588            .register(Arc::new(TestTool {
589                id: "tool_c".to_string(),
590            }))
591            .unwrap();
592
593        let prompt =
594            registry.generate_filtered_prompt(&["tool_a".to_string(), "tool_c".to_string()]);
595
596        assert!(prompt.contains("tool_a") || prompt.contains("Test"));
597        assert!(!prompt.contains("tool_b"));
598    }
599
600    #[test]
601    fn test_generate_filtered_prompt_empty_filter() {
602        let mut registry = ToolRegistry::new();
603        registry
604            .register(Arc::new(TestTool {
605                id: "tool_a".to_string(),
606            }))
607            .unwrap();
608        registry
609            .register(Arc::new(TestTool {
610                id: "tool_b".to_string(),
611            }))
612            .unwrap();
613
614        let prompt = registry.generate_filtered_prompt(&[]);
615        assert!(prompt.contains("Test"));
616    }
617
618    #[test]
619    fn test_generate_filtered_prompt_nonexistent_tools() {
620        let mut registry = ToolRegistry::new();
621        registry
622            .register(Arc::new(TestTool {
623                id: "tool_a".to_string(),
624            }))
625            .unwrap();
626
627        let prompt = registry.generate_filtered_prompt(&["nonexistent".to_string()]);
628        assert!(prompt.is_empty());
629
630        let prompt2 =
631            registry.generate_filtered_prompt(&["tool_a".to_string(), "nonexistent".to_string()]);
632        assert!(prompt2.contains("Test"));
633    }
634
635    #[test]
636    fn test_set_tool_aliases() {
637        let mut registry = ToolRegistry::new();
638        registry
639            .register(Arc::new(TestTool {
640                id: "calculator".to_string(),
641            }))
642            .unwrap();
643
644        let aliases = ToolAliases::new()
645            .with_name("ko", "계산기")
646            .with_name("ja", "計算機")
647            .with_description("ko", "수학 계산을 합니다");
648
649        registry.set_tool_aliases("calculator", aliases);
650
651        assert!(registry.get_by_alias("계산기", "ko").is_some());
652        assert!(registry.get_by_alias("計算機", "ja").is_some());
653        assert!(registry.get("calculator").is_some());
654    }
655
656    #[test]
657    fn test_get_by_alias_case_insensitive() {
658        let mut registry = ToolRegistry::new();
659        registry
660            .register(Arc::new(TestTool {
661                id: "search".to_string(),
662            }))
663            .unwrap();
664
665        let aliases = ToolAliases::new().with_name("ko", "검색");
666        registry.set_tool_aliases("search", aliases);
667
668        assert!(registry.get_by_alias("검색", "ko").is_some());
669    }
670
671    #[test]
672    fn test_generate_prompt_with_language() {
673        let mut registry = ToolRegistry::new();
674        registry
675            .register(Arc::new(TestTool {
676                id: "calculator".to_string(),
677            }))
678            .unwrap();
679
680        let aliases = ToolAliases::new()
681            .with_name("ko", "계산기")
682            .with_description("ko", "수학 계산");
683
684        registry.set_tool_aliases("calculator", aliases);
685
686        let prompt_en = registry.generate_tools_prompt_with_lang(None, false);
687        assert!(prompt_en.contains("Test"));
688
689        let prompt_ko = registry.generate_tools_prompt_with_lang(Some("ko"), false);
690        assert!(prompt_ko.contains("계산기"));
691        assert!(prompt_ko.contains("수학 계산"));
692    }
693
694    #[test]
695    fn test_generate_tools_prompt_parallel() {
696        let mut registry = ToolRegistry::new();
697        registry
698            .register(Arc::new(TestTool {
699                id: "tool_a".to_string(),
700            }))
701            .unwrap();
702        registry
703            .register(Arc::new(TestTool {
704                id: "tool_b".to_string(),
705            }))
706            .unwrap();
707
708        // Without parallel: no array instruction
709        let prompt_seq = registry.generate_tools_prompt();
710        assert!(prompt_seq.contains("\"tool\": \"tool_name\""));
711        assert!(!prompt_seq.contains("JSON array"));
712        assert!(!prompt_seq.contains("tool_name1"));
713
714        // With parallel: array instruction present
715        let prompt_par = registry.generate_tools_prompt_with_parallel(true);
716        assert!(prompt_par.contains("\"tool\": \"tool_name\""));
717        assert!(prompt_par.contains("JSON array"));
718        assert!(prompt_par.contains("tool_name1"));
719        assert!(prompt_par.contains("tool_name2"));
720    }
721
722    #[test]
723    fn test_generate_filtered_prompt_parallel() {
724        let mut registry = ToolRegistry::new();
725        registry
726            .register(Arc::new(TestTool {
727                id: "tool_a".to_string(),
728            }))
729            .unwrap();
730        registry
731            .register(Arc::new(TestTool {
732                id: "tool_b".to_string(),
733            }))
734            .unwrap();
735
736        // Filtered without parallel
737        let prompt_seq =
738            registry.generate_filtered_prompt(&["tool_a".to_string(), "tool_b".to_string()]);
739        assert!(!prompt_seq.contains("JSON array"));
740
741        // Filtered with parallel
742        let prompt_par = registry.generate_filtered_prompt_with_parallel(
743            &["tool_a".to_string(), "tool_b".to_string()],
744            true,
745        );
746        assert!(prompt_par.contains("JSON array"));
747        assert!(prompt_par.contains("tool_name1"));
748    }
749}