Skip to main content

matrixcode_core/mcp/
transport.rs

1//! MCP Transport Layer
2//!
3//! 提供两种传输方式:
4//! - StdioTransport: 通过 stdin/stdout 与子进程通信(最常用)
5//! - SseTransport: 通过 HTTP SSE 连接远程服务器
6
7use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{timeout, Duration};
14
15// ============================================================================
16// Transport Trait
17// ============================================================================
18
19/// MCP 传输层抽象
20#[async_trait]
21pub trait Transport: Send + Sync {
22    /// 发送请求并等待响应
23    async fn send(&self, message: &str) -> Result<String>;
24    
25    /// 发送通知(无需响应)
26    async fn notify(&self, message: &str) -> Result<()>;
27    
28    /// 接收一条消息
29    async fn receive(&self) -> Result<String>;
30    
31    /// 关闭连接
32    async fn close(&self) -> Result<()>;
33}
34
35// ============================================================================
36// Stdio Transport
37// ============================================================================
38
39/// Stdio 传输 - 通过子进程的 stdin/stdout 通信
40pub struct StdioTransport {
41    /// 子进程
42    process: Arc<Mutex<Option<Child>>>,
43    /// 写入端 (进程 stdin)
44    writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45    /// 读取端 (进程 stdout)
46    reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47    /// 服务器名称(用于日志)
48    server_name: String,
49}
50
51impl StdioTransport {
52    /// 启动 MCP 服务器进程
53    pub async fn spawn(
54        name: impl Into<String>,
55        command: &str,
56        args: &[String],
57        env: Option<Vec<(String, String)>>,
58    ) -> Result<Self> {
59        let server_name = name.into();
60        
61        // Windows 兼容性:npx, npm 等需要通过 cmd.exe 运行
62        let (actual_command, actual_args) = if cfg!(target_os = "windows") 
63            && (command == "npx" || command == "npm" || command == "node") {
64            let mut full_args = vec!["/c".to_string(), command.to_string()];
65            full_args.extend(args.iter().cloned());
66            ("cmd.exe".to_string(), full_args)
67        } else {
68            (command.to_string(), args.to_vec())
69        };
70        
71        // 使用 tokio 异步 Command
72        let mut cmd = Command::new(&actual_command);
73        cmd.args(&actual_args)
74            .stdin(std::process::Stdio::piped())
75            .stdout(std::process::Stdio::piped())
76            .stderr(std::process::Stdio::piped())
77            .kill_on_drop(true); // 确保进程在 drop 时被杀死
78        
79        // 设置环境变量
80        if let Some(env_vars) = env {
81            for (key, value) in env_vars {
82                cmd.env(key, value);
83            }
84        }
85        
86        // 启动进程
87        tracing::debug!("Spawning MCP server '{}' with command: {} {:?}", server_name, actual_command, actual_args);
88        let mut child = cmd.spawn()
89            .map_err(|e| anyhow!("Failed to spawn MCP server '{}': {} (command: {} {:?})", 
90                server_name, e, actual_command, actual_args))?;
91        
92        tracing::debug!("MCP server '{}' process spawned successfully", server_name);
93        
94        // 获取 stdin/stdout (tokio 异步版本)
95        let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(child.stdin.take()
96            .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?);
97        let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(child.stdout.take()
98            .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?);
99        
100        tracing::info!("MCP server '{}' started: {} {:?}", server_name, actual_command, actual_args);
101        
102        Ok(Self {
103            process: Arc::new(Mutex::new(Some(child))),
104            writer: Arc::new(Mutex::new(Some(stdin))),
105            reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
106            server_name,
107        })
108    }
109    
110    /// 读取一行响应(带超时)
111    async fn read_line(&self) -> Result<String> {
112        let mut reader_lock = self.reader.lock().await;
113        let reader = reader_lock.as_mut()
114            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
115        
116        let mut line = String::new();
117        
118        // 添加 30 秒超时
119        tracing::debug!("Reading from '{}' (timeout: 30s)...", self.server_name);
120        let read_result = tokio::time::timeout(
121            Duration::from_secs(30),
122            reader.read_line(&mut line)
123        ).await;
124        
125        tracing::debug!("Read result from '{}': {:?}", self.server_name, read_result.is_ok());
126        
127        match read_result {
128            Ok(Ok(_)) => {
129                if line.is_empty() {
130                    return Err(anyhow!("EOF reached for server '{}'", self.server_name));
131                }
132                // 移除换行符
133                Ok(line.trim_end().to_string())
134            }
135            Ok(Err(e)) => Err(anyhow!("Read error for server '{}': {}", self.server_name, e)),
136            Err(_) => Err(anyhow!("Read timeout for server '{}' after 30s", self.server_name)),
137        }
138    }
139}
140
141#[async_trait]
142impl Transport for StdioTransport {
143    async fn send(&self, message: &str) -> Result<String> {
144        tracing::debug!("MCP send to '{}': {}", self.server_name, message.chars().take(200).collect::<String>());
145        
146        let mut writer_lock = self.writer.lock().await;
147        let writer = writer_lock.as_mut()
148            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
149        
150        // 发送请求(带换行符)
151        writer.write_all(format!("{}\n", message).as_bytes()).await?;
152        writer.flush().await?;
153        
154        tracing::debug!("MCP sent, waiting for response from '{}'...", self.server_name);
155        
156        // 等待响应
157        let response = self.read_line().await?;
158        tracing::debug!("MCP received from '{}': {}", self.server_name, response.chars().take(200).collect::<String>());
159        Ok(response)
160    }
161    
162    async fn notify(&self, message: &str) -> Result<()> {
163        let mut writer_lock = self.writer.lock().await;
164        let writer = writer_lock.as_mut()
165            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
166        
167        tracing::info!("MCP >> '{}' : {}", self.server_name, message.chars().take(100).collect::<String>());
168        writer.write_all(format!("{}\n", message).as_bytes()).await?;
169        writer.flush().await?;
170        Ok(())
171    }
172    
173    async fn receive(&self) -> Result<String> {
174        let line = self.read_line().await?;
175        tracing::info!("MCP << '{}' : {}", self.server_name, line.chars().take(100).collect::<String>());
176        Ok(line)
177    }
178    
179    async fn close(&self) -> Result<()> {
180        let mut process_lock = self.process.lock().await;
181        if let Some(mut child) = process_lock.take() {
182            child.kill().await.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
183            tracing::info!("MCP server '{}' stopped", self.server_name);
184        }
185        
186        *self.writer.lock().await = None;
187        *self.reader.lock().await = None;
188        Ok(())
189    }
190}
191
192// ============================================================================
193// SSE Transport (HTTP)
194// ============================================================================
195
196/// SSE 传输 - 通过 HTTP Server-Sent Events 通信
197pub struct SseTransport {
198    /// 基础 URL
199    base_url: String,
200    /// HTTP 客户端
201    client: reqwest::Client,
202    /// 服务器名称
203    server_name: String,
204    /// 请求超时
205    timeout_ms: u64,
206}
207
208impl SseTransport {
209    /// 创建 SSE 传输
210    pub fn new(
211        name: impl Into<String>,
212        base_url: impl Into<String>,
213        timeout_ms: Option<u64>,
214    ) -> Self {
215        Self {
216            base_url: base_url.into(),
217            client: reqwest::Client::new(),
218            server_name: name.into(),
219            timeout_ms: timeout_ms.unwrap_or(30000),
220        }
221    }
222    
223    /// 发送 HTTP 请求
224    async fn send_http(&self, body: &str) -> Result<String> {
225        let url = format!("{}/mcp", self.base_url);
226        
227        let response = timeout(
228            Duration::from_millis(self.timeout_ms),
229            self.client
230                .post(&url)
231                .header("Content-Type", "application/json")
232                .body(body.to_string())
233                .send()
234        ).await
235            .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
236            .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
237        
238        let text = response.text().await?;
239        Ok(text)
240    }
241}
242
243#[async_trait]
244impl Transport for SseTransport {
245    async fn send(&self, message: &str) -> Result<String> {
246        self.send_http(message).await
247    }
248    
249    async fn notify(&self, message: &str) -> Result<()> {
250        // SSE 通知也是通过 HTTP POST
251        self.send_http(message).await?;
252        Ok(())
253    }
254    
255    async fn receive(&self) -> Result<String> {
256        // SSE 需要等待 HTTP 响应,通常 send 已包含响应
257        // 这里作为简化实现,实际 SSE 场景可能需要单独处理
258        Err(anyhow!("SSE receive not implemented - use send() for request/response"))
259    }
260    
261    async fn close(&self) -> Result<()> {
262        // HTTP 连接无需关闭
263        Ok(())
264    }
265}
266
267// ============================================================================
268// Transport Factory
269// ============================================================================
270
271/// 传���配置
272#[derive(Debug, Clone)]
273pub enum TransportConfig {
274    /// Stdio 传输配置
275    Stdio {
276        command: String,
277        args: Vec<String>,
278        env: Option<Vec<(String, String)>>,
279    },
280    /// SSE 传输配置
281    Sse {
282        url: String,
283        timeout_ms: Option<u64>,
284    },
285}
286
287impl TransportConfig {
288    /// 创建 stdio 配置
289    pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
290        Self::Stdio {
291            command: command.into(),
292            args,
293            env: None,
294        }
295    }
296    
297    /// 创建 SSE 配置
298    pub fn sse(url: impl Into<String>) -> Self {
299        Self::Sse {
300            url: url.into(),
301            timeout_ms: None,
302        }
303    }
304}
305
306/// 创建传输实例
307pub async fn create_transport(
308    server_name: &str,
309    config: &TransportConfig,
310) -> Result<Box<dyn Transport>> {
311    match config {
312        TransportConfig::Stdio { command, args, env } => {
313            Ok(Box::new(StdioTransport::spawn(
314                server_name,
315                command,
316                args,
317                env.clone(),
318            ).await?))
319        }
320        TransportConfig::Sse { url, timeout_ms } => {
321            Ok(Box::new(SseTransport::new(
322                server_name,
323                url,
324                *timeout_ms,
325            )))
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    
334    #[test]
335    fn test_transport_config_stdio() {
336        let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
337        match config {
338            TransportConfig::Stdio { command, args, .. } => {
339                assert_eq!(command, "npx");
340                assert_eq!(args.len(), 2);
341            }
342            _ => panic!("Expected Stdio variant"),
343        }
344    }
345    
346    #[test]
347    fn test_transport_config_sse() {
348        let config = TransportConfig::sse("http://localhost:3000");
349        match config {
350            TransportConfig::Sse { url, .. } => {
351                assert_eq!(url, "http://localhost:3000");
352            }
353            _ => panic!("Expected Sse variant"),
354        }
355    }
356}