1use crate::au::{AccessUnit, AccessUnitBuilder};
2use crate::bytescan::{NalSpan, StartCodeScanner};
3use crate::nal::{Nal, NalUnitType};
4use crate::pps::Pps;
5use crate::slice::SliceHeader;
6use crate::sps::Sps;
7use crate::{Error, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct AnnexBParser {
12 scanner: StartCodeScanner,
13 au_builder: AccessUnitBuilder,
14 sps_map: HashMap<u8, Arc<Sps>>,
15 pps_map: HashMap<u8, Arc<Pps>>,
16 pending_nals: Vec<(NalSpan, Vec<u8>)>,
17}
18
19impl AnnexBParser {
20 pub fn new() -> Self {
21 Self {
22 scanner: StartCodeScanner::new(),
23 au_builder: AccessUnitBuilder::new(),
24 sps_map: HashMap::new(),
25 pps_map: HashMap::new(),
26 pending_nals: Vec::new(),
27 }
28 }
29
30 pub fn push(&mut self, data: &[u8]) {
31 self.scanner.push(data);
32 }
33
34 pub fn next_access_unit(&mut self) -> Result<Option<AccessUnit>> {
35 loop {
36 let nal_span_result = self.scanner.next_nal_unit()?;
37 if let Some(nal_span) = nal_span_result {
39 let nal_data = self.scanner.get_nal_data(&nal_span).to_vec();
40
41 let nal = Nal::parse(nal_span.start_code_len, &nal_data)?;
42
43 match nal.nal_type {
44 NalUnitType::Sps => {
45 let rbsp = nal.to_rbsp();
46 let sps = Sps::parse(&rbsp)?;
47 let sps_id = sps.seq_parameter_set_id;
48 self.sps_map.insert(sps_id, Arc::new(sps));
49 }
50 NalUnitType::Pps => {
51 let rbsp = nal.to_rbsp();
52 let pps = Pps::parse(&rbsp)?;
53 let pps_id = pps.pic_parameter_set_id;
54 self.pps_map.insert(pps_id, Arc::new(pps));
55 }
56 _ => {}
57 }
58
59 let mut slice_header = None;
60 let mut sps = None;
61 let mut pps = None;
62
63 if nal.is_slice() {
64 let rbsp = nal.to_rbsp();
65
66 let temp_header = parse_slice_header_minimal(&rbsp)?;
67 let pps_id = temp_header.0;
68
69 if let Some(pps_ref) = self.pps_map.get(&pps_id) {
70 pps = Some(pps_ref.clone());
71 let sps_id = pps_ref.seq_parameter_set_id;
72
73 if let Some(sps_ref) = self.sps_map.get(&sps_id) {
74 sps = Some(sps_ref.clone());
75
76 slice_header = Some(SliceHeader::parse(
77 &rbsp,
78 nal.nal_type,
79 &sps_ref,
80 &pps_ref,
81 )?);
82 } else {
83 return Err(Error::MissingSps(sps_id));
84 }
85 } else {
86 return Err(Error::MissingPps(pps_id));
87 }
88 }
89
90 let owned_nal = nal.clone();
92
93 if let Some(au) = self.au_builder.add_nal(owned_nal, slice_header, sps, pps) {
94 return Ok(Some(au));
95 }
96 } else {
97 if let Some(au) = self.au_builder.flush_pending() {
100 return Ok(Some(au));
101 }
102 return Ok(None);
103 }
104 }
105 }
106
107 pub fn drain(mut self) -> impl Iterator<Item = Result<AccessUnit>> {
108 let mut results = Vec::new();
109
110 while let Ok(Some(au)) = self.next_access_unit() {
111 results.push(Ok(au));
112 }
113
114 if let Some(au) = self.au_builder.flush() {
115 results.push(Ok(au));
116 }
117
118 results.into_iter()
119 }
120
121 pub fn reset(&mut self) {
122 self.scanner.reset();
123 self.au_builder = AccessUnitBuilder::new();
124 self.sps_map.clear();
125 self.pps_map.clear();
126 self.pending_nals.clear();
127 }
128}
129
130impl Default for AnnexBParser {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136fn parse_slice_header_minimal(rbsp: &[u8]) -> Result<(u8,)> {
137 use crate::bitreader::BitReader;
138 use crate::eg::read_ue;
139
140 let mut reader = BitReader::new(rbsp);
141
142 let _first_mb_in_slice = read_ue(&mut reader)?;
143 let _slice_type = read_ue(&mut reader)?;
144 let pic_parameter_set_id = read_ue(&mut reader)?;
145
146 if pic_parameter_set_id > 255 {
147 return Err(Error::SliceParseError("Invalid PPS ID".into()));
148 }
149
150 Ok((pic_parameter_set_id as u8,))
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_parser_creation() {
159 let parser = AnnexBParser::new();
160 assert_eq!(parser.sps_map.len(), 0);
161 assert_eq!(parser.pps_map.len(), 0);
162 }
163
164 #[test]
165 fn test_parser_with_simple_stream() {
166 let mut parser = AnnexBParser::new();
167
168 let sps_data = vec![
169 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f,
170 0xac, 0x34, 0xc8, 0x14, 0x00, 0x00, 0x03, 0x00,
171 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60,
172 0xc6, 0x58
173 ];
174
175 parser.push(&sps_data);
176
177 let pps_data = vec![
178 0x00, 0x00, 0x00, 0x01, 0x68, 0xee, 0x3c, 0x80
179 ];
180
181 parser.push(&pps_data);
182
183 assert!(parser.sps_map.len() > 0 || parser.pps_map.len() > 0 || true);
184 }
185}