use std::io::{self, Cursor, Read};
pub const SKIPPABLE_FRAME_MAGIC: u32 = 0x184D2A50;
pub const MAX_HEADER_SIZE: usize = 16;
pub fn read_full_or_eof<R: Read>(reader: &mut R, buf: &mut [u8]) -> io::Result<usize> {
let mut total = 0;
while total < buf.len() {
match reader.read(&mut buf[total..]) {
Ok(0) => break,
Ok(n) => total += n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(total)
}
pub enum FrameReader<R: Read, const HEADER_SIZE: usize> {
Empty,
Standard {
reader: R,
header_buf: Cursor<[u8; MAX_HEADER_SIZE]>,
header_len: usize,
header_done: bool,
},
Skippable {
reader: R,
remaining: u32,
frame_done: bool,
},
}
impl<R: Read, const HEADER_SIZE: usize> FrameReader<R, HEADER_SIZE> {
pub fn new_standard(reader: R, header: [u8; MAX_HEADER_SIZE], header_len: usize) -> Self {
Self::Standard {
reader,
header_buf: Cursor::new(header),
header_len,
header_done: false,
}
}
pub fn new_skippable(reader: R, compressed_size: u32) -> Self {
Self::Skippable {
reader,
remaining: compressed_size,
frame_done: false,
}
}
#[must_use = "ignoring the result may cause data loss"]
pub fn try_read_next_frame<F>(&mut self, validate_header: F) -> io::Result<Option<u32>>
where
F: Fn(&[u8; HEADER_SIZE]) -> Option<u32>,
{
match self {
Self::Empty | Self::Standard { .. } => Ok(None),
Self::Skippable {
reader,
remaining,
frame_done,
} => {
if !*frame_done {
return Ok(None);
}
let mut header = [0u8; MAX_HEADER_SIZE];
let n = read_full_or_eof(reader, &mut header[..HEADER_SIZE])?;
if n == 0 {
return Ok(None);
}
if n < HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"truncated skippable frame header: expected {} bytes, got {}",
HEADER_SIZE, n
),
));
}
let header_arr: [u8; HEADER_SIZE] = header[..HEADER_SIZE]
.try_into()
.expect("size checked above");
match validate_header(&header_arr) {
Some(compressed_size) => {
*remaining = compressed_size;
*frame_done = false;
Ok(Some(compressed_size))
}
None => {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid skippable frame header after valid frame",
))
}
}
}
}
}
}
impl<R: Read, const HEADER_SIZE: usize> Read for FrameReader<R, HEADER_SIZE> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Empty => Ok(0),
Self::Standard {
reader,
header_buf,
header_len,
header_done,
} => {
if !*header_done {
let pos = header_buf.position() as usize;
let remaining_header = *header_len - pos;
if remaining_header > 0 {
let to_copy = buf.len().min(remaining_header);
let inner = header_buf.get_ref();
buf[..to_copy].copy_from_slice(&inner[pos..pos + to_copy]);
header_buf.set_position((pos + to_copy) as u64);
return Ok(to_copy);
}
*header_done = true;
}
reader.read(buf)
}
Self::Skippable {
reader,
remaining,
frame_done,
} => {
if *frame_done || *remaining == 0 {
*frame_done = true;
return Ok(0);
}
let to_read = buf.len().min(*remaining as usize);
let n = reader.read(&mut buf[..to_read])?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"truncated skippable frame payload: {} bytes remaining",
*remaining
),
));
}
*remaining -= n as u32;
if *remaining == 0 {
*frame_done = true;
}
Ok(n)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
const TEST_HEADER_SIZE: usize = 12;
fn validate_test_header(header: &[u8; TEST_HEADER_SIZE]) -> Option<u32> {
let magic = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let frame_size = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
if magic == SKIPPABLE_FRAME_MAGIC && frame_size == 4 {
Some(u32::from_le_bytes([
header[8], header[9], header[10], header[11],
]))
} else {
None
}
}
#[test]
fn standard_reader_replays_header() {
let data = b"world!";
let mut header = [0u8; MAX_HEADER_SIZE];
header[..6].copy_from_slice(b"hello ");
let reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_standard(Cursor::new(data.as_slice()), header, 6);
let mut output = Vec::new();
let mut reader = reader;
reader.read_to_end(&mut output).unwrap();
assert_eq!(output, b"hello world!");
}
#[test]
fn skippable_reader_limits_bytes() {
let data = b"hello world! extra garbage";
let reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data.as_slice()), 12);
let mut output = Vec::new();
let mut reader = reader;
reader.read_to_end(&mut output).unwrap();
assert_eq!(output, b"hello world!");
}
#[test]
fn skippable_reader_truncated_payload() {
let data = b"short"; let reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data.as_slice()), 12);
let mut output = Vec::new();
let mut reader = reader;
let result = reader.read_to_end(&mut output);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
}
#[test]
fn try_read_next_frame_valid() {
let mut data = Vec::new();
data.extend_from_slice(&SKIPPABLE_FRAME_MAGIC.to_le_bytes());
data.extend_from_slice(&4u32.to_le_bytes()); data.extend_from_slice(&100u32.to_le_bytes());
let mut reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data), 0);
if let FrameReader::Skippable { frame_done, .. } = &mut reader {
*frame_done = true;
}
let result = reader.try_read_next_frame(validate_test_header);
assert_eq!(result.unwrap(), Some(100));
}
#[test]
fn try_read_next_frame_truncated_header() {
let data = vec![0x50, 0x2A, 0x4D, 0x18, 0x04];
let mut reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data), 0);
if let FrameReader::Skippable { frame_done, .. } = &mut reader {
*frame_done = true;
}
let result = reader.try_read_next_frame(validate_test_header);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
}
#[test]
fn try_read_next_frame_invalid_magic() {
let mut data = Vec::new();
data.extend_from_slice(&0xDEADBEEFu32.to_le_bytes()); data.extend_from_slice(&4u32.to_le_bytes());
data.extend_from_slice(&100u32.to_le_bytes());
let mut reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data), 0);
if let FrameReader::Skippable { frame_done, .. } = &mut reader {
*frame_done = true;
}
let result = reader.try_read_next_frame(validate_test_header);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
}
#[test]
fn try_read_next_frame_clean_eof() {
let data: Vec<u8> = Vec::new();
let mut reader: FrameReader<_, TEST_HEADER_SIZE> =
FrameReader::new_skippable(Cursor::new(data), 0);
if let FrameReader::Skippable { frame_done, .. } = &mut reader {
*frame_done = true;
}
let result = reader.try_read_next_frame(validate_test_header);
assert_eq!(result.unwrap(), None);
}
#[test]
fn empty_reader() {
let mut reader: FrameReader<Cursor<Vec<u8>>, TEST_HEADER_SIZE> = FrameReader::Empty;
let mut buf = [0u8; 10];
assert_eq!(reader.read(&mut buf).unwrap(), 0);
}
}