use super::types::{ConnectionId, ConstrainedError, PacketFlags, SequenceNumber};
pub const HEADER_SIZE: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConstrainedHeader {
pub connection_id: ConnectionId,
pub seq: SequenceNumber,
pub ack: SequenceNumber,
pub flags: PacketFlags,
}
impl ConstrainedHeader {
pub const fn new(
connection_id: ConnectionId,
seq: SequenceNumber,
ack: SequenceNumber,
flags: PacketFlags,
) -> Self {
Self {
connection_id,
seq,
ack,
flags,
}
}
pub fn syn(connection_id: ConnectionId) -> Self {
Self {
connection_id,
seq: SequenceNumber::new(0),
ack: SequenceNumber::new(0),
flags: PacketFlags::SYN,
}
}
pub fn syn_ack(connection_id: ConnectionId, ack: SequenceNumber) -> Self {
Self {
connection_id,
seq: SequenceNumber::new(0),
ack,
flags: PacketFlags::SYN_ACK,
}
}
pub fn ack(connection_id: ConnectionId, seq: SequenceNumber, ack: SequenceNumber) -> Self {
Self {
connection_id,
seq,
ack,
flags: PacketFlags::ACK,
}
}
pub fn data(connection_id: ConnectionId, seq: SequenceNumber, ack: SequenceNumber) -> Self {
Self {
connection_id,
seq,
ack,
flags: PacketFlags::DATA.union(PacketFlags::ACK),
}
}
pub fn fin(connection_id: ConnectionId, seq: SequenceNumber, ack: SequenceNumber) -> Self {
Self {
connection_id,
seq,
ack,
flags: PacketFlags::FIN.union(PacketFlags::ACK),
}
}
pub fn reset(connection_id: ConnectionId) -> Self {
Self {
connection_id,
seq: SequenceNumber::new(0),
ack: SequenceNumber::new(0),
flags: PacketFlags::RST,
}
}
pub fn ping(connection_id: ConnectionId, seq: SequenceNumber) -> Self {
Self {
connection_id,
seq,
ack: SequenceNumber::new(0),
flags: PacketFlags::PING,
}
}
pub fn pong(connection_id: ConnectionId, ack: SequenceNumber) -> Self {
Self {
connection_id,
seq: SequenceNumber::new(0),
ack,
flags: PacketFlags::PONG,
}
}
pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
let cid_bytes = self.connection_id.to_bytes();
[
cid_bytes[0],
cid_bytes[1],
self.seq.value(),
self.ack.value(),
self.flags.value(),
]
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, ConstrainedError> {
if bytes.len() < HEADER_SIZE {
return Err(ConstrainedError::PacketTooSmall {
expected: HEADER_SIZE,
actual: bytes.len(),
});
}
Ok(Self {
connection_id: ConnectionId::from_bytes([bytes[0], bytes[1]]),
seq: SequenceNumber::new(bytes[2]),
ack: SequenceNumber::new(bytes[3]),
flags: PacketFlags::new(bytes[4]),
})
}
pub const fn is_syn(&self) -> bool {
self.flags.is_syn()
}
pub const fn is_syn_ack(&self) -> bool {
self.flags.is_syn() && self.flags.is_ack()
}
pub const fn is_ack(&self) -> bool {
self.flags.is_ack()
}
pub const fn is_fin(&self) -> bool {
self.flags.is_fin()
}
pub const fn is_rst(&self) -> bool {
self.flags.is_rst()
}
pub const fn is_data(&self) -> bool {
self.flags.is_data()
}
pub const fn is_ping(&self) -> bool {
self.flags.is_ping()
}
pub const fn is_pong(&self) -> bool {
self.flags.is_pong()
}
}
impl std::fmt::Display for ConstrainedHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{} {} {} {}]",
self.connection_id, self.seq, self.ack, self.flags
)
}
}
#[derive(Debug, Clone)]
pub struct ConstrainedPacket {
pub header: ConstrainedHeader,
pub payload: Vec<u8>,
}
impl ConstrainedPacket {
pub fn new(header: ConstrainedHeader, payload: Vec<u8>) -> Self {
Self { header, payload }
}
pub fn control(header: ConstrainedHeader) -> Self {
Self {
header,
payload: Vec::new(),
}
}
pub fn data(
connection_id: ConnectionId,
seq: SequenceNumber,
ack: SequenceNumber,
payload: Vec<u8>,
) -> Self {
Self {
header: ConstrainedHeader::data(connection_id, seq, ack),
payload,
}
}
pub fn total_size(&self) -> usize {
HEADER_SIZE + self.payload.len()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.total_size());
bytes.extend_from_slice(&self.header.to_bytes());
bytes.extend_from_slice(&self.payload);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, ConstrainedError> {
let header = ConstrainedHeader::from_bytes(bytes)?;
let payload = if bytes.len() > HEADER_SIZE {
bytes[HEADER_SIZE..].to_vec()
} else {
Vec::new()
};
Ok(Self { header, payload })
}
pub fn has_payload(&self) -> bool {
!self.payload.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_serialization() {
let header = ConstrainedHeader::new(
ConnectionId::new(0x1234),
SequenceNumber::new(10),
SequenceNumber::new(5),
PacketFlags::DATA.union(PacketFlags::ACK),
);
let bytes = header.to_bytes();
assert_eq!(bytes.len(), HEADER_SIZE);
assert_eq!(bytes[0], 0x12); assert_eq!(bytes[1], 0x34); assert_eq!(bytes[2], 10); assert_eq!(bytes[3], 5); assert_eq!(bytes[4], 0x12);
let restored = ConstrainedHeader::from_bytes(&bytes).unwrap();
assert_eq!(restored, header);
}
#[test]
fn test_header_from_bytes_too_short() {
let result = ConstrainedHeader::from_bytes(&[1, 2, 3]);
assert!(result.is_err());
match result {
Err(ConstrainedError::PacketTooSmall { expected, actual }) => {
assert_eq!(expected, HEADER_SIZE);
assert_eq!(actual, 3);
}
_ => panic!("Expected PacketTooSmall error"),
}
}
#[test]
fn test_syn_header() {
let header = ConstrainedHeader::syn(ConnectionId::new(0xABCD));
assert!(header.is_syn());
assert!(!header.is_ack());
assert_eq!(header.seq, SequenceNumber::new(0));
}
#[test]
fn test_syn_ack_header() {
let header = ConstrainedHeader::syn_ack(ConnectionId::new(0xABCD), SequenceNumber::new(1));
assert!(header.is_syn());
assert!(header.is_ack());
assert!(header.is_syn_ack());
assert_eq!(header.ack, SequenceNumber::new(1));
}
#[test]
fn test_data_header() {
let header = ConstrainedHeader::data(
ConnectionId::new(0x1234),
SequenceNumber::new(5),
SequenceNumber::new(3),
);
assert!(header.is_data());
assert!(header.is_ack());
assert!(!header.is_syn());
}
#[test]
fn test_fin_header() {
let header = ConstrainedHeader::fin(
ConnectionId::new(0x1234),
SequenceNumber::new(10),
SequenceNumber::new(8),
);
assert!(header.is_fin());
assert!(header.is_ack());
}
#[test]
fn test_reset_header() {
let header = ConstrainedHeader::reset(ConnectionId::new(0x1234));
assert!(header.is_rst());
assert!(!header.is_ack());
}
#[test]
fn test_ping_pong_headers() {
let ping = ConstrainedHeader::ping(ConnectionId::new(0x1234), SequenceNumber::new(5));
assert!(ping.is_ping());
assert!(!ping.is_pong());
let pong = ConstrainedHeader::pong(ConnectionId::new(0x1234), SequenceNumber::new(5));
assert!(pong.is_pong());
assert!(!pong.is_ping());
}
#[test]
fn test_header_display() {
let header = ConstrainedHeader::data(
ConnectionId::new(0xABCD),
SequenceNumber::new(10),
SequenceNumber::new(5),
);
let display = format!("{}", header);
assert!(display.contains("ABCD"));
assert!(display.contains("SEQ:10"));
assert!(display.contains("ACK|DATA"));
}
#[test]
fn test_packet_serialization() {
let packet = ConstrainedPacket::data(
ConnectionId::new(0x1234),
SequenceNumber::new(5),
SequenceNumber::new(3),
b"Hello".to_vec(),
);
assert_eq!(packet.total_size(), HEADER_SIZE + 5);
assert!(packet.has_payload());
let bytes = packet.to_bytes();
assert_eq!(bytes.len(), HEADER_SIZE + 5);
assert_eq!(&bytes[HEADER_SIZE..], b"Hello");
let restored = ConstrainedPacket::from_bytes(&bytes).unwrap();
assert_eq!(restored.header, packet.header);
assert_eq!(restored.payload, packet.payload);
}
#[test]
fn test_control_packet() {
let packet = ConstrainedPacket::control(ConstrainedHeader::syn(ConnectionId::new(0x1234)));
assert!(!packet.has_payload());
assert_eq!(packet.total_size(), HEADER_SIZE);
}
#[test]
fn test_packet_from_bytes_header_only() {
let header = ConstrainedHeader::ack(
ConnectionId::new(0x1234),
SequenceNumber::new(1),
SequenceNumber::new(0),
);
let bytes = header.to_bytes();
let packet = ConstrainedPacket::from_bytes(&bytes).unwrap();
assert_eq!(packet.header, header);
assert!(packet.payload.is_empty());
}
}