Skip to main content

proof_engine/networking/
protocol.rs

1//! Multiplayer networking protocol: packet encoding, decoding, auth tokens,
2//! bandwidth tracking, and packet filtering.
3//!
4//! Custom binary wire format (no serde dependency):
5//! - Variable-length integers for sequence numbers
6//! - Bit-packed flags in the header
7//! - Delta encoding hints for position payloads
8//! - 16-byte header, variable payload
9
10// ─── Constants ───────────────────────────────────────────────────────────────
11
12/// Wire protocol version negotiated on connect.
13pub const PROTOCOL_VERSION: u8 = 1;
14
15/// Magic bytes at the start of every packet (PEMP = Proof-Engine Multiplayer Protocol).
16pub const MAGIC: [u8; 4] = [0x50, 0x45, 0x4D, 0x50];
17
18/// Maximum allowed payload length (64 KiB).
19pub const MAX_PAYLOAD_LEN: usize = 65535;
20
21/// Maximum packets stored in replay/filter history.
22pub const FILTER_HISTORY_LEN: usize = 1024;
23
24// ─── PacketKind ──────────────────────────────────────────────────────────────
25
26/// Discriminant for every packet type on the wire.
27///
28/// Encoded as a `u16` in the packet header so `Custom(u16)` can carry
29/// application-defined packet kinds without colliding with the well-known range
30/// (0x0000–0x001D).
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum PacketKind {
33    Connect,           // 0x00
34    Disconnect,        // 0x01
35    Heartbeat,         // 0x02
36    StateUpdate,       // 0x03
37    InputEvent,        // 0x04
38    ChatMessage,       // 0x05
39    SpawnEntity,       // 0x06
40    DespawnEntity,     // 0x07
41    UpdateTransform,   // 0x08
42    AnimationState,    // 0x09
43    SoundEvent,        // 0x0A
44    ParticleEvent,     // 0x0B
45    ForceFieldSync,    // 0x0C
46    CameraUpdate,      // 0x0D
47    ScriptCall,        // 0x0E
48    ScriptResult,      // 0x0F
49    Ack,               // 0x10
50    Nack,              // 0x11
51    Ping,              // 0x12
52    Pong,              // 0x13
53    LobbyJoin,         // 0x14
54    LobbyLeave,        // 0x15
55    GameStart,         // 0x16
56    GameEnd,           // 0x17
57    VoteKick,          // 0x18
58    VoteBan,           // 0x19
59    FileRequest,       // 0x1A
60    FileChunk,         // 0x1B
61    Error,             // 0x1C
62    Custom(u16),       // 0x8000–0xFFFF (high-bit set)
63}
64
65impl PacketKind {
66    /// Encode to a `u16` discriminant.
67    pub fn to_u16(self) -> u16 {
68        match self {
69            PacketKind::Connect        => 0x00,
70            PacketKind::Disconnect     => 0x01,
71            PacketKind::Heartbeat      => 0x02,
72            PacketKind::StateUpdate    => 0x03,
73            PacketKind::InputEvent     => 0x04,
74            PacketKind::ChatMessage    => 0x05,
75            PacketKind::SpawnEntity    => 0x06,
76            PacketKind::DespawnEntity  => 0x07,
77            PacketKind::UpdateTransform => 0x08,
78            PacketKind::AnimationState => 0x09,
79            PacketKind::SoundEvent     => 0x0A,
80            PacketKind::ParticleEvent  => 0x0B,
81            PacketKind::ForceFieldSync => 0x0C,
82            PacketKind::CameraUpdate   => 0x0D,
83            PacketKind::ScriptCall     => 0x0E,
84            PacketKind::ScriptResult   => 0x0F,
85            PacketKind::Ack            => 0x10,
86            PacketKind::Nack           => 0x11,
87            PacketKind::Ping           => 0x12,
88            PacketKind::Pong           => 0x13,
89            PacketKind::LobbyJoin      => 0x14,
90            PacketKind::LobbyLeave     => 0x15,
91            PacketKind::GameStart      => 0x16,
92            PacketKind::GameEnd        => 0x17,
93            PacketKind::VoteKick       => 0x18,
94            PacketKind::VoteBan        => 0x19,
95            PacketKind::FileRequest    => 0x1A,
96            PacketKind::FileChunk      => 0x1B,
97            PacketKind::Error          => 0x1C,
98            PacketKind::Custom(v)      => 0x8000 | v,
99        }
100    }
101
102    /// Decode from a `u16` discriminant.
103    pub fn from_u16(v: u16) -> Result<Self, ProtocolError> {
104        if v & 0x8000 != 0 {
105            return Ok(PacketKind::Custom(v & 0x7FFF));
106        }
107        match v {
108            0x00 => Ok(PacketKind::Connect),
109            0x01 => Ok(PacketKind::Disconnect),
110            0x02 => Ok(PacketKind::Heartbeat),
111            0x03 => Ok(PacketKind::StateUpdate),
112            0x04 => Ok(PacketKind::InputEvent),
113            0x05 => Ok(PacketKind::ChatMessage),
114            0x06 => Ok(PacketKind::SpawnEntity),
115            0x07 => Ok(PacketKind::DespawnEntity),
116            0x08 => Ok(PacketKind::UpdateTransform),
117            0x09 => Ok(PacketKind::AnimationState),
118            0x0A => Ok(PacketKind::SoundEvent),
119            0x0B => Ok(PacketKind::ParticleEvent),
120            0x0C => Ok(PacketKind::ForceFieldSync),
121            0x0D => Ok(PacketKind::CameraUpdate),
122            0x0E => Ok(PacketKind::ScriptCall),
123            0x0F => Ok(PacketKind::ScriptResult),
124            0x10 => Ok(PacketKind::Ack),
125            0x11 => Ok(PacketKind::Nack),
126            0x12 => Ok(PacketKind::Ping),
127            0x13 => Ok(PacketKind::Pong),
128            0x14 => Ok(PacketKind::LobbyJoin),
129            0x15 => Ok(PacketKind::LobbyLeave),
130            0x16 => Ok(PacketKind::GameStart),
131            0x17 => Ok(PacketKind::GameEnd),
132            0x18 => Ok(PacketKind::VoteKick),
133            0x19 => Ok(PacketKind::VoteBan),
134            0x1A => Ok(PacketKind::FileRequest),
135            0x1B => Ok(PacketKind::FileChunk),
136            0x1C => Ok(PacketKind::Error),
137            other => Err(ProtocolError::UnknownPacketKind(other)),
138        }
139    }
140}
141
142// ─── CompressionHint ─────────────────────────────────────────────────────────
143
144/// Indicates how the payload bytes are compressed.
145/// The receiver must use matching decompression.
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum CompressionHint {
148    None,
149    Zlib,
150    Lz4,
151}
152
153impl CompressionHint {
154    pub fn to_u8(self) -> u8 {
155        match self {
156            CompressionHint::None => 0,
157            CompressionHint::Zlib => 1,
158            CompressionHint::Lz4  => 2,
159        }
160    }
161
162    pub fn from_u8(v: u8) -> Result<Self, ProtocolError> {
163        match v {
164            0 => Ok(CompressionHint::None),
165            1 => Ok(CompressionHint::Zlib),
166            2 => Ok(CompressionHint::Lz4),
167            _ => Err(ProtocolError::InvalidCompression(v)),
168        }
169    }
170}
171
172// ─── PacketHeader ─────────────────────────────────────────────────────────────
173
174/// Fixed-size header that precedes every packet on the wire.
175///
176/// Wire layout (20 bytes):
177/// ```text
178/// [0..4]   magic       PEMP
179/// [4]      version     u8
180/// [5]      flags       u8  (bits: 0-1 compression, 2 reliable, 3 ordered, 4 fragmented, 5-7 reserved)
181/// [6..8]   kind        u16 big-endian
182/// [8..12]  sequence    u32 big-endian
183/// [12..16] ack         u32 big-endian
184/// [16..20] ack_bits    u32 big-endian
185/// [20..22] payload_len u16 big-endian
186/// ```
187/// Total header: 22 bytes.
188#[derive(Debug, Clone, PartialEq, Eq)]
189pub struct PacketHeader {
190    pub version:     u8,
191    pub flags:       u8,
192    pub kind:        PacketKind,
193    pub sequence:    u32,
194    pub ack:         u32,
195    pub ack_bits:    u32,
196    pub payload_len: u16,
197}
198
199impl PacketHeader {
200    pub const SIZE: usize = 22;
201
202    /// Flag bit: payload is reliable (must be acked).
203    pub const FLAG_RELIABLE:   u8 = 0b0000_0100;
204    /// Flag bit: channel is ordered.
205    pub const FLAG_ORDERED:    u8 = 0b0000_1000;
206    /// Flag bit: packet is a fragment of a larger message.
207    pub const FLAG_FRAGMENTED: u8 = 0b0001_0000;
208
209    /// Extract compression hint from flags bits 0-1.
210    pub fn compression(&self) -> Result<CompressionHint, ProtocolError> {
211        CompressionHint::from_u8(self.flags & 0x03)
212    }
213
214    pub fn is_reliable(&self) -> bool {
215        self.flags & Self::FLAG_RELIABLE != 0
216    }
217
218    pub fn is_ordered(&self) -> bool {
219        self.flags & Self::FLAG_ORDERED != 0
220    }
221
222    pub fn is_fragmented(&self) -> bool {
223        self.flags & Self::FLAG_FRAGMENTED != 0
224    }
225}
226
227// ─── Packet ───────────────────────────────────────────────────────────────────
228
229/// A fully-parsed network packet ready for dispatch.
230#[derive(Debug, Clone, PartialEq, Eq)]
231pub struct Packet {
232    pub kind:     PacketKind,
233    pub sequence: u32,
234    pub ack:      u32,
235    pub ack_bits: u32,
236    pub payload:  Vec<u8>,
237    /// Extra header flags preserved for routing decisions.
238    pub flags:    u8,
239}
240
241impl Packet {
242    pub fn new(kind: PacketKind, sequence: u32, ack: u32, ack_bits: u32, payload: Vec<u8>) -> Self {
243        Self { kind, sequence, ack, ack_bits, payload, flags: 0 }
244    }
245
246    pub fn with_flags(mut self, flags: u8) -> Self {
247        self.flags = flags;
248        self
249    }
250
251    pub fn is_reliable(&self) -> bool {
252        self.flags & PacketHeader::FLAG_RELIABLE != 0
253    }
254
255    /// Heartbeat shorthand.
256    pub fn heartbeat(sequence: u32, ack: u32, ack_bits: u32) -> Self {
257        Self::new(PacketKind::Heartbeat, sequence, ack, ack_bits, Vec::new())
258    }
259
260    /// Ping with 8-byte timestamp payload.
261    pub fn ping(sequence: u32, ack: u32, ack_bits: u32, timestamp_us: u64) -> Self {
262        let mut payload = Vec::with_capacity(8);
263        payload.extend_from_slice(&timestamp_us.to_be_bytes());
264        Self::new(PacketKind::Ping, sequence, ack, ack_bits, payload)
265    }
266
267    /// Pong mirrors the ping timestamp plus local receive timestamp.
268    pub fn pong(sequence: u32, ack: u32, ack_bits: u32, ping_ts: u64, recv_ts: u64) -> Self {
269        let mut payload = Vec::with_capacity(16);
270        payload.extend_from_slice(&ping_ts.to_be_bytes());
271        payload.extend_from_slice(&recv_ts.to_be_bytes());
272        Self::new(PacketKind::Pong, sequence, ack, ack_bits, payload)
273    }
274}
275
276// ─── ProtocolError ────────────────────────────────────────────────────────────
277
278/// All errors that can arise from encoding or decoding packets.
279#[derive(Debug, Clone, PartialEq, Eq)]
280pub enum ProtocolError {
281    /// Input buffer is too short for a complete header.
282    BufferTooShort { needed: usize, got: usize },
283    /// Magic bytes do not match.
284    BadMagic([u8; 4]),
285    /// Protocol version mismatch.
286    VersionMismatch { expected: u8, got: u8 },
287    /// Payload length in header exceeds remaining buffer.
288    PayloadTruncated { declared: usize, available: usize },
289    /// Payload exceeds the maximum allowed size.
290    PayloadTooLarge(usize),
291    /// Unknown packet kind discriminant.
292    UnknownPacketKind(u16),
293    /// Unknown compression tag.
294    InvalidCompression(u8),
295    /// Packet was identified as a replay attack.
296    ReplayDetected { sequence: u32 },
297    /// Generic encode error with a message.
298    EncodeError(String),
299}
300
301impl std::fmt::Display for ProtocolError {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        match self {
304            Self::BufferTooShort { needed, got } =>
305                write!(f, "buffer too short: need {needed} bytes, got {got}"),
306            Self::BadMagic(m) =>
307                write!(f, "bad magic bytes: {:02x}{:02x}{:02x}{:02x}", m[0], m[1], m[2], m[3]),
308            Self::VersionMismatch { expected, got } =>
309                write!(f, "version mismatch: expected {expected}, got {got}"),
310            Self::PayloadTruncated { declared, available } =>
311                write!(f, "payload truncated: declared {declared} bytes, only {available} available"),
312            Self::PayloadTooLarge(n) =>
313                write!(f, "payload too large: {n} bytes"),
314            Self::UnknownPacketKind(k) =>
315                write!(f, "unknown packet kind: 0x{k:04x}"),
316            Self::InvalidCompression(c) =>
317                write!(f, "invalid compression tag: {c}"),
318            Self::ReplayDetected { sequence } =>
319                write!(f, "replay attack detected: sequence {sequence}"),
320            Self::EncodeError(s) =>
321                write!(f, "encode error: {s}"),
322        }
323    }
324}
325
326impl std::error::Error for ProtocolError {}
327
328// ─── PacketEncoder ────────────────────────────────────────────────────────────
329
330/// Serializes `Packet` values to a contiguous byte buffer.
331///
332/// The encoder does NOT own any state between calls; call `encode` for each
333/// packet and write the returned `Vec<u8>` to your socket.
334pub struct PacketEncoder {
335    /// Compression hint written into the header flags.
336    pub compression: CompressionHint,
337    /// Whether the reliable flag should be set on encoded packets.
338    pub reliable: bool,
339    /// Whether the ordered flag should be set on encoded packets.
340    pub ordered: bool,
341}
342
343impl Default for PacketEncoder {
344    fn default() -> Self {
345        Self {
346            compression: CompressionHint::None,
347            reliable: false,
348            ordered: false,
349        }
350    }
351}
352
353impl PacketEncoder {
354    pub fn new() -> Self {
355        Self::default()
356    }
357
358    /// Build flags byte from encoder settings and any packet-level flags.
359    fn build_flags(&self, extra: u8) -> u8 {
360        let mut f = self.compression.to_u8(); // bits 0-1
361        if self.reliable { f |= PacketHeader::FLAG_RELIABLE; }
362        if self.ordered  { f |= PacketHeader::FLAG_ORDERED; }
363        f | (extra & PacketHeader::FLAG_FRAGMENTED)
364    }
365
366    /// Encode a `Packet` to bytes.  Returns `Err` if payload exceeds limit.
367    pub fn encode(&self, packet: &Packet) -> Result<Vec<u8>, ProtocolError> {
368        let payload_len = packet.payload.len();
369        if payload_len > MAX_PAYLOAD_LEN {
370            return Err(ProtocolError::PayloadTooLarge(payload_len));
371        }
372        let total = PacketHeader::SIZE + payload_len;
373        let mut buf = Vec::with_capacity(total);
374
375        // Magic
376        buf.extend_from_slice(&MAGIC);
377        // Version
378        buf.push(PROTOCOL_VERSION);
379        // Flags
380        buf.push(self.build_flags(packet.flags));
381        // Kind (u16 BE)
382        buf.extend_from_slice(&packet.kind.to_u16().to_be_bytes());
383        // Sequence (u32 BE)
384        buf.extend_from_slice(&packet.sequence.to_be_bytes());
385        // Ack (u32 BE)
386        buf.extend_from_slice(&packet.ack.to_be_bytes());
387        // Ack-bits (u32 BE)
388        buf.extend_from_slice(&packet.ack_bits.to_be_bytes());
389        // Payload length (u16 BE)
390        buf.extend_from_slice(&(payload_len as u16).to_be_bytes());
391        // Payload
392        buf.extend_from_slice(&packet.payload);
393
394        debug_assert_eq!(buf.len(), total);
395        Ok(buf)
396    }
397
398    /// Encode a sequence of packets back-to-back (UDP batch / TCP stream).
399    pub fn encode_batch(&self, packets: &[Packet]) -> Result<Vec<u8>, ProtocolError> {
400        let mut out = Vec::new();
401        for p in packets {
402            out.extend(self.encode(p)?);
403        }
404        Ok(out)
405    }
406
407    /// Variable-length encode a u64 (LEB128).
408    pub fn encode_varint(value: u64, out: &mut Vec<u8>) {
409        let mut v = value;
410        loop {
411            let byte = (v & 0x7F) as u8;
412            v >>= 7;
413            if v == 0 {
414                out.push(byte);
415                break;
416            } else {
417                out.push(byte | 0x80);
418            }
419        }
420    }
421
422    /// Delta-encode an f32 position value as a 16-bit fixed-point delta.
423    /// Returns the quantised delta in centimetres (±327.67 m range at 1 cm precision).
424    pub fn encode_position_delta(from: f32, to: f32) -> i16 {
425        let delta_cm = ((to - from) * 100.0).round();
426        delta_cm.clamp(i16::MIN as f32, i16::MAX as f32) as i16
427    }
428
429    /// Bit-pack up to 8 booleans into a single byte.
430    pub fn pack_bools(flags: &[bool]) -> u8 {
431        let mut byte = 0u8;
432        for (i, &b) in flags.iter().take(8).enumerate() {
433            if b { byte |= 1 << i; }
434        }
435        byte
436    }
437}
438
439// ─── PacketDecoder ────────────────────────────────────────────────────────────
440
441/// Deserializes packets from a byte buffer with strict bounds checking.
442pub struct PacketDecoder {
443    /// When `true` reject packets whose version differs from `PROTOCOL_VERSION`.
444    pub strict_version: bool,
445}
446
447impl Default for PacketDecoder {
448    fn default() -> Self {
449        Self { strict_version: true }
450    }
451}
452
453impl PacketDecoder {
454    pub fn new() -> Self {
455        Self::default()
456    }
457
458    /// Decode exactly one packet starting at the beginning of `buf`.
459    /// Returns `(packet, bytes_consumed)` on success.
460    pub fn decode(&self, buf: &[u8]) -> Result<(Packet, usize), ProtocolError> {
461        // Minimum header check
462        if buf.len() < PacketHeader::SIZE {
463            return Err(ProtocolError::BufferTooShort {
464                needed: PacketHeader::SIZE,
465                got:    buf.len(),
466            });
467        }
468
469        // Magic
470        let magic: [u8; 4] = buf[0..4].try_into().unwrap();
471        if magic != MAGIC {
472            return Err(ProtocolError::BadMagic(magic));
473        }
474
475        let version     = buf[4];
476        let flags       = buf[5];
477        let kind_raw    = u16::from_be_bytes([buf[6], buf[7]]);
478        let sequence    = u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]);
479        let ack         = u32::from_be_bytes([buf[12], buf[13], buf[14], buf[15]]);
480        let ack_bits    = u32::from_be_bytes([buf[16], buf[17], buf[18], buf[19]]);
481        let payload_len = u16::from_be_bytes([buf[20], buf[21]]) as usize;
482
483        if self.strict_version && version != PROTOCOL_VERSION {
484            return Err(ProtocolError::VersionMismatch {
485                expected: PROTOCOL_VERSION,
486                got:      version,
487            });
488        }
489
490        if payload_len > MAX_PAYLOAD_LEN {
491            return Err(ProtocolError::PayloadTooLarge(payload_len));
492        }
493
494        let total = PacketHeader::SIZE + payload_len;
495        if buf.len() < total {
496            return Err(ProtocolError::PayloadTruncated {
497                declared:  payload_len,
498                available: buf.len().saturating_sub(PacketHeader::SIZE),
499            });
500        }
501
502        let kind = PacketKind::from_u16(kind_raw)?;
503        let payload = buf[PacketHeader::SIZE..total].to_vec();
504
505        let packet = Packet {
506            kind,
507            sequence,
508            ack,
509            ack_bits,
510            payload,
511            flags,
512        };
513
514        Ok((packet, total))
515    }
516
517    /// Decode all packets packed end-to-end in `buf`.
518    pub fn decode_all(&self, buf: &[u8]) -> Result<Vec<Packet>, ProtocolError> {
519        let mut packets = Vec::new();
520        let mut offset  = 0usize;
521        while offset < buf.len() {
522            let (pkt, consumed) = self.decode(&buf[offset..])?;
523            packets.push(pkt);
524            offset += consumed;
525        }
526        Ok(packets)
527    }
528
529    /// Decode a variable-length integer (LEB128) from `buf` at `offset`.
530    /// Returns `(value, new_offset)`.
531    pub fn decode_varint(buf: &[u8], offset: usize) -> Result<(u64, usize), ProtocolError> {
532        let mut result = 0u64;
533        let mut shift  = 0u32;
534        let mut pos    = offset;
535        loop {
536            if pos >= buf.len() {
537                return Err(ProtocolError::BufferTooShort {
538                    needed: pos + 1,
539                    got:    buf.len(),
540                });
541            }
542            let byte = buf[pos] as u64;
543            pos += 1;
544            result |= (byte & 0x7F) << shift;
545            if byte & 0x80 == 0 {
546                break;
547            }
548            shift += 7;
549            if shift >= 64 {
550                return Err(ProtocolError::EncodeError("varint overflow".into()));
551            }
552        }
553        Ok((result, pos))
554    }
555
556    /// Unpack up to 8 booleans from a flags byte.
557    pub fn unpack_bools(byte: u8, count: usize) -> [bool; 8] {
558        let mut out = [false; 8];
559        for i in 0..count.min(8) {
560            out[i] = (byte >> i) & 1 != 0;
561        }
562        out
563    }
564
565    /// Decode a 16-bit position delta back to an f32 offset.
566    pub fn decode_position_delta(delta: i16) -> f32 {
567        delta as f32 / 100.0
568    }
569}
570
571// ─── ConnectionToken ─────────────────────────────────────────────────────────
572
573/// Opaque token issued by the auth server, embedded in the Connect packet.
574///
575/// The server verifies `client_id` matches `server_key` HMAC and that
576/// `expires_at` (Unix seconds) has not elapsed.
577#[derive(Debug, Clone, PartialEq, Eq)]
578pub struct ConnectionToken {
579    pub client_id:  u64,
580    pub server_key: [u8; 16],
581    pub expires_at: u64,
582}
583
584impl ConnectionToken {
585    pub const SIZE: usize = 32; // 8 + 16 + 8
586
587    pub fn new(client_id: u64, server_key: [u8; 16], expires_at: u64) -> Self {
588        Self { client_id, server_key, expires_at }
589    }
590
591    /// Serialize to exactly 32 bytes.
592    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
593        let mut out = [0u8; Self::SIZE];
594        out[0..8].copy_from_slice(&self.client_id.to_be_bytes());
595        out[8..24].copy_from_slice(&self.server_key);
596        out[24..32].copy_from_slice(&self.expires_at.to_be_bytes());
597        out
598    }
599
600    /// Deserialize from 32 bytes.
601    pub fn from_bytes(b: &[u8]) -> Result<Self, ProtocolError> {
602        if b.len() < Self::SIZE {
603            return Err(ProtocolError::BufferTooShort {
604                needed: Self::SIZE,
605                got:    b.len(),
606            });
607        }
608        let client_id = u64::from_be_bytes(b[0..8].try_into().unwrap());
609        let server_key: [u8; 16] = b[8..24].try_into().unwrap();
610        let expires_at = u64::from_be_bytes(b[24..32].try_into().unwrap());
611        Ok(Self { client_id, server_key, expires_at })
612    }
613
614    /// Returns `true` when the token has not expired relative to `now_secs`.
615    pub fn is_valid(&self, now_secs: u64) -> bool {
616        now_secs < self.expires_at
617    }
618
619    /// Produce a simple 4-byte checksum used to verify the key field.
620    /// Real deployments should use HMAC-SHA256; this is a stand-in.
621    pub fn checksum(&self) -> u32 {
622        let mut h = 0x811c9dc5u32;
623        for b in &self.server_key {
624            h ^= *b as u32;
625            h = h.wrapping_mul(0x01000193);
626        }
627        h ^= self.client_id as u32;
628        h = h.wrapping_mul(0x01000193);
629        h
630    }
631}
632
633// ─── PacketFilter ─────────────────────────────────────────────────────────────
634
635/// Tracks recently seen sequence numbers to detect replay attacks and
636/// malformed packets before they reach higher-level code.
637pub struct PacketFilter {
638    /// Circular buffer of recently seen sequence numbers per peer.
639    seen: std::collections::HashMap<u64, SeenWindow>,
640    /// Maximum packets that can arrive with the same sequence before rejection.
641    max_duplicates: u32,
642}
643
644/// Sliding-window replay detection for a single peer.
645struct SeenWindow {
646    /// Highest sequence seen so far.
647    highest: u32,
648    /// Bitset: bit i set means (highest - i) has been seen.
649    bits: u64,
650}
651
652impl SeenWindow {
653    fn new() -> Self {
654        Self { highest: 0, bits: 0 }
655    }
656
657    /// Returns `true` if the sequence is new (not a duplicate/replay).
658    fn check_and_insert(&mut self, seq: u32) -> bool {
659        let diff = self.highest.wrapping_sub(seq);
660        if seq == self.highest && self.bits & 1 != 0 {
661            // exact duplicate of highest
662            return false;
663        }
664        if diff < 64 && diff > 0 {
665            // Older packet within window
666            let mask = 1u64 << diff;
667            if self.bits & mask != 0 {
668                return false; // already seen
669            }
670            self.bits |= mask;
671            return true;
672        }
673        if seq.wrapping_sub(self.highest) < 0x8000_0000 {
674            // New highest
675            let advance = seq.wrapping_sub(self.highest);
676            if advance >= 64 {
677                self.bits = 1;
678            } else {
679                self.bits = (self.bits << advance) | 1;
680            }
681            self.highest = seq;
682            return true;
683        }
684        // Packet is too old (more than 64 below highest) — reject
685        false
686    }
687}
688
689impl PacketFilter {
690    pub fn new() -> Self {
691        Self {
692            seen: std::collections::HashMap::new(),
693            max_duplicates: 0,
694        }
695    }
696
697    /// Register a peer and allow tracking for it.
698    pub fn register_peer(&mut self, peer_id: u64) {
699        self.seen.entry(peer_id).or_insert_with(SeenWindow::new);
700    }
701
702    /// Remove tracking for a disconnected peer.
703    pub fn remove_peer(&mut self, peer_id: u64) {
704        self.seen.remove(&peer_id);
705    }
706
707    /// Returns `Ok(())` if `packet` should be accepted, `Err` otherwise.
708    pub fn check(&mut self, peer_id: u64, packet: &Packet) -> Result<(), ProtocolError> {
709        // Basic sanity: payload size within declared limits
710        if packet.payload.len() > MAX_PAYLOAD_LEN {
711            return Err(ProtocolError::PayloadTooLarge(packet.payload.len()));
712        }
713
714        // Replay detection
715        let window = self.seen.entry(peer_id).or_insert_with(SeenWindow::new);
716        if !window.check_and_insert(packet.sequence) {
717            return Err(ProtocolError::ReplayDetected { sequence: packet.sequence });
718        }
719
720        Ok(())
721    }
722
723    /// Resets all state (e.g. on reconnect).
724    pub fn reset(&mut self) {
725        self.seen.clear();
726    }
727}
728
729impl Default for PacketFilter {
730    fn default() -> Self {
731        Self::new()
732    }
733}
734
735// ─── BandwidthTracker ─────────────────────────────────────────────────────────
736
737/// Rolling-window bandwidth meter.
738///
739/// Call `record_send` / `record_recv` with byte counts and the current
740/// millisecond timestamp.  Query `bytes_per_sec_up` / `bytes_per_sec_down`
741/// to get the rolling rate.
742pub struct BandwidthTracker {
743    /// Length of the rolling window in milliseconds.
744    pub window_ms: u64,
745    send_buckets: std::collections::VecDeque<(u64, usize)>,
746    recv_buckets: std::collections::VecDeque<(u64, usize)>,
747    total_sent: u64,
748    total_recv: u64,
749}
750
751impl BandwidthTracker {
752    pub fn new(window_ms: u64) -> Self {
753        Self {
754            window_ms,
755            send_buckets: std::collections::VecDeque::new(),
756            recv_buckets: std::collections::VecDeque::new(),
757            total_sent: 0,
758            total_recv: 0,
759        }
760    }
761
762    /// Default 1-second rolling window.
763    pub fn default_window() -> Self {
764        Self::new(1000)
765    }
766
767    /// Record `bytes` sent at time `now_ms`.
768    pub fn record_send(&mut self, bytes: usize, now_ms: u64) {
769        self.total_sent += bytes as u64;
770        self.send_buckets.push_back((now_ms, bytes));
771        self.evict_old(now_ms);
772    }
773
774    /// Record `bytes` received at time `now_ms`.
775    pub fn record_recv(&mut self, bytes: usize, now_ms: u64) {
776        self.total_recv += bytes as u64;
777        self.recv_buckets.push_back((now_ms, bytes));
778        self.evict_old(now_ms);
779    }
780
781    fn evict_old(&mut self, now_ms: u64) {
782        let cutoff = now_ms.saturating_sub(self.window_ms);
783        while let Some(&(ts, _)) = self.send_buckets.front() {
784            if ts < cutoff { self.send_buckets.pop_front(); } else { break; }
785        }
786        while let Some(&(ts, _)) = self.recv_buckets.front() {
787            if ts < cutoff { self.recv_buckets.pop_front(); } else { break; }
788        }
789    }
790
791    /// Bytes per second upload over the rolling window.
792    pub fn bytes_per_sec_up(&self, now_ms: u64) -> f64 {
793        let cutoff = now_ms.saturating_sub(self.window_ms);
794        let sum: usize = self.send_buckets.iter()
795            .filter(|&&(ts, _)| ts >= cutoff)
796            .map(|&(_, b)| b)
797            .sum();
798        (sum as f64) / (self.window_ms as f64 / 1000.0)
799    }
800
801    /// Bytes per second download over the rolling window.
802    pub fn bytes_per_sec_down(&self, now_ms: u64) -> f64 {
803        let cutoff = now_ms.saturating_sub(self.window_ms);
804        let sum: usize = self.recv_buckets.iter()
805            .filter(|&&(ts, _)| ts >= cutoff)
806            .map(|&(_, b)| b)
807            .sum();
808        (sum as f64) / (self.window_ms as f64 / 1000.0)
809    }
810
811    pub fn total_bytes_sent(&self) -> u64 { self.total_sent }
812    pub fn total_bytes_recv(&self) -> u64 { self.total_recv }
813}
814
815// ─── Tests ────────────────────────────────────────────────────────────────────
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820
821    fn make_packet(kind: PacketKind, seq: u32, payload: Vec<u8>) -> Packet {
822        Packet::new(kind, seq, 0, 0, payload)
823    }
824
825    // ── PacketKind round-trip ─────────────────────────────────────────────────
826
827    #[test]
828    fn test_packet_kind_roundtrip_well_known() {
829        let kinds = [
830            PacketKind::Connect, PacketKind::Disconnect, PacketKind::Heartbeat,
831            PacketKind::StateUpdate, PacketKind::InputEvent, PacketKind::ChatMessage,
832            PacketKind::SpawnEntity, PacketKind::DespawnEntity, PacketKind::UpdateTransform,
833            PacketKind::AnimationState, PacketKind::SoundEvent, PacketKind::ParticleEvent,
834            PacketKind::ForceFieldSync, PacketKind::CameraUpdate, PacketKind::ScriptCall,
835            PacketKind::ScriptResult, PacketKind::Ack, PacketKind::Nack,
836            PacketKind::Ping, PacketKind::Pong, PacketKind::LobbyJoin, PacketKind::LobbyLeave,
837            PacketKind::GameStart, PacketKind::GameEnd, PacketKind::VoteKick, PacketKind::VoteBan,
838            PacketKind::FileRequest, PacketKind::FileChunk, PacketKind::Error,
839        ];
840        for k in kinds {
841            let v = k.to_u16();
842            assert_eq!(PacketKind::from_u16(v).unwrap(), k, "round-trip for {k:?}");
843        }
844    }
845
846    #[test]
847    fn test_packet_kind_custom_roundtrip() {
848        let k = PacketKind::Custom(42);
849        assert_eq!(PacketKind::from_u16(k.to_u16()).unwrap(), k);
850    }
851
852    #[test]
853    fn test_packet_kind_unknown_returns_err() {
854        assert!(PacketKind::from_u16(0x1D).is_err());
855    }
856
857    // ── Encoder / Decoder round-trip ─────────────────────────────────────────
858
859    #[test]
860    fn test_encode_decode_roundtrip() {
861        let enc = PacketEncoder::new();
862        let dec = PacketDecoder::new();
863        let pkt = make_packet(PacketKind::ChatMessage, 7, b"hello world".to_vec());
864        let bytes = enc.encode(&pkt).unwrap();
865        let (decoded, consumed) = dec.decode(&bytes).unwrap();
866        assert_eq!(consumed, bytes.len());
867        assert_eq!(decoded.kind, pkt.kind);
868        assert_eq!(decoded.sequence, pkt.sequence);
869        assert_eq!(decoded.payload, pkt.payload);
870    }
871
872    #[test]
873    fn test_decode_too_short_returns_err() {
874        let dec = PacketDecoder::new();
875        let short = [0u8; 5];
876        assert!(matches!(dec.decode(&short), Err(ProtocolError::BufferTooShort { .. })));
877    }
878
879    #[test]
880    fn test_decode_bad_magic() {
881        let dec = PacketDecoder::new();
882        let mut bytes = vec![0u8; PacketHeader::SIZE];
883        // Don't write correct magic
884        assert!(matches!(dec.decode(&bytes), Err(ProtocolError::BadMagic(_))));
885        // Fix magic but bad version
886        bytes[0..4].copy_from_slice(&MAGIC);
887        bytes[4] = 99; // wrong version
888        assert!(matches!(dec.decode(&bytes), Err(ProtocolError::VersionMismatch { .. })));
889    }
890
891    #[test]
892    fn test_encode_batch_and_decode_all() {
893        let enc = PacketEncoder::new();
894        let dec = PacketDecoder::new();
895        let pkts = vec![
896            make_packet(PacketKind::Ping, 1, vec![1, 2, 3, 4, 5, 6, 7, 8]),
897            make_packet(PacketKind::Pong, 2, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
898            make_packet(PacketKind::Heartbeat, 3, vec![]),
899        ];
900        let bytes = enc.encode_batch(&pkts).unwrap();
901        let decoded = dec.decode_all(&bytes).unwrap();
902        assert_eq!(decoded.len(), pkts.len());
903        for (a, b) in pkts.iter().zip(decoded.iter()) {
904            assert_eq!(a.kind, b.kind);
905            assert_eq!(a.sequence, b.sequence);
906            assert_eq!(a.payload, b.payload);
907        }
908    }
909
910    // ── Varint ────────────────────────────────────────────────────────────────
911
912    #[test]
913    fn test_varint_roundtrip() {
914        let values: &[u64] = &[0, 1, 127, 128, 255, 300, 16383, 16384, u32::MAX as u64, u64::MAX / 2];
915        for &v in values {
916            let mut buf = Vec::new();
917            PacketEncoder::encode_varint(v, &mut buf);
918            let (decoded, _) = PacketDecoder::decode_varint(&buf, 0).unwrap();
919            assert_eq!(decoded, v, "varint roundtrip for {v}");
920        }
921    }
922
923    // ── ConnectionToken ───────────────────────────────────────────────────────
924
925    #[test]
926    fn test_connection_token_roundtrip() {
927        let tok = ConnectionToken::new(0xDEADBEEF_CAFEBABE, [0xAB; 16], 9999999999);
928        let bytes = tok.to_bytes();
929        let tok2 = ConnectionToken::from_bytes(&bytes).unwrap();
930        assert_eq!(tok, tok2);
931    }
932
933    #[test]
934    fn test_connection_token_validity() {
935        let tok = ConnectionToken::new(1, [0u8; 16], 1000);
936        assert!(tok.is_valid(999));
937        assert!(!tok.is_valid(1000));
938        assert!(!tok.is_valid(1001));
939    }
940
941    // ── PacketFilter ──────────────────────────────────────────────────────────
942
943    #[test]
944    fn test_packet_filter_accepts_new_sequences() {
945        let mut filter = PacketFilter::new();
946        let peer = 1u64;
947        for seq in 0u32..10 {
948            let pkt = make_packet(PacketKind::StateUpdate, seq, vec![]);
949            assert!(filter.check(peer, &pkt).is_ok(), "seq {seq} should be accepted");
950        }
951    }
952
953    #[test]
954    fn test_packet_filter_rejects_replay() {
955        let mut filter = PacketFilter::new();
956        let peer = 42u64;
957        let pkt = make_packet(PacketKind::StateUpdate, 5, vec![]);
958        assert!(filter.check(peer, &pkt).is_ok());
959        // Same sequence again — replay
960        assert!(matches!(
961            filter.check(peer, &pkt),
962            Err(ProtocolError::ReplayDetected { sequence: 5 })
963        ));
964    }
965
966    // ── BandwidthTracker ──────────────────────────────────────────────────────
967
968    #[test]
969    fn test_bandwidth_tracker_basic() {
970        let mut bw = BandwidthTracker::new(1000);
971        bw.record_send(500, 0);
972        bw.record_send(500, 500);
973        bw.record_recv(1024, 0);
974        assert_eq!(bw.total_bytes_sent(), 1000);
975        assert_eq!(bw.total_bytes_recv(), 1024);
976        // Within window
977        let up = bw.bytes_per_sec_up(999);
978        assert!(up > 0.0);
979    }
980
981    #[test]
982    fn test_bandwidth_tracker_evicts_old() {
983        let mut bw = BandwidthTracker::new(1000);
984        bw.record_send(9999, 0);
985        // 2 seconds later — old bucket evicted
986        bw.record_send(1, 2001);
987        let up = bw.bytes_per_sec_up(2001);
988        // Only the recent byte should be in the window
989        assert!(up < 5.0, "old data should be evicted, up={up}");
990    }
991
992    // ── Position delta ────────────────────────────────────────────────────────
993
994    #[test]
995    fn test_position_delta_encoding() {
996        let from = 10.0f32;
997        let to = 10.5f32;
998        let delta = PacketEncoder::encode_position_delta(from, to);
999        let recovered = from + PacketDecoder::decode_position_delta(delta);
1000        assert!((recovered - to).abs() < 0.01, "recovered={recovered}, expected={to}");
1001    }
1002
1003    // ── CompressionHint ───────────────────────────────────────────────────────
1004
1005    #[test]
1006    fn test_compression_hint_roundtrip() {
1007        for hint in [CompressionHint::None, CompressionHint::Zlib, CompressionHint::Lz4] {
1008            assert_eq!(CompressionHint::from_u8(hint.to_u8()).unwrap(), hint);
1009        }
1010        assert!(CompressionHint::from_u8(99).is_err());
1011    }
1012}