msf_rtp/
rtp.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use crate::InvalidInput;
4
5/// Helper struct.
6#[repr(C, packed)]
7struct RawRtpHeader {
8    options: u16,
9    sequence_number: u16,
10    timestamp: u32,
11    ssrc: u32,
12}
13
14/// RTP header.
15#[derive(Clone)]
16pub struct RtpHeader {
17    options: u16,
18    sequence_number: u16,
19    timestamp: u32,
20    ssrc: u32,
21    csrcs: Vec<u32>,
22    extension: Option<RtpHeaderExtension>,
23}
24
25impl RtpHeader {
26    /// Create a new RTP header.
27    #[inline]
28    pub const fn new() -> Self {
29        Self {
30            options: 2 << 14,
31            sequence_number: 0,
32            timestamp: 0,
33            ssrc: 0,
34            csrcs: Vec::new(),
35            extension: None,
36        }
37    }
38
39    /// Decode an RTP header from given data.
40    pub fn decode(data: &mut Bytes) -> Result<Self, InvalidInput> {
41        let mut buffer = data.clone();
42
43        if buffer.len() < std::mem::size_of::<RawRtpHeader>() {
44            return Err(InvalidInput);
45        }
46
47        let ptr = buffer.as_ptr() as *const RawRtpHeader;
48
49        let raw = unsafe { ptr.read_unaligned() };
50
51        let mut res = Self {
52            options: u16::from_be(raw.options),
53            sequence_number: u16::from_be(raw.sequence_number),
54            timestamp: u32::from_be(raw.timestamp),
55            ssrc: u32::from_be(raw.ssrc),
56            csrcs: Vec::new(),
57            extension: None,
58        };
59
60        buffer.advance(std::mem::size_of::<RawRtpHeader>());
61
62        if (res.options >> 14) != 2 {
63            return Err(InvalidInput);
64        }
65
66        let csrc_count = ((res.options >> 8) & 0xf) as usize;
67
68        if buffer.len() < (csrc_count << 2) {
69            return Err(InvalidInput);
70        }
71
72        res.csrcs = Vec::with_capacity(csrc_count);
73
74        for _ in 0..csrc_count {
75            res.csrcs.push(buffer.get_u32());
76        }
77
78        if (res.options & 0x1000) != 0 {
79            res.extension = Some(RtpHeaderExtension::decode(&mut buffer)?);
80        }
81
82        *data = buffer;
83
84        Ok(res)
85    }
86
87    /// Encode the header.
88    pub fn encode(&self, buf: &mut BytesMut) {
89        buf.reserve(self.raw_size());
90
91        let raw = RawRtpHeader {
92            options: self.options.to_be(),
93            sequence_number: self.sequence_number.to_be(),
94            timestamp: self.timestamp.to_be(),
95            ssrc: self.ssrc.to_be(),
96        };
97
98        let ptr = &raw as *const _ as *const u8;
99
100        let data = unsafe { std::slice::from_raw_parts(ptr, std::mem::size_of::<RawRtpHeader>()) };
101
102        buf.extend_from_slice(data);
103
104        for csrc in &self.csrcs {
105            buf.put_u32(*csrc);
106        }
107
108        if let Some(extension) = self.extension.as_ref() {
109            extension.encode(buf);
110        }
111    }
112
113    /// Check if the RTP packet contains any padding.
114    #[inline]
115    pub fn padding(&self) -> bool {
116        (self.options & 0x2000) != 0
117    }
118
119    /// Set the padding bit.
120    #[inline]
121    pub fn with_padding(mut self, padding: bool) -> Self {
122        self.options &= !0x2000;
123        self.options |= (padding as u16) << 13;
124        self
125    }
126
127    /// Check if there is an RTP header extension.
128    #[inline]
129    pub fn extension(&self) -> Option<&RtpHeaderExtension> {
130        self.extension.as_ref()
131    }
132
133    /// Set the extension bit.
134    #[inline]
135    pub fn with_extension(mut self, extension: Option<RtpHeaderExtension>) -> Self {
136        self.options &= !0x1000;
137        self.options |= (extension.is_some() as u16) << 12;
138        self.extension = extension;
139        self
140    }
141
142    /// Check if the RTP marker bit is set.
143    #[inline]
144    pub fn marker(&self) -> bool {
145        (self.options & 0x0080) != 0
146    }
147
148    /// Set the marker bit.
149    #[inline]
150    pub fn with_marker(mut self, marker: bool) -> Self {
151        self.options &= !0x0080;
152        self.options |= (marker as u16) << 7;
153        self
154    }
155
156    /// Get RTP payload type.
157    ///
158    /// Note: Only the lower 7 bits are used.
159    #[inline]
160    pub fn payload_type(&self) -> u8 {
161        (self.options & 0x7f) as u8
162    }
163
164    /// Set the payload type.
165    ///
166    /// # Panics
167    /// The method panics if the payload type is greater than 127.
168    #[inline]
169    pub fn with_payload_type(mut self, payload_type: u8) -> Self {
170        assert!(payload_type < 128);
171
172        self.options &= !0x7f;
173        self.options |= (payload_type & 0x7f) as u16;
174        self
175    }
176
177    /// Get RTP sequence number.
178    #[inline]
179    pub fn sequence_number(&self) -> u16 {
180        self.sequence_number
181    }
182
183    /// Set the sequence number.
184    #[inline]
185    pub fn with_sequence_number(mut self, n: u16) -> Self {
186        self.sequence_number = n;
187        self
188    }
189
190    /// Get RTP timestamp.
191    #[inline]
192    pub fn timestamp(&self) -> u32 {
193        self.timestamp
194    }
195
196    /// Set RTP timestamp.
197    #[inline]
198    pub fn with_timestamp(mut self, timestamp: u32) -> Self {
199        self.timestamp = timestamp;
200        self
201    }
202
203    /// Get the SSRC identifier.
204    #[inline]
205    pub fn ssrc(&self) -> u32 {
206        self.ssrc
207    }
208
209    /// Set the SSRC identifier.
210    #[inline]
211    pub fn with_ssrc(mut self, ssrc: u32) -> Self {
212        self.ssrc = ssrc;
213        self
214    }
215
216    /// Get a list of CSRC identifiers.
217    #[inline]
218    pub fn csrcs(&self) -> &[u32] {
219        &self.csrcs
220    }
221
222    /// Set the CSRC identifiers.
223    ///
224    /// # Panics
225    /// The method panics if the number of identifiers is greater than 255.
226    #[inline]
227    pub fn with_csrcs<T>(mut self, csrcs: T) -> Self
228    where
229        T: Into<Vec<u32>>,
230    {
231        let csrcs = csrcs.into();
232
233        assert!(csrcs.len() <= 0xf);
234
235        self.csrcs = csrcs;
236        self.options &= !0xf00;
237        self.options |= (self.csrcs.len() as u16) << 8;
238        self
239    }
240
241    /// Get raw size of the header (i.e. byte length of the encoded header).
242    #[inline]
243    pub fn raw_size(&self) -> usize {
244        std::mem::size_of::<RawRtpHeader>()
245            + (self.csrcs.len() << 2)
246            + self.extension.as_ref().map(|e| e.raw_size()).unwrap_or(0)
247    }
248}
249
250impl Default for RtpHeader {
251    #[inline]
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// Helper struct.
258#[repr(C, packed)]
259struct RawHeaderExtension {
260    misc: u16,
261    length: u16,
262}
263
264/// RTP header extension.
265#[derive(Clone)]
266pub struct RtpHeaderExtension {
267    misc: u16,
268    data: Bytes,
269}
270
271impl RtpHeaderExtension {
272    /// Create a new header extension.
273    #[inline]
274    pub const fn new() -> Self {
275        Self {
276            misc: 0,
277            data: Bytes::new(),
278        }
279    }
280
281    /// Decode RTP header extension from given data.
282    pub fn decode(data: &mut Bytes) -> Result<Self, InvalidInput> {
283        let mut buffer = data.clone();
284
285        if buffer.len() < std::mem::size_of::<RawHeaderExtension>() {
286            return Err(InvalidInput);
287        }
288
289        let ptr = buffer.as_ptr() as *const RawHeaderExtension;
290
291        let raw = unsafe { ptr.read_unaligned() };
292
293        let extension_length = (u16::from_be(raw.length) as usize) << 2;
294        let misc = u16::from_be(raw.misc);
295
296        buffer.advance(std::mem::size_of::<RawHeaderExtension>());
297
298        if buffer.len() < extension_length {
299            return Err(InvalidInput);
300        }
301
302        let res = Self {
303            misc,
304            data: buffer.split_to(extension_length),
305        };
306
307        *data = buffer;
308
309        Ok(res)
310    }
311
312    /// Encode the header extension.
313    pub fn encode(&self, buf: &mut BytesMut) {
314        buf.reserve(self.raw_size());
315
316        let length = (self.data.len() >> 2) as u16;
317
318        let raw = RawHeaderExtension {
319            misc: self.misc.to_be(),
320            length: length.to_be(),
321        };
322
323        let ptr = &raw as *const _ as *const u8;
324
325        let header =
326            unsafe { std::slice::from_raw_parts(ptr, std::mem::size_of::<RawHeaderExtension>()) };
327
328        buf.extend_from_slice(header);
329        buf.extend_from_slice(&self.data);
330    }
331
332    /// Get the first 16 bits of the header extension.
333    #[inline]
334    pub fn misc(&self) -> u16 {
335        self.misc
336    }
337
338    /// Set the first 16 bits of the header extension.
339    #[inline]
340    pub fn with_misc(mut self, misc: u16) -> Self {
341        self.misc = misc;
342        self
343    }
344
345    /// Get header extension data.
346    #[inline]
347    pub fn data(&self) -> &Bytes {
348        &self.data
349    }
350
351    /// Set the extension data.
352    ///
353    /// # Panics
354    /// The method panics if the length of the data is not a multiple of four
355    /// or if the length is greater than 262140.
356    #[inline]
357    pub fn with_data(mut self, data: Bytes) -> Self {
358        assert_eq!(data.len() & 3, 0);
359
360        let words = data.len() >> 2;
361
362        assert!(words <= (u16::MAX as usize));
363
364        self.data = data;
365        self
366    }
367
368    /// Get raw size of the header extension (i.e. byte length of the encoded
369    /// header extension).
370    #[inline]
371    pub fn raw_size(&self) -> usize {
372        std::mem::size_of::<RawHeaderExtension>() + self.data.len()
373    }
374}
375
376impl Default for RtpHeaderExtension {
377    #[inline]
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383/// RTP packet.
384#[derive(Clone)]
385pub struct RtpPacket {
386    header: RtpHeader,
387    payload: Bytes,
388}
389
390impl RtpPacket {
391    /// Create a new RTP packet.
392    #[inline]
393    pub const fn new() -> Self {
394        Self {
395            header: RtpHeader::new(),
396            payload: Bytes::new(),
397        }
398    }
399
400    /// Create a new RTP packets from given parts.
401    pub fn from_parts(header: RtpHeader, payload: Bytes) -> Result<Self, InvalidInput> {
402        if header.padding() {
403            let padding_len = payload.last().copied().ok_or(InvalidInput)? as usize;
404
405            if padding_len == 0 || payload.len() < padding_len {
406                return Err(InvalidInput);
407            }
408        }
409
410        let res = Self { header, payload };
411
412        Ok(res)
413    }
414
415    /// Deconstruct the packet into its parts.
416    #[inline]
417    pub fn deconstruct(self) -> (RtpHeader, Bytes) {
418        (self.header, self.payload)
419    }
420
421    /// Decode RTP packet from given data frame.
422    pub fn decode(mut frame: Bytes) -> Result<Self, InvalidInput> {
423        let header = RtpHeader::decode(&mut frame)?;
424
425        let payload = frame;
426
427        Self::from_parts(header, payload)
428    }
429
430    /// Encode the packet.
431    pub fn encode(&self, buf: &mut BytesMut) {
432        buf.reserve(self.raw_size());
433
434        self.header.encode(buf);
435
436        buf.extend_from_slice(&self.payload);
437    }
438
439    /// Get the RTP header.
440    #[inline]
441    pub fn header(&self) -> &RtpHeader {
442        &self.header
443    }
444
445    /// Get the marker bit value.
446    #[inline]
447    pub fn marker(&self) -> bool {
448        self.header.marker()
449    }
450
451    /// Set the marker bit.
452    #[inline]
453    pub fn with_marker(mut self, marker: bool) -> Self {
454        self.header = self.header.with_marker(marker);
455        self
456    }
457
458    /// Get the payload type.
459    ///
460    /// Note: Only the lower 7 bits are used.
461    #[inline]
462    pub fn payload_type(&self) -> u8 {
463        self.header.payload_type()
464    }
465
466    /// Set the payload type.
467    ///
468    /// # Panics
469    /// The method panics if the payload type is greater than 127.
470    #[inline]
471    pub fn with_payload_type(mut self, payload_type: u8) -> Self {
472        self.header = self.header.with_payload_type(payload_type);
473        self
474    }
475
476    /// Get the RTP sequence number.
477    #[inline]
478    pub fn sequence_number(&self) -> u16 {
479        self.header.sequence_number()
480    }
481
482    /// Set the RTP sequence number.
483    #[inline]
484    pub fn with_sequence_number(mut self, sequence_number: u16) -> Self {
485        self.header = self.header.with_sequence_number(sequence_number);
486        self
487    }
488
489    /// Get the RTP timestamp.
490    #[inline]
491    pub fn timestamp(&self) -> u32 {
492        self.header.timestamp()
493    }
494
495    /// Set the RTP timestamp.
496    #[inline]
497    pub fn with_timestamp(mut self, timestamp: u32) -> Self {
498        self.header = self.header.with_timestamp(timestamp);
499        self
500    }
501
502    /// Get the SSRC identifier.
503    #[inline]
504    pub fn ssrc(&self) -> u32 {
505        self.header.ssrc()
506    }
507
508    /// Set the SSRC identifier.
509    #[inline]
510    pub fn with_ssrc(mut self, ssrc: u32) -> Self {
511        self.header = self.header.with_ssrc(ssrc);
512        self
513    }
514
515    /// Get the CSRC identifiers.
516    #[inline]
517    pub fn csrcs(&self) -> &[u32] {
518        self.header.csrcs()
519    }
520
521    /// Set the CSRC identifiers.
522    ///
523    /// # Panics
524    /// The method panics if the number of identifiers is greater than 255.
525    #[inline]
526    pub fn with_csrcs<T>(mut self, csrcs: T) -> Self
527    where
528        T: Into<Vec<u32>>,
529    {
530        self.header = self.header.with_csrcs(csrcs);
531        self
532    }
533
534    /// Get length of the optional padding.
535    ///
536    /// Zero means that the padding is not used at all.
537    #[inline]
538    pub fn padding(&self) -> u8 {
539        if self.header.padding() {
540            *self.payload.last().unwrap()
541        } else {
542            0
543        }
544    }
545
546    /// Get the packet payload including the optional padding.
547    #[inline]
548    pub fn payload(&self) -> &Bytes {
549        &self.payload
550    }
551
552    /// Get the packet payload without any padding.
553    #[inline]
554    pub fn stripped_payload(&self) -> Bytes {
555        let payload_len = self.payload.len();
556        let padding_len = self.padding() as usize;
557
558        let len = payload_len - padding_len;
559
560        self.payload.slice(..len)
561    }
562
563    /// Set the payload and add padding of a given length.
564    ///
565    /// If the padding is zero, no padding will be added and the padding bit in
566    /// the RTP header will be set to zero.
567    #[inline]
568    pub fn with_payload(mut self, payload: Bytes, padding: u8) -> Self {
569        if padding > 0 {
570            let len = payload.len() + (padding as usize);
571
572            let mut buffer = BytesMut::with_capacity(len);
573
574            buffer.extend_from_slice(&payload);
575            buffer.resize(len, 0);
576
577            buffer[len - 1] = padding;
578
579            self.header = self.header.with_padding(true);
580            self.payload = buffer.freeze();
581        } else {
582            self.header = self.header.with_padding(false);
583            self.payload = payload;
584        }
585
586        self
587    }
588
589    /// Set the payload that already includes padding.
590    ///
591    /// # Panics
592    /// The method panics if the given payload is empty, if the last byte is
593    /// zero or if the length of the padding is greater than the length of the
594    /// payload.
595    #[inline]
596    pub fn with_padded_payload(mut self, payload: Bytes) -> Self {
597        let padding_len = payload.last().copied().expect("empty payload") as usize;
598
599        assert!(padding_len > 0 && payload.len() >= padding_len);
600
601        self.header = self.header.with_padding(true);
602        self.payload = payload;
603        self
604    }
605
606    /// Get raw size of the packet (i.e. byte length of the encoded packet).
607    #[inline]
608    pub fn raw_size(&self) -> usize {
609        self.header.raw_size() + self.payload.len()
610    }
611}
612
613impl Default for RtpPacket {
614    #[inline]
615    fn default() -> Self {
616        Self::new()
617    }
618}