Skip to main content

lellm_agent/tools/
registry.rs

1//! 工具注册表 — 支持按名称、同义词、分类搜索。
2
3use std::collections::{HashMap, HashSet};
4
5use lellm_core::ToolDefinition;
6
7/// 工具来源分类
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ToolSource {
10    Builtin,
11    Dynamic,
12    Mcp,
13    Skill,
14}
15
16/// 工具搜索结果(包含来源信息)。
17#[derive(Debug, Clone)]
18pub struct ToolSearchResult {
19    pub definition: ToolDefinition,
20    pub source: ToolSource,
21    pub category: String,
22}
23
24impl ToolSearchResult {
25    pub fn name(&self) -> &str {
26        &self.definition.name
27    }
28}
29
30/// 工具注册表。
31pub struct ToolRegistry {
32    tools: HashMap<String, ToolSearchResult>,
33    synonyms: HashMap<String, Vec<String>>,
34    categories: HashMap<String, Vec<String>>,
35}
36
37impl ToolRegistry {
38    pub fn new() -> Self {
39        Self {
40            tools: HashMap::new(),
41            synonyms: HashMap::new(),
42            categories: HashMap::new(),
43        }
44    }
45
46    pub fn register(&mut self, name: &str, source: ToolSource, def: ToolDefinition) {
47        let category = Self::infer_category(name);
48        let result = ToolSearchResult {
49            definition: def,
50            source,
51            category: category.clone(),
52        };
53        self.tools.insert(name.to_string(), result);
54        self.categories
55            .entry(category)
56            .or_default()
57            .push(name.to_string());
58    }
59
60    pub fn add_synonyms(&mut self, tool_name: &str, synonyms: &[&str]) {
61        for syn in synonyms {
62            self.synonyms
63                .entry(syn.to_string())
64                .or_default()
65                .push(tool_name.to_string());
66        }
67    }
68
69    /// 搜索工具(精确 → 同义词 → 子串兜底)
70    pub fn search(&self, query: &str) -> Vec<ToolSearchResult> {
71        let query_lower = query.to_lowercase();
72        let mut results = Vec::new();
73        let mut seen = HashSet::new();
74
75        if let Some(result) = self.tools.get(query) {
76            results.push(result.clone());
77            seen.insert(query);
78        }
79
80        if let Some(names) = self.synonyms.get(&query_lower) {
81            for name in names {
82                if seen.insert(name)
83                    && let Some(result) = self.tools.get(name)
84                {
85                    results.push(result.clone());
86                }
87            }
88        }
89
90        if results.is_empty() {
91            for (name, result) in &self.tools {
92                if name.to_lowercase().contains(&query_lower) {
93                    results.push(result.clone());
94                }
95            }
96        }
97
98        results
99    }
100
101    pub fn search_category(&self, category: &str) -> Vec<ToolSearchResult> {
102        let mut results = Vec::new();
103        if let Some(names) = self.categories.get(category) {
104            for name in names {
105                if let Some(result) = self.tools.get(name) {
106                    results.push(result.clone());
107                }
108            }
109        }
110        results
111    }
112
113    pub fn list_tools(&self) -> Vec<ToolSearchResult> {
114        self.tools.values().cloned().collect()
115    }
116
117    fn infer_category(name: &str) -> String {
118        if name.starts_with("read") || name.starts_with("write") || name.starts_with("bash") {
119            "builtin".to_string()
120        } else {
121            "custom".to_string()
122        }
123    }
124}
125
126impl Default for ToolRegistry {
127    fn default() -> Self {
128        Self::new()
129    }
130}