use bytes::{Buf, BufMut, Bytes, BytesMut};
use thiserror::Error;
use crate::CodecKind;
pub const FRAME_MAGIC: &[u8; 4] = b"S4F2";
pub const PADDING_MAGIC: &[u8; 4] = b"S4P1";
pub const FRAME_HEADER_BYTES: usize = 4 + 4 + 8 + 8 + 4;
pub const PADDING_HEADER_BYTES: usize = 4 + 8;
pub const S3_MULTIPART_MIN_PART_BYTES: usize = 5 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
pub codec: CodecKind,
pub original_size: u64,
pub compressed_size: u64,
pub crc32c: u32,
}
#[derive(Debug, Error)]
pub enum FrameError {
#[error("frame too short: need at least {FRAME_HEADER_BYTES} bytes, have {0}")]
TooShort(usize),
#[error("bad frame magic: expected {expected:?}, got {got:?}")]
BadMagic { expected: [u8; 4], got: [u8; 4] },
#[error("frame compressed_size {compressed_size} exceeds remaining buffer {remaining}")]
PayloadTruncated {
compressed_size: u64,
remaining: usize,
},
#[error("unknown codec id {0} in frame header (decoder out of date?)")]
UnknownCodec(u32),
}
pub fn write_frame(dst: &mut BytesMut, header: FrameHeader, payload: &[u8]) {
debug_assert_eq!(payload.len() as u64, header.compressed_size);
dst.reserve(FRAME_HEADER_BYTES + payload.len());
dst.put_slice(FRAME_MAGIC);
dst.put_u32_le(header.codec.id());
dst.put_u64_le(header.original_size);
dst.put_u64_le(header.compressed_size);
dst.put_u32_le(header.crc32c);
dst.put_slice(payload);
}
pub fn pad_to_minimum(dst: &mut BytesMut, min_total: usize) {
if dst.len() >= min_total {
return;
}
let need = min_total - dst.len();
let payload_len = need.saturating_sub(PADDING_HEADER_BYTES);
dst.reserve(PADDING_HEADER_BYTES + payload_len);
dst.put_slice(PADDING_MAGIC);
dst.put_u64_le(payload_len as u64);
dst.put_bytes(0, payload_len);
}
pub fn read_frame(mut input: Bytes) -> Result<(FrameHeader, Bytes, Bytes), FrameError> {
if input.len() < FRAME_HEADER_BYTES {
return Err(FrameError::TooShort(input.len()));
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&input[..4]);
if &magic != FRAME_MAGIC {
return Err(FrameError::BadMagic {
expected: *FRAME_MAGIC,
got: magic,
});
}
input.advance(4);
let codec_id = input.get_u32_le();
let codec = CodecKind::from_id(codec_id).ok_or(FrameError::UnknownCodec(codec_id))?;
let original_size = input.get_u64_le();
let compressed_size = input.get_u64_le();
let crc32c = input.get_u32_le();
if (compressed_size as usize) > input.len() {
return Err(FrameError::PayloadTruncated {
compressed_size,
remaining: input.len(),
});
}
let payload = input.split_to(compressed_size as usize);
Ok((
FrameHeader {
codec,
original_size,
compressed_size,
crc32c,
},
payload,
input,
))
}
pub struct FrameIter {
rest: Bytes,
fused: bool,
}
impl FrameIter {
pub fn new(input: Bytes) -> Self {
Self {
rest: input,
fused: false,
}
}
}
impl Iterator for FrameIter {
type Item = Result<(FrameHeader, Bytes), FrameError>;
fn next(&mut self) -> Option<Self::Item> {
if self.fused {
return None;
}
loop {
if self.rest.is_empty() {
return None;
}
if self.rest.len() < 4 {
self.fused = true;
return Some(Err(FrameError::TooShort(self.rest.len())));
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&self.rest[..4]);
if &magic == PADDING_MAGIC {
if self.rest.len() < PADDING_HEADER_BYTES {
self.fused = true;
return Some(Err(FrameError::TooShort(self.rest.len())));
}
self.rest.advance(4);
let pad_len = self.rest.get_u64_le();
if (pad_len as usize) > self.rest.len() {
self.fused = true;
return Some(Err(FrameError::PayloadTruncated {
compressed_size: pad_len,
remaining: self.rest.len(),
}));
}
self.rest.advance(pad_len as usize);
continue;
}
return match read_frame(std::mem::take(&mut self.rest)) {
Ok((hdr, payload, remainder)) => {
self.rest = remainder;
Some(Ok((hdr, payload)))
}
Err(e) => {
self.fused = true;
Some(Err(e))
}
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn frame_roundtrip_single() {
let payload = Bytes::from_static(b"hello frame payload");
let header = FrameHeader {
codec: CodecKind::CpuZstd,
original_size: 999,
compressed_size: payload.len() as u64,
crc32c: 0xdead_beef,
};
let mut buf = BytesMut::new();
write_frame(&mut buf, header, &payload);
assert_eq!(buf.len(), FRAME_HEADER_BYTES + payload.len());
let bytes = buf.freeze();
let (got_header, got_payload, rest) = read_frame(bytes).unwrap();
assert_eq!(got_header, header);
assert_eq!(got_payload, payload);
assert!(rest.is_empty());
}
#[test]
fn frame_iter_walks_all_frames_with_mixed_codecs() {
let codecs = [
CodecKind::Passthrough,
CodecKind::CpuZstd,
CodecKind::NvcompZstd,
CodecKind::NvcompBitcomp,
CodecKind::DietGpuAns,
];
let mut buf = BytesMut::new();
for (i, codec) in codecs.iter().enumerate() {
let payload = vec![i as u8; (i + 1) * 4];
let h = FrameHeader {
codec: *codec,
original_size: 100 + i as u64,
compressed_size: payload.len() as u64,
crc32c: i as u32,
};
write_frame(&mut buf, h, &payload);
}
let total = FrameIter::new(buf.freeze())
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(total.len(), 5);
for (i, (h, payload)) in total.iter().enumerate() {
assert_eq!(h.codec, codecs[i], "codec must be preserved per frame");
assert_eq!(h.original_size, 100 + i as u64);
assert_eq!(h.crc32c, i as u32);
assert_eq!(payload.len(), (i + 1) * 4);
}
}
#[test]
fn frame_bad_magic_rejected() {
let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
buf.put_slice(b"BAD!");
buf.put_u32_le(0); buf.put_u64_le(0);
buf.put_u64_le(0);
buf.put_u32_le(0);
let err = read_frame(buf.freeze()).unwrap_err();
assert!(matches!(err, FrameError::BadMagic { .. }));
}
#[test]
fn frame_truncated_rejected() {
let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
buf.put_slice(FRAME_MAGIC);
buf.put_u32_le(CodecKind::CpuZstd.id());
buf.put_u64_le(100);
buf.put_u64_le(100);
buf.put_u32_le(0);
let err = read_frame(buf.freeze()).unwrap_err();
assert!(matches!(err, FrameError::PayloadTruncated { .. }));
}
#[test]
fn frame_unknown_codec_rejected() {
let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
buf.put_slice(FRAME_MAGIC);
buf.put_u32_le(99); buf.put_u64_le(0);
buf.put_u64_le(0);
buf.put_u32_le(0);
let err = read_frame(buf.freeze()).unwrap_err();
assert!(matches!(err, FrameError::UnknownCodec(99)));
}
#[test]
fn frame_too_short_for_header_rejected() {
let buf = Bytes::from_static(b"shortdata");
let err = read_frame(buf).unwrap_err();
assert!(matches!(err, FrameError::TooShort(_)));
}
#[test]
fn padding_skipped_by_iter() {
let mut buf = BytesMut::new();
let p1 = Bytes::from_static(b"first frame");
write_frame(
&mut buf,
FrameHeader {
codec: CodecKind::CpuZstd,
original_size: 11,
compressed_size: p1.len() as u64,
crc32c: 0,
},
&p1,
);
pad_to_minimum(&mut buf, 1024);
assert!(buf.len() >= 1024);
let p2 = Bytes::from_static(b"second frame");
write_frame(
&mut buf,
FrameHeader {
codec: CodecKind::CpuZstd,
original_size: 12,
compressed_size: p2.len() as u64,
crc32c: 0,
},
&p2,
);
let frames: Vec<_> = FrameIter::new(buf.freeze())
.collect::<Result<_, _>>()
.unwrap();
assert_eq!(
frames.len(),
2,
"padding must be skipped, only data yielded"
);
assert_eq!(frames[0].1, p1);
assert_eq!(frames[1].1, p2);
}
#[test]
fn pad_to_minimum_is_noop_when_already_above() {
let mut buf = BytesMut::new();
buf.extend_from_slice(&[0u8; 1024]);
pad_to_minimum(&mut buf, 100);
assert_eq!(buf.len(), 1024);
}
#[test]
fn pad_to_minimum_grows_to_target() {
let mut buf = BytesMut::new();
write_frame(
&mut buf,
FrameHeader {
codec: CodecKind::Passthrough,
original_size: 0,
compressed_size: 0,
crc32c: 0,
},
&[],
);
let before = buf.len();
pad_to_minimum(&mut buf, 5_000_000);
assert!(buf.len() >= 5_000_000);
assert!(buf.len() < 5_000_000 + 64, "no excessive overshoot");
assert!(buf.len() > before);
}
}