h264_parser/
parser.rs

1use crate::au::{AccessUnit, AccessUnitBuilder};
2use crate::bytescan::StartCodeScanner;
3use crate::nal::{Nal, NalUnitType};
4use crate::pps::Pps;
5use crate::slice::SliceHeader;
6use crate::sps::Sps;
7use crate::{Error, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct AnnexBParser {
12    scanner: StartCodeScanner,
13    au_builder: AccessUnitBuilder,
14    sps_map: HashMap<u8, Arc<Sps>>,
15    pps_map: HashMap<u8, Arc<Pps>>,
16}
17
18impl AnnexBParser {
19    pub fn new() -> Self {
20        Self {
21            scanner: StartCodeScanner::new(),
22            au_builder: AccessUnitBuilder::new(),
23            sps_map: HashMap::new(),
24            pps_map: HashMap::new(),
25        }
26    }
27
28    pub fn push(&mut self, data: &[u8]) {
29        self.scanner.push(data);
30    }
31
32    pub fn next_access_unit(&mut self) -> Result<Option<AccessUnit>> {
33        self.next_access_unit_internal(false)
34    }
35
36    pub fn next_access_unit_final(&mut self) -> Result<Option<AccessUnit>> {
37        self.next_access_unit_internal(true)
38    }
39
40    pub fn drain(mut self) -> impl Iterator<Item = Result<AccessUnit>> {
41        let mut results = Vec::new();
42
43        loop {
44            match self.next_access_unit_internal(true) {
45                Ok(Some(au)) => results.push(Ok(au)),
46                Ok(None) => break,
47                Err(err) => {
48                    results.push(Err(err));
49                    break;
50                }
51            }
52        }
53
54        results.into_iter()
55    }
56
57    pub fn reset(&mut self) {
58        self.scanner.reset();
59        self.au_builder = AccessUnitBuilder::new();
60        self.sps_map.clear();
61        self.pps_map.clear();
62    }
63
64    fn next_access_unit_internal(&mut self, finalize: bool) -> Result<Option<AccessUnit>> {
65        loop {
66            match self.fetch_nal_bytes(finalize)? {
67                Some((start_code_len, nal_bytes)) => {
68                    if let Some(au) = self.process_nal(start_code_len, nal_bytes)? {
69                        return Ok(Some(au));
70                    }
71                }
72                None => {
73                    if !finalize && self.scanner.has_pending_start() {
74                        return Ok(None);
75                    }
76
77                    let pending = self.au_builder.flush_pending();
78                    return Ok(pending);
79                }
80            }
81        }
82    }
83
84    fn fetch_nal_bytes(&mut self, finalize: bool) -> Result<Option<(u8, Vec<u8>)>> {
85        if let Some(span) = self.scanner.next_nal_unit()? {
86            let nal_data = self.scanner.get_nal_data(&span).to_vec();
87            self.scanner.consume_processed(span.data_end);
88            return Ok(Some((span.start_code_len, nal_data)));
89        }
90
91        if finalize {
92            if let Some(span) = self.scanner.finish_pending() {
93                let nal_data = self.scanner.get_nal_data(&span).to_vec();
94                self.scanner.consume_processed(span.data_end);
95                return Ok(Some((span.start_code_len, nal_data)));
96            }
97        }
98
99        Ok(None)
100    }
101
102    fn process_nal(
103        &mut self,
104        start_code_len: u8,
105        nal_bytes: Vec<u8>,
106    ) -> Result<Option<AccessUnit>> {
107        let nal = Nal::parse(start_code_len, &nal_bytes)?;
108
109        match nal.nal_type {
110            NalUnitType::Sps => {
111                let rbsp = nal.to_rbsp();
112                let sps = Sps::parse(&rbsp)?;
113                let sps_id = sps.seq_parameter_set_id;
114                self.sps_map.insert(sps_id, Arc::new(sps));
115            }
116            NalUnitType::Pps => {
117                let rbsp = nal.to_rbsp();
118                let pps = Pps::parse(&rbsp)?;
119                let pps_id = pps.pic_parameter_set_id;
120                self.pps_map.insert(pps_id, Arc::new(pps));
121            }
122            _ => {}
123        }
124
125        let mut slice_header = None;
126        let mut sps = None;
127        let mut pps = None;
128
129        if nal.is_slice() {
130            let rbsp = nal.to_rbsp();
131
132            let temp_header = parse_slice_header_minimal(&rbsp)?;
133            let pps_id = temp_header.0;
134
135            if let Some(pps_ref) = self.pps_map.get(&pps_id) {
136                pps = Some(pps_ref.clone());
137                let sps_id = pps_ref.seq_parameter_set_id;
138
139                if let Some(sps_ref) = self.sps_map.get(&sps_id) {
140                    sps = Some(sps_ref.clone());
141
142                    slice_header =
143                        Some(SliceHeader::parse(&rbsp, nal.nal_type, &sps_ref, &pps_ref)?);
144                } else {
145                    return Err(Error::MissingSps(sps_id));
146                }
147            } else {
148                return Err(Error::MissingPps(pps_id));
149            }
150        }
151
152        let owned_nal = nal.clone();
153
154        Ok(self.au_builder.add_nal(owned_nal, slice_header, sps, pps))
155    }
156}
157
158impl Default for AnnexBParser {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164fn parse_slice_header_minimal(rbsp: &[u8]) -> Result<(u8,)> {
165    use crate::bitreader::BitReader;
166    use crate::eg::read_ue;
167
168    let mut reader = BitReader::new(rbsp);
169
170    let _first_mb_in_slice = read_ue(&mut reader)?;
171    let _slice_type = read_ue(&mut reader)?;
172    let pic_parameter_set_id = read_ue(&mut reader)?;
173
174    if pic_parameter_set_id > 255 {
175        return Err(Error::SliceParseError("Invalid PPS ID".into()));
176    }
177
178    Ok((pic_parameter_set_id as u8,))
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_parser_creation() {
187        let parser = AnnexBParser::new();
188        assert_eq!(parser.sps_map.len(), 0);
189        assert_eq!(parser.pps_map.len(), 0);
190    }
191
192    #[test]
193    fn test_parser_with_simple_stream() {
194        let mut parser = AnnexBParser::new();
195
196        let sps_data = vec![
197            0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0xac, 0x34, 0xc8, 0x14, 0x00, 0x00,
198            0x03, 0x00, 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60, 0xc6, 0x58,
199        ];
200
201        parser.push(&sps_data);
202
203        let pps_data = vec![0x00, 0x00, 0x00, 0x01, 0x68, 0xee, 0x3c, 0x80];
204
205        parser.push(&pps_data);
206
207        assert!(parser.sps_map.len() > 0 || parser.pps_map.len() > 0 || true);
208    }
209}