h264_parser/
au.rs

1use crate::nal::{Nal, NalUnitType};
2use crate::pps::Pps;
3use crate::sei::{SeiMessage, SeiPayload};
4use crate::slice::{PictureId, SliceHeader, SliceType};
5use crate::sps::Sps;
6use std::borrow::Cow;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum AccessUnitKind {
11    Idr,
12    RecoveryPoint(u32),
13    NonIdr,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccessUnit {
18    pub nals: Vec<Nal>,
19    pub is_keyframe: bool,
20    pub kind: AccessUnitKind,
21    pub sps: Option<Arc<Sps>>,
22    pub pps: Option<Arc<Pps>>,
23    pub picture_id: Option<PictureId>,
24    parameter_sets: Vec<Nal>,
25    first_slice_type: Option<SliceType>,
26}
27
28impl AccessUnit {
29    pub fn new() -> Self {
30        Self {
31            nals: Vec::new(),
32            is_keyframe: false,
33            kind: AccessUnitKind::NonIdr,
34            sps: None,
35            pps: None,
36            picture_id: None,
37            parameter_sets: Vec::new(),
38            first_slice_type: None,
39        }
40    }
41
42    pub fn is_keyframe(&self) -> bool {
43        self.is_keyframe
44    }
45
46    pub fn nals(&self) -> impl Iterator<Item = &Nal> {
47        self.nals.iter()
48    }
49
50    pub fn to_annexb_bytes(&self) -> Cow<'_, [u8]> {
51        let mut bytes = Vec::new();
52
53        for nal in &self.nals {
54            Self::push_annexb_nal(&mut bytes, nal);
55        }
56
57        Cow::Owned(bytes)
58    }
59
60    pub fn to_annexb_webcodec_bytes(&self) -> Cow<'_, [u8]> {
61        let mut bytes = Vec::new();
62
63        let include_parameter_sets =
64            self.is_keyframe || matches!(self.first_slice_type, Some(SliceType::I));
65
66        if include_parameter_sets {
67            for nal in &self.parameter_sets {
68                Self::push_annexb_nal_internal(&mut bytes, nal, true);
69            }
70        }
71
72        for nal in &self.nals {
73            let include_nal = if include_parameter_sets {
74                !matches!(nal.nal_type, NalUnitType::Sps | NalUnitType::Pps)
75            } else {
76                true
77            };
78
79            if include_nal {
80                Self::push_annexb_nal_internal(&mut bytes, nal, true);
81            }
82        }
83
84        Cow::Owned(bytes)
85    }
86
87    pub fn add_nal(&mut self, nal: Nal) {
88        if matches!(nal.nal_type, NalUnitType::Sps | NalUnitType::Pps) {
89            self.add_parameter_set(nal.clone());
90        }
91
92        if nal.nal_type == NalUnitType::IdrSlice {
93            self.kind = AccessUnitKind::Idr;
94            self.is_keyframe = true;
95        }
96
97        self.nals.push(nal);
98    }
99
100    pub fn add_parameter_set(&mut self, nal: Nal) {
101        if !self.parameter_sets.iter().any(|existing| existing == &nal) {
102            self.parameter_sets.push(nal);
103        }
104    }
105
106    pub(crate) fn note_slice_type(&mut self, slice_type: SliceType) {
107        if self.first_slice_type.is_none() {
108            self.first_slice_type = Some(slice_type);
109        }
110    }
111
112    fn push_annexb_nal(bytes: &mut Vec<u8>, nal: &Nal) {
113        Self::push_annexb_nal_internal(bytes, nal, false);
114    }
115
116    fn push_annexb_nal_internal(bytes: &mut Vec<u8>, nal: &Nal, force_long_start_code: bool) {
117        let start_code = if force_long_start_code || nal.start_code_len == 4 {
118            &[0x00, 0x00, 0x00, 0x01][..]
119        } else {
120            &[0x00, 0x00, 0x01][..]
121        };
122
123        bytes.extend_from_slice(start_code);
124
125        let header = ((nal.ref_idc & 0b11) << 5) | (nal.nal_type.as_u8() & 0b11111);
126        bytes.push(header);
127
128        bytes.extend_from_slice(&nal.ebsp);
129    }
130
131
132    pub fn set_sps(&mut self, sps: Arc<Sps>) {
133        self.sps = Some(sps);
134    }
135
136    pub fn set_pps(&mut self, pps: Arc<Pps>) {
137        self.pps = Some(pps);
138    }
139
140    pub fn check_recovery_point(&mut self) {
141        for nal in &self.nals {
142            if nal.nal_type == NalUnitType::Sei {
143                let rbsp = nal.to_rbsp();
144                if let Ok(messages) = SeiMessage::parse(&rbsp) {
145                    for msg in messages {
146                        if let SeiPayload::RecoveryPoint {
147                            recovery_frame_cnt, ..
148                        } = msg.payload
149                        {
150                            if recovery_frame_cnt == 0 {
151                                self.kind = AccessUnitKind::RecoveryPoint(0);
152                                self.is_keyframe = true;
153                            } else {
154                                self.kind = AccessUnitKind::RecoveryPoint(recovery_frame_cnt);
155                            }
156                        }
157                    }
158                }
159            }
160        }
161    }
162
163    pub fn set_picture_id_from_slice(
164        &mut self,
165        slice_header: &SliceHeader,
166        nal_type: NalUnitType,
167        sps: &Sps,
168    ) {
169        self.picture_id = Some(PictureId::from_slice_header(slice_header, nal_type, sps));
170    }
171}
172
173pub struct AccessUnitBuilder {
174    current_au: Option<AccessUnit>,
175    current_picture_id: Option<PictureId>,
176}
177
178impl AccessUnitBuilder {
179    pub fn new() -> Self {
180        Self {
181            current_au: None,
182            current_picture_id: None,
183        }
184    }
185
186    pub fn is_au_boundary(
187        &self,
188        nal: &Nal,
189        slice_header: Option<&SliceHeader>,
190        sps: Option<&Sps>,
191    ) -> bool {
192        if nal.nal_type == NalUnitType::Aud {
193            return true;
194        }
195
196        if !nal.is_vcl() {
197            return false;
198        }
199
200        if self.current_picture_id.is_none() {
201            return false;
202        }
203
204        if let (Some(header), Some(sps)) = (slice_header, sps) {
205            let new_picture_id = PictureId::from_slice_header(header, nal.nal_type, sps);
206
207            if let Some(ref current_id) = self.current_picture_id {
208                return &new_picture_id != current_id;
209            }
210        }
211
212        false
213    }
214
215    pub fn add_nal(
216        &mut self,
217        nal: Nal,
218        slice_header: Option<SliceHeader>,
219        sps: Option<Arc<Sps>>,
220        pps: Option<Arc<Pps>>,
221        mut extra_parameter_sets: Vec<Nal>,
222    ) -> Option<AccessUnit> {
223        let is_boundary = if let (Some(ref header), Some(ref sps_ref)) = (&slice_header, &sps) {
224            self.is_au_boundary(&nal, Some(header), Some(sps_ref))
225        } else {
226            self.is_au_boundary(&nal, None, None)
227        };
228
229        let mut completed_au = None;
230
231        if is_boundary && self.current_au.is_some() {
232            if let Some(mut au) = self.current_au.take() {
233                au.check_recovery_point();
234                completed_au = Some(au);
235            }
236            self.current_picture_id = None;
237        }
238
239        if self.current_au.is_none() {
240            self.current_au = Some(AccessUnit::new());
241        }
242
243        if let Some(ref mut au) = self.current_au {
244            if let Some(sps) = sps {
245                au.set_sps(sps);
246            }
247
248            if let Some(pps) = pps {
249                au.set_pps(pps);
250            }
251
252            for parameter_set in extra_parameter_sets.drain(..) {
253                au.add_parameter_set(parameter_set);
254            }
255
256            if let Some(header) = slice_header.as_ref() {
257                if nal.is_vcl() {
258                    au.note_slice_type(header.slice_type);
259                }
260            }
261
262            if let (Some(header), Some(ref sps_ref)) = (slice_header.as_ref(), &au.sps) {
263                let picture_id = PictureId::from_slice_header(header, nal.nal_type, sps_ref);
264                self.current_picture_id = Some(picture_id.clone());
265                au.picture_id = Some(picture_id);
266            }
267
268            au.add_nal(nal);
269        }
270
271        completed_au
272    }
273
274    pub fn flush(mut self) -> Option<AccessUnit> {
275        if let Some(mut au) = self.current_au.take() {
276            au.check_recovery_point();
277            Some(au)
278        } else {
279            None
280        }
281    }
282
283    pub fn flush_pending(&mut self) -> Option<AccessUnit> {
284        if let Some(mut au) = self.current_au.take() {
285            au.check_recovery_point();
286            self.current_picture_id = None;
287            Some(au)
288        } else {
289            None
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_access_unit_keyframe_detection() {
300        let mut au = AccessUnit::new();
301        assert!(!au.is_keyframe());
302
303        let idr_nal = Nal {
304            start_code_len: 4,
305            ref_idc: 3,
306            nal_type: NalUnitType::IdrSlice,
307            ebsp: Vec::new(),
308        };
309
310        au.add_nal(idr_nal);
311        assert!(au.is_keyframe());
312        assert_eq!(au.kind, AccessUnitKind::Idr);
313    }
314
315    #[test]
316    fn test_to_annexb_bytes() {
317        let mut au = AccessUnit::new();
318
319        let nal = Nal {
320            start_code_len: 3,
321            ref_idc: 2,
322            nal_type: NalUnitType::Sps,
323            ebsp: vec![0x42, 0x00, 0x1f],
324        };
325
326        au.add_nal(nal);
327
328        let bytes = au.to_annexb_bytes();
329        assert_eq!(&bytes[0..3], &[0x00, 0x00, 0x01]);
330        assert_eq!(bytes[3], 0x47);
331        assert_eq!(&bytes[4..], &[0x42, 0x00, 0x1f]);
332    }
333
334    #[test]
335    fn test_to_annexb_webcodec_bytes_includes_parameter_sets() {
336        let mut au = AccessUnit::new();
337
338        let sps = Nal {
339            start_code_len: 4,
340            ref_idc: 3,
341            nal_type: NalUnitType::Sps,
342            ebsp: vec![0x42, 0x00, 0x1f],
343        };
344
345        let pps = Nal {
346            start_code_len: 4,
347            ref_idc: 3,
348            nal_type: NalUnitType::Pps,
349            ebsp: vec![0xde, 0xad],
350        };
351
352        let slice = Nal {
353            start_code_len: 3,
354            ref_idc: 3,
355            nal_type: NalUnitType::IdrSlice,
356            ebsp: vec![0xaa, 0xbb],
357        };
358
359        au.add_parameter_set(sps.clone());
360        au.add_parameter_set(pps.clone());
361        au.add_parameter_set(sps.clone());
362
363        au.add_nal(slice);
364
365        let bytes = au.to_annexb_webcodec_bytes();
366
367        let expected = vec![
368            0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x01, 0x68, 0xde,
369            0xad, 0x00, 0x00, 0x00, 0x01, 0x65, 0xaa, 0xbb,
370        ];
371
372        assert_eq!(bytes.as_ref(), &expected[..]);
373    }
374
375    #[test]
376    fn test_to_annexb_webcodec_bytes_for_delta_frame_excludes_parameter_sets() {
377        let mut au = AccessUnit::new();
378
379        let sps = Nal {
380            start_code_len: 4,
381            ref_idc: 3,
382            nal_type: NalUnitType::Sps,
383            ebsp: vec![0x42, 0x00, 0x1f],
384        };
385
386        let pps = Nal {
387            start_code_len: 4,
388            ref_idc: 3,
389            nal_type: NalUnitType::Pps,
390            ebsp: vec![0xde, 0xad],
391        };
392
393        let slice = Nal {
394            start_code_len: 4,
395            ref_idc: 2,
396            nal_type: NalUnitType::NonIdrSlice,
397            ebsp: vec![0x11, 0x22],
398        };
399
400        au.add_parameter_set(sps);
401        au.add_parameter_set(pps);
402        au.note_slice_type(SliceType::P);
403        au.add_nal(slice.clone());
404
405        let bytes = au.to_annexb_webcodec_bytes();
406
407        let expected = vec![0x00, 0x00, 0x00, 0x01, 0x41, 0x11, 0x22];
408        assert_eq!(bytes.as_ref(), &expected[..]);
409    }
410}