#![deny(missing_docs)]
use std::{ops::Deref, slice::Iter, sync::Arc};
use stun_rs::MessageHeader;
use stun_rs::StunAttribute;
use stun_rs::MESSAGE_HEADER_SIZE;
mod client;
mod events;
mod fingerprint;
mod integrity;
mod lt_cred_mech;
mod message;
mod rtt;
mod st_cred_mech;
mod timeout;
pub use crate::client::RttConfig;
pub use crate::client::StunClient;
pub use crate::client::StunClienteBuilder;
pub use crate::client::TransportReliability;
pub use crate::events::StunTransactionError;
pub use crate::events::StuntClientEvent;
pub use crate::message::StunAttributes;
#[derive(Debug, PartialEq, Eq)]
pub enum StunAgentError {
Discarded,
FingerPrintValidationFailed,
Ignored,
MaxOutstandingRequestsReached,
StunCheckFailed,
InternalError(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Integrity {
MessageIntegrity,
MessageIntegritySha256,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CredentialMechanism {
ShortTerm(Option<Integrity>),
LongTerm,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct StunPacketInternal {
buffer: Vec<u8>,
size: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StunPacket(Arc<StunPacketInternal>);
impl StunPacket {
pub(crate) fn new(buffer: Vec<u8>, size: usize) -> Self {
let internal = StunPacketInternal { buffer, size };
StunPacket(Arc::new(internal))
}
}
impl Deref for StunPacket {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0.buffer[..self.0.size]
}
}
impl AsRef<[u8]> for StunPacket {
fn as_ref(&self) -> &[u8] {
self
}
}
#[derive(Debug)]
pub struct StunPacketDecoder {
buffer: Vec<u8>,
current_size: usize,
expected_size: Option<usize>,
}
#[derive(Debug)]
pub enum StunPacketDecodedValue {
Decoded((StunPacket, usize)),
MoreBytesNeeded((StunPacketDecoder, Option<usize>)),
}
#[derive(Debug)]
pub enum StunPacketErrorType {
SmallBuffer,
InvalidStunPacket,
}
#[derive(Debug)]
pub struct StunPacketDecodedError {
pub error_type: StunPacketErrorType,
pub buffer: Vec<u8>,
pub size: usize,
pub consumed: usize,
}
impl StunPacketDecoder {
pub fn new(buffer: Vec<u8>) -> Result<Self, StunPacketDecodedError> {
if buffer.len() < MESSAGE_HEADER_SIZE {
return Err(StunPacketDecodedError {
error_type: StunPacketErrorType::SmallBuffer,
buffer,
size: 0,
consumed: 0,
});
}
Ok(StunPacketDecoder {
buffer,
current_size: 0,
expected_size: None,
})
}
pub fn decode(mut self, data: &[u8]) -> Result<StunPacketDecodedValue, StunPacketDecodedError> {
match self.expected_size {
Some(size) => {
let first = self.current_size;
let remaining = size - first;
if data.len() >= remaining {
self.buffer[first..size].copy_from_slice(&data[..remaining]);
let packet = StunPacket::new(self.buffer, size);
Ok(StunPacketDecodedValue::Decoded((packet, remaining)))
} else {
self.buffer[first..first + data.len()].copy_from_slice(&data[..data.len()]);
self.current_size += data.len();
Ok(StunPacketDecodedValue::MoreBytesNeeded((
self,
Some(remaining - data.len()),
)))
}
}
None => {
let header_length = self.current_size + data.len();
if header_length >= MESSAGE_HEADER_SIZE {
let first = self.current_size;
let remaining = MESSAGE_HEADER_SIZE - first;
self.buffer[first..first + remaining].copy_from_slice(&data[..remaining]);
let slice: &[u8; MESSAGE_HEADER_SIZE] =
self.buffer[..MESSAGE_HEADER_SIZE].try_into().unwrap();
let Ok(header) = MessageHeader::try_from(slice) else {
return Err(StunPacketDecodedError {
error_type: StunPacketErrorType::InvalidStunPacket,
buffer: self.buffer,
size: MESSAGE_HEADER_SIZE,
consumed: remaining,
});
};
let msg_length = header.msg_length as usize;
if self.buffer.len() < msg_length + MESSAGE_HEADER_SIZE {
return Err(StunPacketDecodedError {
error_type: StunPacketErrorType::SmallBuffer,
buffer: self.buffer,
size: MESSAGE_HEADER_SIZE,
consumed: remaining,
});
}
self.expected_size = Some(msg_length + MESSAGE_HEADER_SIZE);
if data.len() >= msg_length + remaining {
self.buffer[MESSAGE_HEADER_SIZE..MESSAGE_HEADER_SIZE + msg_length]
.copy_from_slice(&data[remaining..remaining + msg_length]);
let packet = StunPacket::new(self.buffer, msg_length + MESSAGE_HEADER_SIZE);
Ok(StunPacketDecodedValue::Decoded((
packet,
remaining + msg_length,
)))
} else {
self.buffer
[MESSAGE_HEADER_SIZE..MESSAGE_HEADER_SIZE + data.len() - remaining]
.copy_from_slice(&data[remaining..data.len()]);
self.current_size += data.len();
let remaining = msg_length + MESSAGE_HEADER_SIZE - self.current_size;
Ok(StunPacketDecodedValue::MoreBytesNeeded((
self,
Some(remaining),
)))
}
} else {
let first = self.current_size;
let remaining = data.len();
self.buffer[first..first + remaining].copy_from_slice(&data[..remaining]);
self.current_size += data.len();
Ok(StunPacketDecodedValue::MoreBytesNeeded((self, None)))
}
}
}
}
}
#[derive(Debug)]
struct ProtectedAttributeIteratorObject<'a> {
iter: Iter<'a, StunAttribute>,
integrity: bool,
integrity_sha256: bool,
fingerprint: bool,
}
trait ProtectedAttributeIterator<'a> {
fn protected_iter(&self) -> ProtectedAttributeIteratorObject<'a>;
}
impl<'a> ProtectedAttributeIterator<'a> for &'a [StunAttribute] {
fn protected_iter(&self) -> ProtectedAttributeIteratorObject<'a> {
ProtectedAttributeIteratorObject {
iter: self.iter(),
integrity: false,
integrity_sha256: false,
fingerprint: false,
}
}
}
impl<'a> Iterator for ProtectedAttributeIteratorObject<'a> {
type Item = &'a StunAttribute;
fn next(&mut self) -> Option<Self::Item> {
for attr in &mut self.iter {
if attr.is_message_integrity() {
if self.integrity || self.integrity_sha256 || self.fingerprint {
continue;
}
self.integrity = true;
} else if attr.is_message_integrity_sha256() {
if self.integrity_sha256 || self.fingerprint {
continue;
}
self.integrity_sha256 = true;
} else if attr.is_fingerprint() {
if self.fingerprint {
continue;
}
self.fingerprint = true;
} else if self.integrity || self.integrity_sha256 || self.fingerprint {
continue;
}
return Some(attr);
}
None
}
}
#[cfg(test)]
mod tests_stun_packet {
use super::*;
#[test]
fn test_stun_packet() {
let buffer = vec![0; 10];
assert_eq!(buffer.len(), 10);
let packet = StunPacket::new(buffer, 5);
assert_eq!(packet.as_ref().len(), 5);
let buffer = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let packet = StunPacket::new(buffer, 5);
assert_eq!(packet.len(), 5);
assert_eq!(packet.as_ref(), &[0, 1, 2, 3, 4]);
}
}
#[cfg(test)]
mod tests_protected_iterator {
use super::*;
use stun_rs::{
attributes::stun::{
Fingerprint, MessageIntegrity, MessageIntegritySha256, Nonce, Realm, UserName,
},
methods::BINDING,
Algorithm, AlgorithmId, HMACKey, MessageClass, StunMessageBuilder,
};
const USERNAME: &str = "test-username";
const NONCE: &str = "test-nonce";
const REALM: &str = "test-realm";
const PASSWORD: &str = "test-password";
#[test]
fn test_protected_iterator() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(username)
.with_attribute(nonce)
.with_attribute(realm)
.with_attribute(integrity)
.with_attribute(integrity_sha256)
.with_attribute(Fingerprint::default())
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute UserName");
assert!(attr.is_user_name());
let attr = iter.next().expect("Expected attribute Nonce");
assert!(attr.is_nonce());
let attr = iter.next().expect("Expected attribute Realm");
assert!(attr.is_realm());
let attr = iter.next().expect("Expected attribute MessageIntegrity");
assert!(attr.is_message_integrity());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_only_message_integrity() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(integrity)
.with_attribute(username)
.with_attribute(nonce)
.with_attribute(realm)
.with_attribute(integrity_sha256)
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute MessageIntegrity");
assert!(attr.is_message_integrity());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_skip_non_protected() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(username)
.with_attribute(integrity)
.with_attribute(nonce)
.with_attribute(integrity_sha256)
.with_attribute(realm)
.with_attribute(Fingerprint::default())
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute UserName");
assert!(attr.is_user_name());
let attr = iter.next().expect("Expected attribute MessageIntegrity");
assert!(attr.is_message_integrity());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_skip_message_integrity() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(username)
.with_attribute(integrity_sha256)
.with_attribute(nonce)
.with_attribute(integrity)
.with_attribute(realm)
.with_attribute(Fingerprint::default())
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute UserName");
assert!(attr.is_user_name());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_skip_message_integrity_sha256() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(username)
.with_attribute(Fingerprint::default())
.with_attribute(nonce)
.with_attribute(integrity_sha256)
.with_attribute(integrity)
.with_attribute(realm)
.with_attribute(Fingerprint::default())
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute UserName");
assert!(attr.is_user_name());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_skip_duplicated_integrity_attrs() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let realm = Realm::new(REALM).expect("Failed to create realm");
let algorithm = Algorithm::from(AlgorithmId::MD5);
let key = HMACKey::new_long_term(&username, realm, PASSWORD, algorithm)
.expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(username)
.with_attribute(integrity.clone())
.with_attribute(integrity)
.with_attribute(integrity_sha256.clone())
.with_attribute(integrity_sha256)
.with_attribute(Fingerprint::default())
.with_attribute(Fingerprint::default())
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute UserName");
assert!(attr.is_user_name());
let attr = iter.next().expect("Expected attribute MessageIntegrity");
assert!(attr.is_message_integrity());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
#[test]
fn test_protected_iterator_skip_corner_cases() {
let username = UserName::new(USERNAME).expect("Failed to create username");
let key = HMACKey::new_short_term("test-password").expect("Failed to create HMACKey");
let integrity = MessageIntegrity::new(key.clone());
let integrity_sha256 = MessageIntegritySha256::new(key);
let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
.with_attribute(integrity.clone())
.with_attribute(integrity.clone())
.with_attribute(integrity_sha256.clone())
.with_attribute(integrity.clone())
.with_attribute(integrity_sha256.clone())
.with_attribute(Fingerprint::default())
.with_attribute(integrity)
.with_attribute(integrity_sha256)
.with_attribute(Fingerprint::default())
.with_attribute(username)
.build();
let mut iter = msg.attributes().protected_iter();
let attr = iter.next().expect("Expected attribute MessageIntegrity");
assert!(attr.is_message_integrity());
let attr = iter
.next()
.expect("Expected attribute MessageIntegritySha256");
assert!(attr.is_message_integrity_sha256());
let attr = iter.next().expect("Expected attribute FingerPrint");
assert!(attr.is_fingerprint());
assert!(iter.next().is_none());
}
}
#[cfg(test)]
mod test_stun_packet_decoder {
use super::*;
use stun_vectors::SAMPLE_IPV4_RESPONSE;
#[test]
fn test_stun_packet_decoder_small_parts() {
let buffer = vec![0; 1024];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let mut index = 0;
let data = &SAMPLE_IPV4_RESPONSE[index..10];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, None);
assert_eq!(decoder.current_size, 10);
assert!(decoder.expected_size.is_none());
index = 10;
let data = &SAMPLE_IPV4_RESPONSE[index..15];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, None);
assert_eq!(decoder.current_size, 15);
assert!(decoder.expected_size.is_none());
assert_eq!(decoder.buffer[..15], SAMPLE_IPV4_RESPONSE[..15]);
index = 15;
let data = &SAMPLE_IPV4_RESPONSE[index..index + 5];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, Some(60));
assert_eq!(decoder.current_size, 20);
assert_eq!(decoder.expected_size, Some(60 + MESSAGE_HEADER_SIZE));
assert_eq!(decoder.buffer[..20], SAMPLE_IPV4_RESPONSE[..20]);
index = 20;
let data = &SAMPLE_IPV4_RESPONSE[index..index + 30];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, Some(30));
assert_eq!(decoder.current_size, 50);
assert_eq!(decoder.buffer[..50], SAMPLE_IPV4_RESPONSE[..50]);
index = 50;
let data = &SAMPLE_IPV4_RESPONSE[index..index + 29];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, Some(1));
assert_eq!(decoder.current_size, 79);
assert_eq!(decoder.buffer[..79], SAMPLE_IPV4_RESPONSE[..79]);
index = 79;
let data = &SAMPLE_IPV4_RESPONSE[index..index + 1];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
panic!("Stun packed not decoded");
};
assert_eq!(consumed, 1);
assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
}
#[test]
fn test_stun_packet_decoder_one_step() {
let buffer = vec![0; 1024];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let decoded = decoder
.decode(&SAMPLE_IPV4_RESPONSE)
.expect("Failed to decode");
let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
panic!("Stun packed not decoded");
};
assert_eq!(consumed, SAMPLE_IPV4_RESPONSE.len());
assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
}
#[test]
fn test_stun_packet_decoder_two_step() {
let buffer = vec![0; 1024];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let data = &SAMPLE_IPV4_RESPONSE[..15];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
assert_eq!(remaining, None);
assert_eq!(decoder.current_size, 15);
assert!(decoder.expected_size.is_none());
let data = &SAMPLE_IPV4_RESPONSE[15..];
let decoded = decoder.decode(data).expect("Failed to decode");
let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
panic!("Stun packed not decoded");
};
assert_eq!(consumed, data.len());
assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
}
#[test]
fn test_stun_packet_decoder_byte_by_byte() {
let buffer = vec![0; 1024];
let mut decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let total = SAMPLE_IPV4_RESPONSE.len();
for index in 0..total {
let data = &SAMPLE_IPV4_RESPONSE[index..index + 1];
let decoded = decoder.decode(data).expect("Failed to decode");
if index < total - 1 {
let StunPacketDecodedValue::MoreBytesNeeded((deco, remaining)) = decoded else {
panic!("Expected more bytes needed");
};
if index >= MESSAGE_HEADER_SIZE - 1 {
assert_eq!(remaining, Some(total - 1 - index));
} else {
assert_eq!(remaining, None);
}
decoder = deco;
} else {
let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
panic!("Stun packed not decoded");
};
assert_eq!(consumed, 1);
assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
break;
}
}
}
#[test]
fn test_stun_packet_decoder_small_buffer() {
let buffer = vec![0; 10];
let error = StunPacketDecoder::new(buffer).expect_err("Expected small buffer error");
let StunPacketErrorType::SmallBuffer = error.error_type else {
panic!("Expected small buffer error");
};
let buffer = vec![0; 50];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let result = decoder
.decode(&SAMPLE_IPV4_RESPONSE[..10])
.expect("Failed to decode");
let StunPacketDecodedValue::MoreBytesNeeded((decoder, None)) = result else {
panic!("Expected more bytes needed");
};
let error = decoder
.decode(&SAMPLE_IPV4_RESPONSE[10..])
.expect_err("Expected error");
let StunPacketErrorType::SmallBuffer = error.error_type else {
panic!("Expected small buffer error");
};
let buffer = vec![0; 50];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let error = decoder
.decode(&SAMPLE_IPV4_RESPONSE)
.expect_err("Expected error");
let StunPacketErrorType::SmallBuffer = error.error_type else {
panic!("Expected small buffer error");
};
}
#[test]
fn test_stun_packet_decoder_invalid_stun_packet() {
let buffer = vec![0; 1024];
let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
let data = vec![0; 1024];
let error = decoder.decode(&data).expect_err("Expected error");
let StunPacketErrorType::InvalidStunPacket = error.error_type else {
panic!("Expected invalid STUN packet error");
};
}
}