mcp_protocol_sdk/transport/
stdio.rs1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{Mutex, mpsc};
14use tokio::time::{Duration, timeout};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes};
18use crate::transport::traits::{
19 ConnectionState, ServerRequestHandler, ServerTransport, Transport, TransportConfig,
20};
21
22pub struct StdioClientTransport {
27 child: Option<Child>,
28 stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
29 #[allow(dead_code)]
30 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
31 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
32 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
33 config: TransportConfig,
34 state: ConnectionState,
35}
36
37impl StdioClientTransport {
38 pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
47 Self::with_config(command, args, TransportConfig::default()).await
48 }
49
50 pub async fn with_config<S: AsRef<str>>(
60 command: S,
61 args: Vec<S>,
62 config: TransportConfig,
63 ) -> McpResult<Self> {
64 let command_str = command.as_ref();
65 let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
66
67 tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
68
69 let mut child = Command::new(command_str)
70 .args(&args_str)
71 .stdin(Stdio::piped())
72 .stdout(Stdio::piped())
73 .stderr(Stdio::piped())
74 .spawn()
75 .map_err(|e| McpError::transport(format!("Failed to start server process: {e}")))?;
76
77 let stdin = child
78 .stdin
79 .take()
80 .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
81 let stdout = child
82 .stdout
83 .take()
84 .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
85
86 let stdin_writer = BufWriter::new(stdin);
87 let stdout_reader = BufReader::new(stdout);
88
89 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
90 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
91
92 let reader_pending_requests = pending_requests.clone();
94 let reader = stdout_reader;
95 tokio::spawn(async move {
96 Self::message_processor(reader, notification_sender, reader_pending_requests).await;
97 });
98
99 Ok(Self {
100 child: Some(child),
101 stdin_writer: Some(stdin_writer),
102 stdout_reader: None, notification_receiver: Some(notification_receiver),
104 pending_requests,
105 config,
106 state: ConnectionState::Connected,
107 })
108 }
109
110 async fn message_processor(
111 mut reader: BufReader<tokio::process::ChildStdout>,
112 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
113 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
114 ) {
115 let mut line = String::new();
116
117 loop {
118 line.clear();
119 match reader.read_line(&mut line).await {
120 Ok(0) => {
121 tracing::debug!("STDIO reader reached EOF");
122 break;
123 }
124 Ok(_) => {
125 let line = line.trim();
126 if line.is_empty() {
127 continue;
128 }
129
130 tracing::trace!("Received: {}", line);
131
132 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
134 let mut pending = pending_requests.lock().await;
135 match pending.remove(&response.id) {
136 Some(sender) => {
137 let _ = sender.send(response);
138 }
139 _ => {
140 tracing::warn!(
141 "Received response for unknown request ID: {:?}",
142 response.id
143 );
144 }
145 }
146 }
147 else if let Ok(notification) =
149 serde_json::from_str::<JsonRpcNotification>(line)
150 {
151 if notification_sender.send(notification).is_err() {
152 tracing::debug!("Notification receiver dropped");
153 break;
154 }
155 } else {
156 tracing::warn!("Failed to parse message: {}", line);
157 }
158 }
159 Err(e) => {
160 tracing::error!("Error reading from stdout: {}", e);
161 break;
162 }
163 }
164 }
165 }
166}
167
168#[async_trait]
169impl Transport for StdioClientTransport {
170 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
171 let writer = self
172 .stdin_writer
173 .as_mut()
174 .ok_or_else(|| McpError::transport("Transport not connected"))?;
175
176 let (sender, receiver) = tokio::sync::oneshot::channel();
177
178 {
180 let mut pending = self.pending_requests.lock().await;
181 pending.insert(request.id.clone(), sender);
182 }
183
184 let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
186
187 tracing::trace!("Sending: {}", request_line);
188
189 writer
190 .write_all(request_line.as_bytes())
191 .await
192 .map_err(|e| McpError::transport(format!("Failed to write request: {e}")))?;
193 writer
194 .write_all(b"\n")
195 .await
196 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
197 writer
198 .flush()
199 .await
200 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
201
202 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
204
205 let response = timeout(timeout_duration, receiver)
206 .await
207 .map_err(|_| McpError::timeout("Request timeout"))?
208 .map_err(|_| McpError::transport("Response channel closed"))?;
209
210 Ok(response)
211 }
212
213 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
214 let writer = self
215 .stdin_writer
216 .as_mut()
217 .ok_or_else(|| McpError::transport("Transport not connected"))?;
218
219 let notification_line =
220 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
221
222 tracing::trace!("Sending notification: {}", notification_line);
223
224 writer
225 .write_all(notification_line.as_bytes())
226 .await
227 .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
228 writer
229 .write_all(b"\n")
230 .await
231 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
232 writer
233 .flush()
234 .await
235 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
236
237 Ok(())
238 }
239
240 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
241 if let Some(ref mut receiver) = self.notification_receiver {
242 match receiver.try_recv() {
243 Ok(notification) => Ok(Some(notification)),
244 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
245 Err(mpsc::error::TryRecvError::Disconnected) => {
246 Err(McpError::transport("Notification channel disconnected"))
247 }
248 }
249 } else {
250 Ok(None)
251 }
252 }
253
254 async fn close(&mut self) -> McpResult<()> {
255 tracing::debug!("Closing STDIO transport");
256
257 self.state = ConnectionState::Closing;
258
259 if let Some(mut writer) = self.stdin_writer.take() {
261 let _ = writer.shutdown().await;
262 }
263
264 if let Some(mut child) = self.child.take() {
266 match timeout(Duration::from_secs(5), child.wait()).await {
267 Ok(Ok(status)) => {
268 tracing::debug!("Server process exited with status: {}", status);
269 }
270 Ok(Err(e)) => {
271 tracing::warn!("Error waiting for server process: {}", e);
272 }
273 Err(_) => {
274 tracing::warn!("Timeout waiting for server process, killing it");
275 let _ = child.kill().await;
276 }
277 }
278 }
279
280 self.state = ConnectionState::Disconnected;
281 Ok(())
282 }
283
284 fn is_connected(&self) -> bool {
285 matches!(self.state, ConnectionState::Connected)
286 }
287
288 fn connection_info(&self) -> String {
289 let state = &self.state;
290 format!("STDIO transport (state: {state:?})")
291 }
292}
293
294pub struct StdioServerTransport {
299 stdin_reader: Option<BufReader<tokio::io::Stdin>>,
300 stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
301 #[allow(dead_code)]
302 config: TransportConfig,
303 running: bool,
304 request_handler: Option<ServerRequestHandler>,
305}
306
307impl StdioServerTransport {
308 pub fn new() -> Self {
313 Self::with_config(TransportConfig::default())
314 }
315
316 pub fn with_config(config: TransportConfig) -> Self {
324 let stdin_reader = BufReader::new(tokio::io::stdin());
325 let stdout_writer = BufWriter::new(tokio::io::stdout());
326
327 Self {
328 stdin_reader: Some(stdin_reader),
329 stdout_writer: Some(stdout_writer),
330 config,
331 running: false,
332 request_handler: None,
333 }
334 }
335}
336
337#[async_trait]
338impl ServerTransport for StdioServerTransport {
339 async fn start(&mut self) -> McpResult<()> {
340 tracing::debug!("Starting STDIO server transport");
341
342 let mut reader = self
343 .stdin_reader
344 .take()
345 .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
346 let mut writer = self
347 .stdout_writer
348 .take()
349 .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
350
351 self.running = true;
352 let request_handler = self.request_handler.clone();
353
354 let mut line = String::new();
355 while self.running {
356 line.clear();
357
358 match reader.read_line(&mut line).await {
359 Ok(0) => {
360 tracing::debug!("STDIN closed, stopping server");
361 break;
362 }
363 Ok(_) => {
364 let line = line.trim();
365 if line.is_empty() {
366 continue;
367 }
368
369 tracing::trace!("Received: {}", line);
370
371 match serde_json::from_str::<JsonRpcRequest>(line) {
373 Ok(request) => {
374 let response_result = if let Some(ref handler) = request_handler {
375 handler(request.clone()).await
377 } else {
378 Err(McpError::protocol(format!(
380 "Method '{}' not found",
381 request.method
382 )))
383 };
384
385 let response_or_error = match response_result {
386 Ok(response) => serde_json::to_string(&response),
387 Err(error) => {
388 let json_rpc_error = crate::protocol::types::JsonRpcError {
390 jsonrpc: "2.0".to_string(),
391 id: request.id,
392 error: crate::protocol::types::ErrorObject {
393 code: match error {
394 McpError::Protocol(ref msg) if msg.contains("not found") => {
395 error_codes::METHOD_NOT_FOUND
396 }
397 _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
398 },
399 message: error.to_string(),
400 data: None,
401 },
402 };
403 serde_json::to_string(&json_rpc_error)
404 }
405 };
406
407 let response_line =
408 response_or_error.map_err(McpError::serialization)?;
409
410 tracing::trace!("Sending: {}", response_line);
411
412 writer
413 .write_all(response_line.as_bytes())
414 .await
415 .map_err(|e| {
416 McpError::transport(format!("Failed to write response: {e}"))
417 })?;
418 writer.write_all(b"\n").await.map_err(|e| {
419 McpError::transport(format!("Failed to write newline: {e}"))
420 })?;
421 writer.flush().await.map_err(|e| {
422 McpError::transport(format!("Failed to flush: {e}"))
423 })?;
424 }
425 Err(e) => {
426 tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
427 }
430 }
431 }
432 Err(e) => {
433 tracing::error!("Error reading from stdin: {}", e);
434 return Err(McpError::io(e));
435 }
436 }
437 }
438
439 Ok(())
440 }
441
442 fn set_request_handler(&mut self, handler: ServerRequestHandler) {
443 self.request_handler = Some(handler);
444 }
445
446 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
447 let writer = self
448 .stdout_writer
449 .as_mut()
450 .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
451
452 let notification_line =
453 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
454
455 tracing::trace!("Sending notification: {}", notification_line);
456
457 writer
458 .write_all(notification_line.as_bytes())
459 .await
460 .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
461 writer
462 .write_all(b"\n")
463 .await
464 .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
465 writer
466 .flush()
467 .await
468 .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
469
470 Ok(())
471 }
472
473 async fn stop(&mut self) -> McpResult<()> {
474 tracing::debug!("Stopping STDIO server transport");
475 self.running = false;
476 Ok(())
477 }
478
479 fn is_running(&self) -> bool {
480 self.running
481 }
482
483 fn server_info(&self) -> String {
484 format!("STDIO server transport (running: {})", self.running)
485 }
486}
487
488impl StdioServerTransport {
490 pub async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
493 Err(McpError::protocol(format!(
495 "Method '{}' not found (test mode)",
496 request.method
497 )))
498 }
499}
500
501impl Default for StdioServerTransport {
502 fn default() -> Self {
503 Self::new()
504 }
505}
506
507impl Drop for StdioClientTransport {
508 fn drop(&mut self) {
509 if let Some(mut child) = self.child.take() {
510 let _ = child.start_kill();
512 }
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use serde_json::json;
520
521 #[test]
522 fn test_stdio_server_creation() {
523 let transport = StdioServerTransport::new();
524 assert!(!transport.is_running());
525 assert!(transport.stdin_reader.is_some());
526 assert!(transport.stdout_writer.is_some());
527 }
528
529 #[test]
530 fn test_stdio_server_with_config() {
531 let config = TransportConfig {
532 read_timeout_ms: Some(30_000),
533 ..Default::default()
534 };
535
536 let transport = StdioServerTransport::with_config(config);
537 assert_eq!(transport.config.read_timeout_ms, Some(30_000));
538 }
539
540 #[tokio::test]
541 async fn test_stdio_server_handle_request() {
542 let mut transport = StdioServerTransport::new();
543
544 let request = JsonRpcRequest {
545 jsonrpc: "2.0".to_string(),
546 id: json!(1),
547 method: "unknown_method".to_string(),
548 params: None,
549 };
550
551 let result = transport.handle_request(request).await;
552 assert!(result.is_err());
553
554 match result.unwrap_err() {
555 McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
556 _ => panic!("Expected Protocol error"),
557 }
558 }
559
560 }