rtc_stun/
message.rs

1#[cfg(test)]
2mod message_test;
3
4use crate::attributes::*;
5use shared::error::*;
6
7use base64::prelude::*;
8use rand::Rng;
9use std::fmt;
10use std::io::{Read, Write};
11
12// MAGIC_COOKIE is fixed value that aids in distinguishing STUN packets
13// from packets of other protocols when STUN is multiplexed with those
14// other protocols on the same Port.
15//
16// The magic cookie field MUST contain the fixed value 0x2112A442 in
17// network byte order.
18//
19// Defined in "STUN Message Structure", section 6.
20pub const MAGIC_COOKIE: u32 = 0x2112A442;
21pub const ATTRIBUTE_HEADER_SIZE: usize = 4;
22pub const MESSAGE_HEADER_SIZE: usize = 20;
23
24// TRANSACTION_ID_SIZE is length of transaction id array (in bytes).
25pub const TRANSACTION_ID_SIZE: usize = 12; // 96 bit
26
27#[derive(PartialEq, Eq, Hash, Copy, Clone, Default, Debug)]
28pub struct TransactionId(pub [u8; TRANSACTION_ID_SIZE]);
29
30impl TransactionId {
31    /// new returns new random transaction ID using crypto/rand
32    /// as source.
33    pub fn new() -> Self {
34        let mut b = TransactionId([0u8; TRANSACTION_ID_SIZE]);
35        rand::thread_rng().fill(&mut b.0);
36        b
37    }
38}
39
40impl Setter for TransactionId {
41    fn add_to(&self, m: &mut Message) -> Result<()> {
42        m.transaction_id = *self;
43        m.write_transaction_id();
44        Ok(())
45    }
46}
47
48// Interfaces that are implemented by message attributes, shorthands for them,
49// or helpers for message fields as type or transaction id.
50pub trait Setter {
51    // Setter sets *Message attribute.
52    fn add_to(&self, m: &mut Message) -> Result<()>;
53}
54
55// Getter parses attribute from *Message.
56pub trait Getter {
57    fn get_from(&mut self, m: &Message) -> Result<()>;
58}
59
60// Checker checks *Message attribute.
61pub trait Checker {
62    fn check(&self, m: &Message) -> Result<()>;
63}
64
65// is_message returns true if b looks like STUN message.
66// Useful for multiplexing. is_message does not guarantee
67// that decoding will be successful.
68pub fn is_message(b: &[u8]) -> bool {
69    b.len() >= MESSAGE_HEADER_SIZE && u32::from_be_bytes([b[4], b[5], b[6], b[7]]) == MAGIC_COOKIE
70}
71// Message represents a single STUN packet. It uses aggressive internal
72// buffering to enable zero-allocation encoding and decoding,
73// so there are some usage constraints:
74//
75// 	Message, its fields, results of m.Get or any attribute a.GetFrom
76//	are valid only until Message.Raw is not modified.
77#[derive(Default, Debug, Clone)]
78pub struct Message {
79    pub typ: MessageType,
80    pub length: u32, // len(Raw) not including header
81    pub transaction_id: TransactionId,
82    pub attributes: Attributes,
83    pub raw: Vec<u8>,
84}
85
86impl fmt::Display for Message {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        let t_id = BASE64_STANDARD.encode(self.transaction_id.0);
89        write!(
90            f,
91            "{} l={} attrs={} id={}",
92            self.typ,
93            self.length,
94            self.attributes.0.len(),
95            t_id
96        )
97    }
98}
99
100// Equal returns true if Message b equals to m.
101// Ignores m.Raw.
102impl PartialEq for Message {
103    fn eq(&self, other: &Self) -> bool {
104        if self.typ != other.typ {
105            return false;
106        }
107        if self.transaction_id != other.transaction_id {
108            return false;
109        }
110        if self.length != other.length {
111            return false;
112        }
113        if self.attributes != other.attributes {
114            return false;
115        }
116        true
117    }
118}
119
120const DEFAULT_RAW_CAPACITY: usize = 120;
121
122impl Setter for Message {
123    // add_to sets b.TransactionID to m.TransactionID.
124    //
125    // Implements Setter to aid in crafting responses.
126    fn add_to(&self, b: &mut Message) -> Result<()> {
127        b.transaction_id = self.transaction_id;
128        b.write_transaction_id();
129        Ok(())
130    }
131}
132
133impl Message {
134    // New returns *Message with pre-allocated Raw.
135    pub fn new() -> Self {
136        Message {
137            raw: {
138                let mut raw = Vec::with_capacity(DEFAULT_RAW_CAPACITY);
139                raw.extend_from_slice(&[0; MESSAGE_HEADER_SIZE]);
140                raw
141            },
142            ..Default::default()
143        }
144    }
145
146    // marshal_binary implements the encoding.BinaryMarshaler interface.
147    pub fn marshal_binary(&self) -> Result<Vec<u8>> {
148        // We can't return m.Raw, allocation is expected by implicit interface
149        // contract induced by other implementations.
150        Ok(self.raw.clone())
151    }
152
153    // unmarshal_binary implements the encoding.BinaryUnmarshaler interface.
154    pub fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> {
155        // We can't retain data, copy is expected by interface contract.
156        self.raw.clear();
157        self.raw.extend_from_slice(data);
158        self.decode()
159    }
160
161    // NewTransactionID sets m.TransactionID to random value from crypto/rand
162    // and returns error if any.
163    pub fn new_transaction_id(&mut self) -> Result<()> {
164        rand::thread_rng().fill(&mut self.transaction_id.0);
165        self.write_transaction_id();
166        Ok(())
167    }
168
169    // Reset resets Message, attributes and underlying buffer length.
170    pub fn reset(&mut self) {
171        self.raw.clear();
172        self.length = 0;
173        self.attributes.0.clear();
174    }
175
176    // grow ensures that internal buffer has n length.
177    fn grow(&mut self, n: usize, resize: bool) {
178        if self.raw.len() >= n {
179            if resize {
180                self.raw.resize(n, 0);
181            }
182            return;
183        }
184        self.raw.extend_from_slice(&vec![0; n - self.raw.len()]);
185    }
186
187    // Add appends new attribute to message. Not goroutine-safe.
188    //
189    // Value of attribute is copied to internal buffer so
190    // it is safe to reuse v.
191    pub fn add(&mut self, t: AttrType, v: &[u8]) {
192        // Allocating buffer for TLV (type-length-value).
193        // T = t, L = len(v), V = v.
194        // m.Raw will look like:
195        // [0:20]                               <- message header
196        // [20:20+m.Length]                     <- existing message attributes
197        // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
198        // [first:last]                         <- same as previous
199        // [0 1|2 3|4    4 + len(v)]            <- mapping for allocated buffer
200        //   T   L        V
201        let alloc_size = ATTRIBUTE_HEADER_SIZE + v.len(); // ~ len(TLV) = len(TL) + len(V)
202        let first = MESSAGE_HEADER_SIZE + self.length as usize; // first byte number
203        let mut last = first + alloc_size; // last byte number
204        self.grow(last, true); // growing cap(Raw) to fit TLV
205        self.length += alloc_size as u32; // rendering length change
206
207        // Encoding attribute TLV to allocated buffer.
208        let buf = &mut self.raw[first..last];
209        buf[0..2].copy_from_slice(&t.value().to_be_bytes()); // T
210        buf[2..4].copy_from_slice(&(v.len() as u16).to_be_bytes()); // L
211
212        let value = &mut buf[ATTRIBUTE_HEADER_SIZE..];
213        value.copy_from_slice(v); // V
214
215        let attr = RawAttribute {
216            typ: t,                 // T
217            length: v.len() as u16, // L
218            value: value.to_vec(),  // V
219        };
220
221        // Checking that attribute value needs padding.
222        if attr.length as usize % PADDING != 0 {
223            // Performing padding.
224            let bytes_to_add = nearest_padded_value_length(v.len()) - v.len();
225            last += bytes_to_add;
226            self.grow(last, true);
227            // setting all padding bytes to zero
228            // to prevent data leak from previous
229            // data in next bytes_to_add bytes
230            let buf = &mut self.raw[last - bytes_to_add..last];
231            for b in buf {
232                *b = 0;
233            }
234            self.length += bytes_to_add as u32; // rendering length change
235        }
236        self.attributes.0.push(attr);
237        self.write_length();
238    }
239
240    // WriteLength writes m.Length to m.Raw.
241    pub fn write_length(&mut self) {
242        self.grow(4, false);
243        self.raw[2..4].copy_from_slice(&(self.length as u16).to_be_bytes());
244    }
245
246    // WriteHeader writes header to underlying buffer. Not goroutine-safe.
247    pub fn write_header(&mut self) {
248        self.grow(MESSAGE_HEADER_SIZE, false);
249
250        self.write_type();
251        self.write_length();
252        self.raw[4..8].copy_from_slice(&MAGIC_COOKIE.to_be_bytes()); // magic cookie
253        self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0);
254        // transaction ID
255    }
256
257    // WriteTransactionID writes m.TransactionID to m.Raw.
258    pub fn write_transaction_id(&mut self) {
259        self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0);
260        // transaction ID
261    }
262
263    // WriteAttributes encodes all m.Attributes to m.
264    pub fn write_attributes(&mut self) {
265        let attributes: Vec<RawAttribute> = self.attributes.0.drain(..).collect();
266        for a in &attributes {
267            self.add(a.typ, &a.value);
268        }
269        self.attributes = Attributes(attributes);
270    }
271
272    // WriteType writes m.Type to m.Raw.
273    pub fn write_type(&mut self) {
274        self.grow(2, false);
275        self.raw[..2].copy_from_slice(&self.typ.value().to_be_bytes()); // message type
276    }
277
278    // SetType sets m.Type and writes it to m.Raw.
279    pub fn set_type(&mut self, t: MessageType) {
280        self.typ = t;
281        self.write_type();
282    }
283
284    // Encode re-encodes message into m.Raw.
285    pub fn encode(&mut self) {
286        self.raw.clear();
287        self.write_header();
288        self.length = 0;
289        self.write_attributes();
290    }
291
292    // Decode decodes m.Raw into m.
293    pub fn decode(&mut self) -> Result<()> {
294        // decoding message header
295        let buf = &self.raw;
296        if buf.len() < MESSAGE_HEADER_SIZE {
297            return Err(Error::ErrUnexpectedHeaderEof);
298        }
299
300        let t = u16::from_be_bytes([buf[0], buf[1]]); // first 2 bytes
301        let size = u16::from_be_bytes([buf[2], buf[3]]) as usize; // second 2 bytes
302        let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); // last 4 bytes
303        let full_size = MESSAGE_HEADER_SIZE + size; // len(m.Raw)
304
305        if cookie != MAGIC_COOKIE {
306            return Err(Error::Other(format!(
307                "{cookie:x} is invalid magic cookie (should be {MAGIC_COOKIE:x})"
308            )));
309        }
310        if buf.len() < full_size {
311            return Err(Error::Other(format!(
312                "buffer length {} is less than {} (expected message size)",
313                buf.len(),
314                full_size
315            )));
316        }
317
318        // saving header data
319        self.typ.read_value(t);
320        self.length = size as u32;
321        self.transaction_id
322            .0
323            .copy_from_slice(&buf[8..MESSAGE_HEADER_SIZE]);
324
325        self.attributes.0.clear();
326        let mut offset = 0;
327        let mut b = &buf[MESSAGE_HEADER_SIZE..full_size];
328
329        while offset < size {
330            // checking that we have enough bytes to read header
331            if b.len() < ATTRIBUTE_HEADER_SIZE {
332                return Err(Error::Other(format!(
333                    "buffer length {} is less than {} (expected header size)",
334                    b.len(),
335                    ATTRIBUTE_HEADER_SIZE
336                )));
337            }
338
339            let mut a = RawAttribute {
340                typ: compat_attr_type(u16::from_be_bytes([b[0], b[1]])), // first 2 bytes
341                length: u16::from_be_bytes([b[2], b[3]]),                // second 2 bytes
342                ..Default::default()
343            };
344            let a_l = a.length as usize; // attribute length
345            let a_buff_l = nearest_padded_value_length(a_l); // expected buffer length (with padding)
346
347            b = &b[ATTRIBUTE_HEADER_SIZE..]; // slicing again to simplify value read
348            offset += ATTRIBUTE_HEADER_SIZE;
349            if b.len() < a_buff_l {
350                // checking size
351                return Err(Error::Other(format!(
352                    "buffer length {} is less than {} (expected value size for {})",
353                    b.len(),
354                    a_buff_l,
355                    a.typ
356                )));
357            }
358            a.value = b[..a_l].to_vec();
359            offset += a_buff_l;
360            b = &b[a_buff_l..];
361
362            self.attributes.0.push(a);
363        }
364
365        Ok(())
366    }
367
368    // WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
369    // call result.
370    pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<usize> {
371        let n = writer.write(&self.raw)?;
372        Ok(n)
373    }
374
375    // ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
376    // Decodes it and return error if any. If m.Raw is too small, will return
377    // ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
378    //
379    // Can return *DecodeErr while decoding too.
380    pub fn read_from<R: Read>(&mut self, reader: &mut R) -> Result<usize> {
381        let mut t_buf = vec![0; DEFAULT_RAW_CAPACITY];
382        let n = reader.read(&mut t_buf)?;
383        self.raw = t_buf[..n].to_vec();
384        self.decode()?;
385        Ok(n)
386    }
387
388    // Write decodes message and return error if any.
389    //
390    // Any error is unrecoverable, but message could be partially decoded.
391    pub fn write(&mut self, t_buf: &[u8]) -> Result<usize> {
392        self.raw.clear();
393        self.raw.extend_from_slice(t_buf);
394        self.decode()?;
395        Ok(t_buf.len())
396    }
397
398    // CloneTo clones m to b securing any further m mutations.
399    pub fn clone_to(&self, b: &mut Message) -> Result<()> {
400        b.raw.clear();
401        b.raw.extend_from_slice(&self.raw);
402        b.decode()
403    }
404
405    // Contains return true if message contain t attribute.
406    pub fn contains(&self, t: AttrType) -> bool {
407        for a in &self.attributes.0 {
408            if a.typ == t {
409                return true;
410            }
411        }
412        false
413    }
414
415    // get returns byte slice that represents attribute value,
416    // if there is no attribute with such type,
417    // ErrAttributeNotFound is returned.
418    pub fn get(&self, t: AttrType) -> Result<Vec<u8>> {
419        let (v, ok) = self.attributes.get(t);
420        if ok {
421            Ok(v.value)
422        } else {
423            Err(Error::ErrAttributeNotFound)
424        }
425    }
426
427    // Build resets message and applies setters to it in batch, returning on
428    // first error. To prevent allocations, pass pointers to values.
429    //
430    // Example:
431    //  var (
432    //  	t        = BindingRequest
433    //  	username = NewUsername("username")
434    //  	nonce    = NewNonce("nonce")
435    //  	realm    = NewRealm("example.org")
436    //  )
437    //  m := new(Message)
438    //  m.Build(t, username, nonce, realm)     // 4 allocations
439    //  m.Build(&t, &username, &nonce, &realm) // 0 allocations
440    //
441    // See BenchmarkBuildOverhead.
442    pub fn build(&mut self, setters: &[Box<dyn Setter>]) -> Result<()> {
443        self.reset();
444        self.write_header();
445        for s in setters {
446            s.add_to(self)?;
447        }
448        Ok(())
449    }
450
451    // Check applies checkers to message in batch, returning on first error.
452    pub fn check<C: Checker>(&self, checkers: &[C]) -> Result<()> {
453        for c in checkers {
454            c.check(self)?;
455        }
456        Ok(())
457    }
458
459    // Parse applies getters to message in batch, returning on first error.
460    pub fn parse<G: Getter>(&self, getters: &mut [G]) -> Result<()> {
461        for c in getters {
462            c.get_from(self)?;
463        }
464        Ok(())
465    }
466}
467
468// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
469#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)]
470pub struct MessageClass(u8);
471
472// Possible values for message class in STUN Message Type.
473pub const CLASS_REQUEST: MessageClass = MessageClass(0x00); // 0b00
474pub const CLASS_INDICATION: MessageClass = MessageClass(0x01); // 0b01
475pub const CLASS_SUCCESS_RESPONSE: MessageClass = MessageClass(0x02); // 0b10
476pub const CLASS_ERROR_RESPONSE: MessageClass = MessageClass(0x03); // 0b11
477
478impl fmt::Display for MessageClass {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        let s = match *self {
481            CLASS_REQUEST => "request",
482            CLASS_INDICATION => "indication",
483            CLASS_SUCCESS_RESPONSE => "success response",
484            CLASS_ERROR_RESPONSE => "error response",
485            _ => "unknown message class",
486        };
487
488        write!(f, "{s}")
489    }
490}
491
492// Method is uint16 representation of 12-bit STUN method.
493#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)]
494pub struct Method(u16);
495
496// Possible methods for STUN Message.
497pub const METHOD_BINDING: Method = Method(0x001);
498pub const METHOD_ALLOCATE: Method = Method(0x003);
499pub const METHOD_REFRESH: Method = Method(0x004);
500pub const METHOD_SEND: Method = Method(0x006);
501pub const METHOD_DATA: Method = Method(0x007);
502pub const METHOD_CREATE_PERMISSION: Method = Method(0x008);
503pub const METHOD_CHANNEL_BIND: Method = Method(0x009);
504
505// Methods from RFC 6062.
506pub const METHOD_CONNECT: Method = Method(0x000a);
507pub const METHOD_CONNECTION_BIND: Method = Method(0x000b);
508pub const METHOD_CONNECTION_ATTEMPT: Method = Method(0x000c);
509
510impl fmt::Display for Method {
511    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
512        let unknown = format!("0x{:x}", self.0);
513
514        let s = match *self {
515            METHOD_BINDING => "Binding",
516            METHOD_ALLOCATE => "Allocate",
517            METHOD_REFRESH => "Refresh",
518            METHOD_SEND => "Send",
519            METHOD_DATA => "Data",
520            METHOD_CREATE_PERMISSION => "CreatePermission",
521            METHOD_CHANNEL_BIND => "ChannelBind",
522
523            // RFC 6062.
524            METHOD_CONNECT => "Connect",
525            METHOD_CONNECTION_BIND => "ConnectionBind",
526            METHOD_CONNECTION_ATTEMPT => "ConnectionAttempt",
527            _ => unknown.as_str(),
528        };
529
530        write!(f, "{s}")
531    }
532}
533
534// MessageType is STUN Message Type Field.
535#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
536pub struct MessageType {
537    pub method: Method,      // e.g. binding
538    pub class: MessageClass, // e.g. request
539}
540
541// Common STUN message types.
542// Binding request message type.
543pub const BINDING_REQUEST: MessageType = MessageType {
544    method: METHOD_BINDING,
545    class: CLASS_REQUEST,
546};
547// Binding success response message type
548pub const BINDING_SUCCESS: MessageType = MessageType {
549    method: METHOD_BINDING,
550    class: CLASS_SUCCESS_RESPONSE,
551};
552// Binding error response message type.
553pub const BINDING_ERROR: MessageType = MessageType {
554    method: METHOD_BINDING,
555    class: CLASS_ERROR_RESPONSE,
556};
557
558impl fmt::Display for MessageType {
559    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
560        write!(f, "{} {}", self.method, self.class)
561    }
562}
563
564const METHOD_ABITS: u16 = 0xf; // 0b0000000000001111
565const METHOD_BBITS: u16 = 0x70; // 0b0000000001110000
566const METHOD_DBITS: u16 = 0xf80; // 0b0000111110000000
567
568const METHOD_BSHIFT: u16 = 1;
569const METHOD_DSHIFT: u16 = 2;
570
571const FIRST_BIT: u16 = 0x1;
572const SECOND_BIT: u16 = 0x2;
573
574const C0BIT: u16 = FIRST_BIT;
575const C1BIT: u16 = SECOND_BIT;
576
577const CLASS_C0SHIFT: u16 = 4;
578const CLASS_C1SHIFT: u16 = 7;
579
580impl Setter for MessageType {
581    // add_to sets m type to t.
582    fn add_to(&self, m: &mut Message) -> Result<()> {
583        m.set_type(*self);
584        Ok(())
585    }
586}
587
588impl MessageType {
589    // NewType returns new message type with provided method and class.
590    pub fn new(method: Method, class: MessageClass) -> Self {
591        MessageType { method, class }
592    }
593
594    // Value returns bit representation of messageType.
595    pub fn value(&self) -> u16 {
596        //	 0                 1
597        //	 2  3  4 5 6 7 8 9 0 1 2 3 4 5
598        //	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
599        //	|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
600        //	|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
601        //	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
602        // Figure 3: Format of STUN Message Type Field
603
604        // Warning: Abandon all hope ye who enter here.
605        // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
606        let method = self.method.0;
607        let a = method & METHOD_ABITS; // A = M * 0b0000000000001111 (right 4 bits)
608        let b = method & METHOD_BBITS; // B = M * 0b0000000001110000 (3 bits after A)
609        let d = method & METHOD_DBITS; // D = M * 0b0000111110000000 (5 bits after B)
610
611        // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
612        let method = a + (b << METHOD_BSHIFT) + (d << METHOD_DSHIFT);
613
614        // C0 is zero bit of C, C1 is first bit.
615        // C0 = C * 0b01, C1 = (C * 0b10) >> 1
616        // Ct = C0 << 4 + C1 << 8.
617        // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
618        // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
619        // (see figure 3).
620        let c = self.class.0 as u16;
621        let c0 = (c & C0BIT) << CLASS_C0SHIFT;
622        let c1 = (c & C1BIT) << CLASS_C1SHIFT;
623        let class = c0 + c1;
624
625        method + class
626    }
627
628    // ReadValue decodes uint16 into MessageType.
629    pub fn read_value(&mut self, value: u16) {
630        // Decoding class.
631        // We are taking first bit from v >> 4 and second from v >> 7.
632        let c0 = (value >> CLASS_C0SHIFT) & C0BIT;
633        let c1 = (value >> CLASS_C1SHIFT) & C1BIT;
634        let class = c0 + c1;
635        self.class = MessageClass(class as u8);
636
637        // Decoding method.
638        let a = value & METHOD_ABITS; // A(M0-M3)
639        let b = (value >> METHOD_BSHIFT) & METHOD_BBITS; // B(M4-M6)
640        let d = (value >> METHOD_DSHIFT) & METHOD_DBITS; // D(M7-M11)
641        let m = a + b + d;
642        self.method = Method(m);
643    }
644}