Skip to main content

proof_engine/networking/
transport.rs

1//! Reliable UDP transport layer.
2//!
3//! Builds reliable, ordered delivery on top of raw UDP without any async
4//! runtime.  Drive the stack each frame by calling `ConnectionManager::poll()`.
5//!
6//! ## Architecture
7//! ```text
8//!  Application
9//!      │  send(channel, data)
10//!      ▼
11//!  ConnectionManager  ──  maintains one Connection per peer
12//!      │
13//!      ▼
14//!  Connection  ──  per-channel send queues + ReliableUdp
15//!      │
16//!      ▼
17//!  NonBlockingSocket  ──  wraps std::net::UdpSocket
18//! ```
19
20use std::collections::{HashMap, VecDeque};
21use std::net::{SocketAddr, UdpSocket};
22use std::time::{Duration, Instant};
23
24use crate::networking::protocol::{
25    Packet, PacketEncoder, PacketDecoder, PacketKind, ProtocolError, PacketHeader,
26};
27
28// ─── Constants ───────────────────────────────────────────────────────────────
29
30/// Maximum Transmission Unit for outgoing packets (bytes).
31pub const MTU: usize = 1400;
32/// Base retransmit timeout in milliseconds.
33pub const RETRANSMIT_BASE_MS: u64 = 100;
34/// Maximum retransmit timeout after backoff (ms).
35pub const RETRANSMIT_MAX_MS: u64 = 8000;
36/// Maximum number of retransmit attempts before giving up.
37pub const MAX_RETRANSMIT: u32 = 10;
38/// Keepalive interval in milliseconds.
39pub const KEEPALIVE_MS: u64 = 500;
40/// Peer timeout after last received packet (ms).
41pub const PEER_TIMEOUT_MS: u64 = 10_000;
42/// EWMA smoothing factor for RTT estimation.
43pub const RTT_ALPHA: f64 = 0.125;
44/// EWMA smoothing factor for jitter.
45pub const JITTER_ALPHA: f64 = 0.25;
46/// Fragment timeout in milliseconds (drop partial reassembly).
47pub const FRAGMENT_TIMEOUT_MS: u64 = 5_000;
48/// Number of sequence numbers in the ack window.
49pub const ACK_WINDOW: u32 = 32;
50
51// ─── Channel ─────────────────────────────────────────────────────────────────
52
53/// Delivery channel semantics for outgoing data.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum Channel {
56    /// Guaranteed delivery, no ordering guarantee.
57    Reliable,
58    /// Best-effort, no ordering guarantee (fire-and-forget).
59    Unreliable,
60    /// Guaranteed delivery in the order sent.
61    ReliableOrdered,
62    /// Best-effort, but older out-of-order packets are discarded.
63    UnreliableOrdered,
64}
65
66impl Channel {
67    pub fn is_reliable(self) -> bool {
68        matches!(self, Channel::Reliable | Channel::ReliableOrdered)
69    }
70    pub fn is_ordered(self) -> bool {
71        matches!(self, Channel::ReliableOrdered | Channel::UnreliableOrdered)
72    }
73}
74
75// ─── ConnectionState ─────────────────────────────────────────────────────────
76
77/// Lifecycle state of a single UDP connection to a remote peer.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum ConnectionState {
80    /// No connection.
81    Disconnected,
82    /// Sent a Connect packet, waiting for server acknowledgement.
83    Connecting,
84    /// Fully established; can send and receive data.
85    Connected,
86    /// No packet received within `PEER_TIMEOUT_MS`.
87    TimedOut,
88    /// Remote explicitly requested disconnect (kick/ban).
89    Kicked,
90}
91
92// ─── TransportStats ──────────────────────────────────────────────────────────
93
94/// Snapshot of transport-layer statistics for a single peer connection.
95#[derive(Debug, Clone, Default)]
96pub struct TransportStats {
97    /// Smoothed round-trip time in milliseconds.
98    pub rtt_ms: f64,
99    /// Estimated packet loss as a percentage (0.0–100.0).
100    pub packet_loss_pct: f64,
101    /// Smoothed jitter in milliseconds.
102    pub jitter_ms: f64,
103    /// Outgoing bytes per second.
104    pub bandwidth_up: f64,
105    /// Incoming bytes per second.
106    pub bandwidth_down: f64,
107    /// Total packets sent.
108    pub packets_sent: u64,
109    /// Total packets received.
110    pub packets_recv: u64,
111    /// Total retransmissions.
112    pub retransmits: u64,
113}
114
115// ─── ReceivedPacket ──────────────────────────────────────────────────────────
116
117/// A packet received from a specific peer, ready for the application layer.
118#[derive(Debug, Clone)]
119pub struct ReceivedPacket {
120    pub from:   SocketAddr,
121    pub packet: Packet,
122}
123
124// ─── NonBlockingSocket ───────────────────────────────────────────────────────
125
126/// Non-blocking UDP socket wrapper with poll-based receive.
127pub struct NonBlockingSocket {
128    socket: UdpSocket,
129    /// Local address this socket is bound to.
130    pub local_addr: SocketAddr,
131}
132
133impl NonBlockingSocket {
134    /// Bind to `addr` and set non-blocking mode.
135    pub fn bind(addr: SocketAddr) -> Result<Self, std::io::Error> {
136        let socket = UdpSocket::bind(addr)?;
137        socket.set_nonblocking(true)?;
138        let local_addr = socket.local_addr()?;
139        Ok(Self { socket, local_addr })
140    }
141
142    /// Send `data` to `dest`.  Returns bytes written.
143    pub fn send_to(&self, data: &[u8], dest: SocketAddr) -> Result<usize, std::io::Error> {
144        self.socket.send_to(data, dest)
145    }
146
147    /// Poll for available packets.  Returns all currently-buffered datagrams.
148    /// Stops on `WouldBlock` (nothing more to read right now).
149    pub fn poll(&self, buf: &mut Vec<u8>) -> Vec<(SocketAddr, Vec<u8>)> {
150        let mut results = Vec::new();
151        buf.resize(65535, 0);
152        loop {
153            match self.socket.recv_from(buf) {
154                Ok((len, addr)) => {
155                    results.push((addr, buf[..len].to_vec()));
156                }
157                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
158                Err(_) => break,
159            }
160        }
161        results
162    }
163}
164
165// ─── Fragmenter ──────────────────────────────────────────────────────────────
166
167/// Fragment header appended before each fragment payload.
168/// Layout: packet_id(u16 BE) + fragment_idx(u8) + total_fragments(u8) = 4 bytes.
169#[derive(Debug, Clone)]
170struct FragmentHeader {
171    packet_id:       u16,
172    fragment_idx:    u8,
173    total_fragments: u8,
174}
175
176impl FragmentHeader {
177    const SIZE: usize = 4;
178
179    fn encode(&self) -> [u8; Self::SIZE] {
180        let id_bytes = self.packet_id.to_be_bytes();
181        [id_bytes[0], id_bytes[1], self.fragment_idx, self.total_fragments]
182    }
183
184    fn decode(b: &[u8]) -> Option<Self> {
185        if b.len() < Self::SIZE { return None; }
186        Some(Self {
187            packet_id:       u16::from_be_bytes([b[0], b[1]]),
188            fragment_idx:    b[2],
189            total_fragments: b[3],
190        })
191    }
192}
193
194/// Partial reassembly state for one large message.
195#[derive(Debug)]
196struct PartialMessage {
197    total_fragments: u8,
198    received:        Vec<Option<Vec<u8>>>,
199    created_at:      Instant,
200}
201
202impl PartialMessage {
203    fn new(total: u8) -> Self {
204        Self {
205            total_fragments: total,
206            received:        vec![None; total as usize],
207            created_at:      Instant::now(),
208        }
209    }
210
211    fn is_complete(&self) -> bool {
212        self.received.iter().all(|s| s.is_some())
213    }
214
215    fn is_expired(&self) -> bool {
216        self.created_at.elapsed() > Duration::from_millis(FRAGMENT_TIMEOUT_MS)
217    }
218
219    fn reassemble(&self) -> Vec<u8> {
220        self.received.iter().flat_map(|s| s.as_ref().unwrap().iter().copied()).collect()
221    }
222}
223
224/// Splits large payloads into MTU-safe fragments and reassembles on the
225/// receiver side.
226pub struct Fragmenter {
227    next_packet_id: u16,
228    /// In-progress reassembly: keyed by (peer_addr_hash, packet_id).
229    reassembly:     HashMap<(u64, u16), PartialMessage>,
230}
231
232impl Fragmenter {
233    pub fn new() -> Self {
234        Self {
235            next_packet_id: 0,
236            reassembly:     HashMap::new(),
237        }
238    }
239
240    /// Split `data` into MTU-sized chunks.  Returns vec of raw datagrams ready
241    /// to send.  Each datagram has the 4-byte fragment header prepended.
242    pub fn fragment(&mut self, data: &[u8]) -> Vec<Vec<u8>> {
243        let max_body = MTU - PacketHeader::SIZE - FragmentHeader::SIZE;
244        let chunks: Vec<&[u8]> = data.chunks(max_body).collect();
245        let total = chunks.len().min(255) as u8;
246        let id    = self.next_packet_id;
247        self.next_packet_id = self.next_packet_id.wrapping_add(1);
248
249        chunks.iter().enumerate().take(255).map(|(i, chunk)| {
250            let fh = FragmentHeader {
251                packet_id:       id,
252                fragment_idx:    i as u8,
253                total_fragments: total,
254            };
255            let mut out = Vec::with_capacity(FragmentHeader::SIZE + chunk.len());
256            out.extend_from_slice(&fh.encode());
257            out.extend_from_slice(chunk);
258            out
259        }).collect()
260    }
261
262    /// Feed an incoming fragment.  Returns `Some(reassembled_data)` when all
263    /// fragments for a message have arrived.
264    pub fn receive_fragment(&mut self, peer_key: u64, raw: &[u8]) -> Option<Vec<u8>> {
265        let fh = FragmentHeader::decode(raw)?;
266        let body = raw[FragmentHeader::SIZE..].to_vec();
267
268        let entry = self.reassembly
269            .entry((peer_key, fh.packet_id))
270            .or_insert_with(|| PartialMessage::new(fh.total_fragments));
271
272        if fh.fragment_idx as usize >= entry.received.len() {
273            return None; // malformed
274        }
275        entry.received[fh.fragment_idx as usize] = Some(body);
276
277        if entry.is_complete() {
278            let data = entry.reassemble();
279            self.reassembly.remove(&(peer_key, fh.packet_id));
280            Some(data)
281        } else {
282            None
283        }
284    }
285
286    /// Evict stale partial messages to free memory.
287    pub fn gc(&mut self) {
288        self.reassembly.retain(|_, v| !v.is_expired());
289    }
290}
291
292impl Default for Fragmenter {
293    fn default() -> Self { Self::new() }
294}
295
296// ─── SendEntry ────────────────────────────────────────────────────────────────
297
298/// A reliable packet sitting in the retransmit queue.
299#[derive(Debug, Clone)]
300struct SendEntry {
301    sequence:          u32,
302    data:              Vec<u8>,
303    sent_at:           Instant,
304    next_retransmit:   Instant,
305    retransmit_count:  u32,
306    retransmit_delay:  Duration,
307}
308
309impl SendEntry {
310    fn new(sequence: u32, data: Vec<u8>, now: Instant) -> Self {
311        let delay = Duration::from_millis(RETRANSMIT_BASE_MS);
312        Self {
313            sequence,
314            data,
315            sent_at: now,
316            next_retransmit: now + delay,
317            retransmit_count: 0,
318            retransmit_delay: delay,
319        }
320    }
321
322    /// Advance the retransmit timer with exponential backoff.
323    fn backoff(&mut self, now: Instant) {
324        self.retransmit_count += 1;
325        self.retransmit_delay = Duration::from_millis(
326            (self.retransmit_delay.as_millis() as u64 * 2).min(RETRANSMIT_MAX_MS),
327        );
328        self.next_retransmit = now + self.retransmit_delay;
329    }
330
331    fn is_due(&self, now: Instant) -> bool {
332        now >= self.next_retransmit
333    }
334}
335
336// ─── ReorderBuffer ───────────────────────────────────────────────────────────
337
338/// Holds out-of-order packets until gaps are filled, then delivers in sequence.
339struct ReorderBuffer {
340    /// Next expected sequence number.
341    next_expected: u32,
342    /// Buffered out-of-order packets keyed by sequence.
343    buffer: HashMap<u32, Packet>,
344    /// Maximum buffered packets before we advance anyway.
345    max_hold: usize,
346}
347
348impl ReorderBuffer {
349    fn new() -> Self {
350        Self { next_expected: 0, buffer: HashMap::new(), max_hold: 64 }
351    }
352
353    /// Insert a packet.  Returns a vector of in-order packets now deliverable.
354    fn insert(&mut self, pkt: Packet) -> Vec<Packet> {
355        let seq = pkt.sequence;
356
357        if seq == self.next_expected {
358            // In-order — deliver immediately plus any buffered that follow.
359            let mut out = vec![pkt];
360            self.next_expected = self.next_expected.wrapping_add(1);
361            loop {
362                if let Some(next) = self.buffer.remove(&self.next_expected) {
363                    out.push(next);
364                    self.next_expected = self.next_expected.wrapping_add(1);
365                } else {
366                    break;
367                }
368            }
369            out
370        } else {
371            // Out-of-order — buffer it.
372            self.buffer.insert(seq, pkt);
373            // If buffer is overflowing, flush everything and skip ahead.
374            if self.buffer.len() > self.max_hold {
375                let mut all: Vec<Packet> = self.buffer.drain().map(|(_, p)| p).collect();
376                all.sort_by_key(|p| p.sequence);
377                if let Some(last) = all.last() {
378                    self.next_expected = last.sequence.wrapping_add(1);
379                }
380                return all;
381            }
382            Vec::new()
383        }
384    }
385
386    fn reset(&mut self) {
387        self.next_expected = 0;
388        self.buffer.clear();
389    }
390}
391
392// ─── AckAccumulator ──────────────────────────────────────────────────────────
393
394/// Maintains the ack + ack_bits fields sent with each outgoing packet.
395struct AckAccumulator {
396    last_received: u32,
397    ack_bits:      u32,
398}
399
400impl AckAccumulator {
401    fn new() -> Self { Self { last_received: 0, ack_bits: 0 } }
402
403    /// Record that we received `seq`.
404    fn record(&mut self, seq: u32) {
405        let diff = self.last_received.wrapping_sub(seq);
406        if seq == self.last_received {
407            // duplicate — ignore
408        } else if seq.wrapping_sub(self.last_received) < 0x8000_0000 {
409            // newer
410            let advance = seq.wrapping_sub(self.last_received);
411            if advance >= 32 {
412                self.ack_bits = 0;
413            } else {
414                self.ack_bits <<= advance;
415                self.ack_bits |= 1 << (advance - 1);
416            }
417            self.last_received = seq;
418        } else if diff < 32 {
419            // older but within window
420            self.ack_bits |= 1 << (diff - 1);
421        }
422    }
423
424    fn ack(&self)      -> u32 { self.last_received }
425    fn ack_bits(&self) -> u32 { self.ack_bits }
426}
427
428// ─── CongestionControl ───────────────────────────────────────────────────────
429
430/// Simple AIMD (Additive Increase Multiplicative Decrease) congestion window.
431struct CongestionControl {
432    /// Current congestion window in packets.
433    pub cwnd: u32,
434    ssthresh: u32,
435}
436
437impl CongestionControl {
438    fn new() -> Self { Self { cwnd: 16, ssthresh: 64 } }
439
440    /// Called on each ack — increase window.
441    fn on_ack(&mut self) {
442        if self.cwnd < self.ssthresh {
443            // Slow start: double each ack
444            self.cwnd = (self.cwnd + 2).min(256);
445        } else {
446            // Congestion avoidance: +1 per RTT
447            self.cwnd = (self.cwnd + 1).min(256);
448        }
449    }
450
451    /// Called on detected loss — halve window.
452    fn on_loss(&mut self) {
453        self.ssthresh = (self.cwnd / 2).max(4);
454        self.cwnd = self.ssthresh;
455    }
456
457    fn can_send(&self, in_flight: u32) -> bool {
458        in_flight < self.cwnd
459    }
460}
461
462// ─── ReliableUdp ─────────────────────────────────────────────────────────────
463
464/// Reliable ordered UDP transport for a single peer connection.
465///
466/// Caller owns the `NonBlockingSocket` and passes it in for send operations.
467/// `tick()` must be called frequently (each game frame) to drive retransmits.
468pub struct ReliableUdp {
469    pub peer_addr:     SocketAddr,
470    state:             ConnectionState,
471    next_sequence:     u32,
472    ack_accum:         AckAccumulator,
473    send_queue:        VecDeque<SendEntry>,
474    reorder_buf:       ReorderBuffer,
475    congestion:        CongestionControl,
476    rtt_ms:            f64,
477    jitter_ms:         f64,
478    last_recv:         Instant,
479    last_keepalive:    Instant,
480    /// Pending timestamps for RTT calculation: seq -> sent_at.
481    ping_map:          HashMap<u32, Instant>,
482    encoder:           PacketEncoder,
483    decoder:           PacketDecoder,
484    stats:             TransportStats,
485    /// Number of reliable packets currently in flight.
486    in_flight:         u32,
487}
488
489impl ReliableUdp {
490    pub fn new(peer_addr: SocketAddr) -> Self {
491        let now = Instant::now();
492        Self {
493            peer_addr,
494            state:          ConnectionState::Connecting,
495            next_sequence:  0,
496            ack_accum:      AckAccumulator::new(),
497            send_queue:     VecDeque::new(),
498            reorder_buf:    ReorderBuffer::new(),
499            congestion:     CongestionControl::new(),
500            rtt_ms:         50.0,
501            jitter_ms:      0.0,
502            last_recv:      now,
503            last_keepalive: now,
504            ping_map:       HashMap::new(),
505            encoder:        PacketEncoder { reliable: true, ..PacketEncoder::default() },
506            decoder:        PacketDecoder::new(),
507            stats:          TransportStats::default(),
508            in_flight:      0,
509        }
510    }
511
512    pub fn state(&self) -> ConnectionState { self.state }
513    pub fn stats(&self) -> &TransportStats { &self.stats }
514    pub fn rtt_ms(&self) -> f64 { self.rtt_ms }
515
516    /// Allocate a new sequence number.
517    fn next_seq(&mut self) -> u32 {
518        let s = self.next_sequence;
519        self.next_sequence = self.next_sequence.wrapping_add(1);
520        s
521    }
522
523    /// Enqueue a reliable packet for delivery.
524    pub fn send_reliable(&mut self, socket: &NonBlockingSocket, mut packet: Packet) {
525        packet.sequence = self.next_seq();
526        packet.ack      = self.ack_accum.ack();
527        packet.ack_bits = self.ack_accum.ack_bits();
528        packet.flags    |= PacketHeader::FLAG_RELIABLE;
529
530        if let Ok(data) = self.encoder.encode(&packet) {
531            // Send immediately if window allows
532            if self.congestion.can_send(self.in_flight) {
533                let _ = socket.send_to(&data, self.peer_addr);
534                self.in_flight += 1;
535                self.stats.packets_sent += 1;
536                self.stats.bandwidth_up += data.len() as f64;
537
538                if packet.kind == PacketKind::Ping {
539                    self.ping_map.insert(packet.sequence, Instant::now());
540                }
541
542                let entry = SendEntry::new(packet.sequence, data, Instant::now());
543                self.send_queue.push_back(entry);
544            } else {
545                // Queue for later
546                let entry = SendEntry::new(packet.sequence, data, Instant::now());
547                self.send_queue.push_back(entry);
548            }
549        }
550    }
551
552    /// Send a best-effort (unreliable) packet.
553    pub fn send_unreliable(&mut self, socket: &NonBlockingSocket, mut packet: Packet) {
554        packet.sequence = self.next_seq();
555        packet.ack      = self.ack_accum.ack();
556        packet.ack_bits = self.ack_accum.ack_bits();
557        if let Ok(data) = self.encoder.encode(&packet) {
558            let _ = socket.send_to(&data, self.peer_addr);
559            self.stats.packets_sent += 1;
560            self.stats.bandwidth_up += data.len() as f64;
561        }
562    }
563
564    /// Called by `ConnectionManager` when a raw datagram arrives from this peer.
565    /// Returns decoded in-order packets ready for the application.
566    pub fn receive(&mut self, raw: &[u8]) -> Vec<Packet> {
567        self.last_recv = Instant::now();
568        self.stats.bandwidth_down += raw.len() as f64;
569
570        let (pkt, _) = match self.decoder.decode(raw) {
571            Ok(p)  => p,
572            Err(_) => return Vec::new(),
573        };
574
575        self.stats.packets_recv += 1;
576
577        // Process ack / ack_bits from incoming packet
578        self.process_acks(pkt.ack, pkt.ack_bits);
579
580        // Record this packet in our ack accumulator
581        self.ack_accum.record(pkt.sequence);
582
583        // Handle control packets internally
584        match pkt.kind {
585            PacketKind::Pong => {
586                self.handle_pong(&pkt);
587                return Vec::new();
588            }
589            PacketKind::Heartbeat => {
590                if self.state == ConnectionState::Connecting {
591                    self.state = ConnectionState::Connected;
592                }
593                return Vec::new();
594            }
595            PacketKind::Disconnect => {
596                self.state = ConnectionState::Disconnected;
597                return Vec::new();
598            }
599            PacketKind::Connect => {
600                self.state = ConnectionState::Connected;
601                return Vec::new();
602            }
603            _ => {}
604        }
605
606        if self.state == ConnectionState::Connecting {
607            self.state = ConnectionState::Connected;
608        }
609
610        // For ordered channels, buffer and reorder.
611        if pkt.is_reliable() || pkt.flags & PacketHeader::FLAG_ORDERED != 0 {
612            self.reorder_buf.insert(pkt)
613        } else {
614            vec![pkt]
615        }
616    }
617
618    /// Process acks received in the incoming packet's header fields.
619    fn process_acks(&mut self, ack: u32, ack_bits: u32) {
620        // The `ack` field is the highest sequence the remote has received.
621        // Bits in `ack_bits` indicate which of the 32 prior sequences were also received.
622
623        let mut acked_seqs = Vec::new();
624        acked_seqs.push(ack);
625        for i in 0..32u32 {
626            if ack_bits & (1 << i) != 0 {
627                acked_seqs.push(ack.wrapping_sub(i + 1));
628            }
629        }
630
631        let now = Instant::now();
632        let mut any_acked = false;
633
634        self.send_queue.retain(|entry| {
635            if acked_seqs.contains(&entry.sequence) {
636                any_acked = true;
637                if let Some(&sent_at) = self.ping_map.get(&entry.sequence) {
638                    let rtt = now.duration_since(sent_at).as_secs_f64() * 1000.0;
639                    let err = (rtt - self.rtt_ms).abs();
640                    self.jitter_ms = JITTER_ALPHA * err + (1.0 - JITTER_ALPHA) * self.jitter_ms;
641                    self.rtt_ms    = RTT_ALPHA * rtt + (1.0 - RTT_ALPHA) * self.rtt_ms;
642                    self.ping_map.remove(&entry.sequence);
643                }
644                self.in_flight = self.in_flight.saturating_sub(1);
645                false // remove from queue
646            } else {
647                true // keep
648            }
649        });
650
651        if any_acked {
652            self.congestion.on_ack();
653        }
654
655        // Compute packet loss from ack gaps
656        let loss = self.estimate_packet_loss(ack, ack_bits);
657        self.stats.packet_loss_pct = loss;
658        self.stats.rtt_ms          = self.rtt_ms;
659        self.stats.jitter_ms       = self.jitter_ms;
660    }
661
662    fn estimate_packet_loss(&self, _ack: u32, ack_bits: u32) -> f64 {
663        // Count set bits in ack_bits; bits NOT set indicate lost packets.
664        let received = ack_bits.count_ones();
665        let window   = 32u32;
666        let lost     = window - received;
667        (lost as f64 / window as f64) * 100.0
668    }
669
670    fn handle_pong(&mut self, pkt: &Packet) {
671        if pkt.payload.len() < 8 { return; }
672        let ping_seq_bytes: [u8; 8] = pkt.payload[0..8].try_into().unwrap_or_default();
673        let _ping_ts = u64::from_be_bytes(ping_seq_bytes);
674        // RTT already updated in process_acks via ping_map; nothing else to do.
675    }
676
677    /// Drive retransmits, keepalives, and timeout detection.
678    /// Call every frame.  Returns packets that need to be sent via `socket`.
679    pub fn tick(&mut self, socket: &NonBlockingSocket) {
680        let now = Instant::now();
681
682        // Timeout detection
683        if now.duration_since(self.last_recv) > Duration::from_millis(PEER_TIMEOUT_MS) {
684            self.state = ConnectionState::TimedOut;
685            return;
686        }
687
688        // Retransmits
689        let mut lost_count = 0u32;
690        for entry in self.send_queue.iter_mut() {
691            if entry.is_due(now) {
692                if entry.retransmit_count >= MAX_RETRANSMIT {
693                    // Give up — will be cleaned up below
694                    lost_count += 1;
695                    continue;
696                }
697                let _ = socket.send_to(&entry.data, self.peer_addr);
698                self.stats.retransmits += 1;
699                entry.backoff(now);
700            }
701        }
702
703        // Remove exhausted entries
704        self.send_queue.retain(|e| e.retransmit_count < MAX_RETRANSMIT);
705
706        // Signal loss to congestion control
707        if lost_count > 0 {
708            self.congestion.on_loss();
709            self.in_flight = self.in_flight.saturating_sub(lost_count);
710        }
711
712        // Keepalive
713        if now.duration_since(self.last_keepalive) > Duration::from_millis(KEEPALIVE_MS) {
714            self.last_keepalive = now;
715            let seq = self.next_seq();
716            let hb = Packet::heartbeat(seq, self.ack_accum.ack(), self.ack_accum.ack_bits());
717            if let Ok(data) = self.encoder.encode(&hb) {
718                let _ = socket.send_to(&data, self.peer_addr);
719                self.stats.packets_sent += 1;
720            }
721        }
722    }
723
724    pub fn disconnect(&mut self, socket: &NonBlockingSocket) {
725        let seq = self.next_seq();
726        let pkt = Packet::new(
727            PacketKind::Disconnect, seq,
728            self.ack_accum.ack(), self.ack_accum.ack_bits(), Vec::new(),
729        );
730        if let Ok(data) = self.encoder.encode(&pkt) {
731            let _ = socket.send_to(&data, self.peer_addr);
732        }
733        self.state = ConnectionState::Disconnected;
734    }
735
736    pub fn reset_reorder(&mut self) {
737        self.reorder_buf.reset();
738    }
739}
740
741// ─── ConnectionManager ────────────────────────────────────────────────────────
742
743/// Manages multiple UDP peer connections over a single local socket.
744pub struct ConnectionManager {
745    socket:     NonBlockingSocket,
746    peers:      HashMap<SocketAddr, ReliableUdp>,
747    fragmenter: Fragmenter,
748    encoder:    PacketEncoder,
749    recv_buf:   Vec<u8>,
750}
751
752impl ConnectionManager {
753    /// Create a `ConnectionManager` bound to `local_addr`.
754    pub fn bind(local_addr: SocketAddr) -> Result<Self, std::io::Error> {
755        Ok(Self {
756            socket:     NonBlockingSocket::bind(local_addr)?,
757            peers:      HashMap::new(),
758            fragmenter: Fragmenter::new(),
759            encoder:    PacketEncoder::new(),
760            recv_buf:   vec![0u8; 65535],
761        })
762    }
763
764    /// Initiate a connection to `addr`.
765    pub fn connect(&mut self, addr: SocketAddr) {
766        let mut conn = ReliableUdp::new(addr);
767        let pkt = Packet::new(PacketKind::Connect, 0, 0, 0, Vec::new());
768        conn.send_reliable(&self.socket, pkt);
769        self.peers.insert(addr, conn);
770    }
771
772    /// Gracefully disconnect a peer.
773    pub fn disconnect(&mut self, addr: SocketAddr) {
774        if let Some(conn) = self.peers.get_mut(&addr) {
775            conn.disconnect(&self.socket);
776        }
777        self.peers.remove(&addr);
778    }
779
780    /// Send `data` to `addr` on `channel`.  Fragments if larger than MTU.
781    pub fn send(&mut self, addr: SocketAddr, channel: Channel, data: Vec<u8>) {
782        let needs_fragment = data.len() > MTU - PacketHeader::SIZE;
783
784        let peer = self.peers.entry(addr).or_insert_with(|| ReliableUdp::new(addr));
785
786        if needs_fragment {
787            let frags = self.fragmenter.fragment(&data);
788            for frag in frags {
789                let mut pkt = Packet::new(
790                    PacketKind::StateUpdate,
791                    0, 0, 0, frag,
792                );
793                pkt.flags |= PacketHeader::FLAG_FRAGMENTED;
794                if channel.is_reliable() {
795                    peer.send_reliable(&self.socket, pkt);
796                } else {
797                    peer.send_unreliable(&self.socket, pkt);
798                }
799            }
800        } else {
801            let pkt = Packet::new(PacketKind::StateUpdate, 0, 0, 0, data);
802            if channel.is_reliable() {
803                peer.send_reliable(&self.socket, pkt);
804            } else {
805                peer.send_unreliable(&self.socket, pkt);
806            }
807        }
808    }
809
810    /// Send a typed `Packet` to `addr` on `channel`.
811    pub fn send_packet(&mut self, addr: SocketAddr, channel: Channel, packet: Packet) {
812        let peer = self.peers.entry(addr).or_insert_with(|| ReliableUdp::new(addr));
813        if channel.is_reliable() {
814            peer.send_reliable(&self.socket, packet);
815        } else {
816            peer.send_unreliable(&self.socket, packet);
817        }
818    }
819
820    /// Poll the socket for incoming datagrams and drive retransmits.
821    /// Returns all application-level packets received this frame.
822    pub fn poll(&mut self) -> Vec<ReceivedPacket> {
823        let mut out = Vec::new();
824
825        // Receive all pending datagrams
826        let datagrams = self.socket.poll(&mut self.recv_buf);
827        for (addr, raw) in datagrams {
828            let peer = self.peers.entry(addr).or_insert_with(|| ReliableUdp::new(addr));
829            let packets = peer.receive(&raw);
830            for pkt in packets {
831                out.push(ReceivedPacket { from: addr, packet: pkt });
832            }
833        }
834
835        // Tick all peers (retransmit / keepalive / timeout)
836        for conn in self.peers.values_mut() {
837            conn.tick(&self.socket);
838        }
839
840        // Clean up timed-out / disconnected peers
841        self.peers.retain(|_, conn| {
842            !matches!(conn.state(), ConnectionState::TimedOut | ConnectionState::Disconnected)
843        });
844
845        // Fragment GC
846        self.fragmenter.gc();
847
848        out
849    }
850
851    /// Returns the current state of a peer connection.
852    pub fn peer_state(&self, addr: SocketAddr) -> Option<ConnectionState> {
853        self.peers.get(&addr).map(|c| c.state())
854    }
855
856    /// Returns transport stats for a peer.
857    pub fn peer_stats(&self, addr: SocketAddr) -> Option<&TransportStats> {
858        self.peers.get(&addr).map(|c| c.stats())
859    }
860
861    /// Returns count of connected peers.
862    pub fn peer_count(&self) -> usize {
863        self.peers.len()
864    }
865
866    /// Returns all connected peer addresses.
867    pub fn peer_addrs(&self) -> Vec<SocketAddr> {
868        self.peers.keys().copied().collect()
869    }
870
871    /// Broadcast a packet to all connected peers on `channel`.
872    pub fn broadcast(&mut self, channel: Channel, packet: Packet) {
873        let addrs: Vec<SocketAddr> = self.peers.keys().copied().collect();
874        for addr in addrs {
875            self.send_packet(addr, channel, packet.clone());
876        }
877    }
878}
879
880// ─── Tests ────────────────────────────────────────────────────────────────────
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885    use std::net::SocketAddr;
886
887    fn loopback(port: u16) -> SocketAddr {
888        format!("127.0.0.1:{port}").parse().unwrap()
889    }
890
891    // ── Channel flags ─────────────────────────────────────────────────────────
892
893    #[test]
894    fn test_channel_flags() {
895        assert!(Channel::Reliable.is_reliable());
896        assert!(Channel::ReliableOrdered.is_reliable());
897        assert!(!Channel::Unreliable.is_reliable());
898        assert!(!Channel::UnreliableOrdered.is_reliable());
899
900        assert!(Channel::ReliableOrdered.is_ordered());
901        assert!(Channel::UnreliableOrdered.is_ordered());
902        assert!(!Channel::Reliable.is_ordered());
903        assert!(!Channel::Unreliable.is_ordered());
904    }
905
906    // ── Fragmenter ────────────────────────────────────────────────────────────
907
908    #[test]
909    fn test_fragmenter_roundtrip_small() {
910        let mut f = Fragmenter::new();
911        let data = vec![0xABu8; 100];
912        let frags = f.fragment(&data);
913        assert_eq!(frags.len(), 1);
914        let result = f.receive_fragment(1, &frags[0]);
915        assert!(result.is_some());
916        assert_eq!(result.unwrap(), data);
917    }
918
919    #[test]
920    fn test_fragmenter_roundtrip_large() {
921        let mut f = Fragmenter::new();
922        let data: Vec<u8> = (0..4000).map(|i| (i % 251) as u8).collect();
923        let frags = f.fragment(&data);
924        assert!(frags.len() > 1);
925
926        let mut assembled = None;
927        for frag in frags {
928            assembled = f.receive_fragment(99, &frag);
929        }
930        assert!(assembled.is_some());
931        assert_eq!(assembled.unwrap(), data);
932    }
933
934    #[test]
935    fn test_fragmenter_out_of_order() {
936        let mut f = Fragmenter::new();
937        let data: Vec<u8> = (0..4000).map(|i| (i % 127) as u8).collect();
938        let mut frags = f.fragment(&data);
939        // Reverse fragment order
940        frags.reverse();
941        let mut assembled = None;
942        for frag in frags {
943            assembled = f.receive_fragment(7, &frag);
944        }
945        assert!(assembled.is_some());
946        // Data bytes may not match due to reversal, but reassembly completed
947        assert_eq!(assembled.unwrap().len(), data.len());
948    }
949
950    // ── AckAccumulator ────────────────────────────────────────────────────────
951
952    #[test]
953    fn test_ack_accumulator_basic() {
954        let mut acc = AckAccumulator::new();
955        acc.record(0);
956        acc.record(1);
957        acc.record(2);
958        assert_eq!(acc.ack(), 2);
959        // Bits: 1 means seq=1 received, bit 1 means seq=0 received
960        assert!(acc.ack_bits() & 1 != 0); // seq=1
961        assert!(acc.ack_bits() & 2 != 0); // seq=0
962    }
963
964    // ── CongestionControl ─────────────────────────────────────────────────────
965
966    #[test]
967    fn test_congestion_control_aimd() {
968        let mut cc = CongestionControl::new();
969        let initial = cc.cwnd;
970        cc.on_ack();
971        cc.on_ack();
972        assert!(cc.cwnd >= initial); // window grew or stayed
973        let before_loss = cc.cwnd;
974        cc.on_loss();
975        assert!(cc.cwnd < before_loss); // window shrank
976    }
977
978    // ── ReorderBuffer ─────────────────────────────────────────────────────────
979
980    #[test]
981    fn test_reorder_buffer_in_order() {
982        let mut rb = ReorderBuffer::new();
983        let p0 = Packet::new(PacketKind::StateUpdate, 0, 0, 0, vec![]);
984        let p1 = Packet::new(PacketKind::StateUpdate, 1, 0, 0, vec![]);
985        let out0 = rb.insert(p0);
986        let out1 = rb.insert(p1);
987        assert_eq!(out0.len(), 1);
988        assert_eq!(out1.len(), 1);
989    }
990
991    #[test]
992    fn test_reorder_buffer_out_of_order() {
993        let mut rb = ReorderBuffer::new();
994        let p0 = Packet::new(PacketKind::StateUpdate, 0, 0, 0, vec![1]);
995        let p2 = Packet::new(PacketKind::StateUpdate, 2, 0, 0, vec![3]);
996        let p1 = Packet::new(PacketKind::StateUpdate, 1, 0, 0, vec![2]);
997
998        let out_p0 = rb.insert(p0); // seq=0 → delivered immediately
999        let out_p2 = rb.insert(p2); // seq=2 → buffered
1000        assert_eq!(out_p0.len(), 1);
1001        assert_eq!(out_p2.len(), 0);
1002
1003        let out_p1 = rb.insert(p1); // seq=1 → delivers 1 and 2
1004        assert_eq!(out_p1.len(), 2);
1005    }
1006
1007    // ── ConnectionManager bind ────────────────────────────────────────────────
1008
1009    #[test]
1010    fn test_connection_manager_bind() {
1011        // Just verify we can bind on an ephemeral port
1012        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
1013        let mgr = ConnectionManager::bind(addr);
1014        assert!(mgr.is_ok());
1015    }
1016
1017    #[test]
1018    fn test_connection_manager_peer_count() {
1019        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
1020        let mut mgr = ConnectionManager::bind(addr).unwrap();
1021        assert_eq!(mgr.peer_count(), 0);
1022    }
1023
1024    // ── NonBlockingSocket ─────────────────────────────────────────────────────
1025
1026    #[test]
1027    fn test_non_blocking_socket_bind() {
1028        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
1029        let sock = NonBlockingSocket::bind(addr);
1030        assert!(sock.is_ok());
1031        let sock = sock.unwrap();
1032        // Port should be non-zero (assigned by OS)
1033        assert_ne!(sock.local_addr.port(), 0);
1034    }
1035
1036    // ── SendEntry backoff ─────────────────────────────────────────────────────
1037
1038    #[test]
1039    fn test_send_entry_backoff_growth() {
1040        let mut entry = SendEntry::new(1, vec![0u8; 10], Instant::now());
1041        let d0 = entry.retransmit_delay;
1042        entry.backoff(Instant::now());
1043        let d1 = entry.retransmit_delay;
1044        entry.backoff(Instant::now());
1045        let d2 = entry.retransmit_delay;
1046        assert!(d1 >= d0);
1047        assert!(d2 >= d1);
1048        assert!(d2.as_millis() <= RETRANSMIT_MAX_MS as u128);
1049    }
1050}