ant_quic/nat_traversal/
rfc_migration.rs

1//! RFC Migration Strategy for NAT Traversal
2//!
3//! This module provides a migration path from the current implementation
4//! to RFC-compliant frames while maintaining backward compatibility and
5//! preserving essential functionality like priority-based candidate selection.
6
7use crate::{
8    TransportError, VarInt,
9    frame::{Frame, FrameType},
10};
11use std::net::SocketAddr;
12
13/// Migration configuration for NAT traversal
14#[derive(Debug, Clone)]
15pub struct NatMigrationConfig {
16    /// Whether to accept old format frames
17    pub accept_legacy_frames: bool,
18    /// Whether to send RFC-compliant frames
19    pub send_rfc_frames: bool,
20    /// Default priority calculation strategy
21    pub priority_strategy: PriorityCalculation,
22}
23
24impl Default for NatMigrationConfig {
25    fn default() -> Self {
26        Self {
27            // Start in compatibility mode
28            accept_legacy_frames: true,
29            send_rfc_frames: false,
30            priority_strategy: PriorityCalculation::IceLike,
31        }
32    }
33}
34
35/// Priority calculation strategies
36#[derive(Debug, Clone, Copy)]
37pub enum PriorityCalculation {
38    /// Use ICE-like priority calculation
39    IceLike,
40    /// Simple priority based on address type
41    Simple,
42    /// Fixed priority for all addresses
43    Fixed(u32),
44}
45
46impl NatMigrationConfig {
47    /// Create a config for full RFC compliance
48    pub fn rfc_compliant() -> Self {
49        Self {
50            accept_legacy_frames: false,
51            send_rfc_frames: true,
52            priority_strategy: PriorityCalculation::IceLike,
53        }
54    }
55
56    /// Create a config for legacy mode
57    pub fn legacy_only() -> Self {
58        Self {
59            accept_legacy_frames: true,
60            send_rfc_frames: false,
61            priority_strategy: PriorityCalculation::IceLike,
62        }
63    }
64}
65
66/// Calculate priority for an address based on its characteristics
67pub fn calculate_address_priority(addr: &SocketAddr, strategy: PriorityCalculation) -> u32 {
68    match strategy {
69        PriorityCalculation::Fixed(p) => p,
70        PriorityCalculation::Simple => simple_priority(addr),
71        PriorityCalculation::IceLike => ice_like_priority(addr),
72    }
73}
74
75/// Simple priority calculation based on address type
76fn simple_priority(addr: &SocketAddr) -> u32 {
77    match addr {
78        SocketAddr::V4(v4) => {
79            let ip = v4.ip();
80            if ip.is_loopback() {
81                100 // Lowest
82            } else if ip.is_private() {
83                200 // Medium
84            } else {
85                300 // Highest
86            }
87        }
88        SocketAddr::V6(v6) => {
89            let ip = v6.ip();
90            if ip.is_loopback() {
91                50 // Lower than IPv4 loopback
92            } else if ip.is_unicast_link_local() {
93                150 // Link-local
94            } else {
95                250 // Slightly lower than public IPv4
96            }
97        }
98    }
99}
100
101/// ICE-like priority calculation (RFC 5245 Section 4.1.2.1)
102fn ice_like_priority(addr: &SocketAddr) -> u32 {
103    // Priority = (2^24)*(type preference) + (2^8)*(local preference) + (256 - component ID)
104
105    let type_pref = match addr {
106        SocketAddr::V4(v4) => {
107            let ip = v4.ip();
108            if ip.is_loopback() {
109                0 // Host candidate (loopback)
110            } else if ip.is_private() {
111                100 // Host candidate (private)
112            } else {
113                126 // Server reflexive (public)
114            }
115        }
116        SocketAddr::V6(v6) => {
117            let ip = v6.ip();
118            if ip.is_loopback() {
119                0 // Host candidate (loopback)
120            } else if ip.is_unicast_link_local() {
121                90 // Host candidate (link-local)
122            } else {
123                120 // Server reflexive (public IPv6)
124            }
125        }
126    };
127
128    // Local preference based on IP family
129    let local_pref = match addr {
130        SocketAddr::V4(_) => 65535, // Prefer IPv4 for compatibility
131        SocketAddr::V6(_) => 65534, // Slightly lower for IPv6
132    };
133
134    // Component ID (we only have one component in QUIC)
135    let component_id = 1;
136
137    // Calculate priority
138    ((type_pref as u32) << 24) + ((local_pref as u32) << 8) + (256 - component_id)
139}
140
141/// Frame conversion wrapper for migration
142pub struct FrameMigrator {
143    config: NatMigrationConfig,
144}
145
146impl FrameMigrator {
147    pub fn new(config: NatMigrationConfig) -> Self {
148        Self { config }
149    }
150
151    /// Check if we should send RFC frames based on configuration
152    pub fn should_send_rfc_frames(&self) -> bool {
153        self.config.send_rfc_frames
154    }
155
156    /// Process incoming frames based on configuration
157    pub fn process_incoming_frame(
158        &self,
159        _frame_type: FrameType,
160        frame: Frame,
161        _sender_addr: SocketAddr,
162    ) -> Result<Frame, TransportError> {
163        match frame {
164            Frame::AddAddress(mut add) => {
165                // If we received an RFC frame (no priority), calculate it
166                if add.priority == VarInt::from_u32(0) {
167                    add.priority = VarInt::from_u32(calculate_address_priority(
168                        &add.address,
169                        self.config.priority_strategy,
170                    ));
171                }
172                Ok(Frame::AddAddress(add))
173            }
174            Frame::PunchMeNow(punch) => {
175                // Handle both formats
176                Ok(Frame::PunchMeNow(punch))
177            }
178            _ => Ok(frame),
179        }
180    }
181
182    /// Check if we should accept this frame type
183    pub fn should_accept_frame(&self, frame_type: FrameType) -> bool {
184        if self.config.accept_legacy_frames {
185            // Accept all NAT traversal frames
186            true
187        } else {
188            // Only accept RFC-compliant frame types
189            matches!(
190                frame_type,
191                FrameType::ADD_ADDRESS_IPV4
192                    | FrameType::ADD_ADDRESS_IPV6
193                    | FrameType::PUNCH_ME_NOW_IPV4
194                    | FrameType::PUNCH_ME_NOW_IPV6
195                    | FrameType::REMOVE_ADDRESS
196            )
197        }
198    }
199}
200
201/// Helper to determine if a peer supports RFC frames
202#[derive(Debug, Clone)]
203pub struct PeerCapabilities {
204    /// Peer's connection ID
205    pub peer_id: Vec<u8>,
206    /// Whether peer supports RFC NAT traversal
207    pub supports_rfc_nat: bool,
208    /// When we learned about this capability
209    pub discovered_at: std::time::Instant,
210}
211
212/// Tracks peer capabilities for gradual migration
213pub struct CapabilityTracker {
214    peers: std::collections::HashMap<Vec<u8>, PeerCapabilities>,
215}
216
217impl CapabilityTracker {
218    pub fn new() -> Self {
219        Self {
220            peers: std::collections::HashMap::new(),
221        }
222    }
223
224    /// Record that a peer supports RFC frames
225    pub fn mark_rfc_capable(&mut self, peer_id: Vec<u8>) {
226        self.peers.insert(
227            peer_id.clone(),
228            PeerCapabilities {
229                peer_id,
230                supports_rfc_nat: true,
231                discovered_at: std::time::Instant::now(),
232            },
233        );
234    }
235
236    /// Check if a peer supports RFC frames
237    pub fn is_rfc_capable(&self, peer_id: &[u8]) -> bool {
238        self.peers
239            .get(peer_id)
240            .map(|cap| cap.supports_rfc_nat)
241            .unwrap_or(false)
242    }
243
244    /// Clean up old entries
245    pub fn cleanup_old_entries(&mut self, max_age: std::time::Duration) {
246        let now = std::time::Instant::now();
247        self.peers
248            .retain(|_, cap| now.duration_since(cap.discovered_at) < max_age);
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_priority_calculation() {
258        let public_v4: SocketAddr = "8.8.8.8:53".parse().unwrap();
259        let private_v4: SocketAddr = "192.168.1.1:80".parse().unwrap();
260        let loopback_v4: SocketAddr = "127.0.0.1:8080".parse().unwrap();
261
262        // Test simple strategy
263        let simple_pub = calculate_address_priority(&public_v4, PriorityCalculation::Simple);
264        let simple_priv = calculate_address_priority(&private_v4, PriorityCalculation::Simple);
265        let simple_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::Simple);
266
267        assert!(simple_pub > simple_priv);
268        assert!(simple_priv > simple_loop);
269
270        // Test ICE-like strategy
271        let ice_pub = calculate_address_priority(&public_v4, PriorityCalculation::IceLike);
272        let ice_priv = calculate_address_priority(&private_v4, PriorityCalculation::IceLike);
273        let ice_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::IceLike);
274
275        assert!(ice_pub > ice_priv);
276        assert!(ice_priv > ice_loop);
277
278        // Test fixed strategy
279        let fixed = calculate_address_priority(&public_v4, PriorityCalculation::Fixed(12345));
280        assert_eq!(fixed, 12345);
281    }
282
283    #[test]
284    fn test_migration_configs() {
285        let default_config = NatMigrationConfig::default();
286        assert!(default_config.accept_legacy_frames);
287        assert!(!default_config.send_rfc_frames);
288
289        let rfc_config = NatMigrationConfig::rfc_compliant();
290        assert!(!rfc_config.accept_legacy_frames);
291        assert!(rfc_config.send_rfc_frames);
292
293        let legacy_config = NatMigrationConfig::legacy_only();
294        assert!(legacy_config.accept_legacy_frames);
295        assert!(!legacy_config.send_rfc_frames);
296    }
297
298    #[test]
299    fn test_capability_tracker() {
300        let mut tracker = CapabilityTracker::new();
301        let peer_id = vec![1, 2, 3, 4];
302
303        assert!(!tracker.is_rfc_capable(&peer_id));
304
305        tracker.mark_rfc_capable(peer_id.clone());
306        assert!(tracker.is_rfc_capable(&peer_id));
307
308        // Test cleanup
309        tracker.cleanup_old_entries(std::time::Duration::from_secs(3600));
310        assert!(tracker.is_rfc_capable(&peer_id)); // Should still be there
311    }
312}