mcp_protocol_sdk/transport/
stdio.rs

1//! STDIO transport implementation for MCP
2//!
3//! This module provides STDIO-based transport for MCP communication,
4//! which is commonly used for command-line tools and process communication.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{Mutex, mpsc};
14use tokio::time::{Duration, timeout};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20/// STDIO transport for MCP clients
21///
22/// This transport communicates with an MCP server via STDIO (standard input/output).
23/// It's typically used when the server is a separate process.
24pub struct StdioClientTransport {
25    child: Option<Child>,
26    stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27    #[allow(dead_code)]
28    stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31    config: TransportConfig,
32    state: ConnectionState,
33}
34
35impl StdioClientTransport {
36    /// Create a new STDIO client transport
37    ///
38    /// # Arguments
39    /// * `command` - Command to execute for the MCP server
40    /// * `args` - Arguments to pass to the command
41    ///
42    /// # Returns
43    /// Result containing the transport or an error
44    pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45        Self::with_config(command, args, TransportConfig::default()).await
46    }
47
48    /// Create a new STDIO client transport with custom configuration
49    ///
50    /// # Arguments
51    /// * `command` - Command to execute for the MCP server
52    /// * `args` - Arguments to pass to the command
53    /// * `config` - Transport configuration
54    ///
55    /// # Returns
56    /// Result containing the transport or an error
57    pub async fn with_config<S: AsRef<str>>(
58        command: S,
59        args: Vec<S>,
60        config: TransportConfig,
61    ) -> McpResult<Self> {
62        let command_str = command.as_ref();
63        let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65        tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67        let mut child = Command::new(command_str)
68            .args(&args_str)
69            .stdin(Stdio::piped())
70            .stdout(Stdio::piped())
71            .stderr(Stdio::piped())
72            .spawn()
73            .map_err(|e| McpError::transport(format!("Failed to start server process: {e}")))?;
74
75        let stdin = child
76            .stdin
77            .take()
78            .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79        let stdout = child
80            .stdout
81            .take()
82            .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84        let stdin_writer = BufWriter::new(stdin);
85        let stdout_reader = BufReader::new(stdout);
86
87        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90        // Start message processing task
91        let reader_pending_requests = pending_requests.clone();
92        let reader = stdout_reader;
93        tokio::spawn(async move {
94            Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95        });
96
97        Ok(Self {
98            child: Some(child),
99            stdin_writer: Some(stdin_writer),
100            stdout_reader: None, // Moved to processor task
101            notification_receiver: Some(notification_receiver),
102            pending_requests,
103            config,
104            state: ConnectionState::Connected,
105        })
106    }
107
108    async fn message_processor(
109        mut reader: BufReader<tokio::process::ChildStdout>,
110        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112    ) {
113        let mut line = String::new();
114
115        loop {
116            line.clear();
117            match reader.read_line(&mut line).await {
118                Ok(0) => {
119                    tracing::debug!("STDIO reader reached EOF");
120                    break;
121                }
122                Ok(_) => {
123                    let line = line.trim();
124                    if line.is_empty() {
125                        continue;
126                    }
127
128                    tracing::trace!("Received: {}", line);
129
130                    // Try to parse as response first
131                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132                        let mut pending = pending_requests.lock().await;
133                        match pending.remove(&response.id) {
134                            Some(sender) => {
135                                let _ = sender.send(response);
136                            }
137                            _ => {
138                                tracing::warn!(
139                                    "Received response for unknown request ID: {:?}",
140                                    response.id
141                                );
142                            }
143                        }
144                    }
145                    // Try to parse as notification
146                    else if let Ok(notification) =
147                        serde_json::from_str::<JsonRpcNotification>(line)
148                    {
149                        if notification_sender.send(notification).is_err() {
150                            tracing::debug!("Notification receiver dropped");
151                            break;
152                        }
153                    } else {
154                        tracing::warn!("Failed to parse message: {}", line);
155                    }
156                }
157                Err(e) => {
158                    tracing::error!("Error reading from stdout: {}", e);
159                    break;
160                }
161            }
162        }
163    }
164}
165
166#[async_trait]
167impl Transport for StdioClientTransport {
168    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
169        let writer = self
170            .stdin_writer
171            .as_mut()
172            .ok_or_else(|| McpError::transport("Transport not connected"))?;
173
174        let (sender, receiver) = tokio::sync::oneshot::channel();
175
176        // Store the pending request
177        {
178            let mut pending = self.pending_requests.lock().await;
179            pending.insert(request.id.clone(), sender);
180        }
181
182        // Send the request
183        let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
184
185        tracing::trace!("Sending: {}", request_line);
186
187        writer
188            .write_all(request_line.as_bytes())
189            .await
190            .map_err(|e| McpError::transport(format!("Failed to write request: {e}")))?;
191        writer
192            .write_all(b"\n")
193            .await
194            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
195        writer
196            .flush()
197            .await
198            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
199
200        // Wait for response with timeout
201        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
202
203        let response = timeout(timeout_duration, receiver)
204            .await
205            .map_err(|_| McpError::timeout("Request timeout"))?
206            .map_err(|_| McpError::transport("Response channel closed"))?;
207
208        Ok(response)
209    }
210
211    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
212        let writer = self
213            .stdin_writer
214            .as_mut()
215            .ok_or_else(|| McpError::transport("Transport not connected"))?;
216
217        let notification_line =
218            serde_json::to_string(&notification).map_err(McpError::serialization)?;
219
220        tracing::trace!("Sending notification: {}", notification_line);
221
222        writer
223            .write_all(notification_line.as_bytes())
224            .await
225            .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
226        writer
227            .write_all(b"\n")
228            .await
229            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
230        writer
231            .flush()
232            .await
233            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
234
235        Ok(())
236    }
237
238    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
239        if let Some(ref mut receiver) = self.notification_receiver {
240            match receiver.try_recv() {
241                Ok(notification) => Ok(Some(notification)),
242                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
243                Err(mpsc::error::TryRecvError::Disconnected) => {
244                    Err(McpError::transport("Notification channel disconnected"))
245                }
246            }
247        } else {
248            Ok(None)
249        }
250    }
251
252    async fn close(&mut self) -> McpResult<()> {
253        tracing::debug!("Closing STDIO transport");
254
255        self.state = ConnectionState::Closing;
256
257        // Close stdin to signal the server to shut down
258        if let Some(mut writer) = self.stdin_writer.take() {
259            let _ = writer.shutdown().await;
260        }
261
262        // Wait for the child process to exit
263        if let Some(mut child) = self.child.take() {
264            match timeout(Duration::from_secs(5), child.wait()).await {
265                Ok(Ok(status)) => {
266                    tracing::debug!("Server process exited with status: {}", status);
267                }
268                Ok(Err(e)) => {
269                    tracing::warn!("Error waiting for server process: {}", e);
270                }
271                Err(_) => {
272                    tracing::warn!("Timeout waiting for server process, killing it");
273                    let _ = child.kill().await;
274                }
275            }
276        }
277
278        self.state = ConnectionState::Disconnected;
279        Ok(())
280    }
281
282    fn is_connected(&self) -> bool {
283        matches!(self.state, ConnectionState::Connected)
284    }
285
286    fn connection_info(&self) -> String {
287        let state = &self.state;
288        format!("STDIO transport (state: {state:?})")
289    }
290}
291
292/// STDIO transport for MCP servers
293///
294/// This transport communicates with an MCP client via STDIO (standard input/output).
295/// It reads requests from stdin and writes responses to stdout.
296pub struct StdioServerTransport {
297    stdin_reader: Option<BufReader<tokio::io::Stdin>>,
298    stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
299    #[allow(dead_code)]
300    config: TransportConfig,
301    running: bool,
302    request_handler: Option<
303        Box<
304            dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
305        >,
306    >,
307}
308
309impl StdioServerTransport {
310    /// Create a new STDIO server transport
311    ///
312    /// # Returns
313    /// New STDIO server transport instance
314    pub fn new() -> Self {
315        Self::with_config(TransportConfig::default())
316    }
317
318    /// Create a new STDIO server transport with custom configuration
319    ///
320    /// # Arguments
321    /// * `config` - Transport configuration
322    ///
323    /// # Returns
324    /// New STDIO server transport instance
325    pub fn with_config(config: TransportConfig) -> Self {
326        let stdin_reader = BufReader::new(tokio::io::stdin());
327        let stdout_writer = BufWriter::new(tokio::io::stdout());
328
329        Self {
330            stdin_reader: Some(stdin_reader),
331            stdout_writer: Some(stdout_writer),
332            config,
333            running: false,
334            request_handler: None,
335        }
336    }
337
338    /// Set the request handler function
339    ///
340    /// # Arguments
341    /// * `handler` - Function that processes incoming requests
342    pub fn set_request_handler<F>(&mut self, handler: F)
343    where
344        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
345            + Send
346            + Sync
347            + 'static,
348    {
349        self.request_handler = Some(Box::new(handler));
350    }
351}
352
353#[async_trait]
354impl ServerTransport for StdioServerTransport {
355    async fn start(&mut self) -> McpResult<()> {
356        tracing::debug!("Starting STDIO server transport");
357
358        let mut reader = self
359            .stdin_reader
360            .take()
361            .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
362        let mut writer = self
363            .stdout_writer
364            .take()
365            .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
366
367        self.running = true;
368
369        let mut line = String::new();
370        while self.running {
371            line.clear();
372
373            match reader.read_line(&mut line).await {
374                Ok(0) => {
375                    tracing::debug!("STDIN closed, stopping server");
376                    break;
377                }
378                Ok(_) => {
379                    let line = line.trim();
380                    if line.is_empty() {
381                        continue;
382                    }
383
384                    tracing::trace!("Received: {}", line);
385
386                    // Parse the request
387                    match serde_json::from_str::<JsonRpcRequest>(line) {
388                        Ok(request) => {
389                            let response_or_error = match self.handle_request(request.clone()).await
390                            {
391                                Ok(response) => serde_json::to_string(&response),
392                                Err(error) => {
393                                    // Convert McpError to JsonRpcError
394                                    let json_rpc_error = crate::protocol::types::JsonRpcError {
395                                        jsonrpc: "2.0".to_string(),
396                                        id: request.id,
397                                        error: crate::protocol::types::ErrorObject {
398                                            code: match error {
399                                                McpError::Protocol(ref msg) if msg.contains("not found") => {
400                                                    error_codes::METHOD_NOT_FOUND
401                                                }
402                                                _ => crate::protocol::types::error_codes::INTERNAL_ERROR,
403                                            },
404                                            message: error.to_string(),
405                                            data: None,
406                                        },
407                                    };
408                                    serde_json::to_string(&json_rpc_error)
409                                }
410                            };
411
412                            let response_line =
413                                response_or_error.map_err(McpError::serialization)?;
414
415                            tracing::trace!("Sending: {}", response_line);
416
417                            writer
418                                .write_all(response_line.as_bytes())
419                                .await
420                                .map_err(|e| {
421                                    McpError::transport(format!("Failed to write response: {e}"))
422                                })?;
423                            writer.write_all(b"\n").await.map_err(|e| {
424                                McpError::transport(format!("Failed to write newline: {e}"))
425                            })?;
426                            writer.flush().await.map_err(|e| {
427                                McpError::transport(format!("Failed to flush: {e}"))
428                            })?;
429                        }
430                        Err(e) => {
431                            tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
432                            // Send parse error response if we can extract an ID
433                            // For now, just continue
434                        }
435                    }
436                }
437                Err(e) => {
438                    tracing::error!("Error reading from stdin: {}", e);
439                    return Err(McpError::io(e));
440                }
441            }
442        }
443
444        Ok(())
445    }
446
447    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
448        // Default implementation - return method not found error
449        Err(McpError::protocol(format!(
450            "Method '{}' not found",
451            request.method
452        )))
453    }
454
455    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
456        let writer = self
457            .stdout_writer
458            .as_mut()
459            .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
460
461        let notification_line =
462            serde_json::to_string(&notification).map_err(McpError::serialization)?;
463
464        tracing::trace!("Sending notification: {}", notification_line);
465
466        writer
467            .write_all(notification_line.as_bytes())
468            .await
469            .map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
470        writer
471            .write_all(b"\n")
472            .await
473            .map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
474        writer
475            .flush()
476            .await
477            .map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
478
479        Ok(())
480    }
481
482    async fn stop(&mut self) -> McpResult<()> {
483        tracing::debug!("Stopping STDIO server transport");
484        self.running = false;
485        Ok(())
486    }
487
488    fn is_running(&self) -> bool {
489        self.running
490    }
491
492    fn server_info(&self) -> String {
493        format!("STDIO server transport (running: {})", self.running)
494    }
495}
496
497impl Default for StdioServerTransport {
498    fn default() -> Self {
499        Self::new()
500    }
501}
502
503impl Drop for StdioClientTransport {
504    fn drop(&mut self) {
505        if let Some(mut child) = self.child.take() {
506            // Try to kill the child process if it's still running
507            let _ = child.start_kill();
508        }
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use serde_json::json;
516
517    #[test]
518    fn test_stdio_server_creation() {
519        let transport = StdioServerTransport::new();
520        assert!(!transport.is_running());
521        assert!(transport.stdin_reader.is_some());
522        assert!(transport.stdout_writer.is_some());
523    }
524
525    #[test]
526    fn test_stdio_server_with_config() {
527        let config = TransportConfig {
528            read_timeout_ms: Some(30_000),
529            ..Default::default()
530        };
531
532        let transport = StdioServerTransport::with_config(config);
533        assert_eq!(transport.config.read_timeout_ms, Some(30_000));
534    }
535
536    #[tokio::test]
537    async fn test_stdio_server_handle_request() {
538        let mut transport = StdioServerTransport::new();
539
540        let request = JsonRpcRequest {
541            jsonrpc: "2.0".to_string(),
542            id: json!(1),
543            method: "unknown_method".to_string(),
544            params: None,
545        };
546
547        let result = transport.handle_request(request).await;
548        assert!(result.is_err());
549
550        match result.unwrap_err() {
551            McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
552            _ => panic!("Expected Protocol error"),
553        }
554    }
555
556    // Note: Integration tests with actual processes would go in tests/integration/
557}