matrixcode_core/matrixrpc/transport/
stdio.rs1use std::io;
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
34use tokio::process::{Child, Command};
35use tokio::sync::Mutex;
36
37use super::{FrameCodec, Transport, TransportConfig};
38use crate::matrixrpc::protocol::JsonRpcMessage;
39
40pub struct StdioTransport {
46 reader: Option<BufReader<Box<dyn AsyncRead + Send + Unpin>>>,
48 writer: Option<Box<dyn AsyncWrite + Send + Unpin>>,
50 codec: FrameCodec,
52 config: TransportConfig,
54 closed: bool,
56 child: Option<Child>,
58 read_buffer: Vec<u8>,
60}
61
62impl StdioTransport {
63 pub fn new() -> Self {
68 Self::with_config(TransportConfig::default())
69 }
70
71 pub fn with_config(config: TransportConfig) -> Self {
73 Self {
74 reader: None,
75 writer: None,
76 codec: FrameCodec::with_max_size(config.max_message_size),
77 config,
78 closed: false,
79 child: None,
80 read_buffer: Vec::new(),
81 }
82 }
83
84 pub fn from_streams<R, W>(reader: R, writer: W, config: TransportConfig) -> Self
88 where
89 R: AsyncRead + Send + Unpin + 'static,
90 W: AsyncWrite + Send + Unpin + 'static,
91 {
92 Self {
93 reader: Some(BufReader::new(Box::new(reader))),
94 writer: Some(Box::new(writer)),
95 codec: FrameCodec::with_max_size(config.max_message_size),
96 config,
97 closed: false,
98 child: None,
99 read_buffer: Vec::new(),
100 }
101 }
102
103 pub async fn spawn_child(command: &mut Command) -> io::Result<Self> {
107 Self::spawn_child_with_config(command, TransportConfig::default()).await
108 }
109
110 pub async fn spawn_child_with_config(
112 command: &mut Command,
113 config: TransportConfig,
114 ) -> io::Result<Self> {
115 let mut child = command
116 .stdin(std::process::Stdio::piped())
117 .stdout(std::process::Stdio::piped())
118 .stderr(std::process::Stdio::null())
119 .kill_on_drop(true)
120 .spawn()?;
121
122 let stdin = child
123 .stdin
124 .take()
125 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Failed to open stdin"))?;
126
127 let stdout = child
128 .stdout
129 .take()
130 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Failed to open stdout"))?;
131
132 Ok(Self {
133 reader: Some(BufReader::new(Box::new(stdout))),
134 writer: Some(Box::new(stdin)),
135 codec: FrameCodec::with_max_size(config.max_message_size),
136 config,
137 closed: false,
138 child: Some(child),
139 read_buffer: Vec::new(),
140 })
141 }
142
143 pub fn with_tokio_stdio() -> Self {
147 let stdin = tokio::io::stdin();
148 let stdout = tokio::io::stdout();
149 Self::from_streams(stdin, stdout, TransportConfig::default())
150 }
151
152 pub fn child(&mut self) -> Option<&mut Child> {
154 self.child.as_mut()
155 }
156
157 pub fn is_child_running(&mut self) -> bool {
159 if let Some(child) = &mut self.child {
160 child.try_wait().ok().flatten().is_none()
161 } else {
162 false
163 }
164 }
165
166 pub async fn wait_child(&mut self) -> io::Result<Option<std::process::ExitStatus>> {
168 if let Some(child) = &mut self.child {
169 child.wait().await.map(Some)
170 } else {
171 Ok(None)
172 }
173 }
174}
175
176impl Default for StdioTransport {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182#[async_trait]
183impl Transport for StdioTransport {
184 async fn send(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
185 if self.closed {
186 return Err(io::Error::new(
187 io::ErrorKind::BrokenPipe,
188 "Transport is closed",
189 ));
190 }
191
192 let writer = self
193 .writer
194 .as_mut()
195 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "No writer available"))?;
196
197 let frame = self.codec.encode(message)?;
199
200 if self.config.write_timeout_ms > 0 {
202 let timeout_duration = std::time::Duration::from_millis(self.config.write_timeout_ms);
203 tokio::time::timeout(timeout_duration, async {
204 writer.write_all(&frame).await?;
205 writer.flush().await
206 })
207 .await
208 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Write timeout"))?
209 .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
210 } else {
211 writer.write_all(&frame).await?;
212 writer.flush().await?;
213 }
214
215 Ok(())
216 }
217
218 async fn receive(&mut self) -> io::Result<Option<JsonRpcMessage>> {
219 if self.closed {
220 return Err(io::Error::new(
221 io::ErrorKind::BrokenPipe,
222 "Transport is closed",
223 ));
224 }
225
226 let reader = self
227 .reader
228 .as_mut()
229 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "No reader available"))?;
230
231 if !self.read_buffer.is_empty() {
233 if let (_, Some(message)) = self.codec.decode_from_buffer(&self.read_buffer)? {
234 self.read_buffer.clear();
236 return Ok(Some(message));
237 }
238 }
239
240 let mut temp_buf = vec![0u8; 8192];
242 let bytes_read: usize = if self.config.read_timeout_ms > 0 {
243 let timeout_duration = std::time::Duration::from_millis(self.config.read_timeout_ms);
244 tokio::time::timeout(timeout_duration, reader.read(&mut temp_buf))
245 .await
246 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Read timeout"))??
247 } else {
248 reader.read(&mut temp_buf).await?
249 };
250
251 if bytes_read == 0 {
252 return Ok(None);
254 }
255
256 self.read_buffer.extend_from_slice(&temp_buf[..bytes_read]);
258
259 if self.read_buffer.len() > self.config.max_message_size {
261 return Err(io::Error::new(
262 io::ErrorKind::InvalidData,
263 format!(
264 "Buffer size {} exceeds maximum {}",
265 self.read_buffer.len(),
266 self.config.max_message_size
267 ),
268 ));
269 }
270
271 let (remaining, message) = self.codec.decode_from_buffer(&self.read_buffer)?;
273
274 self.read_buffer = remaining.to_vec();
276
277 Ok(message)
278 }
279
280 async fn close(&mut self) -> io::Result<()> {
281 if self.closed {
282 return Ok(());
283 }
284
285 if let Some(mut writer) = self.writer.take() {
287 let _ = writer.shutdown().await;
288 }
289
290 self.reader.take();
292
293 if let Some(mut child) = self.child.take() {
295 let _ = child.kill().await;
296 }
297
298 self.closed = true;
299 Ok(())
300 }
301
302 fn is_closed(&self) -> bool {
303 self.closed
304 }
305}
306
307#[allow(dead_code)]
310pub type SharedStdioTransport = Arc<Mutex<StdioTransport>>;
312#[allow(dead_code)]
313#[allow(dead_code)]
314
315#[allow(dead_code)]
317#[allow(dead_code)]
318pub fn shared_stdio_transport() -> SharedStdioTransport {
319 Arc::new(Mutex::new(StdioTransport::with_tokio_stdio()))
320}
321
322#[allow(dead_code)]
323pub fn shared_stdio_transport_with_config(config: TransportConfig) -> SharedStdioTransport {
325 Arc::new(Mutex::new(StdioTransport::from_streams(
326 tokio::io::stdin(),
327 tokio::io::stdout(),
328 config,
329 )))
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::matrixrpc::protocol::{JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse};
336 use serde_json::json;
337 use tokio::io::{self, AsyncReadExt};
338
339 #[tokio::test]
340 async fn test_send_and_receive() {
341 let (server_read, client_write) = io::duplex(1024);
343 let (client_read, _server_write) = io::duplex(1024);
345
346 let mut transport = StdioTransport::from_streams(
348 client_read,
349 client_write,
350 TransportConfig::default(),
351 );
352
353 let request = JsonRpcMessage::Request(JsonRpcRequest::new("test_method"));
355 transport.send(&request).await.unwrap();
356
357 let read_task = tokio::spawn(async move {
359 let mut reader = server_read;
360 let mut buf = vec![0u8; 1024];
361 let n = reader.read(&mut buf).await.unwrap();
362 let frame = String::from_utf8_lossy(&buf[..n]);
363 assert!(frame.contains("Content-Length:"));
364 assert!(frame.contains("\"method\":\"test_method\""));
365 });
366
367 read_task.await.unwrap();
368 }
369
370 #[tokio::test]
371 async fn test_close_transport() {
372 let (client_read, client_write) = io::duplex(1024);
373 let mut transport =
374 StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
375
376 assert!(!transport.is_closed());
377
378 transport.close().await.unwrap();
379 assert!(transport.is_closed());
380
381 let request = JsonRpcMessage::Request(JsonRpcRequest::new("test"));
383 let result = transport.send(&request).await;
384 assert!(result.is_err());
385 }
386
387 #[test]
388 fn test_transport_config() {
389 let config = TransportConfig::new()
390 .max_message_size(1024 * 1024)
391 .read_timeout(5000)
392 .write_timeout(10000);
393
394 let transport = StdioTransport::with_config(config.clone());
395 assert_eq!(transport.config.max_message_size, 1024 * 1024);
396 assert_eq!(transport.config.read_timeout_ms, 5000);
397 assert_eq!(transport.config.write_timeout_ms, 10000);
398 }
399
400 #[tokio::test]
401 async fn test_encode_decode_roundtrip() {
402 let (read, write) = io::duplex(4096);
403
404 let mut transport1 = StdioTransport::from_streams(
406 tokio::io::empty(), write,
408 TransportConfig::default(),
409 );
410
411 let request = JsonRpcMessage::Request(
413 JsonRpcRequest::with_id("test_method", 42).params(json!({"arg": "value"})),
414 );
415 transport1.send(&request).await.unwrap();
416
417 let mut transport2 = StdioTransport::from_streams(
419 read,
420 tokio::io::sink(), TransportConfig::default(),
422 );
423
424 let received = transport2.receive().await.unwrap();
425 assert!(received.is_some());
426 let msg = received.unwrap();
427 assert!(msg.is_request());
428 assert_eq!(msg.as_request().unwrap().method, "test_method");
429 }
430
431 #[tokio::test]
434 async fn test_receive_on_closed_transport() {
435 let (client_read, client_write) = io::duplex(1024);
436 let mut transport =
437 StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
438
439 transport.close().await.unwrap();
440
441 let result = transport.receive().await;
443 assert!(result.is_err());
444 let err = result.unwrap_err();
445 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
446 }
447
448 #[tokio::test]
449 async fn test_send_without_writer() {
450 let mut transport = StdioTransport::with_config(TransportConfig::default());
451 let request = JsonRpcMessage::Request(JsonRpcRequest::new("test"));
454 let result = transport.send(&request).await;
455 assert!(result.is_err());
456 let err = result.unwrap_err();
457 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
458 }
459
460 #[tokio::test]
461 async fn test_receive_without_reader() {
462 let mut transport = StdioTransport::with_config(TransportConfig::default());
463 let result = transport.receive().await;
466 assert!(result.is_err());
467 let err = result.unwrap_err();
468 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
469 }
470
471 #[tokio::test]
472 async fn test_double_close() {
473 let (client_read, client_write) = io::duplex(1024);
474 let mut transport =
475 StdioTransport::from_streams(client_read, client_write, TransportConfig::default());
476
477 transport.close().await.unwrap();
479 assert!(transport.is_closed());
480 transport.close().await.unwrap();
481 assert!(transport.is_closed());
482 }
483
484 #[tokio::test]
485 async fn test_send_and_receive_response() {
486 let (_server_read, client_write) = io::duplex(4096);
488 let (client_read, server_write) = io::duplex(4096);
489
490 let mut client_transport = StdioTransport::from_streams(
492 client_read,
493 client_write,
494 TransportConfig::default(),
495 );
496
497 let response =
499 JsonRpcMessage::Response(JsonRpcResponse::success(1, json!({"status": "ok"})));
500 let frame = FrameCodec::new().encode(&response).unwrap();
501
502 let write_task = tokio::spawn(async move {
504 use tokio::io::AsyncWriteExt;
505 let mut writer = server_write;
506 writer.write_all(&frame).await.unwrap();
507 writer.flush().await.unwrap();
508 writer
509 });
510
511 let received = client_transport.receive().await.unwrap();
513 assert!(received.is_some());
514 let msg = received.unwrap();
515 assert!(msg.is_response());
516 assert!(msg.as_response().unwrap().is_success());
517
518 write_task.await.unwrap();
520 }
521
522 #[tokio::test]
523 async fn test_multiple_messages_codec_roundtrip() {
524 let codec = FrameCodec::new();
526 let mut buffer = Vec::new();
527
528 for i in 0..5 {
530 let request = JsonRpcMessage::Request(
531 JsonRpcRequest::with_id("test", i).params(json!({"index": i})),
532 );
533 let frame = codec.encode(&request).unwrap();
534 buffer.extend_from_slice(&frame);
535 }
536
537 for i in 0..5 {
539 let (remaining, message) = codec.decode_from_buffer(&buffer).unwrap();
540 assert!(message.is_some());
541 let msg = message.unwrap();
542 assert!(msg.is_request());
543 assert_eq!(msg.as_request().unwrap().id, Some(JsonRpcId::Number(i)));
544 buffer = remaining.to_vec();
545 }
546
547 assert!(buffer.is_empty());
548 }
549
550 #[tokio::test]
551 async fn test_receive_eof() {
552 let (read, write) = io::duplex(1024);
554
555 let mut transport =
556 StdioTransport::from_streams(read, write, TransportConfig::default());
557
558 drop(transport.writer.take());
560
561 let result = transport.receive().await;
563 assert!(result.is_ok());
565 assert!(result.unwrap().is_none());
566 }
567
568 #[test]
569 fn test_default_transport() {
570 let transport = StdioTransport::default();
571 assert!(!transport.is_closed());
572 assert!(transport.child.is_none());
573 }
574
575 #[test]
576 fn test_child_process_methods() {
577 let mut transport = StdioTransport::new();
578 assert!(transport.child().is_none());
579 assert!(!transport.is_child_running());
580 }
581
582 #[tokio::test]
583 async fn test_max_message_size_exceeded_on_receive() {
584 let large_params = "x".repeat(100);
586 let request = JsonRpcMessage::Request(
587 JsonRpcRequest::new("test").params(json!({"data": large_params})),
588 );
589 let frame = FrameCodec::new().encode(&request).unwrap();
590
591 let (read, write) = io::duplex(8192);
593
594 let frame_clone = frame.clone();
596 let write_task = tokio::spawn(async move {
597 use tokio::io::AsyncWriteExt;
598 let mut writer = write;
599 writer.write_all(&frame_clone).await.unwrap();
600 writer.flush().await.unwrap();
601 writer
602 });
603
604 let mut small_transport = StdioTransport::from_streams(
605 read,
606 tokio::io::sink(),
607 TransportConfig::new().max_message_size(10),
608 );
609
610 let result = small_transport.receive().await;
612 assert!(result.is_err());
613 let err = result.unwrap_err();
614 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
615
616 write_task.await.unwrap();
617 }
618
619 #[test]
620 fn test_send_notification_codec() {
621 let notification = JsonRpcMessage::Request(
623 JsonRpcRequest::notification("log")
624 .params(json!({"message": "hello"})),
625 );
626
627 let codec = FrameCodec::new();
628 let frame = codec.encode(¬ification).unwrap();
629 let frame_str = String::from_utf8_lossy(&frame);
630
631 assert!(frame_str.contains("Content-Length:"));
632 assert!(frame_str.contains("\"method\":\"log\""));
633
634 let body_start = frame_str.find("\r\n\r\n").unwrap() + 4;
636 let body = &frame_str[body_start..];
637 let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
638 assert!(parsed.get("id").is_none());
639 }
640
641 #[test]
642 fn test_send_error_response_codec() {
643 let error_response = JsonRpcMessage::Response(
645 JsonRpcResponse::error(1, JsonRpcError::method_not_found("unknown")),
646 );
647
648 let codec = FrameCodec::new();
649 let frame = codec.encode(&error_response).unwrap();
650 let frame_str = String::from_utf8_lossy(&frame);
651
652 assert!(frame_str.contains("\"error\""));
653 assert!(frame_str.contains("Method 'unknown' not found"));
654 }
655
656 #[test]
657 fn test_codec_roundtrip_with_string_id() {
658 let request = JsonRpcMessage::Request(
660 JsonRpcRequest::with_id("test_method", "uuid-12345")
661 .params(json!({"arg": "value"})),
662 );
663
664 let codec = FrameCodec::new();
665 let frame = codec.encode(&request).unwrap();
666 let (remaining, decoded) = codec.decode_from_buffer(&frame).unwrap();
667
668 assert!(remaining.is_empty());
669 assert!(decoded.is_some());
670 let msg = decoded.unwrap();
671 assert_eq!(msg.as_request().unwrap().id, Some(JsonRpcId::String("uuid-12345".to_string())));
672 }
673}