use anyhow::Result;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
const MAX_FRAME_BYTES: usize = 128 * 1024;
const MAX_LENGTH_DIGITS: usize = 7;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Framing {
OctetCounted,
LfTerminated,
}
#[derive(Debug, thiserror::Error)]
pub enum FramingError {
#[error("malformed frame: unexpected first byte 0x{0:02x}")]
Malformed(u8),
#[error("octet-counted length prefix exceeds {MAX_LENGTH_DIGITS} digits")]
LengthPrefixTooLong,
#[error("frame exceeds maximum size of {MAX_FRAME_BYTES} bytes")]
FrameTooLarge,
#[error("invalid octet-counted length prefix: {0}")]
InvalidLength(String),
#[error(transparent)]
Io(#[from] std::io::Error),
}
pub async fn detect_framing<R>(reader: &mut R) -> Result<Option<Framing>, FramingError>
where
R: AsyncBufRead + Unpin,
{
let buf = reader.fill_buf().await?;
let Some(&first) = buf.first() else {
return Ok(None);
};
match first {
b'0'..=b'9' => Ok(Some(Framing::OctetCounted)),
b'<' => Ok(Some(Framing::LfTerminated)),
other => Err(FramingError::Malformed(other)),
}
}
pub async fn read_frame<R>(
reader: &mut R,
framing: Framing,
) -> Result<Option<Vec<u8>>, FramingError>
where
R: AsyncBufRead + Unpin,
{
match framing {
Framing::OctetCounted => read_octet_counted(reader).await,
Framing::LfTerminated => read_lf_terminated(reader).await,
}
}
async fn read_octet_counted<R>(reader: &mut R) -> Result<Option<Vec<u8>>, FramingError>
where
R: AsyncBufRead + Unpin,
{
let mut digits = Vec::with_capacity(MAX_LENGTH_DIGITS);
loop {
let mut byte = [0u8; 1];
match reader.read_exact(&mut byte).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return if digits.is_empty() {
Ok(None)
} else {
Err(FramingError::Io(e))
};
}
Err(e) => return Err(FramingError::Io(e)),
}
match byte[0] {
b' ' if !digits.is_empty() => break,
d @ b'0'..=b'9' => {
if digits.len() >= MAX_LENGTH_DIGITS {
return Err(FramingError::LengthPrefixTooLong);
}
digits.push(d);
}
other => {
return Err(FramingError::InvalidLength(format!(
"unexpected byte 0x{other:02x} in length prefix"
)));
}
}
}
let len_str =
std::str::from_utf8(&digits).map_err(|e| FramingError::InvalidLength(e.to_string()))?;
let len: usize = len_str
.parse()
.map_err(|e: std::num::ParseIntError| FramingError::InvalidLength(e.to_string()))?;
if len > MAX_FRAME_BYTES {
return Err(FramingError::FrameTooLarge);
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).await?;
Ok(Some(buf))
}
async fn read_lf_terminated<R>(reader: &mut R) -> Result<Option<Vec<u8>>, FramingError>
where
R: AsyncBufRead + Unpin,
{
loop {
let mut buf = Vec::with_capacity(512);
let n = reader.read_until(b'\n', &mut buf).await?;
if n == 0 {
return Ok(None);
}
if buf.len() > MAX_FRAME_BYTES {
return Err(FramingError::FrameTooLarge);
}
if buf.last() == Some(&b'\n') {
buf.pop();
}
if buf.last() == Some(&b'\r') {
buf.pop();
}
if !buf.is_empty() {
return Ok(Some(buf));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::BufReader;
#[tokio::test]
async fn detects_octet_counted_from_digit() {
let data: &[u8] = b"42 hello";
let mut r = BufReader::new(data);
assert_eq!(
detect_framing(&mut r).await.unwrap(),
Some(Framing::OctetCounted)
);
}
#[tokio::test]
async fn detects_lf_from_angle_bracket() {
let data: &[u8] = b"<14>foo\n";
let mut r = BufReader::new(data);
assert_eq!(
detect_framing(&mut r).await.unwrap(),
Some(Framing::LfTerminated)
);
}
#[tokio::test]
async fn detects_eof_on_empty_stream() {
let data: &[u8] = b"";
let mut r = BufReader::new(data);
assert_eq!(detect_framing(&mut r).await.unwrap(), None);
}
#[tokio::test]
async fn rejects_garbage_first_byte() {
let data: &[u8] = b"garbage";
let mut r = BufReader::new(data);
let err = detect_framing(&mut r).await.unwrap_err();
assert!(matches!(err, FramingError::Malformed(_)));
}
#[tokio::test]
async fn reads_octet_counted_single_frame() {
let data: &[u8] = b"5 hello";
let mut r = BufReader::new(data);
let frame = read_frame(&mut r, Framing::OctetCounted).await.unwrap();
assert_eq!(frame.as_deref(), Some(&b"hello"[..]));
}
#[tokio::test]
async fn reads_octet_counted_two_frames() {
let data: &[u8] = b"5 hello3 abc";
let mut r = BufReader::new(data);
let f1 = read_frame(&mut r, Framing::OctetCounted).await.unwrap();
let f2 = read_frame(&mut r, Framing::OctetCounted).await.unwrap();
let f3 = read_frame(&mut r, Framing::OctetCounted).await.unwrap();
assert_eq!(f1.as_deref(), Some(&b"hello"[..]));
assert_eq!(f2.as_deref(), Some(&b"abc"[..]));
assert_eq!(f3, None);
}
#[tokio::test]
async fn reads_lf_terminated_single_frame() {
let data: &[u8] = b"<14>1 - host app - - - hello\n";
let mut r = BufReader::new(data);
let frame = read_frame(&mut r, Framing::LfTerminated).await.unwrap();
assert_eq!(frame.as_deref(), Some(&b"<14>1 - host app - - - hello"[..]));
}
#[tokio::test]
async fn reads_lf_terminated_strips_crlf() {
let data: &[u8] = b"<14>foo\r\n<14>bar\n";
let mut r = BufReader::new(data);
let f1 = read_frame(&mut r, Framing::LfTerminated).await.unwrap();
let f2 = read_frame(&mut r, Framing::LfTerminated).await.unwrap();
assert_eq!(f1.as_deref(), Some(&b"<14>foo"[..]));
assert_eq!(f2.as_deref(), Some(&b"<14>bar"[..]));
}
#[tokio::test]
async fn rejects_oversize_octet_counted_length_prefix() {
let data: &[u8] = b"12345678 x";
let mut r = BufReader::new(data);
let err = read_frame(&mut r, Framing::OctetCounted).await.unwrap_err();
assert!(matches!(err, FramingError::LengthPrefixTooLong));
}
#[tokio::test]
async fn returns_none_on_clean_eof_between_frames() {
let data: &[u8] = b"";
let mut r = BufReader::new(data);
let f = read_frame(&mut r, Framing::OctetCounted).await.unwrap();
assert_eq!(f, None);
}
}