rust_agent/mcp/
client.rs

1// MCP client interface definition
2use anyhow::Error;
3use log::{debug, info, warn};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11use crate::mcp::JSONRPCRequest;
12use crate::mcp::JSONRPCResponse;
13
14// MCP tool structure
15#[derive(Debug,Clone)]
16pub struct McpTool {
17    pub name: String,
18    pub description: String,
19    // Other tool-related fields
20}
21
22// Simple MCP client implementation, modify SimpleMcpClient structure, add tool handler field
23#[derive(Clone)]
24pub struct SimpleMcpClient {
25    pub url: String,
26    pub available_tools: Vec<McpTool>,
27    // Use Arc to wrap tool handlers to support cloning
28    pub tool_handlers: HashMap<String, Arc<dyn Fn(HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<Value, Error>> + Send>> + Send + Sync>>,
29    // Connection status flag, indicates whether successfully connected to MCP server
30    pub is_mcp_server_connected: Arc<Mutex<bool>>,
31}
32
33// Implement methods for SimpleMcpClient structure
34impl SimpleMcpClient {
35    pub fn new(url: String) -> Self {
36        Self {
37            url,
38            available_tools: Vec::new(),
39            tool_handlers: HashMap::new(),
40            is_mcp_server_connected: Arc::new(Mutex::new(false)), // Initial state is disconnected
41        }
42    }
43    
44    // Add custom tool method
45    pub fn add_tool(&mut self, tool: McpTool) {
46        self.available_tools.push(tool);
47    }
48    
49    // Register tool handler method
50    pub fn register_tool_handler<F, Fut>(&mut self, tool_name: String, handler: F)
51    where
52        F: Fn(HashMap<String, Value>) -> Fut + Send + Sync + 'static,
53        Fut: Future<Output = Result<Value, Error>> + Send + 'static,
54    {
55        self.tool_handlers.insert(tool_name, Arc::new(move |params| {
56            let params_clone = params.clone();
57            Box::pin(handler(params_clone))
58        }));
59    }
60    
61    // Batch add tools method
62    pub fn add_tools(&mut self, tools: Vec<McpTool>) {
63        self.available_tools.extend(tools);
64    }
65    
66    // Clear tool list method
67    pub fn clear_tools(&mut self) {
68        self.available_tools.clear();
69    }
70    
71    // Set server connection status
72    pub fn set_server_connected(&self, connected: bool) {
73        if let Ok(mut conn_status) = self.is_mcp_server_connected.lock() {
74            *conn_status = connected;
75            if connected {
76                info!("MCP server connection status set to connected");
77            } else {
78                info!("MCP server connection status set to disconnected");
79            }
80        }
81    }
82    
83    // Get server connection status
84    pub fn is_server_connected(&self) -> bool {
85        *self.is_mcp_server_connected.lock().unwrap_or_else(|e| e.into_inner())
86    }
87}
88
89// Implement McpClient trait for SimpleMcpClient
90impl McpClient for SimpleMcpClient {
91    // Connect to MCP server
92    fn connect(&mut self, url: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
93        let url = url.to_string();
94        Box::pin(async move {
95            self.url = url;
96            Ok(())
97        })
98    }
99    
100    // Get available tool list
101    fn get_tools(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<McpTool>, Error>> + Send + '_>> {
102        let url = self.url.clone();
103        let local_tools = self.available_tools.clone();
104        let is_connected = self.is_mcp_server_connected.clone();
105        Box::pin(async move {
106            // First check connection status flag, return local tool list directly if not connected
107            let connected = if let Ok(conn) = is_connected.lock() {
108                *conn
109            } else {
110                false
111            };
112
113            if !connected {
114                warn!("MCP server is not connected, returning local tools only");
115                return Ok(local_tools);
116            }
117            
118            if !url.is_empty() {
119                // Construct JSON-RPC request
120                let request = JSONRPCRequest {
121                    jsonrpc: "2.0".to_string(),
122                    id: Some(Value::String(Uuid::new_v4().to_string())),
123                    method: "tools/list".to_string(),
124                    params: None,
125                };
126
127                // Send HTTP POST request
128                let client = reqwest::Client::new();
129                let response = client
130                    .post(&format!("{}/rpc", url))
131                    .json(&request)
132                    .send()
133                    .await;
134
135                // Check if request was sent successfully
136                match response {
137                    Ok(response) => {
138                        // Check HTTP status code
139                        if !response.status().is_success() {
140                            let status = response.status();
141                            let body = response.text().await.unwrap_or_else(|_| "Unable to read response body".to_string());
142                            warn!("MCP server returned HTTP error {}: {}. Response body: {}", status.as_u16(), status.canonical_reason().unwrap_or("Unknown error"), body);
143                            // Return local tool list when server returns error
144                            return Ok(local_tools);
145                        }
146
147                        // Get response text for debugging
148                        let response_text = response.text().await
149                            .map_err(|e| Error::msg(format!("Failed to read response body: {}", e)))?;
150                        
151                        // Check if response is empty
152                        if response_text.trim().is_empty() {
153                            warn!("MCP server returned empty response");
154                            // Return local tool list when server returns empty response
155                            return Ok(local_tools);
156                        }
157
158                        // Try to parse JSON
159                        let rpc_response: JSONRPCResponse = serde_json::from_str(&response_text)
160                            .map_err(|e| {
161                                warn!("Failed to parse response as JSON: {}. Response content: {}", e, response_text);
162                                // Return local tool list when JSON parsing fails
163                                Error::msg(format!("Failed to parse response as JSON: {}. Response content: {}", e, response_text))
164                            })?;
165                        
166                        // Check for errors
167                        if let Some(error) = rpc_response.error {
168                            warn!("JSON-RPC error: {} (code: {})", error.message, error.code);
169                            // Return local tool list when JSON-RPC returns error
170                            return Ok(local_tools);
171                        }
172                        
173                        // Parse tool list
174                        if let Some(result) = rpc_response.result {
175                            debug!("Server response result: {:?}", result);
176                            if let Some(tools_value) = result.get("tools") {
177                                debug!("Tools value: {:?}", tools_value);
178                                if let Ok(tools_array) = serde_json::from_value::<Vec<serde_json::Value>>(tools_value.clone()) {
179                                    let mut tools = Vec::new();
180                                    // First add local tools to tools
181                                    tools.extend(local_tools);
182                                    for tool_value in tools_array {
183                                        debug!("Processing tool value: {:?}", tool_value);
184                                        if let (Ok(name), Ok(description)) = (
185                                            serde_json::from_value::<String>(tool_value["name"].clone()),
186                                            serde_json::from_value::<String>(tool_value["description"].clone())
187                                        ) {
188                                            tools.push(McpTool {
189                                                name,
190                                                description,
191                                            });
192                                        } else {
193                                            warn!("Failed to parse tool from server response: {:?}", tool_value);
194                                        }
195                                    }
196                                    return Ok(tools);
197                                } else {
198                                    warn!("Failed to parse tools array from server response: {:?}", tools_value);
199                                }
200                            } else {
201                                warn!("No 'tools' field in server response result: {:?}", result);
202                            }
203                        } else {
204                            warn!("No result in JSON-RPC response");
205                        }
206                        
207                        // Return local tool list if parsing fails
208                        warn!("Failed to parse tools from server response");
209                        Ok(local_tools)
210                    }
211                    Err(e) => {
212                        // Return local tool list when unable to connect to server
213                        warn!("Failed to send request to MCP server: {}", e);
214                        Ok(local_tools)
215                    }
216                }
217            } else {
218                // Return local tool list if no URL is set
219                Ok(local_tools)
220            }
221        })
222    }
223    
224    // Call specified tool
225    fn call_tool(&self, tool_name: &str, params: HashMap<String, Value>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, Error>> + Send + '_>> {
226        let url = self.url.clone();
227        let tool_name = tool_name.to_string();
228        let params = params.clone();
229        let handler_opt = self.tool_handlers.get(&tool_name).cloned();
230        Box::pin(async move {
231            // Check if there is a custom tool handler
232            if let Some(handler) = handler_opt {
233                // If there is a custom handler, call it
234                info!("Calling tool {} with params {:?}", tool_name, params);
235                handler(params.clone()).await
236            } else {
237                // Otherwise send JSON-RPC request via HTTP
238                if !url.is_empty() {
239                    // Construct JSON-RPC request
240                    let request = JSONRPCRequest {
241                        jsonrpc: "2.0".to_string(),
242                        id: Some(Value::String(Uuid::new_v4().to_string())),
243                        method: "tools/call".to_string(),
244                        params: Some(json!({
245                            "name": tool_name,
246                            "arguments": params
247                        })),
248                    };
249
250                    // Send HTTP POST request
251                    let client = reqwest::Client::new();
252                    let response = client
253                        .post(&format!("{}/rpc", url))
254                        .json(&request)
255                        .send()
256                        .await?;
257
258                    // Parse response
259                    let rpc_response: JSONRPCResponse = response.json().await?;
260                    
261                    // Check for errors
262                    if let Some(error) = rpc_response.error {
263                        return Err(Error::msg(format!("JSON-RPC error: {} (code: {})", error.message, error.code)));
264                    }
265                    
266                    // Return result
267                    Ok(rpc_response.result.unwrap_or(Value::Null))
268                } else {
269                    // If no URL is set and no custom handler, use default processing logic
270                    match tool_name.as_str() {
271                        "get_weather" => {
272                            // Bind default values to variables to extend lifetime
273                            let default_city = Value::String("Beijing".to_string());
274                            let city_value = params.get("city").unwrap_or(&default_city);
275                            let city = city_value.as_str().unwrap_or("Beijing");
276                            Ok(json!({
277                                "city": city,
278                                "temperature": "25°C",
279                                "weather": "cloudy",
280                                "humidity": "60%"
281                            }))
282                        },
283                        _ => Err(Error::msg(format!("Unknown tool: {}", tool_name)))
284                    }
285                }
286            }
287        })
288    }
289    
290    // Disconnect
291    fn disconnect(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
292        let url = self.url.clone();
293        let is_connected = self.is_mcp_server_connected.clone();
294        Box::pin(async move {
295            // Simple implementation: simulate successful disconnection
296            if let Ok(mut conn) = is_connected.lock() {
297                *conn = false;
298            }
299            info!("Disconnected from MCP server at {}", url);
300            Ok(())
301        })
302    }
303    
304    // Get tool response
305    fn get_response(&self, tool_call_id: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, Error>> + Send + '_>> {
306        let tool_call_id = tool_call_id.to_string();
307        Box::pin(async move {
308            // Simple implementation: return simulated tool response
309            Ok(serde_json::json!({
310                "tool_call_id": tool_call_id,
311                "status": "completed",
312                "response": {
313                    "data": "Sample tool response data"
314                }
315            }))
316        })
317    }
318    
319    // Clone method
320    fn clone(&self) -> Box<dyn McpClient> {
321        // Manually create deep copy of available_tools
322        let tools = self.available_tools.iter().map(|t| McpTool {
323            name: t.name.clone(),
324            description: t.description.clone()
325        }).collect();
326        
327        // Copy tool handlers
328        let tool_handlers = self.tool_handlers.clone();
329
330        // Clone connection status
331        let is_connected = if let Ok(conn) = self.is_mcp_server_connected.lock() {
332            Arc::new(Mutex::new(*conn))
333        } else {
334            Arc::new(Mutex::new(false))
335        };
336        
337        Box::new(SimpleMcpClient {
338            url: self.url.clone(),
339            available_tools: tools,
340            tool_handlers,
341            is_mcp_server_connected: is_connected,
342        })
343    }
344    
345    // Ping服务器
346    fn ping(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
347        let url = self.url.clone();
348        Box::pin(async move {
349            if !url.is_empty() {
350                // 创建 ping 请求
351                let request = JSONRPCRequest {
352                    jsonrpc: "2.0".to_string(),
353                    id: Some(Value::Number(serde_json::Number::from(1))),
354                    method: "ping".to_string(),
355                    params: None,
356                };
357            
358                // 发送请求到服务器 - 使用正确的路径 /rpc
359                let url = format!("{}/rpc", url);
360                let client = reqwest::Client::new();
361                let response = client
362                    .post(&url)
363                    .header("Content-Type", "application/json")
364                    .json(&request)
365                    .send()
366                    .await
367                    .map_err(|e| Error::msg(format!("Failed to send ping request: {}", e)))?;
368            
369                // 检查响应状态
370                if !response.status().is_success() {
371                    return Err(Error::msg(format!("Ping request failed with status: {}", response.status())));
372                }
373            
374                // 解析响应
375                let response_text = response.text().await
376                    .map_err(|e| Error::msg(format!("Failed to read response: {}", e)))?;
377                let response_value: Value = serde_json::from_str(&response_text)
378                    .map_err(|e| Error::msg(format!("Failed to parse response: {}", e)))?;
379            
380                // 检查响应中是否有错误
381                if let Some(error) = response_value.get("error") {
382                    if !error.is_null() {
383                        return Err(Error::msg(format!("Ping request returned error: {}", error)));
384                    }
385                }
386                
387                // 检查是否有结果字段
388                if let Some(_result) = response_value.get("result") {
389                    // Ping 成功,返回空结果
390                    Ok(())
391                } else {
392                    Err(Error::msg("No result in ping response"))
393                }
394            } else {
395                Err(Error::msg("No URL set for MCP client"))
396            }
397        })
398    }
399}
400
401// MCP client interface
402pub trait McpClient: Send + Sync {
403    // Connect to MCP server
404    fn connect(&mut self, _url: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
405        Box::pin(async move {
406            // Simple implementation: simulate successful connection
407            Ok(())
408        })
409    }
410    
411    // Get available tool list
412    fn get_tools(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<McpTool>, Error>> + Send + '_>> {
413        Box::pin(async move {
414            // Simple implementation: return simulated tool list
415            Ok(vec![McpTool {
416                name: "example_tool".to_string(),
417                description: "Example tool description".to_string()
418            }])
419        })
420    }
421    
422    // Call specified tool
423    fn call_tool(&self, tool_name: &str, params: HashMap<String, Value>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, Error>> + Send + '_>> {
424        let _tool_name = tool_name.to_string();
425        let _params = params.clone();
426        Box::pin(async move {
427            // Default implementation returns error because trait doesn't know how to send HTTP requests
428            Err(Error::msg("HTTP client not implemented in trait"))
429        })
430    }
431    
432    // Disconnect
433    fn disconnect(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
434        Box::pin(async move {
435            // Simple implementation: simulate successful disconnection
436            Ok(())
437        })
438    }
439    
440    // Get tool response
441    fn get_response(&self, tool_call_id: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, Error>> + Send + '_>> {
442        let tool_call_id = tool_call_id.to_string();
443        Box::pin(async move {
444            // Simple implementation: return simulated tool response
445            Ok(serde_json::json!({
446                "tool_call_id": tool_call_id,
447                "status": "completed",
448                "response": {
449                    "data": "Sample tool response data"
450                }
451            }))
452        })
453    }
454    
455    // Clone method
456    fn clone(&self) -> Box<dyn McpClient>;
457    
458    // Ping server
459    fn ping(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>> {
460        Box::pin(async move {
461            // Default implementation returns error because trait doesn't know how to send HTTP requests
462            Err(Error::msg("HTTP client not implemented in trait"))
463        })
464    }
465}