1use anyhow::{Result, anyhow};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{Duration, timeout};
14
15#[async_trait]
21pub trait Transport: Send + Sync {
22 async fn send(&self, message: &str) -> Result<String>;
24
25 async fn notify(&self, message: &str) -> Result<()>;
27
28 async fn receive(&self) -> Result<String>;
30
31 async fn close(&self) -> Result<()>;
33}
34
35pub struct StdioTransport {
41 process: Arc<Mutex<Option<Child>>>,
43 writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45 reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47 server_name: String,
49}
50
51impl StdioTransport {
52 pub async fn spawn(
54 name: impl Into<String>,
55 command: &str,
56 args: &[String],
57 env: Option<Vec<(String, String)>>,
58 ) -> Result<Self> {
59 let server_name = name.into();
60
61 let (actual_command, actual_args) = if cfg!(target_os = "windows")
63 && (command == "npx" || command == "npm" || command == "node")
64 {
65 let mut full_args = vec!["/c".to_string(), command.to_string()];
66 full_args.extend(args.iter().cloned());
67 ("cmd.exe".to_string(), full_args)
68 } else {
69 (command.to_string(), args.to_vec())
70 };
71
72 let mut cmd = Command::new(&actual_command);
74 cmd.args(&actual_args)
75 .stdin(std::process::Stdio::piped())
76 .stdout(std::process::Stdio::piped())
77 .stderr(std::process::Stdio::piped())
78 .kill_on_drop(true); if let Some(env_vars) = env {
82 for (key, value) in env_vars {
83 cmd.env(key, value);
84 }
85 }
86
87 tracing::debug!(
89 "Spawning MCP server '{}' with command: {} {:?}",
90 server_name,
91 actual_command,
92 actual_args
93 );
94 let mut child = cmd.spawn().map_err(|e| {
95 anyhow!(
96 "Failed to spawn MCP server '{}': {} (command: {} {:?})",
97 server_name,
98 e,
99 actual_command,
100 actual_args
101 )
102 })?;
103
104 tracing::debug!("MCP server '{}' process spawned successfully", server_name);
105
106 let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(
108 child
109 .stdin
110 .take()
111 .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?,
112 );
113 let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(
114 child
115 .stdout
116 .take()
117 .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?,
118 );
119
120 tracing::info!(
121 "MCP server '{}' started: {} {:?}",
122 server_name,
123 actual_command,
124 actual_args
125 );
126
127 Ok(Self {
128 process: Arc::new(Mutex::new(Some(child))),
129 writer: Arc::new(Mutex::new(Some(stdin))),
130 reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
131 server_name,
132 })
133 }
134
135 async fn read_line(&self) -> Result<String> {
137 let mut reader_lock = self.reader.lock().await;
138 let reader = reader_lock
139 .as_mut()
140 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
141
142 let mut line = String::new();
143
144 tracing::debug!("Reading from '{}' (timeout: 30s)...", self.server_name);
146 let read_result =
147 tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut line)).await;
148
149 tracing::debug!(
150 "Read result from '{}': {:?}",
151 self.server_name,
152 read_result.is_ok()
153 );
154
155 match read_result {
156 Ok(Ok(_)) => {
157 if line.is_empty() {
158 return Err(anyhow!("EOF reached for server '{}'", self.server_name));
159 }
160 Ok(line.trim_end().to_string())
162 }
163 Ok(Err(e)) => Err(anyhow!(
164 "Read error for server '{}': {}",
165 self.server_name,
166 e
167 )),
168 Err(_) => Err(anyhow!(
169 "Read timeout for server '{}' after 30s",
170 self.server_name
171 )),
172 }
173 }
174}
175
176#[async_trait]
177impl Transport for StdioTransport {
178 async fn send(&self, message: &str) -> Result<String> {
179 tracing::debug!(
180 "MCP send to '{}': {}",
181 self.server_name,
182 message.chars().take(200).collect::<String>()
183 );
184
185 let mut writer_lock = self.writer.lock().await;
186 let writer = writer_lock
187 .as_mut()
188 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
189
190 writer
192 .write_all(format!("{}\n", message).as_bytes())
193 .await?;
194 writer.flush().await?;
195
196 tracing::debug!(
197 "MCP sent, waiting for response from '{}'...",
198 self.server_name
199 );
200
201 let response = self.read_line().await?;
203 tracing::debug!(
204 "MCP received from '{}': {}",
205 self.server_name,
206 response.chars().take(200).collect::<String>()
207 );
208 Ok(response)
209 }
210
211 async fn notify(&self, message: &str) -> Result<()> {
212 let mut writer_lock = self.writer.lock().await;
213 let writer = writer_lock
214 .as_mut()
215 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
216
217 tracing::info!(
218 "MCP >> '{}' : {}",
219 self.server_name,
220 message.chars().take(100).collect::<String>()
221 );
222 writer
223 .write_all(format!("{}\n", message).as_bytes())
224 .await?;
225 writer.flush().await?;
226 Ok(())
227 }
228
229 async fn receive(&self) -> Result<String> {
230 let line = self.read_line().await?;
231 tracing::info!(
232 "MCP << '{}' : {}",
233 self.server_name,
234 line.chars().take(100).collect::<String>()
235 );
236 Ok(line)
237 }
238
239 async fn close(&self) -> Result<()> {
240 let mut process_lock = self.process.lock().await;
241 if let Some(mut child) = process_lock.take() {
242 child
243 .kill()
244 .await
245 .map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
246 tracing::info!("MCP server '{}' stopped", self.server_name);
247 }
248
249 *self.writer.lock().await = None;
250 *self.reader.lock().await = None;
251 Ok(())
252 }
253}
254
255pub struct SseTransport {
261 base_url: String,
263 client: reqwest::Client,
265 server_name: String,
267 timeout_ms: u64,
269}
270
271impl SseTransport {
272 pub fn new(
274 name: impl Into<String>,
275 base_url: impl Into<String>,
276 timeout_ms: Option<u64>,
277 ) -> Self {
278 Self {
279 base_url: base_url.into(),
280 client: reqwest::Client::new(),
281 server_name: name.into(),
282 timeout_ms: timeout_ms.unwrap_or(30000),
283 }
284 }
285
286 async fn send_http(&self, body: &str) -> Result<String> {
288 let url = format!("{}/mcp", self.base_url);
289
290 let response = timeout(
291 Duration::from_millis(self.timeout_ms),
292 self.client
293 .post(&url)
294 .header("Content-Type", "application/json")
295 .body(body.to_string())
296 .send(),
297 )
298 .await
299 .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
300 .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
301
302 let text = response.text().await?;
303 Ok(text)
304 }
305}
306
307#[async_trait]
308impl Transport for SseTransport {
309 async fn send(&self, message: &str) -> Result<String> {
310 self.send_http(message).await
311 }
312
313 async fn notify(&self, message: &str) -> Result<()> {
314 self.send_http(message).await?;
316 Ok(())
317 }
318
319 async fn receive(&self) -> Result<String> {
320 Err(anyhow!(
323 "SSE receive not implemented - use send() for request/response"
324 ))
325 }
326
327 async fn close(&self) -> Result<()> {
328 Ok(())
330 }
331}
332
333#[derive(Debug, Clone)]
339pub enum TransportConfig {
340 Stdio {
342 command: String,
343 args: Vec<String>,
344 env: Option<Vec<(String, String)>>,
345 },
346 Sse {
348 url: String,
349 timeout_ms: Option<u64>,
350 },
351}
352
353impl TransportConfig {
354 pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
356 Self::Stdio {
357 command: command.into(),
358 args,
359 env: None,
360 }
361 }
362
363 pub fn sse(url: impl Into<String>) -> Self {
365 Self::Sse {
366 url: url.into(),
367 timeout_ms: None,
368 }
369 }
370}
371
372pub async fn create_transport(
374 server_name: &str,
375 config: &TransportConfig,
376) -> Result<Box<dyn Transport>> {
377 match config {
378 TransportConfig::Stdio { command, args, env } => Ok(Box::new(
379 StdioTransport::spawn(server_name, command, args, env.clone()).await?,
380 )),
381 TransportConfig::Sse { url, timeout_ms } => {
382 Ok(Box::new(SseTransport::new(server_name, url, *timeout_ms)))
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_transport_config_stdio() {
393 let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
394 match config {
395 TransportConfig::Stdio { command, args, .. } => {
396 assert_eq!(command, "npx");
397 assert_eq!(args.len(), 2);
398 }
399 _ => panic!("Expected Stdio variant"),
400 }
401 }
402
403 #[test]
404 fn test_transport_config_sse() {
405 let config = TransportConfig::sse("http://localhost:3000");
406 match config {
407 TransportConfig::Sse { url, .. } => {
408 assert_eq!(url, "http://localhost:3000");
409 }
410 _ => panic!("Expected Sse variant"),
411 }
412 }
413}