Skip to main content

proof_engine/netcode/
transport.rs

1//! Network transport layer: reliable/unreliable channels, fragmentation,
2//! bandwidth throttling, and connection state machine.
3
4use std::collections::{HashMap, VecDeque};
5use std::time::{Duration, Instant};
6
7/// Configuration for the transport layer.
8#[derive(Debug, Clone)]
9pub struct TransportConfig {
10    pub max_packet_size: usize,
11    pub fragment_size: usize,
12    pub max_fragments_per_packet: usize,
13    pub reliable_window_size: usize,
14    pub max_retransmits: u32,
15    pub retransmit_timeout_ms: u64,
16    pub connection_timeout_ms: u64,
17    pub keepalive_interval_ms: u64,
18    pub max_bandwidth_bytes_per_sec: usize,
19    pub ack_redundancy: usize,
20}
21
22impl Default for TransportConfig {
23    fn default() -> Self {
24        Self {
25            max_packet_size: 1200,
26            fragment_size: 1024,
27            max_fragments_per_packet: 256,
28            reliable_window_size: 256,
29            max_retransmits: 10,
30            retransmit_timeout_ms: 200,
31            connection_timeout_ms: 10000,
32            keepalive_interval_ms: 1000,
33            max_bandwidth_bytes_per_sec: 65536,
34            ack_redundancy: 3,
35        }
36    }
37}
38
39/// Transport-layer statistics.
40#[derive(Debug, Clone, Default)]
41pub struct TransportStats {
42    pub packets_sent: u64,
43    pub packets_received: u64,
44    pub packets_lost: u64,
45    pub packets_acked: u64,
46    pub bytes_sent: u64,
47    pub bytes_received: u64,
48    pub retransmissions: u64,
49    pub rtt_ms: f64,
50    pub rtt_variance_ms: f64,
51    pub packet_loss_ratio: f64,
52    pub bandwidth_used_bytes_per_sec: f64,
53    pub fragments_sent: u64,
54    pub fragments_reassembled: u64,
55}
56
57impl TransportStats {
58    pub fn update_loss_ratio(&mut self) {
59        let total = self.packets_sent;
60        if total > 0 {
61            self.packet_loss_ratio = self.packets_lost as f64 / total as f64;
62        }
63    }
64
65    pub fn reset(&mut self) {
66        *self = Self::default();
67    }
68}
69
70/// Types of packets in the protocol.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PacketType {
73    ConnectionRequest,
74    ConnectionAccept,
75    ConnectionDeny,
76    Disconnect,
77    Keepalive,
78    Reliable,
79    Unreliable,
80    Fragment,
81    Ack,
82}
83
84impl PacketType {
85    pub fn to_u8(self) -> u8 {
86        match self {
87            PacketType::ConnectionRequest => 0,
88            PacketType::ConnectionAccept => 1,
89            PacketType::ConnectionDeny => 2,
90            PacketType::Disconnect => 3,
91            PacketType::Keepalive => 4,
92            PacketType::Reliable => 5,
93            PacketType::Unreliable => 6,
94            PacketType::Fragment => 7,
95            PacketType::Ack => 8,
96        }
97    }
98
99    pub fn from_u8(v: u8) -> Option<Self> {
100        match v {
101            0 => Some(PacketType::ConnectionRequest),
102            1 => Some(PacketType::ConnectionAccept),
103            2 => Some(PacketType::ConnectionDeny),
104            3 => Some(PacketType::Disconnect),
105            4 => Some(PacketType::Keepalive),
106            5 => Some(PacketType::Reliable),
107            6 => Some(PacketType::Unreliable),
108            7 => Some(PacketType::Fragment),
109            8 => Some(PacketType::Ack),
110            _ => None,
111        }
112    }
113}
114
115/// Header prepended to every packet.
116#[derive(Debug, Clone)]
117pub struct PacketHeader {
118    pub protocol_id: u32,
119    pub packet_type: PacketType,
120    pub sequence: u16,
121    pub ack: u16,
122    pub ack_bits: u32,
123    pub timestamp_ms: u64,
124    pub payload_size: u16,
125}
126
127impl PacketHeader {
128    pub const SERIALIZED_SIZE: usize = 4 + 1 + 2 + 2 + 4 + 8 + 2;
129
130    pub fn new(packet_type: PacketType, sequence: u16) -> Self {
131        Self {
132            protocol_id: 0x50524F46, // "PROF"
133            packet_type,
134            sequence,
135            ack: 0,
136            ack_bits: 0,
137            timestamp_ms: 0,
138            payload_size: 0,
139        }
140    }
141
142    pub fn serialize(&self) -> Vec<u8> {
143        let mut buf = Vec::with_capacity(Self::SERIALIZED_SIZE);
144        buf.extend_from_slice(&self.protocol_id.to_le_bytes());
145        buf.push(self.packet_type.to_u8());
146        buf.extend_from_slice(&self.sequence.to_le_bytes());
147        buf.extend_from_slice(&self.ack.to_le_bytes());
148        buf.extend_from_slice(&self.ack_bits.to_le_bytes());
149        buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
150        buf.extend_from_slice(&self.payload_size.to_le_bytes());
151        buf
152    }
153
154    pub fn deserialize(data: &[u8]) -> Option<Self> {
155        if data.len() < Self::SERIALIZED_SIZE {
156            return None;
157        }
158        let protocol_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
159        let packet_type = PacketType::from_u8(data[4])?;
160        let sequence = u16::from_le_bytes([data[5], data[6]]);
161        let ack = u16::from_le_bytes([data[7], data[8]]);
162        let ack_bits = u32::from_le_bytes([data[9], data[10], data[11], data[12]]);
163        let timestamp_ms = u64::from_le_bytes([
164            data[13], data[14], data[15], data[16],
165            data[17], data[18], data[19], data[20],
166        ]);
167        let payload_size = u16::from_le_bytes([data[21], data[22]]);
168
169        Some(Self {
170            protocol_id,
171            packet_type,
172            sequence,
173            ack,
174            ack_bits,
175            timestamp_ms,
176            payload_size,
177        })
178    }
179
180    pub fn validate_protocol(&self) -> bool {
181        self.protocol_id == 0x50524F46
182    }
183}
184
185/// Sequence number math: handles wrapping at u16 boundary.
186fn sequence_greater_than(a: u16, b: u16) -> bool {
187    ((a > b) && (a - b <= 32768)) || ((a < b) && (b - a > 32768))
188}
189
190fn sequence_difference(a: u16, b: u16) -> i32 {
191    if sequence_greater_than(a, b) {
192        if a >= b { (a - b) as i32 } else { (a as i32 + 65536) - b as i32 }
193    } else if a == b {
194        0
195    } else {
196        -(sequence_difference(b, a))
197    }
198}
199
200/// An entry in the reliable send window.
201#[derive(Debug, Clone)]
202struct ReliableEntry {
203    sequence: u16,
204    data: Vec<u8>,
205    send_time: Instant,
206    retransmit_count: u32,
207    acked: bool,
208    next_retransmit: Instant,
209    channel_id: u8,
210}
211
212/// Reliable delivery channel with retransmission and RTT estimation.
213pub struct ReliableChannel {
214    config: TransportConfig,
215    local_sequence: u16,
216    remote_sequence: u16,
217    send_window: VecDeque<ReliableEntry>,
218    receive_buffer: HashMap<u16, Vec<u8>>,
219    next_deliver_sequence: u16,
220    pending_acks: Vec<u16>,
221    rtt_estimate_ms: f64,
222    rtt_variance_ms: f64,
223    smoothed_rtt_ms: f64,
224    stats: TransportStats,
225    channel_id: u8,
226    ack_bitfield: u32,
227    last_ack_sequence: u16,
228    send_queue: VecDeque<Vec<u8>>,
229}
230
231impl ReliableChannel {
232    pub fn new(channel_id: u8, config: TransportConfig) -> Self {
233        Self {
234            config,
235            local_sequence: 0,
236            remote_sequence: 0,
237            send_window: VecDeque::new(),
238            receive_buffer: HashMap::new(),
239            next_deliver_sequence: 0,
240            pending_acks: Vec::new(),
241            rtt_estimate_ms: 100.0,
242            rtt_variance_ms: 50.0,
243            smoothed_rtt_ms: 100.0,
244            stats: TransportStats::default(),
245            channel_id,
246            ack_bitfield: 0,
247            last_ack_sequence: 0,
248            send_queue: VecDeque::new(),
249        }
250    }
251
252    pub fn channel_id(&self) -> u8 {
253        self.channel_id
254    }
255
256    pub fn stats(&self) -> &TransportStats {
257        &self.stats
258    }
259
260    pub fn rtt_ms(&self) -> f64 {
261        self.smoothed_rtt_ms
262    }
263
264    pub fn local_sequence(&self) -> u16 {
265        self.local_sequence
266    }
267
268    pub fn remote_sequence(&self) -> u16 {
269        self.remote_sequence
270    }
271
272    /// Queue data for reliable sending. Returns the assigned sequence number.
273    pub fn send(&mut self, data: Vec<u8>) -> u16 {
274        let seq = self.local_sequence;
275        self.local_sequence = self.local_sequence.wrapping_add(1);
276        self.send_queue.push_back(data);
277        seq
278    }
279
280    /// Flush the send queue, producing packets ready to send.
281    pub fn flush(&mut self, now: Instant) -> Vec<(PacketHeader, Vec<u8>)> {
282        let mut packets = Vec::new();
283
284        while let Some(data) = self.send_queue.pop_front() {
285            let seq = self.local_sequence.wrapping_sub(self.send_queue.len() as u16 + 1);
286            let rto = self.compute_rto();
287
288            let entry = ReliableEntry {
289                sequence: seq,
290                data: data.clone(),
291                send_time: now,
292                retransmit_count: 0,
293                acked: false,
294                next_retransmit: now + rto,
295                channel_id: self.channel_id,
296            };
297
298            let mut header = PacketHeader::new(PacketType::Reliable, seq);
299            header.ack = self.remote_sequence;
300            header.ack_bits = self.ack_bitfield;
301            header.payload_size = data.len() as u16;
302
303            self.send_window.push_back(entry);
304            self.stats.packets_sent += 1;
305            self.stats.bytes_sent += (PacketHeader::SERIALIZED_SIZE + data.len()) as u64;
306
307            packets.push((header, data));
308        }
309
310        // Check for retransmissions
311        let rto = self.compute_rto();
312        let max_retransmits = self.config.max_retransmits;
313        let remote_seq = self.remote_sequence;
314        let ack_bits = self.ack_bitfield;
315        let mut retransmissions = 0u64;
316        let mut pkts_sent = 0u64;
317        let mut bytes_sent = 0u64;
318        let mut pkts_lost = 0u64;
319        for entry in self.send_window.iter_mut() {
320            if !entry.acked && now >= entry.next_retransmit {
321                if entry.retransmit_count < max_retransmits {
322                    entry.retransmit_count += 1;
323                    // Exponential backoff
324                    let backoff = rto * (1 << entry.retransmit_count.min(5));
325                    entry.next_retransmit = now + backoff;
326                    entry.send_time = now;
327
328                    let mut header = PacketHeader::new(PacketType::Reliable, entry.sequence);
329                    header.ack = remote_seq;
330                    header.ack_bits = ack_bits;
331                    header.payload_size = entry.data.len() as u16;
332
333                    retransmissions += 1;
334                    pkts_sent += 1;
335                    bytes_sent += (PacketHeader::SERIALIZED_SIZE + entry.data.len()) as u64;
336
337                    packets.push((header, entry.data.clone()));
338                } else {
339                    pkts_lost += 1;
340                }
341            }
342        }
343        self.stats.retransmissions += retransmissions;
344        self.stats.packets_sent += pkts_sent;
345        self.stats.bytes_sent += bytes_sent;
346        self.stats.packets_lost += pkts_lost;
347
348        // Prune old acked entries from the window
349        while let Some(front) = self.send_window.front() {
350            if front.acked || front.retransmit_count >= self.config.max_retransmits {
351                self.send_window.pop_front();
352            } else {
353                break;
354            }
355        }
356
357        packets
358    }
359
360    fn compute_rto(&self) -> Duration {
361        // Jacobson/Karels algorithm: RTO = SRTT + max(G, 4*RTTVAR)
362        let rto_ms = self.smoothed_rtt_ms + 4.0 * self.rtt_variance_ms;
363        let rto_ms = rto_ms.max(self.config.retransmit_timeout_ms as f64);
364        Duration::from_millis(rto_ms as u64)
365    }
366
367    /// Process an incoming reliable packet.
368    pub fn receive(&mut self, header: &PacketHeader, payload: Vec<u8>, now: Instant) {
369        self.stats.packets_received += 1;
370        self.stats.bytes_received += (PacketHeader::SERIALIZED_SIZE + payload.len()) as u64;
371
372        let seq = header.sequence;
373
374        // Update remote sequence tracking
375        if sequence_greater_than(seq, self.remote_sequence) {
376            let diff = sequence_difference(seq, self.remote_sequence);
377            // Shift ack bitfield
378            if diff > 0 && diff <= 32 {
379                self.ack_bitfield <<= diff as u32;
380                self.ack_bitfield |= 1 << (diff as u32 - 1);
381            } else if diff > 32 {
382                self.ack_bitfield = 0;
383            }
384            self.remote_sequence = seq;
385        } else {
386            let diff = sequence_difference(self.remote_sequence, seq);
387            if diff > 0 && diff <= 32 {
388                self.ack_bitfield |= 1 << (diff as u32 - 1);
389            }
390        }
391
392        // Store in receive buffer for ordered delivery
393        self.receive_buffer.insert(seq, payload);
394        self.pending_acks.push(seq);
395
396        // Process acks from the remote side
397        self.process_acks(header.ack, header.ack_bits, now);
398    }
399
400    fn process_acks(&mut self, ack: u16, ack_bits: u32, now: Instant) {
401        let mut acked_count = 0u64;
402        let mut rtt_samples = Vec::new();
403        for entry in self.send_window.iter_mut() {
404            if entry.acked {
405                continue;
406            }
407            let seq = entry.sequence;
408            let is_acked = if seq == ack {
409                true
410            } else {
411                let diff = sequence_difference(ack, seq);
412                diff > 0 && diff <= 32 && (ack_bits & (1 << (diff - 1))) != 0
413            };
414
415            if is_acked {
416                entry.acked = true;
417                acked_count += 1;
418
419                if entry.retransmit_count == 0 {
420                    let rtt = now.duration_since(entry.send_time).as_secs_f64() * 1000.0;
421                    rtt_samples.push(rtt);
422                }
423            }
424        }
425        self.stats.packets_acked += acked_count;
426        for rtt in rtt_samples {
427            self.update_rtt(rtt);
428        }
429    }
430
431    fn update_rtt(&mut self, sample_ms: f64) {
432        // Jacobson/Karels RTT estimation
433        let alpha = 0.125;
434        let beta = 0.25;
435
436        let err = sample_ms - self.smoothed_rtt_ms;
437        self.smoothed_rtt_ms += alpha * err;
438        self.rtt_variance_ms += beta * (err.abs() - self.rtt_variance_ms);
439        self.rtt_estimate_ms = sample_ms;
440    }
441
442    /// Drain delivered messages in order.
443    pub fn drain_received(&mut self) -> Vec<Vec<u8>> {
444        let mut messages = Vec::new();
445        loop {
446            if let Some(data) = self.receive_buffer.remove(&self.next_deliver_sequence) {
447                messages.push(data);
448                self.next_deliver_sequence = self.next_deliver_sequence.wrapping_add(1);
449            } else {
450                break;
451            }
452        }
453        messages
454    }
455
456    /// Get pending acks to piggyback on outgoing packets.
457    pub fn drain_pending_acks(&mut self) -> Vec<u16> {
458        std::mem::take(&mut self.pending_acks)
459    }
460
461    /// Number of unacked packets in the send window.
462    pub fn in_flight(&self) -> usize {
463        self.send_window.iter().filter(|e| !e.acked).count()
464    }
465
466    /// Whether the send window is full.
467    pub fn is_congested(&self) -> bool {
468        self.in_flight() >= self.config.reliable_window_size
469    }
470
471    /// Reset the channel state.
472    pub fn reset(&mut self) {
473        self.local_sequence = 0;
474        self.remote_sequence = 0;
475        self.next_deliver_sequence = 0;
476        self.send_window.clear();
477        self.receive_buffer.clear();
478        self.pending_acks.clear();
479        self.send_queue.clear();
480        self.smoothed_rtt_ms = 100.0;
481        self.rtt_variance_ms = 50.0;
482        self.stats.reset();
483    }
484}
485
486/// Unreliable channel: fire-and-forget with sequence numbers for ordering.
487pub struct UnreliableChannel {
488    local_sequence: u16,
489    remote_sequence: u16,
490    stats: TransportStats,
491    received_buffer: VecDeque<Vec<u8>>,
492    max_buffer_size: usize,
493    ack_bitfield: u32,
494    drop_out_of_order: bool,
495}
496
497impl UnreliableChannel {
498    pub fn new() -> Self {
499        Self {
500            local_sequence: 0,
501            remote_sequence: 0,
502            stats: TransportStats::default(),
503            received_buffer: VecDeque::new(),
504            max_buffer_size: 256,
505            ack_bitfield: 0,
506            drop_out_of_order: false,
507        }
508    }
509
510    pub fn with_max_buffer(mut self, size: usize) -> Self {
511        self.max_buffer_size = size;
512        self
513    }
514
515    pub fn set_drop_out_of_order(&mut self, drop: bool) {
516        self.drop_out_of_order = drop;
517    }
518
519    pub fn stats(&self) -> &TransportStats {
520        &self.stats
521    }
522
523    pub fn local_sequence(&self) -> u16 {
524        self.local_sequence
525    }
526
527    /// Prepare a packet for unreliable sending.
528    pub fn send(&mut self, data: Vec<u8>) -> (PacketHeader, Vec<u8>) {
529        let seq = self.local_sequence;
530        self.local_sequence = self.local_sequence.wrapping_add(1);
531
532        let mut header = PacketHeader::new(PacketType::Unreliable, seq);
533        header.ack = self.remote_sequence;
534        header.ack_bits = self.ack_bitfield;
535        header.payload_size = data.len() as u16;
536
537        self.stats.packets_sent += 1;
538        self.stats.bytes_sent += (PacketHeader::SERIALIZED_SIZE + data.len()) as u64;
539
540        (header, data)
541    }
542
543    /// Process an incoming unreliable packet.
544    pub fn receive(&mut self, header: &PacketHeader, payload: Vec<u8>) {
545        self.stats.packets_received += 1;
546        self.stats.bytes_received += (PacketHeader::SERIALIZED_SIZE + payload.len()) as u64;
547
548        let seq = header.sequence;
549
550        // Check if this is newer than what we have
551        if self.drop_out_of_order && !sequence_greater_than(seq, self.remote_sequence) && seq != self.remote_sequence {
552            // Drop out-of-order packet
553            return;
554        }
555
556        // Update remote sequence tracking
557        if sequence_greater_than(seq, self.remote_sequence) {
558            let diff = sequence_difference(seq, self.remote_sequence);
559            if diff > 0 && diff <= 32 {
560                self.ack_bitfield <<= diff as u32;
561                self.ack_bitfield |= 1 << (diff as u32 - 1);
562            } else if diff > 32 {
563                self.ack_bitfield = 0;
564            }
565            self.remote_sequence = seq;
566        } else {
567            let diff = sequence_difference(self.remote_sequence, seq);
568            if diff > 0 && diff <= 32 {
569                self.ack_bitfield |= 1 << (diff as u32 - 1);
570            }
571        }
572
573        // Buffer the payload
574        self.received_buffer.push_back(payload);
575        while self.received_buffer.len() > self.max_buffer_size {
576            self.received_buffer.pop_front();
577        }
578    }
579
580    /// Drain all received messages.
581    pub fn drain_received(&mut self) -> Vec<Vec<u8>> {
582        self.received_buffer.drain(..).collect()
583    }
584
585    pub fn reset(&mut self) {
586        self.local_sequence = 0;
587        self.remote_sequence = 0;
588        self.received_buffer.clear();
589        self.stats.reset();
590    }
591}
592
593/// Header for a fragment of a larger packet.
594#[derive(Debug, Clone)]
595pub struct FragmentHeader {
596    pub group_id: u16,
597    pub fragment_index: u8,
598    pub total_fragments: u8,
599    pub fragment_size: u16,
600}
601
602impl FragmentHeader {
603    pub const SERIALIZED_SIZE: usize = 2 + 1 + 1 + 2;
604
605    pub fn serialize(&self) -> Vec<u8> {
606        let mut buf = Vec::with_capacity(Self::SERIALIZED_SIZE);
607        buf.extend_from_slice(&self.group_id.to_le_bytes());
608        buf.push(self.fragment_index);
609        buf.push(self.total_fragments);
610        buf.extend_from_slice(&self.fragment_size.to_le_bytes());
611        buf
612    }
613
614    pub fn deserialize(data: &[u8]) -> Option<Self> {
615        if data.len() < Self::SERIALIZED_SIZE {
616            return None;
617        }
618        Some(Self {
619            group_id: u16::from_le_bytes([data[0], data[1]]),
620            fragment_index: data[2],
621            total_fragments: data[3],
622            fragment_size: u16::from_le_bytes([data[4], data[5]]),
623        })
624    }
625}
626
627/// Reassembly state for a group of fragments.
628struct ReassemblyGroup {
629    group_id: u16,
630    total_fragments: u8,
631    received_mask: u64,
632    fragments: Vec<Option<Vec<u8>>>,
633    creation_time: Instant,
634    total_size: usize,
635}
636
637impl ReassemblyGroup {
638    fn new(group_id: u16, total_fragments: u8, now: Instant) -> Self {
639        let mut fragments = Vec::with_capacity(total_fragments as usize);
640        for _ in 0..total_fragments {
641            fragments.push(None);
642        }
643        Self {
644            group_id,
645            total_fragments,
646            received_mask: 0,
647            fragments,
648            creation_time: now,
649            total_size: 0,
650        }
651    }
652
653    fn insert(&mut self, index: u8, data: Vec<u8>) -> bool {
654        if index >= self.total_fragments {
655            return false;
656        }
657        let bit = 1u64 << index;
658        if self.received_mask & bit != 0 {
659            return false; // duplicate
660        }
661        self.received_mask |= bit;
662        self.total_size += data.len();
663        self.fragments[index as usize] = Some(data);
664        self.is_complete()
665    }
666
667    fn is_complete(&self) -> bool {
668        let expected = if self.total_fragments >= 64 {
669            u64::MAX
670        } else {
671            (1u64 << self.total_fragments) - 1
672        };
673        self.received_mask == expected
674    }
675
676    fn assemble(&self) -> Option<Vec<u8>> {
677        if !self.is_complete() {
678            return None;
679        }
680        let mut result = Vec::with_capacity(self.total_size);
681        for frag in &self.fragments {
682            if let Some(data) = frag {
683                result.extend_from_slice(data);
684            } else {
685                return None;
686            }
687        }
688        Some(result)
689    }
690
691    fn received_count(&self) -> u8 {
692        self.received_mask.count_ones() as u8
693    }
694
695    fn age(&self, now: Instant) -> Duration {
696        now.duration_since(self.creation_time)
697    }
698}
699
700/// Buffer managing reassembly of fragmented packets.
701pub struct ReassemblyBuffer {
702    groups: HashMap<u16, ReassemblyGroup>,
703    timeout: Duration,
704    max_groups: usize,
705}
706
707impl ReassemblyBuffer {
708    pub fn new(timeout_ms: u64, max_groups: usize) -> Self {
709        Self {
710            groups: HashMap::new(),
711            timeout: Duration::from_millis(timeout_ms),
712            max_groups,
713        }
714    }
715
716    pub fn insert(&mut self, header: &FragmentHeader, data: Vec<u8>, now: Instant) -> Option<Vec<u8>> {
717        // Clean up expired groups
718        self.cleanup(now);
719
720        let group = self.groups
721            .entry(header.group_id)
722            .or_insert_with(|| ReassemblyGroup::new(header.group_id, header.total_fragments, now));
723
724        if group.insert(header.fragment_index, data) {
725            let assembled = group.assemble();
726            self.groups.remove(&header.group_id);
727            assembled
728        } else {
729            None
730        }
731    }
732
733    fn cleanup(&mut self, now: Instant) {
734        self.groups.retain(|_, group| group.age(now) < self.timeout);
735
736        // If still over capacity, remove oldest
737        while self.groups.len() > self.max_groups {
738            let oldest = self.groups.iter()
739                .min_by_key(|(_, g)| g.creation_time)
740                .map(|(&id, _)| id);
741            if let Some(id) = oldest {
742                self.groups.remove(&id);
743            } else {
744                break;
745            }
746        }
747    }
748
749    pub fn pending_groups(&self) -> usize {
750        self.groups.len()
751    }
752
753    pub fn clear(&mut self) {
754        self.groups.clear();
755    }
756}
757
758/// Splits large payloads into fragments for transmission.
759pub struct PacketFragmenter {
760    fragment_size: usize,
761    max_fragments: usize,
762    next_group_id: u16,
763}
764
765impl PacketFragmenter {
766    pub fn new(fragment_size: usize, max_fragments: usize) -> Self {
767        Self {
768            fragment_size: fragment_size.max(64),
769            max_fragments: max_fragments.max(1),
770            next_group_id: 0,
771        }
772    }
773
774    /// Check if data needs fragmentation.
775    pub fn needs_fragmentation(&self, data_len: usize) -> bool {
776        data_len > self.fragment_size
777    }
778
779    /// Fragment a payload into multiple pieces with headers.
780    pub fn fragment(&mut self, data: &[u8]) -> Vec<(FragmentHeader, Vec<u8>)> {
781        if data.is_empty() {
782            return Vec::new();
783        }
784
785        let total_fragments = ((data.len() + self.fragment_size - 1) / self.fragment_size).min(self.max_fragments);
786        let group_id = self.next_group_id;
787        self.next_group_id = self.next_group_id.wrapping_add(1);
788
789        let mut fragments = Vec::with_capacity(total_fragments);
790        let mut offset = 0;
791
792        for i in 0..total_fragments {
793            let end = (offset + self.fragment_size).min(data.len());
794            let fragment_data = data[offset..end].to_vec();
795
796            let header = FragmentHeader {
797                group_id,
798                fragment_index: i as u8,
799                total_fragments: total_fragments as u8,
800                fragment_size: fragment_data.len() as u16,
801            };
802
803            fragments.push((header, fragment_data));
804            offset = end;
805
806            if offset >= data.len() {
807                break;
808            }
809        }
810
811        fragments
812    }
813
814    pub fn max_payload_size(&self) -> usize {
815        self.fragment_size * self.max_fragments
816    }
817}
818
819/// Bandwidth throttle using a token bucket algorithm.
820pub struct BandwidthThrottle {
821    max_bytes_per_sec: f64,
822    tokens: f64,
823    max_tokens: f64,
824    last_refill: Instant,
825    bytes_sent_this_second: usize,
826    second_start: Instant,
827    enabled: bool,
828}
829
830impl BandwidthThrottle {
831    pub fn new(max_bytes_per_sec: usize) -> Self {
832        let now = Instant::now();
833        Self {
834            max_bytes_per_sec: max_bytes_per_sec as f64,
835            tokens: max_bytes_per_sec as f64,
836            max_tokens: max_bytes_per_sec as f64 * 1.5,
837            last_refill: now,
838            bytes_sent_this_second: 0,
839            second_start: now,
840            enabled: true,
841        }
842    }
843
844    pub fn set_enabled(&mut self, enabled: bool) {
845        self.enabled = enabled;
846    }
847
848    pub fn is_enabled(&self) -> bool {
849        self.enabled
850    }
851
852    pub fn set_max_bytes_per_sec(&mut self, max: usize) {
853        self.max_bytes_per_sec = max as f64;
854        self.max_tokens = max as f64 * 1.5;
855    }
856
857    pub fn max_bytes_per_sec(&self) -> usize {
858        self.max_bytes_per_sec as usize
859    }
860
861    /// Refill tokens based on elapsed time.
862    pub fn refill(&mut self, now: Instant) {
863        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
864        self.tokens += self.max_bytes_per_sec * elapsed;
865        if self.tokens > self.max_tokens {
866            self.tokens = self.max_tokens;
867        }
868        self.last_refill = now;
869
870        // Reset per-second counter
871        if now.duration_since(self.second_start).as_secs_f64() >= 1.0 {
872            self.bytes_sent_this_second = 0;
873            self.second_start = now;
874        }
875    }
876
877    /// Check if we can send `bytes` right now.
878    pub fn can_send(&self, bytes: usize) -> bool {
879        if !self.enabled {
880            return true;
881        }
882        self.tokens >= bytes as f64
883    }
884
885    /// Consume tokens for sending.
886    pub fn consume(&mut self, bytes: usize) {
887        self.tokens -= bytes as f64;
888        self.bytes_sent_this_second += bytes;
889    }
890
891    /// Try to consume tokens; returns true if allowed.
892    pub fn try_send(&mut self, bytes: usize, now: Instant) -> bool {
893        self.refill(now);
894        if self.can_send(bytes) {
895            self.consume(bytes);
896            true
897        } else {
898            false
899        }
900    }
901
902    pub fn available_bytes(&self) -> usize {
903        if !self.enabled {
904            return usize::MAX;
905        }
906        self.tokens.max(0.0) as usize
907    }
908
909    pub fn utilization(&self) -> f64 {
910        if self.max_bytes_per_sec <= 0.0 {
911            return 0.0;
912        }
913        self.bytes_sent_this_second as f64 / self.max_bytes_per_sec
914    }
915
916    pub fn reset(&mut self) {
917        self.tokens = self.max_bytes_per_sec;
918        self.bytes_sent_this_second = 0;
919        self.last_refill = Instant::now();
920        self.second_start = self.last_refill;
921    }
922}
923
924/// Connection states in the state machine.
925#[derive(Debug, Clone, Copy, PartialEq, Eq)]
926pub enum ConnectionState {
927    Disconnected,
928    Connecting,
929    Connected,
930    Disconnecting,
931}
932
933/// Events emitted by the connection state machine.
934#[derive(Debug, Clone)]
935pub enum ConnectionEvent {
936    Connected,
937    Disconnected { reason: DisconnectReason },
938    ConnectionFailed { reason: String },
939    TimedOut,
940}
941
942/// Reasons for disconnection.
943#[derive(Debug, Clone, Copy, PartialEq, Eq)]
944pub enum DisconnectReason {
945    Graceful,
946    Timeout,
947    Kicked,
948    ProtocolError,
949    Full,
950}
951
952/// Connection state machine managing the lifecycle of a network connection.
953pub struct ConnectionStateMachine {
954    state: ConnectionState,
955    config: TransportConfig,
956    connect_started: Option<Instant>,
957    last_packet_received: Option<Instant>,
958    last_packet_sent: Option<Instant>,
959    connect_attempts: u32,
960    max_connect_attempts: u32,
961    connect_retry_interval: Duration,
962    events: VecDeque<ConnectionEvent>,
963    disconnect_reason: Option<DisconnectReason>,
964    session_id: u64,
965    keepalive_due: bool,
966}
967
968impl ConnectionStateMachine {
969    pub fn new(config: TransportConfig) -> Self {
970        Self {
971            state: ConnectionState::Disconnected,
972            config,
973            connect_started: None,
974            last_packet_received: None,
975            last_packet_sent: None,
976            connect_attempts: 0,
977            max_connect_attempts: 5,
978            connect_retry_interval: Duration::from_millis(500),
979            events: VecDeque::new(),
980            disconnect_reason: None,
981            session_id: 0,
982            keepalive_due: false,
983        }
984    }
985
986    pub fn state(&self) -> ConnectionState {
987        self.state
988    }
989
990    pub fn is_connected(&self) -> bool {
991        self.state == ConnectionState::Connected
992    }
993
994    pub fn session_id(&self) -> u64 {
995        self.session_id
996    }
997
998    /// Begin connecting.
999    pub fn connect(&mut self, now: Instant) {
1000        if self.state != ConnectionState::Disconnected {
1001            return;
1002        }
1003        self.state = ConnectionState::Connecting;
1004        self.connect_started = Some(now);
1005        self.connect_attempts = 0;
1006        self.session_id = generate_session_id(now);
1007    }
1008
1009    /// Handle acceptance of our connection.
1010    pub fn on_accepted(&mut self, now: Instant) {
1011        if self.state == ConnectionState::Connecting {
1012            self.state = ConnectionState::Connected;
1013            self.last_packet_received = Some(now);
1014            self.events.push_back(ConnectionEvent::Connected);
1015        }
1016    }
1017
1018    /// Handle denial of our connection.
1019    pub fn on_denied(&mut self, reason: String) {
1020        if self.state == ConnectionState::Connecting {
1021            self.state = ConnectionState::Disconnected;
1022            self.events.push_back(ConnectionEvent::ConnectionFailed { reason });
1023        }
1024    }
1025
1026    /// Handle an incoming packet (any type) to update keep-alive tracking.
1027    pub fn on_packet_received(&mut self, now: Instant) {
1028        self.last_packet_received = Some(now);
1029    }
1030
1031    /// Mark that we sent a packet.
1032    pub fn on_packet_sent(&mut self, now: Instant) {
1033        self.last_packet_sent = Some(now);
1034        self.keepalive_due = false;
1035    }
1036
1037    /// Initiate graceful disconnection.
1038    pub fn disconnect(&mut self) {
1039        if self.state == ConnectionState::Connected {
1040            self.state = ConnectionState::Disconnecting;
1041            self.disconnect_reason = Some(DisconnectReason::Graceful);
1042        }
1043    }
1044
1045    /// Complete disconnection (after sending disconnect packet).
1046    pub fn on_disconnect_complete(&mut self) {
1047        let reason = self.disconnect_reason.unwrap_or(DisconnectReason::Graceful);
1048        self.state = ConnectionState::Disconnected;
1049        self.events.push_back(ConnectionEvent::Disconnected { reason });
1050        self.connect_started = None;
1051        self.last_packet_received = None;
1052    }
1053
1054    /// Handle remote disconnect.
1055    pub fn on_remote_disconnect(&mut self, reason: DisconnectReason) {
1056        if self.state == ConnectionState::Connected || self.state == ConnectionState::Connecting {
1057            self.state = ConnectionState::Disconnected;
1058            self.disconnect_reason = Some(reason);
1059            self.events.push_back(ConnectionEvent::Disconnected { reason });
1060        }
1061    }
1062
1063    /// Tick the state machine. Returns true if a connect retry is needed.
1064    pub fn update(&mut self, now: Instant) -> bool {
1065        let mut needs_retry = false;
1066
1067        match self.state {
1068            ConnectionState::Connecting => {
1069                if let Some(started) = self.connect_started {
1070                    let elapsed = now.duration_since(started);
1071                    if elapsed >= Duration::from_millis(self.config.connection_timeout_ms) {
1072                        self.state = ConnectionState::Disconnected;
1073                        self.events.push_back(ConnectionEvent::TimedOut);
1074                        return false;
1075                    }
1076
1077                    // Check if we need to retry
1078                    let expected_attempts = (elapsed.as_millis() / self.connect_retry_interval.as_millis()).max(1) as u32;
1079                    if expected_attempts > self.connect_attempts && self.connect_attempts < self.max_connect_attempts {
1080                        self.connect_attempts += 1;
1081                        needs_retry = true;
1082                    }
1083                }
1084            }
1085            ConnectionState::Connected => {
1086                // Check for timeout
1087                if let Some(last_recv) = self.last_packet_received {
1088                    let since_recv = now.duration_since(last_recv);
1089                    if since_recv >= Duration::from_millis(self.config.connection_timeout_ms) {
1090                        self.state = ConnectionState::Disconnected;
1091                        self.disconnect_reason = Some(DisconnectReason::Timeout);
1092                        self.events.push_back(ConnectionEvent::TimedOut);
1093                        return false;
1094                    }
1095                }
1096
1097                // Check if keepalive is needed
1098                if let Some(last_sent) = self.last_packet_sent {
1099                    let since_sent = now.duration_since(last_sent);
1100                    if since_sent >= Duration::from_millis(self.config.keepalive_interval_ms) {
1101                        self.keepalive_due = true;
1102                    }
1103                } else {
1104                    self.keepalive_due = true;
1105                }
1106            }
1107            ConnectionState::Disconnecting => {
1108                // Just transition to disconnected
1109                self.on_disconnect_complete();
1110            }
1111            ConnectionState::Disconnected => {}
1112        }
1113
1114        needs_retry
1115    }
1116
1117    /// Whether a keepalive should be sent.
1118    pub fn needs_keepalive(&self) -> bool {
1119        self.keepalive_due && self.state == ConnectionState::Connected
1120    }
1121
1122    /// Drain pending events.
1123    pub fn drain_events(&mut self) -> Vec<ConnectionEvent> {
1124        self.events.drain(..).collect()
1125    }
1126
1127    /// Force transition to disconnected.
1128    pub fn force_disconnect(&mut self, reason: DisconnectReason) {
1129        self.state = ConnectionState::Disconnected;
1130        self.disconnect_reason = Some(reason);
1131        self.events.push_back(ConnectionEvent::Disconnected { reason });
1132    }
1133
1134    pub fn reset(&mut self) {
1135        self.state = ConnectionState::Disconnected;
1136        self.connect_started = None;
1137        self.last_packet_received = None;
1138        self.last_packet_sent = None;
1139        self.connect_attempts = 0;
1140        self.events.clear();
1141        self.disconnect_reason = None;
1142        self.keepalive_due = false;
1143    }
1144
1145    /// Time since last received packet.
1146    pub fn time_since_last_received(&self, now: Instant) -> Option<Duration> {
1147        self.last_packet_received.map(|t| now.duration_since(t))
1148    }
1149}
1150
1151/// Generate a pseudo-unique session ID from an Instant.
1152fn generate_session_id(now: Instant) -> u64 {
1153    let elapsed = now.elapsed();
1154    let nanos = elapsed.as_nanos() as u64;
1155    // FNV-1a hash for spreading bits
1156    let mut hash: u64 = 0xcbf29ce484222325;
1157    for byte in nanos.to_le_bytes() {
1158        hash ^= byte as u64;
1159        hash = hash.wrapping_mul(0x100000001b3);
1160    }
1161    hash
1162}
1163
1164/// A complete outgoing packet ready for serialization.
1165#[derive(Debug, Clone)]
1166pub struct OutgoingPacket {
1167    pub header: PacketHeader,
1168    pub payload: Vec<u8>,
1169    pub fragment_header: Option<FragmentHeader>,
1170}
1171
1172impl OutgoingPacket {
1173    pub fn serialize(&self) -> Vec<u8> {
1174        let mut buf = self.header.serialize();
1175        if let Some(ref fh) = self.fragment_header {
1176            buf.extend_from_slice(&fh.serialize());
1177        }
1178        buf.extend_from_slice(&self.payload);
1179        buf
1180    }
1181
1182    pub fn total_size(&self) -> usize {
1183        PacketHeader::SERIALIZED_SIZE
1184            + self.fragment_header.as_ref().map_or(0, |_| FragmentHeader::SERIALIZED_SIZE)
1185            + self.payload.len()
1186    }
1187}
1188
1189/// Deserialize an incoming raw packet into its components.
1190pub fn deserialize_packet(data: &[u8]) -> Option<(PacketHeader, Option<FragmentHeader>, Vec<u8>)> {
1191    let header = PacketHeader::deserialize(data)?;
1192    if !header.validate_protocol() {
1193        return None;
1194    }
1195
1196    let mut offset = PacketHeader::SERIALIZED_SIZE;
1197
1198    let fragment_header = if header.packet_type == PacketType::Fragment {
1199        if data.len() < offset + FragmentHeader::SERIALIZED_SIZE {
1200            return None;
1201        }
1202        let fh = FragmentHeader::deserialize(&data[offset..])?;
1203        offset += FragmentHeader::SERIALIZED_SIZE;
1204        Some(fh)
1205    } else {
1206        None
1207    };
1208
1209    let payload = if offset < data.len() {
1210        data[offset..].to_vec()
1211    } else {
1212        Vec::new()
1213    };
1214
1215    Some((header, fragment_header, payload))
1216}
1217
1218/// Jitter buffer for smoothing network packet delivery timing.
1219pub struct JitterBuffer {
1220    buffer: VecDeque<(u64, Vec<u8>)>,
1221    delay_ms: u64,
1222    max_size: usize,
1223}
1224
1225impl JitterBuffer {
1226    pub fn new(delay_ms: u64, max_size: usize) -> Self {
1227        Self {
1228            buffer: VecDeque::new(),
1229            delay_ms,
1230            max_size,
1231        }
1232    }
1233
1234    pub fn push(&mut self, timestamp_ms: u64, data: Vec<u8>) {
1235        // Insert in sorted order by timestamp
1236        let pos = self.buffer.iter().position(|(ts, _)| *ts > timestamp_ms);
1237        match pos {
1238            Some(idx) => self.buffer.insert(idx, (timestamp_ms, data)),
1239            None => self.buffer.push_back((timestamp_ms, data)),
1240        }
1241
1242        // Trim excess
1243        while self.buffer.len() > self.max_size {
1244            self.buffer.pop_front();
1245        }
1246    }
1247
1248    pub fn drain_ready(&mut self, current_time_ms: u64) -> Vec<Vec<u8>> {
1249        let threshold = current_time_ms.saturating_sub(self.delay_ms);
1250        let mut ready = Vec::new();
1251        while let Some(&(ts, _)) = self.buffer.front() {
1252            if ts <= threshold {
1253                if let Some((_, data)) = self.buffer.pop_front() {
1254                    ready.push(data);
1255                }
1256            } else {
1257                break;
1258            }
1259        }
1260        ready
1261    }
1262
1263    pub fn set_delay(&mut self, delay_ms: u64) {
1264        self.delay_ms = delay_ms;
1265    }
1266
1267    pub fn len(&self) -> usize {
1268        self.buffer.len()
1269    }
1270
1271    pub fn is_empty(&self) -> bool {
1272        self.buffer.is_empty()
1273    }
1274
1275    pub fn clear(&mut self) {
1276        self.buffer.clear();
1277    }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282    use super::*;
1283
1284    #[test]
1285    fn test_packet_header_roundtrip() {
1286        let mut h = PacketHeader::new(PacketType::Reliable, 42);
1287        h.ack = 10;
1288        h.ack_bits = 0xFF00FF00;
1289        h.timestamp_ms = 123456789;
1290        h.payload_size = 512;
1291
1292        let data = h.serialize();
1293        let h2 = PacketHeader::deserialize(&data).unwrap();
1294        assert_eq!(h2.sequence, 42);
1295        assert_eq!(h2.ack, 10);
1296        assert_eq!(h2.ack_bits, 0xFF00FF00);
1297        assert_eq!(h2.timestamp_ms, 123456789);
1298        assert_eq!(h2.payload_size, 512);
1299        assert!(h2.validate_protocol());
1300    }
1301
1302    #[test]
1303    fn test_sequence_greater_than() {
1304        assert!(sequence_greater_than(1, 0));
1305        assert!(sequence_greater_than(100, 99));
1306        // Wraparound
1307        assert!(sequence_greater_than(0, 65535));
1308        assert!(!sequence_greater_than(65535, 0));
1309    }
1310
1311    #[test]
1312    fn test_fragment_header_roundtrip() {
1313        let fh = FragmentHeader {
1314            group_id: 7,
1315            fragment_index: 3,
1316            total_fragments: 10,
1317            fragment_size: 1024,
1318        };
1319        let data = fh.serialize();
1320        let fh2 = FragmentHeader::deserialize(&data).unwrap();
1321        assert_eq!(fh2.group_id, 7);
1322        assert_eq!(fh2.fragment_index, 3);
1323        assert_eq!(fh2.total_fragments, 10);
1324        assert_eq!(fh2.fragment_size, 1024);
1325    }
1326
1327    #[test]
1328    fn test_fragmentation_and_reassembly() {
1329        let mut fragmenter = PacketFragmenter::new(100, 256);
1330        let data: Vec<u8> = (0..350).map(|i| (i % 256) as u8).collect();
1331
1332        let fragments = fragmenter.fragment(&data);
1333        assert_eq!(fragments.len(), 4);
1334
1335        let now = Instant::now();
1336        let mut reassembly = ReassemblyBuffer::new(5000, 32);
1337
1338        let mut result = None;
1339        for (fh, fdata) in &fragments {
1340            result = reassembly.insert(fh, fdata.clone(), now);
1341        }
1342
1343        let assembled = result.unwrap();
1344        assert_eq!(assembled, data);
1345    }
1346
1347    #[test]
1348    fn test_bandwidth_throttle() {
1349        let mut throttle = BandwidthThrottle::new(1000);
1350        let now = Instant::now();
1351        throttle.refill(now);
1352
1353        assert!(throttle.can_send(500));
1354        throttle.consume(500);
1355        assert!(throttle.can_send(500));
1356        throttle.consume(500);
1357        assert!(!throttle.can_send(100));
1358    }
1359
1360    #[test]
1361    fn test_connection_state_machine() {
1362        let config = TransportConfig::default();
1363        let mut csm = ConnectionStateMachine::new(config);
1364        let now = Instant::now();
1365
1366        assert_eq!(csm.state(), ConnectionState::Disconnected);
1367        csm.connect(now);
1368        assert_eq!(csm.state(), ConnectionState::Connecting);
1369        csm.on_accepted(now);
1370        assert_eq!(csm.state(), ConnectionState::Connected);
1371
1372        let events = csm.drain_events();
1373        assert_eq!(events.len(), 1);
1374    }
1375
1376    #[test]
1377    fn test_unreliable_channel() {
1378        let mut ch = UnreliableChannel::new();
1379        let (header, payload) = ch.send(vec![1, 2, 3]);
1380        assert_eq!(header.sequence, 0);
1381
1382        ch.receive(&header, payload);
1383        let msgs = ch.drain_received();
1384        assert_eq!(msgs.len(), 1);
1385        assert_eq!(msgs[0], vec![1, 2, 3]);
1386    }
1387
1388    #[test]
1389    fn test_jitter_buffer() {
1390        let mut jb = JitterBuffer::new(50, 100);
1391        jb.push(100, vec![1]);
1392        jb.push(120, vec![2]);
1393        jb.push(90, vec![0]);
1394
1395        let ready = jb.drain_ready(140);
1396        assert_eq!(ready.len(), 1);
1397        assert_eq!(ready[0], vec![0]);
1398
1399        let ready2 = jb.drain_ready(160);
1400        assert_eq!(ready2.len(), 1);
1401        assert_eq!(ready2[0], vec![1]);
1402    }
1403
1404    #[test]
1405    fn test_deserialize_packet() {
1406        let mut header = PacketHeader::new(PacketType::Unreliable, 5);
1407        header.payload_size = 3;
1408        let mut data = header.serialize();
1409        data.extend_from_slice(&[10, 20, 30]);
1410
1411        let (h, fh, payload) = deserialize_packet(&data).unwrap();
1412        assert_eq!(h.sequence, 5);
1413        assert!(fh.is_none());
1414        assert_eq!(payload, vec![10, 20, 30]);
1415    }
1416}