mcp_protocol_sdk/transport/
stdio.rs1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20pub struct StdioClientTransport {
25 child: Option<Child>,
26 stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
28 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
29 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
30 config: TransportConfig,
31 state: ConnectionState,
32}
33
34impl StdioClientTransport {
35 pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
44 Self::with_config(command, args, TransportConfig::default()).await
45 }
46
47 pub async fn with_config<S: AsRef<str>>(
57 command: S,
58 args: Vec<S>,
59 config: TransportConfig,
60 ) -> McpResult<Self> {
61 let command_str = command.as_ref();
62 let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
63
64 tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
65
66 let mut child = Command::new(command_str)
67 .args(&args_str)
68 .stdin(Stdio::piped())
69 .stdout(Stdio::piped())
70 .stderr(Stdio::piped())
71 .spawn()
72 .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
73
74 let stdin = child
75 .stdin
76 .take()
77 .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
78 let stdout = child
79 .stdout
80 .take()
81 .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
82
83 let stdin_writer = BufWriter::new(stdin);
84 let stdout_reader = BufReader::new(stdout);
85
86 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
87 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
88
89 let reader_pending_requests = pending_requests.clone();
91 let mut reader = stdout_reader;
92 tokio::spawn(async move {
93 Self::message_processor(reader, notification_sender, reader_pending_requests).await;
94 });
95
96 Ok(Self {
97 child: Some(child),
98 stdin_writer: Some(stdin_writer),
99 stdout_reader: None, notification_receiver: Some(notification_receiver),
101 pending_requests,
102 config,
103 state: ConnectionState::Connected,
104 })
105 }
106
107 async fn message_processor(
108 mut reader: BufReader<tokio::process::ChildStdout>,
109 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
110 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
111 ) {
112 let mut line = String::new();
113
114 loop {
115 line.clear();
116 match reader.read_line(&mut line).await {
117 Ok(0) => {
118 tracing::debug!("STDIO reader reached EOF");
119 break;
120 }
121 Ok(_) => {
122 let line = line.trim();
123 if line.is_empty() {
124 continue;
125 }
126
127 tracing::trace!("Received: {}", line);
128
129 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
131 let mut pending = pending_requests.lock().await;
132 if let Some(sender) = pending.remove(&response.id) {
133 let _ = sender.send(response);
134 } else {
135 tracing::warn!(
136 "Received response for unknown request ID: {:?}",
137 response.id
138 );
139 }
140 }
141 else if let Ok(notification) =
143 serde_json::from_str::<JsonRpcNotification>(line)
144 {
145 if notification_sender.send(notification).is_err() {
146 tracing::debug!("Notification receiver dropped");
147 break;
148 }
149 } else {
150 tracing::warn!("Failed to parse message: {}", line);
151 }
152 }
153 Err(e) => {
154 tracing::error!("Error reading from stdout: {}", e);
155 break;
156 }
157 }
158 }
159 }
160}
161
162#[async_trait]
163impl Transport for StdioClientTransport {
164 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
165 let writer = self
166 .stdin_writer
167 .as_mut()
168 .ok_or_else(|| McpError::transport("Transport not connected"))?;
169
170 let (sender, receiver) = tokio::sync::oneshot::channel();
171
172 {
174 let mut pending = self.pending_requests.lock().await;
175 pending.insert(request.id.clone(), sender);
176 }
177
178 let request_line =
180 serde_json::to_string(&request).map_err(|e| McpError::serialization(e))?;
181
182 tracing::trace!("Sending: {}", request_line);
183
184 writer
185 .write_all(request_line.as_bytes())
186 .await
187 .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188 writer
189 .write_all(b"\n")
190 .await
191 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192 writer
193 .flush()
194 .await
195 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200 let response = timeout(timeout_duration, receiver)
201 .await
202 .map_err(|_| McpError::timeout("Request timeout"))?
203 .map_err(|_| McpError::transport("Response channel closed"))?;
204
205 Ok(response)
206 }
207
208 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209 let writer = self
210 .stdin_writer
211 .as_mut()
212 .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214 let notification_line =
215 serde_json::to_string(¬ification).map_err(|e| McpError::serialization(e))?;
216
217 tracing::trace!("Sending notification: {}", notification_line);
218
219 writer
220 .write_all(notification_line.as_bytes())
221 .await
222 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223 writer
224 .write_all(b"\n")
225 .await
226 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227 writer
228 .flush()
229 .await
230 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232 Ok(())
233 }
234
235 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236 if let Some(ref mut receiver) = self.notification_receiver {
237 match receiver.try_recv() {
238 Ok(notification) => Ok(Some(notification)),
239 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240 Err(mpsc::error::TryRecvError::Disconnected) => {
241 Err(McpError::transport("Notification channel disconnected"))
242 }
243 }
244 } else {
245 Ok(None)
246 }
247 }
248
249 async fn close(&mut self) -> McpResult<()> {
250 tracing::debug!("Closing STDIO transport");
251
252 self.state = ConnectionState::Closing;
253
254 if let Some(mut writer) = self.stdin_writer.take() {
256 let _ = writer.shutdown().await;
257 }
258
259 if let Some(mut child) = self.child.take() {
261 match timeout(Duration::from_secs(5), child.wait()).await {
262 Ok(Ok(status)) => {
263 tracing::debug!("Server process exited with status: {}", status);
264 }
265 Ok(Err(e)) => {
266 tracing::warn!("Error waiting for server process: {}", e);
267 }
268 Err(_) => {
269 tracing::warn!("Timeout waiting for server process, killing it");
270 let _ = child.kill().await;
271 }
272 }
273 }
274
275 self.state = ConnectionState::Disconnected;
276 Ok(())
277 }
278
279 fn is_connected(&self) -> bool {
280 matches!(self.state, ConnectionState::Connected)
281 }
282
283 fn connection_info(&self) -> String {
284 format!("STDIO transport (state: {:?})", self.state)
285 }
286}
287
288pub struct StdioServerTransport {
293 stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294 stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295 config: TransportConfig,
296 running: bool,
297 request_handler: Option<
298 Box<
299 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
300 >,
301 >,
302}
303
304impl StdioServerTransport {
305 pub fn new() -> Self {
310 Self::with_config(TransportConfig::default())
311 }
312
313 pub fn with_config(config: TransportConfig) -> Self {
321 let stdin_reader = BufReader::new(tokio::io::stdin());
322 let stdout_writer = BufWriter::new(tokio::io::stdout());
323
324 Self {
325 stdin_reader: Some(stdin_reader),
326 stdout_writer: Some(stdout_writer),
327 config,
328 running: false,
329 request_handler: None,
330 }
331 }
332
333 pub fn set_request_handler<F>(&mut self, handler: F)
338 where
339 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
340 + Send
341 + Sync
342 + 'static,
343 {
344 self.request_handler = Some(Box::new(handler));
345 }
346}
347
348#[async_trait]
349impl ServerTransport for StdioServerTransport {
350 async fn start(&mut self) -> McpResult<()> {
351 tracing::debug!("Starting STDIO server transport");
352
353 let mut reader = self
354 .stdin_reader
355 .take()
356 .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
357 let mut writer = self
358 .stdout_writer
359 .take()
360 .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
361
362 self.running = true;
363
364 let mut line = String::new();
365 while self.running {
366 line.clear();
367
368 match reader.read_line(&mut line).await {
369 Ok(0) => {
370 tracing::debug!("STDIN closed, stopping server");
371 break;
372 }
373 Ok(_) => {
374 let line = line.trim();
375 if line.is_empty() {
376 continue;
377 }
378
379 tracing::trace!("Received: {}", line);
380
381 match serde_json::from_str::<JsonRpcRequest>(line) {
383 Ok(request) => {
384 let response = self.handle_request(request).await?;
385
386 let response_line = serde_json::to_string(&response)
387 .map_err(|e| McpError::serialization(e))?;
388
389 tracing::trace!("Sending: {}", response_line);
390
391 writer
392 .write_all(response_line.as_bytes())
393 .await
394 .map_err(|e| {
395 McpError::transport(format!("Failed to write response: {}", e))
396 })?;
397 writer.write_all(b"\n").await.map_err(|e| {
398 McpError::transport(format!("Failed to write newline: {}", e))
399 })?;
400 writer.flush().await.map_err(|e| {
401 McpError::transport(format!("Failed to flush: {}", e))
402 })?;
403 }
404 Err(e) => {
405 tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
406 }
409 }
410 }
411 Err(e) => {
412 tracing::error!("Error reading from stdin: {}", e);
413 return Err(McpError::io(e));
414 }
415 }
416 }
417
418 Ok(())
419 }
420
421 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
422 Ok(JsonRpcResponse {
424 jsonrpc: "2.0".to_string(),
425 id: request.id,
426 result: None,
427 error: Some(crate::protocol::types::JsonRpcError {
428 code: crate::protocol::types::METHOD_NOT_FOUND,
429 message: format!("Method '{}' not found", request.method),
430 data: None,
431 }),
432 })
433 }
434
435 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
436 let writer = self
437 .stdout_writer
438 .as_mut()
439 .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
440
441 let notification_line =
442 serde_json::to_string(¬ification).map_err(|e| McpError::serialization(e))?;
443
444 tracing::trace!("Sending notification: {}", notification_line);
445
446 writer
447 .write_all(notification_line.as_bytes())
448 .await
449 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
450 writer
451 .write_all(b"\n")
452 .await
453 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
454 writer
455 .flush()
456 .await
457 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
458
459 Ok(())
460 }
461
462 async fn stop(&mut self) -> McpResult<()> {
463 tracing::debug!("Stopping STDIO server transport");
464 self.running = false;
465 Ok(())
466 }
467
468 fn is_running(&self) -> bool {
469 self.running
470 }
471
472 fn server_info(&self) -> String {
473 format!("STDIO server transport (running: {})", self.running)
474 }
475}
476
477impl Default for StdioServerTransport {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483impl Drop for StdioClientTransport {
484 fn drop(&mut self) {
485 if let Some(mut child) = self.child.take() {
486 let _ = child.start_kill();
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use serde_json::json;
496
497 #[test]
498 fn test_stdio_server_creation() {
499 let transport = StdioServerTransport::new();
500 assert!(!transport.is_running());
501 assert!(transport.stdin_reader.is_some());
502 assert!(transport.stdout_writer.is_some());
503 }
504
505 #[test]
506 fn test_stdio_server_with_config() {
507 let mut config = TransportConfig::default();
508 config.read_timeout_ms = Some(30_000);
509
510 let transport = StdioServerTransport::with_config(config);
511 assert_eq!(transport.config.read_timeout_ms, Some(30_000));
512 }
513
514 #[tokio::test]
515 async fn test_stdio_server_handle_request() {
516 let mut transport = StdioServerTransport::new();
517
518 let request = JsonRpcRequest {
519 jsonrpc: "2.0".to_string(),
520 id: json!(1),
521 method: "unknown_method".to_string(),
522 params: None,
523 };
524
525 let response = transport.handle_request(request).await.unwrap();
526 assert_eq!(response.jsonrpc, "2.0");
527 assert_eq!(response.id, json!(1));
528 assert!(response.error.is_some());
529 assert!(response.result.is_none());
530
531 let error = response.error.unwrap();
532 assert_eq!(error.code, crate::protocol::types::METHOD_NOT_FOUND);
533 }
534
535 }