Skip to main content

stun_proto/
agent.rs

1// Copyright (C) 2020 Matthew Waters <matthew@centricular.com>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// SPDX-License-Identifier: MIT OR Apache-2.0
10
11//! # STUN agent
12//!
13//! A STUN Agent that follows the procedures of [RFC5389] and [RFC8489] and is implemented with the
14//! sans-IO pattern. This agent does no IO processing and operates solely on inputs it is
15//! provided.
16//!
17//! [RFC8489]: https://tools.ietf.org/html/rfc8489
18//! [RFC5389]: https://tools.ietf.org/html/rfc5389
19
20use core::net::SocketAddr;
21use core::sync::atomic::{AtomicUsize, Ordering};
22
23use alloc::collections::{BTreeMap, BTreeSet};
24use alloc::vec;
25use alloc::vec::Vec;
26use core::time::Duration;
27
28use crate::Instant;
29
30use stun_types::attribute::*;
31use stun_types::data::Data;
32use stun_types::message::*;
33
34use stun_types::TransportType;
35
36use tracing::{debug, trace, warn};
37
38static STUN_AGENT_COUNT: AtomicUsize = AtomicUsize::new(0);
39
40/// Implementation of a STUN agent.
41#[derive(Debug)]
42pub struct StunAgent {
43    id: usize,
44    transport: TransportType,
45    local_addr: SocketAddr,
46    remote_addr: Option<SocketAddr>,
47    validated_peers: BTreeSet<SocketAddr>,
48    outstanding_requests: BTreeMap<TransactionId, StunRequestState>,
49    request_timeouts: Vec<Duration>,
50    last_retransmit_timeout: Duration,
51}
52
53/// Builder struct for a [`StunAgent`].
54#[derive(Debug)]
55pub struct StunAgentBuilder {
56    transport: TransportType,
57    local_addr: SocketAddr,
58    remote_addr: Option<SocketAddr>,
59    rto: RequestRto,
60}
61
62impl StunAgentBuilder {
63    /// Set the remote address the [`StunAgent`] will be configured to only send data to.
64    pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
65        self.remote_addr = Some(addr);
66        self
67    }
68
69    /// Configure the default timeouts and retransmissions for each STUN request.
70    ///
71    /// - `initial` - the initial time between consecutive transmissions. If 0, or 1, then only a
72    ///   single request will be performed.
73    /// - `max` - the maximum amount of time between consecutive retransmits.
74    /// - `retransmits` - the total number of transmissions of the request.
75    /// - `final_retransmit_timeout` - the amount of time after the final transmission to wait
76    ///   for a response before considering the request as having timed out.
77    ///
78    /// As specified in RFC 8489, `initial_rto` should be >= 500ms (unless specific information is
79    /// available on the RTT, `max` is `Duration::MAX`, `retransmits` has a default value of 7,
80    /// and `last_retransmit_timeout` should be `16 * initial_rto`.
81    ///
82    /// STUN transactions over TCP will only send a single request and have a timeout of the sum of
83    /// the timeouts of a UDP transaction.
84    pub fn request_retransmits(
85        mut self,
86        initial: Duration,
87        max: Duration,
88        retransmits: u32,
89        final_retransmit_timeout: Duration,
90    ) -> Self {
91        self.rto.initial = initial;
92        self.rto.max = max;
93        self.rto.retransmits = retransmits;
94        self.rto.last_retransmit = final_retransmit_timeout;
95        self
96    }
97
98    /// Build the [`StunAgent`].
99    pub fn build(self) -> StunAgent {
100        let id = STUN_AGENT_COUNT.fetch_add(1, Ordering::SeqCst);
101        let (request_timeouts, last_retransmit_timeout) =
102            self.rto.calculate_timeouts(self.transport);
103        StunAgent {
104            id,
105            transport: self.transport,
106            local_addr: self.local_addr,
107            remote_addr: self.remote_addr,
108            validated_peers: Default::default(),
109            outstanding_requests: Default::default(),
110            request_timeouts,
111            last_retransmit_timeout,
112        }
113    }
114}
115
116impl StunAgent {
117    /// Create a new [`StunAgentBuilder`].
118    pub fn builder(transport: TransportType, local_addr: SocketAddr) -> StunAgentBuilder {
119        StunAgentBuilder {
120            transport,
121            local_addr,
122            remote_addr: None,
123            rto: Default::default(),
124        }
125    }
126
127    /// The [`TransportType`] of this [`StunAgent`].
128    pub fn transport(&self) -> TransportType {
129        self.transport
130    }
131
132    /// The local address of this [`StunAgent`].
133    pub fn local_addr(&self) -> SocketAddr {
134        self.local_addr
135    }
136
137    /// The remote address of this [`StunAgent`].
138    pub fn remote_addr(&self) -> Option<SocketAddr> {
139        self.remote_addr
140    }
141
142    /// Perform any operations needed to be able to send data to a peer.
143    pub fn send_data<T: AsRef<[u8]>>(&self, bytes: T, to: SocketAddr) -> Transmit<T> {
144        send_data(self.transport, bytes, self.local_addr, to)
145    }
146
147    /// Perform any operations needed to be able to send a [`Message`] to a peer.
148    ///
149    /// The returned [`Transmit`] must be sent to the respective peer after this call.
150    ///
151    /// # Panics
152    ///
153    /// - If the STUN Message is a request. Use [`send_request()`](StunAgent::send_request) instead.
154    #[tracing::instrument(name = "stun_agent_send",
155        skip(self, msg),
156        fields(
157            transport = %self.transport,
158            from = %self.local_addr,
159            transaction_id,
160        )
161    )]
162    pub fn send<T: AsRef<[u8]>>(
163        &mut self,
164        msg: T,
165        to: SocketAddr,
166        now: Instant,
167    ) -> Result<Transmit<T>, StunError> {
168        let data = msg.as_ref();
169        let hdr = MessageHeader::from_bytes(data)?;
170        tracing::Span::current().record(
171            "transaction_id",
172            tracing::field::display(hdr.transaction_id()),
173        );
174        assert!(!hdr.get_type().has_class(MessageClass::Request));
175        trace!("Sending {} to {to}", hdr.get_type());
176        Ok(Transmit::new(msg, self.transport, self.local_addr, to))
177    }
178
179    /// Perform any operations needed to be able to send a request [`Message`] to a peer.
180    ///
181    /// The returned [`Transmit`] must be sent to the respective peer after this call.
182    ///
183    /// # Panics
184    ///
185    /// - If the STUN Message is not a request. Use [`send()`](StunAgent::send) instead.
186    #[tracing::instrument(name = "stun_agent_send_request",
187        skip(self, msg),
188        fields(
189            transport = %self.transport,
190            from = %self.local_addr,
191            transaction_id,
192        )
193    )]
194    pub fn send_request<'a, T: AsRef<[u8]>>(
195        &'a mut self,
196        msg: T,
197        to: SocketAddr,
198        now: Instant,
199    ) -> Result<Transmit<Data<'a>>, StunError> {
200        let data = msg.as_ref();
201        let hdr = MessageHeader::from_bytes(data)?;
202        assert!(hdr.get_type().has_class(MessageClass::Request));
203        let transaction_id = hdr.transaction_id();
204        tracing::Span::current().record("transaction_id", tracing::field::display(transaction_id));
205        let state = match self.outstanding_requests.entry(transaction_id) {
206            alloc::collections::btree_map::Entry::Vacant(entry) => {
207                let integrity_algorithm = MessageAttributesIter::new(data)
208                    .filter_map(|(_offset, attr)| match attr.get_type() {
209                        MessageIntegrity::TYPE => Some(IntegrityAlgorithm::Sha1),
210                        MessageIntegritySha256::TYPE => Some(IntegrityAlgorithm::Sha256),
211                        _ => None,
212                    })
213                    .last();
214                trace!("Adding request to {to} with integrity algorithm: {integrity_algorithm:?}");
215                entry.insert(StunRequestState::new(
216                    msg,
217                    self.transport,
218                    self.local_addr,
219                    to,
220                    transaction_id,
221                    integrity_algorithm,
222                    self.request_timeouts.clone(),
223                    self.last_retransmit_timeout,
224                ))
225            }
226            alloc::collections::btree_map::Entry::Occupied(_entry) => {
227                return Err(StunError::AlreadyInProgress);
228            }
229        };
230        let Some(transmit) = state.poll_transmit(now) else {
231            unreachable!();
232        };
233        Ok(Transmit::new(
234            Data::from(transmit.data),
235            transmit.transport,
236            transmit.from,
237            transmit.to,
238        ))
239    }
240
241    /// Returns whether this agent has received or sent a STUN message with this peer. Failure may
242    /// be the result of an attacker and the caller must drop any non-STUN data received before this
243    /// functions returns `true`.
244    ///
245    /// If non-STUN data is received over a TCP connection from an unvalidated peer, the caller
246    /// must immediately close the TCP connection.
247    pub fn is_validated_peer(&self, remote_addr: SocketAddr) -> bool {
248        self.validated_peers.contains(&remote_addr)
249    }
250
251    /// Indicate to the STUN agent that STUN messages have been sent/received to/from a peer.
252    #[tracing::instrument(
253        name = "stun_validated_peer"
254        skip(self),
255        fields(stun_id = self.id)
256    )]
257    pub fn validated_peer(&mut self, addr: SocketAddr) {
258        if !self.validated_peers.contains(&addr) {
259            debug!("validated peer {:?}", addr);
260            self.validated_peers.insert(addr);
261        }
262    }
263
264    /// Provide data received on a socket from a peer for handling by the [`StunAgent`] after it
265    /// has successfully passed authentication.
266    ///
267    /// For responses, this will cause the associated request to be removed from the agent if it
268    /// exists.
269    ///
270    /// The return value indicates whether the message passes internal checks and should be acted
271    /// upon.
272    #[tracing::instrument(
273        name = "stun_handle_message"
274        skip(self, msg, from),
275        fields(
276            transaction_id = %msg.transaction_id(),
277        )
278    )]
279    pub fn handle_stun_message(&mut self, msg: &Message<'_>, from: SocketAddr) -> bool {
280        if msg.is_response()
281            && self
282                .take_outstanding_request(&msg.transaction_id())
283                .is_none()
284        {
285            trace!("original request disappeared");
286            return false;
287        }
288        self.validated_peer(from);
289        true
290    }
291
292    #[tracing::instrument(
293        skip(self, transaction_id),
294        fields(transaction_id = %transaction_id)
295    )]
296    fn take_outstanding_request(
297        &mut self,
298        transaction_id: &TransactionId,
299    ) -> Option<StunRequestState> {
300        if let Some(request) = self.outstanding_requests.remove(transaction_id) {
301            trace!("removing request");
302            Some(request)
303        } else {
304            trace!("no outstanding request");
305            None
306        }
307    }
308
309    /// Retrieve a reference to an outstanding STUN request. Outstanding requests are kept until
310    /// either:
311    /// - [`handle_stun_message()`](StunAgent::handle_stun_message) is called, or
312    /// - [`poll()`](StunAgent::poll) returns [`StunAgentPollRet::TransactionCancelled`] or
313    ///   [`StunAgentPollRet::TransactionTimedOut`] for the request.
314    pub fn request_transaction(&self, transaction_id: TransactionId) -> Option<StunRequest<'_>> {
315        if self.outstanding_requests.contains_key(&transaction_id) {
316            Some(StunRequest {
317                agent: self,
318                transaction_id,
319            })
320        } else {
321            None
322        }
323    }
324
325    /// Retrieve a mutable reference to an outstanding STUN request. Outstanding requests are kept
326    /// until either:
327    /// - [`handle_stun_message()`](StunAgent::handle_stun_message) is called, or
328    /// - [`poll()`](StunAgent::poll) returns [`StunAgentPollRet::TransactionCancelled`] or
329    ///   [`StunAgentPollRet::TransactionTimedOut`] for the request.
330    pub fn mut_request_transaction(
331        &mut self,
332        transaction_id: TransactionId,
333    ) -> Option<StunRequestMut<'_>> {
334        if self.outstanding_requests.contains_key(&transaction_id) {
335            Some(StunRequestMut {
336                agent: self,
337                transaction_id,
338            })
339        } else {
340            None
341        }
342    }
343
344    fn mut_request_state(
345        &mut self,
346        transaction_id: TransactionId,
347    ) -> Option<&mut StunRequestState> {
348        self.outstanding_requests.get_mut(&transaction_id)
349    }
350
351    fn request_state(&self, transaction_id: TransactionId) -> Option<&StunRequestState> {
352        self.outstanding_requests.get(&transaction_id)
353    }
354
355    /// Poll the agent for making further progress on any outstanding requests. The returned value
356    /// indicates the current state and anything the caller needs to perform.
357    ///
358    /// Upon expiry of the timer from [`StunAgentPollRet::WaitUntil`],
359    /// [`poll_transmit()`](StunAgent::poll_transmit) must be called.
360    #[tracing::instrument(
361        name = "stun_agent_poll"
362        level = "debug",
363        skip(self),
364    )]
365    pub fn poll(&mut self, now: Instant) -> StunAgentPollRet {
366        let mut lowest_wait = now + Duration::from_secs(3600);
367        let mut timeout = None;
368        let mut cancelled = None;
369        for (transaction_id, request) in self.outstanding_requests.iter_mut() {
370            debug_assert_eq!(transaction_id, &request.transaction_id);
371            match request.poll(now) {
372                StunRequestPollRet::Cancelled => {
373                    cancelled = Some(*transaction_id);
374                    break;
375                }
376                StunRequestPollRet::WaitUntil(wait_until) => {
377                    if wait_until < lowest_wait {
378                        lowest_wait = wait_until;
379                    }
380                }
381                StunRequestPollRet::TimedOut => {
382                    timeout = Some(*transaction_id);
383                    break;
384                }
385            }
386        }
387        if let Some(transaction) = timeout {
388            if let Some(_state) = self.outstanding_requests.remove(&transaction) {
389                return StunAgentPollRet::TransactionTimedOut(transaction);
390            }
391        }
392        if let Some(transaction) = cancelled {
393            if let Some(_state) = self.outstanding_requests.remove(&transaction) {
394                return StunAgentPollRet::TransactionCancelled(transaction);
395            }
396        }
397        StunAgentPollRet::WaitUntil(lowest_wait)
398    }
399
400    /// Poll for any transmissions that may need to be performed.
401    #[tracing::instrument(
402        name = "stun_agent_poll_transmit"
403        level = "debug",
404        skip(self),
405    )]
406    pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
407        self.outstanding_requests
408            .values_mut()
409            .filter_map(|request| request.poll_transmit(now))
410            .next()
411    }
412}
413
414/// Return value for [`StunAgent::poll`].
415#[derive(Debug)]
416pub enum StunAgentPollRet {
417    /// An oustanding transaction timed out and has been removed from the agent.
418    TransactionTimedOut(TransactionId),
419    /// An oustanding transaction was cancelled and has been removed from the agent.
420    TransactionCancelled(TransactionId),
421    /// Wait until the specified time has passed.
422    WaitUntil(Instant),
423}
424
425fn send_data<T: AsRef<[u8]>>(
426    transport: TransportType,
427    bytes: T,
428    from: SocketAddr,
429    to: SocketAddr,
430) -> Transmit<T> {
431    Transmit::new(bytes, transport, from, to)
432}
433
434/// A piece of data that needs to, or has been transmitted.
435#[derive(Debug)]
436pub struct Transmit<T: AsRef<[u8]>> {
437    /// The data blob.
438    pub data: T,
439    /// The transport for the transmission.
440    pub transport: TransportType,
441    /// The source address of the transmission.
442    pub from: SocketAddr,
443    /// The destination address of the transmission.
444    pub to: SocketAddr,
445}
446
447impl<T: AsRef<[u8]>> core::fmt::Display for Transmit<T> {
448    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
449        write!(
450            f,
451            "Transmit({}: {} -> {} of {} bytes)",
452            self.transport,
453            self.from,
454            self.to,
455            self.data.as_ref().len()
456        )
457    }
458}
459
460impl<T: AsRef<[u8]>> Transmit<T> {
461    /// Construct a new [`Transmit`] with the specifid data and 5-tuple.
462    pub fn new(data: T, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self {
463        Self {
464            data,
465            transport,
466            from,
467            to,
468        }
469    }
470
471    /// Reinterpret the data of a [`Transmit`] into a different type.
472    ///
473    /// # Examples
474    ///
475    /// ```
476    /// # use stun_proto::agent::Transmit;
477    /// # use stun_proto::types::TransportType;
478    /// # use core::net::SocketAddr;
479    /// let local_addr = "10.0.0.1:1000".parse().unwrap();
480    /// let remote_addr = "10.0.0.2:2000".parse().unwrap();
481    /// let slice = [42; 8];
482    /// let transmit = Transmit::new(slice.clone(), TransportType::Udp, local_addr, remote_addr);
483    /// // change the data type of the `Transmit` into a `Vec<u8>`.
484    /// let transmit = transmit.reinterpret_data(|data| data.to_vec());
485    /// # assert_eq!(transmit.transport, TransportType::Udp);
486    /// # assert_eq!(transmit.from, local_addr);
487    /// # assert_eq!(transmit.to, remote_addr);
488    /// assert_eq!(transmit.data, slice.as_slice());
489    /// ```
490    pub fn reinterpret_data<O: AsRef<[u8]>, F: FnOnce(T) -> O>(self, f: F) -> Transmit<O> {
491        Transmit {
492            data: f(self.data),
493            transport: self.transport,
494            from: self.from,
495            to: self.to,
496        }
497    }
498}
499
500impl Transmit<Data<'_>> {
501    /// Construct a new owned [`Transmit`] from a provided [`Transmit`].
502    pub fn into_owned<'b>(self) -> Transmit<Data<'b>> {
503        self.reinterpret_data(|data| data.into_owned())
504    }
505}
506
507/// Return value for [`StunRequest::poll`].
508#[derive(Debug)]
509enum StunRequestPollRet {
510    /// Wait until the specified time has passed.
511    WaitUntil(Instant),
512    /// The request has been cancelled and will not make further progress.
513    Cancelled,
514    /// The request timed out.
515    TimedOut,
516}
517
518#[derive(Debug)]
519struct RequestRto {
520    initial: Duration,
521    max: Duration,
522    retransmits: u32,
523    last_retransmit: Duration,
524}
525
526impl Default for RequestRto {
527    fn default() -> Self {
528        Self {
529            initial: Duration::from_millis(500),
530            max: Duration::MAX,
531            retransmits: 7,
532            last_retransmit: Duration::from_millis(8),
533        }
534    }
535}
536
537impl RequestRto {
538    fn calculate_timeouts(&self, transport: TransportType) -> (Vec<Duration>, Duration) {
539        match transport {
540            TransportType::Udp => {
541                let timeouts = (0..self.retransmits.max(1) - 1)
542                    .map(|i| (self.initial * 2u32.pow(i)).min(self.max))
543                    .collect::<Vec<_>>();
544                (timeouts, self.last_retransmit)
545            }
546            TransportType::Tcp => {
547                let timeouts = vec![];
548                let last_retransmit_timeout = self.last_retransmit
549                    + (0..self.retransmits.max(1) - 1).fold(Duration::ZERO, |acc, i| {
550                        acc + (self.initial * 2u32.pow(i)).min(self.max)
551                    });
552                (timeouts, last_retransmit_timeout)
553            }
554        }
555    }
556}
557
558#[derive(Debug)]
559struct StunRequestState {
560    transaction_id: TransactionId,
561    request_integrity: Option<IntegrityAlgorithm>,
562    bytes: Vec<u8>,
563    transport: TransportType,
564    from: SocketAddr,
565    to: SocketAddr,
566    timeouts: Vec<Duration>,
567    last_retransmit_timeout: Duration,
568    recv_cancelled: bool,
569    send_cancelled: bool,
570    timeout_i: usize,
571    last_send_time: Option<Instant>,
572}
573
574impl StunRequestState {
575    #[allow(clippy::too_many_arguments)]
576    fn new<T: AsRef<[u8]>>(
577        request: T,
578        transport: TransportType,
579        from: SocketAddr,
580        to: SocketAddr,
581        transaction_id: TransactionId,
582        integrity_algorithm: Option<IntegrityAlgorithm>,
583        timeouts: Vec<Duration>,
584        last_retransmit_timeout: Duration,
585    ) -> Self {
586        let data = request.as_ref();
587        /*
588        let (timeouts, last_retransmit_timeout) = if transport == TransportType::Tcp {
589            (vec![], Duration::from_millis(39500))
590        } else {
591            (
592                [500, 1000, 2000, 4000, 8000, 16000]
593                    .into_iter()
594                    .map(Duration::from_millis)
595                    .collect(),
596                Duration::from_millis(8000),
597            )
598        };*/
599        Self {
600            transaction_id,
601            bytes: data.to_vec(),
602            transport,
603            from,
604            to,
605            request_integrity: integrity_algorithm,
606            timeouts,
607            timeout_i: 0,
608            last_retransmit_timeout,
609            recv_cancelled: false,
610            send_cancelled: false,
611            last_send_time: None,
612        }
613    }
614
615    #[tracing::instrument(skip(self, now), level = "trace")]
616    fn next_send_time(&self, now: Instant) -> Option<Instant> {
617        let Some(last_send) = self.last_send_time else {
618            trace!("not sent yet -> send immediately");
619            return Some(now);
620        };
621        if self.timeout_i >= self.timeouts.len() {
622            let next_send = last_send + self.last_retransmit_timeout;
623            trace!("final retransmission, final timeout ends at {next_send:?}");
624            if next_send > now {
625                return Some(next_send);
626            }
627            return None;
628        }
629        let next_send = last_send + self.timeouts[self.timeout_i];
630        Some(next_send)
631    }
632
633    #[tracing::instrument(
634        name = "stun_request_poll"
635        level = "debug",
636        ret,
637        skip(self, now),
638        fields(transaction_id = %self.transaction_id),
639    )]
640    fn poll(&mut self, now: Instant) -> StunRequestPollRet {
641        if self.recv_cancelled {
642            return StunRequestPollRet::Cancelled;
643        }
644        // TODO: account for TCP connect in timeout
645        let Some(next_send) = self.next_send_time(now) else {
646            return StunRequestPollRet::TimedOut;
647        };
648        if next_send >= now {
649            if self.send_cancelled && self.timeout_i >= self.timeouts.len() {
650                // this cancellation may need a different value
651                return StunRequestPollRet::Cancelled;
652            }
653            return StunRequestPollRet::WaitUntil(next_send);
654        }
655        StunRequestPollRet::WaitUntil(now)
656    }
657
658    #[tracing::instrument(
659        name = "stun_request_poll_transmit",
660        skip(self, now),
661        fields(transaction_id = %self.transaction_id)
662    )]
663    fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
664        if self.recv_cancelled {
665            return None;
666        };
667        let next_send = self.next_send_time(now)?;
668
669        if next_send > now {
670            return None;
671        }
672        if self.last_send_time.is_some() {
673            self.timeout_i += 1;
674        }
675        self.last_send_time = Some(now);
676        if self.send_cancelled {
677            return None;
678        };
679        trace!(
680            "sending {} bytes over {:?} from {:?} to {:?}",
681            self.bytes.len(),
682            self.transport,
683            self.from,
684            self.to
685        );
686        Some(send_data(
687            self.transport,
688            self.bytes.as_slice(),
689            self.from,
690            self.to,
691        ))
692    }
693}
694
695/// A STUN Request.
696#[derive(Debug, Clone)]
697pub struct StunRequest<'a> {
698    agent: &'a StunAgent,
699    transaction_id: TransactionId,
700}
701
702impl StunRequest<'_> {
703    /// The remote address the request is sent to.
704    pub fn peer_address(&self) -> SocketAddr {
705        let state = self.agent.request_state(self.transaction_id).unwrap();
706        state.to
707    }
708
709    /// The integrity algorithm present on the request.
710    pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
711        let state = self.agent.request_state(self.transaction_id).unwrap();
712        state.request_integrity
713    }
714}
715
716/// A STUN Request.
717#[derive(Debug)]
718pub struct StunRequestMut<'a> {
719    agent: &'a mut StunAgent,
720    transaction_id: TransactionId,
721}
722
723impl StunRequestMut<'_> {
724    /// The remote address the request is sent to.
725    pub fn peer_address(&self) -> SocketAddr {
726        let state = self.agent.request_state(self.transaction_id).unwrap();
727        state.to
728    }
729
730    /// The integrity algorithm present on the request.
731    pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
732        let state = self.agent.request_state(self.transaction_id).unwrap();
733        state.request_integrity
734    }
735
736    /// Do not retransmit further. This will still allow for a reply to occur within the configured
737    /// timeouts, but will never send a retransmission. If no response is received, this will cause
738    /// [`StunAgent::poll()`] to return [`StunAgentPollRet::TransactionCancelled`] for this request.
739    pub fn cancel_retransmissions(&mut self) {
740        if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
741            state.send_cancelled = true;
742        }
743    }
744
745    /// Do not wait for any kind of response. This will cause [`StunAgent::poll()`] to return
746    /// [`StunAgentPollRet::TransactionCancelled`] for this request.
747    pub fn cancel(&mut self) {
748        if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
749            state.send_cancelled = true;
750            state.recv_cancelled = true;
751        }
752    }
753
754    /// The [`StunAgent`] this request is being sent with.
755    pub fn agent(&self) -> &StunAgent {
756        self.agent
757    }
758
759    /// The mutable [`StunAgent`] this request is being sent with.
760    pub fn mut_agent(&mut self) -> &mut StunAgent {
761        self.agent
762    }
763
764    /// Configure the timeouts and retransmissions for the STUN request.
765    ///
766    /// This the same as calling `[configure_timeout_with_max`] with a `max` of `Duration::MAX`.
767    pub fn configure_timeout(
768        &mut self,
769        initial_rto: Duration,
770        retransmits: u32,
771        last_retransmit_timeout: Duration,
772    ) {
773        self.configure_timeout_with_max(
774            initial_rto,
775            retransmits,
776            last_retransmit_timeout,
777            Duration::MAX,
778        );
779    }
780
781    /// Configure the timeouts and retransmissions for the STUN request.
782    ///
783    /// - `initial` - the initial time between consecutive transmissions. If 0, or 1, then only a
784    ///   single request will be performed.
785    /// - `max` - the maximum amount of time between consecutive retransmits.
786    /// - `retransmits` - the total number of transmissions of the request.
787    /// - `final_retransmit_timeout` - the amount of time after the final transmission to wait
788    ///   for a response before considering the request as having timed out.
789    ///
790    /// As specified in RFC 8489, `initial_rto` should be >= 500ms (unless specific information is
791    /// available on the RTT, `max` is `Duration::MAX`, `retransmits` has a default value of 7,
792    /// and `last_retransmit_timeout` should be `16 * initial_rto`.
793    ///
794    /// STUN transactions over TCP will only send a single request and have a timeout of the sum of
795    /// the timeouts of a UDP transaction.
796    pub fn configure_timeout_with_max(
797        &mut self,
798        initial_rto: Duration,
799        retransmits: u32,
800        last_retransmit_timeout: Duration,
801        max_rto: Duration,
802    ) {
803        if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
804            let (timeouts, final_wait) = RequestRto {
805                initial: initial_rto,
806                max: max_rto,
807                retransmits,
808                last_retransmit: last_retransmit_timeout,
809            }
810            .calculate_timeouts(state.transport);
811            state.timeouts = timeouts;
812            state.last_retransmit_timeout = final_wait;
813        }
814    }
815}
816
817/// STUN errors.
818#[derive(Debug, thiserror::Error)]
819#[non_exhaustive]
820pub enum StunError {
821    /// The operation is already in progress.
822    #[error("The operation is already in progress")]
823    AlreadyInProgress,
824    /// A resource was not found.
825    #[error("A required resource could not be found")]
826    ResourceNotFound,
827    /// An operation timed out without a response.
828    #[error("An operation timed out")]
829    TimedOut,
830    /// Unexpected data was received or an operation is not allowed at this time.
831    #[error("Unexpected data was received")]
832    ProtocolViolation,
833    /// An operation was cancelled.
834    #[error("Operation was aborted")]
835    Aborted,
836    /// A parsing error. The contained error contains more details.
837    #[error("{}", .0)]
838    ParseError(StunParseError),
839    /// A writing error. The contained error contains more details.
840    #[error("{}", .0)]
841    WriteError(StunWriteError),
842}
843
844impl From<StunParseError> for StunError {
845    fn from(e: StunParseError) -> Self {
846        StunError::ParseError(e)
847    }
848}
849
850impl From<StunWriteError> for StunError {
851    fn from(e: StunWriteError) -> Self {
852        StunError::WriteError(e)
853    }
854}
855
856#[cfg(test)]
857pub(crate) mod tests {
858    use alloc::string::String;
859    use tracing::error;
860
861    use crate::auth::ShortTermAuth;
862
863    use super::*;
864
865    #[test]
866    fn agent_getters_setters() {
867        let _log = crate::tests::test_init_log();
868        let local_addr = "10.0.0.1:12345".parse().unwrap();
869        let remote_addr = "10.0.0.2:3478".parse().unwrap();
870        let agent = StunAgent::builder(TransportType::Udp, local_addr)
871            .remote_addr(remote_addr)
872            .build();
873
874        assert_eq!(agent.transport(), TransportType::Udp);
875        assert_eq!(agent.local_addr(), local_addr);
876        assert_eq!(agent.remote_addr(), Some(remote_addr));
877    }
878
879    #[test]
880    fn request() {
881        let _log = crate::tests::test_init_log();
882        let local_addr = "127.0.0.1:2000".parse().unwrap();
883        let remote_addr = "127.0.0.1:1000".parse().unwrap();
884        let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
885            .remote_addr(remote_addr)
886            .build();
887        let now = Instant::ZERO;
888
889        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
890        let transaction_id = msg.transaction_id();
891        let transmit = agent
892            .send_request(msg.finish(), remote_addr, now)
893            .unwrap()
894            .into_owned();
895        let request = agent.request_transaction(transaction_id).unwrap();
896        assert!(request.integrity().is_none());
897        assert_eq!(transmit.transport, TransportType::Udp);
898        assert_eq!(transmit.from, local_addr);
899        assert_eq!(transmit.to, remote_addr);
900        let request = Message::from_bytes(&transmit.data).unwrap();
901        let response = Message::builder_error(&request, MessageWriteVec::new());
902        let resp_data = response.finish();
903        let response = Message::from_bytes(&resp_data).unwrap();
904        assert!(agent.handle_stun_message(&response, remote_addr));
905        assert!(agent.request_transaction(transaction_id).is_none());
906        assert!(agent.mut_request_transaction(transaction_id).is_none());
907
908        let ret = agent.poll(now);
909        assert!(matches!(ret, StunAgentPollRet::WaitUntil(_)));
910    }
911
912    #[test]
913    fn indication_with_invalid_response() {
914        let _log = crate::tests::test_init_log();
915        let local_addr = "127.0.0.1:2000".parse().unwrap();
916        let remote_addr = "127.0.0.1:1000".parse().unwrap();
917        let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
918            .remote_addr(remote_addr)
919            .build();
920        let transaction_id = TransactionId::generate();
921        let msg = Message::builder(
922            MessageType::from_class_method(MessageClass::Indication, BINDING),
923            transaction_id,
924            MessageWriteVec::new(),
925        );
926        let transmit = agent
927            .send(msg.finish(), remote_addr, Instant::ZERO)
928            .unwrap();
929        assert_eq!(transmit.transport, TransportType::Udp);
930        assert_eq!(transmit.from, local_addr);
931        assert_eq!(transmit.to, remote_addr);
932        let _indication = Message::from_bytes(&transmit.data).unwrap();
933        assert!(agent.request_transaction(transaction_id).is_none());
934        assert!(agent.mut_request_transaction(transaction_id).is_none());
935        // you should definitely never do this ;). Indications should never get replies.
936        let response = Message::builder(
937            MessageType::from_class_method(MessageClass::Error, BINDING),
938            transaction_id,
939            MessageWriteVec::new(),
940        );
941        let resp_data = response.finish();
942        let response = Message::from_bytes(&resp_data).unwrap();
943        // response without a request is dropped.
944        assert!(!agent.handle_stun_message(&response, remote_addr))
945    }
946
947    #[test]
948    fn request_with_credentials() {
949        let _log = crate::tests::test_init_log();
950        let local_addr = "10.0.0.1:12345".parse().unwrap();
951        let remote_addr = "10.0.0.2:3478".parse().unwrap();
952
953        let mut auth = ShortTermAuth::new();
954        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
955        let credentials = ShortTermCredentials::new(String::from("local_password"));
956        auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
957
958        // unvalidated peer data should be dropped
959        assert!(!agent.is_validated_peer(remote_addr));
960
961        let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
962        let transaction_id = msg.transaction_id();
963        msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
964            .unwrap();
965        error!("send");
966        let transmit = agent
967            .send_request(msg.finish(), remote_addr, Instant::ZERO)
968            .unwrap();
969        error!("sent");
970
971        let request = Message::from_bytes(&transmit.data).unwrap();
972
973        error!("generate response");
974        let mut response = Message::builder_success(&request, MessageWriteVec::new());
975        let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
976        response.add_attribute(&xor_addr).unwrap();
977        response
978            .add_message_integrity(&credentials.into(), IntegrityAlgorithm::Sha1)
979            .unwrap();
980        error!("{response:?}");
981
982        let data = response.finish();
983        error!("{data:?}");
984        let response = Message::from_bytes(&data).unwrap();
985        error!("{response}");
986        assert_eq!(
987            auth.validate_incoming_message(&response).unwrap(),
988            Some(IntegrityAlgorithm::Sha1)
989        );
990        let request = agent
991            .request_transaction(response.transaction_id())
992            .unwrap();
993        assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
994        assert!(agent.handle_stun_message(&response, remote_addr));
995
996        assert_eq!(response.transaction_id(), transaction_id);
997        assert!(agent.request_transaction(transaction_id).is_none());
998        assert!(agent.mut_request_transaction(transaction_id).is_none());
999        assert!(agent.is_validated_peer(remote_addr));
1000    }
1001
1002    #[test]
1003    fn request_unanswered() {
1004        let _log = crate::tests::test_init_log();
1005        let local_addr = "127.0.0.1:2000".parse().unwrap();
1006        let remote_addr = "127.0.0.1:1000".parse().unwrap();
1007        let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1008            .remote_addr(remote_addr)
1009            .build();
1010        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1011        let transaction_id = msg.transaction_id();
1012        agent
1013            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1014            .unwrap();
1015        let mut now = Instant::ZERO;
1016        loop {
1017            let _ = agent.poll_transmit(now);
1018            match agent.poll(now) {
1019                StunAgentPollRet::WaitUntil(new_now) => {
1020                    now = new_now;
1021                }
1022                StunAgentPollRet::TransactionTimedOut(_) => break,
1023                _ => unreachable!(),
1024            }
1025        }
1026        assert!(agent.request_transaction(transaction_id).is_none());
1027        assert!(agent.mut_request_transaction(transaction_id).is_none());
1028
1029        // unvalidated peer data should be dropped
1030        assert!(!agent.is_validated_peer(remote_addr));
1031    }
1032
1033    #[test]
1034    fn request_custom_timeout() {
1035        let _log = crate::tests::test_init_log();
1036        let local_addr = "127.0.0.1:2000".parse().unwrap();
1037        let remote_addr = "127.0.0.1:1000".parse().unwrap();
1038        let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1039            .remote_addr(remote_addr)
1040            .build();
1041        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1042        let transaction_id = msg.transaction_id();
1043        let mut now = Instant::ZERO;
1044        agent.send_request(msg.finish(), remote_addr, now).unwrap();
1045        let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
1046        transaction.configure_timeout_with_max(
1047            Duration::from_secs(1),
1048            4,
1049            Duration::from_secs(10),
1050            Duration::from_secs(2),
1051        );
1052        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1053            unreachable!();
1054        };
1055        assert_eq!(wait - now, Duration::from_secs(1));
1056        now = wait;
1057        // a poll with the same instant should not busy loop
1058        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1059            unreachable!();
1060        };
1061        assert_eq!(wait, now);
1062        let Some(_) = agent.poll_transmit(now) else {
1063            unreachable!();
1064        };
1065        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1066            unreachable!();
1067        };
1068        assert_eq!(wait - now, Duration::from_secs(2));
1069        now = wait;
1070        let Some(_) = agent.poll_transmit(now) else {
1071            unreachable!();
1072        };
1073        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1074            unreachable!();
1075        };
1076        assert_eq!(wait - now, Duration::from_secs(2));
1077        now = wait;
1078        let Some(_) = agent.poll_transmit(now) else {
1079            unreachable!();
1080        };
1081        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1082            unreachable!();
1083        };
1084        assert_eq!(wait - now, Duration::from_secs(10));
1085        now = wait;
1086        let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1087            unreachable!();
1088        };
1089        assert_eq!(timed_out, transaction_id);
1090
1091        assert!(agent.request_transaction(transaction_id).is_none());
1092        assert!(agent.mut_request_transaction(transaction_id).is_none());
1093
1094        // unvalidated peer data should be dropped
1095        assert!(!agent.is_validated_peer(remote_addr));
1096    }
1097
1098    #[test]
1099    fn request_no_retransmit() {
1100        let _log = crate::tests::test_init_log();
1101        let local_addr = "127.0.0.1:2000".parse().unwrap();
1102        let remote_addr = "127.0.0.1:1000".parse().unwrap();
1103        let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1104            .remote_addr(remote_addr)
1105            .build();
1106        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1107        let transaction_id = msg.transaction_id();
1108        let mut now = Instant::ZERO;
1109        agent.send_request(msg.finish(), remote_addr, now).unwrap();
1110        let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
1111        transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
1112        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1113            unreachable!();
1114        };
1115        assert_eq!(wait - now, Duration::from_secs(10));
1116        now = wait;
1117        let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1118            unreachable!();
1119        };
1120        assert_eq!(timed_out, transaction_id);
1121
1122        assert!(agent.request_transaction(transaction_id).is_none());
1123        assert!(agent.mut_request_transaction(transaction_id).is_none());
1124
1125        // unvalidated peer data should be dropped
1126        assert!(!agent.is_validated_peer(remote_addr));
1127    }
1128
1129    #[test]
1130    fn request_tcp_custom_timeout() {
1131        let _log = crate::tests::test_init_log();
1132        let local_addr = "127.0.0.1:2000".parse().unwrap();
1133        let remote_addr = "127.0.0.1:1000".parse().unwrap();
1134        let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1135            .remote_addr(remote_addr)
1136            .request_retransmits(
1137                Duration::from_secs(1),
1138                Duration::from_secs(2),
1139                4,
1140                Duration::from_secs(3),
1141            )
1142            .build();
1143        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1144        let transaction_id = msg.transaction_id();
1145        let mut now = Instant::ZERO;
1146        agent.send_request(msg.finish(), remote_addr, now).unwrap();
1147        let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1148            unreachable!();
1149        };
1150        assert_eq!(wait - now, Duration::from_secs(1 + 2 + 2 + 3));
1151        now = wait;
1152        let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1153            unreachable!();
1154        };
1155        assert_eq!(timed_out, transaction_id);
1156
1157        assert!(agent.request_transaction(transaction_id).is_none());
1158        assert!(agent.mut_request_transaction(transaction_id).is_none());
1159
1160        // unvalidated peer data should be dropped
1161        assert!(!agent.is_validated_peer(remote_addr));
1162    }
1163
1164    #[test]
1165    fn request_without_credentials() {
1166        let _log = crate::tests::test_init_log();
1167        let local_addr = "10.0.0.1:12345".parse().unwrap();
1168        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1169
1170        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1171
1172        // unvalidated peer data should be dropped
1173        assert!(!agent.is_validated_peer(remote_addr));
1174
1175        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1176        let transaction_id = msg.transaction_id();
1177        let transmit = agent
1178            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1179            .unwrap();
1180
1181        let request = Message::from_bytes(&transmit.data).unwrap();
1182
1183        let mut response = Message::builder_success(&request, MessageWriteVec::new());
1184        let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1185        response.add_attribute(&xor_addr).unwrap();
1186
1187        let data = response.finish();
1188        let to = transmit.to;
1189        trace!("data: {data:?}");
1190        let response = Message::from_bytes(&data).unwrap();
1191        let request = agent
1192            .request_transaction(response.transaction_id())
1193            .unwrap();
1194        assert_eq!(request.integrity(), None);
1195        assert!(agent.handle_stun_message(&response, to));
1196        assert_eq!(response.transaction_id(), transaction_id);
1197        assert!(agent.request_transaction(transaction_id).is_none());
1198        assert!(agent.mut_request_transaction(transaction_id).is_none());
1199        assert!(agent.is_validated_peer(remote_addr));
1200    }
1201
1202    #[test]
1203    fn response_with_incorrect_credentials() {
1204        let _log = crate::tests::test_init_log();
1205        let local_addr = "10.0.0.1:12345".parse().unwrap();
1206        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1207
1208        let mut auth = ShortTermAuth::new();
1209        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1210        let credentials = ShortTermCredentials::new(String::from("local_password"));
1211        let wrong_credentials = ShortTermCredentials::new(String::from("wrong_password"));
1212        auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
1213
1214        let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
1215        msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
1216            .unwrap();
1217        let transmit = agent
1218            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1219            .unwrap();
1220        let data = transmit.data;
1221
1222        let request = Message::from_bytes(&data).unwrap();
1223
1224        let mut response = Message::builder_success(&request, MessageWriteVec::new());
1225        let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1226        response.add_attribute(&xor_addr).unwrap();
1227        // wrong credentials, should be `remote_credentials`
1228        response
1229            .add_message_integrity(&wrong_credentials.into(), IntegrityAlgorithm::Sha1)
1230            .unwrap();
1231
1232        let data = response.finish();
1233        let response = Message::from_bytes(&data).unwrap();
1234        // reply is ignored as it does not have credentials
1235        let request = agent
1236            .request_transaction(response.transaction_id())
1237            .unwrap();
1238        assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
1239        assert!(matches!(
1240            auth.validate_incoming_message(&response),
1241            Err(ValidateError::IntegrityFailed)
1242        ));
1243
1244        // unvalidated peer data should be dropped
1245        assert!(!agent.is_validated_peer(remote_addr));
1246
1247        // however providing signifying success will cause peer validation to succeed
1248        assert!(agent.handle_stun_message(&response, remote_addr));
1249        assert!(!agent.handle_stun_message(&response, remote_addr));
1250        assert!(agent.is_validated_peer(remote_addr));
1251    }
1252
1253    #[test]
1254    fn duplicate_response_ignored() {
1255        let _log = crate::tests::test_init_log();
1256        let local_addr = "10.0.0.1:12345".parse().unwrap();
1257        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1258
1259        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1260        assert!(!agent.is_validated_peer(remote_addr));
1261
1262        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1263        let transmit = agent
1264            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1265            .unwrap();
1266        let data = transmit.data;
1267
1268        let request = Message::from_bytes(&data).unwrap();
1269
1270        let mut response = Message::builder_success(&request, MessageWriteVec::new());
1271        let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1272        response.add_attribute(&xor_addr).unwrap();
1273
1274        let data = response.finish();
1275        let to = transmit.to;
1276        let response = Message::from_bytes(&data).unwrap();
1277        assert!(agent.handle_stun_message(&response, to));
1278
1279        let response = Message::from_bytes(&data).unwrap();
1280        assert!(!agent.handle_stun_message(&response, to));
1281    }
1282
1283    #[test]
1284    fn request_cancel() {
1285        let _log = crate::tests::test_init_log();
1286        let local_addr = "10.0.0.1:12345".parse().unwrap();
1287        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1288
1289        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1290
1291        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1292        let transaction_id = msg.transaction_id();
1293        let _transmit = agent
1294            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1295            .unwrap();
1296
1297        let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1298        assert_eq!(request.integrity(), None);
1299        assert_eq!(request.agent().local_addr(), local_addr);
1300        assert_eq!(request.mut_agent().local_addr(), local_addr);
1301        assert_eq!(request.peer_address(), remote_addr);
1302        request.cancel();
1303
1304        let ret = agent.poll(Instant::ZERO);
1305        let StunAgentPollRet::TransactionCancelled(_request) = ret else {
1306            unreachable!();
1307        };
1308        assert_eq!(transaction_id, transaction_id);
1309        assert!(agent.request_transaction(transaction_id).is_none());
1310        assert!(agent.mut_request_transaction(transaction_id).is_none());
1311        assert!(!agent.is_validated_peer(remote_addr));
1312    }
1313
1314    #[test]
1315    fn request_cancel_send() {
1316        let _log = crate::tests::test_init_log();
1317        let local_addr = "10.0.0.1:12345".parse().unwrap();
1318        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1319
1320        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1321
1322        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1323        let transaction_id = msg.transaction_id();
1324        let _transmit = agent
1325            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1326            .unwrap();
1327
1328        let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1329        assert_eq!(request.integrity(), None);
1330        assert_eq!(request.agent().local_addr(), local_addr);
1331        assert_eq!(request.mut_agent().local_addr(), local_addr);
1332        assert_eq!(request.peer_address(), remote_addr);
1333        request.cancel_retransmissions();
1334
1335        let mut now = Instant::ZERO;
1336        let start = now;
1337        loop {
1338            match agent.poll(now) {
1339                StunAgentPollRet::WaitUntil(new_now) => {
1340                    assert_ne!(new_now, now);
1341                    now = new_now;
1342                }
1343                StunAgentPollRet::TransactionCancelled(_) => break,
1344                _ => unreachable!(),
1345            }
1346            let _ = agent.poll_transmit(now);
1347        }
1348        assert!(now - start > Duration::from_secs(20));
1349        assert!(agent.request_transaction(transaction_id).is_none());
1350        assert!(agent.mut_request_transaction(transaction_id).is_none());
1351        assert!(!agent.is_validated_peer(remote_addr));
1352    }
1353
1354    #[test]
1355    fn request_duplicate() {
1356        let _log = crate::tests::test_init_log();
1357        let local_addr = "10.0.0.1:12345".parse().unwrap();
1358        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1359
1360        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1361
1362        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1363        let transaction_id = msg.transaction_id();
1364        let msg = msg.finish();
1365        let transmit = agent
1366            .send_request(msg.clone(), remote_addr, Instant::ZERO)
1367            .unwrap();
1368        let to = transmit.to;
1369        let request = Message::from_bytes(&transmit.data).unwrap();
1370
1371        let mut response = Message::builder_success(&request, MessageWriteVec::new());
1372        let xor_addr = XorMappedAddress::new(transmit.from, transaction_id);
1373        response.add_attribute(&xor_addr).unwrap();
1374
1375        assert!(matches!(
1376            agent.send_request(msg, remote_addr, Instant::ZERO),
1377            Err(StunError::AlreadyInProgress)
1378        ));
1379
1380        // the original transaction should still exist
1381        let request = agent.request_transaction(transaction_id).unwrap();
1382        assert_eq!(request.peer_address(), remote_addr);
1383
1384        let data = response.finish();
1385        let response = Message::from_bytes(&data).unwrap();
1386        assert!(agent.handle_stun_message(&response, to));
1387
1388        assert!(agent.is_validated_peer(to));
1389    }
1390
1391    #[test]
1392    fn incoming_request() {
1393        let _log = crate::tests::test_init_log();
1394        let local_addr = "10.0.0.1:12345".parse().unwrap();
1395        let remote_addr = "10.0.0.2:3478".parse().unwrap();
1396
1397        let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1398
1399        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1400        let data = msg.finish();
1401        let stun = Message::from_bytes(&data).unwrap();
1402        error!("{stun:?}");
1403        assert!(agent.handle_stun_message(&stun, remote_addr));
1404        agent.validated_peer(remote_addr);
1405        assert!(agent.is_validated_peer(remote_addr));
1406    }
1407
1408    #[test]
1409    fn tcp_request() {
1410        let _log = crate::tests::test_init_log();
1411        let local_addr = "127.0.0.1:2000".parse().unwrap();
1412        let remote_addr = "127.0.0.1:1000".parse().unwrap();
1413        let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1414            .remote_addr(remote_addr)
1415            .build();
1416
1417        let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1418        let transaction_id = msg.transaction_id();
1419        let transmit = agent
1420            .send_request(msg.finish(), remote_addr, Instant::ZERO)
1421            .unwrap();
1422        assert_eq!(transmit.transport, TransportType::Tcp);
1423        assert_eq!(transmit.from, local_addr);
1424        assert_eq!(transmit.to, remote_addr);
1425
1426        let request = Message::from_bytes(&transmit.data).unwrap();
1427        assert_eq!(request.transaction_id(), transaction_id);
1428    }
1429
1430    #[test]
1431    fn transmit_into_owned() {
1432        let data = [0x10, 0x20];
1433        let transport = TransportType::Udp;
1434        let from = "127.0.0.1:1000".parse().unwrap();
1435        let to = "127.0.0.1:2000".parse().unwrap();
1436        let transmit = Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to);
1437        let owned = transmit.into_owned();
1438        assert_eq!(owned.data.as_ref(), data.as_ref());
1439        assert_eq!(owned.transport, transport);
1440        assert_eq!(owned.from, from);
1441        assert_eq!(owned.to, to);
1442        error!("{owned}");
1443    }
1444
1445    #[test]
1446    fn transmit_display() {
1447        let data = [0x10, 0x20];
1448        let from = "127.0.0.1:1000".parse().unwrap();
1449        let to = "127.0.0.1:2000".parse().unwrap();
1450        assert_eq!(
1451            alloc::format!(
1452                "{}",
1453                Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to)
1454            ),
1455            String::from("Transmit(UDP: 127.0.0.1:1000 -> 127.0.0.1:2000 of 2 bytes)")
1456        );
1457    }
1458
1459    #[test]
1460    fn request_retransmits() {
1461        let _log = crate::tests::test_init_log();
1462        let rto = RequestRto {
1463            initial: Duration::from_millis(1),
1464            max: Duration::MAX,
1465            retransmits: 0,
1466            last_retransmit: Duration::from_secs(1),
1467        };
1468        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1469        assert_eq!(timeouts, vec![]);
1470        assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1471        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1472        assert_eq!(timeouts, vec![]);
1473        assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1474
1475        let rto = RequestRto {
1476            initial: Duration::from_millis(1),
1477            max: Duration::MAX,
1478            retransmits: 1,
1479            last_retransmit: Duration::from_secs(1),
1480        };
1481        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1482        assert_eq!(timeouts, vec![]);
1483        assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1484        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1485        assert_eq!(timeouts, vec![]);
1486        assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1487
1488        let rto = RequestRto {
1489            initial: Duration::from_millis(1),
1490            max: Duration::MAX,
1491            retransmits: 2,
1492            last_retransmit: Duration::from_secs(1),
1493        };
1494        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1495        assert_eq!(timeouts, vec![Duration::from_millis(1)]);
1496        assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1497        let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1498        assert_eq!(timeouts, vec![]);
1499        assert_eq!(
1500            last_transmit_timeout,
1501            Duration::from_secs(1) + Duration::from_millis(1)
1502        );
1503    }
1504}