1use crate::nal::{Nal, NalUnitType};
2use crate::pps::Pps;
3use crate::sei::{SeiMessage, SeiPayload};
4use crate::slice::{PictureId, SliceHeader, SliceType};
5use crate::sps::Sps;
6use std::borrow::Cow;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum AccessUnitKind {
11 Idr,
12 RecoveryPoint(u32),
13 NonIdr,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccessUnit {
18 pub nals: Vec<Nal>,
19 pub is_keyframe: bool,
20 pub kind: AccessUnitKind,
21 pub sps: Option<Arc<Sps>>,
22 pub pps: Option<Arc<Pps>>,
23 pub picture_id: Option<PictureId>,
24 parameter_sets: Vec<Nal>,
25 first_slice_type: Option<SliceType>,
26}
27
28impl AccessUnit {
29 pub fn new() -> Self {
30 Self {
31 nals: Vec::new(),
32 is_keyframe: false,
33 kind: AccessUnitKind::NonIdr,
34 sps: None,
35 pps: None,
36 picture_id: None,
37 parameter_sets: Vec::new(),
38 first_slice_type: None,
39 }
40 }
41
42 pub fn is_keyframe(&self) -> bool {
43 self.is_keyframe
44 }
45
46 pub fn nals(&self) -> impl Iterator<Item = &Nal> {
47 self.nals.iter()
48 }
49
50 pub fn to_annexb_bytes(&self) -> Cow<'_, [u8]> {
51 let mut bytes = Vec::new();
52
53 for nal in &self.nals {
54 Self::push_annexb_nal(&mut bytes, nal);
55 }
56
57 Cow::Owned(bytes)
58 }
59
60 pub fn to_annexb_webcodec_bytes(&self) -> Cow<'_, [u8]> {
61 let mut bytes = Vec::new();
62
63 let include_parameter_sets =
64 self.is_keyframe || matches!(self.first_slice_type, Some(SliceType::I));
65
66 if include_parameter_sets {
67 for nal in &self.parameter_sets {
68 Self::push_annexb_nal_internal(&mut bytes, nal, true);
69 }
70 }
71
72 for nal in &self.nals {
73 let include_nal = if include_parameter_sets {
74 !matches!(nal.nal_type, NalUnitType::Sps | NalUnitType::Pps)
75 } else {
76 true
77 };
78
79 if include_nal {
80 Self::push_annexb_nal_internal(&mut bytes, nal, true);
81 }
82 }
83
84 Cow::Owned(bytes)
85 }
86
87 pub fn add_nal(&mut self, nal: Nal) {
88 if matches!(nal.nal_type, NalUnitType::Sps | NalUnitType::Pps) {
89 self.add_parameter_set(nal.clone());
90 }
91
92 if nal.nal_type == NalUnitType::IdrSlice {
93 self.kind = AccessUnitKind::Idr;
94 self.is_keyframe = true;
95 }
96
97 self.nals.push(nal);
98 }
99
100 pub fn add_parameter_set(&mut self, nal: Nal) {
101 if !self.parameter_sets.iter().any(|existing| existing == &nal) {
102 self.parameter_sets.push(nal);
103 }
104 }
105
106 pub(crate) fn note_slice_type(&mut self, slice_type: SliceType) {
107 if self.first_slice_type.is_none() {
108 self.first_slice_type = Some(slice_type);
109 }
110 }
111
112 fn push_annexb_nal(bytes: &mut Vec<u8>, nal: &Nal) {
113 Self::push_annexb_nal_internal(bytes, nal, false);
114 }
115
116 fn push_annexb_nal_internal(bytes: &mut Vec<u8>, nal: &Nal, force_long_start_code: bool) {
117 let start_code = if force_long_start_code || nal.start_code_len == 4 {
118 &[0x00, 0x00, 0x00, 0x01][..]
119 } else {
120 &[0x00, 0x00, 0x01][..]
121 };
122
123 bytes.extend_from_slice(start_code);
124
125 let header = ((nal.ref_idc & 0b11) << 5) | (nal.nal_type.as_u8() & 0b11111);
126 bytes.push(header);
127
128 bytes.extend_from_slice(&nal.ebsp);
129 }
130
131
132 pub fn set_sps(&mut self, sps: Arc<Sps>) {
133 self.sps = Some(sps);
134 }
135
136 pub fn set_pps(&mut self, pps: Arc<Pps>) {
137 self.pps = Some(pps);
138 }
139
140 pub fn check_recovery_point(&mut self) {
141 for nal in &self.nals {
142 if nal.nal_type == NalUnitType::Sei {
143 let rbsp = nal.to_rbsp();
144 if let Ok(messages) = SeiMessage::parse(&rbsp) {
145 for msg in messages {
146 if let SeiPayload::RecoveryPoint {
147 recovery_frame_cnt, ..
148 } = msg.payload
149 {
150 if recovery_frame_cnt == 0 {
151 self.kind = AccessUnitKind::RecoveryPoint(0);
152 self.is_keyframe = true;
153 } else {
154 self.kind = AccessUnitKind::RecoveryPoint(recovery_frame_cnt);
155 }
156 }
157 }
158 }
159 }
160 }
161 }
162
163 pub fn set_picture_id_from_slice(
164 &mut self,
165 slice_header: &SliceHeader,
166 nal_type: NalUnitType,
167 sps: &Sps,
168 ) {
169 self.picture_id = Some(PictureId::from_slice_header(slice_header, nal_type, sps));
170 }
171}
172
173pub struct AccessUnitBuilder {
174 current_au: Option<AccessUnit>,
175 current_picture_id: Option<PictureId>,
176}
177
178impl AccessUnitBuilder {
179 pub fn new() -> Self {
180 Self {
181 current_au: None,
182 current_picture_id: None,
183 }
184 }
185
186 pub fn is_au_boundary(
187 &self,
188 nal: &Nal,
189 slice_header: Option<&SliceHeader>,
190 sps: Option<&Sps>,
191 ) -> bool {
192 if nal.nal_type == NalUnitType::Aud {
193 return true;
194 }
195
196 if !nal.is_vcl() {
197 return false;
198 }
199
200 if self.current_picture_id.is_none() {
201 return false;
202 }
203
204 if let (Some(header), Some(sps)) = (slice_header, sps) {
205 let new_picture_id = PictureId::from_slice_header(header, nal.nal_type, sps);
206
207 if let Some(ref current_id) = self.current_picture_id {
208 return &new_picture_id != current_id;
209 }
210 }
211
212 false
213 }
214
215 pub fn add_nal(
216 &mut self,
217 nal: Nal,
218 slice_header: Option<SliceHeader>,
219 sps: Option<Arc<Sps>>,
220 pps: Option<Arc<Pps>>,
221 mut extra_parameter_sets: Vec<Nal>,
222 ) -> Option<AccessUnit> {
223 let is_boundary = if let (Some(ref header), Some(ref sps_ref)) = (&slice_header, &sps) {
224 self.is_au_boundary(&nal, Some(header), Some(sps_ref))
225 } else {
226 self.is_au_boundary(&nal, None, None)
227 };
228
229 let mut completed_au = None;
230
231 if is_boundary && self.current_au.is_some() {
232 if let Some(mut au) = self.current_au.take() {
233 au.check_recovery_point();
234 completed_au = Some(au);
235 }
236 self.current_picture_id = None;
237 }
238
239 if self.current_au.is_none() {
240 self.current_au = Some(AccessUnit::new());
241 }
242
243 if let Some(ref mut au) = self.current_au {
244 if let Some(sps) = sps {
245 au.set_sps(sps);
246 }
247
248 if let Some(pps) = pps {
249 au.set_pps(pps);
250 }
251
252 for parameter_set in extra_parameter_sets.drain(..) {
253 au.add_parameter_set(parameter_set);
254 }
255
256 if let Some(header) = slice_header.as_ref() {
257 if nal.is_vcl() {
258 au.note_slice_type(header.slice_type);
259 }
260 }
261
262 if let (Some(header), Some(ref sps_ref)) = (slice_header.as_ref(), &au.sps) {
263 let picture_id = PictureId::from_slice_header(header, nal.nal_type, sps_ref);
264 self.current_picture_id = Some(picture_id.clone());
265 au.picture_id = Some(picture_id);
266 }
267
268 au.add_nal(nal);
269 }
270
271 completed_au
272 }
273
274 pub fn flush(mut self) -> Option<AccessUnit> {
275 if let Some(mut au) = self.current_au.take() {
276 au.check_recovery_point();
277 Some(au)
278 } else {
279 None
280 }
281 }
282
283 pub fn flush_pending(&mut self) -> Option<AccessUnit> {
284 if let Some(mut au) = self.current_au.take() {
285 au.check_recovery_point();
286 self.current_picture_id = None;
287 Some(au)
288 } else {
289 None
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_access_unit_keyframe_detection() {
300 let mut au = AccessUnit::new();
301 assert!(!au.is_keyframe());
302
303 let idr_nal = Nal {
304 start_code_len: 4,
305 ref_idc: 3,
306 nal_type: NalUnitType::IdrSlice,
307 ebsp: Vec::new(),
308 };
309
310 au.add_nal(idr_nal);
311 assert!(au.is_keyframe());
312 assert_eq!(au.kind, AccessUnitKind::Idr);
313 }
314
315 #[test]
316 fn test_to_annexb_bytes() {
317 let mut au = AccessUnit::new();
318
319 let nal = Nal {
320 start_code_len: 3,
321 ref_idc: 2,
322 nal_type: NalUnitType::Sps,
323 ebsp: vec![0x42, 0x00, 0x1f],
324 };
325
326 au.add_nal(nal);
327
328 let bytes = au.to_annexb_bytes();
329 assert_eq!(&bytes[0..3], &[0x00, 0x00, 0x01]);
330 assert_eq!(bytes[3], 0x47);
331 assert_eq!(&bytes[4..], &[0x42, 0x00, 0x1f]);
332 }
333
334 #[test]
335 fn test_to_annexb_webcodec_bytes_includes_parameter_sets() {
336 let mut au = AccessUnit::new();
337
338 let sps = Nal {
339 start_code_len: 4,
340 ref_idc: 3,
341 nal_type: NalUnitType::Sps,
342 ebsp: vec![0x42, 0x00, 0x1f],
343 };
344
345 let pps = Nal {
346 start_code_len: 4,
347 ref_idc: 3,
348 nal_type: NalUnitType::Pps,
349 ebsp: vec![0xde, 0xad],
350 };
351
352 let slice = Nal {
353 start_code_len: 3,
354 ref_idc: 3,
355 nal_type: NalUnitType::IdrSlice,
356 ebsp: vec![0xaa, 0xbb],
357 };
358
359 au.add_parameter_set(sps.clone());
360 au.add_parameter_set(pps.clone());
361 au.add_parameter_set(sps.clone());
362
363 au.add_nal(slice);
364
365 let bytes = au.to_annexb_webcodec_bytes();
366
367 let expected = vec![
368 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x01, 0x68, 0xde,
369 0xad, 0x00, 0x00, 0x00, 0x01, 0x65, 0xaa, 0xbb,
370 ];
371
372 assert_eq!(bytes.as_ref(), &expected[..]);
373 }
374
375 #[test]
376 fn test_to_annexb_webcodec_bytes_for_delta_frame_excludes_parameter_sets() {
377 let mut au = AccessUnit::new();
378
379 let sps = Nal {
380 start_code_len: 4,
381 ref_idc: 3,
382 nal_type: NalUnitType::Sps,
383 ebsp: vec![0x42, 0x00, 0x1f],
384 };
385
386 let pps = Nal {
387 start_code_len: 4,
388 ref_idc: 3,
389 nal_type: NalUnitType::Pps,
390 ebsp: vec![0xde, 0xad],
391 };
392
393 let slice = Nal {
394 start_code_len: 4,
395 ref_idc: 2,
396 nal_type: NalUnitType::NonIdrSlice,
397 ebsp: vec![0x11, 0x22],
398 };
399
400 au.add_parameter_set(sps);
401 au.add_parameter_set(pps);
402 au.note_slice_type(SliceType::P);
403 au.add_nal(slice.clone());
404
405 let bytes = au.to_annexb_webcodec_bytes();
406
407 let expected = vec![0x00, 0x00, 0x00, 0x01, 0x41, 0x11, 0x22];
408 assert_eq!(bytes.as_ref(), &expected[..]);
409 }
410}