use crate::buffer::{ReadBuffer, WriteBuffer};
use crate::constants::{self, PacketType, PACKET_HEADER_SIZE};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PacketHeader {
pub length: u32,
pub packet_checksum: u16,
pub packet_type: PacketType,
pub flags: u8,
pub header_checksum: u16,
}
impl PacketHeader {
pub fn new(packet_type: PacketType, length: u32) -> Self {
Self {
length,
packet_checksum: 0,
packet_type,
flags: 0,
header_checksum: 0,
}
}
pub fn with_flags(packet_type: PacketType, length: u32, flags: u8) -> Self {
Self {
length,
packet_checksum: 0,
packet_type,
flags,
header_checksum: 0,
}
}
pub fn parse(data: &[u8]) -> Result<Self> {
if data.len() < PACKET_HEADER_SIZE {
return Err(Error::PacketTooShort {
expected: PACKET_HEADER_SIZE,
actual: data.len(),
});
}
let mut buf = ReadBuffer::from_slice(data);
Self::read(&mut buf, false)
}
pub fn parse_large_sdu(data: &[u8]) -> Result<Self> {
if data.len() < PACKET_HEADER_SIZE {
return Err(Error::PacketTooShort {
expected: PACKET_HEADER_SIZE,
actual: data.len(),
});
}
let mut buf = ReadBuffer::from_slice(data);
Self::read(&mut buf, true)
}
pub fn read(buf: &mut ReadBuffer, large_sdu: bool) -> Result<Self> {
let length = if large_sdu {
buf.read_u32_be()?
} else {
let len = buf.read_u16_be()? as u32;
buf.skip(2)?; len
};
let packet_checksum = if large_sdu {
0 } else {
0
};
let packet_type = PacketType::try_from(buf.read_u8()?)?;
let flags = buf.read_u8()?;
let header_checksum = buf.read_u16_be()?;
Ok(Self {
length,
packet_checksum,
packet_type,
flags,
header_checksum,
})
}
pub fn write(&self, buf: &mut WriteBuffer, large_sdu: bool) -> Result<()> {
if large_sdu {
buf.write_u32_be(self.length)?;
} else {
buf.write_u16_be(self.length as u16)?;
buf.write_u16_be(self.packet_checksum)?;
}
buf.write_u8(self.packet_type as u8)?;
buf.write_u8(self.flags)?;
buf.write_u16_be(self.header_checksum)?;
Ok(())
}
pub fn to_bytes(&self, large_sdu: bool) -> Result<bytes::Bytes> {
let mut buf = WriteBuffer::with_capacity(PACKET_HEADER_SIZE);
self.write(&mut buf, large_sdu)?;
Ok(buf.freeze())
}
pub fn payload_length(&self) -> usize {
(self.length as usize).saturating_sub(PACKET_HEADER_SIZE)
}
pub fn has_tls_reneg_flag(&self) -> bool {
(self.flags & constants::packet_flags::TLS_RENEG) != 0
}
pub fn has_redirect_flag(&self) -> bool {
(self.flags & constants::packet_flags::REDIRECT) != 0
}
}
impl Default for PacketHeader {
fn default() -> Self {
Self {
length: PACKET_HEADER_SIZE as u32,
packet_checksum: 0,
packet_type: PacketType::Data,
flags: 0,
header_checksum: 0,
}
}
}
#[cfg(test)]
#[derive(Debug)]
pub struct PacketBuilder {
header: PacketHeader,
payload: WriteBuffer,
large_sdu: bool,
}
#[cfg(test)]
impl PacketBuilder {
pub fn new(packet_type: PacketType) -> Self {
Self {
header: PacketHeader::new(packet_type, PACKET_HEADER_SIZE as u32),
payload: WriteBuffer::new(),
large_sdu: false,
}
}
#[allow(dead_code)]
pub fn large_sdu(mut self, large_sdu: bool) -> Self {
self.large_sdu = large_sdu;
self
}
#[allow(dead_code)]
pub fn flags(mut self, flags: u8) -> Self {
self.header.flags = flags;
self
}
pub fn payload(&mut self) -> &mut WriteBuffer {
&mut self.payload
}
pub fn build(mut self) -> Result<bytes::Bytes> {
let payload_len = self.payload.len();
self.header.length = (PACKET_HEADER_SIZE + payload_len) as u32;
let mut result = WriteBuffer::with_capacity(PACKET_HEADER_SIZE + payload_len);
self.header.write(&mut result, self.large_sdu)?;
result.write_bytes(self.payload.as_slice())?;
Ok(result.freeze())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_new() {
let header = PacketHeader::new(PacketType::Connect, 100);
assert_eq!(header.packet_type, PacketType::Connect);
assert_eq!(header.length, 100);
assert_eq!(header.flags, 0);
}
#[test]
fn test_header_parse_small_sdu() {
let data = [
0x00, 0x64, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, ];
let header = PacketHeader::parse(&data).unwrap();
assert_eq!(header.length, 100);
assert_eq!(header.packet_type, PacketType::Connect);
assert_eq!(header.flags, 0x08);
assert!(header.has_tls_reneg_flag());
}
#[test]
fn test_header_parse_large_sdu() {
let data = [
0x00, 0x00, 0x20, 0x00, 0x06, 0x00, 0x00, 0x00, ];
let header = PacketHeader::parse_large_sdu(&data).unwrap();
assert_eq!(header.length, 8192);
assert_eq!(header.packet_type, PacketType::Data);
}
#[test]
fn test_header_write_small_sdu() {
let header = PacketHeader::new(PacketType::Connect, 100);
let mut buf = WriteBuffer::new();
header.write(&mut buf, false).unwrap();
assert_eq!(buf.as_slice(), &[
0x00, 0x64, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, ]);
}
#[test]
fn test_header_write_large_sdu() {
let header = PacketHeader::new(PacketType::Data, 8192);
let mut buf = WriteBuffer::new();
header.write(&mut buf, true).unwrap();
assert_eq!(buf.as_slice(), &[
0x00, 0x00, 0x20, 0x00, 0x06, 0x00, 0x00, 0x00, ]);
}
#[test]
fn test_header_roundtrip_small_sdu() {
let original = PacketHeader::with_flags(PacketType::Accept, 256, 0x04);
let mut buf = WriteBuffer::new();
original.write(&mut buf, false).unwrap();
let parsed = PacketHeader::parse(buf.as_slice()).unwrap();
assert_eq!(original.length, parsed.length);
assert_eq!(original.packet_type, parsed.packet_type);
assert_eq!(original.flags, parsed.flags);
}
#[test]
fn test_header_roundtrip_large_sdu() {
let original = PacketHeader::with_flags(PacketType::Data, 32768, 0x08);
let mut buf = WriteBuffer::new();
original.write(&mut buf, true).unwrap();
let parsed = PacketHeader::parse_large_sdu(buf.as_slice()).unwrap();
assert_eq!(original.length, parsed.length);
assert_eq!(original.packet_type, parsed.packet_type);
assert_eq!(original.flags, parsed.flags);
}
#[test]
fn test_payload_length() {
let header = PacketHeader::new(PacketType::Data, 100);
assert_eq!(header.payload_length(), 100 - PACKET_HEADER_SIZE);
}
#[test]
fn test_packet_builder() {
let mut builder = PacketBuilder::new(PacketType::Connect);
builder.payload().write_bytes(&[0x41, 0x42, 0x43]).unwrap();
let packet = builder.build().unwrap();
assert_eq!(packet.len(), 11);
let header = PacketHeader::parse(&packet).unwrap();
assert_eq!(header.length, 11);
assert_eq!(header.packet_type, PacketType::Connect);
assert_eq!(&packet[8..], &[0x41, 0x42, 0x43]);
}
#[test]
fn test_header_parse_too_short() {
let data = [0x00, 0x01, 0x02]; let result = PacketHeader::parse(&data);
assert!(result.is_err());
}
#[test]
fn test_header_invalid_packet_type() {
let data = [
0x00, 0x08, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, ];
let result = PacketHeader::parse(&data);
assert!(result.is_err());
}
#[test]
fn test_all_packet_types() {
for (packet_type, expected_byte) in [
(PacketType::Connect, 0x01u8),
(PacketType::Accept, 0x02),
(PacketType::Refuse, 0x04),
(PacketType::Redirect, 0x05),
(PacketType::Data, 0x06),
(PacketType::Marker, 0x0C),
(PacketType::Control, 0x0E),
] {
let header = PacketHeader::new(packet_type, 8);
let mut buf = WriteBuffer::new();
header.write(&mut buf, false).unwrap();
assert_eq!(buf.as_slice()[4], expected_byte);
}
}
}