Skip to main content

matrixcode_core/mcp/
proxy.rs

1//! MCP Tool Proxy
2//!
3//! 将 MCP 服务器的工具映射为 MatrixCode 内置工具
4
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17// ============================================================================
18// MCP Tool Wrapper
19// ============================================================================
20
21/// MCP 工具包装器 - 将 MCP 工具映射为内置 Tool
22#[derive(Clone)]
23pub struct McpToolWrapper {
24    /// MCP 客户端
25    client: Arc<McpClient>,
26    /// 工具定义
27    tool_def: McpTool,
28    /// 服务器名称(用于区分不同服务器的同名工具)
29    server_name: String,
30    /// 工具定义(缓存)
31    cached_definition: ToolDefinition,
32}
33
34impl McpToolWrapper {
35    /// 创建 MCP 工具包装器
36    pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
37        // 转换 MCP 工具定义为内置工具定义
38        let name = format!("{}_{}", server_name, tool_def.name);
39        let description = tool_def.description.clone()
40            .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
41        
42        let cached_definition = ToolDefinition {
43            name: name.clone(),
44            description,
45            parameters: tool_def.input_schema.clone(),
46            is_priority: false,
47        };
48        
49        Self {
50            client,
51            tool_def,
52            server_name,
53            cached_definition,
54        }
55    }
56    
57    /// 获取原始工具名称
58    pub fn original_name(&self) -> &str {
59        &self.tool_def.name
60    }
61    
62    /// 获取服务器名称
63    pub fn server_name(&self) -> &str {
64        &self.server_name
65    }
66    
67    /// 解析工具调用结果
68    fn parse_result(&self, result: CallToolResult) -> String {
69        if result.content.is_empty() {
70            return String::new();
71        }
72        
73        let mut output = String::new();
74        
75        for content in result.content {
76            match content {
77                Content::Text { text } => {
78                    output.push_str(&text);
79                    output.push('\n');
80                }
81                Content::Image { data, mime_type } => {
82                    output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
83                }
84                Content::Resource { resource } => {
85                    if let Some(text) = resource.text {
86                        output.push_str(&text);
87                        output.push('\n');
88                    } else if let Some(blob) = resource.blob {
89                        output.push_str(&format!("[Resource: {} ({} bytes)]\n", resource.uri, blob.len()));
90                    } else {
91                        output.push_str(&format!("[Resource: {}]\n", resource.uri));
92                    }
93                }
94            }
95        }
96        
97        output.trim_end().to_string()
98    }
99}
100
101#[async_trait]
102impl Tool for McpToolWrapper {
103    fn definition(&self) -> ToolDefinition {
104        self.cached_definition.clone()
105    }
106    
107    async fn execute(&self, params: Value) -> Result<String> {
108        tracing::debug!(
109            "Executing MCP tool '{}' from server '{}'",
110            self.tool_def.name,
111            self.server_name
112        );
113        
114        // 调用 MCP 工具
115        let result = self.client.call_tool(&self.tool_def.name, Some(params)).await
116            .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
117        
118        // 检查是否为错误
119        if result.is_error.unwrap_or(false) {
120            let error_msg = self.parse_result(result);
121            return Err(anyhow!("MCP tool error: {}", error_msg));
122        }
123        
124        // 解析结果
125        Ok(self.parse_result(result))
126    }
127    
128    fn risk_level(&self) -> RiskLevel {
129        // MCP 工具默认为中等风险
130        // 可以根据工具名称或模式设置不同风险级别
131        let name = &self.tool_def.name;
132        
133        // 只读操作为安全
134        if name.contains("read") || name.contains("list") || name.contains("get") {
135            RiskLevel::Safe
136        }
137        // 浏览器操作为中等风险
138        else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
139            RiskLevel::Mutating
140        }
141        // 写入操作为高风险
142        else if name.contains("write") || name.contains("delete") || name.contains("create") {
143            RiskLevel::Dangerous
144        }
145        // 默认中等风险
146        else {
147            RiskLevel::Mutating
148        }
149    }
150}
151
152// ============================================================================
153// MCP Tool Manager
154// ============================================================================
155
156/// MCP 工具管理器 - 管理多个 MCP 服务器的连接
157pub struct McpToolManager {
158    /// 已连接的 MCP 客户端
159    clients: RwLock<Vec<Arc<McpClient>>>,
160}
161
162impl McpToolManager {
163    /// 创建工具管理器
164    pub fn new() -> Self {
165        Self {
166            clients: RwLock::new(Vec::new()),
167        }
168    }
169    
170    /// 连接 MCP 服务器并获取工具
171    pub async fn connect_server(
172        &self,
173        server_name: impl Into<String>,
174        config: super::transport::TransportConfig,
175    ) -> Result<Vec<Box<dyn Tool>>> {
176        let server_name = server_name.into();
177        
178        // 连接服务器
179        let client = Arc::new(McpClient::connect(&server_name, config).await?);
180        
181        // 检查是否支持工具
182        if !client.supports_tools().await {
183            tracing::warn!("MCP server '{}' does not support tools", server_name);
184            return Ok(Vec::new());
185        }
186        
187        // 获取工具列表
188        let mcp_tools = client.list_tools().await?;
189        tracing::info!(
190            "MCP server '{}' provided {} tools",
191            server_name,
192            mcp_tools.len()
193        );
194        
195        // 转换为内置工具
196        let tools: Vec<Box<dyn Tool>> = mcp_tools
197            .into_iter()
198            .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
199            .collect();
200        
201        // 缓存客户端(用于后续 shutdown)
202        self.clients.write().await.push(client);
203        
204        Ok(tools)
205    }
206    
207    /// 获取已连接的服务器数量
208    pub async fn server_count(&self) -> usize {
209        self.clients.read().await.len()
210    }
211    
212    /// 获取所有已连接的服务器名称
213    pub async fn server_names(&self) -> Vec<String> {
214        self.clients.read().await.iter()
215            .map(|c| c.server_name().to_string())
216            .collect()
217    }
218    
219    /// 关闭所有连接
220    pub async fn shutdown(&self) {
221        let clients = self.clients.read().await;
222        for client in clients.iter() {
223            if let Err(e) = client.shutdown().await {
224                tracing::error!("Failed to shutdown MCP server '{}': {}", client.server_name(), e);
225            }
226        }
227    }
228}
229
230impl Default for McpToolManager {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236// ============================================================================
237// Convenience Functions
238// ============================================================================
239
240/// 连接单个 MCP 服务器并返回工具列表
241pub async fn connect_mcp_server(
242    server_name: impl Into<String>,
243    config: super::transport::TransportConfig,
244) -> Result<Vec<Box<dyn Tool>>> {
245    let server_name = server_name.into();
246    let client = McpClient::connect(&server_name, config).await?;
247    
248    if !client.supports_tools().await {
249        client.shutdown().await?;
250        return Ok(Vec::new());
251    }
252    
253    let mcp_tools = client.list_tools().await?;
254    let client = Arc::new(client);
255    
256    let tools: Vec<Box<dyn Tool>> = mcp_tools
257        .into_iter()
258        .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
259        .collect();
260    
261    Ok(tools)
262}
263
264/// 从配置连接所有 MCP 服务器并返回工具列表
265pub async fn connect_mcp_servers_from_config(
266    mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
267) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
268    let manager = McpToolManager::new();
269    let mut all_tools = Vec::new();
270    
271    for (name, config) in mcp_config.iter() {
272        if !config.enabled {
273            tracing::debug!("MCP server '{}' is disabled, skipping", name);
274            continue;
275        }
276        
277        // 从配置创建 TransportConfig
278        let transport_config = config.to_transport_config()
279            .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
280        
281        tracing::info!("Connecting to MCP server '{}'...", name);
282        let tools = manager.connect_server(name, transport_config).await?;
283        
284        if !tools.is_empty() {
285            tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
286            all_tools.extend(tools);
287        }
288    }
289    
290    Ok((all_tools, manager))
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    
297    #[test]
298    fn test_mcp_tool_wrapper_definition() {
299        // 测试风险级别判断逻辑
300        fn get_risk_level(name: &str) -> RiskLevel {
301            if name.contains("read") || name.contains("list") || name.contains("get") {
302                RiskLevel::Safe
303            } else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
304                RiskLevel::Mutating
305            } else if name.contains("write") || name.contains("delete") || name.contains("create") {
306                RiskLevel::Dangerous
307            } else {
308                RiskLevel::Mutating
309            }
310        }
311        
312        assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
313        assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
314        assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
315    }
316}