use serde::de::DeserializeOwned;
use serde::Serialize;
pub const MAX_FRAME_BYTES: usize = 1 << 20;
#[derive(Debug, thiserror::Error)]
pub enum FrameError {
#[error("frame body too large: {got} bytes, max {max}")]
TooLarge { got: usize, max: usize },
#[error("serialization failed: {0}")]
Serialize(#[from] serde_json::Error),
#[error("IO failure during frame transfer: {0}")]
Io(#[from] std::io::Error),
}
pub fn encode_frame<T: Serialize>(value: &T) -> Result<Vec<u8>, FrameError> {
let body = serde_json::to_vec(value)?;
if body.len() > MAX_FRAME_BYTES {
return Err(FrameError::TooLarge {
got: body.len(),
max: MAX_FRAME_BYTES,
});
}
let mut out = Vec::with_capacity(4 + body.len());
let len = u32::try_from(body.len()).expect("MAX_FRAME_BYTES fits in u32");
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(&body);
Ok(out)
}
pub fn decode_frame<T: DeserializeOwned>(buf: &[u8]) -> Result<Option<(T, usize)>, FrameError> {
if buf.len() < 4 {
return Ok(None);
}
let len_bytes: [u8; 4] = buf[0..4].try_into().expect("4 bytes");
let body_len = u32::from_be_bytes(len_bytes) as usize;
if body_len > MAX_FRAME_BYTES {
return Err(FrameError::TooLarge {
got: body_len,
max: MAX_FRAME_BYTES,
});
}
if buf.len() < 4 + body_len {
return Ok(None);
}
let body = &buf[4..4 + body_len];
let value: T = serde_json::from_slice(body)?;
Ok(Some((value, 4 + body_len)))
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct Msg {
kind: String,
value: i32,
}
#[test]
fn encode_decode_round_trip() {
let m = Msg {
kind: "hello".into(),
value: 42,
};
let bytes = encode_frame(&m).unwrap();
let (decoded, consumed): (Msg, usize) = decode_frame(&bytes).unwrap().unwrap();
assert_eq!(decoded, m);
assert_eq!(consumed, bytes.len());
}
#[test]
fn decode_with_insufficient_prefix_returns_none() {
let buf = [0x00u8; 3];
let res: Result<Option<(Msg, usize)>, _> = decode_frame(&buf);
assert!(matches!(res, Ok(None)));
}
#[test]
fn decode_with_partial_body_returns_none() {
let m = Msg {
kind: "x".into(),
value: 1,
};
let mut bytes = encode_frame(&m).unwrap();
bytes.truncate(bytes.len() - 1); let res: Result<Option<(Msg, usize)>, _> = decode_frame(&bytes);
assert!(matches!(res, Ok(None)));
}
#[test]
fn decode_with_extra_bytes_consumes_only_one_frame() {
let m = Msg {
kind: "a".into(),
value: 1,
};
let mut bytes = encode_frame(&m).unwrap();
bytes.extend_from_slice(b"trailing-garbage-not-part-of-frame");
let original_frame_len = bytes.len() - "trailing-garbage-not-part-of-frame".len();
let (decoded, consumed): (Msg, usize) = decode_frame(&bytes).unwrap().unwrap();
assert_eq!(decoded, m);
assert_eq!(consumed, original_frame_len);
}
#[test]
fn oversized_length_prefix_returns_too_large_error() {
let mut buf = vec![0xFFu8; 4]; buf.extend_from_slice(b"{}");
let res: Result<Option<(Msg, usize)>, _> = decode_frame(&buf);
assert!(matches!(res, Err(FrameError::TooLarge { .. })));
}
#[test]
fn invalid_utf8_body_returns_serialize_error() {
let mut buf = 4u32.to_be_bytes().to_vec();
buf.extend_from_slice(&[0xC0, 0xC0, 0xC0, 0xC0]);
let res: Result<Option<(Msg, usize)>, _> = decode_frame(&buf);
assert!(matches!(res, Err(FrameError::Serialize(_))));
}
}