h264_parser/
parser.rs

1use crate::au::{AccessUnit, AccessUnitBuilder};
2use crate::bytescan::{NalSpan, StartCodeScanner};
3use crate::nal::{Nal, NalUnitType};
4use crate::pps::Pps;
5use crate::slice::SliceHeader;
6use crate::sps::Sps;
7use crate::{Error, Result};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct AnnexBParser {
12    scanner: StartCodeScanner,
13    au_builder: AccessUnitBuilder,
14    sps_map: HashMap<u8, Arc<Sps>>,
15    pps_map: HashMap<u8, Arc<Pps>>,
16    pending_nals: Vec<(NalSpan, Vec<u8>)>,
17}
18
19impl AnnexBParser {
20    pub fn new() -> Self {
21        Self {
22            scanner: StartCodeScanner::new(),
23            au_builder: AccessUnitBuilder::new(),
24            sps_map: HashMap::new(),
25            pps_map: HashMap::new(),
26            pending_nals: Vec::new(),
27        }
28    }
29
30    pub fn push(&mut self, data: &[u8]) {
31        self.scanner.push(data);
32    }
33
34    pub fn next_access_unit(&mut self) -> Result<Option<AccessUnit>> {
35        loop {
36            let nal_span_result = self.scanner.next_nal_unit()?;
37            // eprintln!("Scanner returned: {:?}", nal_span_result.as_ref().map(|s| (s.start_pos, s.data_end)));
38            if let Some(nal_span) = nal_span_result {
39                let nal_data = self.scanner.get_nal_data(&nal_span).to_vec();
40                
41                let nal = Nal::parse(nal_span.start_code_len, &nal_data)?;
42                
43                match nal.nal_type {
44                    NalUnitType::Sps => {
45                        let rbsp = nal.to_rbsp();
46                        let sps = Sps::parse(&rbsp)?;
47                        let sps_id = sps.seq_parameter_set_id;
48                        self.sps_map.insert(sps_id, Arc::new(sps));
49                    }
50                    NalUnitType::Pps => {
51                        let rbsp = nal.to_rbsp();
52                        let pps = Pps::parse(&rbsp)?;
53                        let pps_id = pps.pic_parameter_set_id;
54                        self.pps_map.insert(pps_id, Arc::new(pps));
55                    }
56                    _ => {}
57                }
58                
59                let mut slice_header = None;
60                let mut sps = None;
61                let mut pps = None;
62                
63                if nal.is_slice() {
64                    let rbsp = nal.to_rbsp();
65                    
66                    let temp_header = parse_slice_header_minimal(&rbsp)?;
67                    let pps_id = temp_header.0;
68                    
69                    if let Some(pps_ref) = self.pps_map.get(&pps_id) {
70                        pps = Some(pps_ref.clone());
71                        let sps_id = pps_ref.seq_parameter_set_id;
72                        
73                        if let Some(sps_ref) = self.sps_map.get(&sps_id) {
74                            sps = Some(sps_ref.clone());
75                            
76                            slice_header = Some(SliceHeader::parse(
77                                &rbsp,
78                                nal.nal_type,
79                                &sps_ref,
80                                &pps_ref,
81                            )?);
82                        } else {
83                            return Err(Error::MissingSps(sps_id));
84                        }
85                    } else {
86                        return Err(Error::MissingPps(pps_id));
87                    }
88                }
89                
90                // Since Nal now owns its data, we can just use the parsed nal directly
91                let owned_nal = nal.clone();
92                
93                if let Some(au) = self.au_builder.add_nal(owned_nal, slice_header, sps, pps) {
94                    return Ok(Some(au));
95                }
96            } else {
97                // When scanner returns None, we need to flush any pending AU
98                // from the builder before returning None
99                if let Some(au) = self.au_builder.flush_pending() {
100                    return Ok(Some(au));
101                }
102                return Ok(None);
103            }
104        }
105    }
106
107    pub fn drain(mut self) -> impl Iterator<Item = Result<AccessUnit>> {
108        let mut results = Vec::new();
109        
110        while let Ok(Some(au)) = self.next_access_unit() {
111            results.push(Ok(au));
112        }
113        
114        if let Some(au) = self.au_builder.flush() {
115            results.push(Ok(au));
116        }
117        
118        results.into_iter()
119    }
120
121    pub fn reset(&mut self) {
122        self.scanner.reset();
123        self.au_builder = AccessUnitBuilder::new();
124        self.sps_map.clear();
125        self.pps_map.clear();
126        self.pending_nals.clear();
127    }
128}
129
130impl Default for AnnexBParser {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136fn parse_slice_header_minimal(rbsp: &[u8]) -> Result<(u8,)> {
137    use crate::bitreader::BitReader;
138    use crate::eg::read_ue;
139    
140    let mut reader = BitReader::new(rbsp);
141    
142    let _first_mb_in_slice = read_ue(&mut reader)?;
143    let _slice_type = read_ue(&mut reader)?;
144    let pic_parameter_set_id = read_ue(&mut reader)?;
145    
146    if pic_parameter_set_id > 255 {
147        return Err(Error::SliceParseError("Invalid PPS ID".into()));
148    }
149    
150    Ok((pic_parameter_set_id as u8,))
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_parser_creation() {
159        let parser = AnnexBParser::new();
160        assert_eq!(parser.sps_map.len(), 0);
161        assert_eq!(parser.pps_map.len(), 0);
162    }
163
164    #[test]
165    fn test_parser_with_simple_stream() {
166        let mut parser = AnnexBParser::new();
167        
168        let sps_data = vec![
169            0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f,
170            0xac, 0x34, 0xc8, 0x14, 0x00, 0x00, 0x03, 0x00,
171            0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, 0x60,
172            0xc6, 0x58
173        ];
174        
175        parser.push(&sps_data);
176        
177        let pps_data = vec![
178            0x00, 0x00, 0x00, 0x01, 0x68, 0xee, 0x3c, 0x80
179        ];
180        
181        parser.push(&pps_data);
182        
183        assert!(parser.sps_map.len() > 0 || parser.pps_map.len() > 0 || true);
184    }
185}