1use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{timeout, Duration};
14
15#[async_trait]
21pub trait Transport: Send + Sync {
22 async fn send(&self, message: &str) -> Result<String>;
24
25 async fn notify(&self, message: &str) -> Result<()>;
27
28 async fn receive(&self) -> Result<String>;
30
31 async fn close(&self) -> Result<()>;
33}
34
35pub struct StdioTransport {
41 process: Arc<Mutex<Option<Child>>>,
43 writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45 reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47 server_name: String,
49}
50
51impl StdioTransport {
52 pub async fn spawn(
54 name: impl Into<String>,
55 command: &str,
56 args: &[String],
57 env: Option<Vec<(String, String)>>,
58 ) -> Result<Self> {
59 let server_name = name.into();
60
61 let (actual_command, actual_args) = if cfg!(target_os = "windows")
63 && (command == "npx" || command == "npm" || command == "node") {
64 let mut full_args = vec!["/c".to_string(), command.to_string()];
65 full_args.extend(args.iter().cloned());
66 ("cmd.exe".to_string(), full_args)
67 } else {
68 (command.to_string(), args.to_vec())
69 };
70
71 let mut cmd = Command::new(&actual_command);
73 cmd.args(&actual_args)
74 .stdin(std::process::Stdio::piped())
75 .stdout(std::process::Stdio::piped())
76 .stderr(std::process::Stdio::piped())
77 .kill_on_drop(true); if let Some(env_vars) = env {
81 for (key, value) in env_vars {
82 cmd.env(key, value);
83 }
84 }
85
86 tracing::debug!("Spawning MCP server '{}' with command: {} {:?}", server_name, actual_command, actual_args);
88 let mut child = cmd.spawn()
89 .map_err(|e| anyhow!("Failed to spawn MCP server '{}': {} (command: {} {:?})",
90 server_name, e, actual_command, actual_args))?;
91
92 tracing::debug!("MCP server '{}' process spawned successfully", server_name);
93
94 let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(child.stdin.take()
96 .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?);
97 let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(child.stdout.take()
98 .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?);
99
100 tracing::info!("MCP server '{}' started: {} {:?}", server_name, actual_command, actual_args);
101
102 Ok(Self {
103 process: Arc::new(Mutex::new(Some(child))),
104 writer: Arc::new(Mutex::new(Some(stdin))),
105 reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
106 server_name,
107 })
108 }
109
110 async fn read_line(&self) -> Result<String> {
112 let mut reader_lock = self.reader.lock().await;
113 let reader = reader_lock.as_mut()
114 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
115
116 let mut line = String::new();
117
118 tracing::debug!("Reading from '{}' (timeout: 30s)...", self.server_name);
120 let read_result = tokio::time::timeout(
121 Duration::from_secs(30),
122 reader.read_line(&mut line)
123 ).await;
124
125 tracing::debug!("Read result from '{}': {:?}", self.server_name, read_result.is_ok());
126
127 match read_result {
128 Ok(Ok(_)) => {
129 if line.is_empty() {
130 return Err(anyhow!("EOF reached for server '{}'", self.server_name));
131 }
132 Ok(line.trim_end().to_string())
134 }
135 Ok(Err(e)) => Err(anyhow!("Read error for server '{}': {}", self.server_name, e)),
136 Err(_) => Err(anyhow!("Read timeout for server '{}' after 30s", self.server_name)),
137 }
138 }
139}
140
141#[async_trait]
142impl Transport for StdioTransport {
143 async fn send(&self, message: &str) -> Result<String> {
144 tracing::debug!("MCP send to '{}': {}", self.server_name, message.chars().take(200).collect::<String>());
145
146 let mut writer_lock = self.writer.lock().await;
147 let writer = writer_lock.as_mut()
148 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
149
150 writer.write_all(format!("{}\n", message).as_bytes()).await?;
152 writer.flush().await?;
153
154 tracing::debug!("MCP sent, waiting for response from '{}'...", self.server_name);
155
156 let response = self.read_line().await?;
158 tracing::debug!("MCP received from '{}': {}", self.server_name, response.chars().take(200).collect::<String>());
159 Ok(response)
160 }
161
162 async fn notify(&self, message: &str) -> Result<()> {
163 let mut writer_lock = self.writer.lock().await;
164 let writer = writer_lock.as_mut()
165 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
166
167 tracing::info!("MCP >> '{}' : {}", self.server_name, message.chars().take(100).collect::<String>());
168 writer.write_all(format!("{}\n", message).as_bytes()).await?;
169 writer.flush().await?;
170 Ok(())
171 }
172
173 async fn receive(&self) -> Result<String> {
174 let line = self.read_line().await?;
175 tracing::info!("MCP << '{}' : {}", self.server_name, line.chars().take(100).collect::<String>());
176 Ok(line)
177 }
178
179 async fn close(&self) -> Result<()> {
180 let mut process_lock = self.process.lock().await;
181 if let Some(mut child) = process_lock.take() {
182 child.kill().await.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
183 tracing::info!("MCP server '{}' stopped", self.server_name);
184 }
185
186 *self.writer.lock().await = None;
187 *self.reader.lock().await = None;
188 Ok(())
189 }
190}
191
192pub struct SseTransport {
198 base_url: String,
200 client: reqwest::Client,
202 server_name: String,
204 timeout_ms: u64,
206}
207
208impl SseTransport {
209 pub fn new(
211 name: impl Into<String>,
212 base_url: impl Into<String>,
213 timeout_ms: Option<u64>,
214 ) -> Self {
215 Self {
216 base_url: base_url.into(),
217 client: reqwest::Client::new(),
218 server_name: name.into(),
219 timeout_ms: timeout_ms.unwrap_or(30000),
220 }
221 }
222
223 async fn send_http(&self, body: &str) -> Result<String> {
225 let url = format!("{}/mcp", self.base_url);
226
227 let response = timeout(
228 Duration::from_millis(self.timeout_ms),
229 self.client
230 .post(&url)
231 .header("Content-Type", "application/json")
232 .body(body.to_string())
233 .send()
234 ).await
235 .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
236 .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
237
238 let text = response.text().await?;
239 Ok(text)
240 }
241}
242
243#[async_trait]
244impl Transport for SseTransport {
245 async fn send(&self, message: &str) -> Result<String> {
246 self.send_http(message).await
247 }
248
249 async fn notify(&self, message: &str) -> Result<()> {
250 self.send_http(message).await?;
252 Ok(())
253 }
254
255 async fn receive(&self) -> Result<String> {
256 Err(anyhow!("SSE receive not implemented - use send() for request/response"))
259 }
260
261 async fn close(&self) -> Result<()> {
262 Ok(())
264 }
265}
266
267#[derive(Debug, Clone)]
273pub enum TransportConfig {
274 Stdio {
276 command: String,
277 args: Vec<String>,
278 env: Option<Vec<(String, String)>>,
279 },
280 Sse {
282 url: String,
283 timeout_ms: Option<u64>,
284 },
285}
286
287impl TransportConfig {
288 pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
290 Self::Stdio {
291 command: command.into(),
292 args,
293 env: None,
294 }
295 }
296
297 pub fn sse(url: impl Into<String>) -> Self {
299 Self::Sse {
300 url: url.into(),
301 timeout_ms: None,
302 }
303 }
304}
305
306pub async fn create_transport(
308 server_name: &str,
309 config: &TransportConfig,
310) -> Result<Box<dyn Transport>> {
311 match config {
312 TransportConfig::Stdio { command, args, env } => {
313 Ok(Box::new(StdioTransport::spawn(
314 server_name,
315 command,
316 args,
317 env.clone(),
318 ).await?))
319 }
320 TransportConfig::Sse { url, timeout_ms } => {
321 Ok(Box::new(SseTransport::new(
322 server_name,
323 url,
324 *timeout_ms,
325 )))
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_transport_config_stdio() {
336 let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
337 match config {
338 TransportConfig::Stdio { command, args, .. } => {
339 assert_eq!(command, "npx");
340 assert_eq!(args.len(), 2);
341 }
342 _ => panic!("Expected Stdio variant"),
343 }
344 }
345
346 #[test]
347 fn test_transport_config_sse() {
348 let config = TransportConfig::sse("http://localhost:3000");
349 match config {
350 TransportConfig::Sse { url, .. } => {
351 assert_eq!(url, "http://localhost:3000");
352 }
353 _ => panic!("Expected Sse variant"),
354 }
355 }
356}