use bytes::{Buf, BufMut, Bytes, BytesMut};
pub const MAGIC: u16 = 0x4E45;
pub const VERSION: u8 = 1;
pub const HEADER_SIZE: usize = 68;
pub const PAYLOAD_LEN_OFFSET: usize = 64;
pub const TAG_SIZE: usize = 16;
pub const NONCE_SIZE: usize = 12;
pub const MAX_PACKET_SIZE: usize = 8192;
pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - HEADER_SIZE - TAG_SIZE;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(transparent)]
pub struct PacketFlags(u8);
impl PacketFlags {
pub const NONE: Self = Self(0);
pub const RELIABLE: Self = Self(0b0000_0001);
pub const NACK: Self = Self(0b0000_0010);
pub const PRIORITY: Self = Self(0b0000_0100);
pub const FIN: Self = Self(0b0000_1000);
pub const HANDSHAKE: Self = Self(0b0001_0000);
pub const HEARTBEAT: Self = Self(0b0010_0000);
#[inline]
pub const fn from_bits(bits: u8) -> Self {
Self(bits)
}
#[inline]
pub const fn bits(self) -> u8 {
self.0
}
#[inline]
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
#[inline]
pub const fn with(self, other: Self) -> Self {
Self(self.0 | other.0)
}
#[inline]
pub const fn without(self, other: Self) -> Self {
Self(self.0 & !other.0)
}
#[inline]
pub const fn is_handshake(self) -> bool {
self.contains(Self::HANDSHAKE)
}
#[inline]
pub const fn is_heartbeat(self) -> bool {
self.contains(Self::HEARTBEAT)
}
#[inline]
pub const fn is_reliable(self) -> bool {
self.contains(Self::RELIABLE)
}
#[inline]
pub const fn is_nack(self) -> bool {
self.contains(Self::NACK)
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C, align(8))]
pub struct NetHeader {
pub magic: u16,
pub version: u8,
pub flags: PacketFlags,
pub priority: u8,
pub hop_ttl: u8,
pub hop_count: u8,
pub frag_flags: u8,
pub subprotocol_id: u16,
pub channel_hash: u16,
pub nonce: [u8; NONCE_SIZE],
pub session_id: u64,
pub stream_id: u64,
pub sequence: u64,
pub origin_hash: u64,
pub subnet_id: u32,
pub fragment_id: u16,
pub fragment_offset: u16,
pub payload_len: u16,
pub event_count: u16,
}
const _: () = assert!(std::mem::size_of::<NetHeader>() == 72);
impl NetHeader {
#[inline]
pub fn new(
session_id: u64,
stream_id: u64,
sequence: u64,
nonce: [u8; NONCE_SIZE],
payload_len: u16,
event_count: u16,
flags: PacketFlags,
) -> Self {
Self {
magic: MAGIC,
version: VERSION,
flags,
priority: 0,
hop_ttl: 0,
hop_count: 0,
frag_flags: 0,
subprotocol_id: 0,
channel_hash: 0,
nonce,
session_id,
stream_id,
sequence,
subnet_id: 0,
origin_hash: 0,
fragment_id: 0,
fragment_offset: 0,
payload_len,
event_count,
}
}
#[inline]
pub fn handshake(payload_len: u16) -> Self {
Self::new(
0,
0,
0,
[0u8; NONCE_SIZE],
payload_len,
0,
PacketFlags::HANDSHAKE,
)
}
#[inline]
pub fn heartbeat(session_id: u64) -> Self {
Self::new(
session_id,
0,
0,
[0u8; NONCE_SIZE],
0,
0,
PacketFlags::HEARTBEAT,
)
}
#[inline]
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority;
self
}
#[inline]
pub fn with_hops(mut self, ttl: u8) -> Self {
self.hop_ttl = ttl;
self.hop_count = 0;
self
}
#[inline]
pub fn with_subprotocol(mut self, id: u16) -> Self {
self.subprotocol_id = id;
self
}
#[inline]
pub fn with_channel_hash(mut self, hash: u16) -> Self {
self.channel_hash = hash;
self
}
#[inline]
pub fn with_subnet(mut self, subnet_id: u32) -> Self {
self.subnet_id = subnet_id;
self
}
#[inline]
pub fn with_origin(mut self, origin_hash: u64) -> Self {
self.origin_hash = origin_hash;
self
}
#[inline]
pub fn with_fragment(mut self, id: u16, offset: u16, flags: u8) -> Self {
self.fragment_id = id;
self.fragment_offset = offset;
self.frag_flags = flags;
self
}
#[inline]
pub fn aad(&self) -> [u8; 56] {
let mut aad = [0u8; 56];
aad[0..2].copy_from_slice(&self.magic.to_le_bytes());
aad[2] = self.version;
aad[3] = self.flags.bits();
aad[4] = self.priority;
aad[5] = self.hop_ttl;
aad[7] = self.frag_flags;
aad[8..10].copy_from_slice(&self.subprotocol_id.to_le_bytes());
aad[10..12].copy_from_slice(&self.channel_hash.to_le_bytes());
aad[12..20].copy_from_slice(&self.session_id.to_le_bytes());
aad[20..28].copy_from_slice(&self.stream_id.to_le_bytes());
aad[28..36].copy_from_slice(&self.sequence.to_le_bytes());
aad[36..44].copy_from_slice(&self.origin_hash.to_le_bytes());
aad[44..48].copy_from_slice(&self.subnet_id.to_le_bytes());
aad[48..50].copy_from_slice(&self.fragment_id.to_le_bytes());
aad[50..52].copy_from_slice(&self.fragment_offset.to_le_bytes());
aad[52..54].copy_from_slice(&self.payload_len.to_le_bytes());
aad[54..56].copy_from_slice(&self.event_count.to_le_bytes());
aad
}
#[inline]
pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
let mut buf = [0u8; HEADER_SIZE];
let mut cursor = &mut buf[..];
cursor.put_u16_le(self.magic);
cursor.put_u8(self.version);
cursor.put_u8(self.flags.bits());
cursor.put_u8(self.priority);
cursor.put_u8(self.hop_ttl);
cursor.put_u8(self.hop_count);
cursor.put_u8(self.frag_flags);
cursor.put_u16_le(self.subprotocol_id);
cursor.put_u16_le(self.channel_hash);
cursor.put_slice(&self.nonce);
cursor.put_u64_le(self.session_id);
cursor.put_u64_le(self.stream_id);
cursor.put_u64_le(self.sequence);
cursor.put_u64_le(self.origin_hash);
cursor.put_u32_le(self.subnet_id);
cursor.put_u16_le(self.fragment_id);
cursor.put_u16_le(self.fragment_offset);
cursor.put_u16_le(self.payload_len);
cursor.put_u16_le(self.event_count);
buf
}
#[inline]
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() < HEADER_SIZE {
return None;
}
let mut cursor = &data[..HEADER_SIZE];
let magic = cursor.get_u16_le();
if magic != MAGIC {
return None;
}
let version = cursor.get_u8();
let flags = PacketFlags::from_bits(cursor.get_u8());
let priority = cursor.get_u8();
let hop_ttl = cursor.get_u8();
let hop_count = cursor.get_u8();
let frag_flags = cursor.get_u8();
let subprotocol_id = cursor.get_u16_le();
let channel_hash = cursor.get_u16_le();
let mut nonce = [0u8; NONCE_SIZE];
cursor.copy_to_slice(&mut nonce);
let session_id = cursor.get_u64_le();
let stream_id = cursor.get_u64_le();
let sequence = cursor.get_u64_le();
let origin_hash = cursor.get_u64_le();
let subnet_id = cursor.get_u32_le();
let fragment_id = cursor.get_u16_le();
let fragment_offset = cursor.get_u16_le();
let payload_len = cursor.get_u16_le();
let event_count = cursor.get_u16_le();
Some(Self {
magic,
version,
flags,
priority,
hop_ttl,
hop_count,
frag_flags,
subprotocol_id,
channel_hash,
nonce,
session_id,
stream_id,
sequence,
subnet_id,
origin_hash,
fragment_id,
fragment_offset,
payload_len,
event_count,
})
}
pub const MAX_EVENTS_PER_PACKET: u16 = (MAX_PAYLOAD_SIZE / EventFrame::LEN_SIZE) as u16;
#[inline]
pub fn validate(&self) -> bool {
self.magic == MAGIC
&& self.version == VERSION
&& (self.payload_len as usize) <= MAX_PAYLOAD_SIZE
&& self.event_count <= Self::MAX_EVENTS_PER_PACKET
}
}
pub struct EventFrame;
impl EventFrame {
pub const LEN_SIZE: usize = 4;
#[inline]
#[expect(
clippy::expect_used,
reason = "events larger than u32::MAX (~4 GiB) are an invariant violation upstream — a panic on encode is better than a silent length-prefix truncation that would corrupt the framed stream"
)]
pub fn write_events(events: &[Bytes], buf: &mut BytesMut) -> usize {
let start = buf.len();
for event in events {
let len = u32::try_from(event.len())
.expect("event length exceeds u32::MAX — cannot encode in 4-byte length prefix");
buf.put_u32_le(len);
buf.put_slice(event);
}
buf.len() - start
}
#[inline]
pub fn read_events(mut data: Bytes, count: u16) -> Vec<Bytes> {
let max_events = data.remaining() / Self::LEN_SIZE;
let mut events = Vec::with_capacity((count as usize).min(max_events));
for _ in 0..count {
if data.remaining() < Self::LEN_SIZE {
break;
}
let len = data.get_u32_le() as usize;
if data.remaining() < len {
break;
}
events.push(data.split_to(len));
}
events
}
#[inline]
pub fn calculate_size(events: &[Bytes]) -> usize {
events.iter().map(|e| Self::LEN_SIZE + e.len()).sum()
}
}
pub struct NackPayload {
pub next_expected: u64,
pub missing_bitmap: u64,
}
impl std::fmt::Debug for NackPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NackPayload")
.field("next_expected", &self.next_expected)
.field(
"missing_bitmap",
&format_args!("{:#b}", self.missing_bitmap),
)
.finish()
}
}
impl Clone for NackPayload {
fn clone(&self) -> Self {
*self
}
}
impl Copy for NackPayload {}
impl NackPayload {
pub const SIZE: usize = 16;
#[inline]
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0..8].copy_from_slice(&self.next_expected.to_le_bytes());
buf[8..16].copy_from_slice(&self.missing_bitmap.to_le_bytes());
buf
}
#[inline]
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() != Self::SIZE {
return None;
}
let next_expected = u64::from_le_bytes(data[0..8].try_into().ok()?);
let missing_bitmap = u64::from_le_bytes(data[8..16].try_into().ok()?);
Some(Self {
next_expected,
missing_bitmap,
})
}
#[inline]
pub fn missing_sequences(&self) -> impl Iterator<Item = u64> + '_ {
let base = self.next_expected;
std::iter::once(base).chain((0..64).filter_map(move |i| {
if (self.missing_bitmap >> i) & 1 != 0 {
Some(base.saturating_add(1).saturating_add(i))
} else {
None
}
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_size() {
assert_eq!(HEADER_SIZE, 68);
assert_eq!(std::mem::size_of::<NetHeader>(), 72);
}
#[test]
fn test_header_roundtrip() {
let nonce = [0x42u8; NONCE_SIZE];
let header = NetHeader::new(
0x1234567890ABCDEF,
0xFEDCBA0987654321,
42,
nonce,
1024,
10,
PacketFlags::RELIABLE,
)
.with_priority(7)
.with_hops(32)
.with_subprotocol(0x0100)
.with_channel_hash(0xABCD)
.with_subnet(0x12345678)
.with_origin(0xDEADBEEF)
.with_fragment(1, 512, 0x01);
let bytes = header.to_bytes();
let parsed = NetHeader::from_bytes(&bytes).unwrap();
assert_eq!(parsed.magic, MAGIC);
assert_eq!(parsed.version, VERSION);
assert_eq!(parsed.flags, PacketFlags::RELIABLE);
assert_eq!(parsed.priority, 7);
assert_eq!(parsed.hop_ttl, 32);
assert_eq!(parsed.hop_count, 0);
assert_eq!(parsed.frag_flags, 0x01);
assert_eq!(parsed.subprotocol_id, 0x0100);
assert_eq!(parsed.channel_hash, 0xABCD);
assert_eq!(parsed.nonce, nonce);
assert_eq!(parsed.session_id, 0x1234567890ABCDEF);
assert_eq!(parsed.stream_id, 0xFEDCBA0987654321);
assert_eq!(parsed.sequence, 42);
assert_eq!(parsed.subnet_id, 0x12345678);
assert_eq!(parsed.origin_hash, 0xDEADBEEF);
assert_eq!(parsed.fragment_id, 1);
assert_eq!(parsed.fragment_offset, 512);
assert_eq!(parsed.payload_len, 1024);
assert_eq!(parsed.event_count, 10);
}
#[test]
fn origin_hash_preserves_high_bits_across_wire_roundtrip() {
const HIGH_BITS_HASH: u64 = 0xCAFEBABE_DEADBEEF;
debug_assert!(
HIGH_BITS_HASH > u32::MAX as u64,
"test hash must exercise the high 32 bits"
);
let header = NetHeader::new(
0xAAAA_BBBB_CCCC_DDDD,
0,
1,
[0u8; NONCE_SIZE],
64,
1,
PacketFlags::NONE,
)
.with_origin(HIGH_BITS_HASH);
assert_eq!(header.origin_hash, HIGH_BITS_HASH);
let bytes = header.to_bytes();
let parsed = NetHeader::from_bytes(&bytes).expect("parse");
assert_eq!(parsed.origin_hash, HIGH_BITS_HASH);
let aad_full = header.aad();
let mut tampered = header;
tampered.origin_hash ^= 1u64 << 33; let aad_tampered = tampered.aad();
assert_ne!(
aad_full, aad_tampered,
"AAD must authenticate all 64 bits of origin_hash"
);
}
#[test]
fn header_validation_field_isolation() {
const LOW_COMMON: u32 = 0xDEAD_BEEF;
let a: u64 = LOW_COMMON as u64;
let b: u64 = (0x4242_4242u64 << 32) | (LOW_COMMON as u64);
assert_eq!(a as u32, b as u32);
assert_ne!(a, b);
let mk = |h: u64| -> NetHeader {
NetHeader::new(1, 0, 1, [0u8; NONCE_SIZE], 0, 0, PacketFlags::NONE).with_origin(h)
};
let ha = NetHeader::from_bytes(&mk(a).to_bytes()).unwrap();
let hb = NetHeader::from_bytes(&mk(b).to_bytes()).unwrap();
assert_ne!(
ha.origin_hash, hb.origin_hash,
"low-32-bit collision must not collapse on the wire"
);
}
#[test]
fn test_header_validation() {
let header = NetHeader::new(0, 0, 0, [0u8; NONCE_SIZE], 1024, 0, PacketFlags::NONE);
assert!(header.validate());
let mut bytes = header.to_bytes();
bytes[0] = 0xFF;
let invalid = NetHeader::from_bytes(&bytes);
assert!(invalid.is_none());
}
#[test]
fn test_packet_flags() {
let flags = PacketFlags::NONE
.with(PacketFlags::RELIABLE)
.with(PacketFlags::PRIORITY);
assert!(flags.is_reliable());
assert!(flags.contains(PacketFlags::PRIORITY));
assert!(!flags.is_handshake());
let cleared = flags.without(PacketFlags::RELIABLE);
assert!(!cleared.is_reliable());
assert!(cleared.contains(PacketFlags::PRIORITY));
}
#[test]
fn test_aad() {
let header = NetHeader::new(
0x1234567890ABCDEF,
0xFEDCBA0987654321,
42,
[0u8; NONCE_SIZE],
1024,
10,
PacketFlags::RELIABLE,
)
.with_priority(5)
.with_subnet(0x42);
let aad = header.aad();
assert_eq!(aad.len(), 56);
assert_eq!(u16::from_le_bytes([aad[0], aad[1]]), MAGIC);
assert_eq!(aad[2], VERSION);
assert_eq!(aad[3], PacketFlags::RELIABLE.bits());
assert_eq!(aad[4], 5);
}
#[test]
fn test_event_frame_roundtrip() {
let events = vec![
Bytes::from_static(b"event1"),
Bytes::from_static(b"event2"),
Bytes::from_static(b"event3"),
];
let mut buf = BytesMut::with_capacity(256);
let size = EventFrame::write_events(&events, &mut buf);
assert_eq!(size, 3 * 4 + 6 + 6 + 6);
let parsed = EventFrame::read_events(buf.freeze(), 3);
assert_eq!(parsed.len(), 3);
assert_eq!(&parsed[0][..], b"event1");
assert_eq!(&parsed[1][..], b"event2");
assert_eq!(&parsed[2][..], b"event3");
}
#[test]
fn test_event_frame_length_prefix_fits_u32() {
let big = Bytes::from(vec![0xABu8; 64 * 1024]);
let events = vec![big.clone()];
let mut buf = BytesMut::with_capacity(64 * 1024 + 8);
let size = EventFrame::write_events(&events, &mut buf);
assert_eq!(size, EventFrame::LEN_SIZE + big.len());
let prefix = u32::from_le_bytes(buf[0..4].try_into().unwrap());
assert_eq!(prefix as usize, big.len());
}
#[test]
fn test_nack_payload_roundtrip() {
let nack = NackPayload {
next_expected: 100,
missing_bitmap: 0b1010_0101,
};
let bytes = nack.to_bytes();
let parsed = NackPayload::from_bytes(&bytes).unwrap();
assert_eq!(parsed.next_expected, 100);
assert_eq!(parsed.missing_bitmap, 0b1010_0101);
let missing: Vec<_> = parsed.missing_sequences().collect();
assert_eq!(missing, vec![100, 101, 103, 106, 108]);
}
#[test]
fn test_nack_payload_rejects_trailing_bytes() {
let nack = NackPayload {
next_expected: 1,
missing_bitmap: 0b10,
};
let mut bytes = nack.to_bytes().to_vec();
bytes.push(0xFF);
assert!(
NackPayload::from_bytes(&bytes).is_none(),
"NackPayload::from_bytes must reject buffers longer than SIZE"
);
}
#[test]
fn test_validate_rejects_excessive_event_count() {
let header = NetHeader::new(0, 0, 0, [0u8; NONCE_SIZE], 100, 10, PacketFlags::NONE);
assert!(header.validate());
let header = NetHeader::new(
0,
0,
0,
[0u8; NONCE_SIZE],
100,
NetHeader::MAX_EVENTS_PER_PACKET + 1,
PacketFlags::NONE,
);
assert!(!header.validate());
}
#[test]
fn test_read_events_caps_allocation() {
let data = Bytes::from_static(b"");
let events = EventFrame::read_events(data, u16::MAX);
assert!(events.is_empty());
assert!(events.capacity() <= 1);
}
#[test]
fn test_regression_hop_count_excluded_from_aad() {
let header1 = NetHeader::new(
0x1234,
0x5678,
42,
[0u8; NONCE_SIZE],
100,
5,
PacketFlags::NONE,
);
let mut header2 = header1;
header2.hop_count = 99;
assert_eq!(
header1.aad(),
header2.aad(),
"AAD must be identical regardless of hop_count (mutable in transit)"
);
}
}