Skip to main content

matrixcode_core/mcp/
proxy.rs

1//! MCP Tool Proxy
2//!
3//! 将 MCP 服务器的工具映射为 MatrixCode 内置工具
4
5use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17// ============================================================================
18// MCP Tool Wrapper
19// ============================================================================
20
21/// MCP 工具包装器 - 将 MCP 工具映射为内置 Tool
22pub struct McpToolWrapper {
23    /// MCP 客户端
24    client: Arc<McpClient>,
25    /// 工具定义
26    tool_def: McpTool,
27    /// 服务器名称(用于区分不同服务器的同名工具)
28    server_name: String,
29    /// 工具定义(缓存)
30    cached_definition: ToolDefinition,
31}
32
33impl McpToolWrapper {
34    /// 创建 MCP 工具包装器
35    pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
36        // 转换 MCP 工具定义为内置工具定义
37        let name = format!("{}_{}", server_name, tool_def.name);
38        let description = tool_def.description.clone()
39            .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
40        
41        let cached_definition = ToolDefinition {
42            name: name.clone(),
43            description,
44            parameters: tool_def.input_schema.clone(),
45            is_priority: false,
46        };
47        
48        Self {
49            client,
50            tool_def,
51            server_name,
52            cached_definition,
53        }
54    }
55    
56    /// 获取原始工具名称
57    pub fn original_name(&self) -> &str {
58        &self.tool_def.name
59    }
60    
61    /// 获取服务器名称
62    pub fn server_name(&self) -> &str {
63        &self.server_name
64    }
65    
66    /// 解析工具调用结果
67    fn parse_result(&self, result: CallToolResult) -> String {
68        if result.content.is_empty() {
69            return String::new();
70        }
71        
72        let mut output = String::new();
73        
74        for content in result.content {
75            match content {
76                Content::Text { text } => {
77                    output.push_str(&text);
78                    output.push('\n');
79                }
80                Content::Image { data, mime_type } => {
81                    output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
82                }
83                Content::Resource { resource } => {
84                    if let Some(text) = resource.text {
85                        output.push_str(&text);
86                        output.push('\n');
87                    } else if let Some(blob) = resource.blob {
88                        output.push_str(&format!("[Resource: {} ({} bytes)]\n", resource.uri, blob.len()));
89                    } else {
90                        output.push_str(&format!("[Resource: {}]\n", resource.uri));
91                    }
92                }
93            }
94        }
95        
96        output.trim_end().to_string()
97    }
98}
99
100#[async_trait]
101impl Tool for McpToolWrapper {
102    fn definition(&self) -> ToolDefinition {
103        self.cached_definition.clone()
104    }
105    
106    async fn execute(&self, params: Value) -> Result<String> {
107        tracing::debug!(
108            "Executing MCP tool '{}' from server '{}'",
109            self.tool_def.name,
110            self.server_name
111        );
112        
113        // 调用 MCP 工具
114        let result = self.client.call_tool(&self.tool_def.name, Some(params)).await
115            .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
116        
117        // 检查是否为错误
118        if result.is_error.unwrap_or(false) {
119            let error_msg = self.parse_result(result);
120            return Err(anyhow!("MCP tool error: {}", error_msg));
121        }
122        
123        // 解析结果
124        Ok(self.parse_result(result))
125    }
126    
127    fn risk_level(&self) -> RiskLevel {
128        // MCP 工具默认为中等风险
129        // 可以根据工具名称或模式设置不同风险级别
130        let name = &self.tool_def.name;
131        
132        // 只读操作为安全
133        if name.contains("read") || name.contains("list") || name.contains("get") {
134            RiskLevel::Safe
135        }
136        // 浏览器操作为中等风险
137        else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
138            RiskLevel::Mutating
139        }
140        // 写入操作为高风险
141        else if name.contains("write") || name.contains("delete") || name.contains("create") {
142            RiskLevel::Dangerous
143        }
144        // 默认中等风险
145        else {
146            RiskLevel::Mutating
147        }
148    }
149}
150
151// ============================================================================
152// MCP Tool Manager
153// ============================================================================
154
155/// MCP 工具管理器 - 管理多个 MCP 服务器的连接
156pub struct McpToolManager {
157    /// 已连接的 MCP 客户端
158    clients: RwLock<Vec<Arc<McpClient>>>,
159}
160
161impl McpToolManager {
162    /// 创建工具管理器
163    pub fn new() -> Self {
164        Self {
165            clients: RwLock::new(Vec::new()),
166        }
167    }
168    
169    /// 连接 MCP 服务器并获取工具
170    pub async fn connect_server(
171        &self,
172        server_name: impl Into<String>,
173        config: super::transport::TransportConfig,
174    ) -> Result<Vec<Box<dyn Tool>>> {
175        let server_name = server_name.into();
176        
177        // 连接服务器
178        let client = Arc::new(McpClient::connect(&server_name, config).await?);
179        
180        // 检查是否支持工具
181        if !client.supports_tools().await {
182            tracing::warn!("MCP server '{}' does not support tools", server_name);
183            return Ok(Vec::new());
184        }
185        
186        // 获取工具列表
187        let mcp_tools = client.list_tools().await?;
188        tracing::info!(
189            "MCP server '{}' provided {} tools",
190            server_name,
191            mcp_tools.len()
192        );
193        
194        // 转换为内置工具
195        let tools: Vec<Box<dyn Tool>> = mcp_tools
196            .into_iter()
197            .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
198            .collect();
199        
200        // 缓存客户端(用于后续 shutdown)
201        self.clients.write().await.push(client);
202        
203        Ok(tools)
204    }
205    
206    /// 获取已连接的服务器数量
207    pub async fn server_count(&self) -> usize {
208        self.clients.read().await.len()
209    }
210    
211    /// 获取所有已连接的服务器名称
212    pub async fn server_names(&self) -> Vec<String> {
213        self.clients.read().await.iter()
214            .map(|c| c.server_name().to_string())
215            .collect()
216    }
217    
218    /// 关闭所有连接
219    pub async fn shutdown(&self) {
220        let clients = self.clients.read().await;
221        for client in clients.iter() {
222            if let Err(e) = client.shutdown().await {
223                tracing::error!("Failed to shutdown MCP server '{}': {}", client.server_name(), e);
224            }
225        }
226    }
227}
228
229impl Default for McpToolManager {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235// ============================================================================
236// Convenience Functions
237// ============================================================================
238
239/// 连接单个 MCP 服务器并返回工具列表
240pub async fn connect_mcp_server(
241    server_name: impl Into<String>,
242    config: super::transport::TransportConfig,
243) -> Result<Vec<Box<dyn Tool>>> {
244    let server_name = server_name.into();
245    let client = McpClient::connect(&server_name, config).await?;
246    
247    if !client.supports_tools().await {
248        client.shutdown().await?;
249        return Ok(Vec::new());
250    }
251    
252    let mcp_tools = client.list_tools().await?;
253    let client = Arc::new(client);
254    
255    let tools: Vec<Box<dyn Tool>> = mcp_tools
256        .into_iter()
257        .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
258        .collect();
259    
260    Ok(tools)
261}
262
263/// 从配置连接所有 MCP 服务器并返回工具列表
264pub async fn connect_mcp_servers_from_config(
265    mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
266) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
267    let manager = McpToolManager::new();
268    let mut all_tools = Vec::new();
269    
270    for (name, config) in mcp_config.iter() {
271        if !config.enabled {
272            tracing::debug!("MCP server '{}' is disabled, skipping", name);
273            continue;
274        }
275        
276        // 从配置创建 TransportConfig
277        let transport_config = config.to_transport_config()
278            .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
279        
280        tracing::info!("Connecting to MCP server '{}'...", name);
281        let tools = manager.connect_server(name, transport_config).await?;
282        
283        if !tools.is_empty() {
284            tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
285            all_tools.extend(tools);
286        }
287    }
288    
289    Ok((all_tools, manager))
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    
296    #[test]
297    fn test_mcp_tool_wrapper_definition() {
298        // 测试风险级别判断逻辑
299        fn get_risk_level(name: &str) -> RiskLevel {
300            if name.contains("read") || name.contains("list") || name.contains("get") {
301                RiskLevel::Safe
302            } else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
303                RiskLevel::Mutating
304            } else if name.contains("write") || name.contains("delete") || name.contains("create") {
305                RiskLevel::Dangerous
306            } else {
307                RiskLevel::Mutating
308            }
309        }
310        
311        assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
312        assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
313        assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
314    }
315}