macro_rules! bytewise_xor {
($size:literal, $a:expr, $b:expr, $default:literal) => {{
let mut arr = [$default; $size];
for (i, item) in arr.iter_mut().enumerate() {
*item = $a[i] ^ $b[i];
}
arr
}};
}
mod address;
pub use address::{MappedSocketAddr, XorSocketAddr};
mod alternate;
pub use alternate::{AlternateDomain, AlternateServer};
mod error;
pub use error::{ErrorCode, UnknownAttributes};
mod integrity;
pub use integrity::{MessageIntegrity, MessageIntegritySha256};
mod fingerprint;
pub use fingerprint::Fingerprint;
mod nonce;
pub use nonce::Nonce;
mod password_algorithm;
pub use password_algorithm::{PasswordAlgorithm, PasswordAlgorithmValue, PasswordAlgorithms};
mod realm;
pub use realm::Realm;
mod user;
pub use user::{Userhash, Username};
mod software;
pub use software::Software;
mod xor_addr;
pub use xor_addr::XorMappedAddress;
use crate::data::Data;
use crate::message::{StunParseError, StunWriteError};
use alloc::boxed::Box;
use alloc::vec::Vec;
use byteorder::{BigEndian, ByteOrder};
#[cfg(feature = "std")]
use alloc::collections::BTreeMap;
#[cfg(feature = "std")]
use std::sync::{Mutex, OnceLock};
pub type AttributeDisplay =
fn(&RawAttribute<'_>, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result;
#[cfg(feature = "std")]
static ATTRIBUTE_EXTERNAL_DISPLAY_IMPL: OnceLock<Mutex<BTreeMap<AttributeType, AttributeDisplay>>> =
OnceLock::new();
#[cfg(feature = "std")]
pub fn add_display_impl(atype: AttributeType, imp: AttributeDisplay) {
let mut display_impls = ATTRIBUTE_EXTERNAL_DISPLAY_IMPL
.get_or_init(Default::default)
.lock()
.unwrap();
display_impls.insert(atype, imp);
}
#[cfg(feature = "std")]
#[macro_export]
macro_rules! attribute_display {
($typ:ty) => {{
let imp = |attr: &$crate::attribute::RawAttribute<'_>,
f: &mut core::fmt::Formatter<'_>|
-> core::fmt::Result {
if let Ok(attr) = <$typ>::from_raw_ref(attr) {
write!(f, "{}", attr)
} else {
write!(
f,
"{}(Malformed): len: {}, data: {:?})",
attr.get_type(),
attr.header.length(),
attr.value
)
}
};
$crate::attribute::add_display_impl(<$typ>::TYPE, imp);
}};
}
#[cfg(feature = "std")]
static ATTRIBUTE_TYPE_NAME_MAP: OnceLock<Mutex<BTreeMap<AttributeType, &'static str>>> =
OnceLock::new();
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AttributeType(u16);
impl core::fmt::Display for AttributeType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}({:#x}: {})", self.0, self.0, self.name())
}
}
impl AttributeType {
#[cfg(feature = "std")]
pub fn add_name(self, name: &'static str) {
let mut anames = ATTRIBUTE_TYPE_NAME_MAP
.get_or_init(Default::default)
.lock()
.unwrap();
anames.insert(self, name);
}
pub const fn new(val: u16) -> Self {
Self(val)
}
pub fn value(&self) -> u16 {
self.0
}
pub fn name(self) -> &'static str {
match self {
AttributeType(0x0001) => "MAPPED-ADDRESS",
Username::TYPE => "USERNAME",
MessageIntegrity::TYPE => "MESSAGE-INTEGRITY",
ErrorCode::TYPE => "ERROR-CODE",
UnknownAttributes::TYPE => "UNKNOWN-ATTRIBUTES",
Realm::TYPE => "REALM",
Nonce::TYPE => "NONCE",
MessageIntegritySha256::TYPE => "MESSAGE-INTEGRITY-SHA256",
PasswordAlgorithm::TYPE => "PASSWORD-ALGORITHM",
Userhash::TYPE => "USERHASH",
XorMappedAddress::TYPE => "XOR-MAPPED-ADDRESS",
PasswordAlgorithms::TYPE => "PASSWORD_ALGORITHMS",
AlternateDomain::TYPE => "ALTERNATE-DOMAIN",
Software::TYPE => "SOFTWARE",
AlternateServer::TYPE => "ALTERNATE-SERVER",
Fingerprint::TYPE => "FINGERPRINT",
_ => {
#[cfg(feature = "std")]
{
let anames = ATTRIBUTE_TYPE_NAME_MAP
.get_or_init(Default::default)
.lock()
.unwrap();
if let Some(name) = anames.get(&self) {
return name;
}
}
"unknown"
}
}
}
pub fn comprehension_required(self) -> bool {
self.0 < 0x8000
}
}
impl From<u16> for AttributeType {
fn from(f: u16) -> Self {
Self::new(f)
}
}
impl From<AttributeType> for u16 {
fn from(f: AttributeType) -> Self {
f.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct AttributeHeader {
atype: AttributeType,
length: u16,
}
impl AttributeHeader {
fn parse(data: &[u8]) -> Result<Self, StunParseError> {
if data.len() < 4 {
return Err(StunParseError::Truncated {
expected: 4,
actual: data.len(),
});
}
let ret = Self {
atype: BigEndian::read_u16(&data[0..2]).into(),
length: BigEndian::read_u16(&data[2..4]),
};
Ok(ret)
}
fn to_bytes(self) -> [u8; 4] {
let mut ret = [0; 4];
self.write_into(&mut ret);
ret
}
fn write_into(&self, ret: &mut [u8]) {
BigEndian::write_u16(&mut ret[0..2], self.atype.into());
BigEndian::write_u16(&mut ret[2..4], self.length);
}
pub fn get_type(&self) -> AttributeType {
self.atype
}
pub fn length(&self) -> u16 {
self.length
}
}
impl From<AttributeHeader> for [u8; 4] {
fn from(f: AttributeHeader) -> Self {
f.to_bytes()
}
}
impl TryFrom<&[u8]> for AttributeHeader {
type Error = StunParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
AttributeHeader::parse(value)
}
}
pub trait AttributeStaticType {
const TYPE: AttributeType;
}
pub trait Attribute: core::fmt::Debug + core::marker::Sync + core::marker::Send {
fn get_type(&self) -> AttributeType;
fn length(&self) -> u16;
}
pub trait AttributeFromRaw<'a>: Attribute {
fn from_raw_ref(raw: &RawAttribute) -> Result<Self, StunParseError>
where
Self: Sized;
fn from_raw(raw: RawAttribute<'a>) -> Result<Self, StunParseError>
where
Self: Sized,
{
Self::from_raw_ref(&raw)
}
}
pub fn pad_attribute_len(len: usize) -> usize {
if len % 4 == 0 {
len
} else {
len + 4 - len % 4
}
}
pub trait AttributeExt {
fn padded_len(&self) -> usize;
}
impl<A: Attribute + ?Sized> AttributeExt for A {
fn padded_len(&self) -> usize {
4 + pad_attribute_len(self.length() as usize)
}
}
pub trait AttributeWrite: Attribute {
fn write_into_unchecked(&self, dest: &mut [u8]);
fn to_raw(&self) -> RawAttribute<'_>;
}
pub trait AttributeWriteExt: AttributeWrite {
fn write_header_unchecked(&self, dest: &mut [u8]) -> usize;
fn write_header(&self, dest: &mut [u8]) -> Result<usize, StunWriteError>;
fn write_into(&self, dest: &mut [u8]) -> Result<usize, StunWriteError>;
}
impl<A: AttributeWrite + ?Sized> AttributeWriteExt for A {
fn write_header(&self, dest: &mut [u8]) -> Result<usize, StunWriteError> {
if dest.len() < 4 {
return Err(StunWriteError::TooSmall {
expected: 4,
actual: dest.len(),
});
}
self.write_header_unchecked(dest);
Ok(4)
}
fn write_header_unchecked(&self, dest: &mut [u8]) -> usize {
AttributeHeader {
atype: self.get_type(),
length: self.length(),
}
.write_into(dest);
4
}
fn write_into(&self, dest: &mut [u8]) -> Result<usize, StunWriteError> {
let len = self.padded_len();
if len > dest.len() {
return Err(StunWriteError::TooSmall {
expected: len,
actual: dest.len(),
});
}
self.write_into_unchecked(dest);
Ok(len)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RawAttribute<'a> {
pub header: AttributeHeader,
pub value: Data<'a>,
}
macro_rules! display_attr {
($this:ident, $f:ident, $CamelType:ty) => {{
if let Ok(attr) = <$CamelType>::from_raw_ref($this) {
write!($f, "{}", attr)
} else {
write!(
$f,
"{}(Malformed): len: {}, data: {:?})",
$this.get_type(),
$this.header.length(),
$this.value
)
}
}};
}
impl core::fmt::Display for RawAttribute<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self.get_type() {
Username::TYPE => display_attr!(self, f, Username),
MessageIntegrity::TYPE => display_attr!(self, f, MessageIntegrity),
ErrorCode::TYPE => display_attr!(self, f, ErrorCode),
UnknownAttributes::TYPE => display_attr!(self, f, UnknownAttributes),
Realm::TYPE => display_attr!(self, f, Realm),
Nonce::TYPE => display_attr!(self, f, Nonce),
MessageIntegritySha256::TYPE => {
display_attr!(self, f, MessageIntegritySha256)
}
PasswordAlgorithm::TYPE => display_attr!(self, f, PasswordAlgorithm),
Userhash::TYPE => display_attr!(self, f, Userhash),
XorMappedAddress::TYPE => display_attr!(self, f, XorMappedAddress),
PasswordAlgorithms::TYPE => display_attr!(self, f, PasswordAlgorithms),
AlternateDomain::TYPE => display_attr!(self, f, AlternateDomain),
Software::TYPE => display_attr!(self, f, Software),
AlternateServer::TYPE => display_attr!(self, f, AlternateServer),
Fingerprint::TYPE => display_attr!(self, f, Fingerprint),
_ => {
#[cfg(feature = "std")]
{
let mut display_impls = ATTRIBUTE_EXTERNAL_DISPLAY_IMPL
.get_or_init(|| Default::default())
.lock()
.unwrap();
if let Some(imp) = display_impls.get_mut(&self.get_type()) {
return imp(self, f);
}
}
write!(
f,
"RawAttribute (type: {:?}, len: {}, data: {:?})",
self.header.get_type(),
self.header.length(),
&self.value
)
}
}
}
}
impl<'a> RawAttribute<'a> {
pub fn new(atype: AttributeType, data: &'a [u8]) -> Self {
Self {
header: AttributeHeader {
atype,
length: data.len() as u16,
},
value: data.into(),
}
}
pub fn new_owned(atype: AttributeType, data: Box<[u8]>) -> Self {
Self {
header: AttributeHeader {
atype,
length: data.len() as u16,
},
value: data.into(),
}
}
pub fn from_bytes(data: &'a [u8]) -> Result<Self, StunParseError> {
let header = AttributeHeader::parse(data)?;
if header.length() > (data.len() - 4) as u16 {
return Err(StunParseError::Truncated {
expected: header.length() as usize,
actual: data.len() - 4,
});
}
Ok(Self {
header,
value: Data::Borrowed(data[4..header.length() as usize + 4].into()),
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut vec = Vec::with_capacity(self.padded_len());
let mut header_bytes = [0; 4];
self.header.write_into(&mut header_bytes);
vec.extend(&header_bytes);
vec.extend(&*self.value);
let len = vec.len();
if len % 4 != 0 {
vec.resize(len + 4 - (len % 4), 0);
}
vec
}
pub fn check_type_and_len(
&self,
atype: AttributeType,
allowed_range: impl core::ops::RangeBounds<usize>,
) -> Result<(), StunParseError> {
if self.header.get_type() != atype {
return Err(StunParseError::WrongAttributeImplementation);
}
check_len(self.value.len(), allowed_range)
}
pub fn into_owned<'b>(self) -> RawAttribute<'b> {
RawAttribute {
header: self.header,
value: self.value.into_owned(),
}
}
}
impl Attribute for RawAttribute<'_> {
fn get_type(&self) -> AttributeType {
self.header.get_type()
}
fn length(&self) -> u16 {
self.value.len() as u16
}
}
impl AttributeWrite for RawAttribute<'_> {
fn write_into_unchecked(&self, dest: &mut [u8]) {
let len = self.padded_len();
self.header.write_into(dest);
let mut offset = 4;
dest[offset..offset + self.value.len()].copy_from_slice(&self.value);
offset += self.value.len();
if len - offset > 0 {
dest[offset..len].fill(0);
}
}
fn to_raw(&self) -> RawAttribute<'_> {
self.clone()
}
}
impl<'a, A: AttributeWrite> From<&'a A> for RawAttribute<'a> {
fn from(value: &'a A) -> Self {
value.to_raw()
}
}
fn check_len(
len: usize,
allowed_range: impl core::ops::RangeBounds<usize>,
) -> Result<(), StunParseError> {
match allowed_range.start_bound() {
core::ops::Bound::Unbounded => (),
core::ops::Bound::Included(start) => {
if len < *start {
return Err(StunParseError::Truncated {
expected: *start,
actual: len,
});
}
}
core::ops::Bound::Excluded(start) => {
if len <= *start {
return Err(StunParseError::Truncated {
expected: start + 1,
actual: len,
});
}
}
}
match allowed_range.end_bound() {
core::ops::Bound::Unbounded => (),
core::ops::Bound::Included(end) => {
if len > *end {
return Err(StunParseError::TooLarge {
expected: *end,
actual: len,
});
}
}
core::ops::Bound::Excluded(end) => {
if len >= *end {
return Err(StunParseError::TooLarge {
expected: *end - 1,
actual: len,
});
}
}
}
Ok(())
}
impl From<RawAttribute<'_>> for Vec<u8> {
fn from(f: RawAttribute) -> Self {
f.to_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn attribute_type() {
let _log = crate::tests::test_init_log();
let atype = ErrorCode::TYPE;
let anum: u16 = atype.into();
assert_eq!(atype, anum.into());
}
#[test]
fn short_attribute_header() {
let _log = crate::tests::test_init_log();
let data = [0; 1];
let res: Result<AttributeHeader, _> = data.as_ref().try_into();
assert!(res.is_err());
}
#[test]
fn raw_attribute_construct() {
let _log = crate::tests::test_init_log();
let a = RawAttribute::new(1.into(), &[80, 160]);
assert_eq!(a.get_type(), 1.into());
let bytes: Vec<_> = a.into();
assert_eq!(bytes, &[0, 1, 0, 2, 80, 160, 0, 0]);
let b = RawAttribute::from_bytes(bytes.as_ref()).unwrap();
assert_eq!(b.get_type(), 1.into());
}
#[test]
fn raw_attribute_encoding() {
let mut out = [0; 8];
let mut out2 = [0; 8];
let _log = crate::tests::test_init_log();
let orig = RawAttribute::new(1.into(), &[80, 160]);
assert_eq!(orig.get_type(), 1.into());
orig.write_into(&mut out).unwrap();
orig.write_into_unchecked(&mut out2);
assert_eq!(out, out2);
let mut data: Vec<_> = orig.into();
let len = data.len();
BigEndian::write_u16(&mut data[2..4], len as u16 - 4 + 1);
assert!(matches!(
RawAttribute::from_bytes(data.as_ref()),
Err(StunParseError::Truncated {
expected: 5,
actual: 4
})
));
}
#[test]
fn raw_attribute_header() {
let mut out = [0; 4];
let mut out2 = [0; 4];
let _log = crate::tests::test_init_log();
let orig = RawAttribute::new(1.into(), &[80, 160]);
assert!(matches!(orig.write_header(&mut out), Ok(4)));
assert_eq!(orig.write_header_unchecked(&mut out2), 4);
assert_eq!(out, out2);
assert_eq!(orig.header.to_bytes(), out);
assert_eq!(&orig.to_bytes()[..4], out);
let bytes: [_; 4] = orig.header.into();
assert_eq!(bytes, out);
}
#[test]
fn raw_attribute_write_into_small() {
let mut out = [0; 8];
let _log = crate::tests::test_init_log();
let orig = RawAttribute::new(1.into(), &[80, 160]);
assert_eq!(orig.get_type(), 1.into());
assert!(matches!(
orig.write_header(&mut out[..3]),
Err(StunWriteError::TooSmall {
expected: 4,
actual: 3,
})
));
assert!(matches!(
orig.write_into(&mut out[..7]),
Err(StunWriteError::TooSmall {
expected: 8,
actual: 7,
})
));
}
#[test]
fn test_check_len() {
let _log = crate::tests::test_init_log();
assert!(check_len(4, ..).is_ok());
assert!(check_len(4, 0..).is_ok());
assert!(check_len(4, 0..8).is_ok());
assert!(check_len(4, 0..=8).is_ok());
assert!(check_len(4, ..=8).is_ok());
assert!(matches!(
check_len(4, ..4),
Err(StunParseError::TooLarge {
expected: 3,
actual: 4
})
));
assert!(matches!(
check_len(4, 5..),
Err(StunParseError::Truncated {
expected: 5,
actual: 4
})
));
assert!(matches!(
check_len(4, ..=3),
Err(StunParseError::TooLarge {
expected: 3,
actual: 4
})
));
assert!(matches!(
check_len(
4,
(core::ops::Bound::Excluded(4), core::ops::Bound::Unbounded)
),
Err(StunParseError::Truncated {
expected: 5,
actual: 4
})
));
}
#[test]
#[cfg(feature = "std")]
fn test_external_display_impl() {
use crate::message::TransactionId;
let _log = crate::tests::test_init_log();
let atype = AttributeType::new(0xFFFF);
assert_eq!(atype.name(), "unknown");
let data = [4, 0];
let attr = RawAttribute::new(atype, &data);
assert_eq!(
alloc::format!("{attr}"),
"RawAttribute (type: AttributeType(65535), len: 2, data: Borrowed(DataSlice([4, 0])))"
);
let imp = |attr: &RawAttribute<'_>,
f: &mut core::fmt::Formatter<'_>|
-> core::fmt::Result { write!(f, "Custom {}", attr.value[0]) };
add_display_impl(atype, imp);
let display_str = alloc::format!("{}", attr);
assert_eq!(display_str, "Custom 4");
atype.add_name("SOME-NAME");
assert_eq!(atype.name(), "SOME-NAME");
attribute_display!(Fingerprint);
let id = TransactionId::generate();
let xor_addr = XorMappedAddress::new("127.0.0.1:10000".parse().unwrap(), id);
let raw = xor_addr.to_raw();
let raw = RawAttribute::new(raw.get_type(), &raw.value[..3]);
assert_eq!(alloc::format!("{raw}"), "32(0x20: XOR-MAPPED-ADDRESS)(Malformed): len: 3, data: Borrowed(DataSlice([0, 1, 6])))");
}
}