h264_reader/nal/
pps.rs

1use super::sps::{self, ScalingList};
2use crate::nal::sps::SeqParameterSet;
3use crate::nal::sps::{SeqParamSetId, SeqParamSetIdError};
4use crate::rbsp::BitRead;
5use crate::{rbsp, Context};
6
7#[derive(Debug)]
8pub enum PpsError {
9    RbspReaderError(rbsp::BitReaderError),
10    InvalidSliceGroupMapType(u32),
11    InvalidNumSliceGroupsMinus1(u32),
12    InvalidNumRefIdx(&'static str, u32),
13    InvalidSliceGroupChangeType(u32),
14    UnknownSeqParamSetId(SeqParamSetId),
15    BadPicParamSetId(PicParamSetIdError),
16    BadSeqParamSetId(SeqParamSetIdError),
17    ScalingMatrix(sps::ScalingMatrixError),
18    InvalidSecondChromaQpIndexOffset(i32),
19    InvalidPicInitQpMinus26(i32),
20    InvalidPicInitQsMinus26(i32),
21    InvalidChromaQpIndexOffset(i32),
22    InvalidRunLengthMinus1(u32),
23    InvalidTopLeft(u32),
24    InvalidBottomRight(u32),
25    InvalidSliceGroupChangeRateMinus1(u32),
26}
27
28impl From<rbsp::BitReaderError> for PpsError {
29    fn from(e: rbsp::BitReaderError) -> Self {
30        PpsError::RbspReaderError(e)
31    }
32}
33
34#[derive(Debug, Clone)]
35pub enum SliceGroupChangeType {
36    BoxOut,
37    RasterScan,
38    WipeOut,
39}
40impl SliceGroupChangeType {
41    fn from_id(id: u32) -> Result<SliceGroupChangeType, PpsError> {
42        match id {
43            3 => Ok(SliceGroupChangeType::BoxOut),
44            4 => Ok(SliceGroupChangeType::RasterScan),
45            5 => Ok(SliceGroupChangeType::WipeOut),
46            _ => Err(PpsError::InvalidSliceGroupChangeType(id)),
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct SliceRect {
53    top_left: u32,
54    bottom_right: u32,
55}
56impl SliceRect {
57    fn read<R: BitRead>(r: &mut R, sps: &SeqParameterSet) -> Result<SliceRect, PpsError> {
58        let rect = SliceRect {
59            top_left: r.read_ue("top_left")?,
60            bottom_right: r.read_ue("bottom_right")?,
61        };
62        if rect.top_left > rect.bottom_right {
63            return Err(PpsError::InvalidTopLeft(rect.top_left));
64        }
65        if rect.bottom_right > sps.pic_size_in_map_units() {
66            return Err(PpsError::InvalidBottomRight(rect.bottom_right));
67        }
68        if rect.top_left % sps.pic_width_in_mbs() > rect.bottom_right % sps.pic_width_in_mbs() {
69            return Err(PpsError::InvalidTopLeft(rect.top_left));
70        }
71        Ok(rect)
72    }
73}
74
75#[derive(Debug, Clone)]
76pub enum SliceGroup {
77    Interleaved {
78        run_length_minus1: Vec<u32>,
79    },
80    Dispersed {
81        num_slice_groups_minus1: u32,
82    },
83    ForegroundAndLeftover {
84        rectangles: Vec<SliceRect>,
85    },
86    Changing {
87        change_type: SliceGroupChangeType,
88        num_slice_groups_minus1: u32,
89        slice_group_change_direction_flag: bool,
90        slice_group_change_rate_minus1: u32,
91    },
92    ExplicitAssignment {
93        num_slice_groups_minus1: u32,
94        slice_group_id: Vec<u32>,
95    },
96}
97impl SliceGroup {
98    fn read<R: BitRead>(
99        r: &mut R,
100        num_slice_groups_minus1: u32,
101        sps: &SeqParameterSet,
102    ) -> Result<SliceGroup, PpsError> {
103        let slice_group_map_type = r.read_ue("slice_group_map_type")?;
104        match slice_group_map_type {
105            0 => Ok(SliceGroup::Interleaved {
106                run_length_minus1: Self::read_run_lengths(r, num_slice_groups_minus1, sps)?,
107            }),
108            1 => Ok(SliceGroup::Dispersed {
109                num_slice_groups_minus1,
110            }),
111            2 => Ok(SliceGroup::ForegroundAndLeftover {
112                rectangles: Self::read_rectangles(r, num_slice_groups_minus1, sps)?,
113            }),
114            3 | 4 | 5 => {
115                let slice_group_change_direction_flag =
116                    r.read_bool("slice_group_change_direction_flag")?;
117                let slice_group_change_rate_minus1 = r.read_ue("slice_group_change_rate_minus1")?;
118                if slice_group_change_rate_minus1 > sps.pic_size_in_map_units() - 1 {
119                    return Err(PpsError::InvalidSliceGroupChangeRateMinus1(
120                        slice_group_change_rate_minus1,
121                    ));
122                }
123                Ok(SliceGroup::Changing {
124                    change_type: SliceGroupChangeType::from_id(slice_group_map_type)?,
125                    num_slice_groups_minus1,
126                    slice_group_change_direction_flag,
127                    slice_group_change_rate_minus1,
128                })
129            }
130            6 => Ok(SliceGroup::ExplicitAssignment {
131                num_slice_groups_minus1,
132                slice_group_id: Self::read_group_ids(r, num_slice_groups_minus1)?,
133            }),
134            _ => Err(PpsError::InvalidSliceGroupMapType(slice_group_map_type)),
135        }
136    }
137
138    fn read_run_lengths<R: BitRead>(
139        r: &mut R,
140        num_slice_groups_minus1: u32,
141        sps: &SeqParameterSet,
142    ) -> Result<Vec<u32>, PpsError> {
143        let mut run_lengths = Vec::with_capacity(num_slice_groups_minus1 as usize + 1);
144        for _ in 0..num_slice_groups_minus1 + 1 {
145            let run_length_minus1 = r.read_ue("run_length_minus1")?;
146            if run_length_minus1 > sps.pic_size_in_map_units() - 1 {
147                return Err(PpsError::InvalidRunLengthMinus1(run_length_minus1));
148            }
149            run_lengths.push(run_length_minus1);
150        }
151        Ok(run_lengths)
152    }
153
154    fn read_rectangles<R: BitRead>(
155        r: &mut R,
156        num_slice_groups_minus1: u32,
157        seq_parameter_set: &SeqParameterSet,
158    ) -> Result<Vec<SliceRect>, PpsError> {
159        let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1);
160        for _ in 0..num_slice_groups_minus1 + 1 {
161            run_length_minus1.push(SliceRect::read(r, seq_parameter_set)?);
162        }
163        Ok(run_length_minus1)
164    }
165
166    fn read_group_ids<R: BitRead>(
167        r: &mut R,
168        num_slice_groups_minus1: u32,
169    ) -> Result<Vec<u32>, PpsError> {
170        let pic_size_in_map_units_minus1 = r.read_ue("pic_size_in_map_units_minus1")?;
171        // TODO: avoid any panics due to failed conversions
172        let size = (1f64 + f64::from(num_slice_groups_minus1)).log2().ceil() as u32;
173        let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1);
174        for _ in 0..pic_size_in_map_units_minus1 + 1 {
175            run_length_minus1.push(r.read(size, "slice_group_id")?);
176        }
177        Ok(run_length_minus1)
178    }
179}
180
181#[derive(Debug, Clone)]
182pub struct PicScalingMatrix {
183    /// always has length 6
184    pub scaling_list4x4: Vec<ScalingList<16>>,
185    /// `Some` when `transform_8x8_mode_flag` is `true`, `None` otherwise
186    pub scaling_list8x8: Option<Vec<ScalingList<64>>>,
187}
188impl PicScalingMatrix {
189    fn read<R: BitRead>(
190        r: &mut R,
191        sps: &sps::SeqParameterSet,
192        transform_8x8_mode_flag: bool,
193    ) -> Result<Option<PicScalingMatrix>, PpsError> {
194        let pic_scaling_matrix_present_flag = r.read_bool("pic_scaling_matrix_present_flag")?;
195
196        if !pic_scaling_matrix_present_flag {
197            return Ok(None);
198        }
199
200        let count = if transform_8x8_mode_flag {
201            if sps.chroma_info.chroma_format == sps::ChromaFormat::YUV444 {
202                6
203            } else {
204                2
205            }
206        } else {
207            0
208        };
209
210        let mut scaling_list4x4 = Vec::with_capacity(6);
211        let mut scaling_list8x8 = Vec::with_capacity(count);
212
213        for i in 0..6 + count {
214            let seq_scaling_list_present_flag = r.read_bool("seq_scaling_list_present_flag")?;
215            if i < 6 {
216                scaling_list4x4.push(
217                    sps::ScalingList::<16>::read(r, seq_scaling_list_present_flag)
218                        .map_err(PpsError::ScalingMatrix)?,
219                );
220            } else {
221                scaling_list8x8.push(
222                    sps::ScalingList::<64>::read(r, seq_scaling_list_present_flag)
223                        .map_err(PpsError::ScalingMatrix)?,
224                );
225            }
226        }
227
228        let scaling_list8x8 = if scaling_list8x8.is_empty() {
229            None
230        } else {
231            Some(scaling_list8x8)
232        };
233
234        Ok(Some(PicScalingMatrix {
235            scaling_list4x4,
236            scaling_list8x8,
237        }))
238    }
239}
240
241#[derive(Debug, Clone)]
242pub struct PicParameterSetExtra {
243    pub transform_8x8_mode_flag: bool,
244    pub pic_scaling_matrix: Option<PicScalingMatrix>,
245    pub second_chroma_qp_index_offset: i32,
246}
247impl PicParameterSetExtra {
248    fn read<R: BitRead>(
249        r: &mut R,
250        sps: &sps::SeqParameterSet,
251    ) -> Result<Option<PicParameterSetExtra>, PpsError> {
252        Ok(if r.has_more_rbsp_data("transform_8x8_mode_flag")? {
253            let transform_8x8_mode_flag = r.read_bool("transform_8x8_mode_flag")?;
254            let extra = PicParameterSetExtra {
255                transform_8x8_mode_flag,
256                pic_scaling_matrix: PicScalingMatrix::read(r, sps, transform_8x8_mode_flag)?,
257                second_chroma_qp_index_offset: r.read_se("second_chroma_qp_index_offset")?,
258            };
259            if extra.second_chroma_qp_index_offset < -12 || extra.second_chroma_qp_index_offset > 12
260            {
261                return Err(PpsError::InvalidSecondChromaQpIndexOffset(
262                    extra.second_chroma_qp_index_offset,
263                ));
264            }
265            Some(extra)
266        } else {
267            None
268        })
269    }
270}
271
272#[derive(Debug, PartialEq)]
273pub enum PicParamSetIdError {
274    IdTooLarge(u32),
275}
276
277#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
278pub struct PicParamSetId(u8);
279impl PicParamSetId {
280    pub fn from_u32(id: u32) -> Result<PicParamSetId, PicParamSetIdError> {
281        if id > 255 {
282            Err(PicParamSetIdError::IdTooLarge(id))
283        } else {
284            Ok(PicParamSetId(id as u8))
285        }
286    }
287    pub fn id(self) -> u8 {
288        self.0
289    }
290}
291
292#[derive(Clone, Debug)]
293pub struct PicParameterSet {
294    pub pic_parameter_set_id: PicParamSetId,
295    pub seq_parameter_set_id: SeqParamSetId,
296    pub entropy_coding_mode_flag: bool,
297    pub bottom_field_pic_order_in_frame_present_flag: bool,
298    pub slice_groups: Option<SliceGroup>,
299    pub num_ref_idx_l0_default_active_minus1: u32,
300    pub num_ref_idx_l1_default_active_minus1: u32,
301    pub weighted_pred_flag: bool,
302    pub weighted_bipred_idc: u8,
303    pub pic_init_qp_minus26: i32,
304    pub pic_init_qs_minus26: i32,
305    pub chroma_qp_index_offset: i32,
306    pub deblocking_filter_control_present_flag: bool,
307    pub constrained_intra_pred_flag: bool,
308    pub redundant_pic_cnt_present_flag: bool,
309    pub extension: Option<PicParameterSetExtra>,
310}
311impl PicParameterSet {
312    pub fn from_bits<R: BitRead>(ctx: &Context, mut r: R) -> Result<PicParameterSet, PpsError> {
313        let pic_parameter_set_id = PicParamSetId::from_u32(r.read_ue("pic_parameter_set_id")?)
314            .map_err(PpsError::BadPicParamSetId)?;
315        let seq_parameter_set_id = SeqParamSetId::from_u32(r.read_ue("seq_parameter_set_id")?)
316            .map_err(PpsError::BadSeqParamSetId)?;
317        let seq_parameter_set = ctx
318            .sps_by_id(seq_parameter_set_id)
319            .ok_or_else(|| PpsError::UnknownSeqParamSetId(seq_parameter_set_id))?;
320        let pps = PicParameterSet {
321            pic_parameter_set_id,
322            seq_parameter_set_id,
323            entropy_coding_mode_flag: r.read_bool("entropy_coding_mode_flag")?,
324            bottom_field_pic_order_in_frame_present_flag: r
325                .read_bool("bottom_field_pic_order_in_frame_present_flag")?,
326            slice_groups: Self::read_slice_groups(&mut r, seq_parameter_set)?,
327            num_ref_idx_l0_default_active_minus1: read_num_ref_idx(
328                &mut r,
329                "num_ref_idx_l0_default_active_minus1",
330            )?,
331            num_ref_idx_l1_default_active_minus1: read_num_ref_idx(
332                &mut r,
333                "num_ref_idx_l1_default_active_minus1",
334            )?,
335            weighted_pred_flag: r.read_bool("weighted_pred_flag")?,
336            weighted_bipred_idc: r.read(2, "weighted_bipred_idc")?,
337            pic_init_qp_minus26: r.read_se("pic_init_qp_minus26")?,
338            pic_init_qs_minus26: r.read_se("pic_init_qs_minus26")?,
339            chroma_qp_index_offset: r.read_se("chroma_qp_index_offset")?,
340            deblocking_filter_control_present_flag: r
341                .read_bool("deblocking_filter_control_present_flag")?,
342            constrained_intra_pred_flag: r.read_bool("constrained_intra_pred_flag")?,
343            redundant_pic_cnt_present_flag: r.read_bool("redundant_pic_cnt_present_flag")?,
344            extension: PicParameterSetExtra::read(&mut r, seq_parameter_set)?,
345        };
346        let qp_bd_offset_y = 6 * seq_parameter_set.chroma_info.bit_depth_luma_minus8;
347        if pps.pic_init_qp_minus26 < -(26 + i32::from(qp_bd_offset_y))
348            || pps.pic_init_qp_minus26 > 25
349        {
350            return Err(PpsError::InvalidPicInitQpMinus26(pps.pic_init_qp_minus26));
351        }
352        if pps.pic_init_qs_minus26 < -26 || pps.pic_init_qs_minus26 > 25 {
353            return Err(PpsError::InvalidPicInitQsMinus26(pps.pic_init_qs_minus26));
354        }
355        if pps.chroma_qp_index_offset < -12 || pps.chroma_qp_index_offset > 12 {
356            return Err(PpsError::InvalidChromaQpIndexOffset(
357                pps.chroma_qp_index_offset,
358            ));
359        }
360        r.finish_rbsp()?;
361        Ok(pps)
362    }
363
364    fn read_slice_groups<R: BitRead>(
365        r: &mut R,
366        sps: &SeqParameterSet,
367    ) -> Result<Option<SliceGroup>, PpsError> {
368        let num_slice_groups_minus1 = r.read_ue("num_slice_groups_minus1")?;
369        if num_slice_groups_minus1 > 7 {
370            // 7 is the maximum allowed in any profile; some profiles restrict it to 0.
371            return Err(PpsError::InvalidNumSliceGroupsMinus1(
372                num_slice_groups_minus1,
373            ));
374        }
375        Ok(if num_slice_groups_minus1 > 0 {
376            Some(SliceGroup::read(r, num_slice_groups_minus1, sps)?)
377        } else {
378            None
379        })
380    }
381}
382
383fn read_num_ref_idx<R: BitRead>(r: &mut R, name: &'static str) -> Result<u32, PpsError> {
384    let val = r.read_ue(name)?;
385    if val > 31 {
386        return Err(PpsError::InvalidNumRefIdx(name, val));
387    }
388    Ok(val)
389}
390
391#[cfg(test)]
392mod test {
393    use super::*;
394    use crate::nal::sps::SeqParameterSet;
395    use hex_literal::*;
396
397    #[test]
398    fn test_it() {
399        let data = hex!(
400            "64 00 0A AC 72 84 44 26 84 00 00
401            00 04 00 00 00 CA 3C 48 96 11 80"
402        );
403        let sps = super::sps::SeqParameterSet::from_bits(rbsp::BitReader::new(&data[..]))
404            .expect("unexpected test data");
405        let mut ctx = Context::default();
406        ctx.put_seq_param_set(sps);
407        let data = hex!("E8 43 8F 13 21 30");
408        match PicParameterSet::from_bits(&ctx, rbsp::BitReader::new(&data[..])) {
409            Err(e) => panic!("failed: {:?}", e),
410            Ok(pps) => {
411                println!("pps: {:#?}", pps);
412                assert_eq!(pps.pic_parameter_set_id.id(), 0);
413                assert_eq!(pps.seq_parameter_set_id.id(), 0);
414            }
415        }
416    }
417
418    #[test]
419    fn test_transform_8x8_mode_with_scaling_matrix() {
420        let sps = hex!(
421            "64 00 29 ac 1b 1a 50 1e 00 89 f9 70 11 00 00 03 e9 00 00 bb 80 e2 60 00 04 c3 7a 00 00
422             72 70 e8 c4 b8 c4 c0 00 09 86 f4 00 00 e4 e1 d1 89 70 f8 e1 85 2c"
423        );
424        let pps = hex!(
425            "ea 8d ce 50 94 8d 18 b2 5a 55 28 4a 46 8c 59 2d 2a 50 c9 1a 31 64 b4 aa 85 48 d2 75 d5
426             25 1d 23 49 d2 7a 23 74 93 7a 49 be 95 da ad d5 3d 7a 6b 54 22 9a 4e 93 d6 ea 9f a4 ee
427             aa fd 6e bf f5 f7"
428        );
429        let sps = super::sps::SeqParameterSet::from_bits(rbsp::BitReader::new(&sps[..]))
430            .expect("unexpected test data");
431        let mut ctx = Context::default();
432        ctx.put_seq_param_set(sps);
433
434        let pps = PicParameterSet::from_bits(&ctx, rbsp::BitReader::new(&pps[..]))
435            .expect("we mis-parsed pic_scaling_matrix when transform_8x8_mode_flag is active");
436
437        // if transform_8x8_mode_flag were false or pic_scaling_matrix were None then we wouldn't
438        // be recreating the required conditions for the test
439        assert!(matches!(
440            pps.extension,
441            Some(PicParameterSetExtra {
442                transform_8x8_mode_flag: true,
443                pic_scaling_matrix: Some(PicScalingMatrix { scaling_list4x4, scaling_list8x8: Some(scaling_list8x8) }),
444                ..
445            }) if scaling_list4x4.len() == 6 && scaling_list8x8.len() == 2
446        ));
447    }
448
449    // Earlier versions of h264-reader incorrectly limited pic_parameter_set_id to at most 32,
450    // while the spec allows up to 255.  Test that a value over 32 is accepted.
451    #[test]
452    fn pps_id_greater32() {
453        // test SPS/PPS values courtesy of @astraw
454        let sps = hex!("42c01643235010020b3cf00f08846a");
455        let pps = hex!("0448e3c8");
456        let sps = sps::SeqParameterSet::from_bits(rbsp::BitReader::new(&sps[..])).unwrap();
457        let mut ctx = Context::default();
458        ctx.put_seq_param_set(sps);
459
460        let pps = PicParameterSet::from_bits(&ctx, rbsp::BitReader::new(&pps[..])).unwrap();
461
462        assert_eq!(pps.pic_parameter_set_id, PicParamSetId(33));
463    }
464
465    #[test]
466    fn invalid_pic_init_qs_minus26() {
467        let mut ctx = Context::default();
468        let sps = SeqParameterSet::from_bits(rbsp::BitReader::new(
469            &hex!("64 00 0b ac d9 42 4d f8 84")[..],
470        ))
471        .expect("sps");
472        println!("{:#?}", sps);
473        ctx.put_seq_param_set(sps);
474        let pps = PicParameterSet::from_bits(
475            &mut ctx,
476            rbsp::BitReader::new(&hex!("eb e8 02 3b 2c 8b")[..]),
477        );
478        // pic_init_qs_minus26 should be in the range [-26, 25]
479        assert!(matches!(pps, Err(PpsError::InvalidPicInitQsMinus26(-285))));
480    }
481}