1use core::net::SocketAddr;
21use core::sync::atomic::{AtomicUsize, Ordering};
22
23use alloc::collections::{BTreeMap, BTreeSet};
24use alloc::vec;
25use alloc::vec::Vec;
26use core::time::Duration;
27
28use crate::Instant;
29
30use stun_types::attribute::*;
31use stun_types::data::Data;
32use stun_types::message::*;
33
34use stun_types::TransportType;
35
36use tracing::{debug, trace, warn};
37
38static STUN_AGENT_COUNT: AtomicUsize = AtomicUsize::new(0);
39
40#[derive(Debug)]
42pub struct StunAgent {
43 id: usize,
44 transport: TransportType,
45 local_addr: SocketAddr,
46 remote_addr: Option<SocketAddr>,
47 validated_peers: BTreeSet<SocketAddr>,
48 outstanding_requests: BTreeMap<TransactionId, StunRequestState>,
49 request_timeouts: Vec<Duration>,
50 last_retransmit_timeout: Duration,
51}
52
53#[derive(Debug)]
55pub struct StunAgentBuilder {
56 transport: TransportType,
57 local_addr: SocketAddr,
58 remote_addr: Option<SocketAddr>,
59 rto: RequestRto,
60}
61
62impl StunAgentBuilder {
63 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
65 self.remote_addr = Some(addr);
66 self
67 }
68
69 pub fn request_retransmits(
85 mut self,
86 initial: Duration,
87 max: Duration,
88 retransmits: u32,
89 final_retransmit_timeout: Duration,
90 ) -> Self {
91 self.rto.initial = initial;
92 self.rto.max = max;
93 self.rto.retransmits = retransmits;
94 self.rto.last_retransmit = final_retransmit_timeout;
95 self
96 }
97
98 pub fn build(self) -> StunAgent {
100 let id = STUN_AGENT_COUNT.fetch_add(1, Ordering::SeqCst);
101 let (request_timeouts, last_retransmit_timeout) =
102 self.rto.calculate_timeouts(self.transport);
103 StunAgent {
104 id,
105 transport: self.transport,
106 local_addr: self.local_addr,
107 remote_addr: self.remote_addr,
108 validated_peers: Default::default(),
109 outstanding_requests: Default::default(),
110 request_timeouts,
111 last_retransmit_timeout,
112 }
113 }
114}
115
116impl StunAgent {
117 pub fn builder(transport: TransportType, local_addr: SocketAddr) -> StunAgentBuilder {
119 StunAgentBuilder {
120 transport,
121 local_addr,
122 remote_addr: None,
123 rto: Default::default(),
124 }
125 }
126
127 pub fn transport(&self) -> TransportType {
129 self.transport
130 }
131
132 pub fn local_addr(&self) -> SocketAddr {
134 self.local_addr
135 }
136
137 pub fn remote_addr(&self) -> Option<SocketAddr> {
139 self.remote_addr
140 }
141
142 pub fn send_data<T: AsRef<[u8]>>(&self, bytes: T, to: SocketAddr) -> Transmit<T> {
144 send_data(self.transport, bytes, self.local_addr, to)
145 }
146
147 #[tracing::instrument(name = "stun_agent_send",
155 skip(self, msg),
156 fields(
157 transport = %self.transport,
158 from = %self.local_addr,
159 transaction_id,
160 )
161 )]
162 pub fn send<T: AsRef<[u8]>>(
163 &mut self,
164 msg: T,
165 to: SocketAddr,
166 now: Instant,
167 ) -> Result<Transmit<T>, StunError> {
168 let data = msg.as_ref();
169 let hdr = MessageHeader::from_bytes(data)?;
170 tracing::Span::current().record(
171 "transaction_id",
172 tracing::field::display(hdr.transaction_id()),
173 );
174 assert!(!hdr.get_type().has_class(MessageClass::Request));
175 trace!("Sending {} to {to}", hdr.get_type());
176 Ok(Transmit::new(msg, self.transport, self.local_addr, to))
177 }
178
179 #[tracing::instrument(name = "stun_agent_send_request",
187 skip(self, msg),
188 fields(
189 transport = %self.transport,
190 from = %self.local_addr,
191 transaction_id,
192 )
193 )]
194 pub fn send_request<'a, T: AsRef<[u8]>>(
195 &'a mut self,
196 msg: T,
197 to: SocketAddr,
198 now: Instant,
199 ) -> Result<Transmit<Data<'a>>, StunError> {
200 let data = msg.as_ref();
201 let hdr = MessageHeader::from_bytes(data)?;
202 assert!(hdr.get_type().has_class(MessageClass::Request));
203 let transaction_id = hdr.transaction_id();
204 tracing::Span::current().record("transaction_id", tracing::field::display(transaction_id));
205 let state = match self.outstanding_requests.entry(transaction_id) {
206 alloc::collections::btree_map::Entry::Vacant(entry) => {
207 let integrity_algorithm = MessageAttributesIter::new(data)
208 .filter_map(|(_offset, attr)| match attr.get_type() {
209 MessageIntegrity::TYPE => Some(IntegrityAlgorithm::Sha1),
210 MessageIntegritySha256::TYPE => Some(IntegrityAlgorithm::Sha256),
211 _ => None,
212 })
213 .last();
214 trace!("Adding request to {to} with integrity algorithm: {integrity_algorithm:?}");
215 entry.insert(StunRequestState::new(
216 msg,
217 self.transport,
218 self.local_addr,
219 to,
220 transaction_id,
221 integrity_algorithm,
222 self.request_timeouts.clone(),
223 self.last_retransmit_timeout,
224 ))
225 }
226 alloc::collections::btree_map::Entry::Occupied(_entry) => {
227 return Err(StunError::AlreadyInProgress);
228 }
229 };
230 let Some(transmit) = state.poll_transmit(now) else {
231 unreachable!();
232 };
233 Ok(Transmit::new(
234 Data::from(transmit.data),
235 transmit.transport,
236 transmit.from,
237 transmit.to,
238 ))
239 }
240
241 pub fn is_validated_peer(&self, remote_addr: SocketAddr) -> bool {
248 self.validated_peers.contains(&remote_addr)
249 }
250
251 #[tracing::instrument(
253 name = "stun_validated_peer"
254 skip(self),
255 fields(stun_id = self.id)
256 )]
257 pub fn validated_peer(&mut self, addr: SocketAddr) {
258 if !self.validated_peers.contains(&addr) {
259 debug!("validated peer {:?}", addr);
260 self.validated_peers.insert(addr);
261 }
262 }
263
264 #[tracing::instrument(
273 name = "stun_handle_message"
274 skip(self, msg, from),
275 fields(
276 transaction_id = %msg.transaction_id(),
277 )
278 )]
279 pub fn handle_stun_message(&mut self, msg: &Message<'_>, from: SocketAddr) -> bool {
280 if msg.is_response()
281 && self
282 .take_outstanding_request(&msg.transaction_id())
283 .is_none()
284 {
285 trace!("original request disappeared");
286 return false;
287 }
288 self.validated_peer(from);
289 true
290 }
291
292 #[tracing::instrument(
293 skip(self, transaction_id),
294 fields(transaction_id = %transaction_id)
295 )]
296 fn take_outstanding_request(
297 &mut self,
298 transaction_id: &TransactionId,
299 ) -> Option<StunRequestState> {
300 if let Some(request) = self.outstanding_requests.remove(transaction_id) {
301 trace!("removing request");
302 Some(request)
303 } else {
304 trace!("no outstanding request");
305 None
306 }
307 }
308
309 pub fn request_transaction(&self, transaction_id: TransactionId) -> Option<StunRequest<'_>> {
315 if self.outstanding_requests.contains_key(&transaction_id) {
316 Some(StunRequest {
317 agent: self,
318 transaction_id,
319 })
320 } else {
321 None
322 }
323 }
324
325 pub fn mut_request_transaction(
331 &mut self,
332 transaction_id: TransactionId,
333 ) -> Option<StunRequestMut<'_>> {
334 if self.outstanding_requests.contains_key(&transaction_id) {
335 Some(StunRequestMut {
336 agent: self,
337 transaction_id,
338 })
339 } else {
340 None
341 }
342 }
343
344 fn mut_request_state(
345 &mut self,
346 transaction_id: TransactionId,
347 ) -> Option<&mut StunRequestState> {
348 self.outstanding_requests.get_mut(&transaction_id)
349 }
350
351 fn request_state(&self, transaction_id: TransactionId) -> Option<&StunRequestState> {
352 self.outstanding_requests.get(&transaction_id)
353 }
354
355 #[tracing::instrument(
361 name = "stun_agent_poll"
362 level = "debug",
363 skip(self),
364 )]
365 pub fn poll(&mut self, now: Instant) -> StunAgentPollRet {
366 let mut lowest_wait = now + Duration::from_secs(3600);
367 let mut timeout = None;
368 let mut cancelled = None;
369 for (transaction_id, request) in self.outstanding_requests.iter_mut() {
370 debug_assert_eq!(transaction_id, &request.transaction_id);
371 match request.poll(now) {
372 StunRequestPollRet::Cancelled => {
373 cancelled = Some(*transaction_id);
374 break;
375 }
376 StunRequestPollRet::WaitUntil(wait_until) => {
377 if wait_until < lowest_wait {
378 lowest_wait = wait_until;
379 }
380 }
381 StunRequestPollRet::TimedOut => {
382 timeout = Some(*transaction_id);
383 break;
384 }
385 }
386 }
387 if let Some(transaction) = timeout {
388 if let Some(_state) = self.outstanding_requests.remove(&transaction) {
389 return StunAgentPollRet::TransactionTimedOut(transaction);
390 }
391 }
392 if let Some(transaction) = cancelled {
393 if let Some(_state) = self.outstanding_requests.remove(&transaction) {
394 return StunAgentPollRet::TransactionCancelled(transaction);
395 }
396 }
397 StunAgentPollRet::WaitUntil(lowest_wait)
398 }
399
400 #[tracing::instrument(
402 name = "stun_agent_poll_transmit"
403 level = "debug",
404 skip(self),
405 )]
406 pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
407 self.outstanding_requests
408 .values_mut()
409 .filter_map(|request| request.poll_transmit(now))
410 .next()
411 }
412}
413
414#[derive(Debug)]
416pub enum StunAgentPollRet {
417 TransactionTimedOut(TransactionId),
419 TransactionCancelled(TransactionId),
421 WaitUntil(Instant),
423}
424
425fn send_data<T: AsRef<[u8]>>(
426 transport: TransportType,
427 bytes: T,
428 from: SocketAddr,
429 to: SocketAddr,
430) -> Transmit<T> {
431 Transmit::new(bytes, transport, from, to)
432}
433
434#[derive(Debug)]
436pub struct Transmit<T: AsRef<[u8]>> {
437 pub data: T,
439 pub transport: TransportType,
441 pub from: SocketAddr,
443 pub to: SocketAddr,
445}
446
447impl<T: AsRef<[u8]>> core::fmt::Display for Transmit<T> {
448 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
449 write!(
450 f,
451 "Transmit({}: {} -> {} of {} bytes)",
452 self.transport,
453 self.from,
454 self.to,
455 self.data.as_ref().len()
456 )
457 }
458}
459
460impl<T: AsRef<[u8]>> Transmit<T> {
461 pub fn new(data: T, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self {
463 Self {
464 data,
465 transport,
466 from,
467 to,
468 }
469 }
470
471 pub fn reinterpret_data<O: AsRef<[u8]>, F: FnOnce(T) -> O>(self, f: F) -> Transmit<O> {
491 Transmit {
492 data: f(self.data),
493 transport: self.transport,
494 from: self.from,
495 to: self.to,
496 }
497 }
498}
499
500impl Transmit<Data<'_>> {
501 pub fn into_owned<'b>(self) -> Transmit<Data<'b>> {
503 self.reinterpret_data(|data| data.into_owned())
504 }
505}
506
507#[derive(Debug)]
509enum StunRequestPollRet {
510 WaitUntil(Instant),
512 Cancelled,
514 TimedOut,
516}
517
518#[derive(Debug)]
519struct RequestRto {
520 initial: Duration,
521 max: Duration,
522 retransmits: u32,
523 last_retransmit: Duration,
524}
525
526impl Default for RequestRto {
527 fn default() -> Self {
528 Self {
529 initial: Duration::from_millis(500),
530 max: Duration::MAX,
531 retransmits: 7,
532 last_retransmit: Duration::from_millis(8),
533 }
534 }
535}
536
537impl RequestRto {
538 fn calculate_timeouts(&self, transport: TransportType) -> (Vec<Duration>, Duration) {
539 match transport {
540 TransportType::Udp => {
541 let timeouts = (0..self.retransmits.max(1) - 1)
542 .map(|i| (self.initial * 2u32.pow(i)).min(self.max))
543 .collect::<Vec<_>>();
544 (timeouts, self.last_retransmit)
545 }
546 TransportType::Tcp => {
547 let timeouts = vec![];
548 let last_retransmit_timeout = self.last_retransmit
549 + (0..self.retransmits.max(1) - 1).fold(Duration::ZERO, |acc, i| {
550 acc + (self.initial * 2u32.pow(i)).min(self.max)
551 });
552 (timeouts, last_retransmit_timeout)
553 }
554 }
555 }
556}
557
558#[derive(Debug)]
559struct StunRequestState {
560 transaction_id: TransactionId,
561 request_integrity: Option<IntegrityAlgorithm>,
562 bytes: Vec<u8>,
563 transport: TransportType,
564 from: SocketAddr,
565 to: SocketAddr,
566 timeouts: Vec<Duration>,
567 last_retransmit_timeout: Duration,
568 recv_cancelled: bool,
569 send_cancelled: bool,
570 timeout_i: usize,
571 last_send_time: Option<Instant>,
572}
573
574impl StunRequestState {
575 #[allow(clippy::too_many_arguments)]
576 fn new<T: AsRef<[u8]>>(
577 request: T,
578 transport: TransportType,
579 from: SocketAddr,
580 to: SocketAddr,
581 transaction_id: TransactionId,
582 integrity_algorithm: Option<IntegrityAlgorithm>,
583 timeouts: Vec<Duration>,
584 last_retransmit_timeout: Duration,
585 ) -> Self {
586 let data = request.as_ref();
587 Self {
600 transaction_id,
601 bytes: data.to_vec(),
602 transport,
603 from,
604 to,
605 request_integrity: integrity_algorithm,
606 timeouts,
607 timeout_i: 0,
608 last_retransmit_timeout,
609 recv_cancelled: false,
610 send_cancelled: false,
611 last_send_time: None,
612 }
613 }
614
615 #[tracing::instrument(skip(self, now), level = "trace")]
616 fn next_send_time(&self, now: Instant) -> Option<Instant> {
617 let Some(last_send) = self.last_send_time else {
618 trace!("not sent yet -> send immediately");
619 return Some(now);
620 };
621 if self.timeout_i >= self.timeouts.len() {
622 let next_send = last_send + self.last_retransmit_timeout;
623 trace!("final retransmission, final timeout ends at {next_send:?}");
624 if next_send > now {
625 return Some(next_send);
626 }
627 return None;
628 }
629 let next_send = last_send + self.timeouts[self.timeout_i];
630 Some(next_send)
631 }
632
633 #[tracing::instrument(
634 name = "stun_request_poll"
635 level = "debug",
636 ret,
637 skip(self, now),
638 fields(transaction_id = %self.transaction_id),
639 )]
640 fn poll(&mut self, now: Instant) -> StunRequestPollRet {
641 if self.recv_cancelled {
642 return StunRequestPollRet::Cancelled;
643 }
644 let Some(next_send) = self.next_send_time(now) else {
646 return StunRequestPollRet::TimedOut;
647 };
648 if next_send >= now {
649 if self.send_cancelled && self.timeout_i >= self.timeouts.len() {
650 return StunRequestPollRet::Cancelled;
652 }
653 return StunRequestPollRet::WaitUntil(next_send);
654 }
655 StunRequestPollRet::WaitUntil(now)
656 }
657
658 #[tracing::instrument(
659 name = "stun_request_poll_transmit",
660 skip(self, now),
661 fields(transaction_id = %self.transaction_id)
662 )]
663 fn poll_transmit(&mut self, now: Instant) -> Option<Transmit<&[u8]>> {
664 if self.recv_cancelled {
665 return None;
666 };
667 let next_send = self.next_send_time(now)?;
668
669 if next_send > now {
670 return None;
671 }
672 if self.last_send_time.is_some() {
673 self.timeout_i += 1;
674 }
675 self.last_send_time = Some(now);
676 if self.send_cancelled {
677 return None;
678 };
679 trace!(
680 "sending {} bytes over {:?} from {:?} to {:?}",
681 self.bytes.len(),
682 self.transport,
683 self.from,
684 self.to
685 );
686 Some(send_data(
687 self.transport,
688 self.bytes.as_slice(),
689 self.from,
690 self.to,
691 ))
692 }
693}
694
695#[derive(Debug, Clone)]
697pub struct StunRequest<'a> {
698 agent: &'a StunAgent,
699 transaction_id: TransactionId,
700}
701
702impl StunRequest<'_> {
703 pub fn peer_address(&self) -> SocketAddr {
705 let state = self.agent.request_state(self.transaction_id).unwrap();
706 state.to
707 }
708
709 pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
711 let state = self.agent.request_state(self.transaction_id).unwrap();
712 state.request_integrity
713 }
714}
715
716#[derive(Debug)]
718pub struct StunRequestMut<'a> {
719 agent: &'a mut StunAgent,
720 transaction_id: TransactionId,
721}
722
723impl StunRequestMut<'_> {
724 pub fn peer_address(&self) -> SocketAddr {
726 let state = self.agent.request_state(self.transaction_id).unwrap();
727 state.to
728 }
729
730 pub fn integrity(&self) -> Option<IntegrityAlgorithm> {
732 let state = self.agent.request_state(self.transaction_id).unwrap();
733 state.request_integrity
734 }
735
736 pub fn cancel_retransmissions(&mut self) {
740 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
741 state.send_cancelled = true;
742 }
743 }
744
745 pub fn cancel(&mut self) {
748 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
749 state.send_cancelled = true;
750 state.recv_cancelled = true;
751 }
752 }
753
754 pub fn agent(&self) -> &StunAgent {
756 self.agent
757 }
758
759 pub fn mut_agent(&mut self) -> &mut StunAgent {
761 self.agent
762 }
763
764 pub fn configure_timeout(
768 &mut self,
769 initial_rto: Duration,
770 retransmits: u32,
771 last_retransmit_timeout: Duration,
772 ) {
773 self.configure_timeout_with_max(
774 initial_rto,
775 retransmits,
776 last_retransmit_timeout,
777 Duration::MAX,
778 );
779 }
780
781 pub fn configure_timeout_with_max(
797 &mut self,
798 initial_rto: Duration,
799 retransmits: u32,
800 last_retransmit_timeout: Duration,
801 max_rto: Duration,
802 ) {
803 if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
804 let (timeouts, final_wait) = RequestRto {
805 initial: initial_rto,
806 max: max_rto,
807 retransmits,
808 last_retransmit: last_retransmit_timeout,
809 }
810 .calculate_timeouts(state.transport);
811 state.timeouts = timeouts;
812 state.last_retransmit_timeout = final_wait;
813 }
814 }
815}
816
817#[derive(Debug, thiserror::Error)]
819#[non_exhaustive]
820pub enum StunError {
821 #[error("The operation is already in progress")]
823 AlreadyInProgress,
824 #[error("A required resource could not be found")]
826 ResourceNotFound,
827 #[error("An operation timed out")]
829 TimedOut,
830 #[error("Unexpected data was received")]
832 ProtocolViolation,
833 #[error("Operation was aborted")]
835 Aborted,
836 #[error("{}", .0)]
838 ParseError(StunParseError),
839 #[error("{}", .0)]
841 WriteError(StunWriteError),
842}
843
844impl From<StunParseError> for StunError {
845 fn from(e: StunParseError) -> Self {
846 StunError::ParseError(e)
847 }
848}
849
850impl From<StunWriteError> for StunError {
851 fn from(e: StunWriteError) -> Self {
852 StunError::WriteError(e)
853 }
854}
855
856#[cfg(test)]
857pub(crate) mod tests {
858 use alloc::string::String;
859 use tracing::error;
860
861 use crate::auth::ShortTermAuth;
862
863 use super::*;
864
865 #[test]
866 fn agent_getters_setters() {
867 let _log = crate::tests::test_init_log();
868 let local_addr = "10.0.0.1:12345".parse().unwrap();
869 let remote_addr = "10.0.0.2:3478".parse().unwrap();
870 let agent = StunAgent::builder(TransportType::Udp, local_addr)
871 .remote_addr(remote_addr)
872 .build();
873
874 assert_eq!(agent.transport(), TransportType::Udp);
875 assert_eq!(agent.local_addr(), local_addr);
876 assert_eq!(agent.remote_addr(), Some(remote_addr));
877 }
878
879 #[test]
880 fn request() {
881 let _log = crate::tests::test_init_log();
882 let local_addr = "127.0.0.1:2000".parse().unwrap();
883 let remote_addr = "127.0.0.1:1000".parse().unwrap();
884 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
885 .remote_addr(remote_addr)
886 .build();
887 let now = Instant::ZERO;
888
889 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
890 let transaction_id = msg.transaction_id();
891 let transmit = agent
892 .send_request(msg.finish(), remote_addr, now)
893 .unwrap()
894 .into_owned();
895 let request = agent.request_transaction(transaction_id).unwrap();
896 assert!(request.integrity().is_none());
897 assert_eq!(transmit.transport, TransportType::Udp);
898 assert_eq!(transmit.from, local_addr);
899 assert_eq!(transmit.to, remote_addr);
900 let request = Message::from_bytes(&transmit.data).unwrap();
901 let response = Message::builder_error(&request, MessageWriteVec::new());
902 let resp_data = response.finish();
903 let response = Message::from_bytes(&resp_data).unwrap();
904 assert!(agent.handle_stun_message(&response, remote_addr));
905 assert!(agent.request_transaction(transaction_id).is_none());
906 assert!(agent.mut_request_transaction(transaction_id).is_none());
907
908 let ret = agent.poll(now);
909 assert!(matches!(ret, StunAgentPollRet::WaitUntil(_)));
910 }
911
912 #[test]
913 fn indication_with_invalid_response() {
914 let _log = crate::tests::test_init_log();
915 let local_addr = "127.0.0.1:2000".parse().unwrap();
916 let remote_addr = "127.0.0.1:1000".parse().unwrap();
917 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
918 .remote_addr(remote_addr)
919 .build();
920 let transaction_id = TransactionId::generate();
921 let msg = Message::builder(
922 MessageType::from_class_method(MessageClass::Indication, BINDING),
923 transaction_id,
924 MessageWriteVec::new(),
925 );
926 let transmit = agent
927 .send(msg.finish(), remote_addr, Instant::ZERO)
928 .unwrap();
929 assert_eq!(transmit.transport, TransportType::Udp);
930 assert_eq!(transmit.from, local_addr);
931 assert_eq!(transmit.to, remote_addr);
932 let _indication = Message::from_bytes(&transmit.data).unwrap();
933 assert!(agent.request_transaction(transaction_id).is_none());
934 assert!(agent.mut_request_transaction(transaction_id).is_none());
935 let response = Message::builder(
937 MessageType::from_class_method(MessageClass::Error, BINDING),
938 transaction_id,
939 MessageWriteVec::new(),
940 );
941 let resp_data = response.finish();
942 let response = Message::from_bytes(&resp_data).unwrap();
943 assert!(!agent.handle_stun_message(&response, remote_addr))
945 }
946
947 #[test]
948 fn request_with_credentials() {
949 let _log = crate::tests::test_init_log();
950 let local_addr = "10.0.0.1:12345".parse().unwrap();
951 let remote_addr = "10.0.0.2:3478".parse().unwrap();
952
953 let mut auth = ShortTermAuth::new();
954 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
955 let credentials = ShortTermCredentials::new(String::from("local_password"));
956 auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
957
958 assert!(!agent.is_validated_peer(remote_addr));
960
961 let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
962 let transaction_id = msg.transaction_id();
963 msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
964 .unwrap();
965 error!("send");
966 let transmit = agent
967 .send_request(msg.finish(), remote_addr, Instant::ZERO)
968 .unwrap();
969 error!("sent");
970
971 let request = Message::from_bytes(&transmit.data).unwrap();
972
973 error!("generate response");
974 let mut response = Message::builder_success(&request, MessageWriteVec::new());
975 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
976 response.add_attribute(&xor_addr).unwrap();
977 response
978 .add_message_integrity(&credentials.into(), IntegrityAlgorithm::Sha1)
979 .unwrap();
980 error!("{response:?}");
981
982 let data = response.finish();
983 error!("{data:?}");
984 let response = Message::from_bytes(&data).unwrap();
985 error!("{response}");
986 assert_eq!(
987 auth.validate_incoming_message(&response).unwrap(),
988 Some(IntegrityAlgorithm::Sha1)
989 );
990 let request = agent
991 .request_transaction(response.transaction_id())
992 .unwrap();
993 assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
994 assert!(agent.handle_stun_message(&response, remote_addr));
995
996 assert_eq!(response.transaction_id(), transaction_id);
997 assert!(agent.request_transaction(transaction_id).is_none());
998 assert!(agent.mut_request_transaction(transaction_id).is_none());
999 assert!(agent.is_validated_peer(remote_addr));
1000 }
1001
1002 #[test]
1003 fn request_unanswered() {
1004 let _log = crate::tests::test_init_log();
1005 let local_addr = "127.0.0.1:2000".parse().unwrap();
1006 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1007 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1008 .remote_addr(remote_addr)
1009 .build();
1010 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1011 let transaction_id = msg.transaction_id();
1012 agent
1013 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1014 .unwrap();
1015 let mut now = Instant::ZERO;
1016 loop {
1017 let _ = agent.poll_transmit(now);
1018 match agent.poll(now) {
1019 StunAgentPollRet::WaitUntil(new_now) => {
1020 now = new_now;
1021 }
1022 StunAgentPollRet::TransactionTimedOut(_) => break,
1023 _ => unreachable!(),
1024 }
1025 }
1026 assert!(agent.request_transaction(transaction_id).is_none());
1027 assert!(agent.mut_request_transaction(transaction_id).is_none());
1028
1029 assert!(!agent.is_validated_peer(remote_addr));
1031 }
1032
1033 #[test]
1034 fn request_custom_timeout() {
1035 let _log = crate::tests::test_init_log();
1036 let local_addr = "127.0.0.1:2000".parse().unwrap();
1037 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1038 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1039 .remote_addr(remote_addr)
1040 .build();
1041 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1042 let transaction_id = msg.transaction_id();
1043 let mut now = Instant::ZERO;
1044 agent.send_request(msg.finish(), remote_addr, now).unwrap();
1045 let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
1046 transaction.configure_timeout_with_max(
1047 Duration::from_secs(1),
1048 4,
1049 Duration::from_secs(10),
1050 Duration::from_secs(2),
1051 );
1052 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1053 unreachable!();
1054 };
1055 assert_eq!(wait - now, Duration::from_secs(1));
1056 now = wait;
1057 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1059 unreachable!();
1060 };
1061 assert_eq!(wait, now);
1062 let Some(_) = agent.poll_transmit(now) else {
1063 unreachable!();
1064 };
1065 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1066 unreachable!();
1067 };
1068 assert_eq!(wait - now, Duration::from_secs(2));
1069 now = wait;
1070 let Some(_) = agent.poll_transmit(now) else {
1071 unreachable!();
1072 };
1073 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1074 unreachable!();
1075 };
1076 assert_eq!(wait - now, Duration::from_secs(2));
1077 now = wait;
1078 let Some(_) = agent.poll_transmit(now) else {
1079 unreachable!();
1080 };
1081 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1082 unreachable!();
1083 };
1084 assert_eq!(wait - now, Duration::from_secs(10));
1085 now = wait;
1086 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1087 unreachable!();
1088 };
1089 assert_eq!(timed_out, transaction_id);
1090
1091 assert!(agent.request_transaction(transaction_id).is_none());
1092 assert!(agent.mut_request_transaction(transaction_id).is_none());
1093
1094 assert!(!agent.is_validated_peer(remote_addr));
1096 }
1097
1098 #[test]
1099 fn request_no_retransmit() {
1100 let _log = crate::tests::test_init_log();
1101 let local_addr = "127.0.0.1:2000".parse().unwrap();
1102 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1103 let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
1104 .remote_addr(remote_addr)
1105 .build();
1106 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1107 let transaction_id = msg.transaction_id();
1108 let mut now = Instant::ZERO;
1109 agent.send_request(msg.finish(), remote_addr, now).unwrap();
1110 let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
1111 transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
1112 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1113 unreachable!();
1114 };
1115 assert_eq!(wait - now, Duration::from_secs(10));
1116 now = wait;
1117 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1118 unreachable!();
1119 };
1120 assert_eq!(timed_out, transaction_id);
1121
1122 assert!(agent.request_transaction(transaction_id).is_none());
1123 assert!(agent.mut_request_transaction(transaction_id).is_none());
1124
1125 assert!(!agent.is_validated_peer(remote_addr));
1127 }
1128
1129 #[test]
1130 fn request_tcp_custom_timeout() {
1131 let _log = crate::tests::test_init_log();
1132 let local_addr = "127.0.0.1:2000".parse().unwrap();
1133 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1134 let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1135 .remote_addr(remote_addr)
1136 .request_retransmits(
1137 Duration::from_secs(1),
1138 Duration::from_secs(2),
1139 4,
1140 Duration::from_secs(3),
1141 )
1142 .build();
1143 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1144 let transaction_id = msg.transaction_id();
1145 let mut now = Instant::ZERO;
1146 agent.send_request(msg.finish(), remote_addr, now).unwrap();
1147 let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
1148 unreachable!();
1149 };
1150 assert_eq!(wait - now, Duration::from_secs(1 + 2 + 2 + 3));
1151 now = wait;
1152 let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
1153 unreachable!();
1154 };
1155 assert_eq!(timed_out, transaction_id);
1156
1157 assert!(agent.request_transaction(transaction_id).is_none());
1158 assert!(agent.mut_request_transaction(transaction_id).is_none());
1159
1160 assert!(!agent.is_validated_peer(remote_addr));
1162 }
1163
1164 #[test]
1165 fn request_without_credentials() {
1166 let _log = crate::tests::test_init_log();
1167 let local_addr = "10.0.0.1:12345".parse().unwrap();
1168 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1169
1170 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1171
1172 assert!(!agent.is_validated_peer(remote_addr));
1174
1175 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1176 let transaction_id = msg.transaction_id();
1177 let transmit = agent
1178 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1179 .unwrap();
1180
1181 let request = Message::from_bytes(&transmit.data).unwrap();
1182
1183 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1184 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1185 response.add_attribute(&xor_addr).unwrap();
1186
1187 let data = response.finish();
1188 let to = transmit.to;
1189 trace!("data: {data:?}");
1190 let response = Message::from_bytes(&data).unwrap();
1191 let request = agent
1192 .request_transaction(response.transaction_id())
1193 .unwrap();
1194 assert_eq!(request.integrity(), None);
1195 assert!(agent.handle_stun_message(&response, to));
1196 assert_eq!(response.transaction_id(), transaction_id);
1197 assert!(agent.request_transaction(transaction_id).is_none());
1198 assert!(agent.mut_request_transaction(transaction_id).is_none());
1199 assert!(agent.is_validated_peer(remote_addr));
1200 }
1201
1202 #[test]
1203 fn response_with_incorrect_credentials() {
1204 let _log = crate::tests::test_init_log();
1205 let local_addr = "10.0.0.1:12345".parse().unwrap();
1206 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1207
1208 let mut auth = ShortTermAuth::new();
1209 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1210 let credentials = ShortTermCredentials::new(String::from("local_password"));
1211 let wrong_credentials = ShortTermCredentials::new(String::from("wrong_password"));
1212 auth.set_credentials(credentials.clone(), IntegrityAlgorithm::Sha1);
1213
1214 let mut msg = Message::builder_request(BINDING, MessageWriteVec::new());
1215 msg.add_message_integrity(&credentials.clone().into(), IntegrityAlgorithm::Sha1)
1216 .unwrap();
1217 let transmit = agent
1218 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1219 .unwrap();
1220 let data = transmit.data;
1221
1222 let request = Message::from_bytes(&data).unwrap();
1223
1224 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1225 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1226 response.add_attribute(&xor_addr).unwrap();
1227 response
1229 .add_message_integrity(&wrong_credentials.into(), IntegrityAlgorithm::Sha1)
1230 .unwrap();
1231
1232 let data = response.finish();
1233 let response = Message::from_bytes(&data).unwrap();
1234 let request = agent
1236 .request_transaction(response.transaction_id())
1237 .unwrap();
1238 assert_eq!(request.integrity(), Some(IntegrityAlgorithm::Sha1));
1239 assert!(matches!(
1240 auth.validate_incoming_message(&response),
1241 Err(ValidateError::IntegrityFailed)
1242 ));
1243
1244 assert!(!agent.is_validated_peer(remote_addr));
1246
1247 assert!(agent.handle_stun_message(&response, remote_addr));
1249 assert!(!agent.handle_stun_message(&response, remote_addr));
1250 assert!(agent.is_validated_peer(remote_addr));
1251 }
1252
1253 #[test]
1254 fn duplicate_response_ignored() {
1255 let _log = crate::tests::test_init_log();
1256 let local_addr = "10.0.0.1:12345".parse().unwrap();
1257 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1258
1259 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1260 assert!(!agent.is_validated_peer(remote_addr));
1261
1262 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1263 let transmit = agent
1264 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1265 .unwrap();
1266 let data = transmit.data;
1267
1268 let request = Message::from_bytes(&data).unwrap();
1269
1270 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1271 let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id());
1272 response.add_attribute(&xor_addr).unwrap();
1273
1274 let data = response.finish();
1275 let to = transmit.to;
1276 let response = Message::from_bytes(&data).unwrap();
1277 assert!(agent.handle_stun_message(&response, to));
1278
1279 let response = Message::from_bytes(&data).unwrap();
1280 assert!(!agent.handle_stun_message(&response, to));
1281 }
1282
1283 #[test]
1284 fn request_cancel() {
1285 let _log = crate::tests::test_init_log();
1286 let local_addr = "10.0.0.1:12345".parse().unwrap();
1287 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1288
1289 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1290
1291 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1292 let transaction_id = msg.transaction_id();
1293 let _transmit = agent
1294 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1295 .unwrap();
1296
1297 let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1298 assert_eq!(request.integrity(), None);
1299 assert_eq!(request.agent().local_addr(), local_addr);
1300 assert_eq!(request.mut_agent().local_addr(), local_addr);
1301 assert_eq!(request.peer_address(), remote_addr);
1302 request.cancel();
1303
1304 let ret = agent.poll(Instant::ZERO);
1305 let StunAgentPollRet::TransactionCancelled(_request) = ret else {
1306 unreachable!();
1307 };
1308 assert_eq!(transaction_id, transaction_id);
1309 assert!(agent.request_transaction(transaction_id).is_none());
1310 assert!(agent.mut_request_transaction(transaction_id).is_none());
1311 assert!(!agent.is_validated_peer(remote_addr));
1312 }
1313
1314 #[test]
1315 fn request_cancel_send() {
1316 let _log = crate::tests::test_init_log();
1317 let local_addr = "10.0.0.1:12345".parse().unwrap();
1318 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1319
1320 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1321
1322 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1323 let transaction_id = msg.transaction_id();
1324 let _transmit = agent
1325 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1326 .unwrap();
1327
1328 let mut request = agent.mut_request_transaction(transaction_id).unwrap();
1329 assert_eq!(request.integrity(), None);
1330 assert_eq!(request.agent().local_addr(), local_addr);
1331 assert_eq!(request.mut_agent().local_addr(), local_addr);
1332 assert_eq!(request.peer_address(), remote_addr);
1333 request.cancel_retransmissions();
1334
1335 let mut now = Instant::ZERO;
1336 let start = now;
1337 loop {
1338 match agent.poll(now) {
1339 StunAgentPollRet::WaitUntil(new_now) => {
1340 assert_ne!(new_now, now);
1341 now = new_now;
1342 }
1343 StunAgentPollRet::TransactionCancelled(_) => break,
1344 _ => unreachable!(),
1345 }
1346 let _ = agent.poll_transmit(now);
1347 }
1348 assert!(now - start > Duration::from_secs(20));
1349 assert!(agent.request_transaction(transaction_id).is_none());
1350 assert!(agent.mut_request_transaction(transaction_id).is_none());
1351 assert!(!agent.is_validated_peer(remote_addr));
1352 }
1353
1354 #[test]
1355 fn request_duplicate() {
1356 let _log = crate::tests::test_init_log();
1357 let local_addr = "10.0.0.1:12345".parse().unwrap();
1358 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1359
1360 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1361
1362 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1363 let transaction_id = msg.transaction_id();
1364 let msg = msg.finish();
1365 let transmit = agent
1366 .send_request(msg.clone(), remote_addr, Instant::ZERO)
1367 .unwrap();
1368 let to = transmit.to;
1369 let request = Message::from_bytes(&transmit.data).unwrap();
1370
1371 let mut response = Message::builder_success(&request, MessageWriteVec::new());
1372 let xor_addr = XorMappedAddress::new(transmit.from, transaction_id);
1373 response.add_attribute(&xor_addr).unwrap();
1374
1375 assert!(matches!(
1376 agent.send_request(msg, remote_addr, Instant::ZERO),
1377 Err(StunError::AlreadyInProgress)
1378 ));
1379
1380 let request = agent.request_transaction(transaction_id).unwrap();
1382 assert_eq!(request.peer_address(), remote_addr);
1383
1384 let data = response.finish();
1385 let response = Message::from_bytes(&data).unwrap();
1386 assert!(agent.handle_stun_message(&response, to));
1387
1388 assert!(agent.is_validated_peer(to));
1389 }
1390
1391 #[test]
1392 fn incoming_request() {
1393 let _log = crate::tests::test_init_log();
1394 let local_addr = "10.0.0.1:12345".parse().unwrap();
1395 let remote_addr = "10.0.0.2:3478".parse().unwrap();
1396
1397 let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build();
1398
1399 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1400 let data = msg.finish();
1401 let stun = Message::from_bytes(&data).unwrap();
1402 error!("{stun:?}");
1403 assert!(agent.handle_stun_message(&stun, remote_addr));
1404 agent.validated_peer(remote_addr);
1405 assert!(agent.is_validated_peer(remote_addr));
1406 }
1407
1408 #[test]
1409 fn tcp_request() {
1410 let _log = crate::tests::test_init_log();
1411 let local_addr = "127.0.0.1:2000".parse().unwrap();
1412 let remote_addr = "127.0.0.1:1000".parse().unwrap();
1413 let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
1414 .remote_addr(remote_addr)
1415 .build();
1416
1417 let msg = Message::builder_request(BINDING, MessageWriteVec::new());
1418 let transaction_id = msg.transaction_id();
1419 let transmit = agent
1420 .send_request(msg.finish(), remote_addr, Instant::ZERO)
1421 .unwrap();
1422 assert_eq!(transmit.transport, TransportType::Tcp);
1423 assert_eq!(transmit.from, local_addr);
1424 assert_eq!(transmit.to, remote_addr);
1425
1426 let request = Message::from_bytes(&transmit.data).unwrap();
1427 assert_eq!(request.transaction_id(), transaction_id);
1428 }
1429
1430 #[test]
1431 fn transmit_into_owned() {
1432 let data = [0x10, 0x20];
1433 let transport = TransportType::Udp;
1434 let from = "127.0.0.1:1000".parse().unwrap();
1435 let to = "127.0.0.1:2000".parse().unwrap();
1436 let transmit = Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to);
1437 let owned = transmit.into_owned();
1438 assert_eq!(owned.data.as_ref(), data.as_ref());
1439 assert_eq!(owned.transport, transport);
1440 assert_eq!(owned.from, from);
1441 assert_eq!(owned.to, to);
1442 error!("{owned}");
1443 }
1444
1445 #[test]
1446 fn transmit_display() {
1447 let data = [0x10, 0x20];
1448 let from = "127.0.0.1:1000".parse().unwrap();
1449 let to = "127.0.0.1:2000".parse().unwrap();
1450 assert_eq!(
1451 alloc::format!(
1452 "{}",
1453 Transmit::new(Data::from(data.as_ref()), TransportType::Udp, from, to)
1454 ),
1455 String::from("Transmit(UDP: 127.0.0.1:1000 -> 127.0.0.1:2000 of 2 bytes)")
1456 );
1457 }
1458
1459 #[test]
1460 fn request_retransmits() {
1461 let _log = crate::tests::test_init_log();
1462 let rto = RequestRto {
1463 initial: Duration::from_millis(1),
1464 max: Duration::MAX,
1465 retransmits: 0,
1466 last_retransmit: Duration::from_secs(1),
1467 };
1468 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1469 assert_eq!(timeouts, vec![]);
1470 assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1471 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1472 assert_eq!(timeouts, vec![]);
1473 assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1474
1475 let rto = RequestRto {
1476 initial: Duration::from_millis(1),
1477 max: Duration::MAX,
1478 retransmits: 1,
1479 last_retransmit: Duration::from_secs(1),
1480 };
1481 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1482 assert_eq!(timeouts, vec![]);
1483 assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1484 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1485 assert_eq!(timeouts, vec![]);
1486 assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1487
1488 let rto = RequestRto {
1489 initial: Duration::from_millis(1),
1490 max: Duration::MAX,
1491 retransmits: 2,
1492 last_retransmit: Duration::from_secs(1),
1493 };
1494 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Udp);
1495 assert_eq!(timeouts, vec![Duration::from_millis(1)]);
1496 assert_eq!(last_transmit_timeout, Duration::from_secs(1));
1497 let (timeouts, last_transmit_timeout) = rto.calculate_timeouts(TransportType::Tcp);
1498 assert_eq!(timeouts, vec![]);
1499 assert_eq!(
1500 last_transmit_timeout,
1501 Duration::from_secs(1) + Duration::from_millis(1)
1502 );
1503 }
1504}