mcp_protocol_sdk/transport/
stdio.rs

1//! STDIO transport implementation for MCP
2//!
3//! This module provides STDIO-based transport for MCP communication,
4//! which is commonly used for command-line tools and process communication.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{error_codes, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20/// STDIO transport for MCP clients
21///
22/// This transport communicates with an MCP server via STDIO (standard input/output).
23/// It's typically used when the server is a separate process.
24pub struct StdioClientTransport {
25    child: Option<Child>,
26    stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27    #[allow(dead_code)]
28    stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31    config: TransportConfig,
32    state: ConnectionState,
33}
34
35impl StdioClientTransport {
36    /// Create a new STDIO client transport
37    ///
38    /// # Arguments
39    /// * `command` - Command to execute for the MCP server
40    /// * `args` - Arguments to pass to the command
41    ///
42    /// # Returns
43    /// Result containing the transport or an error
44    pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45        Self::with_config(command, args, TransportConfig::default()).await
46    }
47
48    /// Create a new STDIO client transport with custom configuration
49    ///
50    /// # Arguments
51    /// * `command` - Command to execute for the MCP server
52    /// * `args` - Arguments to pass to the command
53    /// * `config` - Transport configuration
54    ///
55    /// # Returns
56    /// Result containing the transport or an error
57    pub async fn with_config<S: AsRef<str>>(
58        command: S,
59        args: Vec<S>,
60        config: TransportConfig,
61    ) -> McpResult<Self> {
62        let command_str = command.as_ref();
63        let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65        tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67        let mut child = Command::new(command_str)
68            .args(&args_str)
69            .stdin(Stdio::piped())
70            .stdout(Stdio::piped())
71            .stderr(Stdio::piped())
72            .spawn()
73            .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
74
75        let stdin = child
76            .stdin
77            .take()
78            .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79        let stdout = child
80            .stdout
81            .take()
82            .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84        let stdin_writer = BufWriter::new(stdin);
85        let stdout_reader = BufReader::new(stdout);
86
87        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90        // Start message processing task
91        let reader_pending_requests = pending_requests.clone();
92        let reader = stdout_reader;
93        tokio::spawn(async move {
94            Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95        });
96
97        Ok(Self {
98            child: Some(child),
99            stdin_writer: Some(stdin_writer),
100            stdout_reader: None, // Moved to processor task
101            notification_receiver: Some(notification_receiver),
102            pending_requests,
103            config,
104            state: ConnectionState::Connected,
105        })
106    }
107
108    async fn message_processor(
109        mut reader: BufReader<tokio::process::ChildStdout>,
110        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112    ) {
113        let mut line = String::new();
114
115        loop {
116            line.clear();
117            match reader.read_line(&mut line).await {
118                Ok(0) => {
119                    tracing::debug!("STDIO reader reached EOF");
120                    break;
121                }
122                Ok(_) => {
123                    let line = line.trim();
124                    if line.is_empty() {
125                        continue;
126                    }
127
128                    tracing::trace!("Received: {}", line);
129
130                    // Try to parse as response first
131                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132                        let mut pending = pending_requests.lock().await;
133                        if let Some(sender) = pending.remove(&response.id) {
134                            let _ = sender.send(response);
135                        } else {
136                            tracing::warn!(
137                                "Received response for unknown request ID: {:?}",
138                                response.id
139                            );
140                        }
141                    }
142                    // Try to parse as notification
143                    else if let Ok(notification) =
144                        serde_json::from_str::<JsonRpcNotification>(line)
145                    {
146                        if notification_sender.send(notification).is_err() {
147                            tracing::debug!("Notification receiver dropped");
148                            break;
149                        }
150                    } else {
151                        tracing::warn!("Failed to parse message: {}", line);
152                    }
153                }
154                Err(e) => {
155                    tracing::error!("Error reading from stdout: {}", e);
156                    break;
157                }
158            }
159        }
160    }
161}
162
163#[async_trait]
164impl Transport for StdioClientTransport {
165    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
166        let writer = self
167            .stdin_writer
168            .as_mut()
169            .ok_or_else(|| McpError::transport("Transport not connected"))?;
170
171        let (sender, receiver) = tokio::sync::oneshot::channel();
172
173        // Store the pending request
174        {
175            let mut pending = self.pending_requests.lock().await;
176            pending.insert(request.id.clone(), sender);
177        }
178
179        // Send the request
180        let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
181
182        tracing::trace!("Sending: {}", request_line);
183
184        writer
185            .write_all(request_line.as_bytes())
186            .await
187            .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188        writer
189            .write_all(b"\n")
190            .await
191            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192        writer
193            .flush()
194            .await
195            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197        // Wait for response with timeout
198        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200        let response = timeout(timeout_duration, receiver)
201            .await
202            .map_err(|_| McpError::timeout("Request timeout"))?
203            .map_err(|_| McpError::transport("Response channel closed"))?;
204
205        Ok(response)
206    }
207
208    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209        let writer = self
210            .stdin_writer
211            .as_mut()
212            .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214        let notification_line =
215            serde_json::to_string(&notification).map_err(McpError::serialization)?;
216
217        tracing::trace!("Sending notification: {}", notification_line);
218
219        writer
220            .write_all(notification_line.as_bytes())
221            .await
222            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223        writer
224            .write_all(b"\n")
225            .await
226            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227        writer
228            .flush()
229            .await
230            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232        Ok(())
233    }
234
235    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236        if let Some(ref mut receiver) = self.notification_receiver {
237            match receiver.try_recv() {
238                Ok(notification) => Ok(Some(notification)),
239                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240                Err(mpsc::error::TryRecvError::Disconnected) => {
241                    Err(McpError::transport("Notification channel disconnected"))
242                }
243            }
244        } else {
245            Ok(None)
246        }
247    }
248
249    async fn close(&mut self) -> McpResult<()> {
250        tracing::debug!("Closing STDIO transport");
251
252        self.state = ConnectionState::Closing;
253
254        // Close stdin to signal the server to shut down
255        if let Some(mut writer) = self.stdin_writer.take() {
256            let _ = writer.shutdown().await;
257        }
258
259        // Wait for the child process to exit
260        if let Some(mut child) = self.child.take() {
261            match timeout(Duration::from_secs(5), child.wait()).await {
262                Ok(Ok(status)) => {
263                    tracing::debug!("Server process exited with status: {}", status);
264                }
265                Ok(Err(e)) => {
266                    tracing::warn!("Error waiting for server process: {}", e);
267                }
268                Err(_) => {
269                    tracing::warn!("Timeout waiting for server process, killing it");
270                    let _ = child.kill().await;
271                }
272            }
273        }
274
275        self.state = ConnectionState::Disconnected;
276        Ok(())
277    }
278
279    fn is_connected(&self) -> bool {
280        matches!(self.state, ConnectionState::Connected)
281    }
282
283    fn connection_info(&self) -> String {
284        format!("STDIO transport (state: {:?})", self.state)
285    }
286}
287
288/// STDIO transport for MCP servers
289///
290/// This transport communicates with an MCP client via STDIO (standard input/output).
291/// It reads requests from stdin and writes responses to stdout.
292pub struct StdioServerTransport {
293    stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294    stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295    #[allow(dead_code)]
296    config: TransportConfig,
297    running: bool,
298    request_handler: Option<
299        Box<
300            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
301        >,
302    >,
303}
304
305impl StdioServerTransport {
306    /// Create a new STDIO server transport
307    ///
308    /// # Returns
309    /// New STDIO server transport instance
310    pub fn new() -> Self {
311        Self::with_config(TransportConfig::default())
312    }
313
314    /// Create a new STDIO server transport with custom configuration
315    ///
316    /// # Arguments
317    /// * `config` - Transport configuration
318    ///
319    /// # Returns
320    /// New STDIO server transport instance
321    pub fn with_config(config: TransportConfig) -> Self {
322        let stdin_reader = BufReader::new(tokio::io::stdin());
323        let stdout_writer = BufWriter::new(tokio::io::stdout());
324
325        Self {
326            stdin_reader: Some(stdin_reader),
327            stdout_writer: Some(stdout_writer),
328            config,
329            running: false,
330            request_handler: None,
331        }
332    }
333
334    /// Set the request handler function
335    ///
336    /// # Arguments
337    /// * `handler` - Function that processes incoming requests
338    pub fn set_request_handler<F>(&mut self, handler: F)
339    where
340        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
341            + Send
342            + Sync
343            + 'static,
344    {
345        self.request_handler = Some(Box::new(handler));
346    }
347}
348
349#[async_trait]
350impl ServerTransport for StdioServerTransport {
351    async fn start(&mut self) -> McpResult<()> {
352        tracing::debug!("Starting STDIO server transport");
353
354        let mut reader = self
355            .stdin_reader
356            .take()
357            .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
358        let mut writer = self
359            .stdout_writer
360            .take()
361            .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
362
363        self.running = true;
364
365        let mut line = String::new();
366        while self.running {
367            line.clear();
368
369            match reader.read_line(&mut line).await {
370                Ok(0) => {
371                    tracing::debug!("STDIN closed, stopping server");
372                    break;
373                }
374                Ok(_) => {
375                    let line = line.trim();
376                    if line.is_empty() {
377                        continue;
378                    }
379
380                    tracing::trace!("Received: {}", line);
381
382                    // Parse the request
383                    match serde_json::from_str::<JsonRpcRequest>(line) {
384                        Ok(request) => {
385                            let response_or_error = match self.handle_request(request.clone()).await
386                            {
387                                Ok(response) => serde_json::to_string(&response),
388                                Err(error) => {
389                                    // Convert McpError to JsonRpcError
390                                    let json_rpc_error = crate::protocol::types::JsonRpcError {
391                                        jsonrpc: "2.0".to_string(),
392                                        id: request.id,
393                                        error: crate::protocol::types::ErrorObject {
394                                            code: match error {
395                                                McpError::Protocol(ref msg) if msg.contains("not found") => {
396                                                    error_codes::METHOD_NOT_FOUND
397                                                }
398                                                _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
399                                            },
400                                            message: error.to_string(),
401                                            data: None,
402                                        },
403                                    };
404                                    serde_json::to_string(&json_rpc_error)
405                                }
406                            };
407
408                            let response_line =
409                                response_or_error.map_err(McpError::serialization)?;
410
411                            tracing::trace!("Sending: {}", response_line);
412
413                            writer
414                                .write_all(response_line.as_bytes())
415                                .await
416                                .map_err(|e| {
417                                    McpError::transport(format!("Failed to write response: {}", e))
418                                })?;
419                            writer.write_all(b"\n").await.map_err(|e| {
420                                McpError::transport(format!("Failed to write newline: {}", e))
421                            })?;
422                            writer.flush().await.map_err(|e| {
423                                McpError::transport(format!("Failed to flush: {}", e))
424                            })?;
425                        }
426                        Err(e) => {
427                            tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
428                            // Send parse error response if we can extract an ID
429                            // For now, just continue
430                        }
431                    }
432                }
433                Err(e) => {
434                    tracing::error!("Error reading from stdin: {}", e);
435                    return Err(McpError::io(e));
436                }
437            }
438        }
439
440        Ok(())
441    }
442
443    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
444        // Default implementation - return method not found error
445        Err(McpError::protocol(format!(
446            "Method '{}' not found",
447            request.method
448        )))
449    }
450
451    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
452        let writer = self
453            .stdout_writer
454            .as_mut()
455            .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
456
457        let notification_line =
458            serde_json::to_string(&notification).map_err(McpError::serialization)?;
459
460        tracing::trace!("Sending notification: {}", notification_line);
461
462        writer
463            .write_all(notification_line.as_bytes())
464            .await
465            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
466        writer
467            .write_all(b"\n")
468            .await
469            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
470        writer
471            .flush()
472            .await
473            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
474
475        Ok(())
476    }
477
478    async fn stop(&mut self) -> McpResult<()> {
479        tracing::debug!("Stopping STDIO server transport");
480        self.running = false;
481        Ok(())
482    }
483
484    fn is_running(&self) -> bool {
485        self.running
486    }
487
488    fn server_info(&self) -> String {
489        format!("STDIO server transport (running: {})", self.running)
490    }
491}
492
493impl Default for StdioServerTransport {
494    fn default() -> Self {
495        Self::new()
496    }
497}
498
499impl Drop for StdioClientTransport {
500    fn drop(&mut self) {
501        if let Some(mut child) = self.child.take() {
502            // Try to kill the child process if it's still running
503            let _ = child.start_kill();
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511    use serde_json::json;
512
513    #[test]
514    fn test_stdio_server_creation() {
515        let transport = StdioServerTransport::new();
516        assert!(!transport.is_running());
517        assert!(transport.stdin_reader.is_some());
518        assert!(transport.stdout_writer.is_some());
519    }
520
521    #[test]
522    fn test_stdio_server_with_config() {
523        let config = TransportConfig {
524            read_timeout_ms: Some(30_000),
525            ..Default::default()
526        };
527
528        let transport = StdioServerTransport::with_config(config);
529        assert_eq!(transport.config.read_timeout_ms, Some(30_000));
530    }
531
532    #[tokio::test]
533    async fn test_stdio_server_handle_request() {
534        let mut transport = StdioServerTransport::new();
535
536        let request = JsonRpcRequest {
537            jsonrpc: "2.0".to_string(),
538            id: json!(1),
539            method: "unknown_method".to_string(),
540            params: None,
541        };
542
543        let result = transport.handle_request(request).await;
544        assert!(result.is_err());
545
546        match result.unwrap_err() {
547            McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
548            _ => panic!("Expected Protocol error"),
549        }
550    }
551
552    // Note: Integration tests with actual processes would go in tests/integration/
553}