pmcp/shared/
stdio.rs

1//! Standard I/O transport implementation.
2//!
3//! This transport uses stdin/stdout for communication, with length-prefixed
4//! framing to ensure message boundaries are preserved.
5
6use crate::error::{Result, TransportError};
7use crate::shared::transport::{Transport, TransportMessage};
8use async_trait::async_trait;
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
11#[cfg(not(target_arch = "wasm32"))]
12use tokio::sync::Mutex;
13
14/// Line-delimited JSON framing header.
15const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
16
17/// stdio transport for MCP communication.
18///
19/// Uses length-prefixed framing compatible with the TypeScript SDK.
20///
21/// # Examples
22///
23/// ```rust,no_run
24/// use pmcp::shared::StdioTransport;
25///
26/// # async fn example() -> pmcp::Result<()> {
27/// let transport = StdioTransport::new();
28/// // Use with Client or Server
29/// # Ok(())
30/// # }
31/// ```
32#[derive(Debug)]
33pub struct StdioTransport {
34    stdin: Mutex<BufReader<tokio::io::Stdin>>,
35    stdout: Mutex<tokio::io::Stdout>,
36    closed: std::sync::atomic::AtomicBool,
37}
38
39impl StdioTransport {
40    /// Create a new stdio transport.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use pmcp::shared::StdioTransport;
46    ///
47    /// let transport = StdioTransport::new();
48    /// // Transport is ready to use
49    /// ```
50    pub fn new() -> Self {
51        Self {
52            stdin: Mutex::new(BufReader::new(tokio::io::stdin())),
53            stdout: Mutex::new(tokio::io::stdout()),
54            closed: std::sync::atomic::AtomicBool::new(false),
55        }
56    }
57
58    /// Parse a content-length header.
59    ///
60    /// Parses lines like "Content-Length: 42" to extract the length.
61    fn parse_content_length(line: &str) -> Option<usize> {
62        line.strip_prefix(CONTENT_LENGTH_HEADER)
63            .and_then(|content| content.trim().parse().ok())
64    }
65}
66
67impl Default for StdioTransport {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73#[async_trait]
74impl Transport for StdioTransport {
75    async fn send(&mut self, message: TransportMessage) -> Result<()> {
76        if self.closed.load(std::sync::atomic::Ordering::Acquire) {
77            return Err(TransportError::ConnectionClosed.into());
78        }
79
80        let json_bytes = Self::serialize_message(&message)?;
81        self.write_message(&json_bytes).await
82    }
83
84    async fn receive(&mut self) -> Result<TransportMessage> {
85        if self.closed.load(std::sync::atomic::Ordering::Acquire) {
86            return Err(TransportError::ConnectionClosed.into());
87        }
88
89        let content_length = self.read_headers().await?;
90        let buffer = self.read_message_body(content_length).await?;
91        Self::parse_message(&buffer)
92    }
93
94    async fn close(&mut self) -> Result<()> {
95        self.closed
96            .store(true, std::sync::atomic::Ordering::Release);
97
98        // Flush any pending output
99        let mut stdout = self.stdout.lock().await;
100        stdout.flush().await.map_err(TransportError::from)?;
101        drop(stdout);
102
103        Ok(())
104    }
105
106    fn is_connected(&self) -> bool {
107        !self.closed.load(std::sync::atomic::Ordering::Acquire)
108    }
109
110    fn transport_type(&self) -> &'static str {
111        "stdio"
112    }
113}
114
115impl StdioTransport {
116    /// Serialize transport message to JSON bytes.
117    pub fn serialize_message(message: &TransportMessage) -> Result<Vec<u8>> {
118        match message {
119            TransportMessage::Request { id, request } => {
120                let jsonrpc_request = crate::shared::create_request(id.clone(), request.clone());
121                serde_json::to_vec(&jsonrpc_request).map_err(|e| {
122                    TransportError::InvalidMessage(format!("Failed to serialize request: {}", e))
123                        .into()
124                })
125            },
126            TransportMessage::Response(response) => serde_json::to_vec(response).map_err(|e| {
127                TransportError::InvalidMessage(format!("Failed to serialize response: {}", e))
128                    .into()
129            }),
130            TransportMessage::Notification(notification) => {
131                let jsonrpc_notification = crate::shared::create_notification(notification.clone());
132                serde_json::to_vec(&jsonrpc_notification).map_err(|e| {
133                    TransportError::InvalidMessage(format!(
134                        "Failed to serialize notification: {}",
135                        e
136                    ))
137                    .into()
138                })
139            },
140        }
141    }
142
143    /// Write framed message to stdout.
144    async fn write_message(&self, json_bytes: &[u8]) -> Result<()> {
145        let mut stdout = self.stdout.lock().await;
146
147        // Write content-length header
148        let header = format!("{}{}\r\n\r\n", CONTENT_LENGTH_HEADER, json_bytes.len());
149        stdout
150            .write_all(header.as_bytes())
151            .await
152            .map_err(TransportError::from)?;
153
154        // Write message payload
155        stdout
156            .write_all(json_bytes)
157            .await
158            .map_err(TransportError::from)?;
159
160        // Always flush stdio
161        stdout.flush().await.map_err(TransportError::from)?;
162        drop(stdout);
163
164        Ok(())
165    }
166
167    /// Read headers and extract content length.
168    async fn read_headers(&self) -> Result<usize> {
169        let mut stdin = self.stdin.lock().await;
170        let mut line = String::new();
171        let mut content_length = None;
172
173        // Read headers until we find content-length
174        loop {
175            line.clear();
176            let bytes_read = stdin
177                .read_line(&mut line)
178                .await
179                .map_err(TransportError::from)?;
180
181            if bytes_read == 0 {
182                // EOF reached
183                drop(stdin);
184                self.closed
185                    .store(true, std::sync::atomic::Ordering::Release);
186                return Err(TransportError::ConnectionClosed.into());
187            }
188
189            let line = line.trim();
190
191            if line.is_empty() {
192                // End of headers
193                break;
194            }
195
196            if let Some(length) = Self::parse_content_length(line) {
197                content_length = Some(length);
198            }
199        }
200        drop(stdin);
201
202        content_length.ok_or_else(|| {
203            TransportError::InvalidMessage("Missing Content-Length header".to_string()).into()
204        })
205    }
206
207    /// Read message body with specified content length.
208    async fn read_message_body(&self, content_length: usize) -> Result<Vec<u8>> {
209        let mut stdin = self.stdin.lock().await;
210        let mut buffer = vec![0u8; content_length];
211        stdin
212            .read_exact(&mut buffer)
213            .await
214            .map_err(TransportError::from)?;
215        drop(stdin);
216        Ok(buffer)
217    }
218
219    /// Parse JSON message and determine its type.
220    pub fn parse_message(buffer: &[u8]) -> Result<TransportMessage> {
221        let json_value: serde_json::Value = serde_json::from_slice(buffer)
222            .map_err(|e| TransportError::InvalidMessage(format!("Invalid JSON: {}", e)))?;
223
224        if json_value.get("method").is_some() {
225            Self::parse_method_message(json_value)
226        } else if json_value.get("result").is_some() || json_value.get("error").is_some() {
227            Self::parse_response_message(json_value)
228        } else {
229            Err(TransportError::InvalidMessage("Unknown message type".to_string()).into())
230        }
231    }
232
233    /// Parse message with method field (request or notification).
234    fn parse_method_message(json_value: serde_json::Value) -> Result<TransportMessage> {
235        if json_value.get("id").is_some() {
236            // It's a request
237            let request: crate::types::JSONRPCRequest<serde_json::Value> =
238                serde_json::from_value(json_value).map_err(|e| {
239                    TransportError::InvalidMessage(format!("Invalid request: {}", e))
240                })?;
241
242            let parsed_request = crate::shared::parse_request(request)
243                .map_err(|e| TransportError::InvalidMessage(format!("Invalid request: {}", e)))?;
244
245            Ok(TransportMessage::Request {
246                id: parsed_request.0,
247                request: parsed_request.1,
248            })
249        } else {
250            // It's a notification
251            let parsed_notification =
252                crate::shared::parse_notification(json_value).map_err(|e| {
253                    TransportError::InvalidMessage(format!("Invalid notification: {}", e))
254                })?;
255
256            Ok(TransportMessage::Notification(parsed_notification))
257        }
258    }
259
260    /// Parse response message.
261    fn parse_response_message(json_value: serde_json::Value) -> Result<TransportMessage> {
262        let response: crate::types::JSONRPCResponse = serde_json::from_value(json_value)
263            .map_err(|e| TransportError::InvalidMessage(format!("Invalid response: {}", e)))?;
264
265        Ok(TransportMessage::Response(response))
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn parse_content_length_valid() {
275        assert_eq!(
276            StdioTransport::parse_content_length("Content-Length: 42"),
277            Some(42)
278        );
279        assert_eq!(
280            StdioTransport::parse_content_length("Content-Length: 0"),
281            Some(0)
282        );
283        assert_eq!(
284            StdioTransport::parse_content_length("Content-Length: 999999"),
285            Some(999_999)
286        );
287        // With whitespace
288        assert_eq!(
289            StdioTransport::parse_content_length("Content-Length:  42  "),
290            Some(42)
291        );
292    }
293
294    #[test]
295    fn parse_content_length_invalid() {
296        assert_eq!(
297            StdioTransport::parse_content_length("Content-Type: application/json"),
298            None
299        );
300        assert_eq!(
301            StdioTransport::parse_content_length("Content-Length: abc"),
302            None
303        );
304        assert_eq!(StdioTransport::parse_content_length(""), None);
305        assert_eq!(
306            StdioTransport::parse_content_length("Content-Length: -42"),
307            None
308        );
309        assert_eq!(StdioTransport::parse_content_length("Content-Length"), None);
310    }
311
312    #[tokio::test]
313    async fn transport_properties() {
314        let transport = StdioTransport::new();
315        assert!(transport.is_connected());
316        assert_eq!(transport.transport_type(), "stdio");
317    }
318
319    #[tokio::test]
320    async fn test_close() {
321        let mut transport = StdioTransport::new();
322        assert!(transport.is_connected());
323
324        transport.close().await.unwrap();
325        assert!(!transport.is_connected());
326    }
327}