mcp_protocol_sdk/transport/
stdio.rs1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{Mutex, mpsc};
14use tokio::time::{Duration, timeout};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20pub struct StdioClientTransport {
25 child: Option<Child>,
26 stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27 #[allow(dead_code)]
28 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31 config: TransportConfig,
32 state: ConnectionState,
33}
34
35impl StdioClientTransport {
36 pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45 Self::with_config(command, args, TransportConfig::default()).await
46 }
47
48 pub async fn with_config<S: AsRef<str>>(
58 command: S,
59 args: Vec<S>,
60 config: TransportConfig,
61 ) -> McpResult<Self> {
62 let command_str = command.as_ref();
63 let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65 tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67 let mut child = Command::new(command_str)
68 .args(&args_str)
69 .stdin(Stdio::piped())
70 .stdout(Stdio::piped())
71 .stderr(Stdio::piped())
72 .spawn()
73 .map_err(|e| McpError::transport(format!("Failed to start server process: {e}")))?;
74
75 let stdin = child
76 .stdin
77 .take()
78 .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79 let stdout = child
80 .stdout
81 .take()
82 .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84 let stdin_writer = BufWriter::new(stdin);
85 let stdout_reader = BufReader::new(stdout);
86
87 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90 let reader_pending_requests = pending_requests.clone();
92 let reader = stdout_reader;
93 tokio::spawn(async move {
94 Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95 });
96
97 Ok(Self {
98 child: Some(child),
99 stdin_writer: Some(stdin_writer),
100 stdout_reader: None, notification_receiver: Some(notification_receiver),
102 pending_requests,
103 config,
104 state: ConnectionState::Connected,
105 })
106 }
107
108 async fn message_processor(
109 mut reader: BufReader<tokio::process::ChildStdout>,
110 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112 ) {
113 let mut line = String::new();
114
115 loop {
116 line.clear();
117 match reader.read_line(&mut line).await {
118 Ok(0) => {
119 tracing::debug!("STDIO reader reached EOF");
120 break;
121 }
122 Ok(_) => {
123 let line = line.trim();
124 if line.is_empty() {
125 continue;
126 }
127
128 tracing::trace!("Received: {}", line);
129
130 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132 let mut pending = pending_requests.lock().await;
133 match pending.remove(&response.id) {
134 Some(sender) => {
135 let _ = sender.send(response);
136 }
137 _ => {
138 tracing::warn!(
139 "Received response for unknown request ID: {:?}",
140 response.id
141 );
142 }
143 }
144 }
145 else if let Ok(notification) =
147 serde_json::from_str::<JsonRpcNotification>(line)
148 {
149 if notification_sender.send(notification).is_err() {
150 tracing::debug!("Notification receiver dropped");
151 break;
152 }
153 } else {
154 tracing::warn!("Failed to parse message: {}", line);
155 }
156 }
157 Err(e) => {
158 tracing::error!("Error reading from stdout: {}", e);
159 break;
160 }
161 }
162 }
163 }
164}
165
166#[async_trait]
167impl Transport for StdioClientTransport {
168 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
169 let writer = self
170 .stdin_writer
171 .as_mut()
172 .ok_or_else(|| McpError::transport("Transport not connected"))?;
173
174 let (sender, receiver) = tokio::sync::oneshot::channel();
175
176 {
178 let mut pending = self.pending_requests.lock().await;
179 pending.insert(request.id.clone(), sender);
180 }
181
182 let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
184
185 tracing::trace!("Sending: {}", request_line);
186
187 writer
188 .write_all(request_line.as_bytes())
189 .await
190 .map_err(|e| McpError::transport(format!("Failed to write request: {e}")))?;
191 writer
192 .write_all(b"\n")
193 .await
194 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
195 writer
196 .flush()
197 .await
198 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
199
200 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
202
203 let response = timeout(timeout_duration, receiver)
204 .await
205 .map_err(|_| McpError::timeout("Request timeout"))?
206 .map_err(|_| McpError::transport("Response channel closed"))?;
207
208 Ok(response)
209 }
210
211 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
212 let writer = self
213 .stdin_writer
214 .as_mut()
215 .ok_or_else(|| McpError::transport("Transport not connected"))?;
216
217 let notification_line =
218 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
219
220 tracing::trace!("Sending notification: {}", notification_line);
221
222 writer
223 .write_all(notification_line.as_bytes())
224 .await
225 .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
226 writer
227 .write_all(b"\n")
228 .await
229 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
230 writer
231 .flush()
232 .await
233 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
234
235 Ok(())
236 }
237
238 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
239 if let Some(ref mut receiver) = self.notification_receiver {
240 match receiver.try_recv() {
241 Ok(notification) => Ok(Some(notification)),
242 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
243 Err(mpsc::error::TryRecvError::Disconnected) => {
244 Err(McpError::transport("Notification channel disconnected"))
245 }
246 }
247 } else {
248 Ok(None)
249 }
250 }
251
252 async fn close(&mut self) -> McpResult<()> {
253 tracing::debug!("Closing STDIO transport");
254
255 self.state = ConnectionState::Closing;
256
257 if let Some(mut writer) = self.stdin_writer.take() {
259 let _ = writer.shutdown().await;
260 }
261
262 if let Some(mut child) = self.child.take() {
264 match timeout(Duration::from_secs(5), child.wait()).await {
265 Ok(Ok(status)) => {
266 tracing::debug!("Server process exited with status: {}", status);
267 }
268 Ok(Err(e)) => {
269 tracing::warn!("Error waiting for server process: {}", e);
270 }
271 Err(_) => {
272 tracing::warn!("Timeout waiting for server process, killing it");
273 let _ = child.kill().await;
274 }
275 }
276 }
277
278 self.state = ConnectionState::Disconnected;
279 Ok(())
280 }
281
282 fn is_connected(&self) -> bool {
283 matches!(self.state, ConnectionState::Connected)
284 }
285
286 fn connection_info(&self) -> String {
287 let state = &self.state;
288 format!("STDIO transport (state: {state:?})")
289 }
290}
291
292pub struct StdioServerTransport {
297 stdin_reader: Option<BufReader<tokio::io::Stdin>>,
298 stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
299 #[allow(dead_code)]
300 config: TransportConfig,
301 running: bool,
302 request_handler: Option<
303 Box<
304 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
305 >,
306 >,
307}
308
309impl StdioServerTransport {
310 pub fn new() -> Self {
315 Self::with_config(TransportConfig::default())
316 }
317
318 pub fn with_config(config: TransportConfig) -> Self {
326 let stdin_reader = BufReader::new(tokio::io::stdin());
327 let stdout_writer = BufWriter::new(tokio::io::stdout());
328
329 Self {
330 stdin_reader: Some(stdin_reader),
331 stdout_writer: Some(stdout_writer),
332 config,
333 running: false,
334 request_handler: None,
335 }
336 }
337
338 pub fn set_request_handler<F>(&mut self, handler: F)
343 where
344 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
345 + Send
346 + Sync
347 + 'static,
348 {
349 self.request_handler = Some(Box::new(handler));
350 }
351}
352
353#[async_trait]
354impl ServerTransport for StdioServerTransport {
355 async fn start(&mut self) -> McpResult<()> {
356 tracing::debug!("Starting STDIO server transport");
357
358 let mut reader = self
359 .stdin_reader
360 .take()
361 .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
362 let mut writer = self
363 .stdout_writer
364 .take()
365 .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
366
367 self.running = true;
368
369 let mut line = String::new();
370 while self.running {
371 line.clear();
372
373 match reader.read_line(&mut line).await {
374 Ok(0) => {
375 tracing::debug!("STDIN closed, stopping server");
376 break;
377 }
378 Ok(_) => {
379 let line = line.trim();
380 if line.is_empty() {
381 continue;
382 }
383
384 tracing::trace!("Received: {}", line);
385
386 match serde_json::from_str::<JsonRpcRequest>(line) {
388 Ok(request) => {
389 let response_or_error = match self.handle_request(request.clone()).await
390 {
391 Ok(response) => serde_json::to_string(&response),
392 Err(error) => {
393 let json_rpc_error = crate::protocol::types::JsonRpcError {
395 jsonrpc: "2.0".to_string(),
396 id: request.id,
397 error: crate::protocol::types::ErrorObject {
398 code: match error {
399 McpError::Protocol(ref msg) if msg.contains("not found") => {
400 error_codes::METHOD_NOT_FOUND
401 }
402 _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
403 },
404 message: error.to_string(),
405 data: None,
406 },
407 };
408 serde_json::to_string(&json_rpc_error)
409 }
410 };
411
412 let response_line =
413 response_or_error.map_err(McpError::serialization)?;
414
415 tracing::trace!("Sending: {}", response_line);
416
417 writer
418 .write_all(response_line.as_bytes())
419 .await
420 .map_err(|e| {
421 McpError::transport(format!("Failed to write response: {e}"))
422 })?;
423 writer.write_all(b"\n").await.map_err(|e| {
424 McpError::transport(format!("Failed to write newline: {e}"))
425 })?;
426 writer.flush().await.map_err(|e| {
427 McpError::transport(format!("Failed to flush: {e}"))
428 })?;
429 }
430 Err(e) => {
431 tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
432 }
435 }
436 }
437 Err(e) => {
438 tracing::error!("Error reading from stdin: {}", e);
439 return Err(McpError::io(e));
440 }
441 }
442 }
443
444 Ok(())
445 }
446
447 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
448 Err(McpError::protocol(format!(
450 "Method '{}' not found",
451 request.method
452 )))
453 }
454
455 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
456 let writer = self
457 .stdout_writer
458 .as_mut()
459 .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
460
461 let notification_line =
462 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
463
464 tracing::trace!("Sending notification: {}", notification_line);
465
466 writer
467 .write_all(notification_line.as_bytes())
468 .await
469 .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
470 writer
471 .write_all(b"\n")
472 .await
473 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
474 writer
475 .flush()
476 .await
477 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
478
479 Ok(())
480 }
481
482 async fn stop(&mut self) -> McpResult<()> {
483 tracing::debug!("Stopping STDIO server transport");
484 self.running = false;
485 Ok(())
486 }
487
488 fn is_running(&self) -> bool {
489 self.running
490 }
491
492 fn server_info(&self) -> String {
493 format!("STDIO server transport (running: {})", self.running)
494 }
495}
496
497impl Default for StdioServerTransport {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503impl Drop for StdioClientTransport {
504 fn drop(&mut self) {
505 if let Some(mut child) = self.child.take() {
506 let _ = child.start_kill();
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use serde_json::json;
516
517 #[test]
518 fn test_stdio_server_creation() {
519 let transport = StdioServerTransport::new();
520 assert!(!transport.is_running());
521 assert!(transport.stdin_reader.is_some());
522 assert!(transport.stdout_writer.is_some());
523 }
524
525 #[test]
526 fn test_stdio_server_with_config() {
527 let config = TransportConfig {
528 read_timeout_ms: Some(30_000),
529 ..Default::default()
530 };
531
532 let transport = StdioServerTransport::with_config(config);
533 assert_eq!(transport.config.read_timeout_ms, Some(30_000));
534 }
535
536 #[tokio::test]
537 async fn test_stdio_server_handle_request() {
538 let mut transport = StdioServerTransport::new();
539
540 let request = JsonRpcRequest {
541 jsonrpc: "2.0".to_string(),
542 id: json!(1),
543 method: "unknown_method".to_string(),
544 params: None,
545 };
546
547 let result = transport.handle_request(request).await;
548 assert!(result.is_err());
549
550 match result.unwrap_err() {
551 McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
552 _ => panic!("Expected Protocol error"),
553 }
554 }
555
556 }