Skip to main content

matrixcode_core/mcp/
client.rs

1//! MCP Client
2//!
3//! MCP 协议客户端,负责:
4//! - 连接管理
5//! - 协议握手
6//! - 工具发现与调用
7
8use anyhow::{anyhow, Result};
9use serde_json::{json, Value};
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12
13use super::transport::{create_transport, Transport, TransportConfig};
14use super::types::*;
15
16// ============================================================================
17// MCP Client
18// ============================================================================
19
20/// MCP 客户端
21pub struct McpClient {
22    /// 服务器名称
23    server_name: String,
24    /// 传输层
25    transport: Box<dyn Transport>,
26    /// 服务器能力
27    capabilities: RwLock<Option<ServerCapabilities>>,
28    /// 服务器信息
29    server_info: RwLock<Option<Implementation>>,
30    /// 工具缓存
31    tools_cache: RwLock<Vec<Tool>>,
32    /// 请求 ID 计数器
33    request_id: RwLock<i64>,
34    /// 是否已初始化
35    initialized: RwLock<bool>,
36}
37
38impl McpClient {
39    /// 创建并初始化 MCP 客户端
40    pub async fn connect(
41        server_name: impl Into<String>,
42        config: TransportConfig,
43    ) -> Result<Self> {
44        let server_name = server_name.into();
45        let transport = create_transport(&server_name, &config).await?;
46        
47        let client = Self {
48            server_name,
49            transport,
50            capabilities: RwLock::new(None),
51            server_info: RwLock::new(None),
52            tools_cache: RwLock::new(Vec::new()),
53            request_id: RwLock::new(0),
54            initialized: RwLock::new(false),
55        };
56        
57        // 执行初始化握手
58        client.initialize().await?;
59        
60        Ok(client)
61    }
62    
63    /// 获取服务器名称
64    pub fn server_name(&self) -> &str {
65        &self.server_name
66    }
67    
68    /// 是否已初始化
69    pub async fn is_initialized(&self) -> bool {
70        *self.initialized.read().await
71    }
72    
73    /// 获取服务器能力
74    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
75        self.capabilities.read().await.clone()
76    }
77    
78    /// 获取服务器信息
79    pub async fn server_info(&self) -> Option<Implementation> {
80        self.server_info.read().await.clone()
81    }
82    
83    // ========================================================================
84    // Protocol Methods
85    // ========================================================================
86    
87    /// 生成下一个请求 ID
88    async fn next_request_id(&self) -> RequestId {
89        let mut id = self.request_id.write().await;
90        *id += 1;
91        RequestId::Number(*id)
92    }
93    
94    /// 发送请求并解析响应
95    async fn send_request<T: serde::de::DeserializeOwned>(
96        &self,
97        method: &str,
98        params: Option<Value>,
99    ) -> Result<T> {
100        let id = self.next_request_id().await;
101        
102        let request = JsonRpcRequest {
103            jsonrpc: "2.0".to_string(),
104            id: id.clone(),
105            method: method.to_string(),
106            params,
107        };
108        
109        let message = serde_json::to_string(&request)?;
110        tracing::debug!("MCP request to '{}': {}", self.server_name, message);
111        
112        // 发送请求
113        self.transport.notify(&message).await?;
114        
115        // 循环读取消息直到找到匹配的响应
116        loop {
117            let response = self.transport.receive().await?;
118            tracing::debug!("MCP message from '{}': {}", self.server_name, response);
119            
120            // 尝试解析为服务端请求(如 roots/list)
121            if let Ok(server_req) = serde_json::from_str::<JsonRpcRequest>(&response) {
122                // 处理服务端请求
123                self.handle_server_request(&server_req).await?;
124                continue;
125            }
126            
127            // 尝试解析为成功响应
128            if let Ok(success) = serde_json::from_str::<JsonRpcResponse>(&response) {
129                if success.id != id {
130                    // 不是我们要的响应,继续等待
131                    continue;
132                }
133                return serde_json::from_value(success.result)
134                    .map_err(|e| anyhow!("Failed to parse result: {}", e));
135            }
136            
137            // 尝试解析为错误响应
138            if let Ok(error) = serde_json::from_str::<JsonRpcError>(&response) {
139                if error.id != id {
140                    continue;
141                }
142                return Err(anyhow!(
143                    "MCP error from '{}': [{}] {}",
144                    self.server_name,
145                    error.error.code,
146                    error.error.message
147                ));
148            }
149            
150            // 尝试解析为通知(无 id)
151            if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(&response) {
152                tracing::debug!("MCP notification from '{}': {}", self.server_name, notification.method);
153                continue;
154            }
155            
156            // 无法识别的消息格式
157            tracing::warn!("Unexpected MCP message format: {}", response);
158        }
159    }
160    
161    /// 处理服务端发来的请求
162    async fn handle_server_request(&self, request: &JsonRpcRequest) -> Result<()> {
163        tracing::debug!("MCP server request '{}': {}", self.server_name, request.method);
164        
165        // 根据方法名处理
166        match request.method.as_str() {
167            "roots/list" => {
168                // 返回空的 roots 列表
169                let response = JsonRpcResponse {
170                    jsonrpc: "2.0".to_string(),
171                    id: request.id.clone(),
172                    result: json!({ "roots": [] }),
173                };
174                let message = serde_json::to_string(&response)?;
175                self.transport.notify(&message).await?;
176            }
177            "ping" => {
178                // 响应 pong
179                let response = JsonRpcResponse {
180                    jsonrpc: "2.0".to_string(),
181                    id: request.id.clone(),
182                    result: json!({}),
183                };
184                let message = serde_json::to_string(&response)?;
185                self.transport.notify(&message).await?;
186            }
187            _ => {
188                tracing::warn!("Unhandled MCP server request: {}", request.method);
189                // 返回方法不存在的错误
190                let error_response = JsonRpcError {
191                    jsonrpc: "2.0".to_string(),
192                    id: request.id.clone(),
193                    error: JsonRpcErrorDetail {
194                        code: -32601,
195                        message: "Method not found".to_string(),
196                        data: None,
197                    },
198                };
199                let message = serde_json::to_string(&error_response)?;
200                self.transport.notify(&message).await?;
201            }
202        }
203        
204        Ok(())
205    }
206    
207    /// 发送通知(无需响应)
208    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
209        let notification = JsonRpcNotification {
210            jsonrpc: "2.0".to_string(),
211            method: method.to_string(),
212            params,
213        };
214        
215        let message = serde_json::to_string(&notification)?;
216        self.transport.notify(&message).await?;
217        Ok(())
218    }
219    
220    // ========================================================================
221    // Initialization
222    // ========================================================================
223    
224    /// 执行初始化握手
225    async fn initialize(&self) -> Result<()> {
226        tracing::info!("Initializing MCP server '{}'", self.server_name);
227        
228        // 发送 initialize 请求
229        let params = InitializeParams {
230            capabilities: ClientCapabilities {
231                roots: Some(RootsCapability {
232                    list_changed: Some(false),
233                }),
234                ..Default::default()
235            },
236            client_info: Implementation::default(),
237            protocol_version: Some("2024-11-05".to_string()),
238        };
239        
240        let result: InitializeResult = self.send_request(
241            "initialize",
242            Some(serde_json::to_value(params)?),
243        ).await?;
244        
245        // 保存服务器信息(先 clone 用于日志)
246        let server_name = result.server_info.name.clone();
247        let server_version = result.server_info.version.clone();
248        
249        *self.capabilities.write().await = Some(result.capabilities);
250        *self.server_info.write().await = Some(result.server_info);
251        
252        tracing::info!(
253            "MCP server '{}' initialized: {} v{}",
254            self.server_name,
255            server_name,
256            server_version
257        );
258        
259        // 发送 initialized 通知
260        self.send_notification("notifications/initialized", None).await?;
261        
262        *self.initialized.write().await = true;
263        Ok(())
264    }
265    
266    // ========================================================================
267    // Tools API
268    // ========================================================================
269    
270    /// 列出所有工具
271    pub async fn list_tools(&self) -> Result<Vec<Tool>> {
272        if !self.is_initialized().await {
273            return Err(anyhow!("MCP client not initialized"));
274        }
275        
276        let result: ListToolsResult = self.send_request("tools/list", None).await?;
277        
278        // 缓存工具列表
279        *self.tools_cache.write().await = result.tools.clone();
280        
281        Ok(result.tools)
282    }
283    
284    /// 获取缓存的工具列表
285    pub async fn cached_tools(&self) -> Vec<Tool> {
286        self.tools_cache.read().await.clone()
287    }
288    
289    /// 调用工具
290    pub async fn call_tool(
291        &self,
292        name: &str,
293        arguments: Option<Value>,
294    ) -> Result<CallToolResult> {
295        if !self.is_initialized().await {
296            return Err(anyhow!("MCP client not initialized"));
297        }
298        
299        let params = CallToolParams {
300            name: name.to_string(),
301            arguments,
302        };
303        
304        self.send_request("tools/call", Some(serde_json::to_value(params)?)).await
305    }
306    
307    /// 检查服务器是否支持工具
308    pub async fn supports_tools(&self) -> bool {
309        self.capabilities.read().await
310            .as_ref()
311            .map(|c| c.tools.is_some())
312            .unwrap_or(false)
313    }
314    
315    // ========================================================================
316    // Resources API (Optional)
317    // ========================================================================
318    
319    /// 列出所有资源
320    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
321        if !self.is_initialized().await {
322            return Err(anyhow!("MCP client not initialized"));
323        }
324        
325        let result: ListResourcesResult = self.send_request("resources/list", None).await?;
326        Ok(result.resources)
327    }
328    
329    /// 读取资源
330    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
331        if !self.is_initialized().await {
332            return Err(anyhow!("MCP client not initialized"));
333        }
334        
335        self.send_request("resources/read", Some(json!({ "uri": uri }))).await
336    }
337    
338    /// 检查服务器是否支持资源
339    pub async fn supports_resources(&self) -> bool {
340        self.capabilities.read().await
341            .as_ref()
342            .map(|c| c.resources.is_some())
343            .unwrap_or(false)
344    }
345    
346    // ========================================================================
347    // Prompts API (Optional)
348    // ========================================================================
349    
350    /// 列出所有 prompt
351    pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
352        if !self.is_initialized().await {
353            return Err(anyhow!("MCP client not initialized"));
354        }
355        
356        let result: ListPromptsResult = self.send_request("prompts/list", None).await?;
357        Ok(result.prompts)
358    }
359    
360    /// 获取 prompt
361    pub async fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, String>>) -> Result<Value> {
362        if !self.is_initialized().await {
363            return Err(anyhow!("MCP client not initialized"));
364        }
365        
366        let mut params = json!({ "name": name });
367        if let Some(args) = arguments {
368            params["arguments"] = serde_json::to_value(args)?;
369        }
370        
371        self.send_request("prompts/get", Some(params)).await
372    }
373    
374    /// 检查服务器是否支持 prompt
375    pub async fn supports_prompts(&self) -> bool {
376        self.capabilities.read().await
377            .as_ref()
378            .map(|c| c.prompts.is_some())
379            .unwrap_or(false)
380    }
381    
382    // ========================================================================
383    // Logging API
384    // ========================================================================
385    
386    /// 设置日志级别
387    pub async fn set_logging_level(&self, level: LogLevel) -> Result<()> {
388        if !self.is_initialized().await {
389            return Err(anyhow!("MCP client not initialized"));
390        }
391        
392        let params = SetLoggingLevelParams { level };
393        self.send_request("logging/setLevel", Some(serde_json::to_value(params)?)).await
394    }
395    
396    // ========================================================================
397    // Lifecycle
398    // ========================================================================
399    
400    /// 关���连接
401    pub async fn shutdown(&self) -> Result<()> {
402        tracing::info!("Shutting down MCP server '{}'", self.server_name);
403        self.transport.close().await
404    }
405}
406
407// ============================================================================
408// MCP Client Builder
409// ============================================================================
410
411/// MCP 客户端构建器
412pub struct McpClientBuilder {
413    server_name: String,
414    config: TransportConfig,
415}
416
417impl McpClientBuilder {
418    /// 创建构建器
419    pub fn new(name: impl Into<String>) -> Self {
420        Self {
421            server_name: name.into(),
422            config: TransportConfig::stdio("", vec![]),
423        }
424    }
425    
426    /// 使用 stdio 传输
427    pub fn stdio(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
428        self.config = TransportConfig::stdio(command, args);
429        self
430    }
431    
432    /// 使用 SSE 传输
433    pub fn sse(mut self, url: impl Into<String>) -> Self {
434        self.config = TransportConfig::sse(url);
435        self
436    }
437    
438    /// 构建并连接
439    pub async fn connect(self) -> Result<McpClient> {
440        McpClient::connect(self.server_name, self.config).await
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    
448    #[test]
449    fn test_client_builder() {
450        let builder = McpClientBuilder::new("test")
451            .stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
452        
453        assert_eq!(builder.server_name, "test");
454    }
455}