Skip to main content

matrixcode_core/mcp/
proxy.rs

1//! MCP Tool Proxy
2//!
3//! 将 MCP 服务器的工具映射为 MatrixCode 内置工具
4
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17// ============================================================================
18// MCP Tool Wrapper
19// ============================================================================
20
21/// MCP 工具包装器 - 将 MCP 工具映射为内置 Tool
22#[derive(Clone)]
23pub struct McpToolWrapper {
24    /// MCP 客户端
25    client: Arc<McpClient>,
26    /// 工具定义
27    tool_def: McpTool,
28    /// 服务器名称(用于区分不同服务器的同名工具)
29    server_name: String,
30    /// 工具定义(缓存)
31    cached_definition: ToolDefinition,
32}
33
34impl McpToolWrapper {
35    /// 创建 MCP 工具包装器
36    pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
37        // 转换 MCP 工具定义为内置工具定义
38        let name = format!("{}_{}", server_name, tool_def.name);
39        let description = tool_def
40            .description
41            .clone()
42            .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
43
44        let cached_definition = ToolDefinition {
45            name: name.clone(),
46            description,
47            parameters: tool_def.input_schema.clone(),
48            is_priority: false,
49        };
50
51        Self {
52            client,
53            tool_def,
54            server_name,
55            cached_definition,
56        }
57    }
58
59    /// 获取原始工具名称
60    pub fn original_name(&self) -> &str {
61        &self.tool_def.name
62    }
63
64    /// 获取服务器名称
65    pub fn server_name(&self) -> &str {
66        &self.server_name
67    }
68
69    /// 解析工具调用结果
70    fn parse_result(&self, result: CallToolResult) -> String {
71        if result.content.is_empty() {
72            return String::new();
73        }
74
75        let mut output = String::new();
76
77        for content in result.content {
78            match content {
79                Content::Text { text } => {
80                    output.push_str(&text);
81                    output.push('\n');
82                }
83                Content::Image { data, mime_type } => {
84                    output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
85                }
86                Content::Resource { resource } => {
87                    if let Some(text) = resource.text {
88                        output.push_str(&text);
89                        output.push('\n');
90                    } else if let Some(blob) = resource.blob {
91                        output.push_str(&format!(
92                            "[Resource: {} ({} bytes)]\n",
93                            resource.uri,
94                            blob.len()
95                        ));
96                    } else {
97                        output.push_str(&format!("[Resource: {}]\n", resource.uri));
98                    }
99                }
100            }
101        }
102
103        output.trim_end().to_string()
104    }
105}
106
107#[async_trait]
108impl Tool for McpToolWrapper {
109    fn definition(&self) -> ToolDefinition {
110        self.cached_definition.clone()
111    }
112
113    async fn execute(&self, params: Value) -> Result<String> {
114        tracing::debug!(
115            "Executing MCP tool '{}' from server '{}'",
116            self.tool_def.name,
117            self.server_name
118        );
119
120        // 调用 MCP 工具
121        let result = self
122            .client
123            .call_tool(&self.tool_def.name, Some(params))
124            .await
125            .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
126
127        // 检查是否为错误
128        if result.is_error.unwrap_or(false) {
129            let error_msg = self.parse_result(result);
130            return Err(anyhow!("MCP tool error: {}", error_msg));
131        }
132
133        // 解析结果
134        Ok(self.parse_result(result))
135    }
136
137    fn risk_level(&self) -> RiskLevel {
138        // MCP 工具默认为中等风险
139        // 可以根据工具名称或模式设置不同风险级别
140        let name = &self.tool_def.name;
141
142        // 只读操作为安全
143        if name.contains("read") || name.contains("list") || name.contains("get") {
144            RiskLevel::Safe
145        }
146        // 浏览器操作为中等风险
147        else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
148            RiskLevel::Mutating
149        }
150        // 写入操作为高风险
151        else if name.contains("write") || name.contains("delete") || name.contains("create") {
152            RiskLevel::Dangerous
153        }
154        // 默认中等风险
155        else {
156            RiskLevel::Mutating
157        }
158    }
159}
160
161// ============================================================================
162// MCP Tool Manager
163// ============================================================================
164
165/// MCP 工具管理器 - 管理多个 MCP 服务器的连接
166pub struct McpToolManager {
167    /// 已连接的 MCP 客户端
168    clients: RwLock<Vec<Arc<McpClient>>>,
169}
170
171impl McpToolManager {
172    /// 创建工具管理器
173    pub fn new() -> Self {
174        Self {
175            clients: RwLock::new(Vec::new()),
176        }
177    }
178
179    /// 连接 MCP 服务器并获取工具
180    pub async fn connect_server(
181        &self,
182        server_name: impl Into<String>,
183        config: super::transport::TransportConfig,
184    ) -> Result<Vec<Box<dyn Tool>>> {
185        let server_name = server_name.into();
186
187        // 连接服务器
188        let client = Arc::new(McpClient::connect(&server_name, config).await?);
189
190        // 检查是否支持工具
191        if !client.supports_tools().await {
192            tracing::warn!("MCP server '{}' does not support tools", server_name);
193            return Ok(Vec::new());
194        }
195
196        // 获取工具列表
197        let mcp_tools = client.list_tools().await?;
198        tracing::info!(
199            "MCP server '{}' provided {} tools",
200            server_name,
201            mcp_tools.len()
202        );
203
204        // 转换为内置工具
205        let tools: Vec<Box<dyn Tool>> = mcp_tools
206            .into_iter()
207            .map(|tool| {
208                Box::new(McpToolWrapper::new(
209                    client.clone(),
210                    tool,
211                    server_name.clone(),
212                )) as Box<dyn Tool>
213            })
214            .collect();
215
216        // 缓存客户端(用于后续 shutdown)
217        self.clients.write().await.push(client);
218
219        Ok(tools)
220    }
221
222    /// 获取已连接的服务器数量
223    pub async fn server_count(&self) -> usize {
224        self.clients.read().await.len()
225    }
226
227    /// 获取所有已连接的服务器名称
228    pub async fn server_names(&self) -> Vec<String> {
229        self.clients
230            .read()
231            .await
232            .iter()
233            .map(|c| c.server_name().to_string())
234            .collect()
235    }
236
237    /// 关闭所有连接
238    pub async fn shutdown(&self) {
239        let clients = self.clients.read().await;
240        for client in clients.iter() {
241            if let Err(e) = client.shutdown().await {
242                tracing::error!(
243                    "Failed to shutdown MCP server '{}': {}",
244                    client.server_name(),
245                    e
246                );
247            }
248        }
249    }
250}
251
252impl Default for McpToolManager {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258// ============================================================================
259// Convenience Functions
260// ============================================================================
261
262/// 连接单个 MCP 服务器并返回工具列表
263pub async fn connect_mcp_server(
264    server_name: impl Into<String>,
265    config: super::transport::TransportConfig,
266) -> Result<Vec<Box<dyn Tool>>> {
267    let server_name = server_name.into();
268    let client = McpClient::connect(&server_name, config).await?;
269
270    if !client.supports_tools().await {
271        client.shutdown().await?;
272        return Ok(Vec::new());
273    }
274
275    let mcp_tools = client.list_tools().await?;
276    let client = Arc::new(client);
277
278    let tools: Vec<Box<dyn Tool>> = mcp_tools
279        .into_iter()
280        .map(|tool| {
281            Box::new(McpToolWrapper::new(
282                client.clone(),
283                tool,
284                server_name.clone(),
285            )) as Box<dyn Tool>
286        })
287        .collect();
288
289    Ok(tools)
290}
291
292/// 从配置连接所有 MCP 服务器并返回工具列表
293pub async fn connect_mcp_servers_from_config(
294    mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
295) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
296    let manager = McpToolManager::new();
297    let mut all_tools = Vec::new();
298
299    for (name, config) in mcp_config.iter() {
300        if !config.enabled {
301            tracing::debug!("MCP server '{}' is disabled, skipping", name);
302            continue;
303        }
304
305        // 从配置创建 TransportConfig
306        let transport_config = config
307            .to_transport_config()
308            .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
309
310        tracing::info!("Connecting to MCP server '{}'...", name);
311        let tools = manager.connect_server(name, transport_config).await?;
312
313        if !tools.is_empty() {
314            tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
315            all_tools.extend(tools);
316        }
317    }
318
319    Ok((all_tools, manager))
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_mcp_tool_wrapper_definition() {
328        // 测试风险级别判断逻辑
329        fn get_risk_level(name: &str) -> RiskLevel {
330            if name.contains("read") || name.contains("list") || name.contains("get") {
331                RiskLevel::Safe
332            } else if name.contains("browser")
333                || name.contains("navigate")
334                || name.contains("click")
335            {
336                RiskLevel::Mutating
337            } else if name.contains("write") || name.contains("delete") || name.contains("create") {
338                RiskLevel::Dangerous
339            } else {
340                RiskLevel::Mutating
341            }
342        }
343
344        assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
345        assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
346        assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
347    }
348}