Skip to main content

rtp/
header.rs

1use bytes::{Buf, BufMut, Bytes};
2use util::marshal::{Marshal, MarshalSize, Unmarshal};
3
4use crate::error::Error;
5
6pub const HEADER_LENGTH: usize = 4;
7pub const VERSION_SHIFT: u8 = 6;
8pub const VERSION_MASK: u8 = 0x3;
9pub const PADDING_SHIFT: u8 = 5;
10pub const PADDING_MASK: u8 = 0x1;
11pub const EXTENSION_SHIFT: u8 = 4;
12pub const EXTENSION_MASK: u8 = 0x1;
13pub const EXTENSION_PROFILE_ONE_BYTE: u16 = 0xBEDE;
14pub const EXTENSION_PROFILE_TWO_BYTE: u16 = 0x1000;
15pub const EXTENSION_ID_RESERVED: u8 = 0xF;
16pub const CC_MASK: u8 = 0xF;
17pub const MARKER_SHIFT: u8 = 7;
18pub const MARKER_MASK: u8 = 0x1;
19pub const PT_MASK: u8 = 0x7F;
20pub const SEQ_NUM_OFFSET: usize = 2;
21pub const SEQ_NUM_LENGTH: usize = 2;
22pub const TIMESTAMP_OFFSET: usize = 4;
23pub const TIMESTAMP_LENGTH: usize = 4;
24pub const SSRC_OFFSET: usize = 8;
25pub const SSRC_LENGTH: usize = 4;
26pub const CSRC_OFFSET: usize = 12;
27pub const CSRC_LENGTH: usize = 4;
28
29#[derive(Debug, Eq, PartialEq, Default, Clone)]
30pub struct Extension {
31    pub id: u8,
32    pub payload: Bytes,
33}
34
35/// Header represents an RTP packet header
36/// NOTE: PayloadOffset is populated by Marshal/Unmarshal and should not be modified
37#[derive(Debug, Eq, PartialEq, Default, Clone)]
38pub struct Header {
39    pub version: u8,
40    pub padding: bool,
41    pub extension: bool,
42    pub marker: bool,
43    pub payload_type: u8,
44    pub sequence_number: u16,
45    pub timestamp: u32,
46    pub ssrc: u32,
47    pub csrc: Vec<u32>,
48    pub extension_profile: u16,
49    pub extensions: Vec<Extension>,
50    pub extensions_padding: usize,
51}
52
53impl Unmarshal for Header {
54    /// Unmarshal parses the passed byte slice and stores the result in the Header this method is called upon
55    fn unmarshal<B>(raw_packet: &mut B) -> Result<Self, util::Error>
56    where
57        Self: Sized,
58        B: Buf,
59    {
60        let raw_packet_len = raw_packet.remaining();
61        if raw_packet_len < HEADER_LENGTH {
62            return Err(Error::ErrHeaderSizeInsufficient.into());
63        }
64        /*
65         *  0                   1                   2                   3
66         *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
67         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
68         * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
69         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
70         * |                           timestamp                           |
71         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
72         * |           synchronization source (SSRC) identifier            |
73         * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
74         * |            contributing source (CSRC) identifiers             |
75         * |                             ....                              |
76         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
77         */
78        let b0 = raw_packet.get_u8();
79        let version = b0 >> VERSION_SHIFT & VERSION_MASK;
80        let padding = (b0 >> PADDING_SHIFT & PADDING_MASK) > 0;
81        let extension = (b0 >> EXTENSION_SHIFT & EXTENSION_MASK) > 0;
82        let cc = (b0 & CC_MASK) as usize;
83
84        let mut curr_offset = CSRC_OFFSET + (cc * CSRC_LENGTH);
85        if raw_packet_len < curr_offset {
86            return Err(Error::ErrHeaderSizeInsufficient.into());
87        }
88
89        let b1 = raw_packet.get_u8();
90        let marker = (b1 >> MARKER_SHIFT & MARKER_MASK) > 0;
91        let payload_type = b1 & PT_MASK;
92
93        let sequence_number = raw_packet.get_u16();
94        let timestamp = raw_packet.get_u32();
95        let ssrc = raw_packet.get_u32();
96
97        let mut csrc = Vec::with_capacity(cc);
98        for _ in 0..cc {
99            csrc.push(raw_packet.get_u32());
100        }
101        let mut extensions_padding: usize = 0;
102        let (extension_profile, extensions) = if extension {
103            let expected = curr_offset + 4;
104            if raw_packet_len < expected {
105                return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
106            }
107            let extension_profile = raw_packet.get_u16();
108            curr_offset += 2;
109            let extension_length = raw_packet.get_u16() as usize * 4;
110            curr_offset += 2;
111
112            let expected = curr_offset + extension_length;
113            if raw_packet_len < expected {
114                return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
115            }
116
117            let mut extensions = vec![];
118            match extension_profile {
119                // RFC 8285 RTP One Byte Header Extension
120                EXTENSION_PROFILE_ONE_BYTE => {
121                    let end = curr_offset + extension_length;
122                    while curr_offset < end {
123                        let b = raw_packet.get_u8();
124                        if b == 0x00 {
125                            // padding
126                            curr_offset += 1;
127                            extensions_padding += 1;
128                            continue;
129                        }
130
131                        let extid = b >> 4;
132                        let len = ((b & (0xFF ^ 0xF0)) + 1) as usize;
133                        curr_offset += 1;
134
135                        if extid == EXTENSION_ID_RESERVED {
136                            break;
137                        }
138
139                        extensions.push(Extension {
140                            id: extid,
141                            payload: raw_packet.copy_to_bytes(len),
142                        });
143                        curr_offset += len;
144                    }
145                }
146                // RFC 8285 RTP Two Byte Header Extension
147                EXTENSION_PROFILE_TWO_BYTE => {
148                    let end = curr_offset + extension_length;
149                    while curr_offset < end {
150                        let b = raw_packet.get_u8();
151                        if b == 0x00 {
152                            // padding
153                            curr_offset += 1;
154                            extensions_padding += 1;
155                            continue;
156                        }
157
158                        let extid = b;
159                        curr_offset += 1;
160
161                        let len = raw_packet.get_u8() as usize;
162                        curr_offset += 1;
163
164                        extensions.push(Extension {
165                            id: extid,
166                            payload: raw_packet.copy_to_bytes(len),
167                        });
168                        curr_offset += len;
169                    }
170                }
171                // RFC3550 Extension
172                _ => {
173                    if raw_packet_len < curr_offset + extension_length {
174                        return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
175                    }
176                    extensions.push(Extension {
177                        id: 0,
178                        payload: raw_packet.copy_to_bytes(extension_length),
179                    });
180                }
181            };
182
183            (extension_profile, extensions)
184        } else {
185            (0, vec![])
186        };
187
188        Ok(Header {
189            version,
190            padding,
191            extension,
192            marker,
193            payload_type,
194            sequence_number,
195            timestamp,
196            ssrc,
197            csrc,
198            extension_profile,
199            extensions,
200            extensions_padding,
201        })
202    }
203}
204
205impl MarshalSize for Header {
206    /// MarshalSize returns the size of the packet once marshaled.
207    fn marshal_size(&self) -> usize {
208        let mut head_size = 12 + (self.csrc.len() * CSRC_LENGTH);
209        if self.extension {
210            let extension_payload_len = self.get_extension_payload_len() + self.extensions_padding;
211            let extension_payload_size = extension_payload_len.div_ceil(4);
212            head_size += 4 + extension_payload_size * 4;
213        }
214        head_size
215    }
216}
217
218impl Marshal for Header {
219    /// Marshal serializes the header and writes to the buffer.
220    fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize, util::Error> {
221        /*
222         *  0                   1                   2                   3
223         *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
224         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
225         * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
226         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
227         * |                           timestamp                           |
228         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
229         * |           synchronization source (SSRC) identifier            |
230         * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
231         * |            contributing source (CSRC) identifiers             |
232         * |                             ....                              |
233         * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
234         */
235        let remaining_before = buf.remaining_mut();
236        if remaining_before < self.marshal_size() {
237            return Err(Error::ErrBufferTooSmall.into());
238        }
239
240        // The first byte contains the version, padding bit, extension bit, and csrc size
241        let mut b0 = (self.version << VERSION_SHIFT) | self.csrc.len() as u8;
242        if self.padding {
243            b0 |= 1 << PADDING_SHIFT;
244        }
245
246        if self.extension {
247            b0 |= 1 << EXTENSION_SHIFT;
248        }
249        buf.put_u8(b0);
250
251        // The second byte contains the marker bit and payload type.
252        let mut b1 = self.payload_type;
253        if self.marker {
254            b1 |= 1 << MARKER_SHIFT;
255        }
256        buf.put_u8(b1);
257
258        buf.put_u16(self.sequence_number);
259        buf.put_u32(self.timestamp);
260        buf.put_u32(self.ssrc);
261
262        for csrc in &self.csrc {
263            buf.put_u32(*csrc);
264        }
265
266        if self.extension {
267            buf.put_u16(self.extension_profile);
268
269            // calculate extensions size and round to 4 bytes boundaries
270            let extension_payload_len = self.get_extension_payload_len();
271            if self.extension_profile != EXTENSION_PROFILE_ONE_BYTE
272                && self.extension_profile != EXTENSION_PROFILE_TWO_BYTE
273                && !extension_payload_len.is_multiple_of(4)
274            {
275                //the payload must be in 32-bit words.
276                return Err(Error::HeaderExtensionPayloadNot32BitWords.into());
277            }
278            let extension_payload_size = (extension_payload_len as u16).div_ceil(4);
279            buf.put_u16(extension_payload_size);
280
281            match self.extension_profile {
282                // RFC 8285 RTP One Byte Header Extension
283                EXTENSION_PROFILE_ONE_BYTE => {
284                    for extension in &self.extensions {
285                        buf.put_u8((extension.id << 4) | (extension.payload.len() as u8 - 1));
286                        buf.put(&*extension.payload);
287                    }
288                }
289                // RFC 8285 RTP Two Byte Header Extension
290                EXTENSION_PROFILE_TWO_BYTE => {
291                    for extension in &self.extensions {
292                        buf.put_u8(extension.id);
293                        buf.put_u8(extension.payload.len() as u8);
294                        buf.put(&*extension.payload);
295                    }
296                }
297                // RFC3550 Extension
298                _ => {
299                    if self.extensions.len() != 1 {
300                        return Err(Error::ErrRfc3550headerIdrange.into());
301                    }
302
303                    if let Some(extension) = self.extensions.first() {
304                        let ext_len = extension.payload.len();
305                        if ext_len % 4 != 0 {
306                            return Err(Error::HeaderExtensionPayloadNot32BitWords.into());
307                        }
308                        buf.put(&*extension.payload);
309                    }
310                }
311            };
312
313            // add padding to reach 4 bytes boundaries
314            for _ in extension_payload_len..extension_payload_size as usize * 4 {
315                buf.put_u8(0);
316            }
317        }
318
319        let remaining_after = buf.remaining_mut();
320        Ok(remaining_before - remaining_after)
321    }
322}
323
324impl Header {
325    pub fn get_extension_payload_len(&self) -> usize {
326        let payload_len: usize = self
327            .extensions
328            .iter()
329            .map(|extension| extension.payload.len())
330            .sum();
331
332        let profile_len = self.extensions.len()
333            * match self.extension_profile {
334                EXTENSION_PROFILE_ONE_BYTE => 1,
335                EXTENSION_PROFILE_TWO_BYTE => 2,
336                _ => 0,
337            };
338
339        payload_len + profile_len
340    }
341
342    /// SetExtension sets an RTP header extension
343    pub fn set_extension(&mut self, id: u8, payload: Bytes) -> Result<(), Error> {
344        let payload_len = payload.len() as isize;
345        if self.extension {
346            let extension_profile_len = match self.extension_profile {
347                EXTENSION_PROFILE_ONE_BYTE => {
348                    if !(1..=14).contains(&id) {
349                        return Err(Error::ErrRfc8285oneByteHeaderIdrange);
350                    }
351                    if payload_len > 16 {
352                        return Err(Error::ErrRfc8285oneByteHeaderSize);
353                    }
354                    1
355                }
356                EXTENSION_PROFILE_TWO_BYTE => {
357                    if id < 1 {
358                        return Err(Error::ErrRfc8285twoByteHeaderIdrange);
359                    }
360                    if payload_len > 255 {
361                        return Err(Error::ErrRfc8285twoByteHeaderSize);
362                    }
363                    2
364                }
365                _ => {
366                    if id != 0 {
367                        return Err(Error::ErrRfc3550headerIdrange);
368                    }
369                    0
370                }
371            };
372
373            let delta;
374            // Update existing if it exists else add new extension
375            if let Some(extension) = self
376                .extensions
377                .iter_mut()
378                .find(|extension| extension.id == id)
379            {
380                delta = payload_len - extension.payload.len() as isize;
381                extension.payload = payload;
382            } else {
383                delta = payload_len + extension_profile_len;
384                self.extensions.push(Extension { id, payload });
385            }
386
387            match delta.cmp(&0) {
388                std::cmp::Ordering::Less => {
389                    self.extensions_padding =
390                        ((self.extensions_padding as isize - delta) % 4) as usize;
391                }
392                std::cmp::Ordering::Greater => {
393                    let extension_padding = (delta % 4) as usize;
394                    if self.extensions_padding < extension_padding {
395                        self.extensions_padding = (self.extensions_padding + 4) - extension_padding;
396                    } else {
397                        self.extensions_padding -= extension_padding
398                    }
399                }
400                _ => {}
401            }
402        } else {
403            // No existing header extensions
404            self.extension = true;
405            let mut extension_profile_len = 0;
406            self.extension_profile = match payload_len {
407                0..=16 => {
408                    extension_profile_len = 1;
409                    EXTENSION_PROFILE_ONE_BYTE
410                }
411                17..=255 => {
412                    extension_profile_len = 2;
413                    EXTENSION_PROFILE_TWO_BYTE
414                }
415                _ => self.extension_profile,
416            };
417
418            let extension_padding = (payload.len() + extension_profile_len) % 4;
419            if self.extensions_padding < extension_padding {
420                self.extensions_padding = self.extensions_padding + 4 - extension_padding;
421            } else {
422                self.extensions_padding -= extension_padding
423            }
424            self.extensions.push(Extension { id, payload });
425        }
426        Ok(())
427    }
428
429    /// returns an extension id array
430    pub fn get_extension_ids(&self) -> Vec<u8> {
431        if self.extension {
432            self.extensions.iter().map(|e| e.id).collect()
433        } else {
434            vec![]
435        }
436    }
437
438    /// returns an RTP header extension
439    pub fn get_extension(&self, id: u8) -> Option<Bytes> {
440        if self.extension {
441            self.extensions
442                .iter()
443                .find(|extension| extension.id == id)
444                .map(|extension| extension.payload.clone())
445        } else {
446            None
447        }
448    }
449
450    /// Removes an RTP Header extension
451    pub fn del_extension(&mut self, id: u8) -> Result<(), Error> {
452        if self.extension {
453            if let Some(index) = self
454                .extensions
455                .iter()
456                .position(|extension| extension.id == id)
457            {
458                let extension = self.extensions.remove(index);
459
460                let extension_profile_len = match self.extension_profile {
461                    EXTENSION_PROFILE_ONE_BYTE => 1,
462                    EXTENSION_PROFILE_TWO_BYTE => 2,
463                    _ => 0,
464                };
465
466                let extension_padding = (extension.payload.len() + extension_profile_len) % 4;
467                self.extensions_padding = (self.extensions_padding + extension_padding) % 4;
468
469                Ok(())
470            } else {
471                Err(Error::ErrHeaderExtensionNotFound)
472            }
473        } else {
474            Err(Error::ErrHeaderExtensionsNotEnabled)
475        }
476    }
477}