use super::bitreader::BitReader;
use super::vlc_tables::{
match_dc_size, match_vlc, AcSymbol, AcTablePtr, DcTablePtr, AC_TABLE_B14, AC_TABLE_B15,
DC_SIZE_CHROMA, DC_SIZE_LUMA,
};
use super::zigzag::{place_in_raster, scan_table};
use super::Mpeg2Error;
use super::Mpeg2Result;
#[derive(Debug, Clone, Copy)]
pub struct DcPredictors {
pub y: i32,
pub cb: i32,
pub cr: i32,
}
impl DcPredictors {
#[must_use]
pub fn reset(intra_dc_precision: u8) -> Self {
let reset_value = 1i32 << (7 + i32::from(intra_dc_precision & 0x03));
Self {
y: reset_value,
cb: reset_value,
cr: reset_value,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockComponent {
Luma,
Cb,
Cr,
}
pub fn decode_dc(
reader: &mut BitReader<'_>,
predictors: &mut DcPredictors,
component: BlockComponent,
) -> Mpeg2Result<i32> {
let dc_table: DcTablePtr = match component {
BlockComponent::Luma => DC_SIZE_LUMA,
BlockComponent::Cb | BlockComponent::Cr => DC_SIZE_CHROMA,
};
let peek = reader.peek_bits_msb_aligned();
let (dc_size, consumed) = match_dc_size(dc_table, peek)?;
reader.skip_bits(consumed)?;
let dct_dc_differential = if dc_size == 0 {
0i32
} else {
let raw = reader.read_bits(dc_size)? as i32;
if (raw >> (dc_size - 1)) & 1 == 1 {
raw
} else {
raw - ((1i32 << dc_size) - 1)
}
};
let predictor = match component {
BlockComponent::Luma => &mut predictors.y,
BlockComponent::Cb => &mut predictors.cb,
BlockComponent::Cr => &mut predictors.cr,
};
*predictor += dct_dc_differential;
Ok(*predictor)
}
pub fn decode_ac(
reader: &mut BitReader<'_>,
block: &mut [i32; 64],
intra_vlc_format: bool,
alternate_scan: bool,
) -> Mpeg2Result<()> {
let ac_table: AcTablePtr = if intra_vlc_format {
AC_TABLE_B15
} else {
AC_TABLE_B14
};
let scan = scan_table(alternate_scan);
let mut scan_index: usize = 1;
loop {
if scan_index >= 64 {
break;
}
let peek = reader.peek_bits_msb_aligned();
let symbol = match_vlc(ac_table, peek)?;
match symbol {
AcSymbol::EndOfBlock { bits } => {
reader.skip_bits(bits)?;
break;
}
AcSymbol::RunLevel { run, level, bits } => {
reader.skip_bits(bits)?;
let sign = reader.read_bit()?;
scan_index += run as usize;
if scan_index >= 64 {
return Err(Mpeg2Error::VlcDecode(format!(
"AC run overflowed block (scan_index {scan_index})"
)));
}
let value = if sign {
-i32::from(level)
} else {
i32::from(level)
};
place_in_raster(block, scan, scan_index, value);
scan_index += 1;
}
AcSymbol::Escape { bits } => {
reader.skip_bits(bits)?;
let run = reader.read_bits(6)? as usize;
let level_raw = reader.read_bits(12)? as i32;
let level = if level_raw & 0x800 != 0 {
level_raw - 0x1000
} else {
level_raw
};
if level == 0 {
return Err(Mpeg2Error::VlcDecode("escape level 0 is forbidden".into()));
}
scan_index += run;
if scan_index >= 64 {
return Err(Mpeg2Error::VlcDecode(format!(
"escape run overflowed block (scan_index {scan_index})"
)));
}
place_in_raster(block, scan, scan_index, level);
scan_index += 1;
}
}
}
Ok(())
}
pub fn decode_intra_block(
reader: &mut BitReader<'_>,
predictors: &mut DcPredictors,
component: BlockComponent,
intra_vlc_format: bool,
alternate_scan: bool,
) -> Mpeg2Result<[i32; 64]> {
let mut block = [0i32; 64];
block[0] = decode_dc(reader, predictors, component)?;
decode_ac(reader, &mut block, intra_vlc_format, alternate_scan)?;
Ok(block)
}
#[cfg(test)]
mod tests {
use super::*;
fn push_bits(bits: &mut Vec<u8>, value: u32, len: u8) {
for i in (0..len).rev() {
bits.push(((value >> i) & 1) as u8);
}
}
fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
let mut padded = bits.to_vec();
while padded.len() % 8 != 0 {
padded.push(0);
}
padded
.chunks(8)
.map(|c| c.iter().fold(0u8, |acc, &b| (acc << 1) | b))
.collect()
}
#[test]
fn dc_predictor_reset_values() {
assert_eq!(DcPredictors::reset(0).y, 128);
assert_eq!(DcPredictors::reset(1).y, 256);
assert_eq!(DcPredictors::reset(2).y, 512);
assert_eq!(DcPredictors::reset(3).y, 1024);
}
#[test]
fn decode_dc_zero_size() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b100, 3); let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors::reset(0);
let dc = decode_dc(&mut r, &mut preds, BlockComponent::Luma).expect("dc");
assert_eq!(dc, 128);
assert_eq!(preds.y, 128);
}
#[test]
fn decode_dc_positive_diff() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b01, 2); push_bits(&mut bits, 0b11, 2); let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors::reset(0);
let dc = decode_dc(&mut r, &mut preds, BlockComponent::Luma).expect("dc");
assert_eq!(dc, 131);
}
#[test]
fn decode_dc_negative_diff() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b01, 2);
push_bits(&mut bits, 0b00, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors::reset(0);
let dc = decode_dc(&mut r, &mut preds, BlockComponent::Luma).expect("dc");
assert_eq!(dc, 125);
}
#[test]
fn decode_dc_chroma_uses_b13() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b00, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors::reset(0);
let dc = decode_dc(&mut r, &mut preds, BlockComponent::Cb).expect("dc");
assert_eq!(dc, 128);
}
#[test]
fn decode_ac_immediate_eob() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b10, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
block[0] = 100;
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[0], 100);
assert!(block[1..].iter().all(|&v| v == 0));
}
#[test]
fn decode_ac_one_run_level_then_eob() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b11, 2); push_bits(&mut bits, 0, 1); push_bits(&mut bits, 0b10, 2); let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[1], 1);
assert!(block[2..].iter().all(|&v| v == 0));
}
#[test]
fn decode_ac_negative_level() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b11, 2);
push_bits(&mut bits, 1, 1); push_bits(&mut bits, 0b10, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[1], -1);
}
#[test]
fn decode_ac_run_skips_positions() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b011, 3);
push_bits(&mut bits, 0, 1);
push_bits(&mut bits, 0b10, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[8], 1);
assert_eq!(block[1], 0);
}
#[test]
fn decode_ac_escape_sequence() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b000001, 6); push_bits(&mut bits, 0, 6); push_bits(&mut bits, 100, 12); push_bits(&mut bits, 0b10, 2); let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[1], 100);
}
#[test]
fn decode_ac_escape_negative_level() {
let neg = (0x1000 - 50) as u32;
let mut bits = Vec::new();
push_bits(&mut bits, 0b000001, 6);
push_bits(&mut bits, 0, 6);
push_bits(&mut bits, neg, 12);
push_bits(&mut bits, 0b10, 2);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, false, false).expect("ac");
assert_eq!(block[1], -50);
}
#[test]
fn decode_intra_block_dc_only() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b100, 3); push_bits(&mut bits, 0b10, 2); let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors::reset(0);
let block = decode_intra_block(&mut r, &mut preds, BlockComponent::Luma, false, false)
.expect("block");
assert_eq!(block[0], 128);
assert!(block[1..].iter().all(|&v| v == 0));
}
#[test]
fn decode_ac_uses_b15_when_intra_vlc_set() {
let mut bits = Vec::new();
push_bits(&mut bits, 0b0110, 4);
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, true, false).expect("ac");
assert!(block[1..].iter().all(|&v| v == 0));
}
}