Skip to main content

matrixcode_core/mcp/
transport.rs

1//! MCP Transport Layer
2//!
3//! 提供两种传输方式:
4//! - StdioTransport: 通过 stdin/stdout 与子进程通信(最常用)
5//! - SseTransport: 通过 HTTP SSE 连接远程服务器
6
7use anyhow::{Result, anyhow};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{Duration, timeout};
14
15// ============================================================================
16// Transport Trait
17// ============================================================================
18
19/// MCP 传输层抽象
20#[async_trait]
21pub trait Transport: Send + Sync {
22    /// 发送请求并等待响应
23    async fn send(&self, message: &str) -> Result<String>;
24
25    /// 发送通知(无需响应)
26    async fn notify(&self, message: &str) -> Result<()>;
27
28    /// 接收一条消息
29    async fn receive(&self) -> Result<String>;
30
31    /// 关闭连接
32    async fn close(&self) -> Result<()>;
33}
34
35// ============================================================================
36// Stdio Transport
37// ============================================================================
38
39/// Stdio 传输 - 通过子进程的 stdin/stdout 通信
40pub struct StdioTransport {
41    /// 子进程
42    process: Arc<Mutex<Option<Child>>>,
43    /// 写入端 (进程 stdin)
44    writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45    /// 读取端 (进程 stdout)
46    reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47    /// 服务器名称(用于日志)
48    server_name: String,
49}
50
51impl StdioTransport {
52    /// 启动 MCP 服务器进程
53    pub async fn spawn(
54        name: impl Into<String>,
55        command: &str,
56        args: &[String],
57        env: Option<Vec<(String, String)>>,
58    ) -> Result<Self> {
59        let server_name = name.into();
60
61        // Windows 兼容性:npx, npm 等需要通过 cmd.exe 运行
62        let (actual_command, actual_args) = if cfg!(target_os = "windows")
63            && (command == "npx" || command == "npm" || command == "node")
64        {
65            let mut full_args = vec!["/c".to_string(), command.to_string()];
66            full_args.extend(args.iter().cloned());
67            ("cmd.exe".to_string(), full_args)
68        } else {
69            (command.to_string(), args.to_vec())
70        };
71
72        // 使用 tokio 异步 Command
73        let mut cmd = Command::new(&actual_command);
74        cmd.args(&actual_args)
75            .stdin(std::process::Stdio::piped())
76            .stdout(std::process::Stdio::piped())
77            .stderr(std::process::Stdio::piped())
78            .kill_on_drop(true); // 确保进程在 drop 时被杀死
79
80        // 设置环境变量
81        if let Some(env_vars) = env {
82            for (key, value) in env_vars {
83                cmd.env(key, value);
84            }
85        }
86
87        // 启动进程
88        tracing::debug!(
89            "Spawning MCP server '{}' with command: {} {:?}",
90            server_name,
91            actual_command,
92            actual_args
93        );
94        let mut child = cmd.spawn().map_err(|e| {
95            anyhow!(
96                "Failed to spawn MCP server '{}': {} (command: {} {:?})",
97                server_name,
98                e,
99                actual_command,
100                actual_args
101            )
102        })?;
103
104        tracing::debug!("MCP server '{}' process spawned successfully", server_name);
105
106        // 获取 stdin/stdout (tokio 异步版本)
107        let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(
108            child
109                .stdin
110                .take()
111                .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?,
112        );
113        let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(
114            child
115                .stdout
116                .take()
117                .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?,
118        );
119
120        tracing::info!(
121            "MCP server '{}' started: {} {:?}",
122            server_name,
123            actual_command,
124            actual_args
125        );
126
127        Ok(Self {
128            process: Arc::new(Mutex::new(Some(child))),
129            writer: Arc::new(Mutex::new(Some(stdin))),
130            reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
131            server_name,
132        })
133    }
134
135    /// 读取一行响应(带超时)
136    async fn read_line(&self) -> Result<String> {
137        let mut reader_lock = self.reader.lock().await;
138        let reader = reader_lock
139            .as_mut()
140            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
141
142        let mut line = String::new();
143
144        // 添加 30 秒超时
145        tracing::debug!("Reading from '{}' (timeout: 30s)...", self.server_name);
146        let read_result =
147            tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut line)).await;
148
149        tracing::debug!(
150            "Read result from '{}': {:?}",
151            self.server_name,
152            read_result.is_ok()
153        );
154
155        match read_result {
156            Ok(Ok(_)) => {
157                if line.is_empty() {
158                    return Err(anyhow!("EOF reached for server '{}'", self.server_name));
159                }
160                // 移除换行符
161                Ok(line.trim_end().to_string())
162            }
163            Ok(Err(e)) => Err(anyhow!(
164                "Read error for server '{}': {}",
165                self.server_name,
166                e
167            )),
168            Err(_) => Err(anyhow!(
169                "Read timeout for server '{}' after 30s",
170                self.server_name
171            )),
172        }
173    }
174}
175
176#[async_trait]
177impl Transport for StdioTransport {
178    async fn send(&self, message: &str) -> Result<String> {
179        tracing::debug!(
180            "MCP send to '{}': {}",
181            self.server_name,
182            message.chars().take(200).collect::<String>()
183        );
184
185        let mut writer_lock = self.writer.lock().await;
186        let writer = writer_lock
187            .as_mut()
188            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
189
190        // 发送请求(带换行符)
191        writer
192            .write_all(format!("{}\n", message).as_bytes())
193            .await?;
194        writer.flush().await?;
195
196        tracing::debug!(
197            "MCP sent, waiting for response from '{}'...",
198            self.server_name
199        );
200
201        // 等待响应
202        let response = self.read_line().await?;
203        tracing::debug!(
204            "MCP received from '{}': {}",
205            self.server_name,
206            response.chars().take(200).collect::<String>()
207        );
208        Ok(response)
209    }
210
211    async fn notify(&self, message: &str) -> Result<()> {
212        let mut writer_lock = self.writer.lock().await;
213        let writer = writer_lock
214            .as_mut()
215            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
216
217        tracing::info!(
218            "MCP >> '{}' : {}",
219            self.server_name,
220            message.chars().take(100).collect::<String>()
221        );
222        writer
223            .write_all(format!("{}\n", message).as_bytes())
224            .await?;
225        writer.flush().await?;
226        Ok(())
227    }
228
229    async fn receive(&self) -> Result<String> {
230        let line = self.read_line().await?;
231        tracing::info!(
232            "MCP << '{}' : {}",
233            self.server_name,
234            line.chars().take(100).collect::<String>()
235        );
236        Ok(line)
237    }
238
239    async fn close(&self) -> Result<()> {
240        let mut process_lock = self.process.lock().await;
241        if let Some(mut child) = process_lock.take() {
242            child
243                .kill()
244                .await
245                .map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
246            tracing::info!("MCP server '{}' stopped", self.server_name);
247        }
248
249        *self.writer.lock().await = None;
250        *self.reader.lock().await = None;
251        Ok(())
252    }
253}
254
255// ============================================================================
256// SSE Transport (HTTP)
257// ============================================================================
258
259/// SSE 传输 - 通过 HTTP Server-Sent Events 通信
260pub struct SseTransport {
261    /// 基础 URL
262    base_url: String,
263    /// HTTP 客户端
264    client: reqwest::Client,
265    /// 服务器名称
266    server_name: String,
267    /// 请求超时
268    timeout_ms: u64,
269}
270
271impl SseTransport {
272    /// 创建 SSE 传输
273    pub fn new(
274        name: impl Into<String>,
275        base_url: impl Into<String>,
276        timeout_ms: Option<u64>,
277    ) -> Self {
278        Self {
279            base_url: base_url.into(),
280            client: reqwest::Client::new(),
281            server_name: name.into(),
282            timeout_ms: timeout_ms.unwrap_or(30000),
283        }
284    }
285
286    /// 发送 HTTP 请求
287    async fn send_http(&self, body: &str) -> Result<String> {
288        let url = format!("{}/mcp", self.base_url);
289
290        let response = timeout(
291            Duration::from_millis(self.timeout_ms),
292            self.client
293                .post(&url)
294                .header("Content-Type", "application/json")
295                .body(body.to_string())
296                .send(),
297        )
298        .await
299        .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
300        .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
301
302        let text = response.text().await?;
303        Ok(text)
304    }
305}
306
307#[async_trait]
308impl Transport for SseTransport {
309    async fn send(&self, message: &str) -> Result<String> {
310        self.send_http(message).await
311    }
312
313    async fn notify(&self, message: &str) -> Result<()> {
314        // SSE 通知也是通过 HTTP POST
315        self.send_http(message).await?;
316        Ok(())
317    }
318
319    async fn receive(&self) -> Result<String> {
320        // SSE 需要等待 HTTP 响应,通常 send 已包含响应
321        // 这里作为简化实现,实际 SSE 场景可能需要单独处理
322        Err(anyhow!(
323            "SSE receive not implemented - use send() for request/response"
324        ))
325    }
326
327    async fn close(&self) -> Result<()> {
328        // HTTP 连接无需关闭
329        Ok(())
330    }
331}
332
333// ============================================================================
334// Transport Factory
335// ============================================================================
336
337/// 传���配置
338#[derive(Debug, Clone)]
339pub enum TransportConfig {
340    /// Stdio 传输配置
341    Stdio {
342        command: String,
343        args: Vec<String>,
344        env: Option<Vec<(String, String)>>,
345    },
346    /// SSE 传输配置
347    Sse {
348        url: String,
349        timeout_ms: Option<u64>,
350    },
351}
352
353impl TransportConfig {
354    /// 创建 stdio 配置
355    pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
356        Self::Stdio {
357            command: command.into(),
358            args,
359            env: None,
360        }
361    }
362
363    /// 创建 SSE 配置
364    pub fn sse(url: impl Into<String>) -> Self {
365        Self::Sse {
366            url: url.into(),
367            timeout_ms: None,
368        }
369    }
370}
371
372/// 创建传输实例
373pub async fn create_transport(
374    server_name: &str,
375    config: &TransportConfig,
376) -> Result<Box<dyn Transport>> {
377    match config {
378        TransportConfig::Stdio { command, args, env } => Ok(Box::new(
379            StdioTransport::spawn(server_name, command, args, env.clone()).await?,
380        )),
381        TransportConfig::Sse { url, timeout_ms } => {
382            Ok(Box::new(SseTransport::new(server_name, url, *timeout_ms)))
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_transport_config_stdio() {
393        let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
394        match config {
395            TransportConfig::Stdio { command, args, .. } => {
396                assert_eq!(command, "npx");
397                assert_eq!(args.len(), 2);
398            }
399            _ => panic!("Expected Stdio variant"),
400        }
401    }
402
403    #[test]
404    fn test_transport_config_sse() {
405        let config = TransportConfig::sse("http://localhost:3000");
406        match config {
407            TransportConfig::Sse { url, .. } => {
408                assert_eq!(url, "http://localhost:3000");
409            }
410            _ => panic!("Expected Sse variant"),
411        }
412    }
413}