matrixcode_core/mcp/
proxy.rs1use anyhow::{anyhow, Result};
6use async_trait::async_trait;
7use serde_json::Value;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::approval::RiskLevel;
12use crate::tools::{Tool, ToolDefinition};
13
14use super::client::McpClient;
15use super::types::{CallToolResult, Content, Tool as McpTool};
16
17#[derive(Clone)]
23pub struct McpToolWrapper {
24 client: Arc<McpClient>,
26 tool_def: McpTool,
28 server_name: String,
30 cached_definition: ToolDefinition,
32}
33
34impl McpToolWrapper {
35 pub fn new(client: Arc<McpClient>, tool_def: McpTool, server_name: String) -> Self {
37 let name = format!("{}_{}", server_name, tool_def.name);
39 let description = tool_def.description.clone()
40 .unwrap_or_else(|| format!("MCP tool: {}", tool_def.name));
41
42 let cached_definition = ToolDefinition {
43 name: name.clone(),
44 description,
45 parameters: tool_def.input_schema.clone(),
46 is_priority: false,
47 };
48
49 Self {
50 client,
51 tool_def,
52 server_name,
53 cached_definition,
54 }
55 }
56
57 pub fn original_name(&self) -> &str {
59 &self.tool_def.name
60 }
61
62 pub fn server_name(&self) -> &str {
64 &self.server_name
65 }
66
67 fn parse_result(&self, result: CallToolResult) -> String {
69 if result.content.is_empty() {
70 return String::new();
71 }
72
73 let mut output = String::new();
74
75 for content in result.content {
76 match content {
77 Content::Text { text } => {
78 output.push_str(&text);
79 output.push('\n');
80 }
81 Content::Image { data, mime_type } => {
82 output.push_str(&format!("[Image: {} ({} bytes)]\n", mime_type, data.len()));
83 }
84 Content::Resource { resource } => {
85 if let Some(text) = resource.text {
86 output.push_str(&text);
87 output.push('\n');
88 } else if let Some(blob) = resource.blob {
89 output.push_str(&format!("[Resource: {} ({} bytes)]\n", resource.uri, blob.len()));
90 } else {
91 output.push_str(&format!("[Resource: {}]\n", resource.uri));
92 }
93 }
94 }
95 }
96
97 output.trim_end().to_string()
98 }
99}
100
101#[async_trait]
102impl Tool for McpToolWrapper {
103 fn definition(&self) -> ToolDefinition {
104 self.cached_definition.clone()
105 }
106
107 async fn execute(&self, params: Value) -> Result<String> {
108 tracing::debug!(
109 "Executing MCP tool '{}' from server '{}'",
110 self.tool_def.name,
111 self.server_name
112 );
113
114 let result = self.client.call_tool(&self.tool_def.name, Some(params)).await
116 .map_err(|e| anyhow!("MCP tool '{}' failed: {}", self.cached_definition.name, e))?;
117
118 if result.is_error.unwrap_or(false) {
120 let error_msg = self.parse_result(result);
121 return Err(anyhow!("MCP tool error: {}", error_msg));
122 }
123
124 Ok(self.parse_result(result))
126 }
127
128 fn risk_level(&self) -> RiskLevel {
129 let name = &self.tool_def.name;
132
133 if name.contains("read") || name.contains("list") || name.contains("get") {
135 RiskLevel::Safe
136 }
137 else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
139 RiskLevel::Mutating
140 }
141 else if name.contains("write") || name.contains("delete") || name.contains("create") {
143 RiskLevel::Dangerous
144 }
145 else {
147 RiskLevel::Mutating
148 }
149 }
150}
151
152pub struct McpToolManager {
158 clients: RwLock<Vec<Arc<McpClient>>>,
160}
161
162impl McpToolManager {
163 pub fn new() -> Self {
165 Self {
166 clients: RwLock::new(Vec::new()),
167 }
168 }
169
170 pub async fn connect_server(
172 &self,
173 server_name: impl Into<String>,
174 config: super::transport::TransportConfig,
175 ) -> Result<Vec<Box<dyn Tool>>> {
176 let server_name = server_name.into();
177
178 let client = Arc::new(McpClient::connect(&server_name, config).await?);
180
181 if !client.supports_tools().await {
183 tracing::warn!("MCP server '{}' does not support tools", server_name);
184 return Ok(Vec::new());
185 }
186
187 let mcp_tools = client.list_tools().await?;
189 tracing::info!(
190 "MCP server '{}' provided {} tools",
191 server_name,
192 mcp_tools.len()
193 );
194
195 let tools: Vec<Box<dyn Tool>> = mcp_tools
197 .into_iter()
198 .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
199 .collect();
200
201 self.clients.write().await.push(client);
203
204 Ok(tools)
205 }
206
207 pub async fn server_count(&self) -> usize {
209 self.clients.read().await.len()
210 }
211
212 pub async fn server_names(&self) -> Vec<String> {
214 self.clients.read().await.iter()
215 .map(|c| c.server_name().to_string())
216 .collect()
217 }
218
219 pub async fn shutdown(&self) {
221 let clients = self.clients.read().await;
222 for client in clients.iter() {
223 if let Err(e) = client.shutdown().await {
224 tracing::error!("Failed to shutdown MCP server '{}': {}", client.server_name(), e);
225 }
226 }
227 }
228}
229
230impl Default for McpToolManager {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236pub async fn connect_mcp_server(
242 server_name: impl Into<String>,
243 config: super::transport::TransportConfig,
244) -> Result<Vec<Box<dyn Tool>>> {
245 let server_name = server_name.into();
246 let client = McpClient::connect(&server_name, config).await?;
247
248 if !client.supports_tools().await {
249 client.shutdown().await?;
250 return Ok(Vec::new());
251 }
252
253 let mcp_tools = client.list_tools().await?;
254 let client = Arc::new(client);
255
256 let tools: Vec<Box<dyn Tool>> = mcp_tools
257 .into_iter()
258 .map(|tool| Box::new(McpToolWrapper::new(client.clone(), tool, server_name.clone())) as Box<dyn Tool>)
259 .collect();
260
261 Ok(tools)
262}
263
264pub async fn connect_mcp_servers_from_config(
266 mcp_config: &std::collections::HashMap<String, super::config::McpServerConfig>,
267) -> Result<(Vec<Box<dyn Tool>>, McpToolManager)> {
268 let manager = McpToolManager::new();
269 let mut all_tools = Vec::new();
270
271 for (name, config) in mcp_config.iter() {
272 if !config.enabled {
273 tracing::debug!("MCP server '{}' is disabled, skipping", name);
274 continue;
275 }
276
277 let transport_config = config.to_transport_config()
279 .map_err(|e| anyhow!("Failed to create transport config for '{}': {}", name, e))?;
280
281 tracing::info!("Connecting to MCP server '{}'...", name);
282 let tools = manager.connect_server(name, transport_config).await?;
283
284 if !tools.is_empty() {
285 tracing::info!("MCP server '{}' provided {} tools", name, tools.len());
286 all_tools.extend(tools);
287 }
288 }
289
290 Ok((all_tools, manager))
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_mcp_tool_wrapper_definition() {
299 fn get_risk_level(name: &str) -> RiskLevel {
301 if name.contains("read") || name.contains("list") || name.contains("get") {
302 RiskLevel::Safe
303 } else if name.contains("browser") || name.contains("navigate") || name.contains("click") {
304 RiskLevel::Mutating
305 } else if name.contains("write") || name.contains("delete") || name.contains("create") {
306 RiskLevel::Dangerous
307 } else {
308 RiskLevel::Mutating
309 }
310 }
311
312 assert_eq!(get_risk_level("read_file"), RiskLevel::Safe);
313 assert_eq!(get_risk_level("browser_navigate"), RiskLevel::Mutating);
314 assert_eq!(get_risk_level("write_file"), RiskLevel::Dangerous);
315 }
316}