use super::error::WireError;
use std::io::{self, Write};
const LEN_PREFIX: usize = 4;
const KIND_SIZE: usize = 1;
const MIN_FRAME_SIZE: usize = LEN_PREFIX + KIND_SIZE;
#[inline]
pub fn encode_frame<W: Write>(kind: u8, payload: &[u8], out: &mut W) -> io::Result<()> {
let body_len_usize = payload
.len()
.checked_add(KIND_SIZE)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "frame payload too large"))?;
let body_len = u32::try_from(body_len_usize)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame payload too large"))?;
out.write_all(&body_len.to_le_bytes())?;
out.write_all(&[kind])?;
out.write_all(payload)?;
Ok(())
}
#[inline]
pub fn decode_frame(buf: &[u8]) -> Result<(u8, &[u8], usize), WireError> {
if buf.len() < MIN_FRAME_SIZE {
return Err(WireError::Truncated);
}
let len_bytes = buf.get(..LEN_PREFIX).ok_or(WireError::Truncated)?;
let mut len_arr = [0u8; LEN_PREFIX];
len_arr.copy_from_slice(len_bytes);
let body_len = u32::from_le_bytes(len_arr) as usize;
if body_len < KIND_SIZE {
return Err(WireError::InvalidPayload("frame body shorter than kind"));
}
let total = LEN_PREFIX
.checked_add(body_len)
.ok_or(WireError::Truncated)?;
if buf.len() < total {
return Err(WireError::Truncated);
}
let kind = *buf.get(LEN_PREFIX).ok_or(WireError::Truncated)?;
let payload_start = LEN_PREFIX + KIND_SIZE;
let payload_end = LEN_PREFIX + body_len;
let payload = buf
.get(payload_start..payload_end)
.ok_or(WireError::Truncated)?;
Ok((kind, payload, total))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_empty_payload() {
let mut buf = Vec::new();
encode_frame(0x01, &[], &mut buf).expect("encode empty payload");
let (kind, payload, consumed) = decode_frame(&buf).expect("decode empty payload");
assert_eq!(kind, 0x01);
assert!(payload.is_empty());
assert_eq!(consumed, buf.len());
}
#[test]
fn roundtrip_with_payload() {
let mut buf = Vec::new();
let payload = [1u8, 2, 3, 4, 5];
encode_frame(0x42, &payload, &mut buf).expect("encode payload");
let (kind, decoded, consumed) = decode_frame(&buf).expect("decode payload");
assert_eq!(kind, 0x42);
assert_eq!(decoded, &payload);
assert_eq!(consumed, buf.len());
}
#[test]
fn truncated_header_returns_truncated() {
let buf = [0x05, 0x00, 0x00];
assert_eq!(decode_frame(&buf), Err(WireError::Truncated));
}
#[test]
fn truncated_payload_returns_truncated() {
let buf = [0x0A, 0x00, 0x00, 0x00, 0x01];
assert_eq!(decode_frame(&buf), Err(WireError::Truncated));
}
#[test]
fn zero_body_length_is_invalid() {
let buf = [0x00, 0x00, 0x00, 0x00, 0x00];
assert!(matches!(
decode_frame(&buf),
Err(WireError::InvalidPayload(_))
));
}
#[test]
fn decode_consumes_only_one_frame_at_a_time() {
let mut buf = Vec::new();
encode_frame(0x01, &[0xAA, 0xBB], &mut buf).expect("encode frame 1");
encode_frame(0x02, &[0xCC], &mut buf).expect("encode frame 2");
let (k1, p1, used1) = decode_frame(&buf).expect("decode frame 1");
assert_eq!(k1, 0x01);
assert_eq!(p1, &[0xAA, 0xBB]);
let rest = buf.get(used1..).expect("rest of buffer");
let (k2, p2, used2) = decode_frame(rest).expect("decode frame 2");
assert_eq!(k2, 0x02);
assert_eq!(p2, &[0xCC]);
assert_eq!(used1 + used2, buf.len());
}
}