use crate::error::Error;
use super::frame::Flags;
pub(crate) const PROTOCOL_VERSION: u8 = 1;
pub(crate) const HEADER_LEN: usize = 12;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub(crate) enum FrameType {
Data = 0,
WindowUpdate = 1,
Ping = 2,
GoAway = 3,
}
impl TryFrom<u8> for FrameType {
type Error = Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(FrameType::Data),
1 => Ok(FrameType::WindowUpdate),
2 => Ok(FrameType::Ping),
3 => Ok(FrameType::GoAway),
_ => Err(Error::Protocol("unknown frame type")),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct Header {
pub version: u8,
pub frame_type: FrameType,
pub flags: Flags,
pub stream_id: u32,
pub length: u32,
}
impl Header {
pub(crate) fn encode(&self, out: &mut [u8; HEADER_LEN]) {
out[0] = self.version;
out[1] = self.frame_type as u8;
out[2..4].copy_from_slice(&self.flags.bits().to_be_bytes());
out[4..8].copy_from_slice(&self.stream_id.to_be_bytes());
out[8..12].copy_from_slice(&self.length.to_be_bytes());
}
pub(crate) fn decode(bytes: &[u8; HEADER_LEN]) -> Result<Self, Error> {
let version = bytes[0];
if version != PROTOCOL_VERSION {
return Err(Error::Protocol("unsupported protocol version"));
}
let frame_type = FrameType::try_from(bytes[1])?;
let flags_bits = u16::from_be_bytes([bytes[2], bytes[3]]);
let flags = Flags::from_bits(flags_bits).ok_or(Error::Protocol("unknown flag bits"))?;
let stream_id = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let length = u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
Ok(Self {
version,
frame_type,
flags,
stream_id,
length,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip() {
let h = Header {
version: PROTOCOL_VERSION,
frame_type: FrameType::Data,
flags: Flags::SYN | Flags::ACK,
stream_id: 0xDEAD_BEEF,
length: 12345,
};
let mut buf = [0u8; HEADER_LEN];
h.encode(&mut buf);
let decoded = Header::decode(&buf).unwrap();
assert_eq!(h, decoded);
}
#[test]
fn rejects_bad_version() {
let mut buf = [0u8; HEADER_LEN];
buf[0] = 99;
assert!(matches!(
Header::decode(&buf),
Err(Error::Protocol("unsupported protocol version"))
));
}
#[test]
fn rejects_unknown_type() {
let mut buf = [0u8; HEADER_LEN];
buf[0] = PROTOCOL_VERSION;
buf[1] = 99;
assert!(matches!(
Header::decode(&buf),
Err(Error::Protocol("unknown frame type"))
));
}
}