Skip to main content

ant_quic/
connection_strategy.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Connection strategy state machine for progressive NAT traversal fallback.
9//!
10//! This module implements a state machine that attempts connections using
11//! progressively more aggressive NAT traversal techniques:
12//!
13//! 1. **Direct IPv4** - Simple direct connection (fastest when both peers have public IPv4)
14//! 2. **Direct IPv6** - Many ISPs have native IPv6 even behind CGNAT
15//! 3. **Hole-Punch** - Coordinated NAT traversal via a common peer
16//! 4. **Relay** - MASQUE CONNECT-UDP relay (guaranteed connectivity)
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! let config = StrategyConfig::default();
22//! let mut strategy = ConnectionStrategy::new(config);
23//!
24//! loop {
25//!     match strategy.current_stage() {
26//!         ConnectionStage::DirectIPv4 { .. } => {
27//!             // Try direct IPv4 connection
28//!         }
29//!         ConnectionStage::DirectIPv6 { .. } => {
30//!             // Try direct IPv6 connection
31//!         }
32//!         ConnectionStage::HolePunching { .. } => {
33//!             // Coordinate hole-punching via common peer
34//!         }
35//!         ConnectionStage::Relay { .. } => {
36//!             // Connect via MASQUE relay
37//!         }
38//!         ConnectionStage::Connected { via } => {
39//!             println!("Connected via {:?}", via);
40//!             break;
41//!         }
42//!         ConnectionStage::Failed { errors } => {
43//!             eprintln!("All strategies failed: {:?}", errors);
44//!             break;
45//!         }
46//!     }
47//! }
48//! ```
49
50use std::net::SocketAddr;
51use std::time::{Duration, Instant};
52
53/// How a connection was established
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum ConnectionMethod {
56    /// Direct IPv4 connection succeeded
57    DirectIPv4,
58    /// Direct IPv6 connection succeeded (NAT bypassed)
59    DirectIPv6,
60    /// Connection established via coordinated hole-punching
61    HolePunched {
62        /// The coordinator peer that helped with hole-punching
63        coordinator: SocketAddr,
64    },
65    /// Connection established via relay
66    Relayed {
67        /// The relay server address
68        relay: SocketAddr,
69    },
70}
71
72impl std::fmt::Display for ConnectionMethod {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            ConnectionMethod::DirectIPv4 => write!(f, "Direct IPv4"),
76            ConnectionMethod::DirectIPv6 => write!(f, "Direct IPv6"),
77            ConnectionMethod::HolePunched { coordinator } => {
78                write!(f, "Hole-punched via {}", coordinator)
79            }
80            ConnectionMethod::Relayed { relay } => write!(f, "Relayed via {}", relay),
81        }
82    }
83}
84
85/// Error that occurred during a connection attempt
86#[derive(Debug, Clone)]
87pub struct ConnectionAttemptError {
88    /// The method that was attempted
89    pub method: AttemptedMethod,
90    /// Description of the error
91    pub error: String,
92    /// When the attempt was made
93    pub timestamp: Instant,
94}
95
96/// Which method was attempted
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub enum AttemptedMethod {
99    /// Direct IPv4 connection
100    DirectIPv4,
101    /// Direct IPv6 connection
102    DirectIPv6,
103    /// Hole-punching with specified round
104    HolePunch {
105        /// The round number
106        round: u32,
107    },
108    /// Relay connection
109    Relay,
110}
111
112/// Current stage of the connection strategy
113#[derive(Debug, Clone)]
114pub enum ConnectionStage {
115    /// Attempting direct IPv4 connection
116    DirectIPv4 {
117        /// When this stage started
118        started: Instant,
119    },
120    /// Attempting direct IPv6 connection
121    DirectIPv6 {
122        /// When this stage started
123        started: Instant,
124    },
125    /// Attempting hole-punching via a coordinator
126    HolePunching {
127        /// The coordinator peer address
128        coordinator: SocketAddr,
129        /// Current hole-punch round (starts at 1)
130        round: u32,
131        /// When this stage started
132        started: Instant,
133    },
134    /// Attempting relay connection
135    Relay {
136        /// The relay server address
137        relay_addr: SocketAddr,
138        /// When this stage started
139        started: Instant,
140    },
141    /// Successfully connected
142    Connected {
143        /// How the connection was established
144        via: ConnectionMethod,
145    },
146    /// All methods failed
147    Failed {
148        /// Errors from all attempted methods
149        errors: Vec<ConnectionAttemptError>,
150    },
151}
152
153/// Configuration for connection strategy timeouts and behavior
154#[derive(Debug, Clone)]
155pub struct StrategyConfig {
156    /// Timeout for direct IPv4 connection attempts
157    pub ipv4_timeout: Duration,
158    /// Timeout for direct IPv6 connection attempts
159    pub ipv6_timeout: Duration,
160    /// Timeout for each hole-punch round
161    pub holepunch_timeout: Duration,
162    /// Timeout for relay connection
163    pub relay_timeout: Duration,
164    /// Maximum number of hole-punch rounds before falling back to relay
165    pub max_holepunch_rounds: u32,
166    /// Whether to attempt IPv6 connections
167    pub ipv6_enabled: bool,
168    /// Whether to attempt relay connections as final fallback
169    pub relay_enabled: bool,
170    /// Optional coordinator address for hole-punching
171    pub coordinator: Option<SocketAddr>,
172    /// Optional relay server address
173    pub relay_addr: Option<SocketAddr>,
174}
175
176impl Default for StrategyConfig {
177    fn default() -> Self {
178        Self {
179            ipv4_timeout: Duration::from_secs(5),
180            ipv6_timeout: Duration::from_secs(5),
181            holepunch_timeout: Duration::from_secs(15),
182            relay_timeout: Duration::from_secs(30),
183            max_holepunch_rounds: 3,
184            ipv6_enabled: true,
185            relay_enabled: true,
186            coordinator: None,
187            relay_addr: None,
188        }
189    }
190}
191
192impl StrategyConfig {
193    /// Create a new strategy config with default values
194    pub fn new() -> Self {
195        Self::default()
196    }
197
198    /// Set the IPv4 timeout
199    pub fn with_ipv4_timeout(mut self, timeout: Duration) -> Self {
200        self.ipv4_timeout = timeout;
201        self
202    }
203
204    /// Set the IPv6 timeout
205    pub fn with_ipv6_timeout(mut self, timeout: Duration) -> Self {
206        self.ipv6_timeout = timeout;
207        self
208    }
209
210    /// Set the hole-punch timeout
211    pub fn with_holepunch_timeout(mut self, timeout: Duration) -> Self {
212        self.holepunch_timeout = timeout;
213        self
214    }
215
216    /// Set the relay timeout
217    pub fn with_relay_timeout(mut self, timeout: Duration) -> Self {
218        self.relay_timeout = timeout;
219        self
220    }
221
222    /// Set the maximum number of hole-punch rounds
223    pub fn with_max_holepunch_rounds(mut self, rounds: u32) -> Self {
224        self.max_holepunch_rounds = rounds;
225        self
226    }
227
228    /// Enable or disable IPv6 attempts
229    pub fn with_ipv6_enabled(mut self, enabled: bool) -> Self {
230        self.ipv6_enabled = enabled;
231        self
232    }
233
234    /// Enable or disable relay fallback
235    pub fn with_relay_enabled(mut self, enabled: bool) -> Self {
236        self.relay_enabled = enabled;
237        self
238    }
239
240    /// Set the coordinator address for hole-punching
241    pub fn with_coordinator(mut self, addr: SocketAddr) -> Self {
242        self.coordinator = Some(addr);
243        self
244    }
245
246    /// Set the relay server address
247    pub fn with_relay(mut self, addr: SocketAddr) -> Self {
248        self.relay_addr = Some(addr);
249        self
250    }
251}
252
253/// Connection strategy state machine
254///
255/// Manages the progression through connection methods from fastest (direct)
256/// to most reliable (relay).
257#[derive(Debug)]
258pub struct ConnectionStrategy {
259    stage: ConnectionStage,
260    config: StrategyConfig,
261    errors: Vec<ConnectionAttemptError>,
262}
263
264impl ConnectionStrategy {
265    /// Create a new connection strategy with the given configuration
266    pub fn new(config: StrategyConfig) -> Self {
267        Self {
268            stage: ConnectionStage::DirectIPv4 {
269                started: Instant::now(),
270            },
271            config,
272            errors: Vec::new(),
273        }
274    }
275
276    /// Get the current stage
277    pub fn current_stage(&self) -> &ConnectionStage {
278        &self.stage
279    }
280
281    /// Get the configuration
282    pub fn config(&self) -> &StrategyConfig {
283        &self.config
284    }
285
286    /// Get the IPv4 timeout
287    pub fn ipv4_timeout(&self) -> Duration {
288        self.config.ipv4_timeout
289    }
290
291    /// Get the IPv6 timeout
292    pub fn ipv6_timeout(&self) -> Duration {
293        self.config.ipv6_timeout
294    }
295
296    /// Get the hole-punch timeout
297    pub fn holepunch_timeout(&self) -> Duration {
298        self.config.holepunch_timeout
299    }
300
301    /// Get the relay timeout
302    pub fn relay_timeout(&self) -> Duration {
303        self.config.relay_timeout
304    }
305
306    /// Record an error and transition to IPv6 stage
307    pub fn transition_to_ipv6(&mut self, error: impl Into<String>) {
308        self.errors.push(ConnectionAttemptError {
309            method: AttemptedMethod::DirectIPv4,
310            error: error.into(),
311            timestamp: Instant::now(),
312        });
313
314        if self.config.ipv6_enabled {
315            self.stage = ConnectionStage::DirectIPv6 {
316                started: Instant::now(),
317            };
318        } else {
319            self.transition_to_holepunch_internal();
320        }
321    }
322
323    /// Record an error and transition to hole-punching stage
324    pub fn transition_to_holepunch(&mut self, error: impl Into<String>) {
325        self.errors.push(ConnectionAttemptError {
326            method: AttemptedMethod::DirectIPv6,
327            error: error.into(),
328            timestamp: Instant::now(),
329        });
330        self.transition_to_holepunch_internal();
331    }
332
333    fn transition_to_holepunch_internal(&mut self) {
334        if let Some(coordinator) = self.config.coordinator {
335            self.stage = ConnectionStage::HolePunching {
336                coordinator,
337                round: 1,
338                started: Instant::now(),
339            };
340        } else {
341            // No coordinator available, skip to relay
342            self.transition_to_relay_internal();
343        }
344    }
345
346    /// Record a hole-punch error and either retry or transition to relay
347    pub fn record_holepunch_error(&mut self, round: u32, error: impl Into<String>) {
348        self.errors.push(ConnectionAttemptError {
349            method: AttemptedMethod::HolePunch { round },
350            error: error.into(),
351            timestamp: Instant::now(),
352        });
353    }
354
355    /// Check if we should retry hole-punching
356    pub fn should_retry_holepunch(&self) -> bool {
357        if let ConnectionStage::HolePunching { round, .. } = &self.stage {
358            *round < self.config.max_holepunch_rounds
359        } else {
360            false
361        }
362    }
363
364    /// Increment the hole-punch round
365    pub fn increment_round(&mut self) {
366        if let ConnectionStage::HolePunching {
367            coordinator, round, ..
368        } = &self.stage
369        {
370            self.stage = ConnectionStage::HolePunching {
371                coordinator: *coordinator,
372                round: round + 1,
373                started: Instant::now(),
374            };
375        }
376    }
377
378    /// Transition to relay stage
379    pub fn transition_to_relay(&mut self, error: impl Into<String>) {
380        if let ConnectionStage::HolePunching { round, .. } = &self.stage {
381            self.errors.push(ConnectionAttemptError {
382                method: AttemptedMethod::HolePunch { round: *round },
383                error: error.into(),
384                timestamp: Instant::now(),
385            });
386        }
387        self.transition_to_relay_internal();
388    }
389
390    fn transition_to_relay_internal(&mut self) {
391        if self.config.relay_enabled {
392            if let Some(relay_addr) = self.config.relay_addr {
393                self.stage = ConnectionStage::Relay {
394                    relay_addr,
395                    started: Instant::now(),
396                };
397            } else {
398                // No relay available
399                self.transition_to_failed("No relay server configured");
400            }
401        } else {
402            self.transition_to_failed("Relay disabled and all other methods failed");
403        }
404    }
405
406    /// Transition to failed state
407    pub fn transition_to_failed(&mut self, error: impl Into<String>) {
408        // Record the final error if we came from relay stage
409        if let ConnectionStage::Relay { .. } = &self.stage {
410            self.errors.push(ConnectionAttemptError {
411                method: AttemptedMethod::Relay,
412                error: error.into(),
413                timestamp: Instant::now(),
414            });
415        }
416
417        self.stage = ConnectionStage::Failed {
418            errors: std::mem::take(&mut self.errors),
419        };
420    }
421
422    /// Mark connection as successful via the specified method
423    pub fn mark_connected(&mut self, method: ConnectionMethod) {
424        self.stage = ConnectionStage::Connected { via: method };
425    }
426
427    /// Check if the strategy has reached a terminal state
428    pub fn is_terminal(&self) -> bool {
429        matches!(
430            self.stage,
431            ConnectionStage::Connected { .. } | ConnectionStage::Failed { .. }
432        )
433    }
434
435    /// Get all recorded errors
436    pub fn errors(&self) -> &[ConnectionAttemptError] {
437        &self.errors
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_default_config() {
447        let config = StrategyConfig::default();
448        assert_eq!(config.ipv4_timeout, Duration::from_secs(5));
449        assert_eq!(config.ipv6_timeout, Duration::from_secs(5));
450        assert_eq!(config.holepunch_timeout, Duration::from_secs(15));
451        assert_eq!(config.relay_timeout, Duration::from_secs(30));
452        assert_eq!(config.max_holepunch_rounds, 3);
453        assert!(config.ipv6_enabled);
454        assert!(config.relay_enabled);
455    }
456
457    #[test]
458    fn test_config_builder() {
459        let config = StrategyConfig::new()
460            .with_ipv4_timeout(Duration::from_secs(3))
461            .with_ipv6_timeout(Duration::from_secs(3))
462            .with_max_holepunch_rounds(5)
463            .with_ipv6_enabled(false);
464
465        assert_eq!(config.ipv4_timeout, Duration::from_secs(3));
466        assert_eq!(config.max_holepunch_rounds, 5);
467        assert!(!config.ipv6_enabled);
468    }
469
470    #[test]
471    fn test_initial_stage() {
472        let strategy = ConnectionStrategy::new(StrategyConfig::default());
473        assert!(matches!(
474            strategy.current_stage(),
475            ConnectionStage::DirectIPv4 { .. }
476        ));
477    }
478
479    #[test]
480    fn test_transition_ipv4_to_ipv6() {
481        let mut strategy = ConnectionStrategy::new(StrategyConfig::default());
482
483        strategy.transition_to_ipv6("Connection refused");
484
485        assert!(matches!(
486            strategy.current_stage(),
487            ConnectionStage::DirectIPv6 { .. }
488        ));
489        assert_eq!(strategy.errors().len(), 1);
490        assert!(matches!(
491            strategy.errors()[0].method,
492            AttemptedMethod::DirectIPv4
493        ));
494    }
495
496    #[test]
497    fn test_skip_ipv6_when_disabled() {
498        let config = StrategyConfig::new()
499            .with_ipv6_enabled(false)
500            .with_coordinator("127.0.0.1:9000".parse().unwrap());
501        let mut strategy = ConnectionStrategy::new(config);
502
503        strategy.transition_to_ipv6("Connection refused");
504
505        // Should skip directly to hole-punching
506        assert!(matches!(
507            strategy.current_stage(),
508            ConnectionStage::HolePunching { round: 1, .. }
509        ));
510    }
511
512    #[test]
513    fn test_transition_to_holepunch() {
514        let config = StrategyConfig::new().with_coordinator("127.0.0.1:9000".parse().unwrap());
515        let mut strategy = ConnectionStrategy::new(config);
516
517        strategy.transition_to_ipv6("IPv4 failed");
518        strategy.transition_to_holepunch("IPv6 failed");
519
520        assert!(matches!(
521            strategy.current_stage(),
522            ConnectionStage::HolePunching {
523                round: 1,
524                coordinator,
525                ..
526            } if coordinator.port() == 9000
527        ));
528    }
529
530    #[test]
531    fn test_holepunch_rounds() {
532        let config = StrategyConfig::new()
533            .with_coordinator("127.0.0.1:9000".parse().unwrap())
534            .with_max_holepunch_rounds(3);
535        let mut strategy = ConnectionStrategy::new(config);
536
537        // Get to holepunch stage
538        strategy.transition_to_ipv6("IPv4 failed");
539        strategy.transition_to_holepunch("IPv6 failed");
540
541        // Round 1
542        assert!(strategy.should_retry_holepunch());
543        strategy.record_holepunch_error(1, "Round 1 failed");
544        strategy.increment_round();
545
546        // Round 2
547        if let ConnectionStage::HolePunching { round, .. } = strategy.current_stage() {
548            assert_eq!(*round, 2);
549        } else {
550            panic!("Expected HolePunching stage");
551        }
552        assert!(strategy.should_retry_holepunch());
553        strategy.record_holepunch_error(2, "Round 2 failed");
554        strategy.increment_round();
555
556        // Round 3 - last round
557        if let ConnectionStage::HolePunching { round, .. } = strategy.current_stage() {
558            assert_eq!(*round, 3);
559        } else {
560            panic!("Expected HolePunching stage");
561        }
562        assert!(!strategy.should_retry_holepunch());
563    }
564
565    #[test]
566    fn test_transition_to_relay() {
567        let config = StrategyConfig::new()
568            .with_coordinator("127.0.0.1:9000".parse().unwrap())
569            .with_relay("127.0.0.1:9001".parse().unwrap());
570        let mut strategy = ConnectionStrategy::new(config);
571
572        strategy.transition_to_ipv6("IPv4 failed");
573        strategy.transition_to_holepunch("IPv6 failed");
574        strategy.transition_to_relay("Holepunch failed");
575
576        if let ConnectionStage::Relay { relay_addr, .. } = strategy.current_stage() {
577            assert_eq!(relay_addr.port(), 9001);
578        } else {
579            panic!("Expected Relay stage");
580        }
581    }
582
583    #[test]
584    fn test_transition_to_failed() {
585        let config = StrategyConfig::new()
586            .with_coordinator("127.0.0.1:9000".parse().unwrap())
587            .with_relay("127.0.0.1:9001".parse().unwrap());
588        let mut strategy = ConnectionStrategy::new(config);
589
590        strategy.transition_to_ipv6("IPv4 failed");
591        strategy.transition_to_holepunch("IPv6 failed");
592        strategy.transition_to_relay("Holepunch failed");
593        strategy.transition_to_failed("Relay failed");
594
595        if let ConnectionStage::Failed { errors } = strategy.current_stage() {
596            assert_eq!(errors.len(), 4);
597        } else {
598            panic!("Expected Failed stage");
599        }
600    }
601
602    #[test]
603    fn test_mark_connected() {
604        let mut strategy = ConnectionStrategy::new(StrategyConfig::default());
605
606        strategy.mark_connected(ConnectionMethod::DirectIPv4);
607
608        if let ConnectionStage::Connected { via } = strategy.current_stage() {
609            assert_eq!(*via, ConnectionMethod::DirectIPv4);
610        } else {
611            panic!("Expected Connected stage");
612        }
613        assert!(strategy.is_terminal());
614    }
615
616    #[test]
617    fn test_connection_method_display() {
618        assert_eq!(format!("{}", ConnectionMethod::DirectIPv4), "Direct IPv4");
619        assert_eq!(format!("{}", ConnectionMethod::DirectIPv6), "Direct IPv6");
620        assert_eq!(
621            format!(
622                "{}",
623                ConnectionMethod::HolePunched {
624                    coordinator: "1.2.3.4:9000".parse().unwrap()
625                }
626            ),
627            "Hole-punched via 1.2.3.4:9000"
628        );
629        assert_eq!(
630            format!(
631                "{}",
632                ConnectionMethod::Relayed {
633                    relay: "5.6.7.8:9001".parse().unwrap()
634                }
635            ),
636            "Relayed via 5.6.7.8:9001"
637        );
638    }
639
640    #[test]
641    fn test_no_coordinator_skips_to_relay() {
642        let config = StrategyConfig::new().with_relay("127.0.0.1:9001".parse().unwrap());
643        // No coordinator set
644        let mut strategy = ConnectionStrategy::new(config);
645
646        strategy.transition_to_ipv6("IPv4 failed");
647        strategy.transition_to_holepunch("IPv6 failed");
648
649        // Should skip hole-punching and go to relay
650        assert!(matches!(
651            strategy.current_stage(),
652            ConnectionStage::Relay { .. }
653        ));
654    }
655
656    #[test]
657    fn test_no_relay_fails() {
658        let config = StrategyConfig::new()
659            .with_coordinator("127.0.0.1:9000".parse().unwrap())
660            .with_relay_enabled(false);
661        let mut strategy = ConnectionStrategy::new(config);
662
663        strategy.transition_to_ipv6("IPv4 failed");
664        strategy.transition_to_holepunch("IPv6 failed");
665        strategy.transition_to_relay("Holepunch failed");
666
667        // Should fail since relay is disabled
668        assert!(matches!(
669            strategy.current_stage(),
670            ConnectionStage::Failed { .. }
671        ));
672    }
673}