use crate::error::WireError;
use crate::frame_type::FrameType;
use crate::varint;
const FLAG_CONTROL: u8 = 0b1000_0000;
const FLAG_PRIORITY: u8 = 0b0100_0000;
const FLAG_ENCRYPTED: u8 = 0b0010_0000;
const _FLAG_RESERVED: u8 = 0b0001_0000;
const TYPE_MASK: u8 = 0x0F;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FrameFlags {
pub control: bool,
pub priority: bool,
pub encrypted: bool,
}
impl FrameFlags {
#[inline]
const fn to_bits(self) -> u8 {
let mut bits = 0u8;
if self.control { bits |= FLAG_CONTROL; }
if self.priority { bits |= FLAG_PRIORITY; }
if self.encrypted { bits |= FLAG_ENCRYPTED; }
bits
}
#[inline]
const fn from_bits(byte: u8) -> Self {
Self {
control: byte & FLAG_CONTROL != 0,
priority: byte & FLAG_PRIORITY != 0,
encrypted: byte & FLAG_ENCRYPTED != 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
pub frame_type: FrameType,
pub flags: FrameFlags,
pub channel_id: u32,
pub stream_id: u64,
pub payload_len: u64,
}
impl FrameHeader {
#[inline]
pub const fn data(channel_id: u32, payload_len: u64) -> Self {
Self {
frame_type: FrameType::Data,
flags: FrameFlags { control: false, priority: false, encrypted: false },
channel_id,
stream_id: 0,
payload_len,
}
}
#[inline]
pub const fn control(frame_type: FrameType) -> Self {
Self {
frame_type,
flags: FrameFlags { control: true, priority: false, encrypted: false },
channel_id: 0,
stream_id: 0,
payload_len: 0,
}
}
#[inline]
pub const fn with_encrypted(mut self) -> Self {
self.flags.encrypted = true;
self
}
#[inline]
pub const fn with_priority(mut self) -> Self {
self.flags.priority = true;
self
}
#[inline]
pub const fn with_stream(mut self, stream_id: u64) -> Self {
self.stream_id = stream_id;
self
}
#[inline]
pub const fn encoded_len(&self) -> usize {
1 + varint::encoded_len(self.channel_id as u64)
+ varint::encoded_len(self.stream_id)
+ varint::encoded_len(self.payload_len)
}
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, WireError> {
let needed = self.encoded_len();
if buf.len() < needed {
return Err(WireError::BufferTooSmall { needed, available: buf.len() });
}
let mut offset = 0;
buf[0] = self.flags.to_bits() | (self.frame_type as u8 & TYPE_MASK);
offset += 1;
offset += varint::encode(self.channel_id as u64, &mut buf[offset..])?;
offset += varint::encode(self.stream_id, &mut buf[offset..])?;
offset += varint::encode(self.payload_len, &mut buf[offset..])?;
debug_assert_eq!(offset, needed);
Ok(offset)
}
pub fn decode(buf: &[u8]) -> Result<(Self, usize), WireError> {
if buf.is_empty() {
return Err(WireError::Incomplete { needed_min: 4, available: 0 });
}
let byte0 = buf[0];
let mut offset = 1;
let flags = FrameFlags::from_bits(byte0);
let frame_type = FrameType::from_u8(byte0 & TYPE_MASK)?;
let (channel_raw, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
needed_min: offset + needed_min,
available: buf.len(),
},
other => other,
})?;
offset += n;
let (stream_id, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
needed_min: offset + needed_min,
available: buf.len(),
},
other => other,
})?;
offset += n;
let (payload_len, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
needed_min: offset + needed_min,
available: buf.len(),
},
other => other,
})?;
offset += n;
Ok((
Self {
frame_type,
flags,
channel_id: channel_raw as u32,
stream_id,
payload_len,
},
offset,
))
}
}
impl core::fmt::Display for FrameHeader {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"[{}] ch={} stream={} len={} flags=[{}{}{}]",
self.frame_type,
self.channel_id,
self.stream_id,
self.payload_len,
if self.flags.control { "C" } else { "" },
if self.flags.priority { "P" } else { "" },
if self.flags.encrypted { "E" } else { "" },
)
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::format;
use super::*;
#[test]
fn minimal_data_header() {
let h = FrameHeader::data(0, 0);
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
assert_eq!(n, 4, "minimum header should be 4 bytes");
assert_eq!(buf[0] & TYPE_MASK, 0x00); assert_eq!(buf[0] & 0xF0, 0x00);
let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
assert_eq!(decoded, h);
assert_eq!(consumed, n);
}
#[test]
fn data_with_flags() {
let h = FrameHeader::data(1, 100)
.with_encrypted()
.with_priority()
.with_stream(42);
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
assert_eq!(buf[0], 0x60);
let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
assert_eq!(decoded, h);
assert_eq!(consumed, n);
}
#[test]
fn control_ping() {
let h = FrameHeader::control(FrameType::Ping);
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
assert_eq!(buf[0], 0x84);
assert_eq!(n, 4);
let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
assert!(decoded.flags.control);
assert_eq!(decoded.frame_type, FrameType::Ping);
}
#[test]
fn handshake_type() {
for ft in [FrameType::HandshakeInit, FrameType::HandshakeReply, FrameType::HandshakeComplete] {
let h = FrameHeader {
frame_type: ft,
flags: FrameFlags::default(),
channel_id: 0,
stream_id: 0,
payload_len: 128,
};
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
assert_eq!(decoded.frame_type, ft);
assert!(decoded.frame_type.is_handshake());
}
}
#[test]
fn large_values() {
let h = FrameHeader {
frame_type: FrameType::FileChunk,
flags: FrameFlags { control: false, priority: true, encrypted: true },
channel_id: 1_000_000,
stream_id: 9_999_999_999,
payload_len: 4_294_967_296, };
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
assert_eq!(decoded, h);
assert_eq!(consumed, n);
}
#[test]
fn encoded_len_accurate() {
let h = FrameHeader::data(42, 100).with_stream(12345);
assert_eq!(h.encoded_len(), h.encode(&mut [0u8; 32]).unwrap());
}
#[test]
fn incomplete_decode() {
let buf = [0x00, 0x00];
match FrameHeader::decode(&buf) {
Err(WireError::Incomplete { .. }) => {} other => panic!("expected Incomplete, got {other:?}"),
}
}
#[test]
fn unknown_type_rejected() {
let buf = [0x0F, 0x00, 0x00, 0x00];
assert!(matches!(
FrameHeader::decode(&buf),
Err(WireError::UnknownFrameType(0x0F))
));
}
#[test]
fn display_format() {
let h = FrameHeader::data(5, 42).with_encrypted();
let s = format!("{h}");
assert!(s.contains("DATA"));
assert!(s.contains("ch=5"));
assert!(s.contains("len=42"));
assert!(s.contains("E")); }
#[test]
fn all_frame_types_encode_decode() {
let types = [
FrameType::Data,
FrameType::HandshakeInit,
FrameType::HandshakeReply,
FrameType::HandshakeComplete,
FrameType::Ping,
FrameType::Pong,
FrameType::Close,
FrameType::FileHeader,
FrameType::FileChunk,
FrameType::Ack,
];
for ft in types {
let h = FrameHeader {
frame_type: ft,
flags: FrameFlags::default(),
channel_id: 0,
stream_id: 0,
payload_len: 0,
};
let mut buf = [0u8; 32];
let n = h.encode(&mut buf).unwrap();
let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
assert_eq!(decoded.frame_type, ft, "type {ft} failed roundtrip");
}
}
}