#![cfg_attr(coverage_nightly, coverage(off))]
use bytes::{Bytes, BytesMut};
use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
pub struct StdioTransport {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
write_buf: BytesMut,
read_buf: BytesMut,
sequence_num: AtomicU64,
}
impl StdioTransport {
pub const PIPE_BUF: usize = 4096; const FRAME_HEADER_SIZE: usize = 16; const MAGIC: &'static [u8; 4] = b"PMAT";
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
Self {
stdin,
stdout: BufReader::new(stdout),
write_buf: BytesMut::with_capacity(Self::PIPE_BUF),
read_buf: BytesMut::with_capacity(Self::PIPE_BUF),
sequence_num: AtomicU64::new(0),
}
}
pub async fn send_atomic(&mut self, payload: &[u8]) -> io::Result<()> {
let seq = self.sequence_num.fetch_add(1, Ordering::AcqRel);
let max_payload = Self::PIPE_BUF - Self::FRAME_HEADER_SIZE;
if payload.len() > max_payload {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Payload {} bytes exceeds atomic limit {}",
payload.len(),
max_payload
),
));
}
self.write_buf.clear();
self.write_buf.extend_from_slice(Self::MAGIC);
self.write_buf.extend_from_slice(&seq.to_le_bytes());
self.write_buf
.extend_from_slice(&(payload.len() as u32).to_le_bytes());
self.write_buf.extend_from_slice(payload);
self.stdin.write_all(&self.write_buf).await?;
self.stdin.flush().await?;
Ok(())
}
pub async fn read_frame(&mut self) -> io::Result<Bytes> {
let mut header = [0u8; Self::FRAME_HEADER_SIZE];
self.stdout.read_exact(&mut header).await?;
if &header[0..4] != Self::MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid magic bytes in frame",
));
}
let _seq = u64::from_le_bytes(header[4..12].try_into().expect("internal error"));
let len = u32::from_le_bytes(header[12..16].try_into().expect("internal error")) as usize;
if len > Self::PIPE_BUF {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Frame length {} exceeds maximum {}", len, Self::PIPE_BUF),
));
}
self.read_buf.clear();
self.read_buf.resize(len, 0);
self.stdout.read_exact(&mut self.read_buf).await?;
Ok(self.read_buf.split().freeze())
}
pub async fn write_frame(&mut self, msg: &[u8]) -> io::Result<()> {
self.send_atomic(msg).await
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_header_size() {
assert_eq!(StdioTransport::FRAME_HEADER_SIZE, 16);
}
#[test]
fn test_max_payload_size() {
let max = StdioTransport::PIPE_BUF - StdioTransport::FRAME_HEADER_SIZE;
assert_eq!(max, 4080);
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod transport_atomicity_proof {
use super::*;
#[tokio::test]
async fn test_frame_round_trip() {
let (mut reader, mut writer) = tokio::io::duplex(8192);
let payload = b"test message";
let mut frame = Vec::new();
frame.extend_from_slice(StdioTransport::MAGIC);
frame.extend_from_slice(&0u64.to_le_bytes());
frame.extend_from_slice(&(payload.len() as u32).to_le_bytes());
frame.extend_from_slice(payload);
writer.write_all(&frame).await.expect("internal error");
let mut header = [0u8; 16];
reader
.read_exact(&mut header)
.await
.expect("internal error");
assert_eq!(&header[0..4], StdioTransport::MAGIC);
let len = u32::from_le_bytes(header[12..16].try_into().expect("internal error")) as usize;
assert_eq!(len, payload.len());
let mut read_payload = vec![0u8; len];
reader
.read_exact(&mut read_payload)
.await
.expect("internal error");
assert_eq!(&read_payload, payload);
}
}