use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::error::{FrameError, SrxError};
pub const MAX_STREAM_PAYLOAD: u32 = 4 * 1024 * 1024;
pub async fn write_length_prefixed<W: AsyncWriteExt + Unpin>(
writer: &mut W,
payload: &[u8],
) -> crate::error::Result<()> {
let len = u32::try_from(payload.len()).map_err(|_| {
SrxError::Frame(FrameError::FrameTooLarge {
size: payload.len(),
max: MAX_STREAM_PAYLOAD as usize,
})
})?;
if len > MAX_STREAM_PAYLOAD {
return Err(SrxError::Frame(FrameError::FrameTooLarge {
size: payload.len(),
max: MAX_STREAM_PAYLOAD as usize,
}));
}
writer.write_all(&len.to_be_bytes()).await?;
writer.write_all(payload).await?;
writer.flush().await?;
Ok(())
}
pub async fn read_length_prefixed<R: AsyncReadExt + Unpin>(
reader: &mut R,
) -> crate::error::Result<Vec<u8>> {
let mut hdr = [0u8; 4];
reader.read_exact(&mut hdr).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
SrxError::Transport(crate::error::TransportError::ChannelClosed)
} else {
SrxError::Io(e)
}
})?;
let len = u32::from_be_bytes(hdr);
if len > MAX_STREAM_PAYLOAD {
return Err(SrxError::Frame(FrameError::FrameTooLarge {
size: len as usize,
max: MAX_STREAM_PAYLOAD as usize,
}));
}
let mut buf = vec![0u8; len as usize];
reader.read_exact(&mut buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
SrxError::Frame(FrameError::Corrupted("truncated stream frame".into()))
} else {
SrxError::Io(e)
}
})?;
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn roundtrip_duplex() {
let (mut left, mut right) = duplex(256 * 1024);
let payload = b"hello-srx-frame";
write_length_prefixed(&mut left, payload).await.unwrap();
let got = read_length_prefixed(&mut right).await.unwrap();
assert_eq!(got, payload);
}
#[tokio::test]
async fn rejects_oversized_length() {
let mut hdr = vec![0xFFu8; 4];
hdr[0] = 0xFF;
hdr[1] = 0xFF;
hdr[2] = 0xFF;
hdr[3] = 0xFF;
let mut r = std::io::Cursor::new(hdr);
assert!(read_length_prefixed(&mut r).await.is_err());
}
}