use std::io::{Read, Write};
use crate::{DEFAULT_MAX_FRAME_PAYLOAD, FrameError, PROTOCOL_VERSION};
pub const FRAME_TYPE_WINDOW: u8 = b'W';
pub const FRAME_TYPE_JSON: u8 = b'J';
pub const FRAME_TYPE_COMPRESSED: u8 = b'C';
pub const FRAME_TYPE_ACK: u8 = b'A';
pub const FRAME_TYPE_DATA_LEGACY: u8 = b'D';
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FrameType {
Window,
Json,
Compressed,
Ack,
}
impl FrameType {
#[must_use]
pub const fn wire_byte(self) -> u8 {
match self {
Self::Window => FRAME_TYPE_WINDOW,
Self::Json => FRAME_TYPE_JSON,
Self::Compressed => FRAME_TYPE_COMPRESSED,
Self::Ack => FRAME_TYPE_ACK,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum Frame {
Window {
count: u32,
},
Json {
seq: u32,
payload: Vec<u8>,
},
Compressed {
decompressed: Vec<u8>,
},
Ack {
seq: u32,
},
Unknown {
frame_type: u8,
raw: Vec<u8>,
},
}
#[must_use]
pub fn encode_window(count: u32) -> [u8; 6] {
let mut out = [0u8; 6];
out[0] = PROTOCOL_VERSION;
out[1] = FRAME_TYPE_WINDOW;
out[2..6].copy_from_slice(&count.to_be_bytes());
out
}
#[must_use]
pub fn encode_ack(seq: u32) -> [u8; 6] {
let mut out = [0u8; 6];
out[0] = PROTOCOL_VERSION;
out[1] = FRAME_TYPE_ACK;
out[2..6].copy_from_slice(&seq.to_be_bytes());
out
}
#[must_use]
pub fn encode_json_frame(seq: u32, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(10 + payload.len());
out.push(PROTOCOL_VERSION);
out.push(FRAME_TYPE_JSON);
out.extend_from_slice(&seq.to_be_bytes());
let len = u32::try_from(payload.len()).unwrap_or(u32::MAX);
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(payload);
out
}
pub fn encode_compressed(level: u32, inner_frames: &[u8]) -> Result<Vec<u8>, FrameError> {
use flate2::Compression;
use flate2::write::ZlibEncoder;
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(level));
encoder
.write_all(inner_frames)
.map_err(|e| FrameError::Compression(e.to_string()))?;
let compressed = encoder
.finish()
.map_err(|e| FrameError::Compression(e.to_string()))?;
let len = u32::try_from(compressed.len()).map_err(|_| FrameError::PayloadTooLarge {
requested: compressed.len(),
limit: u32::MAX as usize,
})?;
let mut out = Vec::with_capacity(6 + compressed.len());
out.push(PROTOCOL_VERSION);
out.push(FRAME_TYPE_COMPRESSED);
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(&compressed);
Ok(out)
}
#[derive(Debug)]
pub struct FrameDecoder {
buf: Vec<u8>,
read_pos: usize,
max_frame_payload: usize,
}
impl Default for FrameDecoder {
fn default() -> Self {
Self::new()
}
}
impl FrameDecoder {
#[must_use]
pub const fn new() -> Self {
Self::with_max_frame_payload(DEFAULT_MAX_FRAME_PAYLOAD)
}
#[must_use]
pub const fn with_max_frame_payload(max_frame_payload: usize) -> Self {
Self {
buf: Vec::new(),
read_pos: 0,
max_frame_payload,
}
}
pub fn feed(&mut self, bytes: &[u8]) {
if self.read_pos > 0 && self.read_pos >= self.buf.len() / 2 {
self.buf.drain(..self.read_pos);
self.read_pos = 0;
}
self.buf.extend_from_slice(bytes);
}
#[must_use]
pub const fn pending(&self) -> usize {
self.buf.len() - self.read_pos
}
pub fn next_frame(&mut self) -> Result<Option<Frame>, FrameError> {
let avail = self.pending();
if avail < 2 {
return Ok(None);
}
let header = &self.buf[self.read_pos..self.read_pos + 2];
if header[0] != PROTOCOL_VERSION {
return Err(FrameError::UnsupportedVersion(header[0]));
}
let frame_type = header[1];
match frame_type {
FRAME_TYPE_WINDOW => Ok(self.try_decode_window()),
FRAME_TYPE_ACK => Ok(self.try_decode_ack()),
FRAME_TYPE_JSON => self.try_decode_json(),
FRAME_TYPE_COMPRESSED => self.try_decode_compressed(),
FRAME_TYPE_DATA_LEGACY => self.try_decode_unknown_with_seq_count(b'D'),
other => Err(FrameError::UnknownFrameType(other)),
}
}
fn read_at<const M: usize>(&self, offset: usize) -> Option<[u8; M]> {
let start = self.read_pos + offset;
if self.buf.len() < start + M {
return None;
}
let mut out = [0u8; M];
out.copy_from_slice(&self.buf[start..start + M]);
Some(out)
}
fn try_decode_window(&mut self) -> Option<Frame> {
if self.pending() < 6 {
return None;
}
let count = u32::from_be_bytes(
self.read_at::<4>(2)
.expect("just verified ≥ 6 bytes pending"),
);
self.read_pos += 6;
Some(Frame::Window { count })
}
fn try_decode_ack(&mut self) -> Option<Frame> {
if self.pending() < 6 {
return None;
}
let seq = u32::from_be_bytes(
self.read_at::<4>(2)
.expect("just verified ≥ 6 bytes pending"),
);
self.read_pos += 6;
Some(Frame::Ack { seq })
}
fn try_decode_json(&mut self) -> Result<Option<Frame>, FrameError> {
if self.pending() < 10 {
return Ok(None);
}
let seq = u32::from_be_bytes(self.read_at::<4>(2).expect("≥ 10 pending"));
let len_raw = u32::from_be_bytes(self.read_at::<4>(6).expect("≥ 10 pending"));
let len = len_raw as usize;
if len > self.max_frame_payload {
return Err(FrameError::PayloadTooLarge {
requested: len,
limit: self.max_frame_payload,
});
}
if self.pending() < 10 + len {
return Ok(None);
}
let start = self.read_pos + 10;
let payload = self.buf[start..start + len].to_vec();
self.read_pos += 10 + len;
Ok(Some(Frame::Json { seq, payload }))
}
fn try_decode_compressed(&mut self) -> Result<Option<Frame>, FrameError> {
if self.pending() < 6 {
return Ok(None);
}
let len_raw = u32::from_be_bytes(self.read_at::<4>(2).expect("≥ 6 pending"));
let len = len_raw as usize;
if len > self.max_frame_payload {
return Err(FrameError::PayloadTooLarge {
requested: len,
limit: self.max_frame_payload,
});
}
if self.pending() < 6 + len {
return Ok(None);
}
let start = self.read_pos + 6;
let compressed = &self.buf[start..start + len];
let decompressed = decompress_capped(compressed, self.max_frame_payload)?;
self.read_pos += 6 + len;
Ok(Some(Frame::Compressed { decompressed }))
}
fn try_decode_unknown_with_seq_count(
&mut self,
type_byte: u8,
) -> Result<Option<Frame>, FrameError> {
if self.pending() < 10 {
return Ok(None);
}
let pair_count = u32::from_be_bytes(self.read_at::<4>(6).expect("≥ 10 pending")) as usize;
let mut cursor = 10;
for _ in 0..pair_count {
if self.pending() < cursor + 4 {
return Ok(None);
}
let key_len = u32::from_be_bytes(
self.read_at::<4>(cursor)
.expect("just bounded by pending check"),
) as usize;
if key_len > self.max_frame_payload {
return Err(FrameError::PayloadTooLarge {
requested: key_len,
limit: self.max_frame_payload,
});
}
cursor += 4 + key_len;
if self.pending() < cursor + 4 {
return Ok(None);
}
let val_len = u32::from_be_bytes(
self.read_at::<4>(cursor)
.expect("just bounded by pending check"),
) as usize;
if val_len > self.max_frame_payload {
return Err(FrameError::PayloadTooLarge {
requested: val_len,
limit: self.max_frame_payload,
});
}
cursor += 4 + val_len;
}
if self.pending() < cursor {
return Ok(None);
}
let raw = self.buf[self.read_pos..self.read_pos + cursor].to_vec();
self.read_pos += cursor;
Ok(Some(Frame::Unknown {
frame_type: type_byte,
raw,
}))
}
}
fn decompress_capped(compressed: &[u8], limit: usize) -> Result<Vec<u8>, FrameError> {
use flate2::read::ZlibDecoder;
let mut out = Vec::new();
let take_limit = u64::try_from(limit).unwrap_or(u64::MAX);
let take_plus_one = take_limit.saturating_add(1);
let decoder = ZlibDecoder::new(compressed);
let mut bounded = decoder.take(take_plus_one);
bounded
.read_to_end(&mut out)
.map_err(|e| FrameError::Decompression(e.to_string()))?;
if out.len() > limit {
return Err(FrameError::DecompressedTooLarge { limit });
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn encode_window_layout() {
let bytes = encode_window(42);
assert_eq!(bytes[0], b'2');
assert_eq!(bytes[1], b'W');
assert_eq!(
u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
42
);
}
#[test]
fn encode_ack_layout() {
let bytes = encode_ack(7);
assert_eq!(bytes[0], b'2');
assert_eq!(bytes[1], b'A');
assert_eq!(
u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
7
);
}
#[test]
fn encode_json_frame_layout() {
let bytes = encode_json_frame(13, b"hello");
assert_eq!(&bytes[..2], b"2J");
assert_eq!(
u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
13
);
assert_eq!(
u32::from_be_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
5
);
assert_eq!(&bytes[10..], b"hello");
}
#[test]
fn decode_window_round_trip() {
let mut d = FrameDecoder::new();
d.feed(&encode_window(123));
let f = d.next_frame().unwrap().unwrap();
assert_eq!(f, Frame::Window { count: 123 });
assert!(d.next_frame().unwrap().is_none());
}
#[test]
fn decode_ack_round_trip() {
let mut d = FrameDecoder::new();
d.feed(&encode_ack(987_654));
assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 987_654 }));
}
#[test]
fn decode_json_round_trip() {
let mut d = FrameDecoder::new();
d.feed(&encode_json_frame(1, br#"{"k":"v"}"#));
let f = d.next_frame().unwrap().unwrap();
let Frame::Json { seq, payload } = f else {
panic!("expected Json")
};
assert_eq!(seq, 1);
assert_eq!(payload, br#"{"k":"v"}"#);
}
#[test]
fn decode_handles_concatenated_frames() {
let mut d = FrameDecoder::new();
let mut feed = Vec::new();
feed.extend_from_slice(&encode_window(2));
feed.extend_from_slice(&encode_json_frame(1, b"a"));
feed.extend_from_slice(&encode_json_frame(2, b"bb"));
feed.extend_from_slice(&encode_ack(2));
d.feed(&feed);
assert_eq!(d.next_frame().unwrap(), Some(Frame::Window { count: 2 }));
let Some(Frame::Json { seq: 1, payload }) = d.next_frame().unwrap() else {
panic!()
};
assert_eq!(payload, b"a");
let Some(Frame::Json { seq: 2, payload }) = d.next_frame().unwrap() else {
panic!()
};
assert_eq!(payload, b"bb");
assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 2 }));
assert!(d.next_frame().unwrap().is_none());
}
#[test]
fn decode_handles_byte_at_a_time_feeds() {
let mut d = FrameDecoder::new();
let frame = encode_json_frame(5, b"abcdefgh");
for byte in &frame {
assert!(d.next_frame().unwrap().is_none());
d.feed(std::slice::from_ref(byte));
}
let Frame::Json { seq, payload } = d.next_frame().unwrap().unwrap() else {
panic!()
};
assert_eq!(seq, 5);
assert_eq!(payload, b"abcdefgh");
}
#[test]
fn decode_compressed_round_trip() {
let inner = [
encode_json_frame(1, b"hello").as_slice(),
encode_json_frame(2, b"world").as_slice(),
]
.concat();
let outer = encode_compressed(6, &inner).unwrap();
let mut d = FrameDecoder::new();
d.feed(&outer);
let Frame::Compressed { decompressed } = d.next_frame().unwrap().unwrap() else {
panic!()
};
assert_eq!(decompressed, inner);
}
#[test]
fn decode_rejects_bad_version() {
let mut d = FrameDecoder::new();
d.feed(&[b'1', b'W', 0, 0, 0, 1]);
assert!(matches!(
d.next_frame(),
Err(FrameError::UnsupportedVersion(b'1'))
));
}
#[test]
fn decode_rejects_unknown_frame_type() {
let mut d = FrameDecoder::new();
d.feed(&[b'2', b'Z', 0, 0, 0, 1]);
assert!(matches!(
d.next_frame(),
Err(FrameError::UnknownFrameType(b'Z'))
));
}
#[test]
fn decode_caps_oversize_json_payload() {
let mut d = FrameDecoder::with_max_frame_payload(16);
let mut buf = vec![b'2', b'J', 0, 0, 0, 1];
buf.extend_from_slice(&100u32.to_be_bytes());
d.feed(&buf);
assert!(matches!(
d.next_frame(),
Err(FrameError::PayloadTooLarge { .. })
));
}
#[test]
fn decode_caps_zlib_bomb() {
let original = vec![0u8; 1024 * 64];
let frame = encode_compressed(9, &original).unwrap();
let mut d = FrameDecoder::with_max_frame_payload(1024); d.feed(&frame);
match d.next_frame() {
Err(FrameError::DecompressedTooLarge { .. } | FrameError::PayloadTooLarge { .. }) => {}
other => panic!("expected size-related error, got {other:?}"),
}
}
#[test]
fn legacy_d_frame_is_decoded_as_unknown_and_advances() {
let mut frame = Vec::new();
frame.push(b'2');
frame.push(b'D');
frame.extend_from_slice(&5u32.to_be_bytes()); frame.extend_from_slice(&1u32.to_be_bytes()); frame.extend_from_slice(&3u32.to_be_bytes());
frame.extend_from_slice(b"foo");
frame.extend_from_slice(&3u32.to_be_bytes());
frame.extend_from_slice(b"bar");
frame.extend_from_slice(&encode_ack(5));
let mut d = FrameDecoder::new();
d.feed(&frame);
let f = d.next_frame().unwrap().unwrap();
let Frame::Unknown { frame_type, raw } = f else {
panic!()
};
assert_eq!(frame_type, b'D');
assert_eq!(&raw[..2], b"2D");
assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 5 }));
}
#[test]
fn decoder_compacts_after_consuming_half() {
let mut d = FrameDecoder::new();
for _ in 0..32 {
d.feed(&encode_ack(1));
let _ = d.next_frame().unwrap();
}
assert!(d.buf.capacity() < 1024, "buf cap = {}", d.buf.capacity());
}
proptest! {
#[test]
fn prop_json_frame_round_trip(seq: u32, payload: Vec<u8>) {
let bytes = encode_json_frame(seq, &payload);
let mut d = FrameDecoder::new();
d.feed(&bytes);
let frame = d.next_frame().unwrap().unwrap();
let Frame::Json { seq: got_seq, payload: got_payload } = frame else {
panic!("expected Json")
};
prop_assert_eq!(got_seq, seq);
prop_assert_eq!(got_payload, payload);
prop_assert!(d.next_frame().unwrap().is_none());
}
#[test]
fn prop_window_round_trip(count: u32) {
let bytes = encode_window(count);
let mut d = FrameDecoder::new();
d.feed(&bytes);
prop_assert_eq!(d.next_frame().unwrap(), Some(Frame::Window { count }));
}
#[test]
fn prop_ack_round_trip(seq: u32) {
let bytes = encode_ack(seq);
let mut d = FrameDecoder::new();
d.feed(&bytes);
prop_assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq }));
}
#[test]
fn prop_compressed_round_trip(payloads in proptest::collection::vec(any::<Vec<u8>>(), 1..16)) {
let mut inner = Vec::new();
for (i, p) in payloads.iter().enumerate() {
let seq = u32::try_from(i + 1).unwrap_or(u32::MAX);
inner.extend_from_slice(&encode_json_frame(seq, p));
}
let outer = encode_compressed(3, &inner).unwrap();
let mut d = FrameDecoder::new();
d.feed(&outer);
let Some(Frame::Compressed { decompressed }) = d.next_frame().unwrap() else {
panic!()
};
prop_assert_eq!(decompressed, inner);
}
#[test]
fn prop_decoder_does_not_panic(bytes in proptest::collection::vec(any::<u8>(), 0..4096)) {
let mut d = FrameDecoder::with_max_frame_payload(8 * 1024);
d.feed(&bytes);
for _ in 0..1024 {
match d.next_frame() {
Ok(Some(_)) => {}
Ok(None) | Err(_) => break,
}
}
}
}
}