ant_quic/connection/
nat_traversal.rs

1use std::{
2    collections::HashMap,
3    net::SocketAddr,
4    time::Duration,
5};
6
7use tracing::{trace, debug};
8
9use crate::{
10    Instant, VarInt,
11};
12
13/// NAT traversal state for a QUIC connection
14/// 
15/// This manages address candidate discovery, validation, and coordination
16/// for establishing direct P2P connections through NATs.
17#[derive(Debug)]
18pub(super) struct NatTraversalState {
19    /// Our role in NAT traversal (from transport parameters)
20    pub(super) role: NatTraversalRole,
21    /// Candidate addresses we've advertised to the peer
22    pub(super) local_candidates: HashMap<VarInt, AddressCandidate>,
23    /// Candidate addresses received from the peer
24    pub(super) remote_candidates: HashMap<VarInt, AddressCandidate>, 
25    /// Generated candidate pairs for connectivity testing
26    pub(super) candidate_pairs: Vec<CandidatePair>,
27    /// Currently active path validation attempts
28    pub(super) active_validations: HashMap<SocketAddr, PathValidationState>,
29    /// Coordination state for simultaneous hole punching
30    pub(super) coordination: Option<CoordinationState>,
31    /// Sequence number for address advertisements
32    pub(super) next_sequence: VarInt,
33    /// Maximum candidates we're willing to handle
34    pub(super) max_candidates: u32,
35    /// Timeout for coordination rounds
36    pub(super) coordination_timeout: Duration,
37    /// Statistics for this NAT traversal session
38    pub(super) stats: NatTraversalStats,
39}
40
41/// Role in NAT traversal coordination
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum NatTraversalRole {
44    /// Client endpoint (initiates connections, on-demand)
45    Client,
46    /// Server endpoint (accepts connections, always reachable)
47    Server { can_relay: bool },
48    /// Bootstrap/relay endpoint (publicly reachable, coordinates traversal)
49    Bootstrap,
50}
51
52/// Address candidate with metadata
53#[derive(Debug, Clone)]
54pub(super) struct AddressCandidate {
55    /// The socket address
56    pub(super) address: SocketAddr,
57    /// Priority for ICE-like selection (higher = better)
58    pub(super) priority: u32,
59    /// How this candidate was discovered
60    pub(super) source: CandidateSource,
61    /// When this candidate was first learned
62    pub(super) discovered_at: Instant,
63    /// Current state of this candidate
64    pub(super) state: CandidateState,
65    /// Number of validation attempts for this candidate
66    pub(super) attempt_count: u32,
67    /// Last validation attempt time
68    pub(super) last_attempt: Option<Instant>,
69}
70
71/// How an address candidate was discovered
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum CandidateSource {
74    /// Local network interface
75    Local,
76    /// Observed by a bootstrap node
77    Observed { by_node: Option<VarInt> },
78    /// Received from peer via AddAddress frame
79    Peer,
80    /// Generated prediction for symmetric NAT
81    Predicted,
82}
83
84/// Current state of a candidate address
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum CandidateState {
87    /// Newly discovered, not yet tested
88    New,
89    /// Currently being validated
90    Validating,
91    /// Successfully validated and usable
92    Valid,
93    /// Validation failed
94    Failed,
95    /// Removed by peer or expired
96    Removed,
97}
98
99/// State of an individual path validation attempt
100#[derive(Debug)]
101pub(super) struct PathValidationState {
102    /// Challenge value sent
103    pub(super) challenge: u64,
104    /// When the challenge was sent
105    pub(super) sent_at: Instant,
106    /// Number of retransmissions
107    pub(super) retry_count: u32,
108    /// Maximum retries allowed
109    pub(super) max_retries: u32,
110    /// Associated with a coordination round (if any)
111    pub(super) coordination_round: Option<VarInt>,
112}
113
114/// Coordination state for simultaneous hole punching
115#[derive(Debug)]
116pub(super) struct CoordinationState {
117    /// Current coordination round number
118    pub(super) round: VarInt,
119    /// Addresses we're punching to in this round
120    pub(super) punch_targets: Vec<PunchTarget>,
121    /// When this round started (coordination phase)
122    pub(super) round_start: Instant,
123    /// When hole punching should begin (synchronized time)
124    pub(super) punch_start: Instant,
125    /// Duration of this coordination round
126    pub(super) round_duration: Duration,
127    /// Current state of this coordination round
128    pub(super) state: CoordinationPhase,
129    /// Whether we've sent our PUNCH_ME_NOW to coordinator
130    pub(super) punch_request_sent: bool,
131    /// Whether we've received peer's PUNCH_ME_NOW via coordinator
132    pub(super) peer_punch_received: bool,
133    /// Retry count for this round
134    pub(super) retry_count: u32,
135    /// Maximum retries before giving up
136    pub(super) max_retries: u32,
137}
138
139/// Phases of the coordination protocol
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum CoordinationPhase {
142    /// Waiting to start coordination
143    Idle,
144    /// Sending PUNCH_ME_NOW to coordinator
145    Requesting,
146    /// Waiting for peer's PUNCH_ME_NOW via coordinator  
147    Coordinating,
148    /// Grace period before synchronized hole punching
149    Preparing,
150    /// Actively sending PATH_CHALLENGE packets
151    Punching,
152    /// Waiting for PATH_RESPONSE validation
153    Validating,
154    /// This round completed successfully
155    Succeeded,
156    /// This round failed, may retry
157    Failed,
158}
159
160/// Target for hole punching in a coordination round
161#[derive(Debug, Clone)]
162pub(super) struct PunchTarget {
163    /// Remote address to punch to
164    pub(super) remote_addr: SocketAddr,
165    /// Our local address for this punch
166    pub(super) local_addr: SocketAddr,
167    /// Sequence number of the remote candidate
168    pub(super) remote_sequence: VarInt,
169    /// Challenge value for validation
170    pub(super) challenge: u64,
171}
172
173/// Candidate pair for ICE-like connectivity testing
174#[derive(Debug, Clone)]
175pub(super) struct CandidatePair {
176    /// Sequence of our local candidate
177    pub(super) local_sequence: VarInt,
178    /// Sequence of remote candidate  
179    pub(super) remote_sequence: VarInt,
180    /// Our local address for this pair
181    pub(super) local_addr: SocketAddr,
182    /// Remote address we're testing connectivity to
183    pub(super) remote_addr: SocketAddr,
184    /// Combined priority for pair ordering (higher = better)
185    pub(super) priority: u64,
186    /// Current state of this pair
187    pub(super) state: PairState,
188    /// Type classification for this pair
189    pub(super) pair_type: PairType,
190    /// When this pair was created
191    pub(super) created_at: Instant,
192    /// When validation was last attempted
193    pub(super) last_check: Option<Instant>,
194}
195
196/// State of a candidate pair during validation
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub(super) enum PairState {
199    /// Waiting to be tested
200    Waiting,
201    /// Currently being validated
202    InProgress,
203    /// Validation succeeded - this pair works
204    Succeeded,
205    /// Validation failed 
206    Failed,
207    /// Temporarily frozen (waiting for other pairs)
208    Frozen,
209}
210
211/// Type classification for candidate pairs (based on ICE)
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub(super) enum PairType {
214    /// Both candidates are on local network
215    HostToHost,
216    /// Local is host, remote is server reflexive (through NAT)
217    HostToServerReflexive,
218    /// Local is server reflexive, remote is host
219    ServerReflexiveToHost,
220    /// Both are server reflexive (both behind NAT)
221    ServerReflexiveToServerReflexive,
222    /// One side is peer reflexive (learned from peer)
223    PeerReflexive,
224    /// Using relay servers
225    Relayed,
226}
227
228/// Type of address candidate (following ICE terminology)
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub(super) enum CandidateType {
231    /// Host candidate - directly reachable local interface
232    Host,
233    /// Server reflexive - public address observed by STUN-like server
234    ServerReflexive,
235    /// Peer reflexive - address learned from incoming packets
236    PeerReflexive,
237    /// Relayed - address of relay server
238    Relayed,
239}
240
241/// Calculate ICE-like priority for an address candidate
242/// Based on RFC 8445 Section 5.1.2.1
243fn calculate_candidate_priority(
244    candidate_type: CandidateType,
245    local_preference: u16,
246    component_id: u8,
247) -> u32 {
248    let type_preference = match candidate_type {
249        CandidateType::Host => 126,
250        CandidateType::PeerReflexive => 110,
251        CandidateType::ServerReflexive => 100,
252        CandidateType::Relayed => 0,
253    };
254    
255    // ICE priority formula: (2^24 * type_pref) + (2^8 * local_pref) + component_id
256    (1u32 << 24) * type_preference 
257        + (1u32 << 8) * local_preference as u32 
258        + component_id as u32
259}
260
261/// Calculate combined priority for a candidate pair
262/// Based on RFC 8445 Section 6.1.2.3  
263fn calculate_pair_priority(local_priority: u32, remote_priority: u32) -> u64 {
264    let g = local_priority as u64;
265    let d = remote_priority as u64;
266    
267    // ICE pair priority formula: 2^32 * MIN(G,D) + 2 * MAX(G,D) + (G>D ? 1 : 0)
268    (1u64 << 32) * g.min(d) + 2 * g.max(d) + if g > d { 1 } else { 0 }
269}
270
271/// Determine candidate type from source information
272fn classify_candidate_type(source: CandidateSource) -> CandidateType {
273    match source {
274        CandidateSource::Local => CandidateType::Host,
275        CandidateSource::Observed { .. } => CandidateType::ServerReflexive,
276        CandidateSource::Peer => CandidateType::PeerReflexive,
277        CandidateSource::Predicted => CandidateType::ServerReflexive, // Symmetric NAT prediction
278    }
279}
280
281/// Determine pair type from individual candidate types
282fn classify_pair_type(local_type: CandidateType, remote_type: CandidateType) -> PairType {
283    match (local_type, remote_type) {
284        (CandidateType::Host, CandidateType::Host) => PairType::HostToHost,
285        (CandidateType::Host, CandidateType::ServerReflexive) => PairType::HostToServerReflexive,
286        (CandidateType::ServerReflexive, CandidateType::Host) => PairType::ServerReflexiveToHost,
287        (CandidateType::ServerReflexive, CandidateType::ServerReflexive) => PairType::ServerReflexiveToServerReflexive,
288        (CandidateType::Relayed, _) | (_, CandidateType::Relayed) => PairType::Relayed,
289        (CandidateType::PeerReflexive, _) | (_, CandidateType::PeerReflexive) => PairType::PeerReflexive,
290    }
291}
292
293/// Check if two candidates are compatible for pairing
294fn are_candidates_compatible(local: &AddressCandidate, remote: &AddressCandidate) -> bool {
295    // Must be same address family (IPv4 with IPv4, IPv6 with IPv6)
296    match (local.address, remote.address) {
297        (SocketAddr::V4(_), SocketAddr::V4(_)) => true,
298        (SocketAddr::V6(_), SocketAddr::V6(_)) => true,
299        _ => false, // No IPv4/IPv6 mixing for now
300    }
301}
302
303/// Statistics for NAT traversal attempts
304#[derive(Debug, Default)]
305pub(super) struct NatTraversalStats {
306    /// Total candidates received from peer
307    pub(super) remote_candidates_received: u32,
308    /// Total candidates we've advertised
309    pub(super) local_candidates_sent: u32,
310    /// Successful validations
311    pub(super) validations_succeeded: u32,
312    /// Failed validations
313    pub(super) validations_failed: u32,
314    /// Coordination rounds attempted
315    pub(super) coordination_rounds: u32,
316    /// Successful direct connections established
317    pub(super) direct_connections: u32,
318}
319
320impl NatTraversalState {
321    /// Create new NAT traversal state with given role and configuration
322    pub(super) fn new(
323        role: NatTraversalRole,
324        max_candidates: u32,
325        coordination_timeout: Duration,
326    ) -> Self {
327        Self {
328            role,
329            local_candidates: HashMap::new(),
330            remote_candidates: HashMap::new(),
331            candidate_pairs: Vec::new(),
332            active_validations: HashMap::new(),
333            coordination: None,
334            next_sequence: VarInt::from_u32(1),
335            max_candidates,
336            coordination_timeout,
337            stats: NatTraversalStats::default(),
338        }
339    }
340
341    /// Add a remote candidate from AddAddress frame
342    pub(super) fn add_remote_candidate(
343        &mut self,
344        sequence: VarInt,
345        address: SocketAddr,
346        priority: VarInt,
347        now: Instant,
348    ) -> Result<(), NatTraversalError> {
349        if self.remote_candidates.len() >= self.max_candidates as usize {
350            return Err(NatTraversalError::TooManyCandidates);
351        }
352
353        // Check for duplicate addresses (different sequence, same address)
354        if self.remote_candidates.values()
355            .any(|c| c.address == address && c.state != CandidateState::Removed) 
356        {
357            return Err(NatTraversalError::DuplicateAddress);
358        }
359
360        let candidate = AddressCandidate {
361            address,
362            priority: priority.into_inner() as u32,
363            source: CandidateSource::Peer,
364            discovered_at: now,
365            state: CandidateState::New,
366            attempt_count: 0,
367            last_attempt: None,
368        };
369
370        self.remote_candidates.insert(sequence, candidate);
371        self.stats.remote_candidates_received += 1;
372        
373        Ok(())
374    }
375
376    /// Remove a candidate by sequence number
377    pub(super) fn remove_candidate(&mut self, sequence: VarInt) -> bool {
378        if let Some(candidate) = self.remote_candidates.get_mut(&sequence) {
379            candidate.state = CandidateState::Removed;
380            
381            // Cancel any active validation for this address
382            self.active_validations.remove(&candidate.address);
383            true
384        } else {
385            false
386        }
387    }
388
389    /// Add a local candidate that we've discovered
390    pub(super) fn add_local_candidate(
391        &mut self,
392        address: SocketAddr,
393        source: CandidateSource,
394        now: Instant,
395    ) -> VarInt {
396        let sequence = self.next_sequence;
397        self.next_sequence = VarInt::from_u64(self.next_sequence.into_inner() + 1)
398            .expect("sequence number overflow");
399
400        // Calculate priority for this candidate
401        let candidate_type = classify_candidate_type(source);
402        let local_preference = self.calculate_local_preference(address);
403        let priority = calculate_candidate_priority(candidate_type, local_preference, 1);
404
405        let candidate = AddressCandidate {
406            address,
407            priority,
408            source,
409            discovered_at: now,
410            state: CandidateState::New,
411            attempt_count: 0,
412            last_attempt: None,
413        };
414
415        self.local_candidates.insert(sequence, candidate);
416        self.stats.local_candidates_sent += 1;
417
418        // Regenerate pairs when we add a new local candidate
419        self.generate_candidate_pairs(now);
420        
421        sequence
422    }
423
424    /// Calculate local preference for address prioritization
425    fn calculate_local_preference(&self, addr: SocketAddr) -> u16 {
426        match addr {
427            SocketAddr::V4(v4) => {
428                if v4.ip().is_loopback() {
429                    0 // Lowest priority
430                } else if v4.ip().is_private() {
431                    65000 // High priority for local network
432                } else {
433                    32000 // Medium priority for public addresses
434                }
435            }
436            SocketAddr::V6(v6) => {
437                if v6.ip().is_loopback() {
438                    0
439                } else if v6.ip().is_unicast_link_local() {
440                    30000 // Link-local gets medium-low priority
441                } else {
442                    50000 // IPv6 generally gets good priority
443                }
444            }
445        }
446    }
447
448    /// Generate all possible candidate pairs from local and remote candidates
449    pub(super) fn generate_candidate_pairs(&mut self, now: Instant) {
450        self.candidate_pairs.clear();
451        
452        for (local_seq, local_candidate) in &self.local_candidates {
453            for (remote_seq, remote_candidate) in &self.remote_candidates {
454                // Skip removed candidates
455                if local_candidate.state == CandidateState::Removed 
456                    || remote_candidate.state == CandidateState::Removed {
457                    continue;
458                }
459
460                // Check compatibility
461                if !are_candidates_compatible(local_candidate, remote_candidate) {
462                    continue;
463                }
464
465                // Calculate combined priority
466                let pair_priority = calculate_pair_priority(
467                    local_candidate.priority, 
468                    remote_candidate.priority
469                );
470
471                // Classify pair type
472                let local_type = classify_candidate_type(local_candidate.source);
473                let remote_type = classify_candidate_type(remote_candidate.source);
474                let pair_type = classify_pair_type(local_type, remote_type);
475
476                let pair = CandidatePair {
477                    local_sequence: *local_seq,
478                    remote_sequence: *remote_seq,
479                    local_addr: local_candidate.address,
480                    remote_addr: remote_candidate.address,
481                    priority: pair_priority,
482                    state: PairState::Waiting,
483                    pair_type,
484                    created_at: now,
485                    last_check: None,
486                };
487
488                self.candidate_pairs.push(pair);
489            }
490        }
491
492        // Sort pairs by priority (highest first)
493        self.candidate_pairs.sort_by(|a, b| b.priority.cmp(&a.priority));
494
495        trace!("Generated {} candidate pairs", self.candidate_pairs.len());
496    }
497
498    /// Get the highest priority pairs ready for validation
499    pub(super) fn get_next_validation_pairs(&mut self, max_concurrent: usize) -> Vec<&mut CandidatePair> {
500        self.candidate_pairs
501            .iter_mut()
502            .filter(|pair| pair.state == PairState::Waiting)
503            .take(max_concurrent)
504            .collect()
505    }
506
507    /// Find a candidate pair by remote address
508    pub(super) fn find_pair_by_remote_addr(&mut self, addr: SocketAddr) -> Option<&mut CandidatePair> {
509        self.candidate_pairs
510            .iter_mut()
511            .find(|pair| pair.remote_addr == addr)
512    }
513
514    /// Mark a pair as succeeded and handle promotion
515    pub(super) fn mark_pair_succeeded(&mut self, remote_addr: SocketAddr) -> bool {
516        // Find the pair and get its type and priority
517        let (succeeded_type, succeeded_priority) = {
518            if let Some(pair) = self.find_pair_by_remote_addr(remote_addr) {
519                pair.state = PairState::Succeeded;
520                (pair.pair_type, pair.priority)
521            } else {
522                return false;
523            }
524        };
525        
526        // Freeze lower priority pairs of the same type to avoid unnecessary testing
527        for other_pair in &mut self.candidate_pairs {
528            if other_pair.pair_type == succeeded_type 
529                && other_pair.priority < succeeded_priority 
530                && other_pair.state == PairState::Waiting {
531                other_pair.state = PairState::Frozen;
532            }
533        }
534        
535        true
536    }
537
538    /// Get the best succeeded pair for each address family
539    pub(super) fn get_best_succeeded_pairs(&self) -> Vec<&CandidatePair> {
540        let mut best_ipv4: Option<&CandidatePair> = None;
541        let mut best_ipv6: Option<&CandidatePair> = None;
542        
543        for pair in &self.candidate_pairs {
544            if pair.state != PairState::Succeeded {
545                continue;
546            }
547            
548            match pair.remote_addr {
549                SocketAddr::V4(_) => {
550                    if best_ipv4.map_or(true, |best| pair.priority > best.priority) {
551                        best_ipv4 = Some(pair);
552                    }
553                }
554                SocketAddr::V6(_) => {
555                    if best_ipv6.map_or(true, |best| pair.priority > best.priority) {
556                        best_ipv6 = Some(pair);
557                    }
558                }
559            }
560        }
561        
562        let mut result = Vec::new();
563        if let Some(pair) = best_ipv4 {
564            result.push(pair);
565        }
566        if let Some(pair) = best_ipv6 {
567            result.push(pair);
568        }
569        result
570    }
571
572    /// Get candidates ready for validation, sorted by priority
573    pub(super) fn get_validation_candidates(&self) -> Vec<(VarInt, &AddressCandidate)> {
574        let mut candidates: Vec<_> = self.remote_candidates
575            .iter()
576            .filter(|(_, c)| c.state == CandidateState::New)
577            .map(|(k, v)| (*k, v))
578            .collect();
579        
580        // Sort by priority (higher priority first)
581        candidates.sort_by(|a, b| b.1.priority.cmp(&a.1.priority));
582        candidates
583    }
584
585    /// Start validation for a candidate address
586    pub(super) fn start_validation(
587        &mut self,
588        sequence: VarInt,
589        challenge: u64,
590        now: Instant,
591    ) -> Result<(), NatTraversalError> {
592        let candidate = self.remote_candidates.get_mut(&sequence)
593            .ok_or(NatTraversalError::UnknownCandidate)?;
594        
595        if candidate.state != CandidateState::New {
596            return Err(NatTraversalError::InvalidCandidateState);
597        }
598
599        // Update candidate state
600        candidate.state = CandidateState::Validating;
601        candidate.attempt_count += 1;
602        candidate.last_attempt = Some(now);
603
604        // Track validation state
605        let validation = PathValidationState {
606            challenge,
607            sent_at: now,
608            retry_count: 0,
609            max_retries: 3, // TODO: Make configurable
610            coordination_round: self.coordination.as_ref().map(|c| c.round),
611        };
612
613        self.active_validations.insert(candidate.address, validation);
614        Ok(())
615    }
616
617    /// Handle successful validation response
618    pub(super) fn handle_validation_success(
619        &mut self,
620        remote_addr: SocketAddr,
621        challenge: u64,
622    ) -> Result<VarInt, NatTraversalError> {
623        // Find the candidate with this address
624        let sequence = self.remote_candidates
625            .iter()
626            .find(|(_, c)| c.address == remote_addr)
627            .map(|(seq, _)| *seq)
628            .ok_or(NatTraversalError::UnknownCandidate)?;
629
630        // Verify challenge matches
631        let validation = self.active_validations.get(&remote_addr)
632            .ok_or(NatTraversalError::NoActiveValidation)?;
633        
634        if validation.challenge != challenge {
635            return Err(NatTraversalError::ChallengeMismatch);
636        }
637
638        // Update candidate state
639        let candidate = self.remote_candidates.get_mut(&sequence)
640            .ok_or(NatTraversalError::UnknownCandidate)?;
641        
642        candidate.state = CandidateState::Valid;
643        self.active_validations.remove(&remote_addr);
644        self.stats.validations_succeeded += 1;
645
646        Ok(sequence)
647    }
648
649    /// Handle failed validation (timeout or error)
650    pub(super) fn handle_validation_failure(
651        &mut self,
652        remote_addr: SocketAddr,
653    ) -> Option<VarInt> {
654        self.active_validations.remove(&remote_addr);
655        
656        // Find and mark candidate as failed
657        let sequence = self.remote_candidates
658            .iter_mut()
659            .find(|(_, c)| c.address == remote_addr)
660            .map(|(seq, candidate)| {
661                candidate.state = CandidateState::Failed;
662                *seq
663            });
664
665        if sequence.is_some() {
666            self.stats.validations_failed += 1;
667        }
668
669        sequence
670    }
671
672    /// Get the highest priority valid candidate
673    pub(super) fn get_best_candidate(&self) -> Option<(VarInt, &AddressCandidate)> {
674        self.remote_candidates
675            .iter()
676            .filter(|(_, c)| c.state == CandidateState::Valid)
677            .max_by_key(|(_, c)| c.priority)
678            .map(|(k, v)| (*k, v))
679    }
680
681    /// Start a new coordination round for simultaneous hole punching
682    pub(super) fn start_coordination_round(
683        &mut self,
684        targets: Vec<PunchTarget>,
685        now: Instant,
686    ) -> VarInt {
687        let round = self.next_sequence;
688        self.next_sequence = VarInt::from_u64(self.next_sequence.into_inner() + 1)
689            .expect("sequence number overflow");
690
691        // Calculate synchronized punch time (grace period for coordination)
692        let coordination_grace = Duration::from_millis(500); // 500ms for coordination
693        let punch_start = now + coordination_grace;
694
695        self.coordination = Some(CoordinationState {
696            round,
697            punch_targets: targets,
698            round_start: now,
699            punch_start,
700            round_duration: self.coordination_timeout,
701            state: CoordinationPhase::Requesting,
702            punch_request_sent: false,
703            peer_punch_received: false,
704            retry_count: 0,
705            max_retries: 3,
706        });
707
708        self.stats.coordination_rounds += 1;
709        trace!("Started coordination round {} with {} targets", round, self.coordination.as_ref().unwrap().punch_targets.len());
710        round
711    }
712
713    /// Get the current coordination phase
714    pub(super) fn get_coordination_phase(&self) -> Option<CoordinationPhase> {
715        self.coordination.as_ref().map(|c| c.state)
716    }
717
718    /// Check if we need to send PUNCH_ME_NOW frame
719    pub(super) fn should_send_punch_request(&self) -> bool {
720        if let Some(coord) = &self.coordination {
721            coord.state == CoordinationPhase::Requesting && !coord.punch_request_sent
722        } else {
723            false
724        }
725    }
726
727    /// Mark that we've sent our PUNCH_ME_NOW request
728    pub(super) fn mark_punch_request_sent(&mut self) {
729        if let Some(coord) = &mut self.coordination {
730            coord.punch_request_sent = true;
731            coord.state = CoordinationPhase::Coordinating;
732            trace!("PUNCH_ME_NOW sent, waiting for peer coordination");
733        }
734    }
735
736    /// Handle receiving peer's PUNCH_ME_NOW (via coordinator)
737    pub(super) fn handle_peer_punch_request(&mut self, peer_round: VarInt, now: Instant) -> bool {
738        if let Some(coord) = &mut self.coordination {
739            if coord.round == peer_round && coord.state == CoordinationPhase::Coordinating {
740                coord.peer_punch_received = true;
741                coord.state = CoordinationPhase::Preparing;
742                
743                // Recalculate punch time based on when we received coordination
744                let remaining_grace = Duration::from_millis(200); // 200ms remaining grace
745                coord.punch_start = now + remaining_grace;
746                
747                trace!("Peer coordination received, punch starts in {:?}", remaining_grace);
748                true
749            } else {
750                debug!("Received coordination for wrong round or phase: {} vs {}, {:?}", 
751                       peer_round, coord.round, coord.state);
752                false
753            }
754        } else {
755            debug!("Received peer coordination but no active round");
756            false
757        }
758    }
759
760    /// Check if it's time to start hole punching
761    pub(super) fn should_start_punching(&self, now: Instant) -> bool {
762        if let Some(coord) = &self.coordination {
763            coord.state == CoordinationPhase::Preparing && now >= coord.punch_start
764        } else {
765            false
766        }
767    }
768
769    /// Start the synchronized hole punching phase
770    pub(super) fn start_punching_phase(&mut self, _now: Instant) {
771        if let Some(coord) = &mut self.coordination {
772            coord.state = CoordinationPhase::Punching;
773            trace!("Starting synchronized hole punching with {} targets", coord.punch_targets.len());
774        }
775    }
776
777    /// Get punch targets for the current round
778    pub(super) fn get_punch_targets(&self) -> Option<&[PunchTarget]> {
779        self.coordination.as_ref().map(|c| c.punch_targets.as_slice())
780    }
781
782    /// Mark coordination as validating (PATH_CHALLENGE sent)
783    pub(super) fn mark_coordination_validating(&mut self) {
784        if let Some(coord) = &mut self.coordination {
785            if coord.state == CoordinationPhase::Punching {
786                coord.state = CoordinationPhase::Validating;
787                trace!("Coordination moved to validation phase");
788            }
789        }
790    }
791
792    /// Handle successful path validation during coordination
793    pub(super) fn handle_coordination_success(&mut self, remote_addr: SocketAddr) -> bool {
794        if let Some(coord) = &mut self.coordination {
795            // Check if this address was one of our punch targets
796            let was_target = coord.punch_targets.iter().any(|target| target.remote_addr == remote_addr);
797            
798            if was_target && coord.state == CoordinationPhase::Validating {
799                coord.state = CoordinationPhase::Succeeded;
800                self.stats.direct_connections += 1;
801                trace!("Coordination succeeded via {}", remote_addr);
802                true
803            } else {
804                false
805            }
806        } else {
807            false
808        }
809    }
810
811    /// Handle coordination failure and determine if we should retry
812    pub(super) fn handle_coordination_failure(&mut self, now: Instant) -> bool {
813        if let Some(coord) = &mut self.coordination {
814            coord.retry_count += 1;
815            
816            if coord.retry_count < coord.max_retries {
817                // Retry with next best candidates
818                coord.state = CoordinationPhase::Requesting;
819                coord.punch_request_sent = false;
820                coord.peer_punch_received = false;
821                coord.round_start = now;
822                coord.punch_start = now + Duration::from_millis(500);
823                
824                trace!("Coordination failed, retrying round {} (attempt {})", 
825                       coord.round, coord.retry_count + 1);
826                true
827            } else {
828                coord.state = CoordinationPhase::Failed;
829                trace!("Coordination failed after {} attempts", coord.retry_count);
830                false
831            }
832        } else {
833            false
834        }
835    }
836
837    /// Check if the current coordination round has timed out
838    pub(super) fn check_coordination_timeout(&mut self, now: Instant) -> bool {
839        if let Some(coord) = &mut self.coordination {
840            let elapsed = now.duration_since(coord.round_start);
841            
842            if elapsed > coord.round_duration {
843                trace!("Coordination round {} timed out after {:?}", coord.round, elapsed);
844                self.handle_coordination_failure(now);
845                true
846            } else {
847                false
848            }
849        } else {
850            false
851        }
852    }
853
854    /// Check if coordination round has timed out
855    pub(super) fn is_coordination_expired(&self, now: Instant) -> bool {
856        self.coordination.as_ref()
857            .map_or(false, |c| now.duration_since(c.round_start) > c.round_duration)
858    }
859
860    /// Complete coordination round
861    pub(super) fn complete_coordination(&mut self) {
862        self.coordination = None;
863    }
864}
865
866/// Errors that can occur during NAT traversal
867#[derive(Debug, Clone, Copy, PartialEq, Eq)]
868pub enum NatTraversalError {
869    /// Too many candidates received
870    TooManyCandidates,
871    /// Duplicate address for different sequence
872    DuplicateAddress,
873    /// Unknown candidate sequence
874    UnknownCandidate,
875    /// Candidate in wrong state for operation
876    InvalidCandidateState,
877    /// No active validation for address
878    NoActiveValidation,
879    /// Challenge value mismatch
880    ChallengeMismatch,
881    /// Coordination round not active
882    NoActiveCoordination,
883}
884
885impl std::fmt::Display for NatTraversalError {
886    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
887        match self {
888            Self::TooManyCandidates => write!(f, "too many candidates"),
889            Self::DuplicateAddress => write!(f, "duplicate address"),
890            Self::UnknownCandidate => write!(f, "unknown candidate"),
891            Self::InvalidCandidateState => write!(f, "invalid candidate state"),
892            Self::NoActiveValidation => write!(f, "no active validation"),
893            Self::ChallengeMismatch => write!(f, "challenge mismatch"),
894            Self::NoActiveCoordination => write!(f, "no active coordination"),
895        }
896    }
897}
898
899impl std::error::Error for NatTraversalError {}
900
901#[cfg(test)]
902mod tests;
903
904#[cfg(test)]
905#[path = "nat_traversal_tests.rs"]
906mod nat_traversal_tests;