ant_quic/nat_traversal/
rfc_migration.rs1use crate::{
8 TransportError, VarInt,
9 frame::{Frame, FrameType},
10};
11use std::net::SocketAddr;
12
13#[derive(Debug, Clone)]
15pub struct NatMigrationConfig {
16 pub accept_legacy_frames: bool,
18 pub send_rfc_frames: bool,
20 pub priority_strategy: PriorityCalculation,
22}
23
24impl Default for NatMigrationConfig {
25 fn default() -> Self {
26 Self {
27 accept_legacy_frames: true,
29 send_rfc_frames: false,
30 priority_strategy: PriorityCalculation::IceLike,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy)]
37pub enum PriorityCalculation {
38 IceLike,
40 Simple,
42 Fixed(u32),
44}
45
46impl NatMigrationConfig {
47 pub fn rfc_compliant() -> Self {
49 Self {
50 accept_legacy_frames: false,
51 send_rfc_frames: true,
52 priority_strategy: PriorityCalculation::IceLike,
53 }
54 }
55
56 pub fn legacy_only() -> Self {
58 Self {
59 accept_legacy_frames: true,
60 send_rfc_frames: false,
61 priority_strategy: PriorityCalculation::IceLike,
62 }
63 }
64}
65
66pub fn calculate_address_priority(addr: &SocketAddr, strategy: PriorityCalculation) -> u32 {
68 match strategy {
69 PriorityCalculation::Fixed(p) => p,
70 PriorityCalculation::Simple => simple_priority(addr),
71 PriorityCalculation::IceLike => ice_like_priority(addr),
72 }
73}
74
75fn simple_priority(addr: &SocketAddr) -> u32 {
77 match addr {
78 SocketAddr::V4(v4) => {
79 let ip = v4.ip();
80 if ip.is_loopback() {
81 100 } else if ip.is_private() {
83 200 } else {
85 300 }
87 }
88 SocketAddr::V6(v6) => {
89 let ip = v6.ip();
90 if ip.is_loopback() {
91 50 } else if ip.is_unicast_link_local() {
93 150 } else {
95 250 }
97 }
98 }
99}
100
101fn ice_like_priority(addr: &SocketAddr) -> u32 {
103 let type_pref = match addr {
106 SocketAddr::V4(v4) => {
107 let ip = v4.ip();
108 if ip.is_loopback() {
109 0 } else if ip.is_private() {
111 100 } else {
113 126 }
115 }
116 SocketAddr::V6(v6) => {
117 let ip = v6.ip();
118 if ip.is_loopback() {
119 0 } else if ip.is_unicast_link_local() {
121 90 } else {
123 120 }
125 }
126 };
127
128 let local_pref = match addr {
130 SocketAddr::V4(_) => 65535, SocketAddr::V6(_) => 65534, };
133
134 let component_id = 1;
136
137 ((type_pref as u32) << 24) + ((local_pref as u32) << 8) + (256 - component_id)
139}
140
141pub struct FrameMigrator {
143 config: NatMigrationConfig,
144}
145
146impl FrameMigrator {
147 pub fn new(config: NatMigrationConfig) -> Self {
148 Self { config }
149 }
150
151 pub fn should_send_rfc_frames(&self) -> bool {
153 self.config.send_rfc_frames
154 }
155
156 pub fn process_incoming_frame(
158 &self,
159 _frame_type: FrameType,
160 frame: Frame,
161 _sender_addr: SocketAddr,
162 ) -> Result<Frame, TransportError> {
163 match frame {
164 Frame::AddAddress(mut add) => {
165 if add.priority == VarInt::from_u32(0) {
167 add.priority = VarInt::from_u32(calculate_address_priority(
168 &add.address,
169 self.config.priority_strategy,
170 ));
171 }
172 Ok(Frame::AddAddress(add))
173 }
174 Frame::PunchMeNow(punch) => {
175 Ok(Frame::PunchMeNow(punch))
177 }
178 _ => Ok(frame),
179 }
180 }
181
182 pub fn should_accept_frame(&self, frame_type: FrameType) -> bool {
184 if self.config.accept_legacy_frames {
185 true
187 } else {
188 matches!(
190 frame_type,
191 FrameType::ADD_ADDRESS_IPV4
192 | FrameType::ADD_ADDRESS_IPV6
193 | FrameType::PUNCH_ME_NOW_IPV4
194 | FrameType::PUNCH_ME_NOW_IPV6
195 | FrameType::REMOVE_ADDRESS
196 )
197 }
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct PeerCapabilities {
204 pub peer_id: Vec<u8>,
206 pub supports_rfc_nat: bool,
208 pub discovered_at: std::time::Instant,
210}
211
212pub struct CapabilityTracker {
214 peers: std::collections::HashMap<Vec<u8>, PeerCapabilities>,
215}
216
217impl CapabilityTracker {
218 pub fn new() -> Self {
219 Self {
220 peers: std::collections::HashMap::new(),
221 }
222 }
223
224 pub fn mark_rfc_capable(&mut self, peer_id: Vec<u8>) {
226 self.peers.insert(
227 peer_id.clone(),
228 PeerCapabilities {
229 peer_id,
230 supports_rfc_nat: true,
231 discovered_at: std::time::Instant::now(),
232 },
233 );
234 }
235
236 pub fn is_rfc_capable(&self, peer_id: &[u8]) -> bool {
238 self.peers
239 .get(peer_id)
240 .map(|cap| cap.supports_rfc_nat)
241 .unwrap_or(false)
242 }
243
244 pub fn cleanup_old_entries(&mut self, max_age: std::time::Duration) {
246 let now = std::time::Instant::now();
247 self.peers
248 .retain(|_, cap| now.duration_since(cap.discovered_at) < max_age);
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_priority_calculation() {
258 let public_v4: SocketAddr = "8.8.8.8:53".parse().unwrap();
259 let private_v4: SocketAddr = "192.168.1.1:80".parse().unwrap();
260 let loopback_v4: SocketAddr = "127.0.0.1:8080".parse().unwrap();
261
262 let simple_pub = calculate_address_priority(&public_v4, PriorityCalculation::Simple);
264 let simple_priv = calculate_address_priority(&private_v4, PriorityCalculation::Simple);
265 let simple_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::Simple);
266
267 assert!(simple_pub > simple_priv);
268 assert!(simple_priv > simple_loop);
269
270 let ice_pub = calculate_address_priority(&public_v4, PriorityCalculation::IceLike);
272 let ice_priv = calculate_address_priority(&private_v4, PriorityCalculation::IceLike);
273 let ice_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::IceLike);
274
275 assert!(ice_pub > ice_priv);
276 assert!(ice_priv > ice_loop);
277
278 let fixed = calculate_address_priority(&public_v4, PriorityCalculation::Fixed(12345));
280 assert_eq!(fixed, 12345);
281 }
282
283 #[test]
284 fn test_migration_configs() {
285 let default_config = NatMigrationConfig::default();
286 assert!(default_config.accept_legacy_frames);
287 assert!(!default_config.send_rfc_frames);
288
289 let rfc_config = NatMigrationConfig::rfc_compliant();
290 assert!(!rfc_config.accept_legacy_frames);
291 assert!(rfc_config.send_rfc_frames);
292
293 let legacy_config = NatMigrationConfig::legacy_only();
294 assert!(legacy_config.accept_legacy_frames);
295 assert!(!legacy_config.send_rfc_frames);
296 }
297
298 #[test]
299 fn test_capability_tracker() {
300 let mut tracker = CapabilityTracker::new();
301 let peer_id = vec![1, 2, 3, 4];
302
303 assert!(!tracker.is_rfc_capable(&peer_id));
304
305 tracker.mark_rfc_capable(peer_id.clone());
306 assert!(tracker.is_rfc_capable(&peer_id));
307
308 tracker.cleanup_old_entries(std::time::Duration::from_secs(3600));
310 assert!(tracker.is_rfc_capable(&peer_id)); }
312}