use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use byteorder::{BigEndian, ByteOrder};
use crate::message::{StunParseError, TransactionId, MAGIC_COOKIE};
use crate::AddressFamily;
use super::{check_len, AttributeType, RawAttribute};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MappedSocketAddr {
addr: SocketAddr,
}
impl MappedSocketAddr {
pub fn new(addr: SocketAddr) -> Self {
Self { addr }
}
pub fn length(&self) -> u16 {
match self.addr {
SocketAddr::V4(_) => 8,
SocketAddr::V6(_) => 20,
}
}
pub fn to_raw(&self, atype: AttributeType) -> RawAttribute<'_> {
match self.addr {
SocketAddr::V4(_addr) => {
let mut buf = [0; 8];
self.write_into_unchecked(&mut buf);
RawAttribute::new(atype, &buf).into_owned()
}
SocketAddr::V6(_addr) => {
let mut buf = [0; 20];
self.write_into_unchecked(&mut buf);
RawAttribute::new(atype, &buf).into_owned()
}
}
}
pub fn from_raw(raw: &RawAttribute) -> Result<Self, StunParseError> {
if raw.value.len() < 4 {
return Err(StunParseError::Truncated {
expected: 4,
actual: raw.value.len(),
});
}
let port = BigEndian::read_u16(&raw.value[2..4]);
let family = AddressFamily::from_byte(raw.value[1])?;
let addr = match family {
AddressFamily::IPV4 => {
check_len(raw.value.len(), 8..=8)?;
IpAddr::V4(Ipv4Addr::from(BigEndian::read_u32(&raw.value[4..8])))
}
AddressFamily::IPV6 => {
check_len(raw.value.len(), 20..=20)?;
let mut octets = [0; 16];
octets.clone_from_slice(&raw.value[4..]);
IpAddr::V6(Ipv6Addr::from(octets))
}
};
Ok(Self {
addr: SocketAddr::new(addr, port),
})
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn write_into_unchecked(&self, dest: &mut [u8]) {
match self.addr {
SocketAddr::V4(addr) => {
dest[0] = 0x0;
dest[1] = AddressFamily::IPV4.to_byte();
BigEndian::write_u16(&mut dest[2..4], addr.port());
let octets = u32::from(*addr.ip());
BigEndian::write_u32(&mut dest[4..8], octets);
}
SocketAddr::V6(addr) => {
dest[0] = 0x0;
dest[1] = AddressFamily::IPV6.to_byte();
BigEndian::write_u16(&mut dest[2..4], addr.port());
let octets = u128::from(*addr.ip());
BigEndian::write_u128(&mut dest[4..20], octets);
}
}
}
}
impl core::fmt::Display for MappedSocketAddr {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self.addr {
SocketAddr::V4(addr) => write!(f, "{addr:?}"),
SocketAddr::V6(addr) => write!(f, "{addr:?}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct XorSocketAddr {
pub addr: MappedSocketAddr,
}
impl XorSocketAddr {
pub fn new(addr: SocketAddr, transaction: TransactionId) -> Self {
Self {
addr: MappedSocketAddr::new(XorSocketAddr::xor_addr(addr, transaction)),
}
}
pub fn length(&self) -> u16 {
self.addr.length()
}
pub fn to_raw(&self, atype: AttributeType) -> RawAttribute<'_> {
self.addr.to_raw(atype)
}
pub fn from_raw(raw: &RawAttribute) -> Result<Self, StunParseError> {
let addr = MappedSocketAddr::from_raw(raw)?;
Ok(Self { addr })
}
pub fn xor_addr(addr: SocketAddr, transaction: TransactionId) -> SocketAddr {
match addr {
SocketAddr::V4(addr) => {
let port = addr.port() ^ (MAGIC_COOKIE >> 16) as u16;
let const_octets = MAGIC_COOKIE.to_be_bytes();
let addr_octets = addr.ip().octets();
let octets = bytewise_xor!(4, const_octets, addr_octets, 0);
SocketAddr::new(IpAddr::V4(Ipv4Addr::from(octets)), port)
}
SocketAddr::V6(addr) => {
let port = addr.port() ^ (MAGIC_COOKIE >> 16) as u16;
let transaction: u128 = transaction.into();
let const_octets = ((MAGIC_COOKIE as u128) << 96
| (transaction & 0x0000_0000_ffff_ffff_ffff_ffff_ffff_ffff))
.to_be_bytes();
let addr_octets = addr.ip().octets();
let octets = bytewise_xor!(16, const_octets, addr_octets, 0);
SocketAddr::new(IpAddr::V6(Ipv6Addr::from(octets)), port)
}
}
}
pub fn addr(&self, transaction: TransactionId) -> SocketAddr {
XorSocketAddr::xor_addr(self.addr.addr(), transaction)
}
pub fn write_into_unchecked(&self, dest: &mut [u8]) {
self.addr.write_into_unchecked(dest)
}
}
impl core::fmt::Display for XorSocketAddr {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self.addr.addr() {
SocketAddr::V4(_) => write!(f, "{:?}", self.addr(0x0.into())),
SocketAddr::V6(addr) => write!(f, "XOR({addr:?})"),
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use tracing::trace;
#[test]
fn mapped_address_ipv4() {
let addr = "192.168.0.1:3178".parse().unwrap();
let _log = crate::tests::test_init_log();
let data = [
0x00, 0x01, 0x00, 0x08, 0x00, 0x01, 0x0C, 0x6A, 0xC0, 0xA8, 0x00, 0x01,
];
let mapped = MappedSocketAddr::from_raw(&RawAttribute::from_bytes(&data).unwrap()).unwrap();
trace!("mapped: {mapped}");
assert_eq!(mapped.addr(), addr);
}
#[test]
fn mapped_address_short() {
let _log = crate::tests::test_init_log();
let data = [0x00, 0x01, 0x00, 0x02, 0x00, 0x00];
assert!(matches!(
MappedSocketAddr::from_raw(&RawAttribute::from_bytes(&data).unwrap()),
Err(StunParseError::Truncated {
expected: 4,
actual: 2
})
));
}
#[test]
fn mapped_address_unknown_family() {
let _log = crate::tests::test_init_log();
let data = [0x00, 0x01, 0x00, 0x04, 0x00, 0x99, 0x00, 0x00];
assert!(matches!(
MappedSocketAddr::from_raw(&RawAttribute::from_bytes(&data).unwrap()),
Err(StunParseError::InvalidAttributeData)
));
}
}