Skip to main content

rift_nat/
lib.rs

1//! NAT traversal helpers (STUN, hole punching, TURN integration).
2//!
3//! This module provides:
4//! - STUN binding discovery for public addresses
5//! - UDP hole punching between peers
6//! - TURN relay allocation helpers (via `turn` module)
7
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
10use std::time::Duration;
11
12use get_if_addrs::get_if_addrs;
13use tokio::net::UdpSocket;
14use tokio::sync::mpsc;
15use tokio::time::{interval, timeout};
16
17use rift_core::PeerId;
18use rift_metrics as metrics;
19use tracing::debug;
20use rand::RngCore;
21
22mod turn;
23pub use turn::{
24    TurnCandidate, TurnError, TurnRelay, TurnServerConfig, allocate_turn_relay,
25    parse_turn_server, spawn_turn_keepalive,
26};
27
28#[derive(Debug, Clone)]
29pub struct NatConfig {
30    /// Ports to attempt for local binding (0 means OS-assigned).
31    pub local_ports: Vec<u16>,
32    /// STUN servers used for public address discovery.
33    pub stun_servers: Vec<SocketAddr>,
34    /// Timeout for STUN binding requests.
35    pub stun_timeout_ms: u64,
36    /// Interval between hole-punch packets.
37    pub punch_interval_ms: u64,
38    /// Overall hole-punch timeout.
39    pub punch_timeout_ms: u64,
40    /// TURN servers to use for relay allocation.
41    pub turn_servers: Vec<TurnServerConfig>,
42    /// Timeout for TURN allocations.
43    pub turn_timeout_ms: u64,
44    /// TURN keepalive interval.
45    pub turn_keepalive_ms: u64,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum NatType {
50    Unknown,
51    OpenInternet,
52    Natted,
53}
54
55#[derive(Debug, Clone)]
56pub struct PeerEndpoint {
57    /// Peer id for logging/context.
58    pub peer_id: PeerId,
59    /// Public addresses advertised by the peer.
60    pub external_addrs: Vec<SocketAddr>,
61    /// Additional ports to try for hole punching.
62    pub punch_ports: Vec<u16>,
63}
64
65#[derive(Debug, thiserror::Error)]
66pub enum HolePunchError {
67    /// No local UDP sockets could be bound.
68    #[error("no local ports could be bound")]
69    NoLocalPorts,
70    /// No remote addresses provided for punching.
71    #[error("no remote addresses to punch")]
72    NoRemoteAddrs,
73    /// Hole punch timed out without success.
74    #[error("timeout while punching")]
75    Timeout,
76    /// Low-level socket I/O error.
77    #[error("io error: {0}")]
78    Io(#[from] std::io::Error),
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StunError {
83    /// STUN servers not configured.
84    #[error("no stun servers configured")]
85    NoServers,
86    /// No response from any STUN server.
87    #[error("no stun responses received")]
88    NoResponses,
89    /// Malformed or unexpected STUN response.
90    #[error("invalid stun response")]
91    InvalidResponse,
92    /// Low-level socket I/O error.
93    #[error("io error: {0}")]
94    Io(#[from] std::io::Error),
95}
96
97/// UDP payloads used for hole punching.
98const PUNCH_SYN: &[u8] = b"RIFT_PUNCH";
99const PUNCH_ACK: &[u8] = b"RIFT_ACK";
100const STUN_MAGIC_COOKIE: u32 = 0x2112A442;
101const STUN_BINDING_REQUEST: u16 = 0x0001;
102const STUN_BINDING_RESPONSE: u16 = 0x0101;
103const STUN_ATTR_MAPPED_ADDRESS: u16 = 0x0001;
104const STUN_ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
105const KEEPALIVE_BYTES: &[u8] = b"RIFT_KEEPALIVE";
106
107/// Allocate TURN relays and return candidates if successful.
108pub async fn gather_turn_candidates(nat_cfg: &NatConfig) -> Result<Vec<TurnCandidate>, TurnError> {
109    if nat_cfg.turn_servers.is_empty() {
110        return Err(TurnError::NoServers);
111    }
112    let mut out = Vec::new();
113    for server in nat_cfg.turn_servers.clone() {
114        match allocate_turn_relay(server, nat_cfg.turn_timeout_ms).await {
115            Ok(candidate) => out.push(candidate),
116            Err(err) => {
117                metrics::inc_counter("rift_turn_failures", &[("reason", "allocate")]);
118                debug!("turn allocate failed: {err}");
119            }
120        }
121    }
122    if out.is_empty() {
123        Err(TurnError::AllocationFailed)
124    } else {
125        Ok(out)
126    }
127}
128
129/// Attempt UDP hole punching with a peer and return the first successful socket.
130pub async fn attempt_hole_punch(
131    nat_cfg: &NatConfig,
132    peer: &PeerEndpoint,
133) -> Result<(UdpSocket, SocketAddr), HolePunchError> {
134    metrics::inc_counter("rift_hole_punch_attempts", &[]);
135    let ports = if nat_cfg.local_ports.is_empty() {
136        vec![0]
137    } else {
138        nat_cfg.local_ports.clone()
139    };
140
141    let mut sockets = Vec::new();
142    for port in ports {
143        if let Ok(socket) = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port)).await {
144            sockets.push(socket);
145        }
146    }
147
148    if sockets.is_empty() {
149        debug!("hole punch failed: no local ports");
150        metrics::inc_counter("rift_hole_punch_failures", &[("reason", "no_local_ports")]);
151        return Err(HolePunchError::NoLocalPorts);
152    }
153
154    let target_addrs = build_target_addrs(peer);
155    if target_addrs.is_empty() {
156        debug!("hole punch failed: no remote addrs");
157        metrics::inc_counter("rift_hole_punch_failures", &[("reason", "no_remote_addrs")]);
158        return Err(HolePunchError::NoRemoteAddrs);
159    }
160
161    let punch_interval_ms = nat_cfg.punch_interval_ms;
162    let done = Arc::new(AtomicBool::new(false));
163    let (tx, mut rx) = mpsc::channel::<(UdpSocket, SocketAddr)>(1);
164
165    for socket in sockets {
166        let targets = target_addrs.clone();
167        let done = done.clone();
168        let tx = tx.clone();
169        tokio::spawn(async move {
170            if done.load(Ordering::Relaxed) {
171                return;
172            }
173            let mut tick = interval(Duration::from_millis(punch_interval_ms.max(50)));
174            let mut buf = [0u8; 1024];
175
176            loop {
177                tokio::select! {
178                    _ = tick.tick() => {
179                        if done.load(Ordering::Relaxed) {
180                            return;
181                        }
182                        for addr in &targets {
183                            let _ = socket.send_to(PUNCH_SYN, addr).await;
184                        }
185                    }
186                    recv = socket.recv_from(&mut buf) => {
187                        let Ok((len, addr)) = recv else { continue; };
188                        if done.load(Ordering::Relaxed) {
189                            return;
190                        }
191                        if !targets.contains(&addr) {
192                            continue;
193                        }
194                        let data = &buf[..len];
195                        if data == PUNCH_SYN {
196                            let _ = socket.send_to(PUNCH_ACK, addr).await;
197                        } else if data == PUNCH_ACK {
198                            let _ = socket.send_to(PUNCH_ACK, addr).await;
199                        }
200                        done.store(true, Ordering::Relaxed);
201                        let _ = tx.send((socket, addr)).await;
202                        return;
203                    }
204                }
205            }
206        });
207    }
208
209    let timeout_ms = nat_cfg.punch_timeout_ms.max(500);
210    let result = timeout(Duration::from_millis(timeout_ms), rx.recv()).await;
211    match result {
212        Ok(Some((socket, addr))) => {
213            debug!(%addr, "hole punch success");
214            metrics::inc_counter("rift_hole_punch_success", &[]);
215            Ok((socket, addr))
216        }
217        _ => {
218            debug!("hole punch timeout");
219            metrics::inc_counter("rift_hole_punch_failures", &[("reason", "timeout")]);
220            Err(HolePunchError::Timeout)
221        }
222    }
223}
224
225/// Collect local (host) candidates for a given listen port.
226/// Loopback, unspecified, and link-local addresses are excluded.
227pub fn gather_local_candidates(listen_port: u16) -> Vec<SocketAddr> {
228    let mut addrs = Vec::new();
229    if let Ok(ifaces) = get_if_addrs() {
230        for iface in ifaces {
231            let ip = iface.ip();
232            if ip.is_loopback() || ip.is_unspecified() {
233                continue;
234            }
235            if let IpAddr::V6(v6) = ip {
236                if v6.is_unicast_link_local() {
237                    continue;
238                }
239            }
240            addrs.push(SocketAddr::new(ip, listen_port));
241        }
242    }
243    addrs.sort();
244    addrs.dedup();
245    addrs
246}
247
248/// Compare local and public address lists to detect NAT behavior.
249pub fn detect_nat_type(local_addrs: &[SocketAddr], public_addrs: &[SocketAddr]) -> NatType {
250    if public_addrs.is_empty() {
251        return NatType::Unknown;
252    }
253    for public in public_addrs {
254        if local_addrs.iter().any(|local| local == public) {
255            return NatType::OpenInternet;
256        }
257    }
258    NatType::Natted
259}
260
261/// Query STUN servers to discover public-facing addresses.
262pub async fn gather_public_addrs(nat_cfg: &NatConfig) -> Result<Vec<SocketAddr>, StunError> {
263    if nat_cfg.stun_servers.is_empty() {
264        return Err(StunError::NoServers);
265    }
266    let ports = if nat_cfg.local_ports.is_empty() {
267        vec![0]
268    } else {
269        nat_cfg.local_ports.clone()
270    };
271
272    let mut results = Vec::new();
273    for port in ports {
274        for server in &nat_cfg.stun_servers {
275            if let Ok(addr) = stun_binding_request(*server, port, nat_cfg.stun_timeout_ms).await {
276                results.push(addr);
277            }
278        }
279    }
280
281    results.sort();
282    results.dedup();
283    if results.is_empty() {
284        Err(StunError::NoResponses)
285    } else {
286        Ok(results)
287    }
288}
289
290/// Spawn periodic keep-alive packets to keep NAT bindings warm.
291pub fn spawn_keepalive(
292    socket: Arc<UdpSocket>,
293    targets: Vec<SocketAddr>,
294    interval_ms: u64,
295) -> tokio::task::JoinHandle<()> {
296    tokio::spawn(async move {
297        if targets.is_empty() {
298            return;
299        }
300        let mut tick = interval(Duration::from_millis(interval_ms.max(200)));
301        loop {
302            tick.tick().await;
303            for addr in &targets {
304                let _ = socket.send_to(KEEPALIVE_BYTES, addr).await;
305            }
306        }
307    })
308}
309
310/// Build all target socket addresses to try for hole punching.
311fn build_target_addrs(peer: &PeerEndpoint) -> Vec<SocketAddr> {
312    let mut addrs = Vec::new();
313    for addr in &peer.external_addrs {
314        addrs.push(*addr);
315        for port in &peer.punch_ports {
316            addrs.push(SocketAddr::new(addr.ip(), *port));
317        }
318    }
319    addrs.sort();
320    addrs.dedup();
321    addrs
322}
323
324/// Perform a single STUN binding request and parse the response.
325async fn stun_binding_request(
326    server: SocketAddr,
327    local_port: u16,
328    timeout_ms: u64,
329) -> Result<SocketAddr, StunError> {
330    let socket = match server.ip() {
331        IpAddr::V4(_) => UdpSocket::bind((Ipv4Addr::UNSPECIFIED, local_port)).await?,
332        IpAddr::V6(_) => UdpSocket::bind((IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), local_port)).await?,
333    };
334    let mut tx_id = [0u8; 12];
335    rand::rngs::OsRng.fill_bytes(&mut tx_id);
336
337    let mut req = Vec::with_capacity(20);
338    req.extend_from_slice(&STUN_BINDING_REQUEST.to_be_bytes());
339    req.extend_from_slice(&0u16.to_be_bytes());
340    req.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
341    req.extend_from_slice(&tx_id);
342
343    socket.send_to(&req, server).await?;
344    let mut buf = [0u8; 1024];
345    let (len, _) = timeout(Duration::from_millis(timeout_ms), socket.recv_from(&mut buf))
346        .await
347        .map_err(|_| StunError::NoResponses)??;
348    parse_stun_response(&buf[..len], &tx_id)
349}
350
351/// Parse a STUN binding response and extract the mapped address.
352fn parse_stun_response(buf: &[u8], tx_id: &[u8; 12]) -> Result<SocketAddr, StunError> {
353    if buf.len() < 20 {
354        return Err(StunError::InvalidResponse);
355    }
356    let msg_type = u16::from_be_bytes([buf[0], buf[1]]);
357    let msg_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
358    let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
359    if msg_type != STUN_BINDING_RESPONSE || cookie != STUN_MAGIC_COOKIE {
360        return Err(StunError::InvalidResponse);
361    }
362    if &buf[8..20] != tx_id {
363        return Err(StunError::InvalidResponse);
364    }
365
366    let mut offset = 20usize;
367    let end = 20 + msg_len.min(buf.len().saturating_sub(20));
368    while offset + 4 <= end {
369        let attr_type = u16::from_be_bytes([buf[offset], buf[offset + 1]]);
370        let attr_len = u16::from_be_bytes([buf[offset + 2], buf[offset + 3]]) as usize;
371        let value_start = offset + 4;
372        let value_end = value_start + attr_len;
373        if value_end > buf.len() {
374            break;
375        }
376        if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS || attr_type == STUN_ATTR_MAPPED_ADDRESS {
377            if let Ok(addr) = parse_mapped_address(&buf[value_start..value_end], attr_type, tx_id) {
378                return Ok(addr);
379            }
380        }
381        let padded = (attr_len + 3) & !3;
382        offset = value_start + padded;
383    }
384    Err(StunError::InvalidResponse)
385}
386
387fn parse_mapped_address(
388    value: &[u8],
389    attr_type: u16,
390    tx_id: &[u8; 12],
391) -> Result<SocketAddr, StunError> {
392    if value.len() < 4 {
393        return Err(StunError::InvalidResponse);
394    }
395    let family = value[1];
396    let port = u16::from_be_bytes([value[2], value[3]]);
397    let port = if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
398        port ^ ((STUN_MAGIC_COOKIE >> 16) as u16)
399    } else {
400        port
401    };
402    match family {
403        0x01 => {
404            if value.len() < 8 {
405                return Err(StunError::InvalidResponse);
406            }
407            let mut ip = [0u8; 4];
408            ip.copy_from_slice(&value[4..8]);
409            if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
410                let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
411                for i in 0..4 {
412                    ip[i] ^= cookie[i];
413                }
414            }
415            Ok(SocketAddr::new(IpAddr::V4(ip.into()), port))
416        }
417        0x02 => {
418            if value.len() < 20 {
419                return Err(StunError::InvalidResponse);
420            }
421            let mut ip = [0u8; 16];
422            ip.copy_from_slice(&value[4..20]);
423            if attr_type == STUN_ATTR_XOR_MAPPED_ADDRESS {
424                let mut xor = [0u8; 16];
425                xor[..4].copy_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
426                xor[4..].copy_from_slice(tx_id);
427                for i in 0..16 {
428                    ip[i] ^= xor[i];
429                }
430            }
431            Ok(SocketAddr::new(IpAddr::V6(ip.into()), port))
432        }
433        _ => Err(StunError::InvalidResponse),
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use tokio::task::JoinHandle;
441
442    async fn spawn_mock_stun(addr: SocketAddr, mapped: SocketAddr) -> JoinHandle<()> {
443        tokio::spawn(async move {
444            let socket = UdpSocket::bind(addr).await.expect("bind stun");
445            let mut buf = [0u8; 1024];
446            let Ok((len, peer)) = socket.recv_from(&mut buf).await else {
447                return;
448            };
449            if len < 20 {
450                return;
451            }
452            let tx_id: [u8; 12] = buf[8..20].try_into().unwrap();
453            let response = build_stun_response(&tx_id, mapped);
454            let _ = socket.send_to(&response, peer).await;
455        })
456    }
457
458    fn build_stun_response(tx_id: &[u8; 12], mapped: SocketAddr) -> Vec<u8> {
459        let mut out = Vec::with_capacity(64);
460        out.extend_from_slice(&STUN_BINDING_RESPONSE.to_be_bytes());
461        out.extend_from_slice(&0u16.to_be_bytes());
462        out.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
463        out.extend_from_slice(tx_id);
464
465        match mapped {
466            SocketAddr::V4(addr) => {
467                let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
468                let ip = u32::from(*addr.ip()) ^ STUN_MAGIC_COOKIE;
469                let mut attr = Vec::with_capacity(12);
470                attr.extend_from_slice(&STUN_ATTR_XOR_MAPPED_ADDRESS.to_be_bytes());
471                attr.extend_from_slice(&8u16.to_be_bytes());
472                attr.push(0);
473                attr.push(0x01);
474                attr.extend_from_slice(&port.to_be_bytes());
475                attr.extend_from_slice(&ip.to_be_bytes());
476                let len = attr.len() as u16;
477                out[2..4].copy_from_slice(&len.to_be_bytes());
478                out.extend_from_slice(&attr);
479            }
480            SocketAddr::V6(addr) => {
481                let port = addr.port() ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
482                let mut ip = addr.ip().octets();
483                let cookie = STUN_MAGIC_COOKIE.to_be_bytes();
484                for i in 0..4 {
485                    ip[i] ^= cookie[i];
486                }
487                for i in 0..12 {
488                    ip[4 + i] ^= tx_id[i];
489                }
490                let mut attr = Vec::with_capacity(24);
491                attr.extend_from_slice(&STUN_ATTR_XOR_MAPPED_ADDRESS.to_be_bytes());
492                attr.extend_from_slice(&20u16.to_be_bytes());
493                attr.push(0);
494                attr.push(0x02);
495                attr.extend_from_slice(&port.to_be_bytes());
496                attr.extend_from_slice(&ip);
497                let len = attr.len() as u16;
498                out[2..4].copy_from_slice(&len.to_be_bytes());
499                out.extend_from_slice(&attr);
500            }
501        }
502        out
503    }
504
505    #[tokio::test]
506    async fn stun_binding_returns_mapped_addr() {
507        let stun_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 34878);
508        let mapped = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 54321);
509        let _handle = spawn_mock_stun(stun_addr, mapped).await;
510
511        let addr = stun_binding_request(stun_addr, 0, 1000).await.unwrap();
512        assert_eq!(addr, mapped);
513    }
514
515    #[test]
516    fn local_candidates_exclude_loopback() {
517        let list = gather_local_candidates(9999);
518        for addr in list {
519            assert!(!addr.ip().is_loopback());
520        }
521    }
522}