1use async_trait::async_trait;
7use serde_json::Value;
8use std::io::{BufRead, BufReader, Write};
9use std::process::{Child, Command, Stdio};
10use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::sync::oneshot;
14use tokio::time::timeout;
15
16use crate::protocol::*;
17use crate::transport::{InitializeParams, McpTransport, McpTransportError, TransportTypeId};
18
19pub struct StdioTransport {
21 process: Arc<Mutex<Child>>,
22 next_id: Arc<AtomicI64>,
23 alive: Arc<AtomicBool>,
24}
25
26impl StdioTransport {
27 pub fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
29 Self::spawn_with_env(command, args, std::collections::HashMap::new())
30 }
31
32 pub fn spawn_with_env(
34 command: &str,
35 args: &[String],
36 env: std::collections::HashMap<String, String>,
37 ) -> Result<Self, McpTransportError> {
38 let mut cmd = Command::new(command);
39 cmd.args(args)
40 .stdin(Stdio::piped())
41 .stdout(Stdio::piped())
42 .stderr(Stdio::piped());
43
44 for (key, value) in env {
45 cmd.env(key, value);
46 }
47
48 let child = cmd.spawn().map_err(|e| {
49 McpTransportError::TransportError(format!(
50 "Failed to spawn process '{}': {}",
51 command, e
52 ))
53 })?;
54
55 let mut process = child;
57 if let Some(status) = process.try_wait().map_err(|e| {
58 McpTransportError::TransportError(format!("Process check failed: {}", e))
59 })? {
60 return Err(McpTransportError::TransportError(format!(
61 "Process exited immediately with status: {}",
62 status
63 )));
64 }
65
66 Ok(Self {
67 process: Arc::new(Mutex::new(process)),
68 next_id: Arc::new(AtomicI64::new(1)),
69 alive: Arc::new(AtomicBool::new(true)),
70 })
71 }
72
73 pub fn send_request_sync(
75 &self,
76 method: &str,
77 params: Option<Value>,
78 ) -> Result<Value, McpTransportError> {
79 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
80 let request = JsonRpcRequest::new(JsonRpcId::Number(id), method, params);
81
82 let mut process = self
83 .process
84 .lock()
85 .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
86
87 let stdin = process
89 .stdin
90 .as_mut()
91 .ok_or_else(|| McpTransportError::TransportError("Failed to get stdin".to_string()))?;
92
93 let request_json = serde_json::to_string(&request)?;
95
96 writeln!(stdin, "{}", request_json).map_err(|e| McpTransportError::IoError(e))?;
97
98 stdin.flush().map_err(|e| McpTransportError::IoError(e))?;
99
100 let stdout = process
102 .stdout
103 .as_mut()
104 .ok_or_else(|| McpTransportError::TransportError("Failed to get stdout".to_string()))?;
105
106 let mut reader = BufReader::new(stdout);
107 let mut response_line = String::new();
108
109 reader
110 .read_line(&mut response_line)
111 .map_err(|e| McpTransportError::IoError(e))?;
112
113 if response_line.is_empty() {
114 self.alive.store(false, Ordering::SeqCst);
115 return Err(McpTransportError::ConnectionClosed);
116 }
117
118 let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
120
121 match response.payload {
123 JsonRpcPayload::Success { result } => Ok(result),
124 JsonRpcPayload::Error { error } => Err(McpTransportError::ServerError(format!(
125 "MCP Error: {}",
126 error
127 ))),
128 }
129 }
130
131 pub fn is_alive(&self) -> bool {
133 if !self.alive.load(Ordering::SeqCst) {
134 return false;
135 }
136
137 if let Ok(mut process) = self.process.lock() {
138 let alive = process.try_wait().ok().flatten().is_none();
139 self.alive.store(alive, Ordering::SeqCst);
140 alive
141 } else {
142 false
143 }
144 }
145
146 pub fn stop(&self) -> Result<(), McpTransportError> {
148 self.alive.store(false, Ordering::SeqCst);
149
150 let mut process = self
151 .process
152 .lock()
153 .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
154
155 process.kill().map_err(|e| McpTransportError::IoError(e))?;
156
157 process.wait().map_err(|e| McpTransportError::IoError(e))?;
158
159 Ok(())
160 }
161}
162
163impl Drop for StdioTransport {
164 fn drop(&mut self) {
165 let _ = self.stop();
166 }
167}
168
169pub struct AsyncStdioTransport {
171 inner: StdioTransport,
172}
173
174impl AsyncStdioTransport {
175 pub fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
177 Ok(Self {
178 inner: StdioTransport::spawn(command, args)?,
179 })
180 }
181
182 pub fn spawn_with_env(
184 command: &str,
185 args: &[String],
186 env: std::collections::HashMap<String, String>,
187 ) -> Result<Self, McpTransportError> {
188 Ok(Self {
189 inner: StdioTransport::spawn_with_env(command, args, env)?,
190 })
191 }
192
193 pub async fn send_request_with_timeout(
195 &self,
196 method: &str,
197 params: Option<Value>,
198 timeout_duration: Duration,
199 ) -> Result<Value, McpTransportError> {
200 let method = method.to_string();
201 let process = Arc::clone(&self.inner.process);
202 let next_id = Arc::clone(&self.inner.next_id);
203 let alive = Arc::clone(&self.inner.alive);
204
205 let (tx, rx) = oneshot::channel();
206
207 tokio::task::spawn_blocking(move || {
209 let id = next_id.fetch_add(1, Ordering::SeqCst);
210 let request = JsonRpcRequest::new(JsonRpcId::Number(id), method, params);
211
212 let result: Result<Value, McpTransportError> = (|| {
213 let mut process = process
214 .lock()
215 .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
216
217 let stdin = process.stdin.as_mut().ok_or_else(|| {
218 McpTransportError::TransportError("Failed to get stdin".to_string())
219 })?;
220
221 let request_json = serde_json::to_string(&request)?;
222
223 writeln!(stdin, "{}", request_json).map_err(|e| McpTransportError::IoError(e))?;
224
225 stdin.flush().map_err(|e| McpTransportError::IoError(e))?;
226
227 let stdout = process.stdout.as_mut().ok_or_else(|| {
228 McpTransportError::TransportError("Failed to get stdout".to_string())
229 })?;
230
231 let mut reader = BufReader::new(stdout);
232 let mut response_line = String::new();
233
234 reader
235 .read_line(&mut response_line)
236 .map_err(|e| McpTransportError::IoError(e))?;
237
238 if response_line.is_empty() {
239 alive.store(false, Ordering::SeqCst);
240 return Err(McpTransportError::ConnectionClosed);
241 }
242
243 let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
244
245 match response.payload {
246 JsonRpcPayload::Success { result } => Ok(result),
247 JsonRpcPayload::Error { error } => Err(McpTransportError::ServerError(
248 format!("MCP Error: {}", error),
249 )),
250 }
251 })();
252
253 let _ = tx.send(result);
254 });
255
256 match timeout(timeout_duration, rx).await {
258 Ok(Ok(result)) => result,
259 Ok(Err(_)) => Err(McpTransportError::TransportError(
260 "Channel closed".to_string(),
261 )),
262 Err(_) => Err(McpTransportError::Timeout(format!(
263 "Request timed out after {:?}",
264 timeout_duration
265 ))),
266 }
267 }
268
269 pub fn is_alive(&self) -> bool {
271 self.inner.is_alive()
272 }
273
274 pub fn stop(&self) -> Result<(), McpTransportError> {
276 self.inner.stop()
277 }
278}
279
280pub struct StdioTransportAdapter {
282 inner: AsyncStdioTransport,
283 timeout: Duration,
284}
285
286impl StdioTransportAdapter {
287 pub async fn connect(
289 command: &str,
290 args: &[String],
291 config: Option<Value>,
292 timeout: Duration,
293 ) -> Result<Self, McpTransportError> {
294 Self::connect_with_env(
295 command,
296 args,
297 std::collections::HashMap::new(),
298 config,
299 timeout,
300 )
301 .await
302 }
303
304 pub async fn connect_with_env(
306 command: &str,
307 args: &[String],
308 env: std::collections::HashMap<String, String>,
309 config: Option<Value>,
310 timeout: Duration,
311 ) -> Result<Self, McpTransportError> {
312 let inner = AsyncStdioTransport::spawn_with_env(command, args, env)?;
313
314 let adapter = Self { inner, timeout };
315
316 let init_params = InitializeParams::new(config);
318 let _init_result = adapter
319 .inner
320 .send_request_with_timeout(
321 "initialize",
322 Some(serde_json::to_value(&init_params)?),
323 adapter.timeout,
324 )
325 .await?;
326
327 let _ = adapter
330 .inner
331 .send_request_with_timeout(
332 "notifications/initialized",
333 Some(serde_json::json!({})),
334 adapter.timeout,
335 )
336 .await;
337
338 Ok(adapter)
339 }
340}
341
342#[async_trait]
343impl McpTransport for StdioTransportAdapter {
344 async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError> {
345 let result = self
346 .inner
347 .send_request_with_timeout("tools/list", Some(serde_json::json!({})), self.timeout)
348 .await?;
349
350 let list_result: ListToolsResult = serde_json::from_value(result)?;
351
352 Ok(list_result
353 .tools
354 .into_iter()
355 .map(ToolDefinition::from)
356 .collect())
357 }
358
359 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
360 let params = CallToolParams {
361 name: name.to_string(),
362 arguments: Some(args),
363 };
364
365 let result = self
366 .inner
367 .send_request_with_timeout(
368 "tools/call",
369 Some(serde_json::to_value(¶ms)?),
370 self.timeout,
371 )
372 .await?;
373
374 let call_result: CallToolResult = serde_json::from_value(result)?;
375
376 if call_result.is_error == Some(true) {
377 let error_text = call_result
378 .content
379 .first()
380 .and_then(|c| c.as_text())
381 .unwrap_or("Unknown error");
382 return Err(McpTransportError::ServerError(error_text.to_string()));
383 }
384
385 let text = call_result
386 .content
387 .iter()
388 .filter_map(|c| c.as_text())
389 .collect::<Vec<_>>()
390 .join("\n");
391
392 Ok(Value::String(text))
393 }
394
395 async fn shutdown(&self) -> Result<(), McpTransportError> {
396 self.inner.stop()
397 }
398
399 fn is_alive(&self) -> bool {
400 self.inner.is_alive()
401 }
402
403 fn transport_type(&self) -> TransportTypeId {
404 TransportTypeId::Stdio
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
416 fn test_transport_type() {
417 assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
420 }
421}