mcp_protocol_sdk/transport/
stdio.rs

1//! STDIO transport implementation for MCP
2//!
3//! This module provides STDIO-based transport for MCP communication,
4//! which is commonly used for command-line tools and process communication.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{Mutex, mpsc};
14use tokio::time::{Duration, timeout};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes};
18use crate::transport::traits::{
19    ConnectionState, ServerRequestHandler, ServerTransport, Transport, TransportConfig,
20};
21
22/// STDIO transport for MCP clients
23///
24/// This transport communicates with an MCP server via STDIO (standard input/output).
25/// It's typically used when the server is a separate process.
26pub struct StdioClientTransport {
27    child: Option<Child>,
28    stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
29    #[allow(dead_code)]
30    stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
31    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
32    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
33    config: TransportConfig,
34    state: ConnectionState,
35}
36
37impl StdioClientTransport {
38    /// Create a new STDIO client transport
39    ///
40    /// # Arguments
41    /// * `command` - Command to execute for the MCP server
42    /// * `args` - Arguments to pass to the command
43    ///
44    /// # Returns
45    /// Result containing the transport or an error
46    pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
47        Self::with_config(command, args, TransportConfig::default()).await
48    }
49
50    /// Create a new STDIO client transport with custom configuration
51    ///
52    /// # Arguments
53    /// * `command` - Command to execute for the MCP server
54    /// * `args` - Arguments to pass to the command
55    /// * `config` - Transport configuration
56    ///
57    /// # Returns
58    /// Result containing the transport or an error
59    pub async fn with_config<S: AsRef<str>>(
60        command: S,
61        args: Vec<S>,
62        config: TransportConfig,
63    ) -> McpResult<Self> {
64        let command_str = command.as_ref();
65        let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
66
67        tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
68
69        let mut child = Command::new(command_str)
70            .args(&args_str)
71            .stdin(Stdio::piped())
72            .stdout(Stdio::piped())
73            .stderr(Stdio::piped())
74            .spawn()
75            .map_err(|e| McpError::transport(format!("Failed to start server process: {e}")))?;
76
77        let stdin = child
78            .stdin
79            .take()
80            .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
81        let stdout = child
82            .stdout
83            .take()
84            .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
85
86        let stdin_writer = BufWriter::new(stdin);
87        let stdout_reader = BufReader::new(stdout);
88
89        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
90        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
91
92        // Start message processing task
93        let reader_pending_requests = pending_requests.clone();
94        let reader = stdout_reader;
95        tokio::spawn(async move {
96            Self::message_processor(reader, notification_sender, reader_pending_requests).await;
97        });
98
99        Ok(Self {
100            child: Some(child),
101            stdin_writer: Some(stdin_writer),
102            stdout_reader: None, // Moved to processor task
103            notification_receiver: Some(notification_receiver),
104            pending_requests,
105            config,
106            state: ConnectionState::Connected,
107        })
108    }
109
110    async fn message_processor(
111        mut reader: BufReader<tokio::process::ChildStdout>,
112        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
113        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
114    ) {
115        let mut line = String::new();
116
117        loop {
118            line.clear();
119            match reader.read_line(&mut line).await {
120                Ok(0) => {
121                    tracing::debug!("STDIO reader reached EOF");
122                    break;
123                }
124                Ok(_) => {
125                    let line = line.trim();
126                    if line.is_empty() {
127                        continue;
128                    }
129
130                    tracing::trace!("Received: {}", line);
131
132                    // Try to parse as response first
133                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
134                        let mut pending = pending_requests.lock().await;
135                        match pending.remove(&response.id) {
136                            Some(sender) => {
137                                let _ = sender.send(response);
138                            }
139                            _ => {
140                                tracing::warn!(
141                                    "Received response for unknown request ID: {:?}",
142                                    response.id
143                                );
144                            }
145                        }
146                    }
147                    // Try to parse as notification
148                    else if let Ok(notification) =
149                        serde_json::from_str::<JsonRpcNotification>(line)
150                    {
151                        if notification_sender.send(notification).is_err() {
152                            tracing::debug!("Notification receiver dropped");
153                            break;
154                        }
155                    } else {
156                        tracing::warn!("Failed to parse message: {}", line);
157                    }
158                }
159                Err(e) => {
160                    tracing::error!("Error reading from stdout: {}", e);
161                    break;
162                }
163            }
164        }
165    }
166}
167
168#[async_trait]
169impl Transport for StdioClientTransport {
170    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
171        let writer = self
172            .stdin_writer
173            .as_mut()
174            .ok_or_else(|| McpError::transport("Transport not connected"))?;
175
176        let (sender, receiver) = tokio::sync::oneshot::channel();
177
178        // Store the pending request
179        {
180            let mut pending = self.pending_requests.lock().await;
181            pending.insert(request.id.clone(), sender);
182        }
183
184        // Send the request
185        let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
186
187        tracing::trace!("Sending: {}", request_line);
188
189        writer
190            .write_all(request_line.as_bytes())
191            .await
192            .map_err(|e| McpError::transport(format!("Failed to write request: {e}")))?;
193        writer
194            .write_all(b"\n")
195            .await
196            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
197        writer
198            .flush()
199            .await
200            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
201
202        // Wait for response with timeout
203        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
204
205        let response = timeout(timeout_duration, receiver)
206            .await
207            .map_err(|_| McpError::timeout("Request timeout"))?
208            .map_err(|_| McpError::transport("Response channel closed"))?;
209
210        Ok(response)
211    }
212
213    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
214        let writer = self
215            .stdin_writer
216            .as_mut()
217            .ok_or_else(|| McpError::transport("Transport not connected"))?;
218
219        let notification_line =
220            serde_json::to_string(&notification).map_err(McpError::serialization)?;
221
222        tracing::trace!("Sending notification: {}", notification_line);
223
224        writer
225            .write_all(notification_line.as_bytes())
226            .await
227            .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
228        writer
229            .write_all(b"\n")
230            .await
231            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
232        writer
233            .flush()
234            .await
235            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
236
237        Ok(())
238    }
239
240    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
241        if let Some(ref mut receiver) = self.notification_receiver {
242            match receiver.try_recv() {
243                Ok(notification) => Ok(Some(notification)),
244                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
245                Err(mpsc::error::TryRecvError::Disconnected) => {
246                    Err(McpError::transport("Notification channel disconnected"))
247                }
248            }
249        } else {
250            Ok(None)
251        }
252    }
253
254    async fn close(&mut self) -> McpResult<()> {
255        tracing::debug!("Closing STDIO transport");
256
257        self.state = ConnectionState::Closing;
258
259        // Close stdin to signal the server to shut down
260        if let Some(mut writer) = self.stdin_writer.take() {
261            let _ = writer.shutdown().await;
262        }
263
264        // Wait for the child process to exit
265        if let Some(mut child) = self.child.take() {
266            match timeout(Duration::from_secs(5), child.wait()).await {
267                Ok(Ok(status)) => {
268                    tracing::debug!("Server process exited with status: {}", status);
269                }
270                Ok(Err(e)) => {
271                    tracing::warn!("Error waiting for server process: {}", e);
272                }
273                Err(_) => {
274                    tracing::warn!("Timeout waiting for server process, killing it");
275                    let _ = child.kill().await;
276                }
277            }
278        }
279
280        self.state = ConnectionState::Disconnected;
281        Ok(())
282    }
283
284    fn is_connected(&self) -> bool {
285        matches!(self.state, ConnectionState::Connected)
286    }
287
288    fn connection_info(&self) -> String {
289        let state = &self.state;
290        format!("STDIO transport (state: {state:?})")
291    }
292}
293
294/// STDIO transport for MCP servers
295///
296/// This transport communicates with an MCP client via STDIO (standard input/output).
297/// It reads requests from stdin and writes responses to stdout.
298pub struct StdioServerTransport {
299    stdin_reader: Option<BufReader<tokio::io::Stdin>>,
300    stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
301    #[allow(dead_code)]
302    config: TransportConfig,
303    running: bool,
304    request_handler: Option<ServerRequestHandler>,
305}
306
307impl StdioServerTransport {
308    /// Create a new STDIO server transport
309    ///
310    /// # Returns
311    /// New STDIO server transport instance
312    pub fn new() -> Self {
313        Self::with_config(TransportConfig::default())
314    }
315
316    /// Create a new STDIO server transport with custom configuration
317    ///
318    /// # Arguments
319    /// * `config` - Transport configuration
320    ///
321    /// # Returns
322    /// New STDIO server transport instance
323    pub fn with_config(config: TransportConfig) -> Self {
324        let stdin_reader = BufReader::new(tokio::io::stdin());
325        let stdout_writer = BufWriter::new(tokio::io::stdout());
326
327        Self {
328            stdin_reader: Some(stdin_reader),
329            stdout_writer: Some(stdout_writer),
330            config,
331            running: false,
332            request_handler: None,
333        }
334    }
335}
336
337#[async_trait]
338impl ServerTransport for StdioServerTransport {
339    async fn start(&mut self) -> McpResult<()> {
340        tracing::debug!("Starting STDIO server transport");
341
342        let mut reader = self
343            .stdin_reader
344            .take()
345            .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
346        let mut writer = self
347            .stdout_writer
348            .take()
349            .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
350
351        self.running = true;
352        let request_handler = self.request_handler.clone();
353
354        let mut line = String::new();
355        while self.running {
356            line.clear();
357
358            match reader.read_line(&mut line).await {
359                Ok(0) => {
360                    tracing::debug!("STDIN closed, stopping server");
361                    break;
362                }
363                Ok(_) => {
364                    let line = line.trim();
365                    if line.is_empty() {
366                        continue;
367                    }
368
369                    tracing::trace!("Received: {}", line);
370
371                    // Parse the request
372                    match serde_json::from_str::<JsonRpcRequest>(line) {
373                        Ok(request) => {
374                            let response_result = if let Some(ref handler) = request_handler {
375                                // Use the provided request handler
376                                handler(request.clone()).await
377                            } else {
378                                // Fall back to error if no handler is set
379                                Err(McpError::protocol(format!(
380                                    "Method '{}' not found",
381                                    request.method
382                                )))
383                            };
384
385                            let response_or_error = match response_result {
386                                Ok(response) => serde_json::to_string(&response),
387                                Err(error) => {
388                                    // Convert McpError to JsonRpcError
389                                    let json_rpc_error = crate::protocol::types::JsonRpcError {
390                                        jsonrpc: "2.0".to_string(),
391                                        id: request.id,
392                                        error: crate::protocol::types::ErrorObject {
393                                            code: match error {
394                                                McpError::Protocol(ref msg) if msg.contains("not found") => {
395                                                    error_codes::METHOD_NOT_FOUND
396                                                }
397                                                _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
398                                            },
399                                            message: error.to_string(),
400                                            data: None,
401                                        },
402                                    };
403                                    serde_json::to_string(&json_rpc_error)
404                                }
405                            };
406
407                            let response_line =
408                                response_or_error.map_err(McpError::serialization)?;
409
410                            tracing::trace!("Sending: {}", response_line);
411
412                            writer
413                                .write_all(response_line.as_bytes())
414                                .await
415                                .map_err(|e| {
416                                    McpError::transport(format!("Failed to write response: {e}"))
417                                })?;
418                            writer.write_all(b"\n").await.map_err(|e| {
419                                McpError::transport(format!("Failed to write newline: {e}"))
420                            })?;
421                            writer.flush().await.map_err(|e| {
422                                McpError::transport(format!("Failed to flush: {e}"))
423                            })?;
424                        }
425                        Err(e) => {
426                            tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
427                            // Send parse error response if we can extract an ID
428                            // For now, just continue
429                        }
430                    }
431                }
432                Err(e) => {
433                    tracing::error!("Error reading from stdin: {}", e);
434                    return Err(McpError::io(e));
435                }
436            }
437        }
438
439        Ok(())
440    }
441
442    fn set_request_handler(&mut self, handler: ServerRequestHandler) {
443        self.request_handler = Some(handler);
444    }
445
446    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
447        let writer = self
448            .stdout_writer
449            .as_mut()
450            .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
451
452        let notification_line =
453            serde_json::to_string(&notification).map_err(McpError::serialization)?;
454
455        tracing::trace!("Sending notification: {}", notification_line);
456
457        writer
458            .write_all(notification_line.as_bytes())
459            .await
460            .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
461        writer
462            .write_all(b"\n")
463            .await
464            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
465        writer
466            .flush()
467            .await
468            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
469
470        Ok(())
471    }
472
473    async fn stop(&mut self) -> McpResult<()> {
474        tracing::debug!("Stopping STDIO server transport");
475        self.running = false;
476        Ok(())
477    }
478
479    fn is_running(&self) -> bool {
480        self.running
481    }
482
483    fn server_info(&self) -> String {
484        format!("STDIO server transport (running: {})", self.running)
485    }
486}
487
488// Backward compatibility method for tests
489impl StdioServerTransport {
490    /// Backward compatibility method for tests
491    /// This method provides a default response for testing purposes
492    pub async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
493        // Default implementation for tests - return method not found error
494        Err(McpError::protocol(format!(
495            "Method '{}' not found (test mode)",
496            request.method
497        )))
498    }
499}
500
501impl Default for StdioServerTransport {
502    fn default() -> Self {
503        Self::new()
504    }
505}
506
507impl Drop for StdioClientTransport {
508    fn drop(&mut self) {
509        if let Some(mut child) = self.child.take() {
510            // Try to kill the child process if it's still running
511            let _ = child.start_kill();
512        }
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use serde_json::json;
520
521    #[test]
522    fn test_stdio_server_creation() {
523        let transport = StdioServerTransport::new();
524        assert!(!transport.is_running());
525        assert!(transport.stdin_reader.is_some());
526        assert!(transport.stdout_writer.is_some());
527    }
528
529    #[test]
530    fn test_stdio_server_with_config() {
531        let config = TransportConfig {
532            read_timeout_ms: Some(30_000),
533            ..Default::default()
534        };
535
536        let transport = StdioServerTransport::with_config(config);
537        assert_eq!(transport.config.read_timeout_ms, Some(30_000));
538    }
539
540    #[tokio::test]
541    async fn test_stdio_server_handle_request() {
542        let mut transport = StdioServerTransport::new();
543
544        let request = JsonRpcRequest {
545            jsonrpc: "2.0".to_string(),
546            id: json!(1),
547            method: "unknown_method".to_string(),
548            params: None,
549        };
550
551        let result = transport.handle_request(request).await;
552        assert!(result.is_err());
553
554        match result.unwrap_err() {
555            McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
556            _ => panic!("Expected Protocol error"),
557        }
558    }
559
560    // Note: Integration tests with actual processes would go in tests/integration/
561}