use crate::error::ProtocolError;
use crate::intent::Intent;
pub const HEADER_LEN: usize = 16;
pub const VERSION: u8 = 1;
pub const FLAG_DATA: u8 = 0x00;
pub const FLAG_ACK: u8 = 0x01;
pub const FLAG_HS: u8 = 0x02;
#[derive(Debug, Clone, Copy)]
pub struct PacketHeader {
pub flags: u8,
pub intent: Intent,
pub priority: u8,
pub payload_len: u16,
pub conn_id: u32,
pub seq: u32,
pub ack: u32,
}
impl PacketHeader {
pub fn encode(&self, buf: &mut [u8]) -> Result<(), ProtocolError> {
if buf.len() < HEADER_LEN {
return Err(ProtocolError::MalformedPacket);
}
self.validate()?;
buf[0] = (VERSION << 4) | (self.flags & 0x0F);
buf[1] = ((self.intent as u8) << 6) | ((self.priority & 0x07) << 3);
buf[2..4].copy_from_slice(&self.payload_len.to_be_bytes());
buf[4..8].copy_from_slice(&self.conn_id.to_be_bytes());
buf[8..12].copy_from_slice(&self.seq.to_be_bytes());
buf[12..16].copy_from_slice(&self.ack.to_be_bytes());
Ok(())
}
pub fn decode(buf: &[u8]) -> Result<Self, ProtocolError> {
if buf.len() < HEADER_LEN {
return Err(ProtocolError::MalformedPacket);
}
let version = buf[0] >> 4;
if version != VERSION {
return Err(ProtocolError::UnsupportedVersion);
}
let flags = buf[0] & 0x0F;
let intent = Intent::from_bits(buf[1] >> 6).ok_or(ProtocolError::InvalidIntent)?;
let priority = (buf[1] >> 3) & 0x07;
let header = Self {
flags,
intent,
priority,
payload_len: u16::from_be_bytes([buf[2], buf[3]]),
conn_id: u32::from_be_bytes(buf[4..8].try_into().unwrap()),
seq: u32::from_be_bytes(buf[8..12].try_into().unwrap()),
ack: u32::from_be_bytes(buf[12..16].try_into().unwrap()),
};
header.validate()?;
Ok(header)
}
fn validate(&self) -> Result<(), ProtocolError> {
match self.flags {
FLAG_DATA => {
if self.payload_len == 0 {
return Err(ProtocolError::MalformedPacket);
}
}
FLAG_ACK => {
if self.payload_len != 0 {
return Err(ProtocolError::MalformedPacket);
}
if self.intent != Intent::Reliable {
return Err(ProtocolError::MalformedPacket);
}
}
FLAG_HS => {
if self.intent != Intent::Reliable {
return Err(ProtocolError::MalformedPacket);
}
}
_ => return Err(ProtocolError::MalformedPacket),
}
Ok(())
}
}