1use crate::nal::{Nal, NalUnitType};
2use crate::pps::Pps;
3use crate::sei::{SeiMessage, SeiPayload};
4use crate::slice::{PictureId, SliceHeader};
5use crate::sps::Sps;
6use std::borrow::Cow;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum AccessUnitKind {
11 Idr,
12 RecoveryPoint(u32),
13 NonIdr,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccessUnit {
18 pub nals: Vec<Nal>,
19 pub is_keyframe: bool,
20 pub kind: AccessUnitKind,
21 pub sps: Option<Arc<Sps>>,
22 pub pps: Option<Arc<Pps>>,
23 pub picture_id: Option<PictureId>,
24}
25
26impl AccessUnit {
27 pub fn new() -> Self {
28 Self {
29 nals: Vec::new(),
30 is_keyframe: false,
31 kind: AccessUnitKind::NonIdr,
32 sps: None,
33 pps: None,
34 picture_id: None,
35 }
36 }
37
38 pub fn is_keyframe(&self) -> bool {
39 self.is_keyframe
40 }
41
42 pub fn nals(&self) -> impl Iterator<Item = &Nal> {
43 self.nals.iter()
44 }
45
46 pub fn to_annexb_bytes(&self) -> Cow<'_, [u8]> {
47 let mut bytes = Vec::new();
48
49 for nal in &self.nals {
50 let start_code = if nal.start_code_len == 4 {
51 &[0x00, 0x00, 0x00, 0x01][..]
52 } else {
53 &[0x00, 0x00, 0x01][..]
54 };
55
56 bytes.extend_from_slice(start_code);
57
58 let header = ((nal.ref_idc & 0b11) << 5) | (nal.nal_type.as_u8() & 0b11111);
59 bytes.push(header);
60
61 bytes.extend_from_slice(&nal.ebsp);
62 }
63
64 Cow::Owned(bytes)
65 }
66
67 pub fn add_nal(&mut self, nal: Nal) {
68 if nal.nal_type == NalUnitType::IdrSlice {
69 self.kind = AccessUnitKind::Idr;
70 self.is_keyframe = true;
71 }
72
73 self.nals.push(nal);
74 }
75
76 pub fn set_sps(&mut self, sps: Arc<Sps>) {
77 self.sps = Some(sps);
78 }
79
80 pub fn set_pps(&mut self, pps: Arc<Pps>) {
81 self.pps = Some(pps);
82 }
83
84 pub fn check_recovery_point(&mut self) {
85 for nal in &self.nals {
86 if nal.nal_type == NalUnitType::Sei {
87 let rbsp = nal.to_rbsp();
88 if let Ok(messages) = SeiMessage::parse(&rbsp) {
89 for msg in messages {
90 if let SeiPayload::RecoveryPoint {
91 recovery_frame_cnt, ..
92 } = msg.payload
93 {
94 if recovery_frame_cnt == 0 {
95 self.kind = AccessUnitKind::RecoveryPoint(0);
96 self.is_keyframe = true;
97 } else {
98 self.kind = AccessUnitKind::RecoveryPoint(recovery_frame_cnt);
99 }
100 }
101 }
102 }
103 }
104 }
105 }
106
107 pub fn set_picture_id_from_slice(
108 &mut self,
109 slice_header: &SliceHeader,
110 nal_type: NalUnitType,
111 sps: &Sps,
112 ) {
113 self.picture_id = Some(PictureId::from_slice_header(slice_header, nal_type, sps));
114 }
115}
116
117pub struct AccessUnitBuilder {
118 current_au: Option<AccessUnit>,
119 current_picture_id: Option<PictureId>,
120}
121
122impl AccessUnitBuilder {
123 pub fn new() -> Self {
124 Self {
125 current_au: None,
126 current_picture_id: None,
127 }
128 }
129
130 pub fn is_au_boundary(
131 &self,
132 nal: &Nal,
133 slice_header: Option<&SliceHeader>,
134 sps: Option<&Sps>,
135 ) -> bool {
136 if nal.nal_type == NalUnitType::Aud {
137 return true;
138 }
139
140 if !nal.is_vcl() {
141 return false;
142 }
143
144 if self.current_picture_id.is_none() {
145 return true;
146 }
147
148 if let (Some(header), Some(sps)) = (slice_header, sps) {
149 let new_picture_id = PictureId::from_slice_header(header, nal.nal_type, sps);
150
151 if let Some(ref current_id) = self.current_picture_id {
152 return &new_picture_id != current_id;
153 }
154 }
155
156 false
157 }
158
159 pub fn add_nal(
160 &mut self,
161 nal: Nal,
162 slice_header: Option<SliceHeader>,
163 sps: Option<Arc<Sps>>,
164 pps: Option<Arc<Pps>>,
165 ) -> Option<AccessUnit> {
166 let is_boundary = if let (Some(ref header), Some(ref sps_ref)) = (&slice_header, &sps) {
167 self.is_au_boundary(&nal, Some(header), Some(sps_ref))
168 } else {
169 self.is_au_boundary(&nal, None, None)
170 };
171
172 let mut completed_au = None;
173
174 if is_boundary && self.current_au.is_some() {
175 if let Some(mut au) = self.current_au.take() {
176 au.check_recovery_point();
177 completed_au = Some(au);
178 }
179 self.current_picture_id = None;
180 }
181
182 if self.current_au.is_none() {
183 self.current_au = Some(AccessUnit::new());
184 }
185
186 if let Some(ref mut au) = self.current_au {
187 if let Some(sps) = sps {
188 au.set_sps(sps);
189 }
190
191 if let Some(pps) = pps {
192 au.set_pps(pps);
193 }
194
195 if let (Some(header), Some(ref sps_ref)) = (slice_header, &au.sps) {
196 let picture_id = PictureId::from_slice_header(&header, nal.nal_type, sps_ref);
197 self.current_picture_id = Some(picture_id.clone());
198 au.picture_id = Some(picture_id);
199 }
200
201 au.add_nal(nal);
202 }
203
204 completed_au
205 }
206
207 pub fn flush(mut self) -> Option<AccessUnit> {
208 if let Some(mut au) = self.current_au.take() {
209 au.check_recovery_point();
210 Some(au)
211 } else {
212 None
213 }
214 }
215
216 pub fn flush_pending(&mut self) -> Option<AccessUnit> {
217 if let Some(mut au) = self.current_au.take() {
218 au.check_recovery_point();
219 self.current_picture_id = None;
220 Some(au)
221 } else {
222 None
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_access_unit_keyframe_detection() {
233 let mut au = AccessUnit::new();
234 assert!(!au.is_keyframe());
235
236 let idr_nal = Nal {
237 start_code_len: 4,
238 ref_idc: 3,
239 nal_type: NalUnitType::IdrSlice,
240 ebsp: Vec::new(),
241 };
242
243 au.add_nal(idr_nal);
244 assert!(au.is_keyframe());
245 assert_eq!(au.kind, AccessUnitKind::Idr);
246 }
247
248 #[test]
249 fn test_to_annexb_bytes() {
250 let mut au = AccessUnit::new();
251
252 let nal = Nal {
253 start_code_len: 3,
254 ref_idc: 2,
255 nal_type: NalUnitType::Sps,
256 ebsp: vec![0x42, 0x00, 0x1f],
257 };
258
259 au.add_nal(nal);
260
261 let bytes = au.to_annexb_bytes();
262 assert_eq!(&bytes[0..3], &[0x00, 0x00, 0x01]);
263 assert_eq!(bytes[3], 0x47);
264 assert_eq!(&bytes[4..], &[0x42, 0x00, 0x1f]);
265 }
266}