1use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::Mutex;
13use tokio::time::{timeout, Duration};
14
15#[async_trait]
21pub trait Transport: Send + Sync {
22 async fn send(&self, message: &str) -> Result<String>;
24
25 async fn notify(&self, message: &str) -> Result<()>;
27
28 async fn receive(&self) -> Result<String>;
30
31 async fn close(&self) -> Result<()>;
33}
34
35pub struct StdioTransport {
41 process: Arc<Mutex<Option<Child>>>,
43 writer: Arc<Mutex<Option<Box<dyn AsyncWrite + Unpin + Send>>>>,
45 reader: Arc<Mutex<Option<BufReader<Box<dyn AsyncRead + Unpin + Send>>>>>,
47 server_name: String,
49}
50
51impl StdioTransport {
52 pub async fn spawn(
54 name: impl Into<String>,
55 command: &str,
56 args: &[String],
57 env: Option<Vec<(String, String)>>,
58 ) -> Result<Self> {
59 let server_name = name.into();
60
61 let (actual_command, actual_args) = if cfg!(target_os = "windows")
63 && (command == "npx" || command == "npm" || command == "node") {
64 let mut full_args = vec!["/c".to_string(), command.to_string()];
65 full_args.extend(args.iter().cloned());
66 ("cmd.exe".to_string(), full_args)
67 } else {
68 (command.to_string(), args.to_vec())
69 };
70
71 let mut cmd = Command::new(&actual_command);
73 cmd.args(&actual_args)
74 .stdin(std::process::Stdio::piped())
75 .stdout(std::process::Stdio::piped())
76 .stderr(std::process::Stdio::piped())
77 .kill_on_drop(true); if let Some(env_vars) = env {
81 for (key, value) in env_vars {
82 cmd.env(key, value);
83 }
84 }
85
86 let mut child = cmd.spawn()
88 .map_err(|e| anyhow!("Failed to spawn MCP server '{}': {} (command: {} {:?})",
89 server_name, e, actual_command, actual_args))?;
90
91 let stdin: Box<dyn AsyncWrite + Unpin + Send> = Box::new(child.stdin.take()
93 .ok_or_else(|| anyhow!("Failed to get stdin for MCP server '{}'", server_name))?);
94 let stdout: Box<dyn AsyncRead + Unpin + Send> = Box::new(child.stdout.take()
95 .ok_or_else(|| anyhow!("Failed to get stdout for MCP server '{}'", server_name))?);
96
97 tracing::info!("MCP server '{}' started: {} {:?}", server_name, actual_command, actual_args);
98
99 Ok(Self {
100 process: Arc::new(Mutex::new(Some(child))),
101 writer: Arc::new(Mutex::new(Some(stdin))),
102 reader: Arc::new(Mutex::new(Some(BufReader::new(stdout)))),
103 server_name,
104 })
105 }
106
107 async fn read_line(&self) -> Result<String> {
109 let mut reader_lock = self.reader.lock().await;
110 let reader = reader_lock.as_mut()
111 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
112
113 let mut line = String::new();
114 reader.read_line(&mut line).await?;
115
116 if line.is_empty() {
117 return Err(anyhow!("EOF reached for server '{}'", self.server_name));
118 }
119
120 let line = line.trim_end().to_string();
122 Ok(line)
123 }
124}
125
126#[async_trait]
127impl Transport for StdioTransport {
128 async fn send(&self, message: &str) -> Result<String> {
129 let mut writer_lock = self.writer.lock().await;
130 let writer = writer_lock.as_mut()
131 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
132
133 writer.write_all(format!("{}\n", message).as_bytes()).await?;
135 writer.flush().await?;
136
137 let response = self.read_line().await?;
139 Ok(response)
140 }
141
142 async fn notify(&self, message: &str) -> Result<()> {
143 let mut writer_lock = self.writer.lock().await;
144 let writer = writer_lock.as_mut()
145 .ok_or_else(|| anyhow!("Transport closed for server '{}'", self.server_name))?;
146
147 writer.write_all(format!("{}\n", message).as_bytes()).await?;
148 writer.flush().await?;
149 Ok(())
150 }
151
152 async fn receive(&self) -> Result<String> {
153 self.read_line().await
154 }
155
156 async fn close(&self) -> Result<()> {
157 let mut process_lock = self.process.lock().await;
158 if let Some(mut child) = process_lock.take() {
159 child.kill().await.map_err(|e| anyhow!("Failed to kill MCP server '{}': {}", self.server_name, e))?;
160 tracing::info!("MCP server '{}' stopped", self.server_name);
161 }
162
163 *self.writer.lock().await = None;
164 *self.reader.lock().await = None;
165 Ok(())
166 }
167}
168
169pub struct SseTransport {
175 base_url: String,
177 client: reqwest::Client,
179 server_name: String,
181 timeout_ms: u64,
183}
184
185impl SseTransport {
186 pub fn new(
188 name: impl Into<String>,
189 base_url: impl Into<String>,
190 timeout_ms: Option<u64>,
191 ) -> Self {
192 Self {
193 base_url: base_url.into(),
194 client: reqwest::Client::new(),
195 server_name: name.into(),
196 timeout_ms: timeout_ms.unwrap_or(30000),
197 }
198 }
199
200 async fn send_http(&self, body: &str) -> Result<String> {
202 let url = format!("{}/mcp", self.base_url);
203
204 let response = timeout(
205 Duration::from_millis(self.timeout_ms),
206 self.client
207 .post(&url)
208 .header("Content-Type", "application/json")
209 .body(body.to_string())
210 .send()
211 ).await
212 .map_err(|_| anyhow!("Request timeout for MCP server '{}'", self.server_name))?
213 .map_err(|e| anyhow!("HTTP error for MCP server '{}': {}", self.server_name, e))?;
214
215 let text = response.text().await?;
216 Ok(text)
217 }
218}
219
220#[async_trait]
221impl Transport for SseTransport {
222 async fn send(&self, message: &str) -> Result<String> {
223 self.send_http(message).await
224 }
225
226 async fn notify(&self, message: &str) -> Result<()> {
227 self.send_http(message).await?;
229 Ok(())
230 }
231
232 async fn receive(&self) -> Result<String> {
233 Err(anyhow!("SSE receive not implemented - use send() for request/response"))
236 }
237
238 async fn close(&self) -> Result<()> {
239 Ok(())
241 }
242}
243
244#[derive(Debug, Clone)]
250pub enum TransportConfig {
251 Stdio {
253 command: String,
254 args: Vec<String>,
255 env: Option<Vec<(String, String)>>,
256 },
257 Sse {
259 url: String,
260 timeout_ms: Option<u64>,
261 },
262}
263
264impl TransportConfig {
265 pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
267 Self::Stdio {
268 command: command.into(),
269 args,
270 env: None,
271 }
272 }
273
274 pub fn sse(url: impl Into<String>) -> Self {
276 Self::Sse {
277 url: url.into(),
278 timeout_ms: None,
279 }
280 }
281}
282
283pub async fn create_transport(
285 server_name: &str,
286 config: &TransportConfig,
287) -> Result<Box<dyn Transport>> {
288 match config {
289 TransportConfig::Stdio { command, args, env } => {
290 Ok(Box::new(StdioTransport::spawn(
291 server_name,
292 command,
293 args,
294 env.clone(),
295 ).await?))
296 }
297 TransportConfig::Sse { url, timeout_ms } => {
298 Ok(Box::new(SseTransport::new(
299 server_name,
300 url,
301 *timeout_ms,
302 )))
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_transport_config_stdio() {
313 let config = TransportConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
314 match config {
315 TransportConfig::Stdio { command, args, .. } => {
316 assert_eq!(command, "npx");
317 assert_eq!(args.len(), 2);
318 }
319 _ => panic!("Expected Stdio variant"),
320 }
321 }
322
323 #[test]
324 fn test_transport_config_sse() {
325 let config = TransportConfig::sse("http://localhost:3000");
326 match config {
327 TransportConfig::Sse { url, .. } => {
328 assert_eq!(url, "http://localhost:3000");
329 }
330 _ => panic!("Expected Sse variant"),
331 }
332 }
333}