use std::collections::HashMap;
use std::io::{Read, Write};
use crate::error::{Error, Result};
pub const DEFAULT_CHUNK_SIZE: usize = 128;
pub const MAX_CHUNK_SIZE: usize = 0x00FF_FFFF;
#[derive(Debug, Clone)]
pub struct Message {
pub msg_type_id: u8,
pub msg_stream_id: u32,
pub timestamp: u32,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageStreamKind {
Control,
NetStream(u32),
Reserved(u32),
}
impl Message {
pub fn stream_kind(&self) -> MessageStreamKind {
match self.msg_stream_id {
0 => MessageStreamKind::Control,
id if id & 0xFF00_0000 == 0 => MessageStreamKind::NetStream(id),
other => MessageStreamKind::Reserved(other),
}
}
pub fn is_control_stream(&self) -> bool {
matches!(self.stream_kind(), MessageStreamKind::Control)
}
pub fn validate_protocol_control_invariants(&self) -> Result<()> {
if matches!(self.stream_kind(), MessageStreamKind::Reserved(_)) {
return Err(Error::ProtocolViolation(format!(
"message stream id {:#010x} sets reserved high byte (spec §4.1: 3-byte field)",
self.msg_stream_id
)));
}
if matches!(self.msg_type_id, 1..=6) && self.msg_stream_id != 0 {
return Err(Error::ProtocolViolation(format!(
"protocol-control message type {} carries non-zero msg_stream_id {} (spec §5 requires 0)",
self.msg_type_id, self.msg_stream_id
)));
}
Ok(())
}
}
#[derive(Default, Debug, Clone)]
struct InState {
msg_type_id: u8,
msg_stream_id: u32,
msg_length: u32,
timestamp: u32,
last_delta: u32,
last_had_ext_ts: bool,
partial: Vec<u8>,
}
pub struct ChunkReader<R: Read> {
stream: R,
chunk_size: usize,
states: HashMap<u32, InState>,
received_bytes: u32,
window_ack_size: u32,
last_ack_bytes: u32,
}
impl<R: Read> ChunkReader<R> {
pub fn new(stream: R) -> Self {
Self {
stream,
chunk_size: DEFAULT_CHUNK_SIZE,
states: HashMap::new(),
received_bytes: 0,
window_ack_size: 0,
last_ack_bytes: 0,
}
}
fn read_exact_counted(&mut self, buf: &mut [u8]) -> Result<()> {
self.stream.read_exact(buf)?;
self.received_bytes = self.received_bytes.wrapping_add(buf.len() as u32);
Ok(())
}
pub fn received_bytes(&self) -> u32 {
self.received_bytes
}
pub fn window_ack_size(&self) -> u32 {
self.window_ack_size
}
pub fn set_window_ack_size(&mut self, size: u32) {
self.window_ack_size = size;
self.last_ack_bytes = self.received_bytes;
}
pub fn ack_due(&mut self) -> Option<u32> {
if self.window_ack_size == 0 {
return None;
}
let since = self.received_bytes.wrapping_sub(self.last_ack_bytes);
if since >= self.window_ack_size {
self.last_ack_bytes = self.received_bytes;
Some(self.received_bytes)
} else {
None
}
}
pub fn set_chunk_size(&mut self, size: usize) {
self.chunk_size = size.clamp(1, MAX_CHUNK_SIZE);
}
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
pub fn abort_partial(&mut self, chunk_stream_id: u32) -> bool {
match self.states.get_mut(&chunk_stream_id) {
Some(st) if !st.partial.is_empty() => {
st.partial.clear();
true
}
_ => false,
}
}
pub fn inner_mut(&mut self) -> &mut R {
&mut self.stream
}
pub fn read_message(&mut self) -> Result<Message> {
loop {
let (csid, fmt) = self.read_basic_header()?;
match fmt {
0 => self.read_fmt0_header(csid)?,
1 => self.read_fmt1_header(csid)?,
2 => self.read_fmt2_header(csid)?,
3 => self.read_fmt3_header(csid)?,
_ => unreachable!("fmt is 2 bits"),
}
let take = {
let state = self.states.get(&csid).ok_or_else(|| {
Error::InvalidChunk(format!(
"fmt {fmt} chunk on csid {csid} without prior fmt-0 state"
))
})?;
let need = state.msg_length as usize - state.partial.len();
need.min(self.chunk_size)
};
let mut buf = vec![0u8; take];
self.read_exact_counted(&mut buf)?;
let state = self
.states
.get_mut(&csid)
.expect("csid state present (checked immediately above)");
state.partial.extend_from_slice(&buf);
if state.partial.len() as u32 >= state.msg_length {
let payload = std::mem::take(&mut state.partial);
let msg = Message {
msg_type_id: state.msg_type_id,
msg_stream_id: state.msg_stream_id,
timestamp: state.timestamp,
payload,
};
return Ok(msg);
}
}
}
fn read_basic_header(&mut self) -> Result<(u32, u8)> {
let mut b = [0u8; 1];
self.read_exact_counted(&mut b)?;
let fmt = (b[0] >> 6) & 0x03;
let low = b[0] & 0x3F;
let csid = match low {
0 => {
let mut b1 = [0u8; 1];
self.read_exact_counted(&mut b1)?;
b1[0] as u32 + 64
}
1 => {
let mut b2 = [0u8; 2];
self.read_exact_counted(&mut b2)?;
b2[0] as u32 + (b2[1] as u32) * 256 + 64
}
other => other as u32,
};
Ok((csid, fmt))
}
fn read_u24_be(&mut self) -> Result<u32> {
let mut b = [0u8; 3];
self.read_exact_counted(&mut b)?;
Ok(((b[0] as u32) << 16) | ((b[1] as u32) << 8) | (b[2] as u32))
}
fn read_u32_le_stream_id(&mut self) -> Result<u32> {
let mut b = [0u8; 4];
self.read_exact_counted(&mut b)?;
Ok(u32::from_le_bytes(b))
}
fn read_fmt0_header(&mut self, csid: u32) -> Result<()> {
let mut ts = self.read_u24_be()?;
let len = self.read_u24_be()?;
let mut t = [0u8; 1];
self.read_exact_counted(&mut t)?;
let ty = t[0];
let stream_id = self.read_u32_le_stream_id()?;
let had_ext_ts = ts == 0x00FF_FFFF;
if had_ext_ts {
ts = self.read_u32_be()?;
}
let st = self.states.entry(csid).or_default();
st.partial.clear();
st.msg_type_id = ty;
st.msg_stream_id = stream_id;
st.msg_length = len;
st.timestamp = ts;
st.last_delta = ts;
st.last_had_ext_ts = had_ext_ts;
Ok(())
}
fn read_fmt1_header(&mut self, csid: u32) -> Result<()> {
let mut delta = self.read_u24_be()?;
let len = self.read_u24_be()?;
let mut t = [0u8; 1];
self.read_exact_counted(&mut t)?;
let ty = t[0];
let had_ext_ts = delta == 0x00FF_FFFF;
if had_ext_ts {
delta = self.read_u32_be()?;
}
let st = self
.states
.get_mut(&csid)
.ok_or_else(|| Error::InvalidChunk("fmt 1 without prior fmt 0".into()))?;
st.msg_type_id = ty;
st.msg_length = len;
st.timestamp = st.timestamp.wrapping_add(delta);
st.last_delta = delta;
st.last_had_ext_ts = had_ext_ts;
st.partial.clear();
Ok(())
}
fn read_fmt2_header(&mut self, csid: u32) -> Result<()> {
let mut delta = self.read_u24_be()?;
let had_ext_ts = delta == 0x00FF_FFFF;
if had_ext_ts {
delta = self.read_u32_be()?;
}
let st = self
.states
.get_mut(&csid)
.ok_or_else(|| Error::InvalidChunk("fmt 2 without prior fmt 0/1".into()))?;
st.timestamp = st.timestamp.wrapping_add(delta);
st.last_delta = delta;
st.last_had_ext_ts = had_ext_ts;
st.partial.clear();
Ok(())
}
fn read_fmt3_header(&mut self, csid: u32) -> Result<()> {
let (had_ext_ts, partial_empty, last_delta) = {
let st = self
.states
.get(&csid)
.ok_or_else(|| Error::InvalidChunk("fmt 3 without prior fmt 0/1/2".into()))?;
(st.last_had_ext_ts, st.partial.is_empty(), st.last_delta)
};
if had_ext_ts {
let _dup = self.read_u32_be()?;
}
if partial_empty {
let st = self.states.get_mut(&csid).unwrap();
st.timestamp = st.timestamp.wrapping_add(last_delta);
}
Ok(())
}
fn read_u32_be(&mut self) -> Result<u32> {
let mut b = [0u8; 4];
self.read_exact_counted(&mut b)?;
Ok(u32::from_be_bytes(b))
}
}
#[derive(Default, Debug, Clone)]
struct OutState {
msg_type_id: u8,
msg_stream_id: u32,
msg_length: u32,
timestamp: u32,
last_delta: u32,
last_had_ext_ts: bool,
primed: bool,
}
pub struct ChunkWriter<W: Write> {
stream: W,
chunk_size: usize,
states: HashMap<u32, OutState>,
}
impl<W: Write> ChunkWriter<W> {
pub fn new(stream: W) -> Self {
Self {
stream,
chunk_size: DEFAULT_CHUNK_SIZE,
states: HashMap::new(),
}
}
pub fn set_chunk_size(&mut self, size: usize) {
self.chunk_size = size.clamp(1, MAX_CHUNK_SIZE);
}
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
pub fn inner_mut(&mut self) -> &mut W {
&mut self.stream
}
pub fn flush(&mut self) -> Result<()> {
self.stream.flush()?;
Ok(())
}
pub fn write_message(&mut self, csid: u32, msg: &Message) -> Result<()> {
let payload_len = msg.payload.len() as u32;
let (prev_primed, prev_stream_id, prev_type_id, prev_length, prev_timestamp) = {
let st = self.states.entry(csid).or_default();
(
st.primed,
st.msg_stream_id,
st.msg_type_id,
st.msg_length,
st.timestamp,
)
};
let fmt = if !prev_primed || prev_stream_id != msg.msg_stream_id {
0
} else if prev_type_id != msg.msg_type_id || prev_length != payload_len {
1
} else if msg.timestamp < prev_timestamp {
0
} else if msg.timestamp == prev_timestamp {
3
} else {
2
};
let delta = msg.timestamp.wrapping_sub(prev_timestamp);
let ext_ts_needed = (fmt == 0 && msg.timestamp >= 0x00FF_FFFF)
|| (fmt != 0 && fmt != 3 && delta >= 0x00FF_FFFF);
let prev_last_had_ext_ts = self
.states
.get(&csid)
.map(|s| s.last_had_ext_ts)
.unwrap_or(false);
let chunk_size = self.chunk_size;
let mut first_chunk_done = false;
let mut cursor = 0usize;
while cursor < msg.payload.len() || !first_chunk_done {
let chunk_fmt = if !first_chunk_done { fmt } else { 3 };
self.write_basic_header(chunk_fmt, csid)?;
match chunk_fmt {
0 => {
let ts_field = if ext_ts_needed {
0x00FF_FFFF
} else {
msg.timestamp
};
self.write_u24_be(ts_field)?;
self.write_u24_be(payload_len)?;
self.stream.write_all(&[msg.msg_type_id])?;
self.stream.write_all(&msg.msg_stream_id.to_le_bytes())?;
if ext_ts_needed {
self.stream.write_all(&msg.timestamp.to_be_bytes())?;
}
}
1 => {
let ts_field = if ext_ts_needed { 0x00FF_FFFF } else { delta };
self.write_u24_be(ts_field)?;
self.write_u24_be(payload_len)?;
self.stream.write_all(&[msg.msg_type_id])?;
if ext_ts_needed {
self.stream.write_all(&msg.timestamp.to_be_bytes())?;
}
}
2 => {
let ts_field = if ext_ts_needed { 0x00FF_FFFF } else { delta };
self.write_u24_be(ts_field)?;
if ext_ts_needed {
self.stream.write_all(&msg.timestamp.to_be_bytes())?;
}
}
3 => {
let ext_repeat = if !first_chunk_done {
ext_ts_needed
} else {
prev_last_had_ext_ts && cursor == 0
};
if ext_repeat {
self.stream.write_all(&msg.timestamp.to_be_bytes())?;
}
}
_ => unreachable!(),
}
let end = (cursor + chunk_size).min(msg.payload.len());
self.stream.write_all(&msg.payload[cursor..end])?;
cursor = end;
first_chunk_done = true;
}
let st = self.states.entry(csid).or_default();
st.msg_type_id = msg.msg_type_id;
st.msg_stream_id = msg.msg_stream_id;
st.msg_length = payload_len;
st.timestamp = msg.timestamp;
st.last_delta = if fmt == 0 { msg.timestamp } else { delta };
st.last_had_ext_ts = ext_ts_needed;
st.primed = true;
Ok(())
}
fn write_basic_header(&mut self, fmt: u8, csid: u32) -> Result<()> {
match csid {
2..=63 => {
self.stream.write_all(&[(fmt << 6) | (csid as u8)])?;
}
64..=319 => {
self.stream.write_all(&[fmt << 6, (csid - 64) as u8])?;
}
320..=65_599 => {
let v = (csid - 64) as u16;
self.stream
.write_all(&[(fmt << 6) | 1, (v & 0xFF) as u8, (v >> 8) as u8])?;
}
other => {
return Err(Error::ProtocolViolation(format!(
"chunk stream id {other} out of range"
)))
}
}
Ok(())
}
fn write_u24_be(&mut self, v: u32) -> Result<()> {
let v = v & 0x00FF_FFFF;
self.stream
.write_all(&[(v >> 16) as u8, (v >> 8) as u8, v as u8])?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn chunk_roundtrip_short_message() {
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.write_message(
3,
&Message {
msg_type_id: 20,
msg_stream_id: 0,
timestamp: 12345,
payload: b"hello world".to_vec(),
},
)
.unwrap();
}
let mut r = ChunkReader::new(Cursor::new(&buf));
let msg = r.read_message().unwrap();
assert_eq!(msg.msg_type_id, 20);
assert_eq!(msg.timestamp, 12345);
assert_eq!(msg.payload, b"hello world");
}
#[test]
fn chunk_roundtrip_multi_chunk_message() {
let payload: Vec<u8> = (0..4096u16).map(|i| (i & 0xFF) as u8).collect();
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.set_chunk_size(128);
w.write_message(
3,
&Message {
msg_type_id: 9,
msg_stream_id: 1,
timestamp: 7000,
payload: payload.clone(),
},
)
.unwrap();
}
let mut r = ChunkReader::new(Cursor::new(&buf));
r.set_chunk_size(128);
let msg = r.read_message().unwrap();
assert_eq!(msg.payload, payload);
assert_eq!(msg.msg_type_id, 9);
assert_eq!(msg.timestamp, 7000);
}
#[test]
fn received_bytes_counts_full_wire_size() {
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.write_message(
3,
&Message {
msg_type_id: 20,
msg_stream_id: 0,
timestamp: 1000,
payload: b"abcdef".to_vec(),
},
)
.unwrap();
}
let wire_len = buf.len() as u32;
let mut r = ChunkReader::new(Cursor::new(&buf));
assert_eq!(r.received_bytes(), 0);
let _ = r.read_message().unwrap();
assert_eq!(r.received_bytes(), wire_len);
}
#[test]
fn ack_not_due_without_window() {
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.write_message(
3,
&Message {
msg_type_id: 20,
msg_stream_id: 0,
timestamp: 0,
payload: vec![0u8; 500],
},
)
.unwrap();
}
let mut r = ChunkReader::new(Cursor::new(&buf));
let _ = r.read_message().unwrap();
assert_eq!(r.window_ack_size(), 0);
assert_eq!(r.ack_due(), None);
}
#[test]
fn ack_due_fires_once_per_window() {
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
for ts in [10u32, 20] {
w.write_message(
4,
&Message {
msg_type_id: 8,
msg_stream_id: 1,
timestamp: ts,
payload: vec![0xAB; 200],
},
)
.unwrap();
}
}
let mut r = ChunkReader::new(Cursor::new(&buf));
r.set_window_ack_size(150);
assert_eq!(r.window_ack_size(), 150);
let _ = r.read_message().unwrap();
let first = r.ack_due().expect("first ack due after window crossed");
assert_eq!(first, r.received_bytes());
assert_eq!(r.ack_due(), None, "ack must not re-fire within a window");
let _ = r.read_message().unwrap();
let second = r.ack_due().expect("second ack due after second window");
assert!(second > first);
assert_eq!(second, r.received_bytes());
}
#[test]
fn set_window_rebases_accounting() {
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.write_message(
4,
&Message {
msg_type_id: 8,
msg_stream_id: 1,
timestamp: 0,
payload: vec![0u8; 400],
},
)
.unwrap();
}
let mut r = ChunkReader::new(Cursor::new(&buf));
let _ = r.read_message().unwrap();
r.set_window_ack_size(100);
assert_eq!(r.ack_due(), None);
}
#[test]
fn abort_partial_discards_in_flight_message() {
let payload: Vec<u8> = (0..200u16).map(|i| (i & 0xFF) as u8).collect();
let mut full = Vec::new();
{
let mut w = ChunkWriter::new(&mut full);
w.set_chunk_size(128);
w.write_message(
5,
&Message {
msg_type_id: 9,
msg_stream_id: 1,
timestamp: 1000,
payload: payload.clone(),
},
)
.unwrap();
}
let first_chunk = &full[..12 + 128];
let mut r = ChunkReader::new(Cursor::new(first_chunk));
r.set_chunk_size(128);
let err = r.read_message().unwrap_err();
assert!(matches!(err, Error::Io(_) | Error::UnexpectedEof));
assert!(r.abort_partial(5), "first abort should discard 128 bytes");
assert!(!r.abort_partial(5), "second abort has nothing to discard");
assert!(!r.abort_partial(9));
}
#[test]
fn back_to_back_same_message_uses_fmt3() {
let msg = Message {
msg_type_id: 9,
msg_stream_id: 1,
timestamp: 1000,
payload: vec![0xAA; 32],
};
let mut buf = Vec::new();
{
let mut w = ChunkWriter::new(&mut buf);
w.write_message(5, &msg).unwrap();
w.write_message(5, &msg).unwrap();
}
let first_headers_len = 1 + 11 + 32; assert_eq!(buf[first_headers_len], 0xC5);
}
}