Skip to main content

matrixcode_core/mcp/
client.rs

1//! MCP Client
2//!
3//! MCP 协议客户端,负责:
4//! - 连接管理
5//! - 协议握手
6//! - 工具发现与调用
7
8use anyhow::{Result, anyhow};
9use serde_json::{Value, json};
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12
13use super::transport::{Transport, TransportConfig, create_transport};
14use super::types::*;
15
16// ============================================================================
17// MCP Client
18// ============================================================================
19
20/// MCP 客户端
21pub struct McpClient {
22    /// 服务器名称
23    server_name: String,
24    /// 传输层
25    transport: Box<dyn Transport>,
26    /// 服务器能力
27    capabilities: RwLock<Option<ServerCapabilities>>,
28    /// 服务器信息
29    server_info: RwLock<Option<Implementation>>,
30    /// 工具缓存
31    tools_cache: RwLock<Vec<Tool>>,
32    /// 请求 ID 计数器
33    request_id: RwLock<i64>,
34    /// 是否已初始化
35    initialized: RwLock<bool>,
36}
37
38impl McpClient {
39    /// 创建并初始化 MCP 客户端
40    pub async fn connect(server_name: impl Into<String>, config: TransportConfig) -> Result<Self> {
41        let server_name = server_name.into();
42        let transport = create_transport(&server_name, &config).await?;
43
44        let client = Self {
45            server_name,
46            transport,
47            capabilities: RwLock::new(None),
48            server_info: RwLock::new(None),
49            tools_cache: RwLock::new(Vec::new()),
50            request_id: RwLock::new(0),
51            initialized: RwLock::new(false),
52        };
53
54        // 执行初始化握手
55        client.initialize().await?;
56
57        Ok(client)
58    }
59
60    /// 获取服务器名称
61    pub fn server_name(&self) -> &str {
62        &self.server_name
63    }
64
65    /// 是否已初始化
66    pub async fn is_initialized(&self) -> bool {
67        *self.initialized.read().await
68    }
69
70    /// 获取服务器能力
71    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
72        self.capabilities.read().await.clone()
73    }
74
75    /// 获取服务器信息
76    pub async fn server_info(&self) -> Option<Implementation> {
77        self.server_info.read().await.clone()
78    }
79
80    // ========================================================================
81    // Protocol Methods
82    // ========================================================================
83
84    /// 生成下一个请求 ID
85    async fn next_request_id(&self) -> RequestId {
86        let mut id = self.request_id.write().await;
87        *id += 1;
88        RequestId::Number(*id)
89    }
90
91    /// 发送请求并解析响应
92    async fn send_request<T: serde::de::DeserializeOwned>(
93        &self,
94        method: &str,
95        params: Option<Value>,
96    ) -> Result<T> {
97        let id = self.next_request_id().await;
98
99        let request = JsonRpcRequest {
100            jsonrpc: "2.0".to_string(),
101            id: id.clone(),
102            method: method.to_string(),
103            params,
104        };
105
106        let message = serde_json::to_string(&request)?;
107        tracing::debug!("MCP request to '{}': {}", self.server_name, message);
108
109        // 发送请求
110        self.transport.notify(&message).await?;
111
112        // 循环读取消息直到找到匹配的响应
113        loop {
114            let response = self.transport.receive().await?;
115            tracing::debug!("MCP message from '{}': {}", self.server_name, response);
116
117            // 尝试解析为服务端请求(如 roots/list)
118            if let Ok(server_req) = serde_json::from_str::<JsonRpcRequest>(&response) {
119                // 处理服务端请求
120                self.handle_server_request(&server_req).await?;
121                continue;
122            }
123
124            // 尝试解析为成功响应
125            if let Ok(success) = serde_json::from_str::<JsonRpcResponse>(&response) {
126                if success.id != id {
127                    // 不是我们要的响应,继续等待
128                    continue;
129                }
130                return serde_json::from_value(success.result)
131                    .map_err(|e| anyhow!("Failed to parse result: {}", e));
132            }
133
134            // 尝试解析为错误响应
135            if let Ok(error) = serde_json::from_str::<JsonRpcError>(&response) {
136                if error.id != id {
137                    continue;
138                }
139                return Err(anyhow!(
140                    "MCP error from '{}': [{}] {}",
141                    self.server_name,
142                    error.error.code,
143                    error.error.message
144                ));
145            }
146
147            // 尝试解析为通知(无 id)
148            if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(&response) {
149                tracing::debug!(
150                    "MCP notification from '{}': {}",
151                    self.server_name,
152                    notification.method
153                );
154                continue;
155            }
156
157            // 无法识别的消息格式
158            tracing::warn!("Unexpected MCP message format: {}", response);
159        }
160    }
161
162    /// 处理服务端发来的请求
163    async fn handle_server_request(&self, request: &JsonRpcRequest) -> Result<()> {
164        tracing::debug!(
165            "MCP server request '{}': {}",
166            self.server_name,
167            request.method
168        );
169
170        // 根据方法名处理
171        match request.method.as_str() {
172            "roots/list" => {
173                // 返回空的 roots 列表
174                let response = JsonRpcResponse {
175                    jsonrpc: "2.0".to_string(),
176                    id: request.id.clone(),
177                    result: json!({ "roots": [] }),
178                };
179                let message = serde_json::to_string(&response)?;
180                self.transport.notify(&message).await?;
181            }
182            "ping" => {
183                // 响应 pong
184                let response = JsonRpcResponse {
185                    jsonrpc: "2.0".to_string(),
186                    id: request.id.clone(),
187                    result: json!({}),
188                };
189                let message = serde_json::to_string(&response)?;
190                self.transport.notify(&message).await?;
191            }
192            _ => {
193                tracing::warn!("Unhandled MCP server request: {}", request.method);
194                // 返回方法不存在的错误
195                let error_response = JsonRpcError {
196                    jsonrpc: "2.0".to_string(),
197                    id: request.id.clone(),
198                    error: JsonRpcErrorDetail {
199                        code: -32601,
200                        message: "Method not found".to_string(),
201                        data: None,
202                    },
203                };
204                let message = serde_json::to_string(&error_response)?;
205                self.transport.notify(&message).await?;
206            }
207        }
208
209        Ok(())
210    }
211
212    /// 发送通知(无需响应)
213    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
214        let notification = JsonRpcNotification {
215            jsonrpc: "2.0".to_string(),
216            method: method.to_string(),
217            params,
218        };
219
220        let message = serde_json::to_string(&notification)?;
221        self.transport.notify(&message).await?;
222        Ok(())
223    }
224
225    // ========================================================================
226    // Initialization
227    // ========================================================================
228
229    /// 执行初始化握手
230    async fn initialize(&self) -> Result<()> {
231        tracing::info!("Initializing MCP server '{}'", self.server_name);
232
233        // 发送 initialize 请求
234        let params = InitializeParams {
235            capabilities: ClientCapabilities {
236                roots: Some(RootsCapability {
237                    list_changed: Some(false),
238                }),
239                ..Default::default()
240            },
241            client_info: Implementation::default(),
242            protocol_version: Some("2024-11-05".to_string()),
243        };
244
245        let result: InitializeResult = self
246            .send_request("initialize", Some(serde_json::to_value(params)?))
247            .await?;
248
249        // 保存服务器信息(先 clone 用于日志)
250        let server_name = result.server_info.name.clone();
251        let server_version = result.server_info.version.clone();
252
253        *self.capabilities.write().await = Some(result.capabilities);
254        *self.server_info.write().await = Some(result.server_info);
255
256        tracing::info!(
257            "MCP server '{}' initialized: {} v{}",
258            self.server_name,
259            server_name,
260            server_version
261        );
262
263        // 发送 initialized 通知
264        self.send_notification("notifications/initialized", None)
265            .await?;
266
267        *self.initialized.write().await = true;
268        Ok(())
269    }
270
271    // ========================================================================
272    // Tools API
273    // ========================================================================
274
275    /// 列出所有工具
276    pub async fn list_tools(&self) -> Result<Vec<Tool>> {
277        if !self.is_initialized().await {
278            return Err(anyhow!("MCP client not initialized"));
279        }
280
281        let result: ListToolsResult = self.send_request("tools/list", None).await?;
282
283        // 缓存工具列表
284        *self.tools_cache.write().await = result.tools.clone();
285
286        Ok(result.tools)
287    }
288
289    /// 获取缓存的工具列表
290    pub async fn cached_tools(&self) -> Vec<Tool> {
291        self.tools_cache.read().await.clone()
292    }
293
294    /// 调用工具
295    pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
296        if !self.is_initialized().await {
297            return Err(anyhow!("MCP client not initialized"));
298        }
299
300        let params = CallToolParams {
301            name: name.to_string(),
302            arguments,
303        };
304
305        self.send_request("tools/call", Some(serde_json::to_value(params)?))
306            .await
307    }
308
309    /// 检查服务器是否支持工具
310    pub async fn supports_tools(&self) -> bool {
311        self.capabilities
312            .read()
313            .await
314            .as_ref()
315            .map(|c| c.tools.is_some())
316            .unwrap_or(false)
317    }
318
319    // ========================================================================
320    // Resources API (Optional)
321    // ========================================================================
322
323    /// 列出所有资源
324    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
325        if !self.is_initialized().await {
326            return Err(anyhow!("MCP client not initialized"));
327        }
328
329        let result: ListResourcesResult = self.send_request("resources/list", None).await?;
330        Ok(result.resources)
331    }
332
333    /// 读取资源
334    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
335        if !self.is_initialized().await {
336            return Err(anyhow!("MCP client not initialized"));
337        }
338
339        self.send_request("resources/read", Some(json!({ "uri": uri })))
340            .await
341    }
342
343    /// 检查服务器是否支持资源
344    pub async fn supports_resources(&self) -> bool {
345        self.capabilities
346            .read()
347            .await
348            .as_ref()
349            .map(|c| c.resources.is_some())
350            .unwrap_or(false)
351    }
352
353    // ========================================================================
354    // Prompts API (Optional)
355    // ========================================================================
356
357    /// 列出所有 prompt
358    pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
359        if !self.is_initialized().await {
360            return Err(anyhow!("MCP client not initialized"));
361        }
362
363        let result: ListPromptsResult = self.send_request("prompts/list", None).await?;
364        Ok(result.prompts)
365    }
366
367    /// 获取 prompt
368    pub async fn get_prompt(
369        &self,
370        name: &str,
371        arguments: Option<HashMap<String, String>>,
372    ) -> Result<Value> {
373        if !self.is_initialized().await {
374            return Err(anyhow!("MCP client not initialized"));
375        }
376
377        let mut params = json!({ "name": name });
378        if let Some(args) = arguments {
379            params["arguments"] = serde_json::to_value(args)?;
380        }
381
382        self.send_request("prompts/get", Some(params)).await
383    }
384
385    /// 检查服务器是否支持 prompt
386    pub async fn supports_prompts(&self) -> bool {
387        self.capabilities
388            .read()
389            .await
390            .as_ref()
391            .map(|c| c.prompts.is_some())
392            .unwrap_or(false)
393    }
394
395    // ========================================================================
396    // Logging API
397    // ========================================================================
398
399    /// 设置日志级别
400    pub async fn set_logging_level(&self, level: LogLevel) -> Result<()> {
401        if !self.is_initialized().await {
402            return Err(anyhow!("MCP client not initialized"));
403        }
404
405        let params = SetLoggingLevelParams { level };
406        self.send_request("logging/setLevel", Some(serde_json::to_value(params)?))
407            .await
408    }
409
410    // ========================================================================
411    // Lifecycle
412    // ========================================================================
413
414    /// 关���连接
415    pub async fn shutdown(&self) -> Result<()> {
416        tracing::info!("Shutting down MCP server '{}'", self.server_name);
417        self.transport.close().await
418    }
419}
420
421// ============================================================================
422// MCP Client Builder
423// ============================================================================
424
425/// MCP 客户端构建器
426pub struct McpClientBuilder {
427    server_name: String,
428    config: TransportConfig,
429}
430
431impl McpClientBuilder {
432    /// 创建构建器
433    pub fn new(name: impl Into<String>) -> Self {
434        Self {
435            server_name: name.into(),
436            config: TransportConfig::stdio("", vec![]),
437        }
438    }
439
440    /// 使用 stdio 传输
441    pub fn stdio(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
442        self.config = TransportConfig::stdio(command, args);
443        self
444    }
445
446    /// 使用 SSE 传输
447    pub fn sse(mut self, url: impl Into<String>) -> Self {
448        self.config = TransportConfig::sse(url);
449        self
450    }
451
452    /// 构建并连接
453    pub async fn connect(self) -> Result<McpClient> {
454        McpClient::connect(self.server_name, self.config).await
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_client_builder() {
464        let builder =
465            McpClientBuilder::new("test").stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
466
467        assert_eq!(builder.server_name, "test");
468    }
469}