mcp_protocol_sdk/transport/
stdio.rs

1//! STDIO transport implementation for MCP
2//!
3//! This module provides STDIO-based transport for MCP communication,
4//! which is commonly used for command-line tools and process communication.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20/// STDIO transport for MCP clients
21///
22/// This transport communicates with an MCP server via STDIO (standard input/output).
23/// It's typically used when the server is a separate process.
24pub struct StdioClientTransport {
25    child: Option<Child>,
26    stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27    #[allow(dead_code)]
28    stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31    config: TransportConfig,
32    state: ConnectionState,
33}
34
35impl StdioClientTransport {
36    /// Create a new STDIO client transport
37    ///
38    /// # Arguments
39    /// * `command` - Command to execute for the MCP server
40    /// * `args` - Arguments to pass to the command
41    ///
42    /// # Returns
43    /// Result containing the transport or an error
44    pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45        Self::with_config(command, args, TransportConfig::default()).await
46    }
47
48    /// Create a new STDIO client transport with custom configuration
49    ///
50    /// # Arguments
51    /// * `command` - Command to execute for the MCP server
52    /// * `args` - Arguments to pass to the command
53    /// * `config` - Transport configuration
54    ///
55    /// # Returns
56    /// Result containing the transport or an error
57    pub async fn with_config<S: AsRef<str>>(
58        command: S,
59        args: Vec<S>,
60        config: TransportConfig,
61    ) -> McpResult<Self> {
62        let command_str = command.as_ref();
63        let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65        tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67        let mut child = Command::new(command_str)
68            .args(&args_str)
69            .stdin(Stdio::piped())
70            .stdout(Stdio::piped())
71            .stderr(Stdio::piped())
72            .spawn()
73            .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
74
75        let stdin = child
76            .stdin
77            .take()
78            .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79        let stdout = child
80            .stdout
81            .take()
82            .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84        let stdin_writer = BufWriter::new(stdin);
85        let stdout_reader = BufReader::new(stdout);
86
87        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90        // Start message processing task
91        let reader_pending_requests = pending_requests.clone();
92        let reader = stdout_reader;
93        tokio::spawn(async move {
94            Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95        });
96
97        Ok(Self {
98            child: Some(child),
99            stdin_writer: Some(stdin_writer),
100            stdout_reader: None, // Moved to processor task
101            notification_receiver: Some(notification_receiver),
102            pending_requests,
103            config,
104            state: ConnectionState::Connected,
105        })
106    }
107
108    async fn message_processor(
109        mut reader: BufReader<tokio::process::ChildStdout>,
110        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112    ) {
113        let mut line = String::new();
114
115        loop {
116            line.clear();
117            match reader.read_line(&mut line).await {
118                Ok(0) => {
119                    tracing::debug!("STDIO reader reached EOF");
120                    break;
121                }
122                Ok(_) => {
123                    let line = line.trim();
124                    if line.is_empty() {
125                        continue;
126                    }
127
128                    tracing::trace!("Received: {}", line);
129
130                    // Try to parse as response first
131                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132                        let mut pending = pending_requests.lock().await;
133                        if let Some(sender) = pending.remove(&response.id) {
134                            let _ = sender.send(response);
135                        } else {
136                            tracing::warn!(
137                                "Received response for unknown request ID: {:?}",
138                                response.id
139                            );
140                        }
141                    }
142                    // Try to parse as notification
143                    else if let Ok(notification) =
144                        serde_json::from_str::<JsonRpcNotification>(line)
145                    {
146                        if notification_sender.send(notification).is_err() {
147                            tracing::debug!("Notification receiver dropped");
148                            break;
149                        }
150                    } else {
151                        tracing::warn!("Failed to parse message: {}", line);
152                    }
153                }
154                Err(e) => {
155                    tracing::error!("Error reading from stdout: {}", e);
156                    break;
157                }
158            }
159        }
160    }
161}
162
163#[async_trait]
164impl Transport for StdioClientTransport {
165    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
166        let writer = self
167            .stdin_writer
168            .as_mut()
169            .ok_or_else(|| McpError::transport("Transport not connected"))?;
170
171        let (sender, receiver) = tokio::sync::oneshot::channel();
172
173        // Store the pending request
174        {
175            let mut pending = self.pending_requests.lock().await;
176            pending.insert(request.id.clone(), sender);
177        }
178
179        // Send the request
180        let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
181
182        tracing::trace!("Sending: {}", request_line);
183
184        writer
185            .write_all(request_line.as_bytes())
186            .await
187            .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188        writer
189            .write_all(b"\n")
190            .await
191            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192        writer
193            .flush()
194            .await
195            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197        // Wait for response with timeout
198        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200        let response = timeout(timeout_duration, receiver)
201            .await
202            .map_err(|_| McpError::timeout("Request timeout"))?
203            .map_err(|_| McpError::transport("Response channel closed"))?;
204
205        Ok(response)
206    }
207
208    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209        let writer = self
210            .stdin_writer
211            .as_mut()
212            .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214        let notification_line =
215            serde_json::to_string(&notification).map_err(McpError::serialization)?;
216
217        tracing::trace!("Sending notification: {}", notification_line);
218
219        writer
220            .write_all(notification_line.as_bytes())
221            .await
222            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223        writer
224            .write_all(b"\n")
225            .await
226            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227        writer
228            .flush()
229            .await
230            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232        Ok(())
233    }
234
235    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236        if let Some(ref mut receiver) = self.notification_receiver {
237            match receiver.try_recv() {
238                Ok(notification) => Ok(Some(notification)),
239                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240                Err(mpsc::error::TryRecvError::Disconnected) => {
241                    Err(McpError::transport("Notification channel disconnected"))
242                }
243            }
244        } else {
245            Ok(None)
246        }
247    }
248
249    async fn close(&mut self) -> McpResult<()> {
250        tracing::debug!("Closing STDIO transport");
251
252        self.state = ConnectionState::Closing;
253
254        // Close stdin to signal the server to shut down
255        if let Some(mut writer) = self.stdin_writer.take() {
256            let _ = writer.shutdown().await;
257        }
258
259        // Wait for the child process to exit
260        if let Some(mut child) = self.child.take() {
261            match timeout(Duration::from_secs(5), child.wait()).await {
262                Ok(Ok(status)) => {
263                    tracing::debug!("Server process exited with status: {}", status);
264                }
265                Ok(Err(e)) => {
266                    tracing::warn!("Error waiting for server process: {}", e);
267                }
268                Err(_) => {
269                    tracing::warn!("Timeout waiting for server process, killing it");
270                    let _ = child.kill().await;
271                }
272            }
273        }
274
275        self.state = ConnectionState::Disconnected;
276        Ok(())
277    }
278
279    fn is_connected(&self) -> bool {
280        matches!(self.state, ConnectionState::Connected)
281    }
282
283    fn connection_info(&self) -> String {
284        format!("STDIO transport (state: {:?})", self.state)
285    }
286}
287
288/// STDIO transport for MCP servers
289///
290/// This transport communicates with an MCP client via STDIO (standard input/output).
291/// It reads requests from stdin and writes responses to stdout.
292pub struct StdioServerTransport {
293    stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294    stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295    #[allow(dead_code)]
296    config: TransportConfig,
297    running: bool,
298    request_handler: Option<
299        Box<
300            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
301        >,
302    >,
303}
304
305impl StdioServerTransport {
306    /// Create a new STDIO server transport
307    ///
308    /// # Returns
309    /// New STDIO server transport instance
310    pub fn new() -> Self {
311        Self::with_config(TransportConfig::default())
312    }
313
314    /// Create a new STDIO server transport with custom configuration
315    ///
316    /// # Arguments
317    /// * `config` - Transport configuration
318    ///
319    /// # Returns
320    /// New STDIO server transport instance
321    pub fn with_config(config: TransportConfig) -> Self {
322        let stdin_reader = BufReader::new(tokio::io::stdin());
323        let stdout_writer = BufWriter::new(tokio::io::stdout());
324
325        Self {
326            stdin_reader: Some(stdin_reader),
327            stdout_writer: Some(stdout_writer),
328            config,
329            running: false,
330            request_handler: None,
331        }
332    }
333
334    /// Set the request handler function
335    ///
336    /// # Arguments
337    /// * `handler` - Function that processes incoming requests
338    pub fn set_request_handler<F>(&mut self, handler: F)
339    where
340        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
341            + Send
342            + Sync
343            + 'static,
344    {
345        self.request_handler = Some(Box::new(handler));
346    }
347}
348
349#[async_trait]
350impl ServerTransport for StdioServerTransport {
351    async fn start(&mut self) -> McpResult<()> {
352        tracing::debug!("Starting STDIO server transport");
353
354        let mut reader = self
355            .stdin_reader
356            .take()
357            .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
358        let mut writer = self
359            .stdout_writer
360            .take()
361            .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
362
363        self.running = true;
364
365        let mut line = String::new();
366        while self.running {
367            line.clear();
368
369            match reader.read_line(&mut line).await {
370                Ok(0) => {
371                    tracing::debug!("STDIN closed, stopping server");
372                    break;
373                }
374                Ok(_) => {
375                    let line = line.trim();
376                    if line.is_empty() {
377                        continue;
378                    }
379
380                    tracing::trace!("Received: {}", line);
381
382                    // Parse the request
383                    match serde_json::from_str::<JsonRpcRequest>(line) {
384                        Ok(request) => {
385                            let response = self.handle_request(request).await?;
386
387                            let response_line = serde_json::to_string(&response)
388                                .map_err(McpError::serialization)?;
389
390                            tracing::trace!("Sending: {}", response_line);
391
392                            writer
393                                .write_all(response_line.as_bytes())
394                                .await
395                                .map_err(|e| {
396                                    McpError::transport(format!("Failed to write response: {}", e))
397                                })?;
398                            writer.write_all(b"\n").await.map_err(|e| {
399                                McpError::transport(format!("Failed to write newline: {}", e))
400                            })?;
401                            writer.flush().await.map_err(|e| {
402                                McpError::transport(format!("Failed to flush: {}", e))
403                            })?;
404                        }
405                        Err(e) => {
406                            tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
407                            // Send parse error response if we can extract an ID
408                            // For now, just continue
409                        }
410                    }
411                }
412                Err(e) => {
413                    tracing::error!("Error reading from stdin: {}", e);
414                    return Err(McpError::io(e));
415                }
416            }
417        }
418
419        Ok(())
420    }
421
422    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
423        // Default implementation - return method not found
424        Ok(JsonRpcResponse {
425            jsonrpc: "2.0".to_string(),
426            id: request.id,
427            result: None,
428            error: Some(crate::protocol::types::JsonRpcError {
429                code: crate::protocol::types::METHOD_NOT_FOUND,
430                message: format!("Method '{}' not found", request.method),
431                data: None,
432            }),
433        })
434    }
435
436    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
437        let writer = self
438            .stdout_writer
439            .as_mut()
440            .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
441
442        let notification_line =
443            serde_json::to_string(&notification).map_err(McpError::serialization)?;
444
445        tracing::trace!("Sending notification: {}", notification_line);
446
447        writer
448            .write_all(notification_line.as_bytes())
449            .await
450            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
451        writer
452            .write_all(b"\n")
453            .await
454            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
455        writer
456            .flush()
457            .await
458            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
459
460        Ok(())
461    }
462
463    async fn stop(&mut self) -> McpResult<()> {
464        tracing::debug!("Stopping STDIO server transport");
465        self.running = false;
466        Ok(())
467    }
468
469    fn is_running(&self) -> bool {
470        self.running
471    }
472
473    fn server_info(&self) -> String {
474        format!("STDIO server transport (running: {})", self.running)
475    }
476}
477
478impl Default for StdioServerTransport {
479    fn default() -> Self {
480        Self::new()
481    }
482}
483
484impl Drop for StdioClientTransport {
485    fn drop(&mut self) {
486        if let Some(mut child) = self.child.take() {
487            // Try to kill the child process if it's still running
488            let _ = child.start_kill();
489        }
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use serde_json::json;
497
498    #[test]
499    fn test_stdio_server_creation() {
500        let transport = StdioServerTransport::new();
501        assert!(!transport.is_running());
502        assert!(transport.stdin_reader.is_some());
503        assert!(transport.stdout_writer.is_some());
504    }
505
506    #[test]
507    fn test_stdio_server_with_config() {
508        let config = TransportConfig {
509            read_timeout_ms: Some(30_000),
510            ..Default::default()
511        };
512
513        let transport = StdioServerTransport::with_config(config);
514        assert_eq!(transport.config.read_timeout_ms, Some(30_000));
515    }
516
517    #[tokio::test]
518    async fn test_stdio_server_handle_request() {
519        let mut transport = StdioServerTransport::new();
520
521        let request = JsonRpcRequest {
522            jsonrpc: "2.0".to_string(),
523            id: json!(1),
524            method: "unknown_method".to_string(),
525            params: None,
526        };
527
528        let response = transport.handle_request(request).await.unwrap();
529        assert_eq!(response.jsonrpc, "2.0");
530        assert_eq!(response.id, json!(1));
531        assert!(response.error.is_some());
532        assert!(response.result.is_none());
533
534        let error = response.error.unwrap();
535        assert_eq!(error.code, crate::protocol::types::METHOD_NOT_FOUND);
536    }
537
538    // Note: Integration tests with actual processes would go in tests/integration/
539}