mcp_protocol_sdk/transport/
stdio.rs1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{error_codes, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20pub struct StdioClientTransport {
25 child: Option<Child>,
26 stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27 #[allow(dead_code)]
28 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31 config: TransportConfig,
32 state: ConnectionState,
33}
34
35impl StdioClientTransport {
36 pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45 Self::with_config(command, args, TransportConfig::default()).await
46 }
47
48 pub async fn with_config<S: AsRef<str>>(
58 command: S,
59 args: Vec<S>,
60 config: TransportConfig,
61 ) -> McpResult<Self> {
62 let command_str = command.as_ref();
63 let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65 tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67 let mut child = Command::new(command_str)
68 .args(&args_str)
69 .stdin(Stdio::piped())
70 .stdout(Stdio::piped())
71 .stderr(Stdio::piped())
72 .spawn()
73 .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
74
75 let stdin = child
76 .stdin
77 .take()
78 .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79 let stdout = child
80 .stdout
81 .take()
82 .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84 let stdin_writer = BufWriter::new(stdin);
85 let stdout_reader = BufReader::new(stdout);
86
87 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90 let reader_pending_requests = pending_requests.clone();
92 let reader = stdout_reader;
93 tokio::spawn(async move {
94 Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95 });
96
97 Ok(Self {
98 child: Some(child),
99 stdin_writer: Some(stdin_writer),
100 stdout_reader: None, notification_receiver: Some(notification_receiver),
102 pending_requests,
103 config,
104 state: ConnectionState::Connected,
105 })
106 }
107
108 async fn message_processor(
109 mut reader: BufReader<tokio::process::ChildStdout>,
110 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112 ) {
113 let mut line = String::new();
114
115 loop {
116 line.clear();
117 match reader.read_line(&mut line).await {
118 Ok(0) => {
119 tracing::debug!("STDIO reader reached EOF");
120 break;
121 }
122 Ok(_) => {
123 let line = line.trim();
124 if line.is_empty() {
125 continue;
126 }
127
128 tracing::trace!("Received: {}", line);
129
130 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132 let mut pending = pending_requests.lock().await;
133 if let Some(sender) = pending.remove(&response.id) {
134 let _ = sender.send(response);
135 } else {
136 tracing::warn!(
137 "Received response for unknown request ID: {:?}",
138 response.id
139 );
140 }
141 }
142 else if let Ok(notification) =
144 serde_json::from_str::<JsonRpcNotification>(line)
145 {
146 if notification_sender.send(notification).is_err() {
147 tracing::debug!("Notification receiver dropped");
148 break;
149 }
150 } else {
151 tracing::warn!("Failed to parse message: {}", line);
152 }
153 }
154 Err(e) => {
155 tracing::error!("Error reading from stdout: {}", e);
156 break;
157 }
158 }
159 }
160 }
161}
162
163#[async_trait]
164impl Transport for StdioClientTransport {
165 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
166 let writer = self
167 .stdin_writer
168 .as_mut()
169 .ok_or_else(|| McpError::transport("Transport not connected"))?;
170
171 let (sender, receiver) = tokio::sync::oneshot::channel();
172
173 {
175 let mut pending = self.pending_requests.lock().await;
176 pending.insert(request.id.clone(), sender);
177 }
178
179 let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
181
182 tracing::trace!("Sending: {}", request_line);
183
184 writer
185 .write_all(request_line.as_bytes())
186 .await
187 .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188 writer
189 .write_all(b"\n")
190 .await
191 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192 writer
193 .flush()
194 .await
195 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200 let response = timeout(timeout_duration, receiver)
201 .await
202 .map_err(|_| McpError::timeout("Request timeout"))?
203 .map_err(|_| McpError::transport("Response channel closed"))?;
204
205 Ok(response)
206 }
207
208 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209 let writer = self
210 .stdin_writer
211 .as_mut()
212 .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214 let notification_line =
215 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
216
217 tracing::trace!("Sending notification: {}", notification_line);
218
219 writer
220 .write_all(notification_line.as_bytes())
221 .await
222 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223 writer
224 .write_all(b"\n")
225 .await
226 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227 writer
228 .flush()
229 .await
230 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232 Ok(())
233 }
234
235 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236 if let Some(ref mut receiver) = self.notification_receiver {
237 match receiver.try_recv() {
238 Ok(notification) => Ok(Some(notification)),
239 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240 Err(mpsc::error::TryRecvError::Disconnected) => {
241 Err(McpError::transport("Notification channel disconnected"))
242 }
243 }
244 } else {
245 Ok(None)
246 }
247 }
248
249 async fn close(&mut self) -> McpResult<()> {
250 tracing::debug!("Closing STDIO transport");
251
252 self.state = ConnectionState::Closing;
253
254 if let Some(mut writer) = self.stdin_writer.take() {
256 let _ = writer.shutdown().await;
257 }
258
259 if let Some(mut child) = self.child.take() {
261 match timeout(Duration::from_secs(5), child.wait()).await {
262 Ok(Ok(status)) => {
263 tracing::debug!("Server process exited with status: {}", status);
264 }
265 Ok(Err(e)) => {
266 tracing::warn!("Error waiting for server process: {}", e);
267 }
268 Err(_) => {
269 tracing::warn!("Timeout waiting for server process, killing it");
270 let _ = child.kill().await;
271 }
272 }
273 }
274
275 self.state = ConnectionState::Disconnected;
276 Ok(())
277 }
278
279 fn is_connected(&self) -> bool {
280 matches!(self.state, ConnectionState::Connected)
281 }
282
283 fn connection_info(&self) -> String {
284 format!("STDIO transport (state: {:?})", self.state)
285 }
286}
287
288pub struct StdioServerTransport {
293 stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294 stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295 #[allow(dead_code)]
296 config: TransportConfig,
297 running: bool,
298 request_handler: Option<
299 Box<
300 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
301 >,
302 >,
303}
304
305impl StdioServerTransport {
306 pub fn new() -> Self {
311 Self::with_config(TransportConfig::default())
312 }
313
314 pub fn with_config(config: TransportConfig) -> Self {
322 let stdin_reader = BufReader::new(tokio::io::stdin());
323 let stdout_writer = BufWriter::new(tokio::io::stdout());
324
325 Self {
326 stdin_reader: Some(stdin_reader),
327 stdout_writer: Some(stdout_writer),
328 config,
329 running: false,
330 request_handler: None,
331 }
332 }
333
334 pub fn set_request_handler<F>(&mut self, handler: F)
339 where
340 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
341 + Send
342 + Sync
343 + 'static,
344 {
345 self.request_handler = Some(Box::new(handler));
346 }
347}
348
349#[async_trait]
350impl ServerTransport for StdioServerTransport {
351 async fn start(&mut self) -> McpResult<()> {
352 tracing::debug!("Starting STDIO server transport");
353
354 let mut reader = self
355 .stdin_reader
356 .take()
357 .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
358 let mut writer = self
359 .stdout_writer
360 .take()
361 .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
362
363 self.running = true;
364
365 let mut line = String::new();
366 while self.running {
367 line.clear();
368
369 match reader.read_line(&mut line).await {
370 Ok(0) => {
371 tracing::debug!("STDIN closed, stopping server");
372 break;
373 }
374 Ok(_) => {
375 let line = line.trim();
376 if line.is_empty() {
377 continue;
378 }
379
380 tracing::trace!("Received: {}", line);
381
382 match serde_json::from_str::<JsonRpcRequest>(line) {
384 Ok(request) => {
385 let response_or_error = match self.handle_request(request.clone()).await
386 {
387 Ok(response) => serde_json::to_string(&response),
388 Err(error) => {
389 let json_rpc_error = crate::protocol::types::JsonRpcError {
391 jsonrpc: "2.0".to_string(),
392 id: request.id,
393 error: crate::protocol::types::ErrorObject {
394 code: match error {
395 McpError::Protocol(ref msg) if msg.contains("not found") => {
396 error_codes::METHOD_NOT_FOUND
397 }
398 _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
399 },
400 message: error.to_string(),
401 data: None,
402 },
403 };
404 serde_json::to_string(&json_rpc_error)
405 }
406 };
407
408 let response_line =
409 response_or_error.map_err(McpError::serialization)?;
410
411 tracing::trace!("Sending: {}", response_line);
412
413 writer
414 .write_all(response_line.as_bytes())
415 .await
416 .map_err(|e| {
417 McpError::transport(format!("Failed to write response: {}", e))
418 })?;
419 writer.write_all(b"\n").await.map_err(|e| {
420 McpError::transport(format!("Failed to write newline: {}", e))
421 })?;
422 writer.flush().await.map_err(|e| {
423 McpError::transport(format!("Failed to flush: {}", e))
424 })?;
425 }
426 Err(e) => {
427 tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
428 }
431 }
432 }
433 Err(e) => {
434 tracing::error!("Error reading from stdin: {}", e);
435 return Err(McpError::io(e));
436 }
437 }
438 }
439
440 Ok(())
441 }
442
443 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
444 Err(McpError::protocol(format!(
446 "Method '{}' not found",
447 request.method
448 )))
449 }
450
451 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
452 let writer = self
453 .stdout_writer
454 .as_mut()
455 .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
456
457 let notification_line =
458 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
459
460 tracing::trace!("Sending notification: {}", notification_line);
461
462 writer
463 .write_all(notification_line.as_bytes())
464 .await
465 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
466 writer
467 .write_all(b"\n")
468 .await
469 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
470 writer
471 .flush()
472 .await
473 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
474
475 Ok(())
476 }
477
478 async fn stop(&mut self) -> McpResult<()> {
479 tracing::debug!("Stopping STDIO server transport");
480 self.running = false;
481 Ok(())
482 }
483
484 fn is_running(&self) -> bool {
485 self.running
486 }
487
488 fn server_info(&self) -> String {
489 format!("STDIO server transport (running: {})", self.running)
490 }
491}
492
493impl Default for StdioServerTransport {
494 fn default() -> Self {
495 Self::new()
496 }
497}
498
499impl Drop for StdioClientTransport {
500 fn drop(&mut self) {
501 if let Some(mut child) = self.child.take() {
502 let _ = child.start_kill();
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 use serde_json::json;
512
513 #[test]
514 fn test_stdio_server_creation() {
515 let transport = StdioServerTransport::new();
516 assert!(!transport.is_running());
517 assert!(transport.stdin_reader.is_some());
518 assert!(transport.stdout_writer.is_some());
519 }
520
521 #[test]
522 fn test_stdio_server_with_config() {
523 let config = TransportConfig {
524 read_timeout_ms: Some(30_000),
525 ..Default::default()
526 };
527
528 let transport = StdioServerTransport::with_config(config);
529 assert_eq!(transport.config.read_timeout_ms, Some(30_000));
530 }
531
532 #[tokio::test]
533 async fn test_stdio_server_handle_request() {
534 let mut transport = StdioServerTransport::new();
535
536 let request = JsonRpcRequest {
537 jsonrpc: "2.0".to_string(),
538 id: json!(1),
539 method: "unknown_method".to_string(),
540 params: None,
541 };
542
543 let result = transport.handle_request(request).await;
544 assert!(result.is_err());
545
546 match result.unwrap_err() {
547 McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
548 _ => panic!("Expected Protocol error"),
549 }
550 }
551
552 }