1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11
12use bytes::{BufMut, Bytes, BytesMut};
13
14use crate::error::{Error, Result};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum HolepunchMsgType {
20 Rendezvous = 0x00,
22 Connect = 0x01,
24 Error = 0x02,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30#[repr(u32)]
31pub enum HolepunchError {
32 NoSuchPeer = 1,
34 NotConnected = 2,
36 NoSupport = 3,
38 NoSelf = 4,
40}
41
42impl HolepunchError {
43 pub fn from_u32(code: u32) -> Option<Self> {
45 match code {
46 1 => Some(Self::NoSuchPeer),
47 2 => Some(Self::NotConnected),
48 3 => Some(Self::NoSupport),
49 4 => Some(Self::NoSelf),
50 _ => None,
51 }
52 }
53}
54
55impl std::fmt::Display for HolepunchError {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::NoSuchPeer => write!(f, "no such peer"),
59 Self::NotConnected => write!(f, "not connected to target"),
60 Self::NoSupport => write!(f, "target does not support holepunch"),
61 Self::NoSelf => write!(f, "cannot holepunch to self"),
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct HolepunchMessage {
69 pub msg_type: HolepunchMsgType,
71 pub addr: SocketAddr,
73 pub error_code: u32,
75}
76
77impl HolepunchMessage {
78 pub fn rendezvous(target: SocketAddr) -> Self {
80 Self {
81 msg_type: HolepunchMsgType::Rendezvous,
82 addr: target,
83 error_code: 0,
84 }
85 }
86
87 pub fn connect(addr: SocketAddr) -> Self {
89 Self {
90 msg_type: HolepunchMsgType::Connect,
91 addr,
92 error_code: 0,
93 }
94 }
95
96 pub fn error(addr: SocketAddr, error: HolepunchError) -> Self {
98 Self {
99 msg_type: HolepunchMsgType::Error,
100 addr,
101 error_code: error as u32,
102 }
103 }
104
105 fn wire_size(&self) -> usize {
107 let addr_len = match self.addr.ip() {
109 IpAddr::V4(_) => 4,
110 IpAddr::V6(_) => 16,
111 };
112 1 + 1 + addr_len + 2 + 4
113 }
114
115 pub fn to_bytes(&self) -> Bytes {
117 let mut buf = BytesMut::with_capacity(self.wire_size());
118 buf.put_u8(self.msg_type as u8);
119 match self.addr.ip() {
120 IpAddr::V4(ip) => {
121 buf.put_u8(0x00);
122 buf.put_slice(&ip.octets());
123 }
124 IpAddr::V6(ip) => {
125 buf.put_u8(0x01);
126 buf.put_slice(&ip.octets());
127 }
128 }
129 buf.put_u16(self.addr.port());
130 buf.put_u32(self.error_code);
131 buf.freeze()
132 }
133
134 pub fn from_bytes(data: &[u8]) -> Result<Self> {
136 if data.len() < 2 {
137 return Err(Error::InvalidExtended("holepunch message too short".into()));
138 }
139
140 let msg_type = match data[0] {
141 0x00 => HolepunchMsgType::Rendezvous,
142 0x01 => HolepunchMsgType::Connect,
143 0x02 => HolepunchMsgType::Error,
144 n => {
145 return Err(Error::InvalidExtended(format!(
146 "unknown holepunch msg_type {n:#04x}"
147 )));
148 }
149 };
150
151 let addr_type = data[1];
152 let (addr_len, expected_total) = match addr_type {
153 0x00 => (4usize, 12usize), 0x01 => (16usize, 24usize), n => {
156 return Err(Error::InvalidExtended(format!(
157 "unknown holepunch addr_type {n:#04x}"
158 )));
159 }
160 };
161
162 if data.len() < expected_total {
163 return Err(Error::InvalidExtended(format!(
164 "holepunch message too short: need {expected_total} bytes, got {}",
165 data.len()
166 )));
167 }
168
169 let addr_start = 2;
170 let ip: IpAddr = if addr_type == 0x00 {
171 let o = &data[addr_start..addr_start + 4];
172 IpAddr::V4(Ipv4Addr::new(o[0], o[1], o[2], o[3]))
173 } else {
174 let mut octets = [0u8; 16];
175 octets.copy_from_slice(&data[addr_start..addr_start + 16]);
176 IpAddr::V6(Ipv6Addr::from(octets))
177 };
178
179 let port_start = addr_start + addr_len;
180 let port = u16::from_be_bytes([data[port_start], data[port_start + 1]]);
181
182 let err_start = port_start + 2;
183 let error_code = u32::from_be_bytes([
184 data[err_start],
185 data[err_start + 1],
186 data[err_start + 2],
187 data[err_start + 3],
188 ]);
189
190 Ok(HolepunchMessage {
191 msg_type,
192 addr: SocketAddr::new(ip, port),
193 error_code,
194 })
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn rendezvous_ipv4_round_trip() {
204 let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
205 let msg = HolepunchMessage::rendezvous(addr);
206 assert_eq!(msg.msg_type, HolepunchMsgType::Rendezvous);
207 assert_eq!(msg.addr, addr);
208 assert_eq!(msg.error_code, 0);
209
210 let bytes = msg.to_bytes();
211 assert_eq!(bytes.len(), 12); let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
214 assert_eq!(parsed, msg);
215 }
216
217 #[test]
218 fn connect_ipv4_round_trip() {
219 let addr: SocketAddr = "10.0.0.1:8080".parse().unwrap();
220 let msg = HolepunchMessage::connect(addr);
221 assert_eq!(msg.msg_type, HolepunchMsgType::Connect);
222
223 let bytes = msg.to_bytes();
224 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
225 assert_eq!(parsed, msg);
226 }
227
228 #[test]
229 fn error_ipv4_round_trip() {
230 let addr: SocketAddr = "172.16.0.5:51413".parse().unwrap();
231 let msg = HolepunchMessage::error(addr, HolepunchError::NotConnected);
232 assert_eq!(msg.msg_type, HolepunchMsgType::Error);
233 assert_eq!(msg.error_code, 2);
234
235 let bytes = msg.to_bytes();
236 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
237 assert_eq!(parsed, msg);
238 }
239
240 #[test]
241 fn rendezvous_ipv6_round_trip() {
242 let addr: SocketAddr = "[2001:db8::1]:6881".parse().unwrap();
243 let msg = HolepunchMessage::rendezvous(addr);
244
245 let bytes = msg.to_bytes();
246 assert_eq!(bytes.len(), 24); let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
249 assert_eq!(parsed, msg);
250 }
251
252 #[test]
253 fn connect_ipv6_round_trip() {
254 let addr: SocketAddr = "[::1]:8080".parse().unwrap();
255 let msg = HolepunchMessage::connect(addr);
256
257 let bytes = msg.to_bytes();
258 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
259 assert_eq!(parsed, msg);
260 }
261
262 #[test]
263 fn error_ipv6_all_error_codes() {
264 let addr: SocketAddr = "[fe80::1]:9999".parse().unwrap();
265
266 for (code, variant) in [
267 (1, HolepunchError::NoSuchPeer),
268 (2, HolepunchError::NotConnected),
269 (3, HolepunchError::NoSupport),
270 (4, HolepunchError::NoSelf),
271 ] {
272 let msg = HolepunchMessage::error(addr, variant);
273 assert_eq!(msg.error_code, code);
274
275 let bytes = msg.to_bytes();
276 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
277 assert_eq!(parsed.error_code, code);
278 assert_eq!(HolepunchError::from_u32(code), Some(variant));
279 }
280 }
281
282 #[test]
283 fn unknown_msg_type_rejected() {
284 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
285 .to_bytes()
286 .to_vec();
287 data[0] = 0x03; assert!(HolepunchMessage::from_bytes(&data).is_err());
289 }
290
291 #[test]
292 fn unknown_addr_type_rejected() {
293 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
294 .to_bytes()
295 .to_vec();
296 data[1] = 0x02; assert!(HolepunchMessage::from_bytes(&data).is_err());
298 }
299
300 #[test]
301 fn too_short_rejected() {
302 assert!(HolepunchMessage::from_bytes(&[]).is_err());
303 assert!(HolepunchMessage::from_bytes(&[0x00]).is_err());
304 assert!(HolepunchMessage::from_bytes(&[0x00, 0x00, 1, 2, 3, 4, 0, 80]).is_err());
306 }
307
308 #[test]
309 fn ipv6_too_short_rejected() {
310 let data = [0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
312 assert!(HolepunchMessage::from_bytes(&data).is_err());
313 }
314
315 #[test]
316 fn error_code_unknown_parses_as_none() {
317 assert!(HolepunchError::from_u32(0).is_none());
318 assert!(HolepunchError::from_u32(5).is_none());
319 assert!(HolepunchError::from_u32(u32::MAX).is_none());
320 }
321
322 #[test]
323 fn error_display() {
324 assert_eq!(HolepunchError::NoSuchPeer.to_string(), "no such peer");
325 assert_eq!(
326 HolepunchError::NotConnected.to_string(),
327 "not connected to target"
328 );
329 assert_eq!(
330 HolepunchError::NoSupport.to_string(),
331 "target does not support holepunch"
332 );
333 assert_eq!(
334 HolepunchError::NoSelf.to_string(),
335 "cannot holepunch to self"
336 );
337 }
338
339 #[test]
340 fn wire_size_ipv4() {
341 let msg = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap());
342 assert_eq!(msg.wire_size(), 12);
343 }
344
345 #[test]
346 fn wire_size_ipv6() {
347 let msg = HolepunchMessage::rendezvous("[::1]:80".parse().unwrap());
348 assert_eq!(msg.wire_size(), 24);
349 }
350
351 #[test]
352 fn exact_wire_bytes_ipv4_rendezvous() {
353 let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
354 let msg = HolepunchMessage::rendezvous(addr);
355 let bytes = msg.to_bytes();
356
357 assert_eq!(bytes[0], 0x00); assert_eq!(bytes[1], 0x00); assert_eq!(&bytes[2..6], &[192, 168, 1, 100]); assert_eq!(u16::from_be_bytes([bytes[6], bytes[7]]), 6881); assert_eq!(
362 u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
363 0
364 ); }
366
367 #[test]
368 fn extra_trailing_bytes_ignored() {
369 let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
370 .to_bytes()
371 .to_vec();
372 data.push(0xFF);
373 data.push(0xAA);
374 let parsed = HolepunchMessage::from_bytes(&data).unwrap();
375 assert_eq!(parsed.msg_type, HolepunchMsgType::Rendezvous);
376 assert_eq!(parsed.addr, "1.2.3.4:80".parse().unwrap());
377 }
378
379 #[test]
380 fn port_zero_accepted() {
381 let msg = HolepunchMessage::rendezvous("1.2.3.4:0".parse().unwrap());
382 let bytes = msg.to_bytes();
383 let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
384 assert_eq!(parsed.addr.port(), 0);
385 }
386}