h264_parser/
au.rs

1use crate::nal::{Nal, NalUnitType};
2use crate::pps::Pps;
3use crate::sei::{SeiMessage, SeiPayload};
4use crate::slice::{PictureId, SliceHeader};
5use crate::sps::Sps;
6use std::borrow::Cow;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum AccessUnitKind {
11    Idr,
12    RecoveryPoint(u32),
13    NonIdr,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccessUnit {
18    pub nals: Vec<Nal>,
19    pub is_keyframe: bool,
20    pub kind: AccessUnitKind,
21    pub sps: Option<Arc<Sps>>,
22    pub pps: Option<Arc<Pps>>,
23    pub picture_id: Option<PictureId>,
24}
25
26impl AccessUnit {
27    pub fn new() -> Self {
28        Self {
29            nals: Vec::new(),
30            is_keyframe: false,
31            kind: AccessUnitKind::NonIdr,
32            sps: None,
33            pps: None,
34            picture_id: None,
35        }
36    }
37
38    pub fn is_keyframe(&self) -> bool {
39        self.is_keyframe
40    }
41
42    pub fn nals(&self) -> impl Iterator<Item = &Nal> {
43        self.nals.iter()
44    }
45
46    pub fn to_annexb_bytes(&self) -> Cow<'_, [u8]> {
47        let mut bytes = Vec::new();
48
49        for nal in &self.nals {
50            let start_code = if nal.start_code_len == 4 {
51                &[0x00, 0x00, 0x00, 0x01][..]
52            } else {
53                &[0x00, 0x00, 0x01][..]
54            };
55
56            bytes.extend_from_slice(start_code);
57
58            let header = ((nal.ref_idc & 0b11) << 5) | (nal.nal_type.as_u8() & 0b11111);
59            bytes.push(header);
60
61            bytes.extend_from_slice(&nal.ebsp);
62        }
63
64        Cow::Owned(bytes)
65    }
66
67    pub fn add_nal(&mut self, nal: Nal) {
68        if nal.nal_type == NalUnitType::IdrSlice {
69            self.kind = AccessUnitKind::Idr;
70            self.is_keyframe = true;
71        }
72
73        self.nals.push(nal);
74    }
75
76    pub fn set_sps(&mut self, sps: Arc<Sps>) {
77        self.sps = Some(sps);
78    }
79
80    pub fn set_pps(&mut self, pps: Arc<Pps>) {
81        self.pps = Some(pps);
82    }
83
84    pub fn check_recovery_point(&mut self) {
85        for nal in &self.nals {
86            if nal.nal_type == NalUnitType::Sei {
87                let rbsp = nal.to_rbsp();
88                if let Ok(messages) = SeiMessage::parse(&rbsp) {
89                    for msg in messages {
90                        if let SeiPayload::RecoveryPoint {
91                            recovery_frame_cnt, ..
92                        } = msg.payload
93                        {
94                            if recovery_frame_cnt == 0 {
95                                self.kind = AccessUnitKind::RecoveryPoint(0);
96                                self.is_keyframe = true;
97                            } else {
98                                self.kind = AccessUnitKind::RecoveryPoint(recovery_frame_cnt);
99                            }
100                        }
101                    }
102                }
103            }
104        }
105    }
106
107    pub fn set_picture_id_from_slice(
108        &mut self,
109        slice_header: &SliceHeader,
110        nal_type: NalUnitType,
111        sps: &Sps,
112    ) {
113        self.picture_id = Some(PictureId::from_slice_header(slice_header, nal_type, sps));
114    }
115}
116
117pub struct AccessUnitBuilder {
118    current_au: Option<AccessUnit>,
119    current_picture_id: Option<PictureId>,
120}
121
122impl AccessUnitBuilder {
123    pub fn new() -> Self {
124        Self {
125            current_au: None,
126            current_picture_id: None,
127        }
128    }
129
130    pub fn is_au_boundary(
131        &self,
132        nal: &Nal,
133        slice_header: Option<&SliceHeader>,
134        sps: Option<&Sps>,
135    ) -> bool {
136        if nal.nal_type == NalUnitType::Aud {
137            return true;
138        }
139
140        if !nal.is_vcl() {
141            return false;
142        }
143
144        if self.current_picture_id.is_none() {
145            return true;
146        }
147
148        if let (Some(header), Some(sps)) = (slice_header, sps) {
149            let new_picture_id = PictureId::from_slice_header(header, nal.nal_type, sps);
150
151            if let Some(ref current_id) = self.current_picture_id {
152                return &new_picture_id != current_id;
153            }
154        }
155
156        false
157    }
158
159    pub fn add_nal(
160        &mut self,
161        nal: Nal,
162        slice_header: Option<SliceHeader>,
163        sps: Option<Arc<Sps>>,
164        pps: Option<Arc<Pps>>,
165    ) -> Option<AccessUnit> {
166        let is_boundary = if let (Some(ref header), Some(ref sps_ref)) = (&slice_header, &sps) {
167            self.is_au_boundary(&nal, Some(header), Some(sps_ref))
168        } else {
169            self.is_au_boundary(&nal, None, None)
170        };
171
172        let mut completed_au = None;
173
174        if is_boundary && self.current_au.is_some() {
175            if let Some(mut au) = self.current_au.take() {
176                au.check_recovery_point();
177                completed_au = Some(au);
178            }
179            self.current_picture_id = None;
180        }
181
182        if self.current_au.is_none() {
183            self.current_au = Some(AccessUnit::new());
184        }
185
186        if let Some(ref mut au) = self.current_au {
187            if let Some(sps) = sps {
188                au.set_sps(sps);
189            }
190
191            if let Some(pps) = pps {
192                au.set_pps(pps);
193            }
194
195            if let (Some(header), Some(ref sps_ref)) = (slice_header, &au.sps) {
196                let picture_id = PictureId::from_slice_header(&header, nal.nal_type, sps_ref);
197                self.current_picture_id = Some(picture_id.clone());
198                au.picture_id = Some(picture_id);
199            }
200
201            au.add_nal(nal);
202        }
203
204        completed_au
205    }
206
207    pub fn flush(mut self) -> Option<AccessUnit> {
208        if let Some(mut au) = self.current_au.take() {
209            au.check_recovery_point();
210            Some(au)
211        } else {
212            None
213        }
214    }
215
216    pub fn flush_pending(&mut self) -> Option<AccessUnit> {
217        if let Some(mut au) = self.current_au.take() {
218            au.check_recovery_point();
219            self.current_picture_id = None;
220            Some(au)
221        } else {
222            None
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_access_unit_keyframe_detection() {
233        let mut au = AccessUnit::new();
234        assert!(!au.is_keyframe());
235
236        let idr_nal = Nal {
237            start_code_len: 4,
238            ref_idc: 3,
239            nal_type: NalUnitType::IdrSlice,
240            ebsp: Vec::new(),
241        };
242
243        au.add_nal(idr_nal);
244        assert!(au.is_keyframe());
245        assert_eq!(au.kind, AccessUnitKind::Idr);
246    }
247
248    #[test]
249    fn test_to_annexb_bytes() {
250        let mut au = AccessUnit::new();
251
252        let nal = Nal {
253            start_code_len: 3,
254            ref_idc: 2,
255            nal_type: NalUnitType::Sps,
256            ebsp: vec![0x42, 0x00, 0x1f],
257        };
258
259        au.add_nal(nal);
260
261        let bytes = au.to_annexb_bytes();
262        assert_eq!(&bytes[0..3], &[0x00, 0x00, 0x01]);
263        assert_eq!(bytes[3], 0x47);
264        assert_eq!(&bytes[4..], &[0x42, 0x00, 0x1f]);
265    }
266}