Skip to main content

matrixcode_core/mcp/
transport.rs

1//! MCP Transport Layer
2//!
3//! 提供两种传输方式:
4//! - StdioTransport: 通过 stdin/stdout 与子进程通信(最常用)
5//! - SseTransport: 通过 HTTP SSE 连接远程服务器
6
7use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{timeout, Duration};
14
15// ============================================================================
16// Transport Trait
17// ============================================================================
18
19/// MCP 传输层抽象
20#[async_trait]
21pub trait Transport: Send + Sync {
22    /// 发送请求并等待响应
23    async fn send(&self, message: &str) -> Result<String>;
24    
25    /// 发送通知(无需响应)
26    async fn notify(&self, message: &str) -> Result<()>;
27    
28    /// 接收一条消息
29    async fn receive(&self) -> Result<String>;
30    
31    /// 关闭连接
32    async fn close(&self) -> Result<()>;
33}
34
35// ============================================================================
36// Stdio Transport
37// ============================================================================
38
39/// Stdio 传输 - 通过子进程的 stdin/stdout 通信
40pub struct StdioTransport {
41    /// 子进程
42    process: Arc<Mutex<Option<Child>>>,
43    /// 写入端 (进程 stdin)
44    writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45    /// 读取端 (进程 stdout)
46    reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47    /// 服务器名称(用于日志)
48    server_name: String,
49}
50
51impl StdioTransport {
52    /// 启动 MCP 服务器进程
53    pub async fn spawn(
54        name: impl Into<String>,
55        command: &str,
56        args: &[String],
57        env: Option<Vec<(String, String)>>,
58    ) -> Result<Self> {
59        let server_name = name.into();
60        
61        // Windows 兼容性:npx, npm 等需要通过 cmd.exe 运行
62        let (actual_command, actual_args) = if cfg!(target_os = "windows") 
63            && (command == "npx" || command == "npm" || command == "node") {
64            let mut full_args = vec!["/c".to_string(), command.to_string()];
65            full_args.extend(args.iter().cloned());
66            ("cmd.exe".to_string(), full_args)
67        } else {
68            (command.to_string(), args.to_vec())
69        };
70        
71        // 使用 tokio 异步 Command
72        let mut cmd = Command::new(&actual_command);
73        cmd.args(&actual_args)
74            .stdin(std::process::Stdio::piped())
75            .stdout(std::process::Stdio::piped())
76            .stderr(std::process::Stdio::piped())
77            .kill_on_drop(true); // 确保进程在 drop 时被杀死
78        
79        // 设置环境变量
80        if let Some(env_vars) = env {
81            for (key, value) in env_vars {
82                cmd.env(key, value);
83            }
84        }
85        
86        // 启动进程
87        let mut child = cmd.spawn()
88            .map_err(|e| anyhow!("Failed to spawn MCP server '{}': {} (command: {} {:?})", 
89                server_name, e, actual_command, actual_args))?;
90        
91        // 获取 stdin/stdout (tokio 异步版本)
92        let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(child.stdin.take()
93            .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?);
94        let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(child.stdout.take()
95            .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?);
96        
97        tracing::info!("MCP server '{}' started: {} {:?}", server_name, actual_command, actual_args);
98        
99        Ok(Self {
100            process: Arc::new(Mutex::new(Some(child))),
101            writer: Arc::new(Mutex::new(Some(stdin))),
102            reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
103            server_name,
104        })
105    }
106    
107    /// 读取一行响应
108    async fn read_line(&self) -> Result<String> {
109        let mut reader_lock = self.reader.lock().await;
110        let reader = reader_lock.as_mut()
111            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
112        
113        let mut line = String::new();
114        reader.read_line(&mut line).await?;
115        
116        if line.is_empty() {
117            return Err(anyhow!("EOF reached for server '{}'", self.server_name));
118        }
119        
120        // 移除换行符
121        let line = line.trim_end().to_string();
122        Ok(line)
123    }
124}
125
126#[async_trait]
127impl Transport for StdioTransport {
128    async fn send(&self, message: &str) -> Result<String> {
129        let mut writer_lock = self.writer.lock().await;
130        let writer = writer_lock.as_mut()
131            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
132        
133        // 发送请求(带换行符)
134        writer.write_all(format!("{}\n", message).as_bytes()).await?;
135        writer.flush().await?;
136        
137        // 等待响应
138        let response = self.read_line().await?;
139        Ok(response)
140    }
141    
142    async fn notify(&self, message: &str) -> Result<()> {
143        let mut writer_lock = self.writer.lock().await;
144        let writer = writer_lock.as_mut()
145            .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
146        
147        writer.write_all(format!("{}\n", message).as_bytes()).await?;
148        writer.flush().await?;
149        Ok(())
150    }
151    
152    async fn receive(&self) -> Result<String> {
153        self.read_line().await
154    }
155    
156    async fn close(&self) -> Result<()> {
157        let mut process_lock = self.process.lock().await;
158        if let Some(mut child) = process_lock.take() {
159            child.kill().await.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
160            tracing::info!("MCP server '{}' stopped", self.server_name);
161        }
162        
163        *self.writer.lock().await = None;
164        *self.reader.lock().await = None;
165        Ok(())
166    }
167}
168
169// ============================================================================
170// SSE Transport (HTTP)
171// ============================================================================
172
173/// SSE 传输 - 通过 HTTP Server-Sent Events 通信
174pub struct SseTransport {
175    /// 基础 URL
176    base_url: String,
177    /// HTTP 客户端
178    client: reqwest::Client,
179    /// 服务器名称
180    server_name: String,
181    /// 请求超时
182    timeout_ms: u64,
183}
184
185impl SseTransport {
186    /// 创建 SSE 传输
187    pub fn new(
188        name: impl Into<String>,
189        base_url: impl Into<String>,
190        timeout_ms: Option<u64>,
191    ) -> Self {
192        Self {
193            base_url: base_url.into(),
194            client: reqwest::Client::new(),
195            server_name: name.into(),
196            timeout_ms: timeout_ms.unwrap_or(30000),
197        }
198    }
199    
200    /// 发送 HTTP 请求
201    async fn send_http(&self, body: &str) -> Result<String> {
202        let url = format!("{}/mcp", self.base_url);
203        
204        let response = timeout(
205            Duration::from_millis(self.timeout_ms),
206            self.client
207                .post(&url)
208                .header("Content-Type", "application/json")
209                .body(body.to_string())
210                .send()
211        ).await
212            .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
213            .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
214        
215        let text = response.text().await?;
216        Ok(text)
217    }
218}
219
220#[async_trait]
221impl Transport for SseTransport {
222    async fn send(&self, message: &str) -> Result<String> {
223        self.send_http(message).await
224    }
225    
226    async fn notify(&self, message: &str) -> Result<()> {
227        // SSE 通知也是通过 HTTP POST
228        self.send_http(message).await?;
229        Ok(())
230    }
231    
232    async fn receive(&self) -> Result<String> {
233        // SSE 需要等待 HTTP 响应,通常 send 已包含响应
234        // 这里作为简化实现,实际 SSE 场景可能需要单独处理
235        Err(anyhow!("SSE receive not implemented - use send() for request/response"))
236    }
237    
238    async fn close(&self) -> Result<()> {
239        // HTTP 连接无需关闭
240        Ok(())
241    }
242}
243
244// ============================================================================
245// Transport Factory
246// ============================================================================
247
248/// 传���配置
249#[derive(Debug, Clone)]
250pub enum TransportConfig {
251    /// Stdio 传输配置
252    Stdio {
253        command: String,
254        args: Vec<String>,
255        env: Option<Vec<(String, String)>>,
256    },
257    /// SSE 传输配置
258    Sse {
259        url: String,
260        timeout_ms: Option<u64>,
261    },
262}
263
264impl TransportConfig {
265    /// 创建 stdio 配置
266    pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
267        Self::Stdio {
268            command: command.into(),
269            args,
270            env: None,
271        }
272    }
273    
274    /// 创建 SSE 配置
275    pub fn sse(url: impl Into<String>) -> Self {
276        Self::Sse {
277            url: url.into(),
278            timeout_ms: None,
279        }
280    }
281}
282
283/// 创建传输实例
284pub async fn create_transport(
285    server_name: &str,
286    config: &TransportConfig,
287) -> Result<Box<dyn Transport>> {
288    match config {
289        TransportConfig::Stdio { command, args, env } => {
290            Ok(Box::new(StdioTransport::spawn(
291                server_name,
292                command,
293                args,
294                env.clone(),
295            ).await?))
296        }
297        TransportConfig::Sse { url, timeout_ms } => {
298            Ok(Box::new(SseTransport::new(
299                server_name,
300                url,
301                *timeout_ms,
302            )))
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    
311    #[test]
312    fn test_transport_config_stdio() {
313        let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
314        match config {
315            TransportConfig::Stdio { command, args, .. } => {
316                assert_eq!(command, "npx");
317                assert_eq!(args.len(), 2);
318            }
319            _ => panic!("Expected Stdio variant"),
320        }
321    }
322    
323    #[test]
324    fn test_transport_config_sse() {
325        let config = TransportConfig::sse("http://localhost:3000");
326        match config {
327            TransportConfig::Sse { url, .. } => {
328                assert_eq!(url, "http://localhost:3000");
329            }
330            _ => panic!("Expected Sse variant"),
331        }
332    }
333}