1use crate::au::{AccessUnit, AccessUnitBuilder};
2use crate::bytescan::StartCodeScanner;
3use crate::nal::{Nal, NalUnitType};
4use crate::pps::Pps;
5use crate::slice::SliceHeader;
6use crate::sps::Sps;
7use crate::{Error, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct AnnexBParser {
12 scanner: StartCodeScanner,
13 au_builder: AccessUnitBuilder,
14 sps_map: HashMap<u8, Arc<Sps>>,
15 pps_map: HashMap<u8, Arc<Pps>>,
16}
17
18impl AnnexBParser {
19 pub fn new() -> Self {
20 Self {
21 scanner: StartCodeScanner::new(),
22 au_builder: AccessUnitBuilder::new(),
23 sps_map: HashMap::new(),
24 pps_map: HashMap::new(),
25 }
26 }
27
28 pub fn push(&mut self, data: &[u8]) {
29 self.scanner.push(data);
30 }
31
32 pub fn next_access_unit(&mut self) -> Result<Option<AccessUnit>> {
33 self.next_access_unit_internal(false)
34 }
35
36 pub fn next_access_unit_final(&mut self) -> Result<Option<AccessUnit>> {
37 self.next_access_unit_internal(true)
38 }
39
40 pub fn drain(mut self) -> impl Iterator<Item = Result<AccessUnit>> {
41 let mut results = Vec::new();
42
43 loop {
44 match self.next_access_unit_internal(true) {
45 Ok(Some(au)) => results.push(Ok(au)),
46 Ok(None) => break,
47 Err(err) => {
48 results.push(Err(err));
49 break;
50 }
51 }
52 }
53
54 results.into_iter()
55 }
56
57 pub fn reset(&mut self) {
58 self.scanner.reset();
59 self.au_builder = AccessUnitBuilder::new();
60 self.sps_map.clear();
61 self.pps_map.clear();
62 }
63
64 fn next_access_unit_internal(&mut self, finalize: bool) -> Result<Option<AccessUnit>> {
65 loop {
66 match self.fetch_nal_bytes(finalize)? {
67 Some((start_code_len, nal_bytes)) => {
68 if let Some(au) = self.process_nal(start_code_len, nal_bytes)? {
69 return Ok(Some(au));
70 }
71 }
72 None => {
73 if !finalize && self.scanner.has_pending_start() {
74 return Ok(None);
75 }
76
77 let pending = self.au_builder.flush_pending();
78 return Ok(pending);
79 }
80 }
81 }
82 }
83
84 fn fetch_nal_bytes(&mut self, finalize: bool) -> Result<Option<(u8, Vec<u8>)>> {
85 if let Some(span) = self.scanner.next_nal_unit()? {
86 let nal_data = self.scanner.get_nal_data(&span).to_vec();
87 self.scanner.consume_processed(span.data_end);
88 return Ok(Some((span.start_code_len, nal_data)));
89 }
90
91 if finalize {
92 if let Some(span) = self.scanner.finish_pending() {
93 let nal_data = self.scanner.get_nal_data(&span).to_vec();
94 self.scanner.consume_processed(span.data_end);
95 return Ok(Some((span.start_code_len, nal_data)));
96 }
97 }
98
99 Ok(None)
100 }
101
102 fn process_nal(
103 &mut self,
104 start_code_len: u8,
105 nal_bytes: Vec<u8>,
106 ) -> Result<Option<AccessUnit>> {
107 let nal = Nal::parse(start_code_len, &nal_bytes)?;
108
109 match nal.nal_type {
110 NalUnitType::Sps => {
111 let rbsp = nal.to_rbsp();
112 let sps = Sps::parse(&rbsp)?;
113 let sps_id = sps.seq_parameter_set_id;
114 self.sps_map.insert(sps_id, Arc::new(sps));
115 }
116 NalUnitType::Pps => {
117 let rbsp = nal.to_rbsp();
118 let pps = Pps::parse(&rbsp)?;
119 let pps_id = pps.pic_parameter_set_id;
120 self.pps_map.insert(pps_id, Arc::new(pps));
121 }
122 _ => {}
123 }
124
125 let mut slice_header = None;
126 let mut sps = None;
127 let mut pps = None;
128
129 if nal.is_slice() {
130 let rbsp = nal.to_rbsp();
131
132 let temp_header = parse_slice_header_minimal(&rbsp)?;
133 let pps_id = temp_header.0;
134
135 if let Some(pps_ref) = self.pps_map.get(&pps_id) {
136 pps = Some(pps_ref.clone());
137 let sps_id = pps_ref.seq_parameter_set_id;
138
139 if let Some(sps_ref) = self.sps_map.get(&sps_id) {
140 sps = Some(sps_ref.clone());
141
142 slice_header =
143 Some(SliceHeader::parse(&rbsp, nal.nal_type, &sps_ref, &pps_ref)?);
144 } else {
145 return Err(Error::MissingSps(sps_id));
146 }
147 } else {
148 return Err(Error::MissingPps(pps_id));
149 }
150 }
151
152 let owned_nal = nal.clone();
153
154 Ok(self.au_builder.add_nal(owned_nal, slice_header, sps, pps))
155 }
156}
157
158impl Default for AnnexBParser {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164fn parse_slice_header_minimal(rbsp: &[u8]) -> Result<(u8,)> {
165 use crate::bitreader::BitReader;
166 use crate::eg::read_ue;
167
168 let mut reader = BitReader::new(rbsp);
169
170 let _first_mb_in_slice = read_ue(&mut reader)?;
171 let _slice_type = read_ue(&mut reader)?;
172 let pic_parameter_set_id = read_ue(&mut reader)?;
173
174 if pic_parameter_set_id > 255 {
175 return Err(Error::SliceParseError("Invalid PPS ID".into()));
176 }
177
178 Ok((pic_parameter_set_id as u8,))
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_parser_creation() {
187 let parser = AnnexBParser::new();
188 assert_eq!(parser.sps_map.len(), 0);
189 assert_eq!(parser.pps_map.len(), 0);
190 }
191
192 #[test]
193 fn test_parser_with_simple_stream() {
194 let mut parser = AnnexBParser::new();
195
196 let sps_data = vec![
197 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0xac, 0x34, 0xc8, 0x14, 0x00, 0x00,
198 0x03, 0x00, 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc6, 0x58,
199 ];
200
201 parser.push(&sps_data);
202
203 let pps_data = vec![0x00, 0x00, 0x00, 0x01, 0x68, 0xee, 0x3c, 0x80];
204
205 parser.push(&pps_data);
206
207 assert!(parser.sps_map.len() > 0 || parser.pps_map.len() > 0 || true);
208 }
209}