1use crate::{Error, Result};
7use serde_json::Value as JsonValue;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9use tokio::sync::mpsc;
10
11pub async fn send_message<W>(stdin: &mut W, message: JsonValue) -> Result<()>
20where
21 W: AsyncWriteExt + Unpin,
22{
23 let json_bytes = serde_json::to_vec(&message)
25 .map_err(|e| Error::TransportError(format!("Failed to serialize JSON: {}", e)))?;
26
27 let length = json_bytes.len() as u32;
28
29 stdin
31 .write_all(&length.to_le_bytes())
32 .await
33 .map_err(|e| Error::TransportError(format!("Failed to write length: {}", e)))?;
34
35 stdin
37 .write_all(&json_bytes)
38 .await
39 .map_err(|e| Error::TransportError(format!("Failed to write message: {}", e)))?;
40
41 stdin
43 .flush()
44 .await
45 .map_err(|e| Error::TransportError(format!("Failed to flush: {}", e)))?;
46
47 Ok(())
48}
49
50pub trait Transport: Send + Sync {
55 fn send(&mut self, message: JsonValue) -> impl std::future::Future<Output = Result<()>> + Send;
57}
58
59pub struct PipeTransport<W, R>
100where
101 W: AsyncWrite + Unpin + Send,
102 R: AsyncRead + Unpin + Send,
103{
104 stdin: W,
105 stdout: R,
106 message_tx: mpsc::UnboundedSender<JsonValue>,
107}
108
109pub struct PipeTransportReceiver<R>
115where
116 R: AsyncRead + Unpin + Send,
117{
118 stdout: R,
119 message_tx: mpsc::UnboundedSender<JsonValue>,
120}
121
122impl<R> PipeTransportReceiver<R>
123where
124 R: AsyncRead + Unpin + Send,
125{
126 pub async fn run(mut self) -> Result<()> {
134 const CHUNK_SIZE: usize = 32_768; loop {
137 let mut len_buf = [0u8; 4];
139 self.stdout.read_exact(&mut len_buf).await.map_err(|e| {
140 Error::TransportError(format!("Failed to read length prefix: {}", e))
141 })?;
142
143 let length = u32::from_le_bytes(len_buf) as usize;
144
145 let message_buf = if length <= CHUNK_SIZE {
148 let mut buf = vec![0u8; length];
150 self.stdout
151 .read_exact(&mut buf)
152 .await
153 .map_err(|e| Error::TransportError(format!("Failed to read message: {}", e)))?;
154 buf
155 } else {
156 let mut buf = Vec::with_capacity(length);
158 let mut remaining = length;
159
160 while remaining > 0 {
161 let to_read = std::cmp::min(remaining, CHUNK_SIZE);
162 let mut chunk = vec![0u8; to_read];
163
164 self.stdout.read_exact(&mut chunk).await.map_err(|e| {
165 Error::TransportError(format!("Failed to read message chunk: {}", e))
166 })?;
167
168 buf.extend_from_slice(&chunk);
169 remaining -= to_read;
170 }
171
172 buf
173 };
174
175 let message: JsonValue = serde_json::from_slice(&message_buf)
177 .map_err(|e| Error::ProtocolError(format!("Failed to parse JSON: {}", e)))?;
178
179 if self.message_tx.send(message).is_err() {
181 break;
183 }
184 }
185
186 Ok(())
187 }
188}
189
190impl<W, R> PipeTransport<W, R>
191where
192 W: AsyncWrite + Unpin + Send,
193 R: AsyncRead + Unpin + Send,
194{
195 pub fn new(stdin: W, stdout: R) -> (Self, mpsc::UnboundedReceiver<JsonValue>) {
235 let (message_tx, message_rx) = mpsc::unbounded_channel();
236
237 let transport = Self {
238 stdin,
239 stdout,
240 message_tx,
241 };
242
243 (transport, message_rx)
244 }
245
246 pub fn into_parts(self) -> (W, PipeTransportReceiver<R>) {
256 (
257 self.stdin,
258 PipeTransportReceiver {
259 stdout: self.stdout,
260 message_tx: self.message_tx,
261 },
262 )
263 }
264
265 pub async fn run(&mut self) -> Result<()> {
282 const CHUNK_SIZE: usize = 32_768; loop {
285 let mut len_buf = [0u8; 4];
288 self.stdout.read_exact(&mut len_buf).await.map_err(|e| {
289 Error::TransportError(format!("Failed to read length prefix: {}", e))
290 })?;
291
292 let length = u32::from_le_bytes(len_buf) as usize;
293
294 let message_buf = if length <= CHUNK_SIZE {
298 let mut buf = vec![0u8; length];
300 self.stdout
301 .read_exact(&mut buf)
302 .await
303 .map_err(|e| Error::TransportError(format!("Failed to read message: {}", e)))?;
304 buf
305 } else {
306 let mut buf = Vec::with_capacity(length);
309 let mut remaining = length;
310
311 while remaining > 0 {
312 let to_read = std::cmp::min(remaining, CHUNK_SIZE);
313 let mut chunk = vec![0u8; to_read];
314
315 self.stdout.read_exact(&mut chunk).await.map_err(|e| {
316 Error::TransportError(format!("Failed to read message chunk: {}", e))
317 })?;
318
319 buf.extend_from_slice(&chunk);
320 remaining -= to_read;
321 }
322
323 buf
324 };
325
326 let message: JsonValue = serde_json::from_slice(&message_buf)
329 .map_err(|e| Error::ProtocolError(format!("Failed to parse JSON: {}", e)))?;
330
331 if self.message_tx.send(message).is_err() {
334 break;
336 }
337 }
338
339 Ok(())
340 }
341
342 async fn send_internal(&mut self, message: JsonValue) -> Result<()> {
357 let json_bytes = serde_json::to_vec(&message)
359 .map_err(|e| Error::TransportError(format!("Failed to serialize JSON: {}", e)))?;
360
361 let length = json_bytes.len() as u32;
362
363 self.stdin
366 .write_all(&length.to_le_bytes())
367 .await
368 .map_err(|e| Error::TransportError(format!("Failed to write length: {}", e)))?;
369
370 self.stdin
372 .write_all(&json_bytes)
373 .await
374 .map_err(|e| Error::TransportError(format!("Failed to write message: {}", e)))?;
375
376 self.stdin
378 .flush()
379 .await
380 .map_err(|e| Error::TransportError(format!("Failed to flush: {}", e)))?;
381
382 Ok(())
383 }
384}
385
386impl<W, R> Transport for PipeTransport<W, R>
387where
388 W: AsyncWrite + Unpin + Send + Sync,
389 R: AsyncRead + Unpin + Send + Sync,
390{
391 async fn send(&mut self, message: JsonValue) -> Result<()> {
392 self.send_internal(message).await
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use tokio::io::{AsyncReadExt, AsyncWriteExt};
400
401 #[test]
402 fn test_length_prefix_encoding() {
403 let length: u32 = 1234;
405 let bytes = length.to_le_bytes();
406
407 assert_eq!(bytes[0], (length & 0xFF) as u8);
409 assert_eq!(bytes[1], ((length >> 8) & 0xFF) as u8);
410 assert_eq!(bytes[2], ((length >> 16) & 0xFF) as u8);
411 assert_eq!(bytes[3], ((length >> 24) & 0xFF) as u8);
412
413 assert_eq!(u32::from_le_bytes(bytes), length);
415 }
416
417 #[test]
418 fn test_message_framing_format() {
419 let message = serde_json::json!({"test": "hello"});
422 let json_bytes = serde_json::to_vec(&message).unwrap();
423 let length = json_bytes.len() as u32;
424 let length_bytes = length.to_le_bytes();
425
426 let mut frame = Vec::new();
428 frame.extend_from_slice(&length_bytes);
429 frame.extend_from_slice(&json_bytes);
430
431 assert_eq!(frame.len(), 4 + json_bytes.len());
433 assert_eq!(&frame[0..4], &length_bytes);
434 assert_eq!(&frame[4..], &json_bytes);
435 }
436
437 #[tokio::test]
438 async fn test_send_message() {
439 let (stdin_read, stdin_write) = tokio::io::duplex(1024);
443 let (stdout_read, stdout_write) = tokio::io::duplex(1024);
444
445 let (_stdin_read, mut _stdout_write) = (stdin_read, stdout_write);
447 let (mut transport, _rx) = PipeTransport::new(stdin_write, stdout_read);
448
449 let test_message = serde_json::json!({
451 "id": 1,
452 "method": "test",
453 "params": {"foo": "bar"}
454 });
455
456 transport.send(test_message.clone()).await.unwrap();
458
459 let (mut read_half, _write_half) = tokio::io::split(_stdin_read);
461 let mut len_buf = [0u8; 4];
462 read_half.read_exact(&mut len_buf).await.unwrap();
463 let length = u32::from_le_bytes(len_buf) as usize;
464
465 let mut msg_buf = vec![0u8; length];
466 read_half.read_exact(&mut msg_buf).await.unwrap();
467
468 let received: serde_json::Value = serde_json::from_slice(&msg_buf).unwrap();
469 assert_eq!(received, test_message);
470 }
471
472 #[tokio::test]
473 async fn test_multiple_messages_in_sequence() {
474 let (_stdin_read, stdin_write) = tokio::io::duplex(4096);
476 let (stdout_read, mut stdout_write) = tokio::io::duplex(4096);
477
478 let (mut transport, mut rx) = PipeTransport::new(stdin_write, stdout_read);
479
480 let read_task = tokio::spawn(async move { transport.run().await });
482
483 let messages = vec![
485 serde_json::json!({"id": 1, "method": "first"}),
486 serde_json::json!({"id": 2, "method": "second"}),
487 serde_json::json!({"id": 3, "method": "third"}),
488 ];
489
490 for msg in &messages {
491 let json_bytes = serde_json::to_vec(msg).unwrap();
492 let length = json_bytes.len() as u32;
493
494 stdout_write.write_all(&length.to_le_bytes()).await.unwrap();
495 stdout_write.write_all(&json_bytes).await.unwrap();
496 }
497 stdout_write.flush().await.unwrap();
498
499 for expected in &messages {
501 let received = rx.recv().await.unwrap();
502 assert_eq!(&received, expected);
503 }
504
505 drop(stdout_write);
507 drop(rx);
508 let _ = read_task.await;
509 }
510
511 #[tokio::test]
512 async fn test_large_message() {
513 let (_stdin_read, stdin_write) = tokio::io::duplex(1024 * 1024); let (stdout_read, mut stdout_write) = tokio::io::duplex(1024 * 1024);
515
516 let (mut transport, mut rx) = PipeTransport::new(stdin_write, stdout_read);
517
518 let read_task = tokio::spawn(async move { transport.run().await });
520
521 let large_string = "x".repeat(100_000);
523 let large_message = serde_json::json!({
524 "id": 1,
525 "data": large_string
526 });
527
528 let json_bytes = serde_json::to_vec(&large_message).unwrap();
529 let length = json_bytes.len() as u32;
530
531 assert!(length > 32_768, "Test message should be > 32KB");
533
534 stdout_write.write_all(&length.to_le_bytes()).await.unwrap();
535 stdout_write.write_all(&json_bytes).await.unwrap();
536 stdout_write.flush().await.unwrap();
537
538 let received = rx.recv().await.unwrap();
540 assert_eq!(received, large_message);
541
542 drop(stdout_write);
543 drop(rx);
544 let _ = read_task.await;
545 }
546
547 #[tokio::test]
548 async fn test_malformed_length_prefix() {
549 let (_stdin_read, stdin_write) = tokio::io::duplex(1024);
550 let (stdout_read, mut stdout_write) = tokio::io::duplex(1024);
551
552 let (mut transport, _rx) = PipeTransport::new(stdin_write, stdout_read);
553
554 stdout_write.write_all(&[0x01, 0x02]).await.unwrap();
557 stdout_write.flush().await.unwrap();
558
559 drop(stdout_write);
561
562 let result = transport.run().await;
564 assert!(result.is_err());
565 assert!(result
566 .unwrap_err()
567 .to_string()
568 .contains("Failed to read length prefix"));
569 }
570
571 #[tokio::test]
572 async fn test_broken_pipe() {
573 let (_stdin_read, stdin_write) = tokio::io::duplex(1024);
574 let (stdout_read, stdout_write) = tokio::io::duplex(1024);
575
576 let (mut transport, _rx) = PipeTransport::new(stdin_write, stdout_read);
577
578 drop(stdout_write);
580
581 let read_task = tokio::spawn(async move { transport.run().await });
583
584 let result = read_task.await.unwrap();
586 assert!(result.is_err());
587 }
588
589 #[tokio::test]
590 async fn test_graceful_shutdown() {
591 let (_stdin_read, stdin_write) = tokio::io::duplex(1024);
592 let (stdout_read, mut stdout_write) = tokio::io::duplex(1024);
593
594 let (mut transport, mut rx) = PipeTransport::new(stdin_write, stdout_read);
595
596 let read_task = tokio::spawn(async move { transport.run().await });
598
599 let message = serde_json::json!({"id": 1, "method": "test"});
601 let json_bytes = serde_json::to_vec(&message).unwrap();
602 let length = json_bytes.len() as u32;
603
604 stdout_write.write_all(&length.to_le_bytes()).await.unwrap();
605 stdout_write.write_all(&json_bytes).await.unwrap();
606 stdout_write.flush().await.unwrap();
607
608 let received = rx.recv().await.unwrap();
610 assert_eq!(received, message);
611
612 drop(rx);
614
615 drop(stdout_write);
617
618 let result = read_task.await.unwrap();
620 assert!(result.is_ok() || result.unwrap_err().to_string().contains("Failed to read"));
622 }
623}