Skip to main content

matrixcode_core/matrixrpc/transport/
stdio.rs

1//! Stdio Transport for MatrixRPC
2//!
3//! Provides transport over stdin/stdout using Content-Length framed messages.
4//! Used for parent-child process communication (e.g., LSP/MCP style).
5//!
6//! # Example
7//!
8//! ```no_run
9//! use matrixcode_core::matrixrpc::transport::{StdioTransport, Transport};
10//! use matrixcode_core::matrixrpc::protocol::JsonRpcRequest;
11//!
12//! #[tokio::main]
13//! async fn main() -> std::io::Result<()> {
14//!     let mut transport = StdioTransport::new();
15//!
16//!     // Send a request
17//!     let request = JsonRpcRequest::new("initialize").into();
18//!     transport.send(&request).await?;
19//!
20//!     // Receive a response
21//!     if let Some(response) = transport.receive().await? {
22//!         println!("Received: {:?}", response);
23//!     }
24//!
25//!     Ok(())
26//! }
27//! ```
28
29use std::io;
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
34use tokio::process::{Child, Command};
35use tokio::sync::Mutex;
36
37use super::{FrameCodec, Transport, TransportConfig};
38use crate::matrixrpc::protocol::JsonRpcMessage;
39
40/// Stdio transport for JSON-RPC communication
41///
42/// Supports two modes:
43/// 1. Native mode: Uses process stdin/stdout directly
44/// 2. Child mode: Communicates with a spawned child process
45pub struct StdioTransport {
46    /// Reader for incoming messages
47    reader: Option<BufReader<Box<dyn AsyncRead + Send + Unpin>>>,
48    /// Writer for outgoing messages
49    writer: Option<Box<dyn AsyncWrite + Send + Unpin>>,
50    /// Frame codec for encoding/decoding
51    codec: FrameCodec,
52    /// Configuration
53    config: TransportConfig,
54    /// Closed flag
55    closed: bool,
56    /// Optional child process (if spawned)
57    child: Option<Child>,
58    /// Read buffer for partial message assembly
59    read_buffer: Vec<u8>,
60}
61
62impl StdioTransport {
63    /// Create a new stdio transport using stdin/stdout
64    ///
65    /// This is used for server-side communication where the process
66    /// reads from stdin and writes to stdout.
67    pub fn new() -> Self {
68        Self::with_config(TransportConfig::default())
69    }
70
71    /// Create a new stdio transport with custom configuration
72    pub fn with_config(config: TransportConfig) -> Self {
73        Self {
74            reader: None,
75            writer: None,
76            codec: FrameCodec::with_max_size(config.max_message_size),
77            config,
78            closed: false,
79            child: None,
80            read_buffer: Vec::new(),
81        }
82    }
83
84    /// Create a stdio transport from async read/write streams
85    ///
86    /// Wraps the reader in a BufReader for efficient reading.
87    pub fn from_streams<R, W>(reader: R, writer: W, config: TransportConfig) -> Self
88    where
89        R: AsyncRead + Send + Unpin + 'static,
90        W: AsyncWrite + Send + Unpin + 'static,
91    {
92        Self {
93            reader: Some(BufReader::new(Box::new(reader))),
94            writer: Some(Box::new(writer)),
95            codec: FrameCodec::with_max_size(config.max_message_size),
96            config,
97            closed: false,
98            child: None,
99            read_buffer: Vec::new(),
100        }
101    }
102
103    /// Spawn a child process and create a transport to communicate with it
104    ///
105    /// The child process should implement the JSON-RPC protocol over stdio.
106    pub async fn spawn_child(command: &mut Command) -> io::Result<Self> {
107        Self::spawn_child_with_config(command, TransportConfig::default()).await
108    }
109
110    /// Spawn a child process with custom configuration
111    pub async fn spawn_child_with_config(
112        command: &mut Command,
113        config: TransportConfig,
114    ) -> io::Result<Self> {
115        let mut child = command
116            .stdin(std::process::Stdio::piped())
117            .stdout(std::process::Stdio::piped())
118            .stderr(std::process::Stdio::null())
119            .kill_on_drop(true)
120            .spawn()?;
121
122        let stdin = child
123            .stdin
124            .take()
125            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Failed to open stdin"))?;
126
127        let stdout = child
128            .stdout
129            .take()
130            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Failed to open stdout"))?;
131
132        Ok(Self {
133            reader: Some(BufReader::new(Box::new(stdout))),
134            writer: Some(Box::new(stdin)),
135            codec: FrameCodec::with_max_size(config.max_message_size),
136            config,
137            closed: false,
138            child: Some(child),
139            read_buffer: Vec::new(),
140        })
141    }
142
143    /// Initialize with tokio stdin/stdout
144    ///
145    /// Must be called within a tokio runtime context.
146    pub fn with_tokio_stdio() -> Self {
147        let stdin = tokio::io::stdin();
148        let stdout = tokio::io::stdout();
149        Self::from_streams(stdin, stdout, TransportConfig::default())
150    }
151
152    /// Get the child process if this transport was created via `spawn_child`
153    pub fn child(&mut self) -> Option<&mut Child> {
154        self.child.as_mut()
155    }
156
157    /// Check if there's a child process and if it's still running
158    pub fn is_child_running(&mut self) -> bool {
159        if let Some(child) = &mut self.child {
160            child.try_wait().ok().flatten().is_none()
161        } else {
162            false
163        }
164    }
165
166    /// Wait for the child process to exit and return its status
167    pub async fn wait_child(&mut self) -> io::Result<Option<std::process::ExitStatus>> {
168        if let Some(child) = &mut self.child {
169            child.wait().await.map(Some)
170        } else {
171            Ok(None)
172        }
173    }
174}
175
176impl Default for StdioTransport {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182#[async_trait]
183impl Transport for StdioTransport {
184    async fn send(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
185        if self.closed {
186            return Err(io::Error::new(
187                io::ErrorKind::BrokenPipe,
188                "Transport is closed",
189            ));
190        }
191
192        let writer = self
193            .writer
194            .as_mut()
195            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "No writer available"))?;
196
197        // Encode the message with Content-Length frame
198        let frame = self.codec.encode(message)?;
199
200        // Write with optional timeout
201        if self.config.write_timeout_ms > 0 {
202            let timeout_duration = std::time::Duration::from_millis(self.config.write_timeout_ms);
203            tokio::time::timeout(timeout_duration, async {
204                writer.write_all(&frame).await?;
205                writer.flush().await
206            })
207            .await
208            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Write timeout"))?
209            .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
210        } else {
211            writer.write_all(&frame).await?;
212            writer.flush().await?;
213        }
214
215        Ok(())
216    }
217
218    async fn receive(&mut self) -> io::Result<Option<JsonRpcMessage>> {
219        if self.closed {
220            return Err(io::Error::new(
221                io::ErrorKind::BrokenPipe,
222                "Transport is closed",
223            ));
224        }
225
226        let reader = self
227            .reader
228            .as_mut()
229            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "No reader available"))?;
230
231        // Try to parse a complete message from existing buffer first
232        if !self.read_buffer.is_empty() {
233            if let (_, Some(message)) = self.codec.decode_from_buffer(&self.read_buffer)? {
234                // Remove consumed data
235                self.read_buffer.clear();
236                return Ok(Some(message));
237            }
238        }
239
240        // Read more data
241        let mut temp_buf = vec![0u8; 8192];
242        let bytes_read: usize = if self.config.read_timeout_ms > 0 {
243            let timeout_duration = std::time::Duration::from_millis(self.config.read_timeout_ms);
244            tokio::time::timeout(timeout_duration, reader.read(&mut temp_buf))
245                .await
246                .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Read timeout"))??
247        } else {
248            reader.read(&mut temp_buf).await?
249        };
250
251        if bytes_read == 0 {
252            // EOF
253            return Ok(None);
254        }
255
256        // Append to read buffer
257        self.read_buffer.extend_from_slice(&temp_buf[..bytes_read]);
258
259        // Check max message size
260        if self.read_buffer.len() > self.config.max_message_size {
261            return Err(io::Error::new(
262                io::ErrorKind::InvalidData,
263                format!(
264                    "Buffer size {} exceeds maximum {}",
265                    self.read_buffer.len(),
266                    self.config.max_message_size
267                ),
268            ));
269        }
270
271        // Try to parse
272        let (remaining, message) = self.codec.decode_from_buffer(&self.read_buffer)?;
273
274        // Update buffer with remaining data
275        self.read_buffer = remaining.to_vec();
276
277        Ok(message)
278    }
279
280    async fn close(&mut self) -> io::Result<()> {
281        if self.closed {
282            return Ok(());
283        }
284
285        // Flush and close writer
286        if let Some(mut writer) = self.writer.take() {
287            let _ = writer.shutdown().await;
288        }
289
290        // Close reader
291        self.reader.take();
292
293        // Kill child process if we own one
294        if let Some(mut child) = self.child.take() {
295            let _ = child.kill().await;
296        }
297
298        self.closed = true;
299        Ok(())
300    }
301
302    fn is_closed(&self) -> bool {
303        self.closed
304    }
305}
306
307/// A thread-safe wrapper around StdioTransport for concurrent access
308///
309#[allow(dead_code)]
310/// Uses Arc<Mutex<>> to allow sharing between multiple tasks.
311pub type SharedStdioTransport = Arc<Mutex<StdioTransport>>;
312#[allow(dead_code)]
313#[allow(dead_code)]
314
315/// Create a shared stdio transport
316#[allow(dead_code)]
317#[allow(dead_code)]
318pub fn shared_stdio_transport() -> SharedStdioTransport {
319    Arc::new(Mutex::new(StdioTransport::with_tokio_stdio()))
320}
321
322#[allow(dead_code)]
323/// Create a shared stdio transport with custom config
324pub fn shared_stdio_transport_with_config(config: TransportConfig) -> SharedStdioTransport {
325    Arc::new(Mutex::new(StdioTransport::from_streams(
326        tokio::io::stdin(),
327        tokio::io::stdout(),
328        config,
329    )))
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::matrixrpc::protocol::{JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse};
336    use serde_json::json;
337    use tokio::io::{self, AsyncReadExt};
338
339    #[tokio::test]
340    async fn test_send_and_receive() {
341        // Create a duplex stream - one end for transport, one for reading
342        let (server_read, client_write) = io::duplex(1024);
343        // Another duplex for the read side (not used in this test)
344        let (client_read, _server_write) = io::duplex(1024);
345
346        // Create transport from client side - pass ownership
347        let mut transport = StdioTransport::from_streams(
348            client_read,
349            client_write,
350            TransportConfig::default(),
351        );
352
353        // Client sends request
354        let request = JsonRpcMessage::Request(JsonRpcRequest::new("test_method"));
355        transport.send(&request).await.unwrap();
356
357        // Read on server side - use a separate task
358        let read_task = tokio::spawn(async move {
359            let mut reader = server_read;
360            let mut buf = vec![0u8; 1024];
361            let n = reader.read(&mut buf).await.unwrap();
362            let frame = String::from_utf8_lossy(&buf[..n]);
363            assert!(frame.contains("Content-Length:"));
364            assert!(frame.contains("\"method\":\"test_method\""));
365        });
366
367        read_task.await.unwrap();
368    }
369
370    #[tokio::test]
371    async fn test_close_transport() {
372        let (client_read, client_write) = io::duplex(1024);
373        let mut transport =
374            StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
375
376        assert!(!transport.is_closed());
377
378        transport.close().await.unwrap();
379        assert!(transport.is_closed());
380
381        // Sending after close should fail
382        let request = JsonRpcMessage::Request(JsonRpcRequest::new("test"));
383        let result = transport.send(&request).await;
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_transport_config() {
389        let config = TransportConfig::new()
390            .max_message_size(1024 * 1024)
391            .read_timeout(5000)
392            .write_timeout(10000);
393
394        let transport = StdioTransport::with_config(config.clone());
395        assert_eq!(transport.config.max_message_size, 1024 * 1024);
396        assert_eq!(transport.config.read_timeout_ms, 5000);
397        assert_eq!(transport.config.write_timeout_ms, 10000);
398    }
399
400    #[tokio::test]
401    async fn test_encode_decode_roundtrip() {
402        let (read, write) = io::duplex(4096);
403
404        // Create transport to send
405        let mut transport1 = StdioTransport::from_streams(
406            tokio::io::empty(), // No input
407            write,
408            TransportConfig::default(),
409        );
410
411        // Send message
412        let request = JsonRpcMessage::Request(
413            JsonRpcRequest::with_id("test_method", 42).params(json!({"arg": "value"})),
414        );
415        transport1.send(&request).await.unwrap();
416
417        // Read on the other side
418        let mut transport2 = StdioTransport::from_streams(
419            read,
420            tokio::io::sink(), // Discard output
421            TransportConfig::default(),
422        );
423
424        let received = transport2.receive().await.unwrap();
425        assert!(received.is_some());
426        let msg = received.unwrap();
427        assert!(msg.is_request());
428        assert_eq!(msg.as_request().unwrap().method, "test_method");
429    }
430
431    // ==================== Additional Edge Case Tests ====================
432
433    #[tokio::test]
434    async fn test_receive_on_closed_transport() {
435        let (client_read, client_write) = io::duplex(1024);
436        let mut transport =
437            StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
438
439        transport.close().await.unwrap();
440
441        // Receiving after close should fail
442        let result = transport.receive().await;
443        assert!(result.is_err());
444        let err = result.unwrap_err();
445        assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
446    }
447
448    #[tokio::test]
449    async fn test_send_without_writer() {
450        let mut transport = StdioTransport::with_config(TransportConfig::default());
451        // transport has no reader/writer set
452
453        let request = JsonRpcMessage::Request(JsonRpcRequest::new("test"));
454        let result = transport.send(&request).await;
455        assert!(result.is_err());
456        let err = result.unwrap_err();
457        assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
458    }
459
460    #[tokio::test]
461    async fn test_receive_without_reader() {
462        let mut transport = StdioTransport::with_config(TransportConfig::default());
463        // transport has no reader/writer set
464
465        let result = transport.receive().await;
466        assert!(result.is_err());
467        let err = result.unwrap_err();
468        assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
469    }
470
471    #[tokio::test]
472    async fn test_double_close() {
473        let (client_read, client_write) = io::duplex(1024);
474        let mut transport =
475            StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
476
477        // Close twice should be idempotent
478        transport.close().await.unwrap();
479        assert!(transport.is_closed());
480        transport.close().await.unwrap();
481        assert!(transport.is_closed());
482    }
483
484    #[tokio::test]
485    async fn test_send_and_receive_response() {
486        // Create a pair of connected duplex streams
487        let (_server_read, client_write) = io::duplex(4096);
488        let (client_read, server_write) = io::duplex(4096);
489
490        // Client transport - pass ownership of streams
491        let mut client_transport = StdioTransport::from_streams(
492            client_read,
493            client_write,
494            TransportConfig::default(),
495        );
496
497        // Server sends a response using the codec directly
498        let response =
499            JsonRpcMessage::Response(JsonRpcResponse::success(1, json!({"status": "ok"})));
500        let frame = FrameCodec::new().encode(&response).unwrap();
501
502        // Use a separate task to write from server side
503        let write_task = tokio::spawn(async move {
504            use tokio::io::AsyncWriteExt;
505            let mut writer = server_write;
506            writer.write_all(&frame).await.unwrap();
507            writer.flush().await.unwrap();
508            writer
509        });
510
511        // Client receives the response
512        let received = client_transport.receive().await.unwrap();
513        assert!(received.is_some());
514        let msg = received.unwrap();
515        assert!(msg.is_response());
516        assert!(msg.as_response().unwrap().is_success());
517
518        // Wait for write task to complete
519        write_task.await.unwrap();
520    }
521
522    #[tokio::test]
523    async fn test_multiple_messages_codec_roundtrip() {
524        // Test encoding/decoding multiple messages using the codec directly
525        let codec = FrameCodec::new();
526        let mut buffer = Vec::new();
527
528        // Encode multiple messages
529        for i in 0..5 {
530            let request = JsonRpcMessage::Request(
531                JsonRpcRequest::with_id("test", i).params(json!({"index": i})),
532            );
533            let frame = codec.encode(&request).unwrap();
534            buffer.extend_from_slice(&frame);
535        }
536
537        // Decode all messages sequentially
538        for i in 0..5 {
539            let (remaining, message) = codec.decode_from_buffer(&buffer).unwrap();
540            assert!(message.is_some());
541            let msg = message.unwrap();
542            assert!(msg.is_request());
543            assert_eq!(msg.as_request().unwrap().id, Some(JsonRpcId::Number(i)));
544            buffer = remaining.to_vec();
545        }
546
547        assert!(buffer.is_empty());
548    }
549
550    #[tokio::test]
551    async fn test_receive_eof() {
552        // Use an empty stream to simulate EOF
553        let (read, write) = io::duplex(1024);
554
555        let mut transport =
556            StdioTransport::from_streams(read, write, TransportConfig::default());
557
558        // Close the write side to simulate EOF
559        drop(transport.writer.take());
560
561        // Receiving from empty stream should return Ok(None) for EOF
562        let result = transport.receive().await;
563        // With empty buffer and EOF, should return Ok(None)
564        assert!(result.is_ok());
565        assert!(result.unwrap().is_none());
566    }
567
568    #[test]
569    fn test_default_transport() {
570        let transport = StdioTransport::default();
571        assert!(!transport.is_closed());
572        assert!(transport.child.is_none());
573    }
574
575    #[test]
576    fn test_child_process_methods() {
577        let mut transport = StdioTransport::new();
578        assert!(transport.child().is_none());
579        assert!(!transport.is_child_running());
580    }
581
582    #[tokio::test]
583    async fn test_max_message_size_exceeded_on_receive() {
584        // Send a message larger than the max size
585        let large_params = "x".repeat(100);
586        let request = JsonRpcMessage::Request(
587            JsonRpcRequest::new("test").params(json!({"data": large_params})),
588        );
589        let frame = FrameCodec::new().encode(&request).unwrap();
590
591        // Create a small max message size config
592        let (read, write) = io::duplex(8192);
593
594        // Write data in a separate task
595        let frame_clone = frame.clone();
596        let write_task = tokio::spawn(async move {
597            use tokio::io::AsyncWriteExt;
598            let mut writer = write;
599            writer.write_all(&frame_clone).await.unwrap();
600            writer.flush().await.unwrap();
601            writer
602        });
603
604        let mut small_transport = StdioTransport::from_streams(
605            read,
606            tokio::io::sink(),
607            TransportConfig::new().max_message_size(10),
608        );
609
610        // Receive should fail with InvalidData
611        let result = small_transport.receive().await;
612        assert!(result.is_err());
613        let err = result.unwrap_err();
614        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
615
616        write_task.await.unwrap();
617    }
618
619    #[test]
620    fn test_send_notification_codec() {
621        // Test notification encoding via codec (no async needed)
622        let notification = JsonRpcMessage::Request(
623            JsonRpcRequest::notification("log")
624                .params(json!({"message": "hello"})),
625        );
626
627        let codec = FrameCodec::new();
628        let frame = codec.encode(&notification).unwrap();
629        let frame_str = String::from_utf8_lossy(&frame);
630
631        assert!(frame_str.contains("Content-Length:"));
632        assert!(frame_str.contains("\"method\":\"log\""));
633
634        // Verify notification has no id in the JSON body
635        let body_start = frame_str.find("\r\n\r\n").unwrap() + 4;
636        let body = &frame_str[body_start..];
637        let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
638        assert!(parsed.get("id").is_none());
639    }
640
641    #[test]
642    fn test_send_error_response_codec() {
643        // Test error response encoding via codec
644        let error_response = JsonRpcMessage::Response(
645            JsonRpcResponse::error(1, JsonRpcError::method_not_found("unknown")),
646        );
647
648        let codec = FrameCodec::new();
649        let frame = codec.encode(&error_response).unwrap();
650        let frame_str = String::from_utf8_lossy(&frame);
651
652        assert!(frame_str.contains("\"error\""));
653        assert!(frame_str.contains("Method 'unknown' not found"));
654    }
655
656    #[test]
657    fn test_codec_roundtrip_with_string_id() {
658        // Test with string ID
659        let request = JsonRpcMessage::Request(
660            JsonRpcRequest::with_id("test_method", "uuid-12345")
661                .params(json!({"arg": "value"})),
662        );
663
664        let codec = FrameCodec::new();
665        let frame = codec.encode(&request).unwrap();
666        let (remaining, decoded) = codec.decode_from_buffer(&frame).unwrap();
667
668        assert!(remaining.is_empty());
669        assert!(decoded.is_some());
670        let msg = decoded.unwrap();
671        assert_eq!(msg.as_request().unwrap().id, Some(JsonRpcId::String("uuid-12345".to_string())));
672    }
673}