rtc_rtp/codec/av1/
depacketizer.rs

1//! AV1 RTP Depacketizer
2//!
3//! Reads AV1 RTP packets and outputs AV1 low overhead bitstream format.
4//! Based on <https://aomediacodec.github.io/av1-rtp-spec/>
5
6use bytes::{BufMut, Bytes, BytesMut};
7
8use crate::codec::av1::leb128::read_leb128;
9use crate::codec::av1::obu::{
10    OBU_HAS_SIZE_BIT, OBU_TYPE_MASK, OBU_TYPE_TEMPORAL_DELIMITER, OBU_TYPE_TILE_LIST,
11};
12use crate::packetizer::Depacketizer;
13use shared::error::{Error, Result};
14
15// AV1 Aggregation Header bit masks
16const AV1_Z_MASK: u8 = 0b1000_0000;
17const AV1_Y_MASK: u8 = 0b0100_0000;
18const AV1_W_MASK: u8 = 0b0011_0000;
19const AV1_N_MASK: u8 = 0b0000_1000;
20
21/// AV1 RTP Depacketizer
22///
23/// Depacketizes AV1 RTP packets into low overhead bitstream format with obu_size fields.
24#[derive(Default, Debug, Clone)]
25pub struct Av1Depacketizer {
26    /// Buffer for fragmented OBU from previous packet
27    buffer: BytesMut,
28    /// Z flag from aggregation header - first OBU is continuation
29    pub z: bool,
30    /// Y flag from aggregation header - last OBU will continue
31    pub y: bool,
32    /// N flag from aggregation header - new coded video sequence
33    pub n: bool,
34}
35
36impl Av1Depacketizer {
37    pub fn new() -> Self {
38        Self::default()
39    }
40}
41
42impl Depacketizer for Av1Depacketizer {
43    /// Depacketize parses an AV1 RTP payload into OBU stream with obu_size_field.
44    ///
45    /// Reference: <https://aomediacodec.github.io/av1-rtp-spec/>
46    fn depacketize(&mut self, payload: &Bytes) -> Result<Bytes> {
47        if payload.len() <= 1 {
48            return Err(Error::ErrShortPacket);
49        }
50
51        // Parse aggregation header
52        // |Z|Y| W |N|-|-|-|
53        let obu_z = (payload[0] & AV1_Z_MASK) != 0;
54        let obu_y = (payload[0] & AV1_Y_MASK) != 0;
55        let obu_count = (payload[0] & AV1_W_MASK) >> 4;
56        let obu_n = (payload[0] & AV1_N_MASK) != 0;
57
58        self.z = obu_z;
59        self.y = obu_y;
60        self.n = obu_n;
61
62        // Clear buffer on new coded video sequence
63        if obu_n {
64            self.buffer.clear();
65        }
66
67        // Clear buffer if Z is not set but we have buffered data
68        if !obu_z && !self.buffer.is_empty() {
69            self.buffer.clear();
70        }
71
72        let mut result = BytesMut::new();
73        let mut offset = 1; // Skip aggregation header
74        let mut obu_offset = 0;
75
76        while offset < payload.len() {
77            let is_first = obu_offset == 0;
78            let is_last = obu_count != 0 && obu_offset == (obu_count - 1) as usize;
79
80            // Read OBU element length
81            let (length_field, is_last) = if obu_count == 0 || !is_last {
82                // W=0 or not last element: length field present
83                let payload_slice = payload.slice(offset..);
84                let (len, n) = read_leb128(&payload_slice);
85                if n == 0 {
86                    return Err(Error::ErrShortPacket);
87                }
88                offset += n;
89
90                // Check if this is actually the last element when W=0
91                let is_last_w0 = obu_count == 0 && offset + len as usize == payload.len();
92                (len as usize, is_last || is_last_w0)
93            } else {
94                // Last element when W != 0: no length field
95                (payload.len() - offset, true)
96            };
97
98            if offset + length_field > payload.len() {
99                return Err(Error::ErrShortPacket);
100            }
101
102            // Build OBU buffer
103            let obu_buffer = if is_first && obu_z {
104                // Continuation of previous packet's OBU
105                if self.buffer.is_empty() {
106                    // Lost first fragment, skip this OBU
107                    if is_last {
108                        break;
109                    }
110                    offset += length_field;
111                    obu_offset += 1;
112                    continue;
113                }
114
115                // Combine buffered data with current fragment
116                let mut combined = std::mem::take(&mut self.buffer);
117                combined.extend_from_slice(&payload[offset..offset + length_field]);
118                combined.freeze()
119            } else {
120                payload.slice(offset..offset + length_field)
121            };
122            offset += length_field;
123
124            // If this is the last OBU and Y flag is set, buffer it for next packet
125            if is_last && obu_y {
126                self.buffer = BytesMut::from(obu_buffer.as_ref());
127                break;
128            }
129
130            // Skip empty OBUs
131            if obu_buffer.is_empty() {
132                if is_last {
133                    break;
134                }
135                obu_offset += 1;
136                continue;
137            }
138
139            // Parse OBU header to check type
140            let obu_type = (obu_buffer[0] & OBU_TYPE_MASK) >> 3;
141
142            // Skip temporal delimiter and tile list OBUs
143            if obu_type == OBU_TYPE_TEMPORAL_DELIMITER || obu_type == OBU_TYPE_TILE_LIST {
144                if is_last {
145                    break;
146                }
147                obu_offset += 1;
148                continue;
149            }
150
151            // Check if OBU has size field
152            let has_size_field = (obu_buffer[0] & OBU_HAS_SIZE_BIT) != 0;
153            let has_extension = (obu_buffer[0] & 0x04) != 0;
154            let header_size = if has_extension { 2 } else { 1 };
155
156            if has_size_field {
157                // OBU already has size field, validate it
158                let payload_slice = obu_buffer.slice(header_size..);
159                let (obu_size, leb_size) = read_leb128(&payload_slice);
160                let expected_size = header_size + leb_size + obu_size as usize;
161                if length_field != expected_size {
162                    return Err(Error::ErrShortPacket);
163                }
164                result.extend_from_slice(&obu_buffer);
165            } else {
166                // Add size field to OBU
167                // Set obu_has_size_field bit
168                result.put_u8(obu_buffer[0] | OBU_HAS_SIZE_BIT);
169
170                // Copy extension header if present
171                if has_extension && obu_buffer.len() > 1 {
172                    result.put_u8(obu_buffer[1]);
173                }
174
175                // Write payload size as LEB128
176                let payload_size = obu_buffer.len() - header_size;
177                write_leb128(&mut result, payload_size as u32);
178
179                // Copy OBU payload
180                if header_size < obu_buffer.len() {
181                    result.extend_from_slice(&obu_buffer[header_size..]);
182                }
183            }
184
185            if is_last {
186                break;
187            }
188            obu_offset += 1;
189        }
190
191        // Validate OBU count if W field was set
192        if obu_count != 0 && obu_offset != (obu_count - 1) as usize && !self.y {
193            return Err(Error::ErrShortPacket);
194        }
195
196        Ok(result.freeze())
197    }
198
199    /// Returns true if Z flag is not set (first OBU is not a continuation)
200    fn is_partition_head(&self, payload: &Bytes) -> bool {
201        if payload.is_empty() {
202            return false;
203        }
204        (payload[0] & AV1_Z_MASK) == 0
205    }
206
207    /// Returns true if marker bit is set (end of frame)
208    fn is_partition_tail(&self, marker: bool, _payload: &Bytes) -> bool {
209        marker
210    }
211}
212
213/// Write LEB128 encoded value to buffer
214fn write_leb128(buf: &mut BytesMut, mut value: u32) {
215    loop {
216        let mut byte = (value & 0x7f) as u8;
217        value >>= 7;
218        if value != 0 {
219            byte |= 0x80;
220        }
221        buf.put_u8(byte);
222        if value == 0 {
223            break;
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_depacketizer_basic() {
234        let mut depacketizer = Av1Depacketizer::new();
235
236        // Simple packet with one OBU element (W=1)
237        // Aggregation header: W=1, no Z, no Y, no N = 0x10
238        // OBU header: type=6 (Frame), no extension, no size = 0x30
239        // Total: aggregation header + OBU header + payload
240        let payload = Bytes::from(vec![
241            0x10, // Aggregation header: W=1
242            0x30, // OBU header: type=6 (Frame), no ext, no size
243            0x01, 0x02, 0x03, // OBU payload
244        ]);
245
246        let result = depacketizer.depacketize(&payload).unwrap();
247        assert!(!result.is_empty());
248        // Should have size field added (OBU_HAS_SIZE_BIT = 0x02)
249        assert_eq!(result[0] & OBU_HAS_SIZE_BIT, OBU_HAS_SIZE_BIT);
250        // Size should be 3 (payload bytes)
251        assert_eq!(result[1], 3);
252        // Payload should follow
253        assert_eq!(&result[2..], &[0x01, 0x02, 0x03]);
254    }
255
256    #[test]
257    fn test_depacketizer_with_w_zero() {
258        let mut depacketizer = Av1Depacketizer::new();
259
260        // Packet with W=0 means each OBU has length prefix
261        // Aggregation header: W=0
262        // Length field (LEB128): 4 bytes
263        // OBU: header + payload
264        let payload = Bytes::from(vec![
265            0x00, // Aggregation header: W=0
266            0x04, // Length field: 4 bytes
267            0x30, // OBU header: type=6 (Frame), no ext, no size
268            0x01, 0x02, 0x03, // OBU payload (3 bytes, total OBU = 4)
269        ]);
270
271        let result = depacketizer.depacketize(&payload).unwrap();
272        assert!(!result.is_empty());
273        // Should have size field added
274        assert_eq!(result[0] & OBU_HAS_SIZE_BIT, OBU_HAS_SIZE_BIT);
275    }
276
277    #[test]
278    fn test_is_partition_head() {
279        let depacketizer = Av1Depacketizer::new();
280
281        // Z=0 means partition head
282        let payload = Bytes::from(vec![0x10, 0x30]);
283        assert!(depacketizer.is_partition_head(&payload));
284
285        // Z=1 means continuation
286        let payload = Bytes::from(vec![0x90, 0x30]);
287        assert!(!depacketizer.is_partition_head(&payload));
288    }
289
290    #[test]
291    fn test_write_leb128() {
292        let mut buf = BytesMut::new();
293
294        // Test small values
295        write_leb128(&mut buf, 0);
296        assert_eq!(buf.as_ref(), &[0x00]);
297
298        buf.clear();
299        write_leb128(&mut buf, 127);
300        assert_eq!(buf.as_ref(), &[0x7f]);
301
302        buf.clear();
303        write_leb128(&mut buf, 128);
304        assert_eq!(buf.as_ref(), &[0x80, 0x01]);
305
306        buf.clear();
307        write_leb128(&mut buf, 16383);
308        assert_eq!(buf.as_ref(), &[0xff, 0x7f]);
309    }
310
311    #[test]
312    fn test_skip_temporal_delimiter() {
313        let mut depacketizer = Av1Depacketizer::new();
314
315        // Packet with temporal delimiter OBU (type=2) which should be skipped
316        let payload = Bytes::from(vec![
317            0x10, // Aggregation header: W=1
318            0x12, // OBU header: type=2 (Temporal Delimiter), no ext, with size
319            0x00, // Size = 0
320        ]);
321
322        let result = depacketizer.depacketize(&payload).unwrap();
323        // Should be empty since temporal delimiter is skipped
324        assert!(result.is_empty());
325    }
326}