1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11
12use bytes::{BufMut, Bytes, BytesMut};
13
14use crate::error::{Error, Result};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum HolepunchMsgType {
20 Rendezvous = 0x00,
22 Connect = 0x01,
24 Error = 0x02,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30#[repr(u32)]
31pub enum HolepunchError {
32 NoSuchPeer = 1,
34 NotConnected = 2,
36 NoSupport = 3,
38 NoSelf = 4,
40}
41
42impl HolepunchError {
43 #[must_use]
45 pub fn from_u32(code: u32) -> Option<Self> {
46 match code {
47 1 => Some(Self::NoSuchPeer),
48 2 => Some(Self::NotConnected),
49 3 => Some(Self::NoSupport),
50 4 => Some(Self::NoSelf),
51 _ => None,
52 }
53 }
54}
55
56impl std::fmt::Display for HolepunchError {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::NoSuchPeer => write!(f, "no such peer"),
60 Self::NotConnected => write!(f, "not connected to target"),
61 Self::NoSupport => write!(f, "target does not support holepunch"),
62 Self::NoSelf => write!(f, "cannot holepunch to self"),
63 }
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct HolepunchMessage {
70 pub msg_type: HolepunchMsgType,
72 pub addr: SocketAddr,
74 pub error_code: u32,
76}
77
78impl HolepunchMessage {
79 #[must_use]
81 pub fn rendezvous(target: SocketAddr) -> Self {
82 Self {
83 msg_type: HolepunchMsgType::Rendezvous,
84 addr: target,
85 error_code: 0,
86 }
87 }
88
89 #[must_use]
91 pub fn connect(addr: SocketAddr) -> Self {
92 Self {
93 msg_type: HolepunchMsgType::Connect,
94 addr,
95 error_code: 0,
96 }
97 }
98
99 #[must_use]
101 pub fn error(addr: SocketAddr, error: HolepunchError) -> Self {
102 Self {
103 msg_type: HolepunchMsgType::Error,
104 addr,
105 error_code: error as u32,
106 }
107 }
108
109 fn wire_size(&self) -> usize {
111 let addr_len = match self.addr.ip() {
113 IpAddr::V4(_) => 4,
114 IpAddr::V6(_) => 16,
115 };
116 1 + 1 + addr_len + 2 + 4
117 }
118
119 #[must_use]
121 pub fn to_bytes(&self) -> Bytes {
122 let mut buf = BytesMut::with_capacity(self.wire_size());
123 buf.put_u8(self.msg_type as u8);
124 match self.addr.ip() {
125 IpAddr::V4(ip) => {
126 buf.put_u8(0x00);
127 buf.put_slice(&ip.octets());
128 }
129 IpAddr::V6(ip) => {
130 buf.put_u8(0x01);
131 buf.put_slice(&ip.octets());
132 }
133 }
134 buf.put_u16(self.addr.port());
135 buf.put_u32(self.error_code);
136 buf.freeze()
137 }
138
139 pub fn from_bytes(data: &[u8]) -> Result<Self> {
145 if data.len() < 2 {
146 return Err(Error::InvalidExtended("holepunch message too short".into()));
147 }
148
149 let msg_type = match data[0] {
150 0x00 => HolepunchMsgType::Rendezvous,
151 0x01 => HolepunchMsgType::Connect,
152 0x02 => HolepunchMsgType::Error,
153 n => {
154 return Err(Error::InvalidExtended(format!(
155 "unknown holepunch msg_type {n:#04x}"
156 )));
157 }
158 };
159
160 let addr_type = data[1];
161 let (addr_len, expected_total) = match addr_type {
162 0x00 => (4usize, 12usize), 0x01 => (16usize, 24usize), n => {
165 return Err(Error::InvalidExtended(format!(
166 "unknown holepunch addr_type {n:#04x}"
167 )));
168 }
169 };
170
171 if data.len() < expected_total {
172 return Err(Error::InvalidExtended(format!(
173 "holepunch message too short: need {expected_total} bytes, got {}",
174 data.len()
175 )));
176 }
177
178 let addr_start = 2;
179 let ip: IpAddr = if addr_type == 0x00 {
180 let o = &data[addr_start..addr_start + 4];
181 IpAddr::V4(Ipv4Addr::new(o[0], o[1], o[2], o[3]))
182 } else {
183 let mut octets = [0u8; 16];
184 octets.copy_from_slice(&data[addr_start..addr_start + 16]);
185 IpAddr::V6(Ipv6Addr::from(octets))
186 };
187
188 let port_start = addr_start + addr_len;
189 let port = u16::from_be_bytes([data[port_start], data[port_start + 1]]);
190
191 let err_start = port_start + 2;
192 let error_code = u32::from_be_bytes([
193 data[err_start],
194 data[err_start + 1],
195 data[err_start + 2],
196 data[err_start + 3],
197 ]);
198
199 Ok(Self {
200 msg_type,
201 addr: SocketAddr::new(ip, port),
202 error_code,
203 })
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn rendezvous_ipv4_round_trip() {
213 let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
214 let msg = HolepunchMessage::rendezvous(addr);
215 assert_eq!(msg.msg_type, HolepunchMsgType::Rendezvous);
216 assert_eq!(msg.addr, addr);
217 assert_eq!(msg.error_code, 0);
218
219 let bytes = msg.to_bytes();
220 assert_eq!(bytes.len(), 12); let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
223 assert_eq!(parsed, msg);
224 }
225
226 #[test]
227 fn connect_ipv4_round_trip() {
228 let addr: SocketAddr = "10.0.0.1:8080".parse().unwrap();
229 let msg = HolepunchMessage::connect(addr);
230 assert_eq!(msg.msg_type, HolepunchMsgType::Connect);
231
232 let bytes = msg.to_bytes();
233 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
234 assert_eq!(parsed, msg);
235 }
236
237 #[test]
238 fn error_ipv4_round_trip() {
239 let addr: SocketAddr = "172.16.0.5:51413".parse().unwrap();
240 let msg = HolepunchMessage::error(addr, HolepunchError::NotConnected);
241 assert_eq!(msg.msg_type, HolepunchMsgType::Error);
242 assert_eq!(msg.error_code, 2);
243
244 let bytes = msg.to_bytes();
245 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
246 assert_eq!(parsed, msg);
247 }
248
249 #[test]
250 fn rendezvous_ipv6_round_trip() {
251 let addr: SocketAddr = "[2001:db8::1]:6881".parse().unwrap();
252 let msg = HolepunchMessage::rendezvous(addr);
253
254 let bytes = msg.to_bytes();
255 assert_eq!(bytes.len(), 24); let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
258 assert_eq!(parsed, msg);
259 }
260
261 #[test]
262 fn connect_ipv6_round_trip() {
263 let addr: SocketAddr = "[::1]:8080".parse().unwrap();
264 let msg = HolepunchMessage::connect(addr);
265
266 let bytes = msg.to_bytes();
267 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
268 assert_eq!(parsed, msg);
269 }
270
271 #[test]
272 fn error_ipv6_all_error_codes() {
273 let addr: SocketAddr = "[fe80::1]:9999".parse().unwrap();
274
275 for (code, variant) in [
276 (1, HolepunchError::NoSuchPeer),
277 (2, HolepunchError::NotConnected),
278 (3, HolepunchError::NoSupport),
279 (4, HolepunchError::NoSelf),
280 ] {
281 let msg = HolepunchMessage::error(addr, variant);
282 assert_eq!(msg.error_code, code);
283
284 let bytes = msg.to_bytes();
285 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
286 assert_eq!(parsed.error_code, code);
287 assert_eq!(HolepunchError::from_u32(code), Some(variant));
288 }
289 }
290
291 #[test]
292 fn unknown_msg_type_rejected() {
293 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
294 .to_bytes()
295 .to_vec();
296 data[0] = 0x03; assert!(HolepunchMessage::from_bytes(&data).is_err());
298 }
299
300 #[test]
301 fn unknown_addr_type_rejected() {
302 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
303 .to_bytes()
304 .to_vec();
305 data[1] = 0x02; assert!(HolepunchMessage::from_bytes(&data).is_err());
307 }
308
309 #[test]
310 fn too_short_rejected() {
311 assert!(HolepunchMessage::from_bytes(&[]).is_err());
312 assert!(HolepunchMessage::from_bytes(&[0x00]).is_err());
313 assert!(HolepunchMessage::from_bytes(&[0x00, 0x00, 1, 2, 3, 4, 0, 80]).is_err());
315 }
316
317 #[test]
318 fn ipv6_too_short_rejected() {
319 let data = [0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
321 assert!(HolepunchMessage::from_bytes(&data).is_err());
322 }
323
324 #[test]
325 fn error_code_unknown_parses_as_none() {
326 assert!(HolepunchError::from_u32(0).is_none());
327 assert!(HolepunchError::from_u32(5).is_none());
328 assert!(HolepunchError::from_u32(u32::MAX).is_none());
329 }
330
331 #[test]
332 fn error_display() {
333 assert_eq!(HolepunchError::NoSuchPeer.to_string(), "no such peer");
334 assert_eq!(
335 HolepunchError::NotConnected.to_string(),
336 "not connected to target"
337 );
338 assert_eq!(
339 HolepunchError::NoSupport.to_string(),
340 "target does not support holepunch"
341 );
342 assert_eq!(
343 HolepunchError::NoSelf.to_string(),
344 "cannot holepunch to self"
345 );
346 }
347
348 #[test]
349 fn wire_size_ipv4() {
350 let msg = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap());
351 assert_eq!(msg.wire_size(), 12);
352 }
353
354 #[test]
355 fn wire_size_ipv6() {
356 let msg = HolepunchMessage::rendezvous("[::1]:80".parse().unwrap());
357 assert_eq!(msg.wire_size(), 24);
358 }
359
360 #[test]
361 fn exact_wire_bytes_ipv4_rendezvous() {
362 let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
363 let msg = HolepunchMessage::rendezvous(addr);
364 let bytes = msg.to_bytes();
365
366 assert_eq!(bytes[0], 0x00); assert_eq!(bytes[1], 0x00); assert_eq!(&bytes[2..6], &[192, 168, 1, 100]); assert_eq!(u16::from_be_bytes([bytes[6], bytes[7]]), 6881); assert_eq!(
371 u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
372 0
373 ); }
375
376 #[test]
377 fn extra_trailing_bytes_ignored() {
378 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
379 .to_bytes()
380 .to_vec();
381 data.push(0xFF);
382 data.push(0xAA);
383 let parsed = HolepunchMessage::from_bytes(&data).unwrap();
384 assert_eq!(parsed.msg_type, HolepunchMsgType::Rendezvous);
385 assert_eq!(parsed.addr, "1.2.3.4:80".parse().unwrap());
386 }
387
388 #[test]
389 fn port_zero_accepted() {
390 let msg = HolepunchMessage::rendezvous("1.2.3.4:0".parse().unwrap());
391 let bytes = msg.to_bytes();
392 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
393 assert_eq!(parsed.addr.port(), 0);
394 }
395}