msf_rtp/
rtp.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use crate::InvalidInput;
4
5/// Helper struct.
6#[repr(packed)]
7struct RawRtpHeader {
8    options: u16,
9    sequence_number: u16,
10    timestamp: u32,
11    ssrc: u32,
12}
13
14/// RTP header.
15#[derive(Clone)]
16pub struct RtpHeader {
17    options: u16,
18    sequence_number: u16,
19    timestamp: u32,
20    ssrc: u32,
21    csrcs: Vec<u32>,
22    extension: Option<RtpHeaderExtension>,
23}
24
25impl RtpHeader {
26    /// Create a new RTP header.
27    #[inline]
28    pub const fn new() -> Self {
29        Self {
30            options: 2 << 14,
31            sequence_number: 0,
32            timestamp: 0,
33            ssrc: 0,
34            csrcs: Vec::new(),
35            extension: None,
36        }
37    }
38
39    /// Decode an RTP header from given data.
40    pub fn decode(data: &mut Bytes) -> Result<Self, InvalidInput> {
41        let mut buffer = data.clone();
42
43        if buffer.len() < std::mem::size_of::<RawRtpHeader>() {
44            return Err(InvalidInput);
45        }
46
47        let ptr = buffer.as_ptr() as *const RawRtpHeader;
48        let raw = unsafe { &*ptr };
49
50        let mut res = Self {
51            options: u16::from_be(raw.options),
52            sequence_number: u16::from_be(raw.sequence_number),
53            timestamp: u32::from_be(raw.timestamp),
54            ssrc: u32::from_be(raw.ssrc),
55            csrcs: Vec::new(),
56            extension: None,
57        };
58
59        buffer.advance(std::mem::size_of::<RawRtpHeader>());
60
61        if (res.options >> 14) != 2 {
62            return Err(InvalidInput);
63        }
64
65        let csrc_count = ((res.options >> 8) & 0xf) as usize;
66
67        if buffer.len() < (csrc_count << 2) {
68            return Err(InvalidInput);
69        }
70
71        res.csrcs = Vec::with_capacity(csrc_count);
72
73        for _ in 0..csrc_count {
74            res.csrcs.push(buffer.get_u32());
75        }
76
77        if (res.options & 0x1000) != 0 {
78            res.extension = Some(RtpHeaderExtension::decode(&mut buffer)?);
79        }
80
81        *data = buffer;
82
83        Ok(res)
84    }
85
86    /// Encode the header.
87    pub fn encode(&self, buf: &mut BytesMut) {
88        buf.reserve(self.raw_size());
89
90        let raw = RawRtpHeader {
91            options: self.options.to_be(),
92            sequence_number: self.sequence_number.to_be(),
93            timestamp: self.timestamp.to_be(),
94            ssrc: self.ssrc.to_be(),
95        };
96
97        let ptr = &raw as *const _ as *const u8;
98
99        let data = unsafe { std::slice::from_raw_parts(ptr, std::mem::size_of::<RawRtpHeader>()) };
100
101        buf.extend_from_slice(data);
102
103        for csrc in &self.csrcs {
104            buf.put_u32(*csrc);
105        }
106
107        if let Some(extension) = self.extension.as_ref() {
108            extension.encode(buf);
109        }
110    }
111
112    /// Check if the RTP packet contains any padding.
113    #[inline]
114    pub fn padding(&self) -> bool {
115        (self.options & 0x2000) != 0
116    }
117
118    /// Set the padding bit.
119    #[inline]
120    pub fn with_padding(mut self, padding: bool) -> Self {
121        self.options &= !0x2000;
122        self.options |= (padding as u16) << 13;
123        self
124    }
125
126    /// Check if there is an RTP header extension.
127    #[inline]
128    pub fn extension(&self) -> Option<&RtpHeaderExtension> {
129        self.extension.as_ref()
130    }
131
132    /// Set the extension bit.
133    #[inline]
134    pub fn with_extension(mut self, extension: Option<RtpHeaderExtension>) -> Self {
135        self.options &= !0x1000;
136        self.options |= (extension.is_some() as u16) << 12;
137        self.extension = extension;
138        self
139    }
140
141    /// Check if the RTP marker bit is set.
142    #[inline]
143    pub fn marker(&self) -> bool {
144        (self.options & 0x0080) != 0
145    }
146
147    /// Set the marker bit.
148    #[inline]
149    pub fn with_marker(mut self, marker: bool) -> Self {
150        self.options &= !0x0080;
151        self.options |= (marker as u16) << 7;
152        self
153    }
154
155    /// Get RTP payload type.
156    ///
157    /// Note: Only the lower 7 bits are used.
158    #[inline]
159    pub fn payload_type(&self) -> u8 {
160        (self.options & 0x7f) as u8
161    }
162
163    /// Set the payload type.
164    ///
165    /// # Panics
166    /// The method panics if the payload type is greater than 127.
167    #[inline]
168    pub fn with_payload_type(mut self, payload_type: u8) -> Self {
169        assert!(payload_type < 128);
170
171        self.options &= !0x7f;
172        self.options |= (payload_type & 0x7f) as u16;
173        self
174    }
175
176    /// Get RTP sequence number.
177    #[inline]
178    pub fn sequence_number(&self) -> u16 {
179        self.sequence_number
180    }
181
182    /// Set the sequence number.
183    #[inline]
184    pub fn with_sequence_number(mut self, n: u16) -> Self {
185        self.sequence_number = n;
186        self
187    }
188
189    /// Get RTP timestamp.
190    #[inline]
191    pub fn timestamp(&self) -> u32 {
192        self.timestamp
193    }
194
195    /// Set RTP timestamp.
196    #[inline]
197    pub fn with_timestamp(mut self, timestamp: u32) -> Self {
198        self.timestamp = timestamp;
199        self
200    }
201
202    /// Get the SSRC identifier.
203    #[inline]
204    pub fn ssrc(&self) -> u32 {
205        self.ssrc
206    }
207
208    /// Set the SSRC identifier.
209    #[inline]
210    pub fn with_ssrc(mut self, ssrc: u32) -> Self {
211        self.ssrc = ssrc;
212        self
213    }
214
215    /// Get a list of CSRC identifiers.
216    #[inline]
217    pub fn csrcs(&self) -> &[u32] {
218        &self.csrcs
219    }
220
221    /// Set the CSRC identifiers.
222    ///
223    /// # Panics
224    /// The method panics if the number of identifiers is greater than 255.
225    #[inline]
226    pub fn with_csrcs<T>(mut self, csrcs: T) -> Self
227    where
228        T: Into<Vec<u32>>,
229    {
230        let csrcs = csrcs.into();
231
232        assert!(csrcs.len() <= 0xf);
233
234        self.csrcs = csrcs;
235        self.options &= !0xf00;
236        self.options |= (self.csrcs.len() as u16) << 8;
237        self
238    }
239
240    /// Get raw size of the header (i.e. byte length of the encoded header).
241    #[inline]
242    pub fn raw_size(&self) -> usize {
243        std::mem::size_of::<RawRtpHeader>()
244            + (self.csrcs.len() << 2)
245            + self.extension.as_ref().map(|e| e.raw_size()).unwrap_or(0)
246    }
247}
248
249impl Default for RtpHeader {
250    #[inline]
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256/// Helper struct.
257#[repr(packed)]
258struct RawHeaderExtension {
259    misc: u16,
260    length: u16,
261}
262
263/// RTP header extension.
264#[derive(Clone)]
265pub struct RtpHeaderExtension {
266    misc: u16,
267    data: Bytes,
268}
269
270impl RtpHeaderExtension {
271    /// Create a new header extension.
272    #[inline]
273    pub const fn new() -> Self {
274        Self {
275            misc: 0,
276            data: Bytes::new(),
277        }
278    }
279
280    /// Decode RTP header extension from given data.
281    pub fn decode(data: &mut Bytes) -> Result<Self, InvalidInput> {
282        let mut buffer = data.clone();
283
284        if buffer.len() < std::mem::size_of::<RawHeaderExtension>() {
285            return Err(InvalidInput);
286        }
287
288        let ptr = buffer.as_ptr() as *const RawHeaderExtension;
289        let raw = unsafe { &*ptr };
290
291        let extension_length = (u16::from_be(raw.length) as usize) << 2;
292        let misc = u16::from_be(raw.misc);
293
294        buffer.advance(std::mem::size_of::<RawHeaderExtension>());
295
296        if buffer.len() < extension_length {
297            return Err(InvalidInput);
298        }
299
300        let res = Self {
301            misc,
302            data: buffer.split_to(extension_length),
303        };
304
305        *data = buffer;
306
307        Ok(res)
308    }
309
310    /// Encode the header extension.
311    pub fn encode(&self, buf: &mut BytesMut) {
312        buf.reserve(self.raw_size());
313
314        let length = (self.data.len() >> 2) as u16;
315
316        let raw = RawHeaderExtension {
317            misc: self.misc.to_be(),
318            length: length.to_be(),
319        };
320
321        let ptr = &raw as *const _ as *const u8;
322
323        let header =
324            unsafe { std::slice::from_raw_parts(ptr, std::mem::size_of::<RawHeaderExtension>()) };
325
326        buf.extend_from_slice(header);
327        buf.extend_from_slice(&self.data);
328    }
329
330    /// Get the first 16 bits of the header extension.
331    #[inline]
332    pub fn misc(&self) -> u16 {
333        self.misc
334    }
335
336    /// Set the first 16 bits of the header extension.
337    #[inline]
338    pub fn with_misc(mut self, misc: u16) -> Self {
339        self.misc = misc;
340        self
341    }
342
343    /// Get header extension data.
344    #[inline]
345    pub fn data(&self) -> &Bytes {
346        &self.data
347    }
348
349    /// Set the extension data.
350    ///
351    /// # Panics
352    /// The method panics if the length of the data is not a multiple of four
353    /// or if the length is greater than 262140.
354    #[inline]
355    pub fn with_data(mut self, data: Bytes) -> Self {
356        assert_eq!(data.len() & 3, 0);
357
358        let words = data.len() >> 2;
359
360        assert!(words <= (u16::MAX as usize));
361
362        self.data = data;
363        self
364    }
365
366    /// Get raw size of the header extension (i.e. byte length of the encoded
367    /// header extension).
368    #[inline]
369    pub fn raw_size(&self) -> usize {
370        std::mem::size_of::<RawHeaderExtension>() + self.data.len()
371    }
372}
373
374impl Default for RtpHeaderExtension {
375    #[inline]
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381/// RTP packet.
382#[derive(Clone)]
383pub struct RtpPacket {
384    header: RtpHeader,
385    payload: Bytes,
386}
387
388impl RtpPacket {
389    /// Create a new RTP packet.
390    #[inline]
391    pub const fn new() -> Self {
392        Self {
393            header: RtpHeader::new(),
394            payload: Bytes::new(),
395        }
396    }
397
398    /// Create a new RTP packets from given parts.
399    pub fn from_parts(header: RtpHeader, payload: Bytes) -> Result<Self, InvalidInput> {
400        if header.padding() {
401            let padding_len = payload.last().copied().ok_or(InvalidInput)? as usize;
402
403            if padding_len == 0 || payload.len() < padding_len {
404                return Err(InvalidInput);
405            }
406        }
407
408        let res = Self { header, payload };
409
410        Ok(res)
411    }
412
413    /// Deconstruct the packet into its parts.
414    #[inline]
415    pub fn deconstruct(self) -> (RtpHeader, Bytes) {
416        (self.header, self.payload)
417    }
418
419    /// Decode RTP packet from given data frame.
420    pub fn decode(mut frame: Bytes) -> Result<Self, InvalidInput> {
421        let header = RtpHeader::decode(&mut frame)?;
422
423        let payload = frame;
424
425        Self::from_parts(header, payload)
426    }
427
428    /// Encode the packet.
429    pub fn encode(&self, buf: &mut BytesMut) {
430        buf.reserve(self.raw_size());
431
432        self.header.encode(buf);
433
434        buf.extend_from_slice(&self.payload);
435    }
436
437    /// Get the RTP header.
438    #[inline]
439    pub fn header(&self) -> &RtpHeader {
440        &self.header
441    }
442
443    /// Get the marker bit value.
444    #[inline]
445    pub fn marker(&self) -> bool {
446        self.header.marker()
447    }
448
449    /// Set the marker bit.
450    #[inline]
451    pub fn with_marker(mut self, marker: bool) -> Self {
452        self.header = self.header.with_marker(marker);
453        self
454    }
455
456    /// Get the payload type.
457    ///
458    /// Note: Only the lower 7 bits are used.
459    #[inline]
460    pub fn payload_type(&self) -> u8 {
461        self.header.payload_type()
462    }
463
464    /// Set the payload type.
465    ///
466    /// # Panics
467    /// The method panics if the payload type is greater than 127.
468    #[inline]
469    pub fn with_payload_type(mut self, payload_type: u8) -> Self {
470        self.header = self.header.with_payload_type(payload_type);
471        self
472    }
473
474    /// Get the RTP sequence number.
475    #[inline]
476    pub fn sequence_number(&self) -> u16 {
477        self.header.sequence_number()
478    }
479
480    /// Set the RTP sequence number.
481    #[inline]
482    pub fn with_sequence_number(mut self, sequence_number: u16) -> Self {
483        self.header = self.header.with_sequence_number(sequence_number);
484        self
485    }
486
487    /// Get the RTP timestamp.
488    #[inline]
489    pub fn timestamp(&self) -> u32 {
490        self.header.timestamp()
491    }
492
493    /// Set the RTP timestamp.
494    #[inline]
495    pub fn with_timestamp(mut self, timestamp: u32) -> Self {
496        self.header = self.header.with_timestamp(timestamp);
497        self
498    }
499
500    /// Get the SSRC identifier.
501    #[inline]
502    pub fn ssrc(&self) -> u32 {
503        self.header.ssrc()
504    }
505
506    /// Set the SSRC identifier.
507    #[inline]
508    pub fn with_ssrc(mut self, ssrc: u32) -> Self {
509        self.header = self.header.with_ssrc(ssrc);
510        self
511    }
512
513    /// Get the CSRC identifiers.
514    #[inline]
515    pub fn csrcs(&self) -> &[u32] {
516        self.header.csrcs()
517    }
518
519    /// Set the CSRC identifiers.
520    ///
521    /// # Panics
522    /// The method panics if the number of identifiers is greater than 255.
523    #[inline]
524    pub fn with_csrcs<T>(mut self, csrcs: T) -> Self
525    where
526        T: Into<Vec<u32>>,
527    {
528        self.header = self.header.with_csrcs(csrcs);
529        self
530    }
531
532    /// Get length of the optional padding.
533    ///
534    /// Zero means that the padding is not used at all.
535    #[inline]
536    pub fn padding(&self) -> u8 {
537        if self.header.padding() {
538            *self.payload.last().unwrap()
539        } else {
540            0
541        }
542    }
543
544    /// Get the packet payload including the optional padding.
545    #[inline]
546    pub fn payload(&self) -> &Bytes {
547        &self.payload
548    }
549
550    /// Get the packet payload without any padding.
551    #[inline]
552    pub fn stripped_payload(&self) -> Bytes {
553        let payload_len = self.payload.len();
554        let padding_len = self.padding() as usize;
555
556        let len = payload_len - padding_len;
557
558        self.payload.slice(..len)
559    }
560
561    /// Set the payload and add padding of a given length.
562    ///
563    /// If the padding is zero, no padding will be added and the padding bit in
564    /// the RTP header will be set to zero.
565    #[inline]
566    pub fn with_payload(mut self, payload: Bytes, padding: u8) -> Self {
567        if padding > 0 {
568            let len = payload.len() + (padding as usize);
569
570            let mut buffer = BytesMut::with_capacity(len);
571
572            buffer.extend_from_slice(&payload);
573            buffer.resize(len, 0);
574
575            buffer[len - 1] = padding;
576
577            self.header = self.header.with_padding(true);
578            self.payload = buffer.freeze();
579        } else {
580            self.header = self.header.with_padding(false);
581            self.payload = payload;
582        }
583
584        self
585    }
586
587    /// Set the payload that already includes padding.
588    ///
589    /// # Panics
590    /// The method panics if the given payload is empty, if the last byte is
591    /// zero or if the length of the padding is greater than the length of the
592    /// payload.
593    #[inline]
594    pub fn with_padded_payload(mut self, payload: Bytes) -> Self {
595        let padding_len = payload.last().copied().expect("empty payload") as usize;
596
597        assert!(padding_len > 0 && payload.len() >= padding_len);
598
599        self.header = self.header.with_padding(true);
600        self.payload = payload;
601        self
602    }
603
604    /// Get raw size of the packet (i.e. byte length of the encoded packet).
605    #[inline]
606    pub fn raw_size(&self) -> usize {
607        self.header.raw_size() + self.payload.len()
608    }
609}
610
611impl Default for RtpPacket {
612    #[inline]
613    fn default() -> Self {
614        Self::new()
615    }
616}