mcp_protocol_sdk/transport/
stdio.rs

1//! STDIO transport implementation for MCP
2//!
3//! This module provides STDIO-based transport for MCP communication,
4//! which is commonly used for command-line tools and process communication.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20/// STDIO transport for MCP clients
21///
22/// This transport communicates with an MCP server via STDIO (standard input/output).
23/// It's typically used when the server is a separate process.
24pub struct StdioClientTransport {
25    child: Option<Child>,
26    stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27    stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
28    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
29    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
30    config: TransportConfig,
31    state: ConnectionState,
32}
33
34impl StdioClientTransport {
35    /// Create a new STDIO client transport
36    ///
37    /// # Arguments
38    /// * `command` - Command to execute for the MCP server
39    /// * `args` - Arguments to pass to the command
40    ///
41    /// # Returns
42    /// Result containing the transport or an error
43    pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
44        Self::with_config(command, args, TransportConfig::default()).await
45    }
46
47    /// Create a new STDIO client transport with custom configuration
48    ///
49    /// # Arguments
50    /// * `command` - Command to execute for the MCP server
51    /// * `args` - Arguments to pass to the command
52    /// * `config` - Transport configuration
53    ///
54    /// # Returns
55    /// Result containing the transport or an error
56    pub async fn with_config<S: AsRef<str>>(
57        command: S,
58        args: Vec<S>,
59        config: TransportConfig,
60    ) -> McpResult<Self> {
61        let command_str = command.as_ref();
62        let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
63
64        tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
65
66        let mut child = Command::new(command_str)
67            .args(&args_str)
68            .stdin(Stdio::piped())
69            .stdout(Stdio::piped())
70            .stderr(Stdio::piped())
71            .spawn()
72            .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
73
74        let stdin = child
75            .stdin
76            .take()
77            .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
78        let stdout = child
79            .stdout
80            .take()
81            .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
82
83        let stdin_writer = BufWriter::new(stdin);
84        let stdout_reader = BufReader::new(stdout);
85
86        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
87        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
88
89        // Start message processing task
90        let reader_pending_requests = pending_requests.clone();
91        let mut reader = stdout_reader;
92        tokio::spawn(async move {
93            Self::message_processor(reader, notification_sender, reader_pending_requests).await;
94        });
95
96        Ok(Self {
97            child: Some(child),
98            stdin_writer: Some(stdin_writer),
99            stdout_reader: None, // Moved to processor task
100            notification_receiver: Some(notification_receiver),
101            pending_requests,
102            config,
103            state: ConnectionState::Connected,
104        })
105    }
106
107    async fn message_processor(
108        mut reader: BufReader<tokio::process::ChildStdout>,
109        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
110        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
111    ) {
112        let mut line = String::new();
113
114        loop {
115            line.clear();
116            match reader.read_line(&mut line).await {
117                Ok(0) => {
118                    tracing::debug!("STDIO reader reached EOF");
119                    break;
120                }
121                Ok(_) => {
122                    let line = line.trim();
123                    if line.is_empty() {
124                        continue;
125                    }
126
127                    tracing::trace!("Received: {}", line);
128
129                    // Try to parse as response first
130                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
131                        let mut pending = pending_requests.lock().await;
132                        if let Some(sender) = pending.remove(&response.id) {
133                            let _ = sender.send(response);
134                        } else {
135                            tracing::warn!(
136                                "Received response for unknown request ID: {:?}",
137                                response.id
138                            );
139                        }
140                    }
141                    // Try to parse as notification
142                    else if let Ok(notification) =
143                        serde_json::from_str::<JsonRpcNotification>(line)
144                    {
145                        if notification_sender.send(notification).is_err() {
146                            tracing::debug!("Notification receiver dropped");
147                            break;
148                        }
149                    } else {
150                        tracing::warn!("Failed to parse message: {}", line);
151                    }
152                }
153                Err(e) => {
154                    tracing::error!("Error reading from stdout: {}", e);
155                    break;
156                }
157            }
158        }
159    }
160}
161
162#[async_trait]
163impl Transport for StdioClientTransport {
164    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
165        let writer = self
166            .stdin_writer
167            .as_mut()
168            .ok_or_else(|| McpError::transport("Transport not connected"))?;
169
170        let (sender, receiver) = tokio::sync::oneshot::channel();
171
172        // Store the pending request
173        {
174            let mut pending = self.pending_requests.lock().await;
175            pending.insert(request.id.clone(), sender);
176        }
177
178        // Send the request
179        let request_line =
180            serde_json::to_string(&request).map_err(|e| McpError::serialization(e))?;
181
182        tracing::trace!("Sending: {}", request_line);
183
184        writer
185            .write_all(request_line.as_bytes())
186            .await
187            .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188        writer
189            .write_all(b"\n")
190            .await
191            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192        writer
193            .flush()
194            .await
195            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197        // Wait for response with timeout
198        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200        let response = timeout(timeout_duration, receiver)
201            .await
202            .map_err(|_| McpError::timeout("Request timeout"))?
203            .map_err(|_| McpError::transport("Response channel closed"))?;
204
205        Ok(response)
206    }
207
208    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209        let writer = self
210            .stdin_writer
211            .as_mut()
212            .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214        let notification_line =
215            serde_json::to_string(&notification).map_err(|e| McpError::serialization(e))?;
216
217        tracing::trace!("Sending notification: {}", notification_line);
218
219        writer
220            .write_all(notification_line.as_bytes())
221            .await
222            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223        writer
224            .write_all(b"\n")
225            .await
226            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227        writer
228            .flush()
229            .await
230            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232        Ok(())
233    }
234
235    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236        if let Some(ref mut receiver) = self.notification_receiver {
237            match receiver.try_recv() {
238                Ok(notification) => Ok(Some(notification)),
239                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240                Err(mpsc::error::TryRecvError::Disconnected) => {
241                    Err(McpError::transport("Notification channel disconnected"))
242                }
243            }
244        } else {
245            Ok(None)
246        }
247    }
248
249    async fn close(&mut self) -> McpResult<()> {
250        tracing::debug!("Closing STDIO transport");
251
252        self.state = ConnectionState::Closing;
253
254        // Close stdin to signal the server to shut down
255        if let Some(mut writer) = self.stdin_writer.take() {
256            let _ = writer.shutdown().await;
257        }
258
259        // Wait for the child process to exit
260        if let Some(mut child) = self.child.take() {
261            match timeout(Duration::from_secs(5), child.wait()).await {
262                Ok(Ok(status)) => {
263                    tracing::debug!("Server process exited with status: {}", status);
264                }
265                Ok(Err(e)) => {
266                    tracing::warn!("Error waiting for server process: {}", e);
267                }
268                Err(_) => {
269                    tracing::warn!("Timeout waiting for server process, killing it");
270                    let _ = child.kill().await;
271                }
272            }
273        }
274
275        self.state = ConnectionState::Disconnected;
276        Ok(())
277    }
278
279    fn is_connected(&self) -> bool {
280        matches!(self.state, ConnectionState::Connected)
281    }
282
283    fn connection_info(&self) -> String {
284        format!("STDIO transport (state: {:?})", self.state)
285    }
286}
287
288/// STDIO transport for MCP servers
289///
290/// This transport communicates with an MCP client via STDIO (standard input/output).
291/// It reads requests from stdin and writes responses to stdout.
292pub struct StdioServerTransport {
293    stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294    stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295    config: TransportConfig,
296    running: bool,
297    request_handler: Option<
298        Box<
299            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
300        >,
301    >,
302}
303
304impl StdioServerTransport {
305    /// Create a new STDIO server transport
306    ///
307    /// # Returns
308    /// New STDIO server transport instance
309    pub fn new() -> Self {
310        Self::with_config(TransportConfig::default())
311    }
312
313    /// Create a new STDIO server transport with custom configuration
314    ///
315    /// # Arguments
316    /// * `config` - Transport configuration
317    ///
318    /// # Returns
319    /// New STDIO server transport instance
320    pub fn with_config(config: TransportConfig) -> Self {
321        let stdin_reader = BufReader::new(tokio::io::stdin());
322        let stdout_writer = BufWriter::new(tokio::io::stdout());
323
324        Self {
325            stdin_reader: Some(stdin_reader),
326            stdout_writer: Some(stdout_writer),
327            config,
328            running: false,
329            request_handler: None,
330        }
331    }
332
333    /// Set the request handler function
334    ///
335    /// # Arguments
336    /// * `handler` - Function that processes incoming requests
337    pub fn set_request_handler<F>(&mut self, handler: F)
338    where
339        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
340            + Send
341            + Sync
342            + 'static,
343    {
344        self.request_handler = Some(Box::new(handler));
345    }
346}
347
348#[async_trait]
349impl ServerTransport for StdioServerTransport {
350    async fn start(&mut self) -> McpResult<()> {
351        tracing::debug!("Starting STDIO server transport");
352
353        let mut reader = self
354            .stdin_reader
355            .take()
356            .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
357        let mut writer = self
358            .stdout_writer
359            .take()
360            .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
361
362        self.running = true;
363
364        let mut line = String::new();
365        while self.running {
366            line.clear();
367
368            match reader.read_line(&mut line).await {
369                Ok(0) => {
370                    tracing::debug!("STDIN closed, stopping server");
371                    break;
372                }
373                Ok(_) => {
374                    let line = line.trim();
375                    if line.is_empty() {
376                        continue;
377                    }
378
379                    tracing::trace!("Received: {}", line);
380
381                    // Parse the request
382                    match serde_json::from_str::<JsonRpcRequest>(line) {
383                        Ok(request) => {
384                            let response = self.handle_request(request).await?;
385
386                            let response_line = serde_json::to_string(&response)
387                                .map_err(|e| McpError::serialization(e))?;
388
389                            tracing::trace!("Sending: {}", response_line);
390
391                            writer
392                                .write_all(response_line.as_bytes())
393                                .await
394                                .map_err(|e| {
395                                    McpError::transport(format!("Failed to write response: {}", e))
396                                })?;
397                            writer.write_all(b"\n").await.map_err(|e| {
398                                McpError::transport(format!("Failed to write newline: {}", e))
399                            })?;
400                            writer.flush().await.map_err(|e| {
401                                McpError::transport(format!("Failed to flush: {}", e))
402                            })?;
403                        }
404                        Err(e) => {
405                            tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
406                            // Send parse error response if we can extract an ID
407                            // For now, just continue
408                        }
409                    }
410                }
411                Err(e) => {
412                    tracing::error!("Error reading from stdin: {}", e);
413                    return Err(McpError::io(e));
414                }
415            }
416        }
417
418        Ok(())
419    }
420
421    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
422        // Default implementation - return method not found
423        Ok(JsonRpcResponse {
424            jsonrpc: "2.0".to_string(),
425            id: request.id,
426            result: None,
427            error: Some(crate::protocol::types::JsonRpcError {
428                code: crate::protocol::types::METHOD_NOT_FOUND,
429                message: format!("Method '{}' not found", request.method),
430                data: None,
431            }),
432        })
433    }
434
435    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
436        let writer = self
437            .stdout_writer
438            .as_mut()
439            .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
440
441        let notification_line =
442            serde_json::to_string(&notification).map_err(|e| McpError::serialization(e))?;
443
444        tracing::trace!("Sending notification: {}", notification_line);
445
446        writer
447            .write_all(notification_line.as_bytes())
448            .await
449            .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
450        writer
451            .write_all(b"\n")
452            .await
453            .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
454        writer
455            .flush()
456            .await
457            .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
458
459        Ok(())
460    }
461
462    async fn stop(&mut self) -> McpResult<()> {
463        tracing::debug!("Stopping STDIO server transport");
464        self.running = false;
465        Ok(())
466    }
467
468    fn is_running(&self) -> bool {
469        self.running
470    }
471
472    fn server_info(&self) -> String {
473        format!("STDIO server transport (running: {})", self.running)
474    }
475}
476
477impl Default for StdioServerTransport {
478    fn default() -> Self {
479        Self::new()
480    }
481}
482
483impl Drop for StdioClientTransport {
484    fn drop(&mut self) {
485        if let Some(mut child) = self.child.take() {
486            // Try to kill the child process if it's still running
487            let _ = child.start_kill();
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use serde_json::json;
496
497    #[test]
498    fn test_stdio_server_creation() {
499        let transport = StdioServerTransport::new();
500        assert!(!transport.is_running());
501        assert!(transport.stdin_reader.is_some());
502        assert!(transport.stdout_writer.is_some());
503    }
504
505    #[test]
506    fn test_stdio_server_with_config() {
507        let mut config = TransportConfig::default();
508        config.read_timeout_ms = Some(30_000);
509
510        let transport = StdioServerTransport::with_config(config);
511        assert_eq!(transport.config.read_timeout_ms, Some(30_000));
512    }
513
514    #[tokio::test]
515    async fn test_stdio_server_handle_request() {
516        let mut transport = StdioServerTransport::new();
517
518        let request = JsonRpcRequest {
519            jsonrpc: "2.0".to_string(),
520            id: json!(1),
521            method: "unknown_method".to_string(),
522            params: None,
523        };
524
525        let response = transport.handle_request(request).await.unwrap();
526        assert_eq!(response.jsonrpc, "2.0");
527        assert_eq!(response.id, json!(1));
528        assert!(response.error.is_some());
529        assert!(response.result.is_none());
530
531        let error = response.error.unwrap();
532        assert_eq!(error.code, crate::protocol::types::METHOD_NOT_FOUND);
533    }
534
535    // Note: Integration tests with actual processes would go in tests/integration/
536}