1use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 time::Duration,
5};
6
7use tracing::{trace, debug};
8
9use crate::{
10 Instant, VarInt,
11};
12
13#[derive(Debug)]
18pub(super) struct NatTraversalState {
19 pub(super) role: NatTraversalRole,
21 pub(super) local_candidates: HashMap<VarInt, AddressCandidate>,
23 pub(super) remote_candidates: HashMap<VarInt, AddressCandidate>,
25 pub(super) candidate_pairs: Vec<CandidatePair>,
27 pub(super) active_validations: HashMap<SocketAddr, PathValidationState>,
29 pub(super) coordination: Option<CoordinationState>,
31 pub(super) next_sequence: VarInt,
33 pub(super) max_candidates: u32,
35 pub(super) coordination_timeout: Duration,
37 pub(super) stats: NatTraversalStats,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum NatTraversalRole {
44 Client,
46 Server { can_relay: bool },
48 Bootstrap,
50}
51
52#[derive(Debug, Clone)]
54pub(super) struct AddressCandidate {
55 pub(super) address: SocketAddr,
57 pub(super) priority: u32,
59 pub(super) source: CandidateSource,
61 pub(super) discovered_at: Instant,
63 pub(super) state: CandidateState,
65 pub(super) attempt_count: u32,
67 pub(super) last_attempt: Option<Instant>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum CandidateSource {
74 Local,
76 Observed { by_node: Option<VarInt> },
78 Peer,
80 Predicted,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum CandidateState {
87 New,
89 Validating,
91 Valid,
93 Failed,
95 Removed,
97}
98
99#[derive(Debug)]
101pub(super) struct PathValidationState {
102 pub(super) challenge: u64,
104 pub(super) sent_at: Instant,
106 pub(super) retry_count: u32,
108 pub(super) max_retries: u32,
110 pub(super) coordination_round: Option<VarInt>,
112}
113
114#[derive(Debug)]
116pub(super) struct CoordinationState {
117 pub(super) round: VarInt,
119 pub(super) punch_targets: Vec<PunchTarget>,
121 pub(super) round_start: Instant,
123 pub(super) punch_start: Instant,
125 pub(super) round_duration: Duration,
127 pub(super) state: CoordinationPhase,
129 pub(super) punch_request_sent: bool,
131 pub(super) peer_punch_received: bool,
133 pub(super) retry_count: u32,
135 pub(super) max_retries: u32,
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum CoordinationPhase {
142 Idle,
144 Requesting,
146 Coordinating,
148 Preparing,
150 Punching,
152 Validating,
154 Succeeded,
156 Failed,
158}
159
160#[derive(Debug, Clone)]
162pub(super) struct PunchTarget {
163 pub(super) remote_addr: SocketAddr,
165 pub(super) local_addr: SocketAddr,
167 pub(super) remote_sequence: VarInt,
169 pub(super) challenge: u64,
171}
172
173#[derive(Debug, Clone)]
175pub(super) struct CandidatePair {
176 pub(super) local_sequence: VarInt,
178 pub(super) remote_sequence: VarInt,
180 pub(super) local_addr: SocketAddr,
182 pub(super) remote_addr: SocketAddr,
184 pub(super) priority: u64,
186 pub(super) state: PairState,
188 pub(super) pair_type: PairType,
190 pub(super) created_at: Instant,
192 pub(super) last_check: Option<Instant>,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub(super) enum PairState {
199 Waiting,
201 InProgress,
203 Succeeded,
205 Failed,
207 Frozen,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub(super) enum PairType {
214 HostToHost,
216 HostToServerReflexive,
218 ServerReflexiveToHost,
220 ServerReflexiveToServerReflexive,
222 PeerReflexive,
224 Relayed,
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub(super) enum CandidateType {
231 Host,
233 ServerReflexive,
235 PeerReflexive,
237 Relayed,
239}
240
241fn calculate_candidate_priority(
244 candidate_type: CandidateType,
245 local_preference: u16,
246 component_id: u8,
247) -> u32 {
248 let type_preference = match candidate_type {
249 CandidateType::Host => 126,
250 CandidateType::PeerReflexive => 110,
251 CandidateType::ServerReflexive => 100,
252 CandidateType::Relayed => 0,
253 };
254
255 (1u32 << 24) * type_preference
257 + (1u32 << 8) * local_preference as u32
258 + component_id as u32
259}
260
261fn calculate_pair_priority(local_priority: u32, remote_priority: u32) -> u64 {
264 let g = local_priority as u64;
265 let d = remote_priority as u64;
266
267 (1u64 << 32) * g.min(d) + 2 * g.max(d) + if g > d { 1 } else { 0 }
269}
270
271fn classify_candidate_type(source: CandidateSource) -> CandidateType {
273 match source {
274 CandidateSource::Local => CandidateType::Host,
275 CandidateSource::Observed { .. } => CandidateType::ServerReflexive,
276 CandidateSource::Peer => CandidateType::PeerReflexive,
277 CandidateSource::Predicted => CandidateType::ServerReflexive, }
279}
280
281fn classify_pair_type(local_type: CandidateType, remote_type: CandidateType) -> PairType {
283 match (local_type, remote_type) {
284 (CandidateType::Host, CandidateType::Host) => PairType::HostToHost,
285 (CandidateType::Host, CandidateType::ServerReflexive) => PairType::HostToServerReflexive,
286 (CandidateType::ServerReflexive, CandidateType::Host) => PairType::ServerReflexiveToHost,
287 (CandidateType::ServerReflexive, CandidateType::ServerReflexive) => PairType::ServerReflexiveToServerReflexive,
288 (CandidateType::Relayed, _) | (_, CandidateType::Relayed) => PairType::Relayed,
289 (CandidateType::PeerReflexive, _) | (_, CandidateType::PeerReflexive) => PairType::PeerReflexive,
290 }
291}
292
293fn are_candidates_compatible(local: &AddressCandidate, remote: &AddressCandidate) -> bool {
295 match (local.address, remote.address) {
297 (SocketAddr::V4(_), SocketAddr::V4(_)) => true,
298 (SocketAddr::V6(_), SocketAddr::V6(_)) => true,
299 _ => false, }
301}
302
303#[derive(Debug, Default)]
305pub(super) struct NatTraversalStats {
306 pub(super) remote_candidates_received: u32,
308 pub(super) local_candidates_sent: u32,
310 pub(super) validations_succeeded: u32,
312 pub(super) validations_failed: u32,
314 pub(super) coordination_rounds: u32,
316 pub(super) direct_connections: u32,
318}
319
320impl NatTraversalState {
321 pub(super) fn new(
323 role: NatTraversalRole,
324 max_candidates: u32,
325 coordination_timeout: Duration,
326 ) -> Self {
327 Self {
328 role,
329 local_candidates: HashMap::new(),
330 remote_candidates: HashMap::new(),
331 candidate_pairs: Vec::new(),
332 active_validations: HashMap::new(),
333 coordination: None,
334 next_sequence: VarInt::from_u32(1),
335 max_candidates,
336 coordination_timeout,
337 stats: NatTraversalStats::default(),
338 }
339 }
340
341 pub(super) fn add_remote_candidate(
343 &mut self,
344 sequence: VarInt,
345 address: SocketAddr,
346 priority: VarInt,
347 now: Instant,
348 ) -> Result<(), NatTraversalError> {
349 if self.remote_candidates.len() >= self.max_candidates as usize {
350 return Err(NatTraversalError::TooManyCandidates);
351 }
352
353 if self.remote_candidates.values()
355 .any(|c| c.address == address && c.state != CandidateState::Removed)
356 {
357 return Err(NatTraversalError::DuplicateAddress);
358 }
359
360 let candidate = AddressCandidate {
361 address,
362 priority: priority.into_inner() as u32,
363 source: CandidateSource::Peer,
364 discovered_at: now,
365 state: CandidateState::New,
366 attempt_count: 0,
367 last_attempt: None,
368 };
369
370 self.remote_candidates.insert(sequence, candidate);
371 self.stats.remote_candidates_received += 1;
372
373 Ok(())
374 }
375
376 pub(super) fn remove_candidate(&mut self, sequence: VarInt) -> bool {
378 if let Some(candidate) = self.remote_candidates.get_mut(&sequence) {
379 candidate.state = CandidateState::Removed;
380
381 self.active_validations.remove(&candidate.address);
383 true
384 } else {
385 false
386 }
387 }
388
389 pub(super) fn add_local_candidate(
391 &mut self,
392 address: SocketAddr,
393 source: CandidateSource,
394 now: Instant,
395 ) -> VarInt {
396 let sequence = self.next_sequence;
397 self.next_sequence = VarInt::from_u64(self.next_sequence.into_inner() + 1)
398 .expect("sequence number overflow");
399
400 let candidate_type = classify_candidate_type(source);
402 let local_preference = self.calculate_local_preference(address);
403 let priority = calculate_candidate_priority(candidate_type, local_preference, 1);
404
405 let candidate = AddressCandidate {
406 address,
407 priority,
408 source,
409 discovered_at: now,
410 state: CandidateState::New,
411 attempt_count: 0,
412 last_attempt: None,
413 };
414
415 self.local_candidates.insert(sequence, candidate);
416 self.stats.local_candidates_sent += 1;
417
418 self.generate_candidate_pairs(now);
420
421 sequence
422 }
423
424 fn calculate_local_preference(&self, addr: SocketAddr) -> u16 {
426 match addr {
427 SocketAddr::V4(v4) => {
428 if v4.ip().is_loopback() {
429 0 } else if v4.ip().is_private() {
431 65000 } else {
433 32000 }
435 }
436 SocketAddr::V6(v6) => {
437 if v6.ip().is_loopback() {
438 0
439 } else if v6.ip().is_unicast_link_local() {
440 30000 } else {
442 50000 }
444 }
445 }
446 }
447
448 pub(super) fn generate_candidate_pairs(&mut self, now: Instant) {
450 self.candidate_pairs.clear();
451
452 for (local_seq, local_candidate) in &self.local_candidates {
453 for (remote_seq, remote_candidate) in &self.remote_candidates {
454 if local_candidate.state == CandidateState::Removed
456 || remote_candidate.state == CandidateState::Removed {
457 continue;
458 }
459
460 if !are_candidates_compatible(local_candidate, remote_candidate) {
462 continue;
463 }
464
465 let pair_priority = calculate_pair_priority(
467 local_candidate.priority,
468 remote_candidate.priority
469 );
470
471 let local_type = classify_candidate_type(local_candidate.source);
473 let remote_type = classify_candidate_type(remote_candidate.source);
474 let pair_type = classify_pair_type(local_type, remote_type);
475
476 let pair = CandidatePair {
477 local_sequence: *local_seq,
478 remote_sequence: *remote_seq,
479 local_addr: local_candidate.address,
480 remote_addr: remote_candidate.address,
481 priority: pair_priority,
482 state: PairState::Waiting,
483 pair_type,
484 created_at: now,
485 last_check: None,
486 };
487
488 self.candidate_pairs.push(pair);
489 }
490 }
491
492 self.candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
494
495 trace!("Generated {} candidate pairs", self.candidate_pairs.len());
496 }
497
498 pub(super) fn get_next_validation_pairs(&mut self, max_concurrent: usize) -> Vec<&mut CandidatePair> {
500 self.candidate_pairs
501 .iter_mut()
502 .filter(|pair| pair.state == PairState::Waiting)
503 .take(max_concurrent)
504 .collect()
505 }
506
507 pub(super) fn find_pair_by_remote_addr(&mut self, addr: SocketAddr) -> Option<&mut CandidatePair> {
509 self.candidate_pairs
510 .iter_mut()
511 .find(|pair| pair.remote_addr == addr)
512 }
513
514 pub(super) fn mark_pair_succeeded(&mut self, remote_addr: SocketAddr) -> bool {
516 let (succeeded_type, succeeded_priority) = {
518 if let Some(pair) = self.find_pair_by_remote_addr(remote_addr) {
519 pair.state = PairState::Succeeded;
520 (pair.pair_type, pair.priority)
521 } else {
522 return false;
523 }
524 };
525
526 for other_pair in &mut self.candidate_pairs {
528 if other_pair.pair_type == succeeded_type
529 && other_pair.priority < succeeded_priority
530 && other_pair.state == PairState::Waiting {
531 other_pair.state = PairState::Frozen;
532 }
533 }
534
535 true
536 }
537
538 pub(super) fn get_best_succeeded_pairs(&self) -> Vec<&CandidatePair> {
540 let mut best_ipv4: Option<&CandidatePair> = None;
541 let mut best_ipv6: Option<&CandidatePair> = None;
542
543 for pair in &self.candidate_pairs {
544 if pair.state != PairState::Succeeded {
545 continue;
546 }
547
548 match pair.remote_addr {
549 SocketAddr::V4(_) => {
550 if best_ipv4.map_or(true, |best| pair.priority > best.priority) {
551 best_ipv4 = Some(pair);
552 }
553 }
554 SocketAddr::V6(_) => {
555 if best_ipv6.map_or(true, |best| pair.priority > best.priority) {
556 best_ipv6 = Some(pair);
557 }
558 }
559 }
560 }
561
562 let mut result = Vec::new();
563 if let Some(pair) = best_ipv4 {
564 result.push(pair);
565 }
566 if let Some(pair) = best_ipv6 {
567 result.push(pair);
568 }
569 result
570 }
571
572 pub(super) fn get_validation_candidates(&self) -> Vec<(VarInt, &AddressCandidate)> {
574 let mut candidates: Vec<_> = self.remote_candidates
575 .iter()
576 .filter(|(_, c)| c.state == CandidateState::New)
577 .map(|(k, v)| (*k, v))
578 .collect();
579
580 candidates.sort_by(|a, b| b.1.priority.cmp(&a.1.priority));
582 candidates
583 }
584
585 pub(super) fn start_validation(
587 &mut self,
588 sequence: VarInt,
589 challenge: u64,
590 now: Instant,
591 ) -> Result<(), NatTraversalError> {
592 let candidate = self.remote_candidates.get_mut(&sequence)
593 .ok_or(NatTraversalError::UnknownCandidate)?;
594
595 if candidate.state != CandidateState::New {
596 return Err(NatTraversalError::InvalidCandidateState);
597 }
598
599 candidate.state = CandidateState::Validating;
601 candidate.attempt_count += 1;
602 candidate.last_attempt = Some(now);
603
604 let validation = PathValidationState {
606 challenge,
607 sent_at: now,
608 retry_count: 0,
609 max_retries: 3, coordination_round: self.coordination.as_ref().map(|c| c.round),
611 };
612
613 self.active_validations.insert(candidate.address, validation);
614 Ok(())
615 }
616
617 pub(super) fn handle_validation_success(
619 &mut self,
620 remote_addr: SocketAddr,
621 challenge: u64,
622 ) -> Result<VarInt, NatTraversalError> {
623 let sequence = self.remote_candidates
625 .iter()
626 .find(|(_, c)| c.address == remote_addr)
627 .map(|(seq, _)| *seq)
628 .ok_or(NatTraversalError::UnknownCandidate)?;
629
630 let validation = self.active_validations.get(&remote_addr)
632 .ok_or(NatTraversalError::NoActiveValidation)?;
633
634 if validation.challenge != challenge {
635 return Err(NatTraversalError::ChallengeMismatch);
636 }
637
638 let candidate = self.remote_candidates.get_mut(&sequence)
640 .ok_or(NatTraversalError::UnknownCandidate)?;
641
642 candidate.state = CandidateState::Valid;
643 self.active_validations.remove(&remote_addr);
644 self.stats.validations_succeeded += 1;
645
646 Ok(sequence)
647 }
648
649 pub(super) fn handle_validation_failure(
651 &mut self,
652 remote_addr: SocketAddr,
653 ) -> Option<VarInt> {
654 self.active_validations.remove(&remote_addr);
655
656 let sequence = self.remote_candidates
658 .iter_mut()
659 .find(|(_, c)| c.address == remote_addr)
660 .map(|(seq, candidate)| {
661 candidate.state = CandidateState::Failed;
662 *seq
663 });
664
665 if sequence.is_some() {
666 self.stats.validations_failed += 1;
667 }
668
669 sequence
670 }
671
672 pub(super) fn get_best_candidate(&self) -> Option<(VarInt, &AddressCandidate)> {
674 self.remote_candidates
675 .iter()
676 .filter(|(_, c)| c.state == CandidateState::Valid)
677 .max_by_key(|(_, c)| c.priority)
678 .map(|(k, v)| (*k, v))
679 }
680
681 pub(super) fn start_coordination_round(
683 &mut self,
684 targets: Vec<PunchTarget>,
685 now: Instant,
686 ) -> VarInt {
687 let round = self.next_sequence;
688 self.next_sequence = VarInt::from_u64(self.next_sequence.into_inner() + 1)
689 .expect("sequence number overflow");
690
691 let coordination_grace = Duration::from_millis(500); let punch_start = now + coordination_grace;
694
695 self.coordination = Some(CoordinationState {
696 round,
697 punch_targets: targets,
698 round_start: now,
699 punch_start,
700 round_duration: self.coordination_timeout,
701 state: CoordinationPhase::Requesting,
702 punch_request_sent: false,
703 peer_punch_received: false,
704 retry_count: 0,
705 max_retries: 3,
706 });
707
708 self.stats.coordination_rounds += 1;
709 trace!("Started coordination round {} with {} targets", round, self.coordination.as_ref().unwrap().punch_targets.len());
710 round
711 }
712
713 pub(super) fn get_coordination_phase(&self) -> Option<CoordinationPhase> {
715 self.coordination.as_ref().map(|c| c.state)
716 }
717
718 pub(super) fn should_send_punch_request(&self) -> bool {
720 if let Some(coord) = &self.coordination {
721 coord.state == CoordinationPhase::Requesting && !coord.punch_request_sent
722 } else {
723 false
724 }
725 }
726
727 pub(super) fn mark_punch_request_sent(&mut self) {
729 if let Some(coord) = &mut self.coordination {
730 coord.punch_request_sent = true;
731 coord.state = CoordinationPhase::Coordinating;
732 trace!("PUNCH_ME_NOW sent, waiting for peer coordination");
733 }
734 }
735
736 pub(super) fn handle_peer_punch_request(&mut self, peer_round: VarInt, now: Instant) -> bool {
738 if let Some(coord) = &mut self.coordination {
739 if coord.round == peer_round && coord.state == CoordinationPhase::Coordinating {
740 coord.peer_punch_received = true;
741 coord.state = CoordinationPhase::Preparing;
742
743 let remaining_grace = Duration::from_millis(200); coord.punch_start = now + remaining_grace;
746
747 trace!("Peer coordination received, punch starts in {:?}", remaining_grace);
748 true
749 } else {
750 debug!("Received coordination for wrong round or phase: {} vs {}, {:?}",
751 peer_round, coord.round, coord.state);
752 false
753 }
754 } else {
755 debug!("Received peer coordination but no active round");
756 false
757 }
758 }
759
760 pub(super) fn should_start_punching(&self, now: Instant) -> bool {
762 if let Some(coord) = &self.coordination {
763 coord.state == CoordinationPhase::Preparing && now >= coord.punch_start
764 } else {
765 false
766 }
767 }
768
769 pub(super) fn start_punching_phase(&mut self, _now: Instant) {
771 if let Some(coord) = &mut self.coordination {
772 coord.state = CoordinationPhase::Punching;
773 trace!("Starting synchronized hole punching with {} targets", coord.punch_targets.len());
774 }
775 }
776
777 pub(super) fn get_punch_targets(&self) -> Option<&[PunchTarget]> {
779 self.coordination.as_ref().map(|c| c.punch_targets.as_slice())
780 }
781
782 pub(super) fn mark_coordination_validating(&mut self) {
784 if let Some(coord) = &mut self.coordination {
785 if coord.state == CoordinationPhase::Punching {
786 coord.state = CoordinationPhase::Validating;
787 trace!("Coordination moved to validation phase");
788 }
789 }
790 }
791
792 pub(super) fn handle_coordination_success(&mut self, remote_addr: SocketAddr) -> bool {
794 if let Some(coord) = &mut self.coordination {
795 let was_target = coord.punch_targets.iter().any(|target| target.remote_addr == remote_addr);
797
798 if was_target && coord.state == CoordinationPhase::Validating {
799 coord.state = CoordinationPhase::Succeeded;
800 self.stats.direct_connections += 1;
801 trace!("Coordination succeeded via {}", remote_addr);
802 true
803 } else {
804 false
805 }
806 } else {
807 false
808 }
809 }
810
811 pub(super) fn handle_coordination_failure(&mut self, now: Instant) -> bool {
813 if let Some(coord) = &mut self.coordination {
814 coord.retry_count += 1;
815
816 if coord.retry_count < coord.max_retries {
817 coord.state = CoordinationPhase::Requesting;
819 coord.punch_request_sent = false;
820 coord.peer_punch_received = false;
821 coord.round_start = now;
822 coord.punch_start = now + Duration::from_millis(500);
823
824 trace!("Coordination failed, retrying round {} (attempt {})",
825 coord.round, coord.retry_count + 1);
826 true
827 } else {
828 coord.state = CoordinationPhase::Failed;
829 trace!("Coordination failed after {} attempts", coord.retry_count);
830 false
831 }
832 } else {
833 false
834 }
835 }
836
837 pub(super) fn check_coordination_timeout(&mut self, now: Instant) -> bool {
839 if let Some(coord) = &mut self.coordination {
840 let elapsed = now.duration_since(coord.round_start);
841
842 if elapsed > coord.round_duration {
843 trace!("Coordination round {} timed out after {:?}", coord.round, elapsed);
844 self.handle_coordination_failure(now);
845 true
846 } else {
847 false
848 }
849 } else {
850 false
851 }
852 }
853
854 pub(super) fn is_coordination_expired(&self, now: Instant) -> bool {
856 self.coordination.as_ref()
857 .map_or(false, |c| now.duration_since(c.round_start) > c.round_duration)
858 }
859
860 pub(super) fn complete_coordination(&mut self) {
862 self.coordination = None;
863 }
864}
865
866#[derive(Debug, Clone, Copy, PartialEq, Eq)]
868pub enum NatTraversalError {
869 TooManyCandidates,
871 DuplicateAddress,
873 UnknownCandidate,
875 InvalidCandidateState,
877 NoActiveValidation,
879 ChallengeMismatch,
881 NoActiveCoordination,
883}
884
885impl std::fmt::Display for NatTraversalError {
886 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
887 match self {
888 Self::TooManyCandidates => write!(f, "too many candidates"),
889 Self::DuplicateAddress => write!(f, "duplicate address"),
890 Self::UnknownCandidate => write!(f, "unknown candidate"),
891 Self::InvalidCandidateState => write!(f, "invalid candidate state"),
892 Self::NoActiveValidation => write!(f, "no active validation"),
893 Self::ChallengeMismatch => write!(f, "challenge mismatch"),
894 Self::NoActiveCoordination => write!(f, "no active coordination"),
895 }
896 }
897}
898
899impl std::error::Error for NatTraversalError {}
900
901#[cfg(test)]
902mod tests;
903
904#[cfg(test)]
905#[path = "nat_traversal_tests.rs"]
906mod nat_traversal_tests;