rns_net/holepunch/
probe.rs1use std::io;
10use std::net::{SocketAddr, UdpSocket};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::thread;
14use std::time::Duration;
15
16const PROBE_MAGIC: &[u8; 4] = b"RNSP";
17const PROBE_VERSION: u8 = 1;
18const PROBE_REQUEST_LEN: usize = 21; const ADDR_TYPE_IPV4: u8 = 4;
20const ADDR_TYPE_IPV6: u8 = 6;
21
22pub fn start_probe_server(listen_addr: SocketAddr) -> io::Result<ProbeServerHandle> {
26 let socket = UdpSocket::bind(listen_addr)?;
27 socket.set_read_timeout(Some(Duration::from_secs(1)))?;
28
29 let running = Arc::new(AtomicBool::new(true));
30 let running_clone = running.clone();
31
32 let handle = thread::Builder::new()
33 .name("probe-server".into())
34 .spawn(move || {
35 run_probe_server(socket, running_clone);
36 })?;
37
38 Ok(ProbeServerHandle {
39 running,
40 thread: Some(handle),
41 })
42}
43
44fn run_probe_server(socket: UdpSocket, running: Arc<AtomicBool>) {
45 let mut buf = [0u8; 64];
46 while running.load(Ordering::Relaxed) {
47 let (len, src) = match socket.recv_from(&mut buf) {
48 Ok(r) => r,
49 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => {
50 continue;
51 }
52 Err(e) => {
53 log::warn!("Probe server recv error: {}", e);
54 continue;
55 }
56 };
57
58 if len != PROBE_REQUEST_LEN {
59 continue;
60 }
61 if &buf[..4] != PROBE_MAGIC {
62 continue;
63 }
64 if buf[4] != PROBE_VERSION {
65 continue;
66 }
67
68 let nonce = &buf[5..21];
69 let response = build_probe_response(nonce, &src);
70 if let Err(e) = socket.send_to(&response, src) {
71 log::debug!("Probe server send error: {}", e);
72 }
73 }
74}
75
76fn build_probe_response(nonce: &[u8], src: &SocketAddr) -> Vec<u8> {
77 let mut resp = Vec::with_capacity(36);
78 resp.extend_from_slice(PROBE_MAGIC);
79 resp.push(PROBE_VERSION);
80 resp.extend_from_slice(nonce);
81
82 match src {
83 SocketAddr::V4(addr) => {
84 resp.push(ADDR_TYPE_IPV4);
85 resp.extend_from_slice(&addr.ip().octets());
86 resp.extend_from_slice(&addr.port().to_be_bytes());
87 }
88 SocketAddr::V6(addr) => {
89 resp.push(ADDR_TYPE_IPV6);
90 resp.extend_from_slice(&addr.ip().octets());
91 resp.extend_from_slice(&addr.port().to_be_bytes());
92 }
93 }
94
95 resp
96}
97
98pub struct ProbeServerHandle {
100 running: Arc<AtomicBool>,
101 thread: Option<thread::JoinHandle<()>>,
102}
103
104impl ProbeServerHandle {
105 pub fn stop(&mut self) {
106 self.running.store(false, Ordering::Relaxed);
107 if let Some(handle) = self.thread.take() {
108 let _ = handle.join();
109 }
110 }
111}
112
113impl Drop for ProbeServerHandle {
114 fn drop(&mut self) {
115 self.stop();
116 }
117}
118
119pub fn probe_endpoint(
126 probe_server: SocketAddr,
127 existing_socket: Option<UdpSocket>,
128 timeout: Duration,
129 device: Option<&str>,
130) -> io::Result<(SocketAddr, UdpSocket)> {
131 let socket = match existing_socket {
132 Some(s) => s,
133 None => {
134 let bind_addr: SocketAddr = if probe_server.is_ipv4() {
135 "0.0.0.0:0".parse().unwrap()
136 } else {
137 "[::]:0".parse().unwrap()
138 };
139 let sock = UdpSocket::bind(bind_addr)?;
140 #[cfg(target_os = "linux")]
141 if let Some(dev) = device {
142 use std::os::unix::io::AsRawFd;
143 crate::interface::bind_to_device(sock.as_raw_fd(), dev)?;
144 }
145 sock
146 }
147 };
148 socket.set_read_timeout(Some(timeout))?;
149
150 let mut nonce = [0u8; 16];
152 let now = std::time::SystemTime::now()
153 .duration_since(std::time::UNIX_EPOCH)
154 .unwrap_or_default();
155 let nanos = now.as_nanos();
156 nonce[..8].copy_from_slice(&nanos.to_le_bytes()[..8]);
157 let local_port = socket.local_addr().map(|a| a.port()).unwrap_or(0);
159 nonce[8..10].copy_from_slice(&local_port.to_be_bytes());
160 let thread_id = std::thread::current().id();
161 let thread_hash = format!("{:?}", thread_id);
162 for (i, b) in thread_hash.bytes().enumerate() {
163 if 10 + i >= 16 { break; }
164 nonce[10 + i] = b;
165 }
166
167 let mut request = Vec::with_capacity(PROBE_REQUEST_LEN);
168 request.extend_from_slice(PROBE_MAGIC);
169 request.push(PROBE_VERSION);
170 request.extend_from_slice(&nonce);
171
172 socket.send_to(&request, probe_server)?;
173
174 let mut buf = [0u8; 64];
176 let (len, _) = socket.recv_from(&mut buf)?;
177
178 parse_probe_response(&buf[..len], &nonce)
179 .map(|addr| (addr, socket))
180 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid probe response"))
181}
182
183fn parse_probe_response(data: &[u8], expected_nonce: &[u8; 16]) -> Option<SocketAddr> {
184 if data.len() < 24 {
185 return None;
186 }
187 if &data[..4] != PROBE_MAGIC {
188 return None;
189 }
190 if data[4] != PROBE_VERSION {
191 return None;
192 }
193 if &data[5..21] != expected_nonce {
194 return None;
195 }
196
197 let addr_type = data[21];
198 match addr_type {
199 ADDR_TYPE_IPV4 => {
200 if data.len() < 28 {
201 return None;
202 }
203 let ip = std::net::Ipv4Addr::new(data[22], data[23], data[24], data[25]);
204 let port = u16::from_be_bytes([data[26], data[27]]);
205 Some(SocketAddr::new(ip.into(), port))
206 }
207 ADDR_TYPE_IPV6 => {
208 if data.len() < 40 {
209 return None;
210 }
211 let mut octets = [0u8; 16];
212 octets.copy_from_slice(&data[22..38]);
213 let ip = std::net::Ipv6Addr::from(octets);
214 let port = u16::from_be_bytes([data[38], data[39]]);
215 Some(SocketAddr::new(ip.into(), port))
216 }
217 _ => None,
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_probe_server_and_client() {
227 let server_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
229 let socket = UdpSocket::bind(server_addr).unwrap();
230 let actual_addr = socket.local_addr().unwrap();
231 socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
232
233 let running = Arc::new(AtomicBool::new(true));
234 let running_clone = running.clone();
235 let server_thread = thread::spawn(move || {
236 run_probe_server(socket, running_clone);
237 });
238
239 thread::sleep(Duration::from_millis(50));
241
242 let (observed, _socket) = probe_endpoint(
244 actual_addr,
245 None,
246 Duration::from_secs(3),
247 None,
248 ).unwrap();
249
250 assert_eq!(observed.ip(), std::net::Ipv4Addr::new(127, 0, 0, 1));
252 assert!(observed.port() > 0);
253
254 running.store(false, Ordering::Relaxed);
256 let _ = server_thread.join();
257 }
258
259 #[test]
260 fn test_probe_response_roundtrip() {
261 let nonce = [0x42u8; 16];
262 let addr: SocketAddr = "1.2.3.4:41000".parse().unwrap();
263 let response = build_probe_response(&nonce, &addr);
264 let parsed = parse_probe_response(&response, &nonce).unwrap();
265 assert_eq!(parsed, addr);
266 }
267
268 #[test]
269 fn test_probe_response_ipv6() {
270 let nonce = [0x42u8; 16];
271 let addr: SocketAddr = "[::1]:52000".parse().unwrap();
272 let response = build_probe_response(&nonce, &addr);
273 let parsed = parse_probe_response(&response, &nonce).unwrap();
274 assert_eq!(parsed, addr);
275 }
276
277 #[test]
278 fn test_probe_response_bad_nonce() {
279 let nonce = [0x42u8; 16];
280 let addr: SocketAddr = "1.2.3.4:41000".parse().unwrap();
281 let response = build_probe_response(&nonce, &addr);
282 let wrong_nonce = [0x99u8; 16];
283 assert!(parse_probe_response(&response, &wrong_nonce).is_none());
284 }
285}