matrixcode_core/mcp/
proxy.rs1use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17#[derive(Clone)]
23pub struct McpToolWrapper {
24 client: Arc<McpClient>,
26 tool_def: McpTool,
28 server_name: String,
30 cached_definition: ToolDefinition,
32}
33
34impl McpToolWrapper {
35 pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
37 let name = format!("{}_{}", server_name, tool_def.name);
39 let description = tool_def
40 .description
41 .clone()
42 .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
43
44 let cached_definition = ToolDefinition {
45 name: name.clone(),
46 description,
47 parameters: tool_def.input_schema.clone(),
48 is_priority: false,
49 };
50
51 Self {
52 client,
53 tool_def,
54 server_name,
55 cached_definition,
56 }
57 }
58
59 pub fn original_name(&self) -> &str {
61 &self.tool_def.name
62 }
63
64 pub fn server_name(&self) -> &str {
66 &self.server_name
67 }
68
69 fn parse_result(&self, result: CallToolResult) -> String {
71 if result.content.is_empty() {
72 return String::new();
73 }
74
75 let mut output = String::new();
76
77 for content in result.content {
78 match content {
79 Content::Text { text } => {
80 output.push_str(&text);
81 output.push('\n');
82 }
83 Content::Image { data, mime_type } => {
84 output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
85 }
86 Content::Resource { resource } => {
87 if let Some(text) = resource.text {
88 output.push_str(&text);
89 output.push('\n');
90 } else if let Some(blob) = resource.blob {
91 output.push_str(&format!(
92 "[Resource: {} ({} bytes)]\n",
93 resource.uri,
94 blob.len()
95 ));
96 } else {
97 output.push_str(&format!("[Resource: {}]\n", resource.uri));
98 }
99 }
100 }
101 }
102
103 output.trim_end().to_string()
104 }
105}
106
107#[async_trait]
108impl Tool for McpToolWrapper {
109 fn definition(&self) -> ToolDefinition {
110 self.cached_definition.clone()
111 }
112
113 async fn execute(&self, params: Value) -> Result<String> {
114 tracing::debug!(
115 "Executing MCP tool '{}' from server '{}'",
116 self.tool_def.name,
117 self.server_name
118 );
119
120 let result = self
122 .client
123 .call_tool(&self.tool_def.name, Some(params))
124 .await
125 .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
126
127 if result.is_error.unwrap_or(false) {
129 let error_msg = self.parse_result(result);
130 return Err(anyhow!("MCP tool error: {}", error_msg));
131 }
132
133 Ok(self.parse_result(result))
135 }
136
137 fn risk_level(&self) -> RiskLevel {
138 let name = &self.tool_def.name;
141
142 if name.contains("read") || name.contains("list") || name.contains("get") {
144 RiskLevel::Safe
145 }
146 else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
148 RiskLevel::Mutating
149 }
150 else if name.contains("write") || name.contains("delete") || name.contains("create") {
152 RiskLevel::Dangerous
153 }
154 else {
156 RiskLevel::Mutating
157 }
158 }
159}
160
161pub struct McpToolManager {
167 clients: RwLock<Vec<Arc<McpClient>>>,
169}
170
171impl McpToolManager {
172 pub fn new() -> Self {
174 Self {
175 clients: RwLock::new(Vec::new()),
176 }
177 }
178
179 pub async fn connect_server(
181 &self,
182 server_name: impl Into<String>,
183 config: super::transport::TransportConfig,
184 ) -> Result<Vec<Box<dyn Tool>>> {
185 let server_name = server_name.into();
186
187 let client = Arc::new(McpClient::connect(&server_name, config).await?);
189
190 if !client.supports_tools().await {
192 tracing::warn!("MCP server '{}' does not support tools", server_name);
193 return Ok(Vec::new());
194 }
195
196 let mcp_tools = client.list_tools().await?;
198 tracing::info!(
199 "MCP server '{}' provided {} tools",
200 server_name,
201 mcp_tools.len()
202 );
203
204 let tools: Vec<Box<dyn Tool>> = mcp_tools
206 .into_iter()
207 .map(|tool| {
208 Box::new(McpToolWrapper::new(
209 client.clone(),
210 tool,
211 server_name.clone(),
212 )) as Box<dyn Tool>
213 })
214 .collect();
215
216 self.clients.write().await.push(client);
218
219 Ok(tools)
220 }
221
222 pub async fn server_count(&self) -> usize {
224 self.clients.read().await.len()
225 }
226
227 pub async fn server_names(&self) -> Vec<String> {
229 self.clients
230 .read()
231 .await
232 .iter()
233 .map(|c| c.server_name().to_string())
234 .collect()
235 }
236
237 pub async fn shutdown(&self) {
239 let clients = self.clients.read().await;
240 for client in clients.iter() {
241 if let Err(e) = client.shutdown().await {
242 tracing::error!(
243 "Failed to shutdown MCP server '{}': {}",
244 client.server_name(),
245 e
246 );
247 }
248 }
249 }
250}
251
252impl Default for McpToolManager {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub async fn connect_mcp_server(
264 server_name: impl Into<String>,
265 config: super::transport::TransportConfig,
266) -> Result<Vec<Box<dyn Tool>>> {
267 let server_name = server_name.into();
268 let client = McpClient::connect(&server_name, config).await?;
269
270 if !client.supports_tools().await {
271 client.shutdown().await?;
272 return Ok(Vec::new());
273 }
274
275 let mcp_tools = client.list_tools().await?;
276 let client = Arc::new(client);
277
278 let tools: Vec<Box<dyn Tool>> = mcp_tools
279 .into_iter()
280 .map(|tool| {
281 Box::new(McpToolWrapper::new(
282 client.clone(),
283 tool,
284 server_name.clone(),
285 )) as Box<dyn Tool>
286 })
287 .collect();
288
289 Ok(tools)
290}
291
292pub async fn connect_mcp_servers_from_config(
294 mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
295) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
296 let manager = McpToolManager::new();
297 let mut all_tools = Vec::new();
298
299 for (name, config) in mcp_config.iter() {
300 if !config.enabled {
301 tracing::debug!("MCP server '{}' is disabled, skipping", name);
302 continue;
303 }
304
305 let transport_config = config
307 .to_transport_config()
308 .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
309
310 tracing::info!("Connecting to MCP server '{}'...", name);
311 let tools = manager.connect_server(name, transport_config).await?;
312
313 if !tools.is_empty() {
314 tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
315 all_tools.extend(tools);
316 }
317 }
318
319 Ok((all_tools, manager))
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_mcp_tool_wrapper_definition() {
328 fn get_risk_level(name: &str) -> RiskLevel {
330 if name.contains("read") || name.contains("list") || name.contains("get") {
331 RiskLevel::Safe
332 } else if name.contains("browser")
333 || name.contains("navigate")
334 || name.contains("click")
335 {
336 RiskLevel::Mutating
337 } else if name.contains("write") || name.contains("delete") || name.contains("create") {
338 RiskLevel::Dangerous
339 } else {
340 RiskLevel::Mutating
341 }
342 }
343
344 assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
345 assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
346 assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
347 }
348}