1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use ai_agents_core::{Tool, ToolInfo};
6
7use super::ToolError;
8use super::provider::{ProviderHealth, ToolProvider, ToolProviderError};
9use super::types::ToolAliases;
10
11#[derive(Clone)]
12enum ToolRef {
13 Builtin(Arc<dyn Tool>),
14 Provider {
15 provider_id: String,
16 tool: Arc<dyn Tool>,
17 },
18}
19
20pub struct ToolRegistry {
21 builtin_tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22
23 providers: RwLock<HashMap<String, Arc<dyn ToolProvider>>>,
24
25 tool_index: RwLock<HashMap<String, ToolRef>>,
26
27 alias_index: RwLock<HashMap<String, String>>,
28
29 builtin_aliases: RwLock<HashMap<String, ToolAliases>>,
30}
31
32impl ToolRegistry {
33 pub fn new() -> Self {
34 Self {
35 builtin_tools: RwLock::new(HashMap::new()),
36 providers: RwLock::new(HashMap::new()),
37 tool_index: RwLock::new(HashMap::new()),
38 alias_index: RwLock::new(HashMap::new()),
39 builtin_aliases: RwLock::new(HashMap::new()),
40 }
41 }
42
43 pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolError> {
44 let id = tool.id().to_string();
45
46 let mut builtin_tools = self.builtin_tools.write();
47 let mut tool_index = self.tool_index.write();
48
49 if builtin_tools.contains_key(&id) || tool_index.contains_key(&id) {
50 return Err(ToolError::Duplicate(id));
51 }
52
53 tool_index.insert(id.clone(), ToolRef::Builtin(tool.clone()));
54 builtin_tools.insert(id, tool);
55 Ok(())
56 }
57
58 pub fn get(&self, id_or_alias: &str) -> Option<Arc<dyn Tool>> {
59 let tool_index = self.tool_index.read();
60
61 if let Some(tool_ref) = tool_index.get(id_or_alias) {
63 return self.resolve_tool_ref(tool_ref);
64 }
65
66 let lower_input = id_or_alias.to_lowercase();
69 for (id, tool_ref) in tool_index.iter() {
70 if id.to_lowercase() == lower_input {
71 return self.resolve_tool_ref(tool_ref);
72 }
73 if let Some(tool) = self.resolve_tool_ref(tool_ref) {
74 if tool.name().to_lowercase() == lower_input {
75 return Some(tool);
76 }
77 }
78 }
79
80 let alias_index = self.alias_index.read();
82 for (alias_key, tool_id) in alias_index.iter() {
83 if alias_key.ends_with(&format!(":{}", lower_input)) {
84 if let Some(tool_ref) = tool_index.get(tool_id) {
85 return self.resolve_tool_ref(tool_ref);
86 }
87 }
88 }
89
90 None
91 }
92
93 fn resolve_tool_ref(&self, tool_ref: &ToolRef) -> Option<Arc<dyn Tool>> {
94 match tool_ref {
95 ToolRef::Builtin(tool) => Some(tool.clone()),
96 ToolRef::Provider { tool, .. } => Some(tool.clone()),
97 }
98 }
99
100 pub fn list_ids(&self) -> Vec<String> {
101 self.tool_index.read().keys().cloned().collect()
102 }
103
104 pub fn list_infos(&self) -> Vec<ToolInfo> {
105 let tool_index = self.tool_index.read();
106 let mut infos = Vec::with_capacity(tool_index.len());
107
108 for tool_ref in tool_index.values() {
109 if let Some(tool) = self.resolve_tool_ref(tool_ref) {
110 infos.push(tool.info());
111 }
112 }
113
114 infos
115 }
116
117 pub fn len(&self) -> usize {
118 self.tool_index.read().len()
119 }
120
121 pub fn is_empty(&self) -> bool {
122 self.tool_index.read().is_empty()
123 }
124
125 pub async fn register_provider(
126 &self,
127 provider: Arc<dyn ToolProvider>,
128 ) -> Result<(), ToolError> {
129 let provider_id = provider.id().to_string();
130
131 {
132 let providers = self.providers.read();
133 if providers.contains_key(&provider_id) {
134 return Err(ToolError::Duplicate(format!("Provider: {}", provider_id)));
135 }
136 }
137
138 let tools = provider.list_tools().await;
139
140 {
141 let mut tool_index = self.tool_index.write();
142 let mut alias_index = self.alias_index.write();
143
144 for descriptor in &tools {
145 if tool_index.contains_key(&descriptor.id) {
146 return Err(ToolError::Duplicate(descriptor.id.clone()));
147 }
148
149 if let Some(tool) = provider.get_tool(&descriptor.id).await {
150 tool_index.insert(
151 descriptor.id.clone(),
152 ToolRef::Provider {
153 provider_id: provider_id.clone(),
154 tool,
155 },
156 );
157
158 if let Some(ref aliases) = descriptor.aliases {
159 for (lang, name) in &aliases.names {
160 let key = format!("{}:{}", lang, name.to_lowercase());
161 alias_index.insert(key, descriptor.id.clone());
162 }
163 }
164 }
165 }
166 }
167
168 self.providers.write().insert(provider_id, provider);
169
170 Ok(())
171 }
172
173 pub fn unregister_provider(&self, provider_id: &str) -> bool {
174 let removed = self.providers.write().remove(provider_id);
175
176 if removed.is_some() {
177 let mut tool_index = self.tool_index.write();
178 let mut alias_index = self.alias_index.write();
179
180 let tools_to_remove: Vec<String> = tool_index
181 .iter()
182 .filter_map(|(id, tool_ref)| {
183 if let ToolRef::Provider {
184 provider_id: pid, ..
185 } = tool_ref
186 {
187 if pid == provider_id {
188 return Some(id.clone());
189 }
190 }
191 None
192 })
193 .collect();
194
195 for tool_id in &tools_to_remove {
196 tool_index.remove(tool_id);
197 }
198
199 alias_index.retain(|_, tool_id| !tools_to_remove.contains(tool_id));
200
201 true
202 } else {
203 false
204 }
205 }
206
207 pub fn set_tool_aliases(&self, tool_id: &str, aliases: ToolAliases) {
208 if !self.tool_index.read().contains_key(tool_id) {
209 return;
210 }
211
212 {
213 let mut alias_index = self.alias_index.write();
214 for (lang, name) in &aliases.names {
215 let key = format!("{}:{}", lang, name.to_lowercase());
216 alias_index.insert(key, tool_id.to_string());
217 }
218 }
219
220 self.builtin_aliases
221 .write()
222 .insert(tool_id.to_string(), aliases);
223 }
224
225 pub fn get_by_alias(&self, alias: &str, lang: &str) -> Option<Arc<dyn Tool>> {
226 let key = format!("{}:{}", lang, alias.to_lowercase());
227 let alias_index = self.alias_index.read();
228
229 if let Some(tool_id) = alias_index.get(&key) {
230 return self.get(tool_id);
231 }
232
233 None
234 }
235
236 pub fn list_providers(&self) -> Vec<String> {
237 self.providers.read().keys().cloned().collect()
238 }
239
240 pub async fn provider_health(&self, provider_id: &str) -> Option<ProviderHealth> {
241 let providers = self.providers.read();
242 if let Some(provider) = providers.get(provider_id) {
243 Some(provider.health_check().await)
244 } else {
245 None
246 }
247 }
248
249 pub async fn refresh_provider(&self, provider_id: &str) -> Result<(), ToolProviderError> {
250 let provider = {
251 let providers = self.providers.read();
252 providers.get(provider_id).cloned()
253 };
254
255 if let Some(provider) = provider {
256 if provider.supports_refresh() {
257 provider.refresh().await?;
258
259 let tools = provider.list_tools().await;
260
261 let mut tool_index = self.tool_index.write();
262 let mut alias_index = self.alias_index.write();
263
264 let old_tools: Vec<String> = tool_index
265 .iter()
266 .filter_map(|(id, tool_ref)| {
267 if let ToolRef::Provider {
268 provider_id: pid, ..
269 } = tool_ref
270 {
271 if pid == provider_id {
272 return Some(id.clone());
273 }
274 }
275 None
276 })
277 .collect();
278
279 for tool_id in &old_tools {
280 tool_index.remove(tool_id);
281 }
282 alias_index.retain(|_, tool_id| !old_tools.contains(tool_id));
283
284 for descriptor in &tools {
285 if let Some(tool) = provider.get_tool(&descriptor.id).await {
286 tool_index.insert(
287 descriptor.id.clone(),
288 ToolRef::Provider {
289 provider_id: provider_id.to_string(),
290 tool,
291 },
292 );
293
294 if let Some(ref aliases) = descriptor.aliases {
295 for (lang, name) in &aliases.names {
296 let key = format!("{}:{}", lang, name.to_lowercase());
297 alias_index.insert(key, descriptor.id.clone());
298 }
299 }
300 }
301 }
302 }
303 Ok(())
304 } else {
305 Err(ToolProviderError::ToolNotFound(format!(
306 "Provider not found: {}",
307 provider_id
308 )))
309 }
310 }
311
312 pub fn generate_tools_prompt(&self) -> String {
313 self.generate_tools_prompt_with_lang(None, false)
314 }
315
316 pub fn generate_tools_prompt_with_parallel(&self, parallel: bool) -> String {
317 self.generate_tools_prompt_with_lang(None, parallel)
318 }
319
320 pub fn generate_tools_prompt_with_lang(
321 &self,
322 language: Option<&str>,
323 parallel: bool,
324 ) -> String {
325 let tool_index = self.tool_index.read();
326 if tool_index.is_empty() {
327 return String::new();
328 }
329
330 let builtin_aliases = self.builtin_aliases.read();
331 let mut prompt = String::from("Available tools:\n");
332
333 for (id, tool_ref) in tool_index.iter() {
334 if let Some(tool) = self.resolve_tool_ref(tool_ref) {
335 let (name, description) = if let Some(lang) = language {
336 if let Some(aliases) = builtin_aliases.get(id) {
337 let name = aliases
338 .names
339 .get(lang)
340 .map(|s| s.as_str())
341 .unwrap_or_else(|| tool.name());
342 let desc = aliases
343 .descriptions
344 .get(lang)
345 .map(|s| s.as_str())
346 .unwrap_or_else(|| tool.description());
347 (name, desc)
348 } else {
349 (tool.name(), tool.description())
350 }
351 } else {
352 (tool.name(), tool.description())
353 };
354
355 let schema = tool.input_schema();
356 let args_desc = if let Some(props) = schema.get("properties") {
357 serde_json::to_string(props).unwrap_or_default()
358 } else {
359 "{}".to_string()
360 };
361
362 prompt.push_str(&format!(
363 "- {}: {}. Arguments: {}\n",
364 name, description, args_desc
365 ));
366 }
367 }
368
369 Self::append_tool_format_instructions(&mut prompt, parallel);
370
371 prompt
372 }
373
374 pub fn generate_filtered_prompt(&self, tool_ids: &[String]) -> String {
375 self.generate_filtered_prompt_with_lang(tool_ids, None, false)
376 }
377
378 pub fn generate_filtered_prompt_with_parallel(
379 &self,
380 tool_ids: &[String],
381 parallel: bool,
382 ) -> String {
383 self.generate_filtered_prompt_with_lang(tool_ids, None, parallel)
384 }
385
386 pub fn generate_filtered_prompt_with_lang(
387 &self,
388 tool_ids: &[String],
389 language: Option<&str>,
390 parallel: bool,
391 ) -> String {
392 if tool_ids.is_empty() {
393 return self.generate_tools_prompt_with_lang(language, parallel);
394 }
395
396 let tool_index = self.tool_index.read();
397 let builtin_aliases = self.builtin_aliases.read();
398 let mut prompt = String::from("Available tools:\n");
399 let mut found_any = false;
400
401 for id in tool_ids {
402 if let Some(tool_ref) = tool_index.get(id) {
403 if let Some(tool) = self.resolve_tool_ref(tool_ref) {
404 found_any = true;
405
406 let (name, description) = if let Some(lang) = language {
407 if let Some(aliases) = builtin_aliases.get(id) {
408 let name = aliases
409 .names
410 .get(lang)
411 .map(|s| s.as_str())
412 .unwrap_or_else(|| tool.name());
413 let desc = aliases
414 .descriptions
415 .get(lang)
416 .map(|s| s.as_str())
417 .unwrap_or_else(|| tool.description());
418 (name, desc)
419 } else {
420 (tool.name(), tool.description())
421 }
422 } else {
423 (tool.name(), tool.description())
424 };
425
426 let schema = tool.input_schema();
427 let args_desc = if let Some(props) = schema.get("properties") {
428 serde_json::to_string(props).unwrap_or_default()
429 } else {
430 "{}".to_string()
431 };
432
433 prompt.push_str(&format!(
434 "- {}: {}. Arguments: {}\n",
435 name, description, args_desc
436 ));
437 }
438 }
439 }
440
441 if !found_any {
442 return String::new();
443 }
444
445 Self::append_tool_format_instructions(&mut prompt, parallel);
446
447 prompt
448 }
449
450 fn append_tool_format_instructions(prompt: &mut String, parallel: bool) {
454 prompt.push_str(
455 "\nWhen you need to use a tool, respond ONLY with valid JSON in this exact format:\n",
456 );
457 prompt.push_str("{\"tool\": \"tool_name\", \"arguments\": {...}}\n");
458 prompt.push_str("The \"tool\" value MUST be one of the exact tool names listed above. Do not invent tool names.\n");
459 if parallel {
460 prompt.push_str(
461 "\nWhen you need to call multiple tools at once, respond with a JSON array:\n",
462 );
463 prompt.push_str(
464 "[{\"tool\": \"tool_name1\", \"arguments\": {...}}, {\"tool\": \"tool_name2\", \"arguments\": {...}}]\n",
465 );
466 }
467 prompt.push_str("\nWhen you receive a tool result, summarize it naturally for the user.\n");
468 prompt.push_str("If no tool is needed, respond normally.");
469 }
470}
471
472impl Default for ToolRegistry {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::ToolResult;
482 use async_trait::async_trait;
483 use serde_json::Value;
484
485 struct TestTool {
486 id: String,
487 }
488
489 #[async_trait]
490 impl Tool for TestTool {
491 fn id(&self) -> &str {
492 &self.id
493 }
494 fn name(&self) -> &str {
495 "Test"
496 }
497 fn description(&self) -> &str {
498 "A test tool"
499 }
500 fn input_schema(&self) -> Value {
501 serde_json::json!({"type": "object"})
502 }
503 async fn execute(&self, _args: Value) -> ToolResult {
504 ToolResult::ok("test")
505 }
506 }
507
508 #[test]
509 fn test_register_and_get() {
510 let mut registry = ToolRegistry::new();
511 let tool = Arc::new(TestTool {
512 id: "test".to_string(),
513 });
514
515 registry.register(tool).unwrap();
516 assert!(registry.get("test").is_some());
517 assert_eq!(registry.len(), 1);
518 }
519
520 #[test]
521 fn test_duplicate_registration() {
522 let mut registry = ToolRegistry::new();
523 let tool1 = Arc::new(TestTool {
524 id: "test".to_string(),
525 });
526 let tool2 = Arc::new(TestTool {
527 id: "test".to_string(),
528 });
529
530 registry.register(tool1).unwrap();
531 assert!(registry.register(tool2).is_err());
532 }
533
534 #[test]
535 fn test_list_ids() {
536 let mut registry = ToolRegistry::new();
537 registry
538 .register(Arc::new(TestTool {
539 id: "a".to_string(),
540 }))
541 .unwrap();
542 registry
543 .register(Arc::new(TestTool {
544 id: "b".to_string(),
545 }))
546 .unwrap();
547
548 let ids = registry.list_ids();
549 assert_eq!(ids.len(), 2);
550 assert!(ids.contains(&"a".to_string()));
551 assert!(ids.contains(&"b".to_string()));
552 }
553
554 #[test]
555 fn test_generate_tools_prompt() {
556 let empty_registry = ToolRegistry::new();
557 let empty_prompt = empty_registry.generate_tools_prompt();
558 assert!(empty_prompt.is_empty());
559
560 let mut registry = ToolRegistry::new();
561 registry
562 .register(Arc::new(TestTool {
563 id: "test".to_string(),
564 }))
565 .unwrap();
566
567 let prompt = registry.generate_tools_prompt();
568 assert!(prompt.contains("Available tools:"));
569 assert!(prompt.contains("Test:"));
570 assert!(prompt.contains("A test tool"));
571 assert!(prompt.contains("tool_name"));
572 }
573
574 #[test]
575 fn test_generate_filtered_prompt_with_filter() {
576 let mut registry = ToolRegistry::new();
577 registry
578 .register(Arc::new(TestTool {
579 id: "tool_a".to_string(),
580 }))
581 .unwrap();
582 registry
583 .register(Arc::new(TestTool {
584 id: "tool_b".to_string(),
585 }))
586 .unwrap();
587 registry
588 .register(Arc::new(TestTool {
589 id: "tool_c".to_string(),
590 }))
591 .unwrap();
592
593 let prompt =
594 registry.generate_filtered_prompt(&["tool_a".to_string(), "tool_c".to_string()]);
595
596 assert!(prompt.contains("tool_a") || prompt.contains("Test"));
597 assert!(!prompt.contains("tool_b"));
598 }
599
600 #[test]
601 fn test_generate_filtered_prompt_empty_filter() {
602 let mut registry = ToolRegistry::new();
603 registry
604 .register(Arc::new(TestTool {
605 id: "tool_a".to_string(),
606 }))
607 .unwrap();
608 registry
609 .register(Arc::new(TestTool {
610 id: "tool_b".to_string(),
611 }))
612 .unwrap();
613
614 let prompt = registry.generate_filtered_prompt(&[]);
615 assert!(prompt.contains("Test"));
616 }
617
618 #[test]
619 fn test_generate_filtered_prompt_nonexistent_tools() {
620 let mut registry = ToolRegistry::new();
621 registry
622 .register(Arc::new(TestTool {
623 id: "tool_a".to_string(),
624 }))
625 .unwrap();
626
627 let prompt = registry.generate_filtered_prompt(&["nonexistent".to_string()]);
628 assert!(prompt.is_empty());
629
630 let prompt2 =
631 registry.generate_filtered_prompt(&["tool_a".to_string(), "nonexistent".to_string()]);
632 assert!(prompt2.contains("Test"));
633 }
634
635 #[test]
636 fn test_set_tool_aliases() {
637 let mut registry = ToolRegistry::new();
638 registry
639 .register(Arc::new(TestTool {
640 id: "calculator".to_string(),
641 }))
642 .unwrap();
643
644 let aliases = ToolAliases::new()
645 .with_name("ko", "계산기")
646 .with_name("ja", "計算機")
647 .with_description("ko", "수학 계산을 합니다");
648
649 registry.set_tool_aliases("calculator", aliases);
650
651 assert!(registry.get_by_alias("계산기", "ko").is_some());
652 assert!(registry.get_by_alias("計算機", "ja").is_some());
653 assert!(registry.get("calculator").is_some());
654 }
655
656 #[test]
657 fn test_get_by_alias_case_insensitive() {
658 let mut registry = ToolRegistry::new();
659 registry
660 .register(Arc::new(TestTool {
661 id: "search".to_string(),
662 }))
663 .unwrap();
664
665 let aliases = ToolAliases::new().with_name("ko", "검색");
666 registry.set_tool_aliases("search", aliases);
667
668 assert!(registry.get_by_alias("검색", "ko").is_some());
669 }
670
671 #[test]
672 fn test_generate_prompt_with_language() {
673 let mut registry = ToolRegistry::new();
674 registry
675 .register(Arc::new(TestTool {
676 id: "calculator".to_string(),
677 }))
678 .unwrap();
679
680 let aliases = ToolAliases::new()
681 .with_name("ko", "계산기")
682 .with_description("ko", "수학 계산");
683
684 registry.set_tool_aliases("calculator", aliases);
685
686 let prompt_en = registry.generate_tools_prompt_with_lang(None, false);
687 assert!(prompt_en.contains("Test"));
688
689 let prompt_ko = registry.generate_tools_prompt_with_lang(Some("ko"), false);
690 assert!(prompt_ko.contains("계산기"));
691 assert!(prompt_ko.contains("수학 계산"));
692 }
693
694 #[test]
695 fn test_generate_tools_prompt_parallel() {
696 let mut registry = ToolRegistry::new();
697 registry
698 .register(Arc::new(TestTool {
699 id: "tool_a".to_string(),
700 }))
701 .unwrap();
702 registry
703 .register(Arc::new(TestTool {
704 id: "tool_b".to_string(),
705 }))
706 .unwrap();
707
708 let prompt_seq = registry.generate_tools_prompt();
710 assert!(prompt_seq.contains("\"tool\": \"tool_name\""));
711 assert!(!prompt_seq.contains("JSON array"));
712 assert!(!prompt_seq.contains("tool_name1"));
713
714 let prompt_par = registry.generate_tools_prompt_with_parallel(true);
716 assert!(prompt_par.contains("\"tool\": \"tool_name\""));
717 assert!(prompt_par.contains("JSON array"));
718 assert!(prompt_par.contains("tool_name1"));
719 assert!(prompt_par.contains("tool_name2"));
720 }
721
722 #[test]
723 fn test_generate_filtered_prompt_parallel() {
724 let mut registry = ToolRegistry::new();
725 registry
726 .register(Arc::new(TestTool {
727 id: "tool_a".to_string(),
728 }))
729 .unwrap();
730 registry
731 .register(Arc::new(TestTool {
732 id: "tool_b".to_string(),
733 }))
734 .unwrap();
735
736 let prompt_seq =
738 registry.generate_filtered_prompt(&["tool_a".to_string(), "tool_b".to_string()]);
739 assert!(!prompt_seq.contains("JSON array"));
740
741 let prompt_par = registry.generate_filtered_prompt_with_parallel(
743 &["tool_a".to_string(), "tool_b".to_string()],
744 true,
745 );
746 assert!(prompt_par.contains("JSON array"));
747 assert!(prompt_par.contains("tool_name1"));
748 }
749}