stun_agent/
lib.rs

1//! STUN Agent library for Rust.
2//!
3//! This crate provides a STUN I/O-free protocol implementation.
4//! An I/O-free protocol implementation, often referred to as a
5//! [`sans-IO`](`https://sans-io.readthedocs.io/index.html`) implementation, is a
6//! network protocol implementation that contains no code for network I/O or
7//! asynchronous flow control. This means the protocol implementation is agnostic
8//! to the underlying networking stack and can be used in any environment that provides
9//! the necessary network I/O and asynchronous flow control.
10//!
11//! These STUN agents are designed for use in a client-server architecture where the
12//! client sends a request and the server responds.
13//!
14//! This sans-IO protocol implementation is defined entirely in terms of synchronous
15//! functions returning synchronous results, without blocking or waiting for any form
16//! of I/O. This makes it suitable for a wide range of environments, enhancing testing,
17//! flexibility, correctness, re-usability and simplicity.
18//!
19//! This library currently provides support for writing STUN clients. Support for
20//! writing servers is not yet implemented. The main element of this library is:
21//! - [`StunClient`](`crate::StunClient`): The STUN client that sends STUN requests and indications to a STUN server.
22#![deny(missing_docs)]
23
24use std::{ops::Deref, slice::Iter, sync::Arc};
25
26use stun_rs::MessageHeader;
27use stun_rs::StunAttribute;
28use stun_rs::MESSAGE_HEADER_SIZE;
29
30mod client;
31mod events;
32mod fingerprint;
33mod integrity;
34mod lt_cred_mech;
35mod message;
36mod rtt;
37mod st_cred_mech;
38mod timeout;
39
40pub use crate::client::RttConfig;
41pub use crate::client::StunClient;
42pub use crate::client::StunClienteBuilder;
43pub use crate::client::TransportReliability;
44pub use crate::events::StunTransactionError;
45pub use crate::events::StuntClientEvent;
46pub use crate::message::StunAttributes;
47
48/// Describes the error that can occur during the STUN agent operation.
49#[derive(Debug, PartialEq, Eq)]
50pub enum StunAgentError {
51    /// Indicates that the STUN agent has discarded the buffer
52    Discarded,
53    /// Indicates that the STUN agent has received an invalid STUN packet
54    FingerPrintValidationFailed,
55    /// Indicates that the STUN agent has ignored the operation
56    Ignored,
57    /// Indicates that the STUN agent has reached the maximum number of outstanding requests
58    MaxOutstandingRequestsReached,
59    /// Indicates that the STUN agent has received an invalid STUN packet
60    StunCheckFailed,
61    /// Indicates that the STUN agent has detected an internal error, and the [`String`] contains the error message
62    InternalError(String),
63}
64
65/// Describes the kind of integrity protection that can be used.
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum Integrity {
68    /// [`MessageInttegrity`](stun_rs::attributes::stun::MessageIntegrity) protection
69    MessageIntegrity,
70    /// [`MessageIntegritySha256`](stun_rs::attributes::stun::MessageIntegritySha256) protection
71    MessageIntegritySha256,
72}
73
74/// Describes the kind of credential mechanism that can be used by the STUN agent.
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum CredentialMechanism {
77    /// [Short-term credential mechanism](https://datatracker.ietf.org/doc/html/rfc8489#section-9.1)
78    /// with the specified [`Integrity`] in case the agent knows from an external mechanism
79    /// which message integrity algorithm is supported by both agents.
80    ShortTerm(Option<Integrity>),
81    /// [Long-term credential mechanism](https://datatracker.ietf.org/doc/html/rfc8489#section-9.2)
82    LongTerm,
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86struct StunPacketInternal {
87    buffer: Vec<u8>,
88    size: usize,
89}
90
91/// A chunk of bytes that represents a STUN packet that can be cloned.
92#[derive(Debug, Clone, PartialEq, Eq)]
93pub struct StunPacket(Arc<StunPacketInternal>);
94
95impl StunPacket {
96    /// Creates a STUN packet from a vector which is filled up to `size` bytes.
97    pub(crate) fn new(buffer: Vec<u8>, size: usize) -> Self {
98        let internal = StunPacketInternal { buffer, size };
99        StunPacket(Arc::new(internal))
100    }
101}
102
103impl Deref for StunPacket {
104    type Target = [u8];
105
106    fn deref(&self) -> &Self::Target {
107        &self.0.buffer[..self.0.size]
108    }
109}
110
111impl AsRef<[u8]> for StunPacket {
112    fn as_ref(&self) -> &[u8] {
113        self
114    }
115}
116
117/// A STUN packet decoder that can be used to decode a STUN packet.
118/// The [`StunPacketDecoder`] is helpful when reading bytes from a stream oriented connection,
119/// such as a `TCP` stream, or even when reading bytes from a datagram oriented connection, such as
120/// a `UDP` socket when the STUN packet is fragmented.
121///```rust
122/// # use stun_agent::StunPacketDecoder;
123///
124/// //let buffer = vec![0; 1024];
125/// //let mut decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
126///```
127#[derive(Debug)]
128pub struct StunPacketDecoder {
129    buffer: Vec<u8>,
130    current_size: usize,
131    expected_size: Option<usize>,
132}
133
134/// Describes the possible outcomes of the STUN packet decoding.
135/// - If the STUN packet has been fully decoded, the method returns the decoded STUN packet
136/// and the number of bytes consumed.
137/// - If the STUN packet has not been fully decoded, the method returns the decoder and the
138/// number of bytes still needed to complete the STUN packet, if known.
139#[derive(Debug)]
140pub enum StunPacketDecodedValue {
141    /// Returns the decoded STUN packet and the number of bytes consumed from the input.
142    Decoded((StunPacket, usize)),
143    /// Returns the decoder and the number of bytes missing to complete the STUN packet if known.
144    MoreBytesNeeded((StunPacketDecoder, Option<usize>)),
145}
146
147/// Describe the error type that can occur during the STUN packet decoding.
148#[derive(Debug)]
149pub enum StunPacketErrorType {
150    /// The buffer is too small to hold the STUN packet.
151    SmallBuffer,
152    /// The buffer does not contain a valid STUN header.
153    InvalidStunPacket,
154}
155
156/// Describes the error that can occur during the STUN packet decoding.
157#[derive(Debug)]
158pub struct StunPacketDecodedError {
159    /// The type of error that occurred during the STUN packet decoding.
160    pub error_type: StunPacketErrorType,
161    /// The internal buffer filled with bytes.
162    pub buffer: Vec<u8>,
163    /// The size of the buffer that has been filled.
164    pub size: usize,
165    /// The number of bytes consumed from the input data.
166    pub consumed: usize,
167}
168
169impl StunPacketDecoder {
170    /// Creates a new STUN packet decoder using the provided buffer. The buffer must be
171    /// at least 20 bytes long to accommodate the STUN message header. If the buffer is
172    /// too small, an error is returned.
173    pub fn new(buffer: Vec<u8>) -> Result<Self, StunPacketDecodedError> {
174        if buffer.len() < MESSAGE_HEADER_SIZE {
175            return Err(StunPacketDecodedError {
176                error_type: StunPacketErrorType::SmallBuffer,
177                buffer,
178                size: 0,
179                consumed: 0,
180            });
181        }
182        Ok(StunPacketDecoder {
183            buffer,
184            current_size: 0,
185            expected_size: None,
186        })
187    }
188
189    /// Decodes the given data and returns the decoded STUN packet. This method takes the data
190    /// read so far as an argument and returns one of the following outcomes:
191    /// - If the STUN packet has been fully decoded, the method returns the decoded STUN packet
192    /// and the number of bytes consumed.
193    /// - If the STUN packet has not been fully decoded, the method returns the decoder and the
194    /// number of bytes still needed to complete the STUN packet, if known.
195    /// - If the buffer is too small or the header does not correspond to a STUN message, the
196    /// method returns an error.
197    /// Note: This method does not perform a full validation of the STUN message; it only checks
198    /// the header. Integrity checks and other validations will be performed by the STUN agent.
199    pub fn decode(mut self, data: &[u8]) -> Result<StunPacketDecodedValue, StunPacketDecodedError> {
200        match self.expected_size {
201            Some(size) => {
202                // At this point we know that the buffer is big enough to hold the message,
203                // so we do not need to check bounds.
204                let first = self.current_size;
205                let remaining = size - first;
206                if data.len() >= remaining {
207                    // Copy only up to the message length
208                    self.buffer[first..size].copy_from_slice(&data[..remaining]);
209                    let packet = StunPacket::new(self.buffer, size);
210                    Ok(StunPacketDecodedValue::Decoded((packet, remaining)))
211                } else {
212                    // Copy all the data
213                    self.buffer[first..first + data.len()].copy_from_slice(&data[..data.len()]);
214                    self.current_size += data.len();
215                    Ok(StunPacketDecodedValue::MoreBytesNeeded((
216                        self,
217                        Some(remaining - data.len()),
218                    )))
219                }
220            }
221            None => {
222                let header_length = self.current_size + data.len();
223                if header_length >= MESSAGE_HEADER_SIZE {
224                    let first = self.current_size;
225                    let remaining = MESSAGE_HEADER_SIZE - first;
226
227                    // Write the STUN message header
228                    self.buffer[first..first + remaining].copy_from_slice(&data[..remaining]);
229
230                    // We can decode the header now
231                    let slice: &[u8; MESSAGE_HEADER_SIZE] =
232                        self.buffer[..MESSAGE_HEADER_SIZE].try_into().unwrap();
233                    let Ok(header) = MessageHeader::try_from(slice) else {
234                        return Err(StunPacketDecodedError {
235                            error_type: StunPacketErrorType::InvalidStunPacket,
236                            buffer: self.buffer,
237                            size: MESSAGE_HEADER_SIZE,
238                            consumed: remaining,
239                        });
240                    };
241                    let msg_length = header.msg_length as usize;
242
243                    // Check if the buffer provided is big enough to hold the message
244                    if self.buffer.len() < msg_length + MESSAGE_HEADER_SIZE {
245                        return Err(StunPacketDecodedError {
246                            error_type: StunPacketErrorType::SmallBuffer,
247                            buffer: self.buffer,
248                            size: MESSAGE_HEADER_SIZE,
249                            consumed: remaining,
250                        });
251                    }
252
253                    self.expected_size = Some(msg_length + MESSAGE_HEADER_SIZE);
254
255                    if data.len() >= msg_length + remaining {
256                        // Copy only up to the message length
257                        self.buffer[MESSAGE_HEADER_SIZE..MESSAGE_HEADER_SIZE + msg_length]
258                            .copy_from_slice(&data[remaining..remaining + msg_length]);
259                        let packet = StunPacket::new(self.buffer, msg_length + MESSAGE_HEADER_SIZE);
260                        Ok(StunPacketDecodedValue::Decoded((
261                            packet,
262                            remaining + msg_length,
263                        )))
264                    } else {
265                        // Copy all the remaining data
266                        self.buffer
267                            [MESSAGE_HEADER_SIZE..MESSAGE_HEADER_SIZE + data.len() - remaining]
268                            .copy_from_slice(&data[remaining..data.len()]);
269                        self.current_size += data.len();
270                        let remaining = msg_length + MESSAGE_HEADER_SIZE - self.current_size;
271                        Ok(StunPacketDecodedValue::MoreBytesNeeded((
272                            self,
273                            Some(remaining),
274                        )))
275                    }
276                } else {
277                    // The number of bytes is less than the header size, so we can safety copy all
278                    // the data because the minimum size of the byte is 20 bytes.
279                    let first = self.current_size;
280                    let remaining = data.len();
281                    self.buffer[first..first + remaining].copy_from_slice(&data[..remaining]);
282                    self.current_size += data.len();
283
284                    // We still don't know the message length
285                    Ok(StunPacketDecodedValue::MoreBytesNeeded((self, None)))
286                }
287            }
288        }
289    }
290}
291
292#[derive(Debug)]
293struct ProtectedAttributeIteratorObject<'a> {
294    iter: Iter<'a, StunAttribute>,
295    integrity: bool,
296    integrity_sha256: bool,
297    fingerprint: bool,
298}
299
300trait ProtectedAttributeIterator<'a> {
301    fn protected_iter(&self) -> ProtectedAttributeIteratorObject<'a>;
302}
303
304impl<'a> ProtectedAttributeIterator<'a> for &'a [StunAttribute] {
305    fn protected_iter(&self) -> ProtectedAttributeIteratorObject<'a> {
306        ProtectedAttributeIteratorObject {
307            iter: self.iter(),
308            integrity: false,
309            integrity_sha256: false,
310            fingerprint: false,
311        }
312    }
313}
314
315impl<'a> Iterator for ProtectedAttributeIteratorObject<'a> {
316    type Item = &'a StunAttribute;
317
318    fn next(&mut self) -> Option<Self::Item> {
319        for attr in &mut self.iter {
320            if attr.is_message_integrity() {
321                if self.integrity || self.integrity_sha256 || self.fingerprint {
322                    continue;
323                }
324                self.integrity = true;
325            } else if attr.is_message_integrity_sha256() {
326                if self.integrity_sha256 || self.fingerprint {
327                    continue;
328                }
329                self.integrity_sha256 = true;
330            } else if attr.is_fingerprint() {
331                if self.fingerprint {
332                    continue;
333                }
334                self.fingerprint = true;
335            } else if self.integrity || self.integrity_sha256 || self.fingerprint {
336                continue;
337            }
338            return Some(attr);
339        }
340        None
341    }
342}
343
344#[cfg(test)]
345mod tests_stun_packet {
346    use super::*;
347
348    #[test]
349    fn test_stun_packet() {
350        let buffer = vec![0; 10];
351        assert_eq!(buffer.len(), 10);
352
353        // Create a stun packet that is only filled up to 5 bytes.
354        let packet = StunPacket::new(buffer, 5);
355        assert_eq!(packet.as_ref().len(), 5);
356
357        let buffer = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
358        let packet = StunPacket::new(buffer, 5);
359        assert_eq!(packet.len(), 5);
360        assert_eq!(packet.as_ref(), &[0, 1, 2, 3, 4]);
361    }
362}
363
364#[cfg(test)]
365mod tests_protected_iterator {
366    use super::*;
367    use stun_rs::{
368        attributes::stun::{
369            Fingerprint, MessageIntegrity, MessageIntegritySha256, Nonce, Realm, UserName,
370        },
371        methods::BINDING,
372        Algorithm, AlgorithmId, HMACKey, MessageClass, StunMessageBuilder,
373    };
374
375    const USERNAME: &str = "test-username";
376    const NONCE: &str = "test-nonce";
377    const REALM: &str = "test-realm";
378    const PASSWORD: &str = "test-password";
379
380    #[test]
381    fn test_protected_iterator() {
382        let username = UserName::new(USERNAME).expect("Failed to create username");
383        let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
384        let realm = Realm::new(REALM).expect("Failed to create realm");
385        let algorithm = Algorithm::from(AlgorithmId::MD5);
386        let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
387            .expect("Failed to create HMACKey");
388        let integrity = MessageIntegrity::new(key.clone());
389        let integrity_sha256 = MessageIntegritySha256::new(key);
390
391        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
392            .with_attribute(username)
393            .with_attribute(nonce)
394            .with_attribute(realm)
395            .with_attribute(integrity)
396            .with_attribute(integrity_sha256)
397            .with_attribute(Fingerprint::default())
398            .build();
399
400        let mut iter = msg.attributes().protected_iter();
401        let attr = iter.next().expect("Expected attribute UserName");
402        assert!(attr.is_user_name());
403
404        let attr = iter.next().expect("Expected attribute Nonce");
405        assert!(attr.is_nonce());
406
407        let attr = iter.next().expect("Expected attribute Realm");
408        assert!(attr.is_realm());
409
410        let attr = iter.next().expect("Expected attribute MessageIntegrity");
411        assert!(attr.is_message_integrity());
412
413        let attr = iter
414            .next()
415            .expect("Expected attribute MessageIntegritySha256");
416        assert!(attr.is_message_integrity_sha256());
417
418        let attr = iter.next().expect("Expected attribute FingerPrint");
419        assert!(attr.is_fingerprint());
420
421        assert!(iter.next().is_none());
422    }
423
424    #[test]
425    fn test_protected_iterator_only_message_integrity() {
426        let username = UserName::new(USERNAME).expect("Failed to create username");
427        let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
428        let realm = Realm::new(REALM).expect("Failed to create realm");
429        let algorithm = Algorithm::from(AlgorithmId::MD5);
430        let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
431            .expect("Failed to create HMACKey");
432        let integrity = MessageIntegrity::new(key.clone());
433        let integrity_sha256 = MessageIntegritySha256::new(key);
434
435        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
436            .with_attribute(integrity)
437            .with_attribute(username)
438            .with_attribute(nonce)
439            .with_attribute(realm)
440            .with_attribute(integrity_sha256)
441            .build();
442
443        let mut iter = msg.attributes().protected_iter();
444        let attr = iter.next().expect("Expected attribute MessageIntegrity");
445        assert!(attr.is_message_integrity());
446
447        let attr = iter
448            .next()
449            .expect("Expected attribute MessageIntegritySha256");
450        assert!(attr.is_message_integrity_sha256());
451
452        assert!(iter.next().is_none());
453    }
454
455    #[test]
456    fn test_protected_iterator_skip_non_protected() {
457        let username = UserName::new(USERNAME).expect("Failed to create username");
458        let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
459        let realm = Realm::new(REALM).expect("Failed to create realm");
460        let algorithm = Algorithm::from(AlgorithmId::MD5);
461        let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
462            .expect("Failed to create HMACKey");
463        let integrity = MessageIntegrity::new(key.clone());
464        let integrity_sha256 = MessageIntegritySha256::new(key);
465
466        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
467            .with_attribute(username)
468            .with_attribute(integrity)
469            .with_attribute(nonce)
470            .with_attribute(integrity_sha256)
471            .with_attribute(realm)
472            .with_attribute(Fingerprint::default())
473            .build();
474
475        let mut iter = msg.attributes().protected_iter();
476        let attr = iter.next().expect("Expected attribute UserName");
477        assert!(attr.is_user_name());
478
479        let attr = iter.next().expect("Expected attribute MessageIntegrity");
480        assert!(attr.is_message_integrity());
481
482        let attr = iter
483            .next()
484            .expect("Expected attribute MessageIntegritySha256");
485        assert!(attr.is_message_integrity_sha256());
486
487        let attr = iter.next().expect("Expected attribute FingerPrint");
488        assert!(attr.is_fingerprint());
489
490        assert!(iter.next().is_none());
491    }
492
493    #[test]
494    fn test_protected_iterator_skip_message_integrity() {
495        let username = UserName::new(USERNAME).expect("Failed to create username");
496        let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
497        let realm = Realm::new(REALM).expect("Failed to create realm");
498        let algorithm = Algorithm::from(AlgorithmId::MD5);
499        let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
500            .expect("Failed to create HMACKey");
501        let integrity = MessageIntegrity::new(key.clone());
502        let integrity_sha256 = MessageIntegritySha256::new(key);
503
504        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
505            .with_attribute(username)
506            .with_attribute(integrity_sha256)
507            .with_attribute(nonce)
508            .with_attribute(integrity)
509            .with_attribute(realm)
510            .with_attribute(Fingerprint::default())
511            .build();
512
513        let mut iter = msg.attributes().protected_iter();
514        let attr = iter.next().expect("Expected attribute UserName");
515        assert!(attr.is_user_name());
516
517        let attr = iter
518            .next()
519            .expect("Expected attribute MessageIntegritySha256");
520        assert!(attr.is_message_integrity_sha256());
521
522        // MessageIntegrity can not go after MessageIntegritySha256, so it must be skipped
523        let attr = iter.next().expect("Expected attribute FingerPrint");
524        assert!(attr.is_fingerprint());
525
526        assert!(iter.next().is_none());
527    }
528
529    #[test]
530    fn test_protected_iterator_skip_message_integrity_sha256() {
531        let username = UserName::new(USERNAME).expect("Failed to create username");
532        let nonce = Nonce::new(NONCE).expect("Failed to create nonce");
533        let realm = Realm::new(REALM).expect("Failed to create realm");
534        let algorithm = Algorithm::from(AlgorithmId::MD5);
535        let key = HMACKey::new_long_term(&username, &realm, PASSWORD, algorithm)
536            .expect("Failed to create HMACKey");
537        let integrity = MessageIntegrity::new(key.clone());
538        let integrity_sha256 = MessageIntegritySha256::new(key);
539
540        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
541            .with_attribute(username)
542            .with_attribute(Fingerprint::default())
543            .with_attribute(nonce)
544            .with_attribute(integrity_sha256)
545            .with_attribute(integrity)
546            .with_attribute(realm)
547            .with_attribute(Fingerprint::default())
548            .build();
549
550        let mut iter = msg.attributes().protected_iter();
551        let attr = iter.next().expect("Expected attribute UserName");
552        assert!(attr.is_user_name());
553
554        let attr = iter.next().expect("Expected attribute FingerPrint");
555        assert!(attr.is_fingerprint());
556
557        // All attributes after FingerPrint must be skipped
558        assert!(iter.next().is_none());
559    }
560
561    #[test]
562    fn test_protected_iterator_skip_duplicated_integrity_attrs() {
563        let username = UserName::new(USERNAME).expect("Failed to create username");
564        let realm = Realm::new(REALM).expect("Failed to create realm");
565        let algorithm = Algorithm::from(AlgorithmId::MD5);
566        let key = HMACKey::new_long_term(&username, realm, PASSWORD, algorithm)
567            .expect("Failed to create HMACKey");
568        let integrity = MessageIntegrity::new(key.clone());
569        let integrity_sha256 = MessageIntegritySha256::new(key);
570
571        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
572            .with_attribute(username)
573            .with_attribute(integrity.clone())
574            .with_attribute(integrity)
575            .with_attribute(integrity_sha256.clone())
576            .with_attribute(integrity_sha256)
577            .with_attribute(Fingerprint::default())
578            .with_attribute(Fingerprint::default())
579            .build();
580
581        let mut iter = msg.attributes().protected_iter();
582        let attr = iter.next().expect("Expected attribute UserName");
583        assert!(attr.is_user_name());
584
585        let attr = iter.next().expect("Expected attribute MessageIntegrity");
586        assert!(attr.is_message_integrity());
587
588        let attr = iter
589            .next()
590            .expect("Expected attribute MessageIntegritySha256");
591        assert!(attr.is_message_integrity_sha256());
592
593        let attr = iter.next().expect("Expected attribute FingerPrint");
594        assert!(attr.is_fingerprint());
595
596        assert!(iter.next().is_none());
597    }
598
599    #[test]
600    fn test_protected_iterator_skip_corner_cases() {
601        let username = UserName::new(USERNAME).expect("Failed to create username");
602
603        let key = HMACKey::new_short_term("test-password").expect("Failed to create HMACKey");
604        let integrity = MessageIntegrity::new(key.clone());
605        let integrity_sha256 = MessageIntegritySha256::new(key);
606
607        let msg = StunMessageBuilder::new(BINDING, MessageClass::Request)
608            .with_attribute(integrity.clone())
609            .with_attribute(integrity.clone())
610            .with_attribute(integrity_sha256.clone())
611            .with_attribute(integrity.clone())
612            .with_attribute(integrity_sha256.clone())
613            .with_attribute(Fingerprint::default())
614            .with_attribute(integrity)
615            .with_attribute(integrity_sha256)
616            .with_attribute(Fingerprint::default())
617            .with_attribute(username)
618            .build();
619
620        let mut iter = msg.attributes().protected_iter();
621        let attr = iter.next().expect("Expected attribute MessageIntegrity");
622        assert!(attr.is_message_integrity());
623        let attr = iter
624            .next()
625            .expect("Expected attribute MessageIntegritySha256");
626        assert!(attr.is_message_integrity_sha256());
627        let attr = iter.next().expect("Expected attribute FingerPrint");
628        assert!(attr.is_fingerprint());
629
630        assert!(iter.next().is_none());
631    }
632}
633
634#[cfg(test)]
635mod test_stun_packet_decoder {
636    use super::*;
637    use stun_vectors::SAMPLE_IPV4_RESPONSE;
638
639    #[test]
640    fn test_stun_packet_decoder_small_parts() {
641        let buffer = vec![0; 1024];
642        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
643
644        let mut index = 0;
645        let data = &SAMPLE_IPV4_RESPONSE[index..10];
646        let decoded = decoder.decode(data).expect("Failed to decode");
647        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
648            panic!("Expected more bytes needed");
649        };
650        // Message header is not processed, so we have no information about remaining bytes
651        assert_eq!(remaining, None);
652        assert_eq!(decoder.current_size, 10);
653        assert!(decoder.expected_size.is_none());
654
655        index = 10;
656        let data = &SAMPLE_IPV4_RESPONSE[index..15];
657        let decoded = decoder.decode(data).expect("Failed to decode");
658        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
659            panic!("Expected more bytes needed");
660        };
661        // Message header is not processed, so we have no information about remaining bytes
662        assert_eq!(remaining, None);
663        assert_eq!(decoder.current_size, 15);
664        assert!(decoder.expected_size.is_none());
665        assert_eq!(decoder.buffer[..15], SAMPLE_IPV4_RESPONSE[..15]);
666
667        index = 15;
668        let data = &SAMPLE_IPV4_RESPONSE[index..index + 5];
669        let decoded = decoder.decode(data).expect("Failed to decode");
670        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
671            panic!("Expected more bytes needed");
672        };
673        // Header is processed and the msg length is 60 (0x3C)
674        assert_eq!(remaining, Some(60));
675        assert_eq!(decoder.current_size, 20);
676        assert_eq!(decoder.expected_size, Some(60 + MESSAGE_HEADER_SIZE));
677        assert_eq!(decoder.buffer[..20], SAMPLE_IPV4_RESPONSE[..20]);
678
679        index = 20;
680        let data = &SAMPLE_IPV4_RESPONSE[index..index + 30];
681        let decoded = decoder.decode(data).expect("Failed to decode");
682        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
683            panic!("Expected more bytes needed");
684        };
685        assert_eq!(remaining, Some(30));
686        assert_eq!(decoder.current_size, 50);
687        assert_eq!(decoder.buffer[..50], SAMPLE_IPV4_RESPONSE[..50]);
688
689        index = 50;
690        let data = &SAMPLE_IPV4_RESPONSE[index..index + 29];
691        let decoded = decoder.decode(data).expect("Failed to decode");
692        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
693            panic!("Expected more bytes needed");
694        };
695        assert_eq!(remaining, Some(1));
696        assert_eq!(decoder.current_size, 79);
697        assert_eq!(decoder.buffer[..79], SAMPLE_IPV4_RESPONSE[..79]);
698
699        // Complete the byte remaining to complete the STUN packet
700        index = 79;
701        let data = &SAMPLE_IPV4_RESPONSE[index..index + 1];
702        let decoded = decoder.decode(data).expect("Failed to decode");
703        let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
704            panic!("Stun packed not decoded");
705        };
706        assert_eq!(consumed, 1);
707        assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
708    }
709
710    #[test]
711    fn test_stun_packet_decoder_one_step() {
712        let buffer = vec![0; 1024];
713        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
714
715        // Read the buffer in one go
716        let decoded = decoder
717            .decode(&SAMPLE_IPV4_RESPONSE)
718            .expect("Failed to decode");
719        let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
720            panic!("Stun packed not decoded");
721        };
722        assert_eq!(consumed, SAMPLE_IPV4_RESPONSE.len());
723        assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
724    }
725
726    #[test]
727    fn test_stun_packet_decoder_two_step() {
728        let buffer = vec![0; 1024];
729        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
730
731        let data = &SAMPLE_IPV4_RESPONSE[..15];
732        let decoded = decoder.decode(data).expect("Failed to decode");
733        let StunPacketDecodedValue::MoreBytesNeeded((decoder, remaining)) = decoded else {
734            panic!("Expected more bytes needed");
735        };
736        // Message header is not processed, so we have no information about remaining bytes
737        assert_eq!(remaining, None);
738        assert_eq!(decoder.current_size, 15);
739        assert!(decoder.expected_size.is_none());
740
741        // Read the rest of the packet
742        let data = &SAMPLE_IPV4_RESPONSE[15..];
743        let decoded = decoder.decode(data).expect("Failed to decode");
744        let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
745            panic!("Stun packed not decoded");
746        };
747        assert_eq!(consumed, data.len());
748        assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
749    }
750
751    #[test]
752    fn test_stun_packet_decoder_byte_by_byte() {
753        let buffer = vec![0; 1024];
754        let mut decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
755
756        let total = SAMPLE_IPV4_RESPONSE.len();
757        for index in 0..total {
758            let data = &SAMPLE_IPV4_RESPONSE[index..index + 1];
759            let decoded = decoder.decode(data).expect("Failed to decode");
760            if index < total - 1 {
761                let StunPacketDecodedValue::MoreBytesNeeded((deco, remaining)) = decoded else {
762                    panic!("Expected more bytes needed");
763                };
764                if index >= MESSAGE_HEADER_SIZE - 1 {
765                    assert_eq!(remaining, Some(total - 1 - index));
766                } else {
767                    assert_eq!(remaining, None);
768                }
769                decoder = deco;
770            } else {
771                let StunPacketDecodedValue::Decoded((packet, consumed)) = decoded else {
772                    panic!("Stun packed not decoded");
773                };
774                assert_eq!(consumed, 1);
775                assert_eq!(&SAMPLE_IPV4_RESPONSE, packet.as_ref());
776                break;
777            }
778        }
779    }
780
781    #[test]
782    fn test_stun_packet_decoder_small_buffer() {
783        let buffer = vec![0; 10];
784        let error = StunPacketDecoder::new(buffer).expect_err("Expected small buffer error");
785        let StunPacketErrorType::SmallBuffer = error.error_type else {
786            panic!("Expected small buffer error");
787        };
788
789        let buffer = vec![0; 50];
790        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
791
792        let result = decoder
793            .decode(&SAMPLE_IPV4_RESPONSE[..10])
794            .expect("Failed to decode");
795        // We could not read the whole header, so it won't fail
796        let StunPacketDecodedValue::MoreBytesNeeded((decoder, None)) = result else {
797            panic!("Expected more bytes needed");
798        };
799
800        let error = decoder
801            .decode(&SAMPLE_IPV4_RESPONSE[10..])
802            .expect_err("Expected error");
803        // The header is read and the buffer is too small to hold the whole message
804        let StunPacketErrorType::SmallBuffer = error.error_type else {
805            panic!("Expected small buffer error");
806        };
807
808        // Test the same scenario but trying to decode the buffer in one go
809        let buffer = vec![0; 50];
810        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
811        let error = decoder
812            .decode(&SAMPLE_IPV4_RESPONSE)
813            .expect_err("Expected error");
814        let StunPacketErrorType::SmallBuffer = error.error_type else {
815            panic!("Expected small buffer error");
816        };
817    }
818
819    #[test]
820    fn test_stun_packet_decoder_invalid_stun_packet() {
821        let buffer = vec![0; 1024];
822        let decoder = StunPacketDecoder::new(buffer).expect("Failed to create decoder");
823
824        let data = vec![0; 1024];
825        let error = decoder.decode(&data).expect_err("Expected error");
826        let StunPacketErrorType::InvalidStunPacket = error.error_type else {
827            panic!("Expected invalid STUN packet error");
828        };
829    }
830}