1use std::collections::{HashMap, VecDeque};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
9pub struct TransportConfig {
10 pub max_packet_size: usize,
11 pub fragment_size: usize,
12 pub max_fragments_per_packet: usize,
13 pub reliable_window_size: usize,
14 pub max_retransmits: u32,
15 pub retransmit_timeout_ms: u64,
16 pub connection_timeout_ms: u64,
17 pub keepalive_interval_ms: u64,
18 pub max_bandwidth_bytes_per_sec: usize,
19 pub ack_redundancy: usize,
20}
21
22impl Default for TransportConfig {
23 fn default() -> Self {
24 Self {
25 max_packet_size: 1200,
26 fragment_size: 1024,
27 max_fragments_per_packet: 256,
28 reliable_window_size: 256,
29 max_retransmits: 10,
30 retransmit_timeout_ms: 200,
31 connection_timeout_ms: 10000,
32 keepalive_interval_ms: 1000,
33 max_bandwidth_bytes_per_sec: 65536,
34 ack_redundancy: 3,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct TransportStats {
42 pub packets_sent: u64,
43 pub packets_received: u64,
44 pub packets_lost: u64,
45 pub packets_acked: u64,
46 pub bytes_sent: u64,
47 pub bytes_received: u64,
48 pub retransmissions: u64,
49 pub rtt_ms: f64,
50 pub rtt_variance_ms: f64,
51 pub packet_loss_ratio: f64,
52 pub bandwidth_used_bytes_per_sec: f64,
53 pub fragments_sent: u64,
54 pub fragments_reassembled: u64,
55}
56
57impl TransportStats {
58 pub fn update_loss_ratio(&mut self) {
59 let total = self.packets_sent;
60 if total > 0 {
61 self.packet_loss_ratio = self.packets_lost as f64 / total as f64;
62 }
63 }
64
65 pub fn reset(&mut self) {
66 *self = Self::default();
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PacketType {
73 ConnectionRequest,
74 ConnectionAccept,
75 ConnectionDeny,
76 Disconnect,
77 Keepalive,
78 Reliable,
79 Unreliable,
80 Fragment,
81 Ack,
82}
83
84impl PacketType {
85 pub fn to_u8(self) -> u8 {
86 match self {
87 PacketType::ConnectionRequest => 0,
88 PacketType::ConnectionAccept => 1,
89 PacketType::ConnectionDeny => 2,
90 PacketType::Disconnect => 3,
91 PacketType::Keepalive => 4,
92 PacketType::Reliable => 5,
93 PacketType::Unreliable => 6,
94 PacketType::Fragment => 7,
95 PacketType::Ack => 8,
96 }
97 }
98
99 pub fn from_u8(v: u8) -> Option<Self> {
100 match v {
101 0 => Some(PacketType::ConnectionRequest),
102 1 => Some(PacketType::ConnectionAccept),
103 2 => Some(PacketType::ConnectionDeny),
104 3 => Some(PacketType::Disconnect),
105 4 => Some(PacketType::Keepalive),
106 5 => Some(PacketType::Reliable),
107 6 => Some(PacketType::Unreliable),
108 7 => Some(PacketType::Fragment),
109 8 => Some(PacketType::Ack),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct PacketHeader {
118 pub protocol_id: u32,
119 pub packet_type: PacketType,
120 pub sequence: u16,
121 pub ack: u16,
122 pub ack_bits: u32,
123 pub timestamp_ms: u64,
124 pub payload_size: u16,
125}
126
127impl PacketHeader {
128 pub const SERIALIZED_SIZE: usize = 4 + 1 + 2 + 2 + 4 + 8 + 2;
129
130 pub fn new(packet_type: PacketType, sequence: u16) -> Self {
131 Self {
132 protocol_id: 0x50524F46, packet_type,
134 sequence,
135 ack: 0,
136 ack_bits: 0,
137 timestamp_ms: 0,
138 payload_size: 0,
139 }
140 }
141
142 pub fn serialize(&self) -> Vec<u8> {
143 let mut buf = Vec::with_capacity(Self::SERIALIZED_SIZE);
144 buf.extend_from_slice(&self.protocol_id.to_le_bytes());
145 buf.push(self.packet_type.to_u8());
146 buf.extend_from_slice(&self.sequence.to_le_bytes());
147 buf.extend_from_slice(&self.ack.to_le_bytes());
148 buf.extend_from_slice(&self.ack_bits.to_le_bytes());
149 buf.extend_from_slice(&self.timestamp_ms.to_le_bytes());
150 buf.extend_from_slice(&self.payload_size.to_le_bytes());
151 buf
152 }
153
154 pub fn deserialize(data: &[u8]) -> Option<Self> {
155 if data.len() < Self::SERIALIZED_SIZE {
156 return None;
157 }
158 let protocol_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
159 let packet_type = PacketType::from_u8(data[4])?;
160 let sequence = u16::from_le_bytes([data[5], data[6]]);
161 let ack = u16::from_le_bytes([data[7], data[8]]);
162 let ack_bits = u32::from_le_bytes([data[9], data[10], data[11], data[12]]);
163 let timestamp_ms = u64::from_le_bytes([
164 data[13], data[14], data[15], data[16],
165 data[17], data[18], data[19], data[20],
166 ]);
167 let payload_size = u16::from_le_bytes([data[21], data[22]]);
168
169 Some(Self {
170 protocol_id,
171 packet_type,
172 sequence,
173 ack,
174 ack_bits,
175 timestamp_ms,
176 payload_size,
177 })
178 }
179
180 pub fn validate_protocol(&self) -> bool {
181 self.protocol_id == 0x50524F46
182 }
183}
184
185fn sequence_greater_than(a: u16, b: u16) -> bool {
187 ((a > b) && (a - b <= 32768)) || ((a < b) && (b - a > 32768))
188}
189
190fn sequence_difference(a: u16, b: u16) -> i32 {
191 if sequence_greater_than(a, b) {
192 if a >= b { (a - b) as i32 } else { (a as i32 + 65536) - b as i32 }
193 } else if a == b {
194 0
195 } else {
196 -(sequence_difference(b, a))
197 }
198}
199
200#[derive(Debug, Clone)]
202struct ReliableEntry {
203 sequence: u16,
204 data: Vec<u8>,
205 send_time: Instant,
206 retransmit_count: u32,
207 acked: bool,
208 next_retransmit: Instant,
209 channel_id: u8,
210}
211
212pub struct ReliableChannel {
214 config: TransportConfig,
215 local_sequence: u16,
216 remote_sequence: u16,
217 send_window: VecDeque<ReliableEntry>,
218 receive_buffer: HashMap<u16, Vec<u8>>,
219 next_deliver_sequence: u16,
220 pending_acks: Vec<u16>,
221 rtt_estimate_ms: f64,
222 rtt_variance_ms: f64,
223 smoothed_rtt_ms: f64,
224 stats: TransportStats,
225 channel_id: u8,
226 ack_bitfield: u32,
227 last_ack_sequence: u16,
228 send_queue: VecDeque<Vec<u8>>,
229}
230
231impl ReliableChannel {
232 pub fn new(channel_id: u8, config: TransportConfig) -> Self {
233 Self {
234 config,
235 local_sequence: 0,
236 remote_sequence: 0,
237 send_window: VecDeque::new(),
238 receive_buffer: HashMap::new(),
239 next_deliver_sequence: 0,
240 pending_acks: Vec::new(),
241 rtt_estimate_ms: 100.0,
242 rtt_variance_ms: 50.0,
243 smoothed_rtt_ms: 100.0,
244 stats: TransportStats::default(),
245 channel_id,
246 ack_bitfield: 0,
247 last_ack_sequence: 0,
248 send_queue: VecDeque::new(),
249 }
250 }
251
252 pub fn channel_id(&self) -> u8 {
253 self.channel_id
254 }
255
256 pub fn stats(&self) -> &TransportStats {
257 &self.stats
258 }
259
260 pub fn rtt_ms(&self) -> f64 {
261 self.smoothed_rtt_ms
262 }
263
264 pub fn local_sequence(&self) -> u16 {
265 self.local_sequence
266 }
267
268 pub fn remote_sequence(&self) -> u16 {
269 self.remote_sequence
270 }
271
272 pub fn send(&mut self, data: Vec<u8>) -> u16 {
274 let seq = self.local_sequence;
275 self.local_sequence = self.local_sequence.wrapping_add(1);
276 self.send_queue.push_back(data);
277 seq
278 }
279
280 pub fn flush(&mut self, now: Instant) -> Vec<(PacketHeader, Vec<u8>)> {
282 let mut packets = Vec::new();
283
284 while let Some(data) = self.send_queue.pop_front() {
285 let seq = self.local_sequence.wrapping_sub(self.send_queue.len() as u16 + 1);
286 let rto = self.compute_rto();
287
288 let entry = ReliableEntry {
289 sequence: seq,
290 data: data.clone(),
291 send_time: now,
292 retransmit_count: 0,
293 acked: false,
294 next_retransmit: now + rto,
295 channel_id: self.channel_id,
296 };
297
298 let mut header = PacketHeader::new(PacketType::Reliable, seq);
299 header.ack = self.remote_sequence;
300 header.ack_bits = self.ack_bitfield;
301 header.payload_size = data.len() as u16;
302
303 self.send_window.push_back(entry);
304 self.stats.packets_sent += 1;
305 self.stats.bytes_sent += (PacketHeader::SERIALIZED_SIZE + data.len()) as u64;
306
307 packets.push((header, data));
308 }
309
310 let rto = self.compute_rto();
312 let max_retransmits = self.config.max_retransmits;
313 let remote_seq = self.remote_sequence;
314 let ack_bits = self.ack_bitfield;
315 let mut retransmissions = 0u64;
316 let mut pkts_sent = 0u64;
317 let mut bytes_sent = 0u64;
318 let mut pkts_lost = 0u64;
319 for entry in self.send_window.iter_mut() {
320 if !entry.acked && now >= entry.next_retransmit {
321 if entry.retransmit_count < max_retransmits {
322 entry.retransmit_count += 1;
323 let backoff = rto * (1 << entry.retransmit_count.min(5));
325 entry.next_retransmit = now + backoff;
326 entry.send_time = now;
327
328 let mut header = PacketHeader::new(PacketType::Reliable, entry.sequence);
329 header.ack = remote_seq;
330 header.ack_bits = ack_bits;
331 header.payload_size = entry.data.len() as u16;
332
333 retransmissions += 1;
334 pkts_sent += 1;
335 bytes_sent += (PacketHeader::SERIALIZED_SIZE + entry.data.len()) as u64;
336
337 packets.push((header, entry.data.clone()));
338 } else {
339 pkts_lost += 1;
340 }
341 }
342 }
343 self.stats.retransmissions += retransmissions;
344 self.stats.packets_sent += pkts_sent;
345 self.stats.bytes_sent += bytes_sent;
346 self.stats.packets_lost += pkts_lost;
347
348 while let Some(front) = self.send_window.front() {
350 if front.acked || front.retransmit_count >= self.config.max_retransmits {
351 self.send_window.pop_front();
352 } else {
353 break;
354 }
355 }
356
357 packets
358 }
359
360 fn compute_rto(&self) -> Duration {
361 let rto_ms = self.smoothed_rtt_ms + 4.0 * self.rtt_variance_ms;
363 let rto_ms = rto_ms.max(self.config.retransmit_timeout_ms as f64);
364 Duration::from_millis(rto_ms as u64)
365 }
366
367 pub fn receive(&mut self, header: &PacketHeader, payload: Vec<u8>, now: Instant) {
369 self.stats.packets_received += 1;
370 self.stats.bytes_received += (PacketHeader::SERIALIZED_SIZE + payload.len()) as u64;
371
372 let seq = header.sequence;
373
374 if sequence_greater_than(seq, self.remote_sequence) {
376 let diff = sequence_difference(seq, self.remote_sequence);
377 if diff > 0 && diff <= 32 {
379 self.ack_bitfield <<= diff as u32;
380 self.ack_bitfield |= 1 << (diff as u32 - 1);
381 } else if diff > 32 {
382 self.ack_bitfield = 0;
383 }
384 self.remote_sequence = seq;
385 } else {
386 let diff = sequence_difference(self.remote_sequence, seq);
387 if diff > 0 && diff <= 32 {
388 self.ack_bitfield |= 1 << (diff as u32 - 1);
389 }
390 }
391
392 self.receive_buffer.insert(seq, payload);
394 self.pending_acks.push(seq);
395
396 self.process_acks(header.ack, header.ack_bits, now);
398 }
399
400 fn process_acks(&mut self, ack: u16, ack_bits: u32, now: Instant) {
401 let mut acked_count = 0u64;
402 let mut rtt_samples = Vec::new();
403 for entry in self.send_window.iter_mut() {
404 if entry.acked {
405 continue;
406 }
407 let seq = entry.sequence;
408 let is_acked = if seq == ack {
409 true
410 } else {
411 let diff = sequence_difference(ack, seq);
412 diff > 0 && diff <= 32 && (ack_bits & (1 << (diff - 1))) != 0
413 };
414
415 if is_acked {
416 entry.acked = true;
417 acked_count += 1;
418
419 if entry.retransmit_count == 0 {
420 let rtt = now.duration_since(entry.send_time).as_secs_f64() * 1000.0;
421 rtt_samples.push(rtt);
422 }
423 }
424 }
425 self.stats.packets_acked += acked_count;
426 for rtt in rtt_samples {
427 self.update_rtt(rtt);
428 }
429 }
430
431 fn update_rtt(&mut self, sample_ms: f64) {
432 let alpha = 0.125;
434 let beta = 0.25;
435
436 let err = sample_ms - self.smoothed_rtt_ms;
437 self.smoothed_rtt_ms += alpha * err;
438 self.rtt_variance_ms += beta * (err.abs() - self.rtt_variance_ms);
439 self.rtt_estimate_ms = sample_ms;
440 }
441
442 pub fn drain_received(&mut self) -> Vec<Vec<u8>> {
444 let mut messages = Vec::new();
445 loop {
446 if let Some(data) = self.receive_buffer.remove(&self.next_deliver_sequence) {
447 messages.push(data);
448 self.next_deliver_sequence = self.next_deliver_sequence.wrapping_add(1);
449 } else {
450 break;
451 }
452 }
453 messages
454 }
455
456 pub fn drain_pending_acks(&mut self) -> Vec<u16> {
458 std::mem::take(&mut self.pending_acks)
459 }
460
461 pub fn in_flight(&self) -> usize {
463 self.send_window.iter().filter(|e| !e.acked).count()
464 }
465
466 pub fn is_congested(&self) -> bool {
468 self.in_flight() >= self.config.reliable_window_size
469 }
470
471 pub fn reset(&mut self) {
473 self.local_sequence = 0;
474 self.remote_sequence = 0;
475 self.next_deliver_sequence = 0;
476 self.send_window.clear();
477 self.receive_buffer.clear();
478 self.pending_acks.clear();
479 self.send_queue.clear();
480 self.smoothed_rtt_ms = 100.0;
481 self.rtt_variance_ms = 50.0;
482 self.stats.reset();
483 }
484}
485
486pub struct UnreliableChannel {
488 local_sequence: u16,
489 remote_sequence: u16,
490 stats: TransportStats,
491 received_buffer: VecDeque<Vec<u8>>,
492 max_buffer_size: usize,
493 ack_bitfield: u32,
494 drop_out_of_order: bool,
495}
496
497impl UnreliableChannel {
498 pub fn new() -> Self {
499 Self {
500 local_sequence: 0,
501 remote_sequence: 0,
502 stats: TransportStats::default(),
503 received_buffer: VecDeque::new(),
504 max_buffer_size: 256,
505 ack_bitfield: 0,
506 drop_out_of_order: false,
507 }
508 }
509
510 pub fn with_max_buffer(mut self, size: usize) -> Self {
511 self.max_buffer_size = size;
512 self
513 }
514
515 pub fn set_drop_out_of_order(&mut self, drop: bool) {
516 self.drop_out_of_order = drop;
517 }
518
519 pub fn stats(&self) -> &TransportStats {
520 &self.stats
521 }
522
523 pub fn local_sequence(&self) -> u16 {
524 self.local_sequence
525 }
526
527 pub fn send(&mut self, data: Vec<u8>) -> (PacketHeader, Vec<u8>) {
529 let seq = self.local_sequence;
530 self.local_sequence = self.local_sequence.wrapping_add(1);
531
532 let mut header = PacketHeader::new(PacketType::Unreliable, seq);
533 header.ack = self.remote_sequence;
534 header.ack_bits = self.ack_bitfield;
535 header.payload_size = data.len() as u16;
536
537 self.stats.packets_sent += 1;
538 self.stats.bytes_sent += (PacketHeader::SERIALIZED_SIZE + data.len()) as u64;
539
540 (header, data)
541 }
542
543 pub fn receive(&mut self, header: &PacketHeader, payload: Vec<u8>) {
545 self.stats.packets_received += 1;
546 self.stats.bytes_received += (PacketHeader::SERIALIZED_SIZE + payload.len()) as u64;
547
548 let seq = header.sequence;
549
550 if self.drop_out_of_order && !sequence_greater_than(seq, self.remote_sequence) && seq != self.remote_sequence {
552 return;
554 }
555
556 if sequence_greater_than(seq, self.remote_sequence) {
558 let diff = sequence_difference(seq, self.remote_sequence);
559 if diff > 0 && diff <= 32 {
560 self.ack_bitfield <<= diff as u32;
561 self.ack_bitfield |= 1 << (diff as u32 - 1);
562 } else if diff > 32 {
563 self.ack_bitfield = 0;
564 }
565 self.remote_sequence = seq;
566 } else {
567 let diff = sequence_difference(self.remote_sequence, seq);
568 if diff > 0 && diff <= 32 {
569 self.ack_bitfield |= 1 << (diff as u32 - 1);
570 }
571 }
572
573 self.received_buffer.push_back(payload);
575 while self.received_buffer.len() > self.max_buffer_size {
576 self.received_buffer.pop_front();
577 }
578 }
579
580 pub fn drain_received(&mut self) -> Vec<Vec<u8>> {
582 self.received_buffer.drain(..).collect()
583 }
584
585 pub fn reset(&mut self) {
586 self.local_sequence = 0;
587 self.remote_sequence = 0;
588 self.received_buffer.clear();
589 self.stats.reset();
590 }
591}
592
593#[derive(Debug, Clone)]
595pub struct FragmentHeader {
596 pub group_id: u16,
597 pub fragment_index: u8,
598 pub total_fragments: u8,
599 pub fragment_size: u16,
600}
601
602impl FragmentHeader {
603 pub const SERIALIZED_SIZE: usize = 2 + 1 + 1 + 2;
604
605 pub fn serialize(&self) -> Vec<u8> {
606 let mut buf = Vec::with_capacity(Self::SERIALIZED_SIZE);
607 buf.extend_from_slice(&self.group_id.to_le_bytes());
608 buf.push(self.fragment_index);
609 buf.push(self.total_fragments);
610 buf.extend_from_slice(&self.fragment_size.to_le_bytes());
611 buf
612 }
613
614 pub fn deserialize(data: &[u8]) -> Option<Self> {
615 if data.len() < Self::SERIALIZED_SIZE {
616 return None;
617 }
618 Some(Self {
619 group_id: u16::from_le_bytes([data[0], data[1]]),
620 fragment_index: data[2],
621 total_fragments: data[3],
622 fragment_size: u16::from_le_bytes([data[4], data[5]]),
623 })
624 }
625}
626
627struct ReassemblyGroup {
629 group_id: u16,
630 total_fragments: u8,
631 received_mask: u64,
632 fragments: Vec<Option<Vec<u8>>>,
633 creation_time: Instant,
634 total_size: usize,
635}
636
637impl ReassemblyGroup {
638 fn new(group_id: u16, total_fragments: u8, now: Instant) -> Self {
639 let mut fragments = Vec::with_capacity(total_fragments as usize);
640 for _ in 0..total_fragments {
641 fragments.push(None);
642 }
643 Self {
644 group_id,
645 total_fragments,
646 received_mask: 0,
647 fragments,
648 creation_time: now,
649 total_size: 0,
650 }
651 }
652
653 fn insert(&mut self, index: u8, data: Vec<u8>) -> bool {
654 if index >= self.total_fragments {
655 return false;
656 }
657 let bit = 1u64 << index;
658 if self.received_mask & bit != 0 {
659 return false; }
661 self.received_mask |= bit;
662 self.total_size += data.len();
663 self.fragments[index as usize] = Some(data);
664 self.is_complete()
665 }
666
667 fn is_complete(&self) -> bool {
668 let expected = if self.total_fragments >= 64 {
669 u64::MAX
670 } else {
671 (1u64 << self.total_fragments) - 1
672 };
673 self.received_mask == expected
674 }
675
676 fn assemble(&self) -> Option<Vec<u8>> {
677 if !self.is_complete() {
678 return None;
679 }
680 let mut result = Vec::with_capacity(self.total_size);
681 for frag in &self.fragments {
682 if let Some(data) = frag {
683 result.extend_from_slice(data);
684 } else {
685 return None;
686 }
687 }
688 Some(result)
689 }
690
691 fn received_count(&self) -> u8 {
692 self.received_mask.count_ones() as u8
693 }
694
695 fn age(&self, now: Instant) -> Duration {
696 now.duration_since(self.creation_time)
697 }
698}
699
700pub struct ReassemblyBuffer {
702 groups: HashMap<u16, ReassemblyGroup>,
703 timeout: Duration,
704 max_groups: usize,
705}
706
707impl ReassemblyBuffer {
708 pub fn new(timeout_ms: u64, max_groups: usize) -> Self {
709 Self {
710 groups: HashMap::new(),
711 timeout: Duration::from_millis(timeout_ms),
712 max_groups,
713 }
714 }
715
716 pub fn insert(&mut self, header: &FragmentHeader, data: Vec<u8>, now: Instant) -> Option<Vec<u8>> {
717 self.cleanup(now);
719
720 let group = self.groups
721 .entry(header.group_id)
722 .or_insert_with(|| ReassemblyGroup::new(header.group_id, header.total_fragments, now));
723
724 if group.insert(header.fragment_index, data) {
725 let assembled = group.assemble();
726 self.groups.remove(&header.group_id);
727 assembled
728 } else {
729 None
730 }
731 }
732
733 fn cleanup(&mut self, now: Instant) {
734 self.groups.retain(|_, group| group.age(now) < self.timeout);
735
736 while self.groups.len() > self.max_groups {
738 let oldest = self.groups.iter()
739 .min_by_key(|(_, g)| g.creation_time)
740 .map(|(&id, _)| id);
741 if let Some(id) = oldest {
742 self.groups.remove(&id);
743 } else {
744 break;
745 }
746 }
747 }
748
749 pub fn pending_groups(&self) -> usize {
750 self.groups.len()
751 }
752
753 pub fn clear(&mut self) {
754 self.groups.clear();
755 }
756}
757
758pub struct PacketFragmenter {
760 fragment_size: usize,
761 max_fragments: usize,
762 next_group_id: u16,
763}
764
765impl PacketFragmenter {
766 pub fn new(fragment_size: usize, max_fragments: usize) -> Self {
767 Self {
768 fragment_size: fragment_size.max(64),
769 max_fragments: max_fragments.max(1),
770 next_group_id: 0,
771 }
772 }
773
774 pub fn needs_fragmentation(&self, data_len: usize) -> bool {
776 data_len > self.fragment_size
777 }
778
779 pub fn fragment(&mut self, data: &[u8]) -> Vec<(FragmentHeader, Vec<u8>)> {
781 if data.is_empty() {
782 return Vec::new();
783 }
784
785 let total_fragments = ((data.len() + self.fragment_size - 1) / self.fragment_size).min(self.max_fragments);
786 let group_id = self.next_group_id;
787 self.next_group_id = self.next_group_id.wrapping_add(1);
788
789 let mut fragments = Vec::with_capacity(total_fragments);
790 let mut offset = 0;
791
792 for i in 0..total_fragments {
793 let end = (offset + self.fragment_size).min(data.len());
794 let fragment_data = data[offset..end].to_vec();
795
796 let header = FragmentHeader {
797 group_id,
798 fragment_index: i as u8,
799 total_fragments: total_fragments as u8,
800 fragment_size: fragment_data.len() as u16,
801 };
802
803 fragments.push((header, fragment_data));
804 offset = end;
805
806 if offset >= data.len() {
807 break;
808 }
809 }
810
811 fragments
812 }
813
814 pub fn max_payload_size(&self) -> usize {
815 self.fragment_size * self.max_fragments
816 }
817}
818
819pub struct BandwidthThrottle {
821 max_bytes_per_sec: f64,
822 tokens: f64,
823 max_tokens: f64,
824 last_refill: Instant,
825 bytes_sent_this_second: usize,
826 second_start: Instant,
827 enabled: bool,
828}
829
830impl BandwidthThrottle {
831 pub fn new(max_bytes_per_sec: usize) -> Self {
832 let now = Instant::now();
833 Self {
834 max_bytes_per_sec: max_bytes_per_sec as f64,
835 tokens: max_bytes_per_sec as f64,
836 max_tokens: max_bytes_per_sec as f64 * 1.5,
837 last_refill: now,
838 bytes_sent_this_second: 0,
839 second_start: now,
840 enabled: true,
841 }
842 }
843
844 pub fn set_enabled(&mut self, enabled: bool) {
845 self.enabled = enabled;
846 }
847
848 pub fn is_enabled(&self) -> bool {
849 self.enabled
850 }
851
852 pub fn set_max_bytes_per_sec(&mut self, max: usize) {
853 self.max_bytes_per_sec = max as f64;
854 self.max_tokens = max as f64 * 1.5;
855 }
856
857 pub fn max_bytes_per_sec(&self) -> usize {
858 self.max_bytes_per_sec as usize
859 }
860
861 pub fn refill(&mut self, now: Instant) {
863 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
864 self.tokens += self.max_bytes_per_sec * elapsed;
865 if self.tokens > self.max_tokens {
866 self.tokens = self.max_tokens;
867 }
868 self.last_refill = now;
869
870 if now.duration_since(self.second_start).as_secs_f64() >= 1.0 {
872 self.bytes_sent_this_second = 0;
873 self.second_start = now;
874 }
875 }
876
877 pub fn can_send(&self, bytes: usize) -> bool {
879 if !self.enabled {
880 return true;
881 }
882 self.tokens >= bytes as f64
883 }
884
885 pub fn consume(&mut self, bytes: usize) {
887 self.tokens -= bytes as f64;
888 self.bytes_sent_this_second += bytes;
889 }
890
891 pub fn try_send(&mut self, bytes: usize, now: Instant) -> bool {
893 self.refill(now);
894 if self.can_send(bytes) {
895 self.consume(bytes);
896 true
897 } else {
898 false
899 }
900 }
901
902 pub fn available_bytes(&self) -> usize {
903 if !self.enabled {
904 return usize::MAX;
905 }
906 self.tokens.max(0.0) as usize
907 }
908
909 pub fn utilization(&self) -> f64 {
910 if self.max_bytes_per_sec <= 0.0 {
911 return 0.0;
912 }
913 self.bytes_sent_this_second as f64 / self.max_bytes_per_sec
914 }
915
916 pub fn reset(&mut self) {
917 self.tokens = self.max_bytes_per_sec;
918 self.bytes_sent_this_second = 0;
919 self.last_refill = Instant::now();
920 self.second_start = self.last_refill;
921 }
922}
923
924#[derive(Debug, Clone, Copy, PartialEq, Eq)]
926pub enum ConnectionState {
927 Disconnected,
928 Connecting,
929 Connected,
930 Disconnecting,
931}
932
933#[derive(Debug, Clone)]
935pub enum ConnectionEvent {
936 Connected,
937 Disconnected { reason: DisconnectReason },
938 ConnectionFailed { reason: String },
939 TimedOut,
940}
941
942#[derive(Debug, Clone, Copy, PartialEq, Eq)]
944pub enum DisconnectReason {
945 Graceful,
946 Timeout,
947 Kicked,
948 ProtocolError,
949 Full,
950}
951
952pub struct ConnectionStateMachine {
954 state: ConnectionState,
955 config: TransportConfig,
956 connect_started: Option<Instant>,
957 last_packet_received: Option<Instant>,
958 last_packet_sent: Option<Instant>,
959 connect_attempts: u32,
960 max_connect_attempts: u32,
961 connect_retry_interval: Duration,
962 events: VecDeque<ConnectionEvent>,
963 disconnect_reason: Option<DisconnectReason>,
964 session_id: u64,
965 keepalive_due: bool,
966}
967
968impl ConnectionStateMachine {
969 pub fn new(config: TransportConfig) -> Self {
970 Self {
971 state: ConnectionState::Disconnected,
972 config,
973 connect_started: None,
974 last_packet_received: None,
975 last_packet_sent: None,
976 connect_attempts: 0,
977 max_connect_attempts: 5,
978 connect_retry_interval: Duration::from_millis(500),
979 events: VecDeque::new(),
980 disconnect_reason: None,
981 session_id: 0,
982 keepalive_due: false,
983 }
984 }
985
986 pub fn state(&self) -> ConnectionState {
987 self.state
988 }
989
990 pub fn is_connected(&self) -> bool {
991 self.state == ConnectionState::Connected
992 }
993
994 pub fn session_id(&self) -> u64 {
995 self.session_id
996 }
997
998 pub fn connect(&mut self, now: Instant) {
1000 if self.state != ConnectionState::Disconnected {
1001 return;
1002 }
1003 self.state = ConnectionState::Connecting;
1004 self.connect_started = Some(now);
1005 self.connect_attempts = 0;
1006 self.session_id = generate_session_id(now);
1007 }
1008
1009 pub fn on_accepted(&mut self, now: Instant) {
1011 if self.state == ConnectionState::Connecting {
1012 self.state = ConnectionState::Connected;
1013 self.last_packet_received = Some(now);
1014 self.events.push_back(ConnectionEvent::Connected);
1015 }
1016 }
1017
1018 pub fn on_denied(&mut self, reason: String) {
1020 if self.state == ConnectionState::Connecting {
1021 self.state = ConnectionState::Disconnected;
1022 self.events.push_back(ConnectionEvent::ConnectionFailed { reason });
1023 }
1024 }
1025
1026 pub fn on_packet_received(&mut self, now: Instant) {
1028 self.last_packet_received = Some(now);
1029 }
1030
1031 pub fn on_packet_sent(&mut self, now: Instant) {
1033 self.last_packet_sent = Some(now);
1034 self.keepalive_due = false;
1035 }
1036
1037 pub fn disconnect(&mut self) {
1039 if self.state == ConnectionState::Connected {
1040 self.state = ConnectionState::Disconnecting;
1041 self.disconnect_reason = Some(DisconnectReason::Graceful);
1042 }
1043 }
1044
1045 pub fn on_disconnect_complete(&mut self) {
1047 let reason = self.disconnect_reason.unwrap_or(DisconnectReason::Graceful);
1048 self.state = ConnectionState::Disconnected;
1049 self.events.push_back(ConnectionEvent::Disconnected { reason });
1050 self.connect_started = None;
1051 self.last_packet_received = None;
1052 }
1053
1054 pub fn on_remote_disconnect(&mut self, reason: DisconnectReason) {
1056 if self.state == ConnectionState::Connected || self.state == ConnectionState::Connecting {
1057 self.state = ConnectionState::Disconnected;
1058 self.disconnect_reason = Some(reason);
1059 self.events.push_back(ConnectionEvent::Disconnected { reason });
1060 }
1061 }
1062
1063 pub fn update(&mut self, now: Instant) -> bool {
1065 let mut needs_retry = false;
1066
1067 match self.state {
1068 ConnectionState::Connecting => {
1069 if let Some(started) = self.connect_started {
1070 let elapsed = now.duration_since(started);
1071 if elapsed >= Duration::from_millis(self.config.connection_timeout_ms) {
1072 self.state = ConnectionState::Disconnected;
1073 self.events.push_back(ConnectionEvent::TimedOut);
1074 return false;
1075 }
1076
1077 let expected_attempts = (elapsed.as_millis() / self.connect_retry_interval.as_millis()).max(1) as u32;
1079 if expected_attempts > self.connect_attempts && self.connect_attempts < self.max_connect_attempts {
1080 self.connect_attempts += 1;
1081 needs_retry = true;
1082 }
1083 }
1084 }
1085 ConnectionState::Connected => {
1086 if let Some(last_recv) = self.last_packet_received {
1088 let since_recv = now.duration_since(last_recv);
1089 if since_recv >= Duration::from_millis(self.config.connection_timeout_ms) {
1090 self.state = ConnectionState::Disconnected;
1091 self.disconnect_reason = Some(DisconnectReason::Timeout);
1092 self.events.push_back(ConnectionEvent::TimedOut);
1093 return false;
1094 }
1095 }
1096
1097 if let Some(last_sent) = self.last_packet_sent {
1099 let since_sent = now.duration_since(last_sent);
1100 if since_sent >= Duration::from_millis(self.config.keepalive_interval_ms) {
1101 self.keepalive_due = true;
1102 }
1103 } else {
1104 self.keepalive_due = true;
1105 }
1106 }
1107 ConnectionState::Disconnecting => {
1108 self.on_disconnect_complete();
1110 }
1111 ConnectionState::Disconnected => {}
1112 }
1113
1114 needs_retry
1115 }
1116
1117 pub fn needs_keepalive(&self) -> bool {
1119 self.keepalive_due && self.state == ConnectionState::Connected
1120 }
1121
1122 pub fn drain_events(&mut self) -> Vec<ConnectionEvent> {
1124 self.events.drain(..).collect()
1125 }
1126
1127 pub fn force_disconnect(&mut self, reason: DisconnectReason) {
1129 self.state = ConnectionState::Disconnected;
1130 self.disconnect_reason = Some(reason);
1131 self.events.push_back(ConnectionEvent::Disconnected { reason });
1132 }
1133
1134 pub fn reset(&mut self) {
1135 self.state = ConnectionState::Disconnected;
1136 self.connect_started = None;
1137 self.last_packet_received = None;
1138 self.last_packet_sent = None;
1139 self.connect_attempts = 0;
1140 self.events.clear();
1141 self.disconnect_reason = None;
1142 self.keepalive_due = false;
1143 }
1144
1145 pub fn time_since_last_received(&self, now: Instant) -> Option<Duration> {
1147 self.last_packet_received.map(|t| now.duration_since(t))
1148 }
1149}
1150
1151fn generate_session_id(now: Instant) -> u64 {
1153 let elapsed = now.elapsed();
1154 let nanos = elapsed.as_nanos() as u64;
1155 let mut hash: u64 = 0xcbf29ce484222325;
1157 for byte in nanos.to_le_bytes() {
1158 hash ^= byte as u64;
1159 hash = hash.wrapping_mul(0x100000001b3);
1160 }
1161 hash
1162}
1163
1164#[derive(Debug, Clone)]
1166pub struct OutgoingPacket {
1167 pub header: PacketHeader,
1168 pub payload: Vec<u8>,
1169 pub fragment_header: Option<FragmentHeader>,
1170}
1171
1172impl OutgoingPacket {
1173 pub fn serialize(&self) -> Vec<u8> {
1174 let mut buf = self.header.serialize();
1175 if let Some(ref fh) = self.fragment_header {
1176 buf.extend_from_slice(&fh.serialize());
1177 }
1178 buf.extend_from_slice(&self.payload);
1179 buf
1180 }
1181
1182 pub fn total_size(&self) -> usize {
1183 PacketHeader::SERIALIZED_SIZE
1184 + self.fragment_header.as_ref().map_or(0, |_| FragmentHeader::SERIALIZED_SIZE)
1185 + self.payload.len()
1186 }
1187}
1188
1189pub fn deserialize_packet(data: &[u8]) -> Option<(PacketHeader, Option<FragmentHeader>, Vec<u8>)> {
1191 let header = PacketHeader::deserialize(data)?;
1192 if !header.validate_protocol() {
1193 return None;
1194 }
1195
1196 let mut offset = PacketHeader::SERIALIZED_SIZE;
1197
1198 let fragment_header = if header.packet_type == PacketType::Fragment {
1199 if data.len() < offset + FragmentHeader::SERIALIZED_SIZE {
1200 return None;
1201 }
1202 let fh = FragmentHeader::deserialize(&data[offset..])?;
1203 offset += FragmentHeader::SERIALIZED_SIZE;
1204 Some(fh)
1205 } else {
1206 None
1207 };
1208
1209 let payload = if offset < data.len() {
1210 data[offset..].to_vec()
1211 } else {
1212 Vec::new()
1213 };
1214
1215 Some((header, fragment_header, payload))
1216}
1217
1218pub struct JitterBuffer {
1220 buffer: VecDeque<(u64, Vec<u8>)>,
1221 delay_ms: u64,
1222 max_size: usize,
1223}
1224
1225impl JitterBuffer {
1226 pub fn new(delay_ms: u64, max_size: usize) -> Self {
1227 Self {
1228 buffer: VecDeque::new(),
1229 delay_ms,
1230 max_size,
1231 }
1232 }
1233
1234 pub fn push(&mut self, timestamp_ms: u64, data: Vec<u8>) {
1235 let pos = self.buffer.iter().position(|(ts, _)| *ts > timestamp_ms);
1237 match pos {
1238 Some(idx) => self.buffer.insert(idx, (timestamp_ms, data)),
1239 None => self.buffer.push_back((timestamp_ms, data)),
1240 }
1241
1242 while self.buffer.len() > self.max_size {
1244 self.buffer.pop_front();
1245 }
1246 }
1247
1248 pub fn drain_ready(&mut self, current_time_ms: u64) -> Vec<Vec<u8>> {
1249 let threshold = current_time_ms.saturating_sub(self.delay_ms);
1250 let mut ready = Vec::new();
1251 while let Some(&(ts, _)) = self.buffer.front() {
1252 if ts <= threshold {
1253 if let Some((_, data)) = self.buffer.pop_front() {
1254 ready.push(data);
1255 }
1256 } else {
1257 break;
1258 }
1259 }
1260 ready
1261 }
1262
1263 pub fn set_delay(&mut self, delay_ms: u64) {
1264 self.delay_ms = delay_ms;
1265 }
1266
1267 pub fn len(&self) -> usize {
1268 self.buffer.len()
1269 }
1270
1271 pub fn is_empty(&self) -> bool {
1272 self.buffer.is_empty()
1273 }
1274
1275 pub fn clear(&mut self) {
1276 self.buffer.clear();
1277 }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282 use super::*;
1283
1284 #[test]
1285 fn test_packet_header_roundtrip() {
1286 let mut h = PacketHeader::new(PacketType::Reliable, 42);
1287 h.ack = 10;
1288 h.ack_bits = 0xFF00FF00;
1289 h.timestamp_ms = 123456789;
1290 h.payload_size = 512;
1291
1292 let data = h.serialize();
1293 let h2 = PacketHeader::deserialize(&data).unwrap();
1294 assert_eq!(h2.sequence, 42);
1295 assert_eq!(h2.ack, 10);
1296 assert_eq!(h2.ack_bits, 0xFF00FF00);
1297 assert_eq!(h2.timestamp_ms, 123456789);
1298 assert_eq!(h2.payload_size, 512);
1299 assert!(h2.validate_protocol());
1300 }
1301
1302 #[test]
1303 fn test_sequence_greater_than() {
1304 assert!(sequence_greater_than(1, 0));
1305 assert!(sequence_greater_than(100, 99));
1306 assert!(sequence_greater_than(0, 65535));
1308 assert!(!sequence_greater_than(65535, 0));
1309 }
1310
1311 #[test]
1312 fn test_fragment_header_roundtrip() {
1313 let fh = FragmentHeader {
1314 group_id: 7,
1315 fragment_index: 3,
1316 total_fragments: 10,
1317 fragment_size: 1024,
1318 };
1319 let data = fh.serialize();
1320 let fh2 = FragmentHeader::deserialize(&data).unwrap();
1321 assert_eq!(fh2.group_id, 7);
1322 assert_eq!(fh2.fragment_index, 3);
1323 assert_eq!(fh2.total_fragments, 10);
1324 assert_eq!(fh2.fragment_size, 1024);
1325 }
1326
1327 #[test]
1328 fn test_fragmentation_and_reassembly() {
1329 let mut fragmenter = PacketFragmenter::new(100, 256);
1330 let data: Vec<u8> = (0..350).map(|i| (i % 256) as u8).collect();
1331
1332 let fragments = fragmenter.fragment(&data);
1333 assert_eq!(fragments.len(), 4);
1334
1335 let now = Instant::now();
1336 let mut reassembly = ReassemblyBuffer::new(5000, 32);
1337
1338 let mut result = None;
1339 for (fh, fdata) in &fragments {
1340 result = reassembly.insert(fh, fdata.clone(), now);
1341 }
1342
1343 let assembled = result.unwrap();
1344 assert_eq!(assembled, data);
1345 }
1346
1347 #[test]
1348 fn test_bandwidth_throttle() {
1349 let mut throttle = BandwidthThrottle::new(1000);
1350 let now = Instant::now();
1351 throttle.refill(now);
1352
1353 assert!(throttle.can_send(500));
1354 throttle.consume(500);
1355 assert!(throttle.can_send(500));
1356 throttle.consume(500);
1357 assert!(!throttle.can_send(100));
1358 }
1359
1360 #[test]
1361 fn test_connection_state_machine() {
1362 let config = TransportConfig::default();
1363 let mut csm = ConnectionStateMachine::new(config);
1364 let now = Instant::now();
1365
1366 assert_eq!(csm.state(), ConnectionState::Disconnected);
1367 csm.connect(now);
1368 assert_eq!(csm.state(), ConnectionState::Connecting);
1369 csm.on_accepted(now);
1370 assert_eq!(csm.state(), ConnectionState::Connected);
1371
1372 let events = csm.drain_events();
1373 assert_eq!(events.len(), 1);
1374 }
1375
1376 #[test]
1377 fn test_unreliable_channel() {
1378 let mut ch = UnreliableChannel::new();
1379 let (header, payload) = ch.send(vec![1, 2, 3]);
1380 assert_eq!(header.sequence, 0);
1381
1382 ch.receive(&header, payload);
1383 let msgs = ch.drain_received();
1384 assert_eq!(msgs.len(), 1);
1385 assert_eq!(msgs[0], vec![1, 2, 3]);
1386 }
1387
1388 #[test]
1389 fn test_jitter_buffer() {
1390 let mut jb = JitterBuffer::new(50, 100);
1391 jb.push(100, vec![1]);
1392 jb.push(120, vec![2]);
1393 jb.push(90, vec![0]);
1394
1395 let ready = jb.drain_ready(140);
1396 assert_eq!(ready.len(), 1);
1397 assert_eq!(ready[0], vec![0]);
1398
1399 let ready2 = jb.drain_ready(160);
1400 assert_eq!(ready2.len(), 1);
1401 assert_eq!(ready2[0], vec![1]);
1402 }
1403
1404 #[test]
1405 fn test_deserialize_packet() {
1406 let mut header = PacketHeader::new(PacketType::Unreliable, 5);
1407 header.payload_size = 3;
1408 let mut data = header.serialize();
1409 data.extend_from_slice(&[10, 20, 30]);
1410
1411 let (h, fh, payload) = deserialize_packet(&data).unwrap();
1412 assert_eq!(h.sequence, 5);
1413 assert!(fh.is_none());
1414 assert_eq!(payload, vec![10, 20, 30]);
1415 }
1416}