use crate::events::{StunTransactionError, StuntClientEvent, TransactionEventHandler};
use crate::fingerprint::{add_fingerprint_attribute, validate_fingerprint};
use crate::integrity::IntegrityError;
use crate::lt_cred_mech::LongTermCredentialClient;
use crate::message::{create_stun_message, StunAttributes};
use crate::rtt::{RttCalcuator, DEFAULT_GRANULARITY};
use crate::st_cred_mech::ShortTermCredentialClient;
use crate::timeout::{RtoManager, StunMessageTimeout, DEFAULT_RC, DEFAULT_RM, DEFAULT_RTO};
use crate::{CredentialMechanism, StunAgentError, StunPacket};
use log::{debug, info, warn};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use stun_rs::attributes::stun::UserName;
use stun_rs::error::StunEncodeError;
use stun_rs::{
HMACKey, MessageClass, MessageDecoder, MessageEncoder, MessageMethod, StunMessage,
TransactionId,
};
pub const DEFAULT_MAX_TRANSACTIONS: usize = 10;
#[derive(Debug)]
pub enum TransportReliability {
Reliable(Duration),
Unreliable(RttConfig),
}
#[derive(Debug)]
pub struct RttConfig {
pub rto: Duration,
pub granularity: Duration,
pub rm: u32,
pub rc: u32,
}
impl Default for RttConfig {
fn default() -> Self {
Self {
rto: DEFAULT_RTO,
granularity: DEFAULT_GRANULARITY,
rm: DEFAULT_RM,
rc: DEFAULT_RC,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StunClientMessageClass {
Request,
Indication,
}
#[derive(Debug)]
struct StunClientParameters {
user_name: Option<String>,
password: Option<String>,
mechanism: Option<CredentialMechanism>,
reliability: TransportReliability,
fingerprint: bool,
max_transactions: usize,
}
#[derive(Debug)]
pub struct StunClienteBuilder(StunClientParameters);
impl StunClienteBuilder {
pub fn new(reliability: TransportReliability) -> StunClienteBuilder {
Self(StunClientParameters {
user_name: None,
password: None,
mechanism: None,
reliability,
fingerprint: false,
max_transactions: DEFAULT_MAX_TRANSACTIONS,
})
}
pub fn with_max_transactions(mut self, max_transactions: usize) -> Self {
self.0.max_transactions = max_transactions;
self
}
pub fn with_mechanism<U, P>(
mut self,
user_name: U,
password: P,
mechanism: CredentialMechanism,
) -> Self
where
U: Into<String>,
P: Into<String>,
{
self.0.user_name = Some(user_name.into());
self.0.password = Some(password.into());
self.0.mechanism = Some(mechanism);
self
}
pub fn with_fingerprint(mut self) -> Self {
self.0.fingerprint = true;
self
}
pub fn build(self) -> Result<StunClient, StunAgentError> {
StunClient::new(self.0)
}
}
#[derive(Debug)]
enum CredentialMechanismClient {
ShortTerm(ShortTermCredentialClient),
LongTerm(LongTermCredentialClient),
}
impl CredentialMechanismClient {
fn prepare_request(&mut self, attributes: &mut StunAttributes) -> Result<(), StunAgentError> {
match self {
CredentialMechanismClient::ShortTerm(mechanism) => {
mechanism.add_attributes(attributes);
Ok(())
}
CredentialMechanismClient::LongTerm(mechanism) => mechanism.prepare_request(attributes),
}
}
fn prepare_indication(
&mut self,
attributes: &mut StunAttributes,
) -> Result<(), StunAgentError> {
match self {
CredentialMechanismClient::ShortTerm(mechanism) => {
mechanism.add_attributes(attributes);
Ok(())
}
CredentialMechanismClient::LongTerm(mechanism) => {
mechanism.prepare_indication(attributes)
}
}
}
fn recv_message(
&mut self,
raw_data: &[u8],
message: &StunMessage,
) -> Result<(), IntegrityError> {
match self {
CredentialMechanismClient::ShortTerm(mechanism) => {
mechanism.recv_message(raw_data, message)
}
CredentialMechanismClient::LongTerm(mechanism) => {
mechanism.recv_message(raw_data, message)
}
}
}
fn signal_protection_violated_on_timeout(&mut self, transaction_id: &TransactionId) -> bool {
match self {
CredentialMechanismClient::ShortTerm(mechanism) => {
mechanism.signal_protection_violated_on_timeout(transaction_id)
}
CredentialMechanismClient::LongTerm(mechanism) => {
mechanism.signal_protection_violated_on_timeout(transaction_id)
}
}
}
}
#[derive(Debug)]
struct StunTransaction {
instant: Option<Instant>,
packet: StunPacket,
rtos: RtoManager,
}
#[derive(Debug)]
struct RttHandler {
rtt: RttCalcuator,
rm: u32,
rc: u32,
last_request: Option<Instant>,
}
#[derive(Debug)]
enum StunRttCalcuator {
Reliable(Duration),
Unreliable(RttHandler),
}
impl From<TransportReliability> for StunRttCalcuator {
fn from(reliability: TransportReliability) -> Self {
match reliability {
TransportReliability::Reliable(timeout) => StunRttCalcuator::Reliable(timeout),
TransportReliability::Unreliable(config) => StunRttCalcuator::Unreliable(RttHandler {
rtt: RttCalcuator::new(config.rto, config.granularity),
rm: config.rm,
rc: config.rc,
last_request: None,
}),
}
}
}
#[derive(Debug)]
pub struct StunClient {
mechanism: Option<CredentialMechanismClient>,
encoder: MessageEncoder,
decoder: MessageDecoder,
use_fingerprint: bool,
timeouts: StunMessageTimeout,
rtt: StunRttCalcuator,
transactions: HashMap<TransactionId, StunTransaction>,
transaction_events: TransactionEventHandler,
max_transactions: usize,
}
impl StunClient {
fn new(params: StunClientParameters) -> Result<Self, StunAgentError> {
let rtt = StunRttCalcuator::from(params.reliability);
let is_reliable = matches!(rtt, StunRttCalcuator::Reliable(_));
let mechanism = match params.mechanism {
Some(value) => {
let user_name = params.user_name.ok_or_else(|| {
StunAgentError::InternalError(String::from("User name is required"))
})?;
let password = params.password.ok_or_else(|| {
StunAgentError::InternalError(String::from("Password is required"))
})?;
let user_name = UserName::new(user_name).map_err(|e| {
StunAgentError::InternalError(format!("Failed to create user name: {}", e))
})?;
match value {
CredentialMechanism::ShortTerm(integrity) => Some(
CredentialMechanismClient::ShortTerm(ShortTermCredentialClient::new(
user_name,
HMACKey::new_short_term(password).map_err(|e| {
StunAgentError::InternalError(format!(
"Failed to create HMAC key: {}",
e
))
})?,
integrity,
is_reliable,
)),
),
CredentialMechanism::LongTerm => Some(CredentialMechanismClient::LongTerm(
LongTermCredentialClient::new(user_name, password, is_reliable),
)),
}
}
None => None,
};
Ok(Self {
mechanism,
encoder: Default::default(),
decoder: Default::default(),
use_fingerprint: params.fingerprint,
timeouts: StunMessageTimeout::default(),
rtt,
transactions: Default::default(),
transaction_events: Default::default(),
max_transactions: params.max_transactions,
})
}
fn prepare_request(&mut self, attributes: &mut StunAttributes) -> Result<(), StunAgentError> {
prepare_stun_message(
StunClientMessageClass::Request,
attributes,
self.mechanism.as_mut(),
self.use_fingerprint,
)
}
fn prepare_indication(
&mut self,
attributes: &mut StunAttributes,
) -> Result<(), StunAgentError> {
prepare_stun_message(
StunClientMessageClass::Indication,
attributes,
self.mechanism.as_mut(),
self.use_fingerprint,
)
}
fn set_timeout(
&mut self,
transaction_id: TransactionId,
instant: Instant,
) -> Result<RtoManager, StunAgentError> {
let mut rto_manager = match self.rtt {
StunRttCalcuator::Reliable(timeout) => RtoManager::new(timeout, 1, 1),
StunRttCalcuator::Unreliable(ref mut handler) => {
if let Some(last_request) = handler.last_request {
if instant - last_request > Duration::from_secs(600) {
debug!(
"Current RTT value {}ms staled caused by inactivity. Resetting.",
handler.rtt.rto().as_millis()
);
handler.rtt.reset();
}
}
handler.last_request = Some(instant);
RtoManager::new(handler.rtt.rto(), handler.rm, handler.rc)
}
};
let timeout = rto_manager.next_rto(instant).ok_or_else(|| {
StunAgentError::InternalError(String::from("Can not calculate next RTO"))
})?;
self.timeouts.add(instant, timeout, transaction_id);
debug!("[{:?}] Set timeout {:?}", transaction_id, timeout);
Ok(rto_manager)
}
fn transaction_finished(&mut self, transaction_id: &TransactionId, instant: Instant) {
self.timeouts.remove(transaction_id);
let Some(transaction) = self.transactions.remove(transaction_id) else {
debug!("[{:?}] Not found", transaction_id);
return;
};
let Some(sent_instant) = transaction.instant else {
return;
};
if let StunRttCalcuator::Unreliable(handler) = &mut self.rtt {
let new_rtt = instant - sent_instant;
debug!(
"[{:?}] RTT calculation: sent={:?}, recv={:?}, rtt={:?}",
transaction_id, sent_instant, instant, new_rtt
);
handler.rtt.update(new_rtt);
}
}
pub fn send_request(
&mut self,
method: MessageMethod,
mut attributes: StunAttributes,
buffer: Vec<u8>,
instant: Instant,
) -> Result<TransactionId, StunAgentError> {
if self.transactions.len() >= self.max_transactions {
return Err(StunAgentError::MaxOutstandingRequestsReached);
}
self.prepare_request(&mut attributes)?;
let msg = create_stun_message(method, MessageClass::Request, None, attributes);
let packet = encode_buffer(&self.encoder, &msg, buffer).map_err(|e| {
StunAgentError::InternalError(format!("Failed to encode request message: {}", e))
})?;
let transaction = StunTransaction {
instant: Some(instant),
packet: packet.clone(),
rtos: self.set_timeout(*msg.transaction_id(), instant)?,
};
self.transactions.insert(*msg.transaction_id(), transaction);
let mut events = self.transaction_events.init();
events.push(StuntClientEvent::OutputPacket(packet));
if let Some((id, left)) = self.timeouts.next_timeout(instant) {
events.push(StuntClientEvent::RestransmissionTimeOut((id, left)));
}
Ok(*msg.transaction_id())
}
pub fn send_indication(
&mut self,
method: MessageMethod,
mut attributes: StunAttributes,
buffer: Vec<u8>,
) -> Result<TransactionId, StunAgentError> {
self.prepare_indication(&mut attributes)?;
let msg = create_stun_message(method, MessageClass::Indication, None, attributes);
let packet = encode_buffer(&self.encoder, &msg, buffer).map_err(|e| {
StunAgentError::InternalError(format!("Failed to encode indication message: {}", e))
})?;
let mut events = self.transaction_events.init();
events.push(StuntClientEvent::OutputPacket(packet));
Ok(*msg.transaction_id())
}
pub fn on_buffer_recv(
&mut self,
buffer: &[u8],
instant: Instant,
) -> Result<(), StunAgentError> {
let (msg, _) = self.decoder.decode(buffer).map_err(|e| {
StunAgentError::InternalError(format!("Failed to decode message: {}", e))
})?;
match msg.class() {
MessageClass::Request => {
debug!(
"Received STUN request with {:?}. Discarding.",
msg.transaction_id()
);
return Err(StunAgentError::Discarded);
}
MessageClass::Indication => {
debug!("Received STUN indication with {:?}", msg.transaction_id());
}
MessageClass::SuccessResponse | MessageClass::ErrorResponse => {
if !self.transactions.contains_key(msg.transaction_id()) {
debug!(
"Received response with no matching {:?}. Discarding.",
msg.transaction_id()
);
return Err(StunAgentError::Discarded);
}
}
}
if self.use_fingerprint && !validate_fingerprint(buffer, &msg)? {
debug!(
"[{:?}] Fingerprint validation failed. Discarding.",
msg.transaction_id()
);
return Err(StunAgentError::Discarded);
}
let mut integrity_event = None;
if let Some(mechanism) = &mut self.mechanism {
if let Err(e) = mechanism.recv_message(buffer, &msg) {
integrity_event = process_integrity_error(e, msg.transaction_id())?;
}
}
if msg.class() != MessageClass::Indication {
self.transaction_finished(msg.transaction_id(), instant);
}
let mut events = self.transaction_events.init();
match integrity_event {
Some(event) => {
events.push(event);
}
None => {
events.push(StuntClientEvent::StunMessageReceived(msg));
}
}
Ok(())
}
pub fn on_timeout(&mut self, instant: Instant) {
let timed_out = self.timeouts.check(instant);
let mut events = self.transaction_events.init();
for transaction_id in timed_out {
if let Some(transaction) = self.transactions.get_mut(&transaction_id) {
match transaction.rtos.next_rto(instant) {
Some(rto) => {
transaction.instant = None;
self.timeouts.add(instant, rto, transaction_id);
debug!("set timeout {:?} for transaction {:?}", rto, transaction_id);
events.push(StuntClientEvent::OutputPacket(transaction.packet.clone()));
}
None => {
let protection_violated = self.mechanism.as_mut().map_or(false, |m| {
m.signal_protection_violated_on_timeout(&transaction_id)
});
let event = if protection_violated {
StuntClientEvent::TransactionFailed((
transaction_id,
StunTransactionError::ProtectionViolated,
))
} else {
StuntClientEvent::TransactionFailed((
transaction_id,
StunTransactionError::TimedOut,
))
};
info!(
"Transaction {:?} timed out. Event: {:?}",
transaction_id, event
);
events.push(event);
}
}
} else {
warn!("Transaction {:?} not found", transaction_id);
}
}
if let Some((id, left)) = self.timeouts.next_timeout(instant) {
events.push(StuntClientEvent::RestransmissionTimeOut((id, left)));
}
}
pub fn events(&mut self) -> Vec<StuntClientEvent> {
self.transaction_events.events()
}
}
fn process_integrity_error(
error: IntegrityError,
transaction_id: &TransactionId,
) -> Result<Option<StuntClientEvent>, StunAgentError> {
match error {
IntegrityError::ProtectionViolated => Ok(Some(StuntClientEvent::TransactionFailed((
*transaction_id,
StunTransactionError::ProtectionViolated,
)))),
IntegrityError::Retry => Ok(Some(StuntClientEvent::Retry(*transaction_id))),
IntegrityError::NotRetryable => Ok(Some(StuntClientEvent::TransactionFailed((
*transaction_id,
StunTransactionError::DoNotRetry,
)))),
IntegrityError::Discarded => {
Err(StunAgentError::Discarded)
}
}
}
fn prepare_stun_message(
class: StunClientMessageClass,
attributes: &mut StunAttributes,
mechanism: Option<&mut CredentialMechanismClient>,
use_fingerprint: bool,
) -> Result<(), StunAgentError> {
if let Some(mechanism) = mechanism {
match class {
StunClientMessageClass::Request => mechanism.prepare_request(attributes)?,
StunClientMessageClass::Indication => mechanism.prepare_indication(attributes)?,
}
}
if use_fingerprint {
add_fingerprint_attribute(attributes);
}
Ok(())
}
fn encode_buffer(
encoder: &MessageEncoder,
msg: &StunMessage,
mut buffer: Vec<u8>,
) -> Result<StunPacket, StunEncodeError> {
let size = encoder.encode(&mut buffer, msg)?;
Ok(StunPacket::new(buffer, size))
}
#[cfg(test)]
mod stun_client_tests {
use super::*;
fn init_logging() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[test]
fn test_stun_client_builder() {
init_logging();
let client =
StunClienteBuilder::new(TransportReliability::Reliable(Duration::from_secs(5)))
.with_max_transactions(5)
.with_mechanism("user", "password", CredentialMechanism::ShortTerm(None))
.with_fingerprint()
.build()
.expect("Could not create STUN client");
assert_eq!(client.max_transactions, 5);
assert!(matches!(
client.mechanism,
Some(CredentialMechanismClient::ShortTerm(_))
));
assert!(client.use_fingerprint);
let error = StunClienteBuilder::new(TransportReliability::Reliable(Duration::from_secs(5)))
.with_max_transactions(5)
.with_mechanism(
"bad\u{0009}user",
"password",
CredentialMechanism::ShortTerm(None),
)
.with_fingerprint()
.build()
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
let error = StunClienteBuilder::new(TransportReliability::Reliable(Duration::from_secs(5)))
.with_max_transactions(5)
.with_mechanism(
"user",
"bad\u{0009}password",
CredentialMechanism::ShortTerm(None),
)
.with_fingerprint()
.build()
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
}
#[test]
fn test_stun_client_constructor() {
init_logging();
let client = StunClient::new(StunClientParameters {
user_name: Some(String::from("user")),
password: Some(String::from("password")),
mechanism: Some(CredentialMechanism::ShortTerm(None)),
reliability: TransportReliability::Reliable(Duration::from_secs(5)),
fingerprint: true,
max_transactions: 5,
})
.expect("Could not create STUN client");
assert_eq!(client.max_transactions, 5);
assert!(matches!(
client.mechanism,
Some(CredentialMechanismClient::ShortTerm(_))
));
assert!(client.use_fingerprint);
let error = StunClient::new(StunClientParameters {
user_name: Some(String::from("bad\u{0009}user")),
password: Some(String::from("password")),
mechanism: Some(CredentialMechanism::ShortTerm(None)),
reliability: TransportReliability::Reliable(Duration::from_secs(5)),
fingerprint: true,
max_transactions: 5,
})
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
let error = StunClient::new(StunClientParameters {
user_name: Some(String::from("user")),
password: Some(String::from("bad\u{0009}password")),
mechanism: Some(CredentialMechanism::ShortTerm(None)),
reliability: TransportReliability::Reliable(Duration::from_secs(5)),
fingerprint: true,
max_transactions: 5,
})
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
let error = StunClient::new(StunClientParameters {
user_name: None,
password: Some(String::from("password")),
mechanism: Some(CredentialMechanism::ShortTerm(None)),
reliability: TransportReliability::Reliable(Duration::from_secs(5)),
fingerprint: true,
max_transactions: 5,
})
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
let error = StunClient::new(StunClientParameters {
user_name: Some(String::from("user")),
password: None,
mechanism: Some(CredentialMechanism::ShortTerm(None)),
reliability: TransportReliability::Reliable(Duration::from_secs(5)),
fingerprint: true,
max_transactions: 5,
})
.expect_err("Should not create STUN client");
assert!(matches!(error, StunAgentError::InternalError(_)));
}
#[test]
fn test_stun_client_transaction_finished_unknown_transaction_id() {
init_logging();
let mut client =
StunClienteBuilder::new(TransportReliability::Reliable(Duration::from_secs(5)))
.with_max_transactions(5)
.with_mechanism("user", "password", CredentialMechanism::ShortTerm(None))
.with_fingerprint()
.build()
.expect("Could not create STUN client");
assert_eq!(client.transactions.len(), 0);
let transanction_id = TransactionId::default();
client.transaction_finished(&transanction_id, Instant::now());
assert_eq!(client.transactions.len(), 0);
}
}