use super::{FrameStruct, FrameType};
use crate::{
VarInt, VarIntBoundsExceeded,
coding::{BufExt, BufMutExt, UnexpectedEnd},
transport_parameters::TransportParameters,
};
use bytes::{Buf, BufMut};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
pub const TRANSPORT_PARAM_RFC_NAT_TRAVERSAL: u64 = 0x3d7e9f0bca12fea8;
fn log_encode_overflow(context: &'static str) {
tracing::error!("VarInt overflow while encoding {context}");
debug_assert!(false, "VarInt overflow while encoding {context}");
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AddAddress {
pub sequence: VarInt,
pub address: SocketAddr,
pub(crate) priority: VarInt,
}
impl AddAddress {
pub fn new(sequence: VarInt, address: SocketAddr) -> Self {
let priority = calculate_priority(&address);
Self {
sequence,
address,
priority: VarInt::from_u32(priority),
}
}
pub fn encode<W: BufMut>(&self, buf: &mut W) {
if self.try_encode(buf).is_err() {
log_encode_overflow("AddAddress");
}
}
pub fn try_encode<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
self.try_encode_legacy(buf)
}
pub fn encode_rfc<W: BufMut>(&self, buf: &mut W) {
if self.try_encode_rfc(buf).is_err() {
log_encode_overflow("AddAddress::encode_rfc");
}
}
pub fn try_encode_rfc<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.address {
SocketAddr::V4(_) => buf.write_var(FrameType::ADD_ADDRESS_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::ADD_ADDRESS_IPV6.0)?,
}
buf.write_var(self.sequence.into_inner())?;
match self.address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
Ok(())
}
pub fn encode_legacy<W: BufMut>(&self, buf: &mut W) {
if self.try_encode_legacy(buf).is_err() {
log_encode_overflow("AddAddress::encode_legacy");
}
}
pub fn try_encode_legacy<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.address {
SocketAddr::V4(_) => buf.write_var(FrameType::ADD_ADDRESS_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::ADD_ADDRESS_IPV6.0)?,
}
buf.write_var(self.sequence.into_inner())?;
buf.write_var(self.priority.into_inner())?;
match self.address {
SocketAddr::V4(addr) => {
buf.put_u8(4); buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_u8(6); buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
buf.put_u32(addr.flowinfo());
buf.put_u32(addr.scope_id());
}
}
Ok(())
}
pub fn decode_rfc<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let sequence = r.get()?;
let address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(octets),
port,
0, 0, ))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
Ok(Self::new(sequence, address))
}
pub fn decode_legacy<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
let sequence = r.get()?;
let priority = r.get()?;
let ip_version = r.get::<u8>()?;
let address = match ip_version {
4 => {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get::<u16>()?;
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
}
6 => {
if r.remaining() < 16 + 2 + 4 + 4 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get::<u16>()?;
let flowinfo = r.get::<u32>()?;
let scope_id = r.get::<u32>()?;
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(octets),
port,
flowinfo,
scope_id,
))
}
_ => return Err(UnexpectedEnd),
};
Ok(Self {
sequence,
address,
priority,
})
}
pub fn decode_auto<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let _start_pos = r.remaining();
match Self::decode_rfc(r, is_ipv6) {
Ok(frame) => Ok(frame),
Err(_) => {
Self::decode_legacy(r)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PunchMeNow {
pub round: VarInt,
pub paired_with_sequence_number: VarInt,
pub address: SocketAddr,
pub(crate) target_peer_id: Option<[u8; 32]>,
}
impl PunchMeNow {
pub fn new(round: VarInt, paired_with_sequence_number: VarInt, address: SocketAddr) -> Self {
Self {
round,
paired_with_sequence_number,
address,
target_peer_id: None,
}
}
pub fn encode<W: BufMut>(&self, buf: &mut W) {
if self.try_encode(buf).is_err() {
log_encode_overflow("PunchMeNow");
}
}
pub fn try_encode<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
self.try_encode_legacy(buf)
}
pub fn encode_rfc<W: BufMut>(&self, buf: &mut W) {
if self.try_encode_rfc(buf).is_err() {
log_encode_overflow("PunchMeNow::encode_rfc");
}
}
pub fn try_encode_rfc<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.address {
SocketAddr::V4(_) => buf.write_var(FrameType::PUNCH_ME_NOW_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::PUNCH_ME_NOW_IPV6.0)?,
}
buf.write_var(self.round.into_inner())?;
buf.write_var(self.paired_with_sequence_number.into_inner())?;
match self.address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
Ok(())
}
pub fn encode_legacy<W: BufMut>(&self, buf: &mut W) {
if self.try_encode_legacy(buf).is_err() {
log_encode_overflow("PunchMeNow::encode_legacy");
}
}
pub fn try_encode_legacy<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.address {
SocketAddr::V4(_) => buf.write_var(FrameType::PUNCH_ME_NOW_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::PUNCH_ME_NOW_IPV6.0)?,
}
buf.write_var(self.round.into_inner())?;
buf.write_var(self.paired_with_sequence_number.into_inner())?;
match self.address {
SocketAddr::V4(addr) => {
buf.put_u8(4); buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_u8(6); buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
buf.put_u32(addr.flowinfo());
buf.put_u32(addr.scope_id());
}
}
match &self.target_peer_id {
Some(peer_id) => {
buf.put_u8(1); buf.put_slice(peer_id);
}
None => {
buf.put_u8(0); }
}
Ok(())
}
pub fn decode_rfc<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let round = r.get()?;
let paired_with_sequence_number = r.get()?;
let address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
Ok(Self::new(round, paired_with_sequence_number, address))
}
pub fn decode_auto<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
match Self::decode_rfc(r, is_ipv6) {
Ok(frame) => Ok(frame),
Err(_) => {
Self::decode_legacy(r)
}
}
}
pub fn decode_legacy<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
let round = r.get()?;
let target_sequence = r.get()?; let ip_version = r.get::<u8>()?;
let address = match ip_version {
4 => {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get::<u16>()?;
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
}
6 => {
if r.remaining() < 16 + 2 + 4 + 4 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get::<u16>()?;
let flowinfo = r.get::<u32>()?;
let scope_id = r.get::<u32>()?;
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(octets),
port,
flowinfo,
scope_id,
))
}
_ => return Err(UnexpectedEnd),
};
let target_peer_id = if r.remaining() > 0 {
let has_peer_id = r.get::<u8>()?;
if has_peer_id == 1 && r.remaining() >= 32 {
let mut peer_id = [0u8; 32];
r.copy_to_slice(&mut peer_id);
Some(peer_id)
} else {
None
}
} else {
None
};
Ok(Self {
round,
paired_with_sequence_number: target_sequence,
address,
target_peer_id,
})
}
}
impl FrameStruct for AddAddress {
const SIZE_BOUND: usize = 4 + 9 + 9 + 1 + 16 + 2 + 4 + 4; }
impl FrameStruct for PunchMeNow {
const SIZE_BOUND: usize = 4 + 9 + 9 + 1 + 16 + 2 + 4 + 4 + 1 + 32; }
impl FrameStruct for RemoveAddress {
const SIZE_BOUND: usize = 4 + 9; }
fn calculate_priority(addr: &SocketAddr) -> u32 {
let type_pref = match addr {
SocketAddr::V4(v4) => {
let ip = v4.ip();
if ip.is_loopback() {
0
} else if ip.is_private() {
100
} else {
126 }
}
SocketAddr::V6(v6) => {
let ip = v6.ip();
if ip.is_loopback() {
0
} else if ip.is_unicast_link_local() {
90
} else {
120
}
}
};
let local_pref = match addr {
SocketAddr::V4(_) => 65535,
SocketAddr::V6(_) => 65534,
};
((type_pref as u32) << 24) + ((local_pref as u32) << 8) + 255
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RemoveAddress {
pub sequence: VarInt,
}
impl RemoveAddress {
pub fn new(sequence: VarInt) -> Self {
Self { sequence }
}
pub fn encode<W: BufMut>(&self, buf: &mut W) {
if self.try_encode(buf).is_err() {
log_encode_overflow("RemoveAddress");
}
}
pub fn try_encode<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
buf.write_var(FrameType::REMOVE_ADDRESS.0)?;
buf.write_var(self.sequence.into_inner())?;
Ok(())
}
pub fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
let sequence = r.get()?;
Ok(Self { sequence })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum TryConnectError {
Timeout = 0,
ConnectionRefused = 1,
NetworkUnreachable = 2,
HostUnreachable = 3,
RateLimited = 4,
InvalidAddress = 5,
InternalError = 255,
}
impl TryConnectError {
pub fn from_u8(value: u8) -> Self {
match value {
0 => Self::Timeout,
1 => Self::ConnectionRefused,
2 => Self::NetworkUnreachable,
3 => Self::HostUnreachable,
4 => Self::RateLimited,
5 => Self::InvalidAddress,
_ => Self::InternalError,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TryConnectTo {
pub request_id: VarInt,
pub target_address: SocketAddr,
pub timeout_ms: u16,
}
impl TryConnectTo {
pub fn new(request_id: VarInt, target_address: SocketAddr, timeout_ms: u16) -> Self {
Self {
request_id,
target_address,
timeout_ms,
}
}
pub fn encode<W: BufMut>(&self, buf: &mut W) {
if self.try_encode(buf).is_err() {
log_encode_overflow("TryConnectTo");
}
}
pub fn try_encode<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.target_address {
SocketAddr::V4(_) => buf.write_var(FrameType::TRY_CONNECT_TO_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::TRY_CONNECT_TO_IPV6.0)?,
}
buf.write_var(self.request_id.into_inner())?;
match self.target_address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
buf.put_u16(self.timeout_ms);
Ok(())
}
pub fn decode<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let request_id = r.get()?;
let target_address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
if r.remaining() < 2 {
return Err(UnexpectedEnd);
}
let timeout_ms = r.get_u16();
Ok(Self {
request_id,
target_address,
timeout_ms,
})
}
}
impl FrameStruct for TryConnectTo {
const SIZE_BOUND: usize = 4 + 9 + 16 + 2 + 2; }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TryConnectToResponse {
pub request_id: VarInt,
pub success: bool,
pub error_code: Option<TryConnectError>,
pub source_address: SocketAddr,
}
impl TryConnectToResponse {
pub fn success(request_id: VarInt, source_address: SocketAddr) -> Self {
Self {
request_id,
success: true,
error_code: None,
source_address,
}
}
pub fn failure(
request_id: VarInt,
error_code: TryConnectError,
source_address: SocketAddr,
) -> Self {
Self {
request_id,
success: false,
error_code: Some(error_code),
source_address,
}
}
pub fn encode<W: BufMut>(&self, buf: &mut W) {
if self.try_encode(buf).is_err() {
log_encode_overflow("TryConnectToResponse");
}
}
pub fn try_encode<W: BufMut>(&self, buf: &mut W) -> Result<(), VarIntBoundsExceeded> {
match self.source_address {
SocketAddr::V4(_) => buf.write_var(FrameType::TRY_CONNECT_TO_RESPONSE_IPV4.0)?,
SocketAddr::V6(_) => buf.write_var(FrameType::TRY_CONNECT_TO_RESPONSE_IPV6.0)?,
}
buf.write_var(self.request_id.into_inner())?;
buf.put_u8(if self.success { 1 } else { 0 });
if let Some(error) = self.error_code {
buf.put_u8(error as u8);
} else {
buf.put_u8(0);
}
match self.source_address {
SocketAddr::V4(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
SocketAddr::V6(addr) => {
buf.put_slice(&addr.ip().octets());
buf.put_u16(addr.port());
}
}
Ok(())
}
pub fn decode<R: Buf>(r: &mut R, is_ipv6: bool) -> Result<Self, UnexpectedEnd> {
let request_id = r.get()?;
if r.remaining() < 2 {
return Err(UnexpectedEnd);
}
let success = r.get_u8() != 0;
let error_byte = r.get_u8();
let error_code = if success {
None
} else {
Some(TryConnectError::from_u8(error_byte))
};
let source_address = if is_ipv6 {
if r.remaining() < 16 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 16];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
} else {
if r.remaining() < 4 + 2 {
return Err(UnexpectedEnd);
}
let mut octets = [0u8; 4];
r.copy_to_slice(&mut octets);
let port = r.get_u16();
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
};
Ok(Self {
request_id,
success,
error_code,
source_address,
})
}
}
impl FrameStruct for TryConnectToResponse {
const SIZE_BOUND: usize = 4 + 9 + 1 + 1 + 16 + 2; }
#[derive(Debug, Clone)]
pub struct NatTraversalFrameConfig {
pub use_rfc_format: bool,
pub accept_legacy: bool,
}
impl Default for NatTraversalFrameConfig {
fn default() -> Self {
Self {
use_rfc_format: true, accept_legacy: true, }
}
}
impl NatTraversalFrameConfig {
pub fn from_transport_params(local: &TransportParameters, peer: &TransportParameters) -> Self {
Self {
use_rfc_format: local.supports_rfc_nat_traversal() && peer.supports_rfc_nat_traversal(),
accept_legacy: true,
}
}
pub fn rfc_only() -> Self {
Self {
use_rfc_format: true,
accept_legacy: false,
}
}
}
pub fn peer_supports_rfc_nat(transport_params: &[u8]) -> bool {
transport_params.windows(8).any(|window| {
let param = u64::from_be_bytes(window.try_into().unwrap_or_default());
param == TRANSPORT_PARAM_RFC_NAT_TRAVERSAL
})
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_add_address_rfc_encoding() {
let frame = AddAddress::new(VarInt::from_u32(42), "192.168.1.100:8080".parse().unwrap());
let mut buf = BytesMut::new();
frame.encode_rfc(&mut buf);
assert_eq!(buf[0..4], [0x80, 0x3d, 0x7e, 0x90]);
buf.advance(4);
let decoded = AddAddress::decode_rfc(&mut buf, false).unwrap();
assert_eq!(decoded.sequence, frame.sequence);
assert_eq!(decoded.address, frame.address);
}
#[test]
fn test_add_address_legacy_compatibility() {
let frame = AddAddress {
sequence: VarInt::from_u32(100),
address: "10.0.0.1:1234".parse().unwrap(),
priority: VarInt::from_u32(12345),
};
let mut buf = BytesMut::new();
frame.encode_legacy(&mut buf);
buf.advance(4);
let decoded = AddAddress::decode_legacy(&mut buf).unwrap();
assert_eq!(decoded.sequence, frame.sequence);
assert_eq!(decoded.address, frame.address);
assert_eq!(decoded.priority, frame.priority);
}
#[test]
fn test_punch_me_now_rfc_encoding() {
let frame = PunchMeNow::new(
VarInt::from_u32(1),
VarInt::from_u32(42),
"192.168.1.100:8080".parse().unwrap(),
);
let mut buf = BytesMut::new();
frame.encode_rfc(&mut buf);
assert_eq!(buf[0..4], [0x80, 0x3d, 0x7e, 0x92]);
buf.advance(4);
let decoded = PunchMeNow::decode_rfc(&mut buf, false).unwrap();
assert_eq!(decoded.round, frame.round);
assert_eq!(
decoded.paired_with_sequence_number,
frame.paired_with_sequence_number
);
assert_eq!(decoded.address, frame.address);
}
#[test]
fn test_punch_me_now_legacy_compatibility() {
let frame = PunchMeNow {
round: VarInt::from_u32(5),
paired_with_sequence_number: VarInt::from_u32(100),
address: "10.0.0.1:1234".parse().unwrap(),
target_peer_id: Some([0xAB; 32]),
};
let mut buf = BytesMut::new();
frame.encode_legacy(&mut buf);
buf.advance(4);
let decoded = PunchMeNow::decode_legacy(&mut buf).unwrap();
assert_eq!(decoded.round, frame.round);
assert_eq!(
decoded.paired_with_sequence_number,
frame.paired_with_sequence_number
);
assert_eq!(decoded.address, frame.address);
assert_eq!(decoded.target_peer_id, frame.target_peer_id);
}
#[test]
fn test_remove_address_encoding() {
let frame = RemoveAddress::new(VarInt::from_u32(42));
let mut buf = BytesMut::new();
frame.encode(&mut buf);
buf.advance(4);
let decoded = RemoveAddress::decode(&mut buf).unwrap();
assert_eq!(decoded.sequence, frame.sequence);
}
}