h264_parser/
parser.rs

1use crate::au::{AccessUnit, AccessUnitBuilder};
2use crate::bytescan::StartCodeScanner;
3use crate::nal::{Nal, NalUnitType};
4use crate::pps::Pps;
5use crate::slice::SliceHeader;
6use crate::sps::Sps;
7use crate::{Error, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct AnnexBParser {
12    scanner: StartCodeScanner,
13    au_builder: AccessUnitBuilder,
14    sps_map: HashMap<u8, Arc<Sps>>,
15    pps_map: HashMap<u8, Arc<Pps>>,
16    sps_nal_map: HashMap<u8, Nal>,
17    pps_nal_map: HashMap<u8, Nal>,
18}
19
20impl AnnexBParser {
21    pub fn new() -> Self {
22        Self {
23            scanner: StartCodeScanner::new(),
24            au_builder: AccessUnitBuilder::new(),
25            sps_map: HashMap::new(),
26            pps_map: HashMap::new(),
27            sps_nal_map: HashMap::new(),
28            pps_nal_map: HashMap::new(),
29        }
30    }
31
32    pub fn push(&mut self, data: &[u8]) {
33        self.scanner.push(data);
34    }
35
36    pub fn next_access_unit(&mut self) -> Result<Option<AccessUnit>> {
37        self.next_access_unit_internal(false)
38    }
39
40    pub fn next_access_unit_final(&mut self) -> Result<Option<AccessUnit>> {
41        self.next_access_unit_internal(true)
42    }
43
44    pub fn drain(mut self) -> impl Iterator<Item = Result<AccessUnit>> {
45        let mut results = Vec::new();
46
47        loop {
48            match self.next_access_unit_internal(true) {
49                Ok(Some(au)) => results.push(Ok(au)),
50                Ok(None) => break,
51                Err(err) => {
52                    results.push(Err(err));
53                    break;
54                }
55            }
56        }
57
58        results.into_iter()
59    }
60
61    pub fn reset(&mut self) {
62        self.scanner.reset();
63        self.au_builder = AccessUnitBuilder::new();
64        self.sps_map.clear();
65        self.pps_map.clear();
66        self.sps_nal_map.clear();
67        self.pps_nal_map.clear();
68    }
69
70    fn next_access_unit_internal(&mut self, finalize: bool) -> Result<Option<AccessUnit>> {
71        loop {
72            match self.fetch_nal_bytes(finalize)? {
73                Some((start_code_len, nal_bytes)) => {
74                    if let Some(au) = self.process_nal(start_code_len, nal_bytes)? {
75                        return Ok(Some(au));
76                    }
77                }
78                None => {
79                    if !finalize && self.scanner.has_pending_start() {
80                        return Ok(None);
81                    }
82
83                    let pending = self.au_builder.flush_pending();
84                    return Ok(pending);
85                }
86            }
87        }
88    }
89
90    fn fetch_nal_bytes(&mut self, finalize: bool) -> Result<Option<(u8, Vec<u8>)>> {
91        if let Some(span) = self.scanner.next_nal_unit()? {
92            let nal_data = self.scanner.get_nal_data(&span).to_vec();
93            self.scanner.consume_processed(span.data_end);
94            return Ok(Some((span.start_code_len, nal_data)));
95        }
96
97        if finalize {
98            if let Some(span) = self.scanner.finish_pending() {
99                let nal_data = self.scanner.get_nal_data(&span).to_vec();
100                self.scanner.consume_processed(span.data_end);
101                return Ok(Some((span.start_code_len, nal_data)));
102            }
103        }
104
105        Ok(None)
106    }
107
108    fn process_nal(
109        &mut self,
110        start_code_len: u8,
111        nal_bytes: Vec<u8>,
112    ) -> Result<Option<AccessUnit>> {
113        let nal = Nal::parse(start_code_len, &nal_bytes)?;
114
115        match nal.nal_type {
116            NalUnitType::Sps => {
117                let rbsp = nal.to_rbsp();
118                let sps = Sps::parse(&rbsp)?;
119                let sps_id = sps.seq_parameter_set_id;
120                self.sps_map.insert(sps_id, Arc::new(sps));
121                self.sps_nal_map.insert(sps_id, nal.clone());
122            }
123            NalUnitType::Pps => {
124                let rbsp = nal.to_rbsp();
125                let pps = Pps::parse(&rbsp)?;
126                let pps_id = pps.pic_parameter_set_id;
127                self.pps_map.insert(pps_id, Arc::new(pps));
128                self.pps_nal_map.insert(pps_id, nal.clone());
129            }
130            _ => {}
131        }
132
133        let mut slice_header = None;
134        let mut sps = None;
135        let mut pps = None;
136        let mut extra_parameter_sets = Vec::new();
137
138        if nal.is_slice() {
139            let rbsp = nal.to_rbsp();
140
141            let temp_header = parse_slice_header_minimal(&rbsp)?;
142            let pps_id = temp_header.0;
143
144            if let Some(pps_ref) = self.pps_map.get(&pps_id) {
145                pps = Some(pps_ref.clone());
146                let sps_id = pps_ref.seq_parameter_set_id;
147
148                if let Some(sps_ref) = self.sps_map.get(&sps_id) {
149                    sps = Some(sps_ref.clone());
150
151                    if let Some(sps_nal) = self.sps_nal_map.get(&sps_id) {
152                        extra_parameter_sets.push(sps_nal.clone());
153                    }
154
155                    if let Some(pps_nal) = self.pps_nal_map.get(&pps_id) {
156                        extra_parameter_sets.push(pps_nal.clone());
157                    }
158
159                    slice_header =
160                        Some(SliceHeader::parse(&rbsp, nal.nal_type, &sps_ref, &pps_ref)?);
161                } else {
162                    return Err(Error::MissingSps(sps_id));
163                }
164            } else {
165                return Err(Error::MissingPps(pps_id));
166            }
167        }
168
169        let owned_nal = nal.clone();
170
171        Ok(self
172            .au_builder
173            .add_nal(owned_nal, slice_header, sps, pps, extra_parameter_sets))
174    }
175}
176
177impl Default for AnnexBParser {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183fn parse_slice_header_minimal(rbsp: &[u8]) -> Result<(u8,)> {
184    use crate::bitreader::BitReader;
185    use crate::eg::read_ue;
186
187    let mut reader = BitReader::new(rbsp);
188
189    let _first_mb_in_slice = read_ue(&mut reader)?;
190    let _slice_type = read_ue(&mut reader)?;
191    let pic_parameter_set_id = read_ue(&mut reader)?;
192
193    if pic_parameter_set_id > 255 {
194        return Err(Error::SliceParseError("Invalid PPS ID".into()));
195    }
196
197    Ok((pic_parameter_set_id as u8,))
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_parser_creation() {
206        let parser = AnnexBParser::new();
207        assert_eq!(parser.sps_map.len(), 0);
208        assert_eq!(parser.pps_map.len(), 0);
209    }
210
211    #[test]
212    fn test_parser_with_simple_stream() {
213        let mut parser = AnnexBParser::new();
214
215        let sps_data = vec![
216            0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0xac, 0x34, 0xc8, 0x14, 0x00, 0x00,
217            0x03, 0x00, 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc6, 0x58,
218        ];
219
220        parser.push(&sps_data);
221
222        let pps_data = vec![0x00, 0x00, 0x00, 0x01, 0x68, 0xee, 0x3c, 0x80];
223
224        parser.push(&pps_data);
225
226        assert!(parser.sps_map.len() > 0 || parser.pps_map.len() > 0 || true);
227    }
228}