matrixcode_core/mcp/
proxy.rs1use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17pub struct McpToolWrapper {
23 client: Arc<McpClient>,
25 tool_def: McpTool,
27 server_name: String,
29 cached_definition: ToolDefinition,
31}
32
33impl McpToolWrapper {
34 pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
36 let name = format!("{}_{}", server_name, tool_def.name);
38 let description = tool_def.description.clone()
39 .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
40
41 let cached_definition = ToolDefinition {
42 name: name.clone(),
43 description,
44 parameters: tool_def.input_schema.clone(),
45 is_priority: false,
46 };
47
48 Self {
49 client,
50 tool_def,
51 server_name,
52 cached_definition,
53 }
54 }
55
56 pub fn original_name(&self) -> &str {
58 &self.tool_def.name
59 }
60
61 pub fn server_name(&self) -> &str {
63 &self.server_name
64 }
65
66 fn parse_result(&self, result: CallToolResult) -> String {
68 if result.content.is_empty() {
69 return String::new();
70 }
71
72 let mut output = String::new();
73
74 for content in result.content {
75 match content {
76 Content::Text { text } => {
77 output.push_str(&text);
78 output.push('\n');
79 }
80 Content::Image { data, mime_type } => {
81 output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
82 }
83 Content::Resource { resource } => {
84 if let Some(text) = resource.text {
85 output.push_str(&text);
86 output.push('\n');
87 } else if let Some(blob) = resource.blob {
88 output.push_str(&format!("[Resource: {} ({} bytes)]\n", resource.uri, blob.len()));
89 } else {
90 output.push_str(&format!("[Resource: {}]\n", resource.uri));
91 }
92 }
93 }
94 }
95
96 output.trim_end().to_string()
97 }
98}
99
100#[async_trait]
101impl Tool for McpToolWrapper {
102 fn definition(&self) -> ToolDefinition {
103 self.cached_definition.clone()
104 }
105
106 async fn execute(&self, params: Value) -> Result<String> {
107 tracing::debug!(
108 "Executing MCP tool '{}' from server '{}'",
109 self.tool_def.name,
110 self.server_name
111 );
112
113 let result = self.client.call_tool(&self.tool_def.name, Some(params)).await
115 .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
116
117 if result.is_error.unwrap_or(false) {
119 let error_msg = self.parse_result(result);
120 return Err(anyhow!("MCP tool error: {}", error_msg));
121 }
122
123 Ok(self.parse_result(result))
125 }
126
127 fn risk_level(&self) -> RiskLevel {
128 let name = &self.tool_def.name;
131
132 if name.contains("read") || name.contains("list") || name.contains("get") {
134 RiskLevel::Safe
135 }
136 else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
138 RiskLevel::Mutating
139 }
140 else if name.contains("write") || name.contains("delete") || name.contains("create") {
142 RiskLevel::Dangerous
143 }
144 else {
146 RiskLevel::Mutating
147 }
148 }
149}
150
151pub struct McpToolManager {
157 clients: RwLock<Vec<Arc<McpClient>>>,
159}
160
161impl McpToolManager {
162 pub fn new() -> Self {
164 Self {
165 clients: RwLock::new(Vec::new()),
166 }
167 }
168
169 pub async fn connect_server(
171 &self,
172 server_name: impl Into<String>,
173 config: super::transport::TransportConfig,
174 ) -> Result<Vec<Box<dyn Tool>>> {
175 let server_name = server_name.into();
176
177 let client = Arc::new(McpClient::connect(&server_name, config).await?);
179
180 if !client.supports_tools().await {
182 tracing::warn!("MCP server '{}' does not support tools", server_name);
183 return Ok(Vec::new());
184 }
185
186 let mcp_tools = client.list_tools().await?;
188 tracing::info!(
189 "MCP server '{}' provided {} tools",
190 server_name,
191 mcp_tools.len()
192 );
193
194 let tools: Vec<Box<dyn Tool>> = mcp_tools
196 .into_iter()
197 .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
198 .collect();
199
200 self.clients.write().await.push(client);
202
203 Ok(tools)
204 }
205
206 pub async fn server_count(&self) -> usize {
208 self.clients.read().await.len()
209 }
210
211 pub async fn server_names(&self) -> Vec<String> {
213 self.clients.read().await.iter()
214 .map(|c| c.server_name().to_string())
215 .collect()
216 }
217
218 pub async fn shutdown(&self) {
220 let clients = self.clients.read().await;
221 for client in clients.iter() {
222 if let Err(e) = client.shutdown().await {
223 tracing::error!("Failed to shutdown MCP server '{}': {}", client.server_name(), e);
224 }
225 }
226 }
227}
228
229impl Default for McpToolManager {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235pub async fn connect_mcp_server(
241 server_name: impl Into<String>,
242 config: super::transport::TransportConfig,
243) -> Result<Vec<Box<dyn Tool>>> {
244 let server_name = server_name.into();
245 let client = McpClient::connect(&server_name, config).await?;
246
247 if !client.supports_tools().await {
248 client.shutdown().await?;
249 return Ok(Vec::new());
250 }
251
252 let mcp_tools = client.list_tools().await?;
253 let client = Arc::new(client);
254
255 let tools: Vec<Box<dyn Tool>> = mcp_tools
256 .into_iter()
257 .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
258 .collect();
259
260 Ok(tools)
261}
262
263pub async fn connect_mcp_servers_from_config(
265 mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
266) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
267 let manager = McpToolManager::new();
268 let mut all_tools = Vec::new();
269
270 for (name, config) in mcp_config.iter() {
271 if !config.enabled {
272 tracing::debug!("MCP server '{}' is disabled, skipping", name);
273 continue;
274 }
275
276 let transport_config = config.to_transport_config()
278 .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
279
280 tracing::info!("Connecting to MCP server '{}'...", name);
281 let tools = manager.connect_server(name, transport_config).await?;
282
283 if !tools.is_empty() {
284 tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
285 all_tools.extend(tools);
286 }
287 }
288
289 Ok((all_tools, manager))
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_mcp_tool_wrapper_definition() {
298 fn get_risk_level(name: &str) -> RiskLevel {
300 if name.contains("read") || name.contains("list") || name.contains("get") {
301 RiskLevel::Safe
302 } else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
303 RiskLevel::Mutating
304 } else if name.contains("write") || name.contains("delete") || name.contains("create") {
305 RiskLevel::Dangerous
306 } else {
307 RiskLevel::Mutating
308 }
309 }
310
311 assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
312 assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
313 assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
314 }
315}