use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use bytes::{BufMut, Bytes, BytesMut};
const ADDR_LEN: usize = 19;
const FAMILY_V4: u8 = 4;
const FAMILY_V6: u8 = 6;
const KIND_PUNCH_REQUEST: u8 = 0x01;
const KIND_PUNCH_INTRODUCE: u8 = 0x02;
const KIND_PUNCH_ACK: u8 = 0x03;
pub const PUNCH_REQUEST_LEN: usize = 1 + 8 + ADDR_LEN;
pub const PUNCH_INTRODUCE_LEN: usize = 1 + 8 + ADDR_LEN + 8;
pub const PUNCH_ACK_LEN: usize = 1 + 8 + 8 + 4;
pub const KEEPALIVE_MAGIC: u16 = 0x4850;
pub const KEEPALIVE_LEN: usize = 2 + 8 + 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Keepalive {
pub sender_node_id: u64,
pub punch_id: u32,
}
pub fn encode_keepalive(ka: &Keepalive) -> Bytes {
let mut buf = BytesMut::with_capacity(KEEPALIVE_LEN);
buf.put_slice(&KEEPALIVE_MAGIC.to_le_bytes());
buf.put_u64_le(ka.sender_node_id);
buf.put_u32_le(ka.punch_id);
debug_assert_eq!(buf.len(), KEEPALIVE_LEN);
buf.freeze()
}
pub fn decode_keepalive(data: &[u8]) -> Option<Keepalive> {
if data.len() != KEEPALIVE_LEN {
return None;
}
let magic = u16::from_le_bytes([data[0], data[1]]);
if magic != KEEPALIVE_MAGIC {
return None;
}
let sender_node_id = u64::from_le_bytes(data[2..10].try_into().ok()?);
let punch_id = u32::from_le_bytes(data[10..14].try_into().ok()?);
Some(Keepalive {
sender_node_id,
punch_id,
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PunchRequest {
pub target: u64,
pub self_reflex: SocketAddr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PunchIntroduce {
pub peer: u64,
pub peer_reflex: SocketAddr,
pub fire_at_ms: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PunchAck {
pub from_peer: u64,
pub to_peer: u64,
pub punch_id: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RendezvousMsg {
PunchRequest(PunchRequest),
PunchIntroduce(PunchIntroduce),
PunchAck(PunchAck),
}
impl RendezvousMsg {
pub fn encode(&self) -> Bytes {
match self {
RendezvousMsg::PunchRequest(req) => encode_punch_request(req),
RendezvousMsg::PunchIntroduce(intro) => encode_punch_introduce(intro),
RendezvousMsg::PunchAck(ack) => encode_punch_ack(ack),
}
}
}
fn encode_addr(buf: &mut BytesMut, addr: SocketAddr) {
match addr {
SocketAddr::V4(v4) => {
buf.put_u8(FAMILY_V4);
let mut bytes = [0u8; 16];
bytes[..4].copy_from_slice(&v4.ip().octets());
buf.put_slice(&bytes);
buf.put_u16(v4.port());
}
SocketAddr::V6(v6) => {
buf.put_u8(FAMILY_V6);
buf.put_slice(&v6.ip().octets());
buf.put_u16(v6.port());
}
}
}
fn decode_addr(bytes: &[u8]) -> Option<SocketAddr> {
if bytes.len() != ADDR_LEN {
return None;
}
let family = bytes[0];
let addr_bytes: [u8; 16] = bytes[1..17].try_into().ok()?;
let port = u16::from_be_bytes([bytes[17], bytes[18]]);
let ip = match family {
FAMILY_V4 => IpAddr::V4(Ipv4Addr::new(
addr_bytes[0],
addr_bytes[1],
addr_bytes[2],
addr_bytes[3],
)),
FAMILY_V6 => IpAddr::V6(Ipv6Addr::from(addr_bytes)),
_ => return None,
};
Some(SocketAddr::new(ip, port))
}
fn encode_punch_request(req: &PunchRequest) -> Bytes {
let mut buf = BytesMut::with_capacity(PUNCH_REQUEST_LEN);
buf.put_u8(KIND_PUNCH_REQUEST);
buf.put_u64(req.target);
encode_addr(&mut buf, req.self_reflex);
debug_assert_eq!(buf.len(), PUNCH_REQUEST_LEN);
buf.freeze()
}
fn encode_punch_introduce(intro: &PunchIntroduce) -> Bytes {
let mut buf = BytesMut::with_capacity(PUNCH_INTRODUCE_LEN);
buf.put_u8(KIND_PUNCH_INTRODUCE);
buf.put_u64(intro.peer);
encode_addr(&mut buf, intro.peer_reflex);
buf.put_u64(intro.fire_at_ms);
debug_assert_eq!(buf.len(), PUNCH_INTRODUCE_LEN);
buf.freeze()
}
fn encode_punch_ack(ack: &PunchAck) -> Bytes {
let mut buf = BytesMut::with_capacity(PUNCH_ACK_LEN);
buf.put_u8(KIND_PUNCH_ACK);
buf.put_u64(ack.from_peer);
buf.put_u64(ack.to_peer);
buf.put_u32(ack.punch_id);
debug_assert_eq!(buf.len(), PUNCH_ACK_LEN);
buf.freeze()
}
pub fn decode(payload: &[u8]) -> Option<RendezvousMsg> {
let &kind = payload.first()?;
match kind {
KIND_PUNCH_REQUEST => {
if payload.len() != PUNCH_REQUEST_LEN {
return None;
}
let target = u64::from_be_bytes(payload[1..9].try_into().ok()?);
let self_reflex = decode_addr(&payload[9..28])?;
Some(RendezvousMsg::PunchRequest(PunchRequest {
target,
self_reflex,
}))
}
KIND_PUNCH_INTRODUCE => {
if payload.len() != PUNCH_INTRODUCE_LEN {
return None;
}
let peer = u64::from_be_bytes(payload[1..9].try_into().ok()?);
let peer_reflex = decode_addr(&payload[9..28])?;
let fire_at_ms = u64::from_be_bytes(payload[28..36].try_into().ok()?);
Some(RendezvousMsg::PunchIntroduce(PunchIntroduce {
peer,
peer_reflex,
fire_at_ms,
}))
}
KIND_PUNCH_ACK => {
if payload.len() != PUNCH_ACK_LEN {
return None;
}
let from_peer = u64::from_be_bytes(payload[1..9].try_into().ok()?);
let to_peer = u64::from_be_bytes(payload[9..17].try_into().ok()?);
let punch_id = u32::from_be_bytes(payload[17..21].try_into().ok()?);
Some(RendezvousMsg::PunchAck(PunchAck {
from_peer,
to_peer,
punch_id,
}))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sa(addr: &str) -> SocketAddr {
addr.parse().unwrap()
}
#[test]
fn punch_request_roundtrip_ipv4() {
let req = PunchRequest {
target: 0x1122_3344_5566_7788,
self_reflex: sa("192.0.2.1:9001"),
};
let encoded = RendezvousMsg::PunchRequest(req).encode();
assert_eq!(encoded.len(), PUNCH_REQUEST_LEN);
match decode(&encoded) {
Some(RendezvousMsg::PunchRequest(out)) => assert_eq!(out, req),
other => panic!("expected PunchRequest, got {other:?}"),
}
}
#[test]
fn punch_request_roundtrip_ipv6() {
let req = PunchRequest {
target: 42,
self_reflex: sa("[2001:db8::1]:443"),
};
let encoded = RendezvousMsg::PunchRequest(req).encode();
match decode(&encoded) {
Some(RendezvousMsg::PunchRequest(out)) => assert_eq!(out, req),
other => panic!("expected PunchRequest, got {other:?}"),
}
}
#[test]
fn punch_request_wire_layout_matches_doc() {
let req = PunchRequest {
target: 0x0102_0304_0506_0708,
self_reflex: sa("192.0.2.7:9001"),
};
let encoded = RendezvousMsg::PunchRequest(req).encode();
assert_eq!(encoded.len(), PUNCH_REQUEST_LEN, "total wire length");
assert_eq!(PUNCH_REQUEST_LEN, 28, "kind(1) + body(27)");
assert_eq!(
PUNCH_REQUEST_LEN - 1,
27,
"body = 8 (target) + 19 (self_reflex)",
);
assert_eq!(encoded[0], 0x01, "kind byte = KIND_PUNCH_REQUEST");
assert_eq!(
&encoded[1..9],
&0x0102_0304_0506_0708_u64.to_be_bytes(),
"target_node big-endian at offset 1",
);
assert_eq!(encoded[9], FAMILY_V4, "family byte at offset 9");
assert_eq!(&encoded[10..14], &[192, 0, 2, 7], "IPv4 in low 4 bytes");
assert_eq!(
&encoded[14..26],
&[0u8; 12],
"upper 12 bytes of address field zero-padded for IPv4",
);
assert_eq!(
&encoded[26..28],
&9001_u16.to_be_bytes(),
"port big-endian at offset 26",
);
}
#[test]
fn punch_introduce_roundtrip() {
let intro = PunchIntroduce {
peer: 0xDEAD_BEEF_FEED_CAFE,
peer_reflex: sa("198.51.100.5:54321"),
fire_at_ms: 1_700_000_000_500,
};
let encoded = RendezvousMsg::PunchIntroduce(intro).encode();
assert_eq!(encoded.len(), PUNCH_INTRODUCE_LEN);
match decode(&encoded) {
Some(RendezvousMsg::PunchIntroduce(out)) => assert_eq!(out, intro),
other => panic!("expected PunchIntroduce, got {other:?}"),
}
}
#[test]
fn punch_ack_roundtrip() {
let ack = PunchAck {
from_peer: 7,
to_peer: 42,
punch_id: 0xCAFEBABE,
};
let encoded = RendezvousMsg::PunchAck(ack).encode();
assert_eq!(encoded.len(), PUNCH_ACK_LEN);
match decode(&encoded) {
Some(RendezvousMsg::PunchAck(out)) => assert_eq!(out, ack),
other => panic!("expected PunchAck, got {other:?}"),
}
}
#[test]
fn punch_ack_from_and_to_are_distinguishable_on_wire() {
let ack = PunchAck {
from_peer: 0x1111_1111_1111_1111,
to_peer: 0x2222_2222_2222_2222,
punch_id: 0x3333_3333,
};
let encoded = RendezvousMsg::PunchAck(ack).encode();
match decode(&encoded) {
Some(RendezvousMsg::PunchAck(out)) => {
assert_eq!(out.from_peer, 0x1111_1111_1111_1111);
assert_eq!(out.to_peer, 0x2222_2222_2222_2222);
}
other => panic!("expected PunchAck, got {other:?}"),
}
}
#[test]
fn unknown_kind_rejects() {
let mut payload = vec![0u8; PUNCH_ACK_LEN];
payload[0] = 0xFF;
assert!(decode(&payload).is_none());
}
#[test]
fn empty_payload_rejects() {
assert!(decode(&[]).is_none());
}
#[test]
fn wrong_length_rejects_per_kind() {
let short_request = vec![KIND_PUNCH_REQUEST; PUNCH_REQUEST_LEN - 1];
assert!(decode(&short_request).is_none());
let short_introduce = vec![KIND_PUNCH_INTRODUCE; PUNCH_INTRODUCE_LEN - 1];
assert!(decode(&short_introduce).is_none());
let short_ack = vec![KIND_PUNCH_ACK; PUNCH_ACK_LEN - 1];
assert!(decode(&short_ack).is_none());
let long_ack = vec![KIND_PUNCH_ACK; PUNCH_ACK_LEN + 1];
assert!(decode(&long_ack).is_none());
}
#[test]
fn unknown_address_family_rejects() {
let mut payload = vec![0u8; PUNCH_REQUEST_LEN];
payload[0] = KIND_PUNCH_REQUEST;
payload[9] = 7;
assert!(decode(&payload).is_none());
}
#[test]
fn keepalive_roundtrip() {
let ka = Keepalive {
sender_node_id: 0xA1B2_C3D4_E5F6_0718,
punch_id: 0x1234_5678,
};
let encoded = encode_keepalive(&ka);
assert_eq!(encoded.len(), KEEPALIVE_LEN);
match decode_keepalive(&encoded) {
Some(out) => assert_eq!(out, ka),
None => panic!("decode_keepalive returned None on a valid packet"),
}
}
#[test]
fn keepalive_byte_layout_is_all_little_endian() {
let ka = Keepalive {
sender_node_id: 0x0102_0304_0506_0708,
punch_id: 0x1A2B_3C4D,
};
let encoded = encode_keepalive(&ka);
let expected: [u8; 14] = [
0x50, 0x48,
0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x4D, 0x3C, 0x2B, 0x1A,
];
assert_eq!(
&encoded[..],
&expected[..],
"keep-alive must be little-endian on the wire — \
mixing endianness with the packet header is fragile",
);
}
#[test]
fn keepalive_magic_is_distinct_from_net_magic() {
use crate::adapter::net::protocol::MAGIC;
assert_ne!(
KEEPALIVE_MAGIC, MAGIC,
"KEEPALIVE_MAGIC collides with Net packet MAGIC",
);
}
#[test]
fn keepalive_wrong_length_rejects() {
let mut too_short = vec![0u8; KEEPALIVE_LEN - 1];
too_short[0..2].copy_from_slice(&KEEPALIVE_MAGIC.to_le_bytes());
assert!(decode_keepalive(&too_short).is_none());
let mut too_long = vec![0u8; KEEPALIVE_LEN + 1];
too_long[0..2].copy_from_slice(&KEEPALIVE_MAGIC.to_le_bytes());
assert!(decode_keepalive(&too_long).is_none());
let mut wrong_magic = vec![0u8; KEEPALIVE_LEN];
wrong_magic[0..2].copy_from_slice(&0xFFFFu16.to_le_bytes());
assert!(decode_keepalive(&wrong_magic).is_none());
}
#[test]
fn encoded_kind_byte_matches_discriminator() {
let req = PunchRequest {
target: 1,
self_reflex: sa("10.0.0.1:1"),
};
let intro = PunchIntroduce {
peer: 1,
peer_reflex: sa("10.0.0.1:1"),
fire_at_ms: 1,
};
let ack = PunchAck {
from_peer: 1,
to_peer: 1,
punch_id: 1,
};
assert_eq!(
RendezvousMsg::PunchRequest(req).encode()[0],
KIND_PUNCH_REQUEST
);
assert_eq!(
RendezvousMsg::PunchIntroduce(intro).encode()[0],
KIND_PUNCH_INTRODUCE
);
assert_eq!(RendezvousMsg::PunchAck(ack).encode()[0], KIND_PUNCH_ACK);
}
}