Skip to main content

rns_net/holepunch/
probe.rs

1//! STUN-like probe service for endpoint discovery.
2//!
3//! Wire format (raw UDP, outside Reticulum framing):
4//!
5//! Request:  [MAGIC:"RNSP" 4B] [VERSION:1B] [NONCE:16B]          = 21 bytes
6//! Response: [MAGIC:"RNSP" 4B] [VERSION:1B] [NONCE:16B (echo)]
7//!           [ADDR_TYPE:1B (4=IPv4,6=IPv6)] [ADDR:4|16B] [PORT:2B] = 24 or 36 bytes
8
9use std::io;
10use std::net::{SocketAddr, UdpSocket};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::thread;
14use std::time::Duration;
15
16const PROBE_MAGIC: &[u8; 4] = b"RNSP";
17const PROBE_VERSION: u8 = 1;
18const PROBE_REQUEST_LEN: usize = 21;  // 4 + 1 + 16
19const ADDR_TYPE_IPV4: u8 = 4;
20const ADDR_TYPE_IPV6: u8 = 6;
21
22/// Start a probe server on the given address. Runs in a background thread.
23///
24/// Returns a handle to stop the server.
25pub fn start_probe_server(listen_addr: SocketAddr) -> io::Result<ProbeServerHandle> {
26    let socket = UdpSocket::bind(listen_addr)?;
27    socket.set_read_timeout(Some(Duration::from_secs(1)))?;
28
29    let running = Arc::new(AtomicBool::new(true));
30    let running_clone = running.clone();
31
32    let handle = thread::Builder::new()
33        .name("probe-server".into())
34        .spawn(move || {
35            run_probe_server(socket, running_clone);
36        })?;
37
38    Ok(ProbeServerHandle {
39        running,
40        thread: Some(handle),
41    })
42}
43
44fn run_probe_server(socket: UdpSocket, running: Arc<AtomicBool>) {
45    let mut buf = [0u8; 64];
46    while running.load(Ordering::Relaxed) {
47        let (len, src) = match socket.recv_from(&mut buf) {
48            Ok(r) => r,
49            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => {
50                continue;
51            }
52            Err(e) => {
53                log::warn!("Probe server recv error: {}", e);
54                continue;
55            }
56        };
57
58        if len != PROBE_REQUEST_LEN {
59            continue;
60        }
61        if &buf[..4] != PROBE_MAGIC {
62            continue;
63        }
64        if buf[4] != PROBE_VERSION {
65            continue;
66        }
67
68        let nonce = &buf[5..21];
69        let response = build_probe_response(nonce, &src);
70        if let Err(e) = socket.send_to(&response, src) {
71            log::debug!("Probe server send error: {}", e);
72        }
73    }
74}
75
76fn build_probe_response(nonce: &[u8], src: &SocketAddr) -> Vec<u8> {
77    let mut resp = Vec::with_capacity(36);
78    resp.extend_from_slice(PROBE_MAGIC);
79    resp.push(PROBE_VERSION);
80    resp.extend_from_slice(nonce);
81
82    match src {
83        SocketAddr::V4(addr) => {
84            resp.push(ADDR_TYPE_IPV4);
85            resp.extend_from_slice(&addr.ip().octets());
86            resp.extend_from_slice(&addr.port().to_be_bytes());
87        }
88        SocketAddr::V6(addr) => {
89            resp.push(ADDR_TYPE_IPV6);
90            resp.extend_from_slice(&addr.ip().octets());
91            resp.extend_from_slice(&addr.port().to_be_bytes());
92        }
93    }
94
95    resp
96}
97
98/// Handle to a running probe server. Stops the server when dropped.
99pub struct ProbeServerHandle {
100    running: Arc<AtomicBool>,
101    thread: Option<thread::JoinHandle<()>>,
102}
103
104impl ProbeServerHandle {
105    pub fn stop(&mut self) {
106        self.running.store(false, Ordering::Relaxed);
107        if let Some(handle) = self.thread.take() {
108            let _ = handle.join();
109        }
110    }
111}
112
113impl Drop for ProbeServerHandle {
114    fn drop(&mut self) {
115        self.stop();
116    }
117}
118
119/// Probe client: discover our public endpoint by sending a probe to a server.
120///
121/// Binds a new UDP socket (or uses an existing one), sends a probe request,
122/// and returns the observed public endpoint.
123///
124/// The socket is returned so it can be reused for hole punching (same NAT mapping).
125pub fn probe_endpoint(
126    probe_server: SocketAddr,
127    existing_socket: Option<UdpSocket>,
128    timeout: Duration,
129    device: Option<&str>,
130) -> io::Result<(SocketAddr, UdpSocket)> {
131    let socket = match existing_socket {
132        Some(s) => s,
133        None => {
134            let bind_addr: SocketAddr = if probe_server.is_ipv4() {
135                "0.0.0.0:0".parse().unwrap()
136            } else {
137                "[::]:0".parse().unwrap()
138            };
139            let sock = UdpSocket::bind(bind_addr)?;
140            #[cfg(target_os = "linux")]
141            if let Some(dev) = device {
142                use std::os::unix::io::AsRawFd;
143                crate::interface::bind_to_device(sock.as_raw_fd(), dev)?;
144            }
145            sock
146        }
147    };
148    socket.set_read_timeout(Some(timeout))?;
149
150    // Build request with a nonce for response matching
151    let mut nonce = [0u8; 16];
152    let now = std::time::SystemTime::now()
153        .duration_since(std::time::UNIX_EPOCH)
154        .unwrap_or_default();
155    let nanos = now.as_nanos();
156    nonce[..8].copy_from_slice(&nanos.to_le_bytes()[..8]);
157    // Fill remaining bytes: local port + thread ID bits + subsec nanos (reversed)
158    let local_port = socket.local_addr().map(|a| a.port()).unwrap_or(0);
159    nonce[8..10].copy_from_slice(&local_port.to_be_bytes());
160    let thread_id = std::thread::current().id();
161    let thread_hash = format!("{:?}", thread_id);
162    for (i, b) in thread_hash.bytes().enumerate() {
163        if 10 + i >= 16 { break; }
164        nonce[10 + i] = b;
165    }
166
167    let mut request = Vec::with_capacity(PROBE_REQUEST_LEN);
168    request.extend_from_slice(PROBE_MAGIC);
169    request.push(PROBE_VERSION);
170    request.extend_from_slice(&nonce);
171
172    socket.send_to(&request, probe_server)?;
173
174    // Wait for response
175    let mut buf = [0u8; 64];
176    let (len, _) = socket.recv_from(&mut buf)?;
177
178    parse_probe_response(&buf[..len], &nonce)
179        .map(|addr| (addr, socket))
180        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid probe response"))
181}
182
183fn parse_probe_response(data: &[u8], expected_nonce: &[u8; 16]) -> Option<SocketAddr> {
184    if data.len() < 24 {
185        return None;
186    }
187    if &data[..4] != PROBE_MAGIC {
188        return None;
189    }
190    if data[4] != PROBE_VERSION {
191        return None;
192    }
193    if &data[5..21] != expected_nonce {
194        return None;
195    }
196
197    let addr_type = data[21];
198    match addr_type {
199        ADDR_TYPE_IPV4 => {
200            if data.len() < 28 {
201                return None;
202            }
203            let ip = std::net::Ipv4Addr::new(data[22], data[23], data[24], data[25]);
204            let port = u16::from_be_bytes([data[26], data[27]]);
205            Some(SocketAddr::new(ip.into(), port))
206        }
207        ADDR_TYPE_IPV6 => {
208            if data.len() < 40 {
209                return None;
210            }
211            let mut octets = [0u8; 16];
212            octets.copy_from_slice(&data[22..38]);
213            let ip = std::net::Ipv6Addr::from(octets);
214            let port = u16::from_be_bytes([data[38], data[39]]);
215            Some(SocketAddr::new(ip.into(), port))
216        }
217        _ => None,
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_probe_server_and_client() {
227        // Start probe server on a random port
228        let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
229        let socket = UdpSocket::bind(server_addr).unwrap();
230        let actual_addr = socket.local_addr().unwrap();
231        socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
232
233        let running = Arc::new(AtomicBool::new(true));
234        let running_clone = running.clone();
235        let server_thread = thread::spawn(move || {
236            run_probe_server(socket, running_clone);
237        });
238
239        // Give server a moment to start
240        thread::sleep(Duration::from_millis(50));
241
242        // Probe from client
243        let (observed, _socket) = probe_endpoint(
244            actual_addr,
245            None,
246            Duration::from_secs(3),
247            None,
248        ).unwrap();
249
250        // Since we're on localhost, the observed address should be 127.0.0.1
251        assert_eq!(observed.ip(), std::net::Ipv4Addr::new(127, 0, 0, 1));
252        assert!(observed.port() > 0);
253
254        // Stop server
255        running.store(false, Ordering::Relaxed);
256        let _ = server_thread.join();
257    }
258
259    #[test]
260    fn test_probe_response_roundtrip() {
261        let nonce = [0x42u8; 16];
262        let addr: SocketAddr = "1.2.3.4:41000".parse().unwrap();
263        let response = build_probe_response(&nonce, &addr);
264        let parsed = parse_probe_response(&response, &nonce).unwrap();
265        assert_eq!(parsed, addr);
266    }
267
268    #[test]
269    fn test_probe_response_ipv6() {
270        let nonce = [0x42u8; 16];
271        let addr: SocketAddr = "[::1]:52000".parse().unwrap();
272        let response = build_probe_response(&nonce, &addr);
273        let parsed = parse_probe_response(&response, &nonce).unwrap();
274        assert_eq!(parsed, addr);
275    }
276
277    #[test]
278    fn test_probe_response_bad_nonce() {
279        let nonce = [0x42u8; 16];
280        let addr: SocketAddr = "1.2.3.4:41000".parse().unwrap();
281        let response = build_probe_response(&nonce, &addr);
282        let wrong_nonce = [0x99u8; 16];
283        assert!(parse_probe_response(&response, &wrong_nonce).is_none());
284    }
285}