use core::net::SocketAddr;
use core::sync::atomic::{AtomicUsize, Ordering};
use alloc::collections::{BTreeMap, BTreeSet};
use alloc::vec;
use alloc::vec::Vec;
use core::time::Duration;
use crate::Instant;
use stun_types::attribute::*;
use stun_types::data::Data;
use stun_types::message::*;
use stun_types::TransportType;
use tracing::{debug, trace, warn};
static STUN_AGENT_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
pub struct StunAgent {
id: usize,
transport: TransportType,
local_addr: SocketAddr,
remote_addr: Option<SocketAddr>,
validated_peers: BTreeSet<SocketAddr>,
outstanding_requests: BTreeMap<TransactionId, StunRequestState>,
request_timeouts: Vec<Duration>,
last_retransmit_timeout: Duration,
}
#[derive(Debug)]
pub struct StunAgentBuilder {
transport: TransportType,
local_addr: SocketAddr,
remote_addr: Option<SocketAddr>,
rto: RequestRto,
}
impl StunAgentBuilder {
pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self
}
pub fn request_retransmits(
mut self,
initial: Duration,
max: Duration,
retransmits: u32,
final_retransmit_timeout: Duration,
) -> Self {
self.rto.initial = initial;
self.rto.max = max;
self.rto.retransmits = retransmits;
self.rto.last_retransmit = final_retransmit_timeout;
self
}
pub fn build(self) -> StunAgent {
let id = STUN_AGENT_COUNT.fetch_add(1, Ordering::SeqCst);
let (request_timeouts, last_retransmit_timeout) =
self.rto.calculate_timeouts(self.transport);
StunAgent {
id,
transport: self.transport,
local_addr: self.local_addr,
remote_addr: self.remote_addr,
validated_peers: Default::default(),
outstanding_requests: Default::default(),
request_timeouts,
last_retransmit_timeout,
}
}
}
impl StunAgent {
pub fn builder(transport: TransportType, local_addr: SocketAddr) -> StunAgentBuilder {
StunAgentBuilder {
transport,
local_addr,
remote_addr: None,
rto: Default::default(),
}
}
pub fn transport(&self) -> TransportType {
self.transport
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
pub fn send_data<T: AsRef<[u8]>>(&self, bytes: T, to: SocketAddr) -> Transmit<T> {
send_data(self.transport, bytes, self.local_addr, to)
}
#[tracing::instrument(name = "stun_agent_send",
skip(self, msg),
fields(
transport = %self.transport,
from = %self.local_addr,
transaction_id,
)
)]
pub fn send<T: AsRef<[u8]>>(
&mut self,
msg: T,
to: SocketAddr,
now: Instant,
) -> Result<Transmit<T>, StunError> {
let data = msg.as_ref();
let hdr = MessageHeader::from_bytes(data)?;
tracing::Span::current().record(
"transaction_id",
tracing::field::display(hdr.transaction_id()),
);
assert!(!hdr.get_type().has_class(MessageClass::Request));
trace!("Sending {} to {to}", hdr.get_type());
Ok(Transmit::new(msg, self.transport, self.local_addr, to))
}
#[tracing::instrument(name = "stun_agent_send_request",
skip(self, msg),
fields(
transport = %self.transport,
from = %self.local_addr,
transaction_id,
)
)]
pub fn send_request<'a, T: AsRef<[u8]>>(
&'a mut self,
msg: T,
to: SocketAddr,
now: Instant,
) -> Result<Transmit<Data<'a>>, StunError> {
let data = msg.as_ref();
let hdr = MessageHeader::from_bytes(data)?;
assert!(hdr.get_type().has_class(MessageClass::Request));
let transaction_id = hdr.transaction_id();
tracing::Span::current().record("transaction_id", tracing::field::display(transaction_id));
let state = match self.outstanding_requests.entry(transaction_id) {
alloc::collections::btree_map::Entry::Vacant(entry) => {
let integrity_algorithm = MessageAttributesIter::new(data)
.filter_map(|(_offset, attr)| match attr.get_type() {
MessageIntegrity::TYPE => Some(IntegrityAlgorithm::Sha1),
MessageIntegritySha256::TYPE => Some(IntegrityAlgorithm::Sha256),
_ => None,
})
.last();
trace!("Adding request to {to} with integrity algorithm: {integrity_algorithm:?}");
entry.insert(StunRequestState::new(
msg,
self.transport,
self.local_addr,
to,
transaction_id,
integrity_algorithm,
self.request_timeouts.clone(),
self.last_retransmit_timeout,
))
}
alloc::collections::btree_map::Entry::Occupied(_entry) => {
return Err(StunError::AlreadyInProgress);
}
};
let Some(transmit) = state.poll_transmit(now) else {
unreachable!();
};
Ok(Transmit::new(
Data::from(transmit.data),
transmit.transport,
transmit.from,
transmit.to,
))
}
pub fn is_validated_peer(&self, remote_addr: SocketAddr) -> bool {
self.validated_peers.contains(&remote_addr)
}
#[tracing::instrument(
name = "stun_validated_peer"
skip(self),
fields(stun_id = self.id)
)]
pub fn validated_peer(&mut self, addr: SocketAddr) {
if !self.validated_peers.contains(&addr) {
debug!("validated peer {:?}", addr);
self.validated_peers.insert(addr);
}
}
#[tracing::instrument(
name = "stun_handle_message"
skip(self, msg, from),
fields(
transaction_id = %msg.transaction_id(),
)
)]
pub fn handle_stun_message(&mut self, msg: &Message<'_>, from: SocketAddr) -> bool {
if msg.is_response()
&& self
.take_outstanding_request(&msg.transaction_id())
.is_none()
{
trace!("original request disappeared");
return false;
}
self.validated_peer(from);
true
}
#[tracing::instrument(
skip(self, transaction_id),
fields(transaction_id = %transaction_id)
)]
fn take_outstanding_request(
&mut self,
transaction_id: &TransactionId,
) -> Option<StunRequestState> {
if let Some(request) = self.outstanding_requests.remove(transaction_id) {
trace!("removing request");
Some(request)
} else {
trace!("no outstanding request");
None
}
}
pub fn request_transaction(&self, transaction_id: TransactionId) -> Option<StunRequest<'_>> {
if self.outstanding_requests.contains_key(&transaction_id) {
Some(StunRequest {
agent: self,
transaction_id,
})
} else {
None
}
}
pub fn mut_request_transaction(
&mut self,
transaction_id: TransactionId,
) -> Option<StunRequestMut<'_>> {
if self.outstanding_requests.contains_key(&transaction_id) {
Some(StunRequestMut {
agent: self,
transaction_id,
})
} else {
None
}
}
fn mut_request_state(
&mut self,
transaction_id: TransactionId,
) -> Option<&mut StunRequestState> {
self.outstanding_requests.get_mut(&transaction_id)
}
fn request_state(&self, transaction_id: TransactionId) -> Option<&StunRequestState> {
self.outstanding_requests.get(&transaction_id)
}
#[tracing::instrument(
name = "stun_agent_poll"
level = "debug",
skip(self),
)]
pub fn poll(&mut self, now: Instant) -> StunAgentPollRet {
let mut lowest_wait = now + Duration::from_secs(3600);
let mut timeout = None;
let mut cancelled = None;
for (transaction_id, request) in self.outstanding_requests.iter_mut() {
debug_assert_eq!(transaction_id, &request.transaction_id);
match request.poll(now) {
StunRequestPollRet::Cancelled => {
cancelled = Some(*transaction_id);
break;
}
StunRequestPollRet::WaitUntil(wait_until) => {
if wait_until < lowest_wait {
lowest_wait = wait_until;
}
}
StunRequestPollRet::TimedOut => {
timeout = Some(*transaction_id);
break;
}
}
}
if let Some(transaction) = timeout {
if let Some(_state) = self.outstanding_requests.remove(&transaction) {
return StunAgentPollRet::TransactionTimedOut(transaction);
}
}
if let Some(transaction) = cancelled {
if let Some(_state) = self.outstanding_requests.remove(&transaction) {
return StunAgentPollRet::TransactionCancelled(transaction);
}
}
StunAgentPollRet::WaitUntil(lowest_wait)
}
#[tracing::instrument(
name = "stun_agent_poll_transmit"
level = "debug",
skip(self),
)]
pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
self.outstanding_requests
.values_mut()
.filter_map(|request| request.poll_transmit(now))
.next()
}
}
#[derive(Debug)]
pub enum StunAgentPollRet {
TransactionTimedOut(TransactionId),
TransactionCancelled(TransactionId),
WaitUntil(Instant),
}
fn send_data<T: AsRef<[u8]>>(
transport: TransportType,
bytes: T,
from: SocketAddr,
to: SocketAddr,
) -> Transmit<T> {
Transmit::new(bytes, transport, from, to)
}
#[derive(Debug)]
pub struct Transmit<T: AsRef<[u8]>> {
pub data: T,
pub transport: TransportType,
pub from: SocketAddr,
pub to: SocketAddr,
}
impl<T: AsRef<[u8]>> core::fmt::Display for Transmit<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Transmit({}: {} -> {} of {} bytes)",
self.transport,
self.from,
self.to,
self.data.as_ref().len()
)
}
}
impl<T: AsRef<[u8]>> Transmit<T> {
pub fn new(data: T, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self {
Self {
data,
transport,
from,
to,
}
}
pub fn reinterpret_data<O: AsRef<[u8]>, F: FnOnce(T) -> O>(self, f: F) -> Transmit<O> {
Transmit {
data: f(self.data),
transport: self.transport,
from: self.from,
to: self.to,
}
}
}
impl Transmit<Data<'_>> {
pub fn into_owned<'b>(self) -> Transmit<Data<'b>> {
self.reinterpret_data(|data| data.into_owned())
}
}
#[derive(Debug)]
enum StunRequestPollRet {
WaitUntil(Instant),
Cancelled,
TimedOut,
}
#[derive(Debug)]
struct RequestRto {
initial: Duration,
max: Duration,
retransmits: u32,
last_retransmit: Duration,
}
impl Default for RequestRto {
fn default() -> Self {
Self {
initial: Duration::from_millis(500),
max: Duration::MAX,
retransmits: 7,
last_retransmit: Duration::from_millis(8),
}
}
}
impl RequestRto {
fn calculate_timeouts(&self, transport: TransportType) -> (Vec<Duration>, Duration) {
match transport {
TransportType::Udp => {
let timeouts = (0..self.retransmits.max(1) - 1)
.map(|i| (self.initial * 2u32.pow(i)).min(self.max))
.collect::<Vec<_>>();
(timeouts, self.last_retransmit)
}
TransportType::Tcp => {
let timeouts = vec![];
let last_retransmit_timeout = self.last_retransmit
+ (0..self.retransmits.max(1) - 1).fold(Duration::ZERO, |acc, i| {
acc + (self.initial * 2u32.pow(i)).min(self.max)
});
(timeouts, last_retransmit_timeout)
}
}
}
}
#[derive(Debug)]
struct StunRequestState {
transaction_id: TransactionId,
request_integrity: Option<IntegrityAlgorithm>,
bytes: Vec<u8>,
transport: TransportType,
from: SocketAddr,
to: SocketAddr,
timeouts: Vec<Duration>,
last_retransmit_timeout: Duration,
recv_cancelled: bool,
send_cancelled: bool,
timeout_i: usize,
last_send_time: Option<Instant>,
}
impl StunRequestState {
#[allow(clippy::too_many_arguments)]
fn new<T: AsRef<[u8]>>(
request: T,
transport: TransportType,
from: SocketAddr,
to: SocketAddr,
transaction_id: TransactionId,
integrity_algorithm: Option<IntegrityAlgorithm>,
timeouts: Vec<Duration>,
last_retransmit_timeout: Duration,
) -> Self {
let data = request.as_ref();
Self {
transaction_id,
bytes: data.to_vec(),
transport,
from,
to,
request_integrity: integrity_algorithm,
timeouts,
timeout_i: 0,
last_retransmit_timeout,
recv_cancelled: false,
send_cancelled: false,
last_send_time: None,
}
}
#[tracing::instrument(skip(self, now), level = "trace")]
fn next_send_time(&self, now: Instant) -> Option<Instant> {
let Some(last_send) = self.last_send_time else {
trace!("not sent yet -> send immediately");
return Some(now);
};
if self.timeout_i >= self.timeouts.len() {
let next_send = last_send + self.last_retransmit_timeout;
trace!("final retransmission, final timeout ends at {next_send:?}");
if next_send > now {
return Some(next_send);
}
return None;
}
let next_send = last_send + self.timeouts[self.timeout_i];
Some(next_send)
}
#[tracing::instrument(
name = "stun_request_poll"
level = "debug",
ret,
skip(self, now),
fields(transaction_id = %self.transaction_id),
)]
fn poll(&mut self, now: Instant) -> StunRequestPollRet {
if self.recv_cancelled {
return StunRequestPollRet::Cancelled;
}
let Some(next_send) = self.next_send_time(now) else {
return StunRequestPollRet::TimedOut;
};
if next_send >= now {
if self.send_cancelled && self.timeout_i >= self.timeouts.len() {
return StunRequestPollRet::Cancelled;
}
return StunRequestPollRet::WaitUntil(next_send);
}
StunRequestPollRet::WaitUntil(now)
}
#[tracing::instrument(
name = "stun_request_poll_transmit",
skip(self, now),
fields(transaction_id = %self.transaction_id)
)]
fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
if self.recv_cancelled {
return None;
};
let next_send = self.next_send_time(now)?;
if next_send > now {
return None;
}
if self.last_send_time.is_some() {
self.timeout_i += 1;
}
self.last_send_time = Some(now);
if self.send_cancelled {
return None;
};
trace!(
"sending {} bytes over {:?} from {:?} to {:?}",
self.bytes.len(),
self.transport,
self.from,
self.to
);
Some(send_data(
self.transport,
self.bytes.as_slice(),
self.from,
self.to,
))
}
}
#[derive(Debug, Clone)]
pub struct StunRequest<'a> {
agent: &'a StunAgent,
transaction_id: TransactionId,
}
impl StunRequest<'_> {
pub fn peer_address(&self) -> SocketAddr {
let state = self.agent.request_state(self.transaction_id).unwrap();
state.to
}
pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
let state = self.agent.request_state(self.transaction_id).unwrap();
state.request_integrity
}
}
#[derive(Debug)]
pub struct StunRequestMut<'a> {
agent: &'a mut StunAgent,
transaction_id: TransactionId,
}
impl StunRequestMut<'_> {
pub fn peer_address(&self) -> SocketAddr {
let state = self.agent.request_state(self.transaction_id).unwrap();
state.to
}
pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
let state = self.agent.request_state(self.transaction_id).unwrap();
state.request_integrity
}
pub fn cancel_retransmissions(&mut self) {
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
state.send_cancelled = true;
}
}
pub fn cancel(&mut self) {
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
state.send_cancelled = true;
state.recv_cancelled = true;
}
}
pub fn agent(&self) -> &StunAgent {
self.agent
}
pub fn mut_agent(&mut self) -> &mut StunAgent {
self.agent
}
pub fn configure_timeout(
&mut self,
initial_rto: Duration,
retransmits: u32,
last_retransmit_timeout: Duration,
) {
self.configure_timeout_with_max(
initial_rto,
retransmits,
last_retransmit_timeout,
Duration::MAX,
);
}
pub fn configure_timeout_with_max(
&mut self,
initial_rto: Duration,
retransmits: u32,
last_retransmit_timeout: Duration,
max_rto: Duration,
) {
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
let (timeouts, final_wait) = RequestRto {
initial: initial_rto,
max: max_rto,
retransmits,
last_retransmit: last_retransmit_timeout,
}
.calculate_timeouts(state.transport);
state.timeouts = timeouts;
state.last_retransmit_timeout = final_wait;
}
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum StunError {
#[error("The operation is already in progress")]
AlreadyInProgress,
#[error("A required resource could not be found")]
ResourceNotFound,
#[error("An operation timed out")]
TimedOut,
#[error("Unexpected data was received")]
ProtocolViolation,
#[error("Operation was aborted")]
Aborted,
#[error("{}", .0)]
ParseError(StunParseError),
#[error("{}", .0)]
WriteError(StunWriteError),
}
impl From<StunParseError> for StunError {
fn from(e: StunParseError) -> Self {
StunError::ParseError(e)
}
}
impl From<StunWriteError> for StunError {
fn from(e: StunWriteError) -> Self {
StunError::WriteError(e)
}
}
#[cfg(test)]
pub(crate) mod tests {
use alloc::string::String;
use tracing::error;
use crate::auth::ShortTermAuth;
use super::*;
#[test]
fn agent_getters_setters() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
assert_eq!(agent.transport(), TransportType::Udp);
assert_eq!(agent.local_addr(), local_addr);
assert_eq!(agent.remote_addr(), Some(remote_addr));
}
#[test]
fn request() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let now = Instant::ZERO;
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let transmit = agent
.send_request(msg.finish(), remote_addr, now)
.unwrap()
.into_owned();
let request = agent.request_transaction(transaction_id).unwrap();
assert!(request.integrity().is_none());
assert_eq!(transmit.transport, TransportType::Udp);
assert_eq!(transmit.from, local_addr);
assert_eq!(transmit.to, remote_addr);
let request = Message::from_bytes(&transmit.data).unwrap();
let response = Message::builder_error(&request, MessageWriteVec::new());
let resp_data = response.finish();
let response = Message::from_bytes(&resp_data).unwrap();
assert!(agent.handle_stun_message(&response, remote_addr));
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
let ret = agent.poll(now);
assert!(matches!(ret, StunAgentPollRet::WaitUntil(_)));
}
#[test]
fn indication_with_invalid_response() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let transaction_id = TransactionId::generate();
let msg = Message::builder(
MessageType::from_class_method(MessageClass::Indication, BINDING),
transaction_id,
MessageWriteVec::new(),
);
let transmit = agent
.send(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
assert_eq!(transmit.transport, TransportType::Udp);
assert_eq!(transmit.from, local_addr);
assert_eq!(transmit.to, remote_addr);
let _indication = Message::from_bytes(&transmit.data).unwrap();
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
let response = Message::builder(
MessageType::from_class_method(MessageClass::Error, BINDING),
transaction_id,
MessageWriteVec::new(),
);
let resp_data = response.finish();
let response = Message::from_bytes(&resp_data).unwrap();
assert!(!agent.handle_stun_message(&response, remote_addr))
}
#[test]
fn request_with_credentials() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut auth = ShortTermAuth::new();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let credentials = ShortTermCredentials::new(String::from("local_password"));
auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
assert!(!agent.is_validated_peer(remote_addr));
let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
.unwrap();
error!("send");
let transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
error!("sent");
let request = Message::from_bytes(&transmit.data).unwrap();
error!("generate response");
let mut response = Message::builder_success(&request, MessageWriteVec::new());
let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
response.add_attribute(&xor_addr).unwrap();
response
.add_message_integrity(&credentials.into(), IntegrityAlgorithm::Sha1)
.unwrap();
error!("{response:?}");
let data = response.finish();
error!("{data:?}");
let response = Message::from_bytes(&data).unwrap();
error!("{response}");
assert_eq!(
auth.validate_incoming_message(&response).unwrap(),
Some(IntegrityAlgorithm::Sha1)
);
let request = agent
.request_transaction(response.transaction_id())
.unwrap();
assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
assert!(agent.handle_stun_message(&response, remote_addr));
assert_eq!(response.transaction_id(), transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(agent.is_validated_peer(remote_addr));
}
#[test]
fn request_unanswered() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let mut now = Instant::ZERO;
loop {
let _ = agent.poll_transmit(now);
match agent.poll(now) {
StunAgentPollRet::WaitUntil(new_now) => {
now = new_now;
}
StunAgentPollRet::TransactionTimedOut(_) => break,
_ => unreachable!(),
}
}
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_custom_timeout() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let mut now = Instant::ZERO;
agent.send_request(msg.finish(), remote_addr, now).unwrap();
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
transaction.configure_timeout_with_max(
Duration::from_secs(1),
4,
Duration::from_secs(10),
Duration::from_secs(2),
);
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(1));
now = wait;
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait, now);
let Some(_) = agent.poll_transmit(now) else {
unreachable!();
};
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(2));
now = wait;
let Some(_) = agent.poll_transmit(now) else {
unreachable!();
};
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(2));
now = wait;
let Some(_) = agent.poll_transmit(now) else {
unreachable!();
};
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(10));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_no_retransmit() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let mut now = Instant::ZERO;
agent.send_request(msg.finish(), remote_addr, now).unwrap();
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(10));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_tcp_custom_timeout() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
.remote_addr(remote_addr)
.request_retransmits(
Duration::from_secs(1),
Duration::from_secs(2),
4,
Duration::from_secs(3),
)
.build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let mut now = Instant::ZERO;
agent.send_request(msg.finish(), remote_addr, now).unwrap();
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(1 + 2 + 2 + 3));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_without_credentials() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
assert!(!agent.is_validated_peer(remote_addr));
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let request = Message::from_bytes(&transmit.data).unwrap();
let mut response = Message::builder_success(&request, MessageWriteVec::new());
let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
response.add_attribute(&xor_addr).unwrap();
let data = response.finish();
let to = transmit.to;
trace!("data: {data:?}");
let response = Message::from_bytes(&data).unwrap();
let request = agent
.request_transaction(response.transaction_id())
.unwrap();
assert_eq!(request.integrity(), None);
assert!(agent.handle_stun_message(&response, to));
assert_eq!(response.transaction_id(), transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(agent.is_validated_peer(remote_addr));
}
#[test]
fn response_with_incorrect_credentials() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut auth = ShortTermAuth::new();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let credentials = ShortTermCredentials::new(String::from("local_password"));
let wrong_credentials = ShortTermCredentials::new(String::from("wrong_password"));
auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
.unwrap();
let transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let data = transmit.data;
let request = Message::from_bytes(&data).unwrap();
let mut response = Message::builder_success(&request, MessageWriteVec::new());
let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
response.add_attribute(&xor_addr).unwrap();
response
.add_message_integrity(&wrong_credentials.into(), IntegrityAlgorithm::Sha1)
.unwrap();
let data = response.finish();
let response = Message::from_bytes(&data).unwrap();
let request = agent
.request_transaction(response.transaction_id())
.unwrap();
assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
assert!(matches!(
auth.validate_incoming_message(&response),
Err(ValidateError::IntegrityFailed)
));
assert!(!agent.is_validated_peer(remote_addr));
assert!(agent.handle_stun_message(&response, remote_addr));
assert!(!agent.handle_stun_message(&response, remote_addr));
assert!(agent.is_validated_peer(remote_addr));
}
#[test]
fn duplicate_response_ignored() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
assert!(!agent.is_validated_peer(remote_addr));
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let data = transmit.data;
let request = Message::from_bytes(&data).unwrap();
let mut response = Message::builder_success(&request, MessageWriteVec::new());
let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
response.add_attribute(&xor_addr).unwrap();
let data = response.finish();
let to = transmit.to;
let response = Message::from_bytes(&data).unwrap();
assert!(agent.handle_stun_message(&response, to));
let response = Message::from_bytes(&data).unwrap();
assert!(!agent.handle_stun_message(&response, to));
}
#[test]
fn request_cancel() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let _transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let mut request = agent.mut_request_transaction(transaction_id).unwrap();
assert_eq!(request.integrity(), None);
assert_eq!(request.agent().local_addr(), local_addr);
assert_eq!(request.mut_agent().local_addr(), local_addr);
assert_eq!(request.peer_address(), remote_addr);
request.cancel();
let ret = agent.poll(Instant::ZERO);
let StunAgentPollRet::TransactionCancelled(_request) = ret else {
unreachable!();
};
assert_eq!(transaction_id, transaction_id);
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_cancel_send() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let _transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
let mut request = agent.mut_request_transaction(transaction_id).unwrap();
assert_eq!(request.integrity(), None);
assert_eq!(request.agent().local_addr(), local_addr);
assert_eq!(request.mut_agent().local_addr(), local_addr);
assert_eq!(request.peer_address(), remote_addr);
request.cancel_retransmissions();
let mut now = Instant::ZERO;
let start = now;
loop {
match agent.poll(now) {
StunAgentPollRet::WaitUntil(new_now) => {
assert_ne!(new_now, now);
now = new_now;
}
StunAgentPollRet::TransactionCancelled(_) => break,
_ => unreachable!(),
}
let _ = agent.poll_transmit(now);
}
assert!(now - start > Duration::from_secs(20));
assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());
assert!(!agent.is_validated_peer(remote_addr));
}
#[test]
fn request_duplicate() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let msg = msg.finish();
let transmit = agent
.send_request(msg.clone(), remote_addr, Instant::ZERO)
.unwrap();
let to = transmit.to;
let request = Message::from_bytes(&transmit.data).unwrap();
let mut response = Message::builder_success(&request, MessageWriteVec::new());
let xor_addr = XorMappedAddress::new(transmit.from, transaction_id);
response.add_attribute(&xor_addr).unwrap();
assert!(matches!(
agent.send_request(msg, remote_addr, Instant::ZERO),
Err(StunError::AlreadyInProgress)
));
let request = agent.request_transaction(transaction_id).unwrap();
assert_eq!(request.peer_address(), remote_addr);
let data = response.finish();
let response = Message::from_bytes(&data).unwrap();
assert!(agent.handle_stun_message(&response, to));
assert!(agent.is_validated_peer(to));
}
#[test]
fn incoming_request() {
let _log = crate::tests::test_init_log();
let local_addr = "10.0.0.1:12345".parse().unwrap();
let remote_addr = "10.0.0.2:3478".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let data = msg.finish();
let stun = Message::from_bytes(&data).unwrap();
error!("{stun:?}");
assert!(agent.handle_stun_message(&stun, remote_addr));
agent.validated_peer(remote_addr);
assert!(agent.is_validated_peer(remote_addr));
}
#[test]
fn tcp_request() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING, MessageWriteVec::new());
let transaction_id = msg.transaction_id();
let transmit = agent
.send_request(msg.finish(), remote_addr, Instant::ZERO)
.unwrap();
assert_eq!(transmit.transport, TransportType::Tcp);
assert_eq!(transmit.from, local_addr);
assert_eq!(transmit.to, remote_addr);
let request = Message::from_bytes(&transmit.data).unwrap();
assert_eq!(request.transaction_id(), transaction_id);
}
#[test]
fn transmit_into_owned() {
let data = [0x10, 0x20];
let transport = TransportType::Udp;
let from = "127.0.0.1:1000".parse().unwrap();
let to = "127.0.0.1:2000".parse().unwrap();
let transmit = Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to);
let owned = transmit.into_owned();
assert_eq!(owned.data.as_ref(), data.as_ref());
assert_eq!(owned.transport, transport);
assert_eq!(owned.from, from);
assert_eq!(owned.to, to);
error!("{owned}");
}
#[test]
fn transmit_display() {
let data = [0x10, 0x20];
let from = "127.0.0.1:1000".parse().unwrap();
let to = "127.0.0.1:2000".parse().unwrap();
assert_eq!(
alloc::format!(
"{}",
Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to)
),
String::from("Transmit(UDP: 127.0.0.1:1000 -> 127.0.0.1:2000 of 2 bytes)")
);
}
#[test]
fn request_retransmits() {
let _log = crate::tests::test_init_log();
let rto = RequestRto {
initial: Duration::from_millis(1),
max: Duration::MAX,
retransmits: 0,
last_retransmit: Duration::from_secs(1),
};
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
assert_eq!(timeouts, vec![]);
assert_eq!(last_transmit_timeout, Duration::from_secs(1));
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
assert_eq!(timeouts, vec![]);
assert_eq!(last_transmit_timeout, Duration::from_secs(1));
let rto = RequestRto {
initial: Duration::from_millis(1),
max: Duration::MAX,
retransmits: 1,
last_retransmit: Duration::from_secs(1),
};
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
assert_eq!(timeouts, vec![]);
assert_eq!(last_transmit_timeout, Duration::from_secs(1));
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
assert_eq!(timeouts, vec![]);
assert_eq!(last_transmit_timeout, Duration::from_secs(1));
let rto = RequestRto {
initial: Duration::from_millis(1),
max: Duration::MAX,
retransmits: 2,
last_retransmit: Duration::from_secs(1),
};
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
assert_eq!(timeouts, vec![Duration::from_millis(1)]);
assert_eq!(last_transmit_timeout, Duration::from_secs(1));
let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
assert_eq!(timeouts, vec![]);
assert_eq!(
last_transmit_timeout,
Duration::from_secs(1) + Duration::from_millis(1)
);
}
}