#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
use crate::error::{NetError, NetResult};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use hmac::{Hmac, KeyInit, Mac};
use sha1::Sha1;
use std::net::SocketAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
BindingRequest,
BindingResponse,
BindingError,
BindingIndication,
}
impl MessageType {
#[must_use]
pub const fn value(&self) -> u16 {
match self {
Self::BindingRequest => 0x0001,
Self::BindingResponse => 0x0101,
Self::BindingError => 0x0111,
Self::BindingIndication => 0x0011,
}
}
#[must_use]
pub const fn from_value(value: u16) -> Option<Self> {
match value {
0x0001 => Some(Self::BindingRequest),
0x0101 => Some(Self::BindingResponse),
0x0111 => Some(Self::BindingError),
0x0011 => Some(Self::BindingIndication),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttributeType {
MappedAddress,
XorMappedAddress,
Username,
MessageIntegrity,
Fingerprint,
ErrorCode,
Realm,
Nonce,
UnknownAttributes,
Software,
AlternateServer,
Priority,
UseCandidate,
IceControlled,
IceControlling,
}
impl AttributeType {
#[must_use]
pub const fn value(&self) -> u16 {
match self {
Self::MappedAddress => 0x0001,
Self::XorMappedAddress => 0x0020,
Self::Username => 0x0006,
Self::MessageIntegrity => 0x0008,
Self::Fingerprint => 0x8028,
Self::ErrorCode => 0x0009,
Self::Realm => 0x0014,
Self::Nonce => 0x0015,
Self::UnknownAttributes => 0x000A,
Self::Software => 0x8022,
Self::AlternateServer => 0x8023,
Self::Priority => 0x0024,
Self::UseCandidate => 0x0025,
Self::IceControlled => 0x8029,
Self::IceControlling => 0x802A,
}
}
#[must_use]
pub const fn from_value(value: u16) -> Option<Self> {
match value {
0x0001 => Some(Self::MappedAddress),
0x0020 => Some(Self::XorMappedAddress),
0x0006 => Some(Self::Username),
0x0008 => Some(Self::MessageIntegrity),
0x8028 => Some(Self::Fingerprint),
0x0009 => Some(Self::ErrorCode),
0x0014 => Some(Self::Realm),
0x0015 => Some(Self::Nonce),
0x000A => Some(Self::UnknownAttributes),
0x8022 => Some(Self::Software),
0x8023 => Some(Self::AlternateServer),
0x0024 => Some(Self::Priority),
0x0025 => Some(Self::UseCandidate),
0x8029 => Some(Self::IceControlled),
0x802A => Some(Self::IceControlling),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct Attribute {
pub attr_type: u16,
pub value: Bytes,
}
impl Attribute {
#[must_use]
pub fn new(attr_type: AttributeType, value: impl Into<Bytes>) -> Self {
Self {
attr_type: attr_type.value(),
value: value.into(),
}
}
#[must_use]
pub fn xor_mapped_address(addr: SocketAddr, transaction_id: &[u8; 12]) -> Self {
let mut buf = BytesMut::new();
match addr {
SocketAddr::V4(v4) => {
buf.put_u8(0); buf.put_u8(0x01);
let port = v4.port() ^ 0x2112;
buf.put_u16(port);
let ip_bytes = v4.ip().octets();
let magic = 0x2112A442u32.to_be_bytes();
for (i, byte) in ip_bytes.iter().enumerate() {
buf.put_u8(byte ^ magic[i]);
}
}
SocketAddr::V6(v6) => {
buf.put_u8(0); buf.put_u8(0x02);
let port = v6.port() ^ 0x2112;
buf.put_u16(port);
let ip_bytes = v6.ip().octets();
let magic = 0x2112A442u32.to_be_bytes();
let mut xor_mask = Vec::new();
xor_mask.extend_from_slice(&magic);
xor_mask.extend_from_slice(transaction_id);
for (i, byte) in ip_bytes.iter().enumerate() {
buf.put_u8(byte ^ xor_mask[i]);
}
}
}
Self {
attr_type: AttributeType::XorMappedAddress.value(),
value: buf.freeze(),
}
}
#[must_use]
pub fn username(username: impl AsRef<str>) -> Self {
Self::new(
AttributeType::Username,
Bytes::from(username.as_ref().to_string()),
)
}
#[must_use]
pub fn priority(priority: u32) -> Self {
let mut buf = BytesMut::new();
buf.put_u32(priority);
Self::new(AttributeType::Priority, buf.freeze())
}
#[must_use]
pub fn use_candidate() -> Self {
Self::new(AttributeType::UseCandidate, Bytes::new())
}
#[must_use]
pub fn ice_controlling(tie_breaker: u64) -> Self {
let mut buf = BytesMut::new();
buf.put_u64(tie_breaker);
Self::new(AttributeType::IceControlling, buf.freeze())
}
#[must_use]
pub fn ice_controlled(tie_breaker: u64) -> Self {
let mut buf = BytesMut::new();
buf.put_u64(tie_breaker);
Self::new(AttributeType::IceControlled, buf.freeze())
}
#[must_use]
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::new();
buf.put_u16(self.attr_type);
buf.put_u16(self.value.len() as u16);
buf.put(self.value.clone());
let padding = (4 - (self.value.len() % 4)) % 4;
for _ in 0..padding {
buf.put_u8(0);
}
buf.freeze()
}
pub fn parse_xor_mapped_address(&self, transaction_id: &[u8; 12]) -> NetResult<SocketAddr> {
if self.value.len() < 4 {
return Err(NetError::parse(0, "XOR-MAPPED-ADDRESS too short"));
}
let mut cursor = self.value.clone();
cursor.advance(1); let family = cursor.get_u8();
let port = cursor.get_u16() ^ 0x2112;
match family {
0x01 => {
if cursor.remaining() < 4 {
return Err(NetError::parse(0, "Invalid IPv4 address"));
}
let magic = 0x2112A442u32.to_be_bytes();
let mut ip = [0u8; 4];
for (i, byte) in ip.iter_mut().enumerate() {
*byte = cursor.get_u8() ^ magic[i];
}
Ok(SocketAddr::new(std::net::IpAddr::V4(ip.into()), port))
}
0x02 => {
if cursor.remaining() < 16 {
return Err(NetError::parse(0, "Invalid IPv6 address"));
}
let magic = 0x2112A442u32.to_be_bytes();
let mut xor_mask = Vec::new();
xor_mask.extend_from_slice(&magic);
xor_mask.extend_from_slice(transaction_id);
let mut ip = [0u8; 16];
for (i, byte) in ip.iter_mut().enumerate() {
*byte = cursor.get_u8() ^ xor_mask[i];
}
Ok(SocketAddr::new(std::net::IpAddr::V6(ip.into()), port))
}
_ => Err(NetError::parse(0, "Unknown address family")),
}
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub message_type: MessageType,
pub transaction_id: [u8; 12],
pub attributes: Vec<Attribute>,
}
impl Message {
pub const MAGIC_COOKIE: u32 = 0x2112A442;
#[must_use]
pub fn new(message_type: MessageType) -> Self {
let mut transaction_id = [0u8; 12];
use rand::RngExt;
rand::rng().fill(&mut transaction_id);
Self {
message_type,
transaction_id,
attributes: Vec::new(),
}
}
#[must_use]
pub fn binding_request() -> Self {
Self::new(MessageType::BindingRequest)
}
#[must_use]
pub fn binding_response(transaction_id: [u8; 12]) -> Self {
Self {
message_type: MessageType::BindingResponse,
transaction_id,
attributes: Vec::new(),
}
}
#[must_use]
pub fn with_attribute(mut self, attr: Attribute) -> Self {
self.attributes.push(attr);
self
}
#[must_use]
pub fn get_attribute(&self, attr_type: AttributeType) -> Option<&Attribute> {
self.attributes
.iter()
.find(|a| a.attr_type == attr_type.value())
}
#[must_use]
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::new();
let mut attrs_buf = BytesMut::new();
for attr in &self.attributes {
attrs_buf.put(attr.encode());
}
buf.put_u16(self.message_type.value());
buf.put_u16(attrs_buf.len() as u16);
buf.put_u32(Self::MAGIC_COOKIE);
buf.put_slice(&self.transaction_id);
buf.put(attrs_buf);
buf.freeze()
}
pub fn encode_with_integrity(&self, password: &str) -> Bytes {
let mut msg = self.clone();
msg.attributes
.retain(|a| a.attr_type != AttributeType::MessageIntegrity.value());
let encoded = msg.encode();
let Ok(mut mac) = Hmac::<Sha1>::new_from_slice(password.as_bytes()) else {
return Bytes::new();
};
mac.update(&encoded);
let result = mac.finalize();
let hmac_bytes = result.into_bytes();
let mut buf = BytesMut::from(encoded.as_ref());
let new_length = buf.len() - 20 + 4 + 20; buf[2..4].copy_from_slice(&(new_length as u16).to_be_bytes());
buf.put_u16(AttributeType::MessageIntegrity.value());
buf.put_u16(20); buf.put_slice(&hmac_bytes);
buf.freeze()
}
pub fn parse(data: &[u8]) -> NetResult<Self> {
if data.len() < 20 {
return Err(NetError::parse(0, "Message too short"));
}
let mut cursor = Bytes::copy_from_slice(data);
let message_type = cursor.get_u16();
let message_type = MessageType::from_value(message_type)
.ok_or_else(|| NetError::parse(0, "Unknown message type"))?;
let length = cursor.get_u16() as usize;
let magic = cursor.get_u32();
if magic != Self::MAGIC_COOKIE {
return Err(NetError::parse(4, "Invalid magic cookie"));
}
let mut transaction_id = [0u8; 12];
cursor.copy_to_slice(&mut transaction_id);
let mut attributes = Vec::new();
let mut parsed = 0;
while parsed < length && cursor.remaining() >= 4 {
let attr_type = cursor.get_u16();
let attr_length = cursor.get_u16() as usize;
if cursor.remaining() < attr_length {
return Err(NetError::parse(parsed as u64, "Attribute value too short"));
}
let value = cursor.copy_to_bytes(attr_length);
attributes.push(Attribute { attr_type, value });
let padding = (4 - (attr_length % 4)) % 4;
cursor.advance(padding.min(cursor.remaining()));
parsed += 4 + attr_length + padding;
}
Ok(Self {
message_type,
transaction_id,
attributes,
})
}
pub fn verify_integrity(&self, password: &str) -> bool {
let integrity_attr = self.get_attribute(AttributeType::MessageIntegrity);
if let Some(attr) = integrity_attr {
if attr.value.len() != 20 {
return false;
}
let mut msg = self.clone();
msg.attributes
.retain(|a| a.attr_type != AttributeType::MessageIntegrity.value());
let encoded = msg.encode();
let Ok(mut mac) = Hmac::<Sha1>::new_from_slice(password.as_bytes()) else {
return false;
};
mac.update(&encoded);
mac.verify_slice(&attr.value).is_ok()
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type() {
assert_eq!(MessageType::BindingRequest.value(), 0x0001);
assert_eq!(
MessageType::from_value(0x0001),
Some(MessageType::BindingRequest)
);
}
#[test]
fn test_binding_request() {
let msg = Message::binding_request();
assert_eq!(msg.message_type, MessageType::BindingRequest);
}
#[test]
fn test_encode_decode() {
let msg = Message::binding_request().with_attribute(Attribute::username("test:user"));
let encoded = msg.encode();
let decoded = Message::parse(&encoded).expect("should succeed in test");
assert_eq!(decoded.message_type, msg.message_type);
assert_eq!(decoded.transaction_id, msg.transaction_id);
assert_eq!(decoded.attributes.len(), 1);
}
#[test]
fn test_xor_mapped_address_v4() {
let addr: SocketAddr = "192.168.1.100:5000"
.parse()
.expect("should succeed in test");
let transaction_id = [0u8; 12];
let attr = Attribute::xor_mapped_address(addr, &transaction_id);
let encoded = attr.encode();
assert!(encoded.len() >= 8);
}
#[test]
fn test_priority_attribute() {
let attr = Attribute::priority(12345);
assert_eq!(attr.attr_type, AttributeType::Priority.value());
}
}