1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
10use std::time::Duration;
11
12use get_if_addrs::get_if_addrs;
13use tokio::net::UdpSocket;
14use tokio::sync::mpsc;
15use tokio::time::{interval, timeout};
16
17use rift_core::PeerId;
18use rift_metrics as metrics;
19use tracing::debug;
20use rand::RngCore;
21
22mod turn;
23pub use turn::{
24 TurnCandidate, TurnError, TurnRelay, TurnServerConfig, allocate_turn_relay,
25 parse_turn_server, spawn_turn_keepalive,
26};
27
28#[derive(Debug, Clone)]
29pub struct NatConfig {
30 pub local_ports: Vec<u16>,
32 pub stun_servers: Vec<SocketAddr>,
34 pub stun_timeout_ms: u64,
36 pub punch_interval_ms: u64,
38 pub punch_timeout_ms: u64,
40 pub turn_servers: Vec<TurnServerConfig>,
42 pub turn_timeout_ms: u64,
44 pub turn_keepalive_ms: u64,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum NatType {
50 Unknown,
51 OpenInternet,
52 Natted,
53}
54
55#[derive(Debug, Clone)]
56pub struct PeerEndpoint {
57 pub peer_id: PeerId,
59 pub external_addrs: Vec<SocketAddr>,
61 pub punch_ports: Vec<u16>,
63}
64
65#[derive(Debug, thiserror::Error)]
66pub enum HolePunchError {
67 #[error("no local ports could be bound")]
69 NoLocalPorts,
70 #[error("no remote addresses to punch")]
72 NoRemoteAddrs,
73 #[error("timeout while punching")]
75 Timeout,
76 #[error("io error: {0}")]
78 Io(#[from] std::io::Error),
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StunError {
83 #[error("no stun servers configured")]
85 NoServers,
86 #[error("no stun responses received")]
88 NoResponses,
89 #[error("invalid stun response")]
91 InvalidResponse,
92 #[error("io error: {0}")]
94 Io(#[from] std::io::Error),
95}
96
97const PUNCH_SYN: &[u8] = b"RIFT_PUNCH";
99const PUNCH_ACK: &[u8] = b"RIFT_ACK";
100const STUN_MAGIC_COOKIE: u32 = 0x2112A442;
101const STUN_BINDING_REQUEST: u16 = 0x0001;
102const STUN_BINDING_RESPONSE: u16 = 0x0101;
103const STUN_ATTR_MAPPED_ADDRESS: u16 = 0x0001;
104const STUN_ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
105const KEEPALIVE_BYTES: &[u8] = b"RIFT_KEEPALIVE";
106
107pub async fn gather_turn_candidates(nat_cfg: &NatConfig) -> Result<Vec<TurnCandidate>, TurnError> {
109 if nat_cfg.turn_servers.is_empty() {
110 return Err(TurnError::NoServers);
111 }
112 let mut out = Vec::new();
113 for server in nat_cfg.turn_servers.clone() {
114 match allocate_turn_relay(server, nat_cfg.turn_timeout_ms).await {
115 Ok(candidate) => out.push(candidate),
116 Err(err) => {
117 metrics::inc_counter("rift_turn_failures", &[("reason", "allocate")]);
118 debug!("turn allocate failed: {err}");
119 }
120 }
121 }
122 if out.is_empty() {
123 Err(TurnError::AllocationFailed)
124 } else {
125 Ok(out)
126 }
127}
128
129pub async fn attempt_hole_punch(
131 nat_cfg: &NatConfig,
132 peer: &PeerEndpoint,
133) -> Result<(UdpSocket, SocketAddr), HolePunchError> {
134 metrics::inc_counter("rift_hole_punch_attempts", &[]);
135 let ports = if nat_cfg.local_ports.is_empty() {
136 vec![0]
137 } else {
138 nat_cfg.local_ports.clone()
139 };
140
141 let mut sockets = Vec::new();
142 for port in ports {
143 if let Ok(socket) = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port)).await {
144 sockets.push(socket);
145 }
146 }
147
148 if sockets.is_empty() {
149 debug!("hole punch failed: no local ports");
150 metrics::inc_counter("rift_hole_punch_failures", &[("reason", "no_local_ports")]);
151 return Err(HolePunchError::NoLocalPorts);
152 }
153
154 let target_addrs = build_target_addrs(peer);
155 if target_addrs.is_empty() {
156 debug!("hole punch failed: no remote addrs");
157 metrics::inc_counter("rift_hole_punch_failures", &[("reason", "no_remote_addrs")]);
158 return Err(HolePunchError::NoRemoteAddrs);
159 }
160
161 let punch_interval_ms = nat_cfg.punch_interval_ms;
162 let done = Arc::new(AtomicBool::new(false));
163 let (tx, mut rx) = mpsc::channel::<(UdpSocket, SocketAddr)>(1);
164
165 for socket in sockets {
166 let targets = target_addrs.clone();
167 let done = done.clone();
168 let tx = tx.clone();
169 tokio::spawn(async move {
170 if done.load(Ordering::Relaxed) {
171 return;
172 }
173 let mut tick = interval(Duration::from_millis(punch_interval_ms.max(50)));
174 let mut buf = [0u8; 1024];
175
176 loop {
177 tokio::select! {
178 _ = tick.tick() => {
179 if done.load(Ordering::Relaxed) {
180 return;
181 }
182 for addr in &targets {
183 let _ = socket.send_to(PUNCH_SYN, addr).await;
184 }
185 }
186 recv = socket.recv_from(&mut buf) => {
187 let Ok((len, addr)) = recv else { continue; };
188 if done.load(Ordering::Relaxed) {
189 return;
190 }
191 if !targets.contains(&addr) {
192 continue;
193 }
194 let data = &buf[..len];
195 if data == PUNCH_SYN {
196 let _ = socket.send_to(PUNCH_ACK, addr).await;
197 } else if data == PUNCH_ACK {
198 let _ = socket.send_to(PUNCH_ACK, addr).await;
199 }
200 done.store(true, Ordering::Relaxed);
201 let _ = tx.send((socket, addr)).await;
202 return;
203 }
204 }
205 }
206 });
207 }
208
209 let timeout_ms = nat_cfg.punch_timeout_ms.max(500);
210 let result = timeout(Duration::from_millis(timeout_ms), rx.recv()).await;
211 match result {
212 Ok(Some((socket, addr))) => {
213 debug!(%addr, "hole punch success");
214 metrics::inc_counter("rift_hole_punch_success", &[]);
215 Ok((socket, addr))
216 }
217 _ => {
218 debug!("hole punch timeout");
219 metrics::inc_counter("rift_hole_punch_failures", &[("reason", "timeout")]);
220 Err(HolePunchError::Timeout)
221 }
222 }
223}
224
225pub fn gather_local_candidates(listen_port: u16) -> Vec<SocketAddr> {
228 let mut addrs = Vec::new();
229 if let Ok(ifaces) = get_if_addrs() {
230 for iface in ifaces {
231 let ip = iface.ip();
232 if ip.is_loopback() || ip.is_unspecified() {
233 continue;
234 }
235 if let IpAddr::V6(v6) = ip {
236 if v6.is_unicast_link_local() {
237 continue;
238 }
239 }
240 addrs.push(SocketAddr::new(ip, listen_port));
241 }
242 }
243 addrs.sort();
244 addrs.dedup();
245 addrs
246}
247
248pub fn detect_nat_type(local_addrs: &[SocketAddr], public_addrs: &[SocketAddr]) -> NatType {
250 if public_addrs.is_empty() {
251 return NatType::Unknown;
252 }
253 for public in public_addrs {
254 if local_addrs.iter().any(|local| local == public) {
255 return NatType::OpenInternet;
256 }
257 }
258 NatType::Natted
259}
260
261pub async fn gather_public_addrs(nat_cfg: &NatConfig) -> Result<Vec<SocketAddr>, StunError> {
263 if nat_cfg.stun_servers.is_empty() {
264 return Err(StunError::NoServers);
265 }
266 let ports = if nat_cfg.local_ports.is_empty() {
267 vec![0]
268 } else {
269 nat_cfg.local_ports.clone()
270 };
271
272 let mut results = Vec::new();
273 for port in ports {
274 for server in &nat_cfg.stun_servers {
275 if let Ok(addr) = stun_binding_request(*server, port, nat_cfg.stun_timeout_ms).await {
276 results.push(addr);
277 }
278 }
279 }
280
281 results.sort();
282 results.dedup();
283 if results.is_empty() {
284 Err(StunError::NoResponses)
285 } else {
286 Ok(results)
287 }
288}
289
290pub fn spawn_keepalive(
292 socket: Arc<UdpSocket>,
293 targets: Vec<SocketAddr>,
294 interval_ms: u64,
295) -> tokio::task::JoinHandle<()> {
296 tokio::spawn(async move {
297 if targets.is_empty() {
298 return;
299 }
300 let mut tick = interval(Duration::from_millis(interval_ms.max(200)));
301 loop {
302 tick.tick().await;
303 for addr in &targets {
304 let _ = socket.send_to(KEEPALIVE_BYTES, addr).await;
305 }
306 }
307 })
308}
309
310fn build_target_addrs(peer: &PeerEndpoint) -> Vec<SocketAddr> {
312 let mut addrs = Vec::new();
313 for addr in &peer.external_addrs {
314 addrs.push(*addr);
315 for port in &peer.punch_ports {
316 addrs.push(SocketAddr::new(addr.ip(), *port));
317 }
318 }
319 addrs.sort();
320 addrs.dedup();
321 addrs
322}
323
324async fn stun_binding_request(
326 server: SocketAddr,
327 local_port: u16,
328 timeout_ms: u64,
329) -> Result<SocketAddr, StunError> {
330 let socket = match server.ip() {
331 IpAddr::V4(_) => UdpSocket::bind((Ipv4Addr::UNSPECIFIED, local_port)).await?,
332 IpAddr::V6(_) => UdpSocket::bind((IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), local_port)).await?,
333 };
334 let mut tx_id = [0u8; 12];
335 rand::rngs::OsRng.fill_bytes(&mut tx_id);
336
337 let mut req = Vec::with_capacity(20);
338 req.extend_from_slice(&STUN_BINDING_REQUEST.to_be_bytes());
339 req.extend_from_slice(&0u16.to_be_bytes());
340 req.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
341 req.extend_from_slice(&tx_id);
342
343 socket.send_to(&req, server).await?;
344 let mut buf = [0u8; 1024];
345 let (len, _) = timeout(Duration::from_millis(timeout_ms), socket.recv_from(&mut buf))
346 .await
347 .map_err(|_| StunError::NoResponses)??;
348 parse_stun_response(&buf[..len], &tx_id)
349}
350
351fn parse_stun_response(buf: &[u8], tx_id: &[u8; 12]) -> Result<SocketAddr, StunError> {
353 if buf.len() < 20 {
354 return Err(StunError::InvalidResponse);
355 }
356 let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
357 let msg_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
358 let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
359 if msg_type != STUN_BINDING_RESPONSE || cookie != STUN_MAGIC_COOKIE {
360 return Err(StunError::InvalidResponse);
361 }
362 if &buf[8..20] != tx_id {
363 return Err(StunError::InvalidResponse);
364 }
365
366 let mut offset = 20usize;
367 let end = 20 + msg_len.min(buf.len().saturating_sub(20));
368 while offset + 4 <= end {
369 let attr_type = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
370 let attr_len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
371 let value_start = offset + 4;
372 let value_end = value_start + attr_len;
373 if value_end > buf.len() {
374 break;
375 }
376 if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS || attr_type == STUN_ATTR_MAPPED_ADDRESS {
377 if let Ok(addr) = parse_mapped_address(&buf[value_start..value_end], attr_type, tx_id) {
378 return Ok(addr);
379 }
380 }
381 let padded = (attr_len + 3) & !3;
382 offset = value_start + padded;
383 }
384 Err(StunError::InvalidResponse)
385}
386
387fn parse_mapped_address(
388 value: &[u8],
389 attr_type: u16,
390 tx_id: &[u8; 12],
391) -> Result<SocketAddr, StunError> {
392 if value.len() < 4 {
393 return Err(StunError::InvalidResponse);
394 }
395 let family = value[1];
396 let port = u16::from_be_bytes([value[2], value[3]]);
397 let port = if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
398 port ^ ((STUN_MAGIC_COOKIE >> 16) as u16)
399 } else {
400 port
401 };
402 match family {
403 0x01 => {
404 if value.len() < 8 {
405 return Err(StunError::InvalidResponse);
406 }
407 let mut ip = [0u8; 4];
408 ip.copy_from_slice(&value[4..8]);
409 if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
410 let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
411 for i in 0..4 {
412 ip[i] ^= cookie[i];
413 }
414 }
415 Ok(SocketAddr::new(IpAddr::V4(ip.into()), port))
416 }
417 0x02 => {
418 if value.len() < 20 {
419 return Err(StunError::InvalidResponse);
420 }
421 let mut ip = [0u8; 16];
422 ip.copy_from_slice(&value[4..20]);
423 if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
424 let mut xor = [0u8; 16];
425 xor[..4].copy_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
426 xor[4..].copy_from_slice(tx_id);
427 for i in 0..16 {
428 ip[i] ^= xor[i];
429 }
430 }
431 Ok(SocketAddr::new(IpAddr::V6(ip.into()), port))
432 }
433 _ => Err(StunError::InvalidResponse),
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use tokio::task::JoinHandle;
441
442 async fn spawn_mock_stun(addr: SocketAddr, mapped: SocketAddr) -> JoinHandle<()> {
443 tokio::spawn(async move {
444 let socket = UdpSocket::bind(addr).await.expect("bind stun");
445 let mut buf = [0u8; 1024];
446 let Ok((len, peer)) = socket.recv_from(&mut buf).await else {
447 return;
448 };
449 if len < 20 {
450 return;
451 }
452 let tx_id: [u8; 12] = buf[8..20].try_into().unwrap();
453 let response = build_stun_response(&tx_id, mapped);
454 let _ = socket.send_to(&response, peer).await;
455 })
456 }
457
458 fn build_stun_response(tx_id: &[u8; 12], mapped: SocketAddr) -> Vec<u8> {
459 let mut out = Vec::with_capacity(64);
460 out.extend_from_slice(&STUN_BINDING_RESPONSE.to_be_bytes());
461 out.extend_from_slice(&0u16.to_be_bytes());
462 out.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
463 out.extend_from_slice(tx_id);
464
465 match mapped {
466 SocketAddr::V4(addr) => {
467 let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
468 let ip = u32::from(*addr.ip()) ^ STUN_MAGIC_COOKIE;
469 let mut attr = Vec::with_capacity(12);
470 attr.extend_from_slice(&STUN_ATTR_XOR_MAPPED_ADDRESS.to_be_bytes());
471 attr.extend_from_slice(&8u16.to_be_bytes());
472 attr.push(0);
473 attr.push(0x01);
474 attr.extend_from_slice(&port.to_be_bytes());
475 attr.extend_from_slice(&ip.to_be_bytes());
476 let len = attr.len() as u16;
477 out[2..4].copy_from_slice(&len.to_be_bytes());
478 out.extend_from_slice(&attr);
479 }
480 SocketAddr::V6(addr) => {
481 let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
482 let mut ip = addr.ip().octets();
483 let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
484 for i in 0..4 {
485 ip[i] ^= cookie[i];
486 }
487 for i in 0..12 {
488 ip[4 + i] ^= tx_id[i];
489 }
490 let mut attr = Vec::with_capacity(24);
491 attr.extend_from_slice(&STUN_ATTR_XOR_MAPPED_ADDRESS.to_be_bytes());
492 attr.extend_from_slice(&20u16.to_be_bytes());
493 attr.push(0);
494 attr.push(0x02);
495 attr.extend_from_slice(&port.to_be_bytes());
496 attr.extend_from_slice(&ip);
497 let len = attr.len() as u16;
498 out[2..4].copy_from_slice(&len.to_be_bytes());
499 out.extend_from_slice(&attr);
500 }
501 }
502 out
503 }
504
505 #[tokio::test]
506 async fn stun_binding_returns_mapped_addr() {
507 let stun_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 34878);
508 let mapped = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 54321);
509 let _handle = spawn_mock_stun(stun_addr, mapped).await;
510
511 let addr = stun_binding_request(stun_addr, 0, 1000).await.unwrap();
512 assert_eq!(addr, mapped);
513 }
514
515 #[test]
516 fn local_candidates_exclude_loopback() {
517 let list = gather_local_candidates(9999);
518 for addr in list {
519 assert!(!addr.ip().is_loopback());
520 }
521 }
522}