Skip to main content

mofa_foundation/agent/tools/
registry.rs

1//! 统一工具注册中心
2//!
3//! 整合内置工具、MCP 工具、自定义工具的注册中心
4
5use async_trait::async_trait;
6use mofa_kernel::agent::components::tool::{
7    Tool, ToolDescriptor, ToolRegistry as ToolRegistryTrait,
8};
9use mofa_kernel::agent::error::AgentResult;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// 统一工具注册中心
14///
15/// 整合多种工具来源的注册中心
16///
17/// # 示例
18///
19/// ```rust,ignore
20/// use mofa_foundation::agent::tools::ToolRegistry;
21/// use mofa_foundation::agent::components::tool::EchoTool;
22///
23/// let mut registry = ToolRegistry::new();
24///
25/// // 注册内置工具
26/// registry.register(Arc::new(EchoTool)).unwrap();
27///
28/// // 注册 MCP 服务器的工具
29/// registry.load_mcp_server("http://localhost:8080").await?;
30///
31/// // 列出所有工具
32/// for tool in registry.list() {
33///     info!("{}: {}", tool.name, tool.description);
34/// }
35/// ```
36pub struct ToolRegistry {
37    /// 工具存储
38    tools: HashMap<String, Arc<dyn Tool>>,
39    /// 工具来源
40    sources: HashMap<String, ToolSource>,
41    /// MCP 客户端 (TODO: 实际 MCP 客户端实现)
42    mcp_endpoints: Vec<String>,
43}
44
45/// 工具来源
46#[derive(Debug, Clone)]
47pub enum ToolSource {
48    /// 内置工具
49    Builtin,
50    /// MCP 服务器
51    Mcp { endpoint: String },
52    /// 自定义插件
53    Plugin { path: String },
54    /// 动态注册
55    Dynamic,
56}
57
58impl ToolRegistry {
59    /// 创建新的统一注册中心
60    pub fn new() -> Self {
61        Self {
62            tools: HashMap::new(),
63            sources: HashMap::new(),
64            mcp_endpoints: Vec::new(),
65        }
66    }
67
68    /// 注册工具并记录来源
69    pub fn register_with_source(
70        &mut self,
71        tool: Arc<dyn Tool>,
72        source: ToolSource,
73    ) -> AgentResult<()> {
74        let name = tool.name().to_string();
75        self.sources.insert(name.clone(), source);
76        self.tools.insert(name, tool);
77        Ok(())
78    }
79
80    /// 加载 MCP 服务器的工具 (TODO: 实际 MCP 实现)
81    pub async fn load_mcp_server(&mut self, endpoint: &str) -> AgentResult<Vec<String>> {
82        // TODO: 实际 MCP 客户端实现
83        // 这里只是记录端点
84        self.mcp_endpoints.push(endpoint.to_string());
85
86        // 模拟加载的工具名称
87        Ok(vec![])
88    }
89
90    /// 卸载 MCP 服务器的工具
91    pub async fn unload_mcp_server(&mut self, endpoint: &str) -> AgentResult<Vec<String>> {
92        self.mcp_endpoints.retain(|e| e != endpoint);
93
94        // 移除该服务器的工具
95        let to_remove: Vec<String> = self
96            .sources
97            .iter()
98            .filter_map(|(name, source)| {
99                if let ToolSource::Mcp { endpoint: ep } = source
100                    && ep == endpoint
101                {
102                    return Some(name.clone());
103                }
104                None
105            })
106            .collect();
107
108        for name in &to_remove {
109            self.tools.remove(name);
110            self.sources.remove(name);
111        }
112
113        Ok(to_remove)
114    }
115
116    /// 热加载插件 (TODO: 实际插件系统实现)
117    pub async fn hot_reload_plugin(&mut self, _path: &str) -> AgentResult<Vec<String>> {
118        // TODO: 实际插件热加载实现
119        Ok(vec![])
120    }
121
122    /// 获取工具来源
123    pub fn get_source(&self, name: &str) -> Option<&ToolSource> {
124        self.sources.get(name)
125    }
126
127    /// 按来源过滤工具
128    pub fn filter_by_source(&self, source_type: &str) -> Vec<ToolDescriptor> {
129        self.tools
130            .iter()
131            .filter(|(name, _)| {
132                if let Some(source) = self.sources.get(*name) {
133                    match source {
134                        ToolSource::Builtin => source_type == "builtin",
135                        ToolSource::Mcp { .. } => source_type == "mcp",
136                        ToolSource::Plugin { .. } => source_type == "plugin",
137                        ToolSource::Dynamic => source_type == "dynamic",
138                    }
139                } else {
140                    false
141                }
142            })
143            .map(|(_, tool)| ToolDescriptor::from_tool(tool.as_ref()))
144            .collect()
145    }
146
147    /// 获取 MCP 端点列表
148    pub fn mcp_endpoints(&self) -> &[String] {
149        &self.mcp_endpoints
150    }
151}
152
153impl Default for ToolRegistry {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159#[async_trait]
160impl ToolRegistryTrait for ToolRegistry {
161    fn register(&mut self, tool: Arc<dyn Tool>) -> AgentResult<()> {
162        self.register_with_source(tool, ToolSource::Dynamic)
163    }
164
165    fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
166        self.tools.get(name).cloned()
167    }
168
169    fn unregister(&mut self, name: &str) -> AgentResult<bool> {
170        self.sources.remove(name);
171        Ok(self.tools.remove(name).is_some())
172    }
173
174    fn list(&self) -> Vec<ToolDescriptor> {
175        self.tools
176            .values()
177            .map(|t| ToolDescriptor::from_tool(t.as_ref()))
178            .collect()
179    }
180
181    fn list_names(&self) -> Vec<String> {
182        self.tools.keys().cloned().collect()
183    }
184
185    fn contains(&self, name: &str) -> bool {
186        self.tools.contains_key(name)
187    }
188
189    fn count(&self) -> usize {
190        self.tools.len()
191    }
192}
193
194// ============================================================================
195// 工具搜索
196// ============================================================================
197
198/// 工具搜索器
199pub struct ToolSearcher<'a> {
200    registry: &'a ToolRegistry,
201}
202
203impl<'a> ToolSearcher<'a> {
204    /// 创建搜索器
205    pub fn new(registry: &'a ToolRegistry) -> Self {
206        Self { registry }
207    }
208
209    /// 按名称模糊搜索
210    pub fn search_by_name(&self, pattern: &str) -> Vec<ToolDescriptor> {
211        let pattern_lower = pattern.to_lowercase();
212        self.registry
213            .tools
214            .iter()
215            .filter(|(name, _)| name.to_lowercase().contains(&pattern_lower))
216            .map(|(_, tool)| ToolDescriptor::from_tool(tool.as_ref()))
217            .collect()
218    }
219
220    /// 按描述搜索
221    pub fn search_by_description(&self, query: &str) -> Vec<ToolDescriptor> {
222        let query_lower = query.to_lowercase();
223        self.registry
224            .tools
225            .values()
226            .filter(|tool| tool.description().to_lowercase().contains(&query_lower))
227            .map(|tool| ToolDescriptor::from_tool(tool.as_ref()))
228            .collect()
229    }
230
231    /// 按标签搜索
232    pub fn search_by_tag(&self, tag: &str) -> Vec<ToolDescriptor> {
233        self.registry
234            .tools
235            .values()
236            .filter(|tool| {
237                let metadata = tool.metadata();
238                metadata.tags.iter().any(|t| t == tag)
239            })
240            .map(|tool| ToolDescriptor::from_tool(tool.as_ref()))
241            .collect()
242    }
243
244    /// 搜索需要确认的工具
245    pub fn search_dangerous(&self) -> Vec<ToolDescriptor> {
246        self.registry
247            .tools
248            .values()
249            .filter(|tool| tool.metadata().is_dangerous || tool.requires_confirmation())
250            .map(|tool| ToolDescriptor::from_tool(tool.as_ref()))
251            .collect()
252    }
253}