use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use bytes::{BufMut, Bytes, BytesMut};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum HolepunchMsgType {
Rendezvous = 0x00,
Connect = 0x01,
Error = 0x02,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum HolepunchError {
NoSuchPeer = 1,
NotConnected = 2,
NoSupport = 3,
NoSelf = 4,
}
impl HolepunchError {
pub fn from_u32(code: u32) -> Option<Self> {
match code {
1 => Some(Self::NoSuchPeer),
2 => Some(Self::NotConnected),
3 => Some(Self::NoSupport),
4 => Some(Self::NoSelf),
_ => None,
}
}
}
impl std::fmt::Display for HolepunchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoSuchPeer => write!(f, "no such peer"),
Self::NotConnected => write!(f, "not connected to target"),
Self::NoSupport => write!(f, "target does not support holepunch"),
Self::NoSelf => write!(f, "cannot holepunch to self"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HolepunchMessage {
pub msg_type: HolepunchMsgType,
pub addr: SocketAddr,
pub error_code: u32,
}
impl HolepunchMessage {
pub fn rendezvous(target: SocketAddr) -> Self {
Self {
msg_type: HolepunchMsgType::Rendezvous,
addr: target,
error_code: 0,
}
}
pub fn connect(addr: SocketAddr) -> Self {
Self {
msg_type: HolepunchMsgType::Connect,
addr,
error_code: 0,
}
}
pub fn error(addr: SocketAddr, error: HolepunchError) -> Self {
Self {
msg_type: HolepunchMsgType::Error,
addr,
error_code: error as u32,
}
}
fn wire_size(&self) -> usize {
let addr_len = match self.addr.ip() {
IpAddr::V4(_) => 4,
IpAddr::V6(_) => 16,
};
1 + 1 + addr_len + 2 + 4
}
pub fn to_bytes(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(self.wire_size());
buf.put_u8(self.msg_type as u8);
match self.addr.ip() {
IpAddr::V4(ip) => {
buf.put_u8(0x00);
buf.put_slice(&ip.octets());
}
IpAddr::V6(ip) => {
buf.put_u8(0x01);
buf.put_slice(&ip.octets());
}
}
buf.put_u16(self.addr.port());
buf.put_u32(self.error_code);
buf.freeze()
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 2 {
return Err(Error::InvalidExtended("holepunch message too short".into()));
}
let msg_type = match data[0] {
0x00 => HolepunchMsgType::Rendezvous,
0x01 => HolepunchMsgType::Connect,
0x02 => HolepunchMsgType::Error,
n => {
return Err(Error::InvalidExtended(format!(
"unknown holepunch msg_type {n:#04x}"
)));
}
};
let addr_type = data[1];
let (addr_len, expected_total) = match addr_type {
0x00 => (4usize, 12usize), 0x01 => (16usize, 24usize), n => {
return Err(Error::InvalidExtended(format!(
"unknown holepunch addr_type {n:#04x}"
)));
}
};
if data.len() < expected_total {
return Err(Error::InvalidExtended(format!(
"holepunch message too short: need {expected_total} bytes, got {}",
data.len()
)));
}
let addr_start = 2;
let ip: IpAddr = if addr_type == 0x00 {
let o = &data[addr_start..addr_start + 4];
IpAddr::V4(Ipv4Addr::new(o[0], o[1], o[2], o[3]))
} else {
let mut octets = [0u8; 16];
octets.copy_from_slice(&data[addr_start..addr_start + 16]);
IpAddr::V6(Ipv6Addr::from(octets))
};
let port_start = addr_start + addr_len;
let port = u16::from_be_bytes([data[port_start], data[port_start + 1]]);
let err_start = port_start + 2;
let error_code = u32::from_be_bytes([
data[err_start],
data[err_start + 1],
data[err_start + 2],
data[err_start + 3],
]);
Ok(HolepunchMessage {
msg_type,
addr: SocketAddr::new(ip, port),
error_code,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rendezvous_ipv4_round_trip() {
let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
let msg = HolepunchMessage::rendezvous(addr);
assert_eq!(msg.msg_type, HolepunchMsgType::Rendezvous);
assert_eq!(msg.addr, addr);
assert_eq!(msg.error_code, 0);
let bytes = msg.to_bytes();
assert_eq!(bytes.len(), 12);
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed, msg);
}
#[test]
fn connect_ipv4_round_trip() {
let addr: SocketAddr = "10.0.0.1:8080".parse().unwrap();
let msg = HolepunchMessage::connect(addr);
assert_eq!(msg.msg_type, HolepunchMsgType::Connect);
let bytes = msg.to_bytes();
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed, msg);
}
#[test]
fn error_ipv4_round_trip() {
let addr: SocketAddr = "172.16.0.5:51413".parse().unwrap();
let msg = HolepunchMessage::error(addr, HolepunchError::NotConnected);
assert_eq!(msg.msg_type, HolepunchMsgType::Error);
assert_eq!(msg.error_code, 2);
let bytes = msg.to_bytes();
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed, msg);
}
#[test]
fn rendezvous_ipv6_round_trip() {
let addr: SocketAddr = "[2001:db8::1]:6881".parse().unwrap();
let msg = HolepunchMessage::rendezvous(addr);
let bytes = msg.to_bytes();
assert_eq!(bytes.len(), 24);
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed, msg);
}
#[test]
fn connect_ipv6_round_trip() {
let addr: SocketAddr = "[::1]:8080".parse().unwrap();
let msg = HolepunchMessage::connect(addr);
let bytes = msg.to_bytes();
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed, msg);
}
#[test]
fn error_ipv6_all_error_codes() {
let addr: SocketAddr = "[fe80::1]:9999".parse().unwrap();
for (code, variant) in [
(1, HolepunchError::NoSuchPeer),
(2, HolepunchError::NotConnected),
(3, HolepunchError::NoSupport),
(4, HolepunchError::NoSelf),
] {
let msg = HolepunchMessage::error(addr, variant);
assert_eq!(msg.error_code, code);
let bytes = msg.to_bytes();
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed.error_code, code);
assert_eq!(HolepunchError::from_u32(code), Some(variant));
}
}
#[test]
fn unknown_msg_type_rejected() {
let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
.to_bytes()
.to_vec();
data[0] = 0x03; assert!(HolepunchMessage::from_bytes(&data).is_err());
}
#[test]
fn unknown_addr_type_rejected() {
let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
.to_bytes()
.to_vec();
data[1] = 0x02; assert!(HolepunchMessage::from_bytes(&data).is_err());
}
#[test]
fn too_short_rejected() {
assert!(HolepunchMessage::from_bytes(&[]).is_err());
assert!(HolepunchMessage::from_bytes(&[0x00]).is_err());
assert!(HolepunchMessage::from_bytes(&[0x00, 0x00, 1, 2, 3, 4, 0, 80]).is_err());
}
#[test]
fn ipv6_too_short_rejected() {
let data = [0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
assert!(HolepunchMessage::from_bytes(&data).is_err());
}
#[test]
fn error_code_unknown_parses_as_none() {
assert!(HolepunchError::from_u32(0).is_none());
assert!(HolepunchError::from_u32(5).is_none());
assert!(HolepunchError::from_u32(u32::MAX).is_none());
}
#[test]
fn error_display() {
assert_eq!(HolepunchError::NoSuchPeer.to_string(), "no such peer");
assert_eq!(
HolepunchError::NotConnected.to_string(),
"not connected to target"
);
assert_eq!(
HolepunchError::NoSupport.to_string(),
"target does not support holepunch"
);
assert_eq!(
HolepunchError::NoSelf.to_string(),
"cannot holepunch to self"
);
}
#[test]
fn wire_size_ipv4() {
let msg = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap());
assert_eq!(msg.wire_size(), 12);
}
#[test]
fn wire_size_ipv6() {
let msg = HolepunchMessage::rendezvous("[::1]:80".parse().unwrap());
assert_eq!(msg.wire_size(), 24);
}
#[test]
fn exact_wire_bytes_ipv4_rendezvous() {
let addr: SocketAddr = "192.168.1.100:6881".parse().unwrap();
let msg = HolepunchMessage::rendezvous(addr);
let bytes = msg.to_bytes();
assert_eq!(bytes[0], 0x00); assert_eq!(bytes[1], 0x00); assert_eq!(&bytes[2..6], &[192, 168, 1, 100]); assert_eq!(u16::from_be_bytes([bytes[6], bytes[7]]), 6881); assert_eq!(
u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]),
0
); }
#[test]
fn extra_trailing_bytes_ignored() {
let mut data = HolepunchMessage::rendezvous("1.2.3.4:80".parse().unwrap())
.to_bytes()
.to_vec();
data.push(0xFF);
data.push(0xAA);
let parsed = HolepunchMessage::from_bytes(&data).unwrap();
assert_eq!(parsed.msg_type, HolepunchMsgType::Rendezvous);
assert_eq!(parsed.addr, "1.2.3.4:80".parse().unwrap());
}
#[test]
fn port_zero_accepted() {
let msg = HolepunchMessage::rendezvous("1.2.3.4:0".parse().unwrap());
let bytes = msg.to_bytes();
let parsed = HolepunchMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed.addr.port(), 0);
}
}