use super::bitreader::BitReader;
use super::vlc_tables::{build_ac_table, build_dc_table_8bit, unpack_ac_value, VlcTable};
use super::DecodeError;
pub fn decode_dc_coefficient(
reader: &mut BitReader<'_>,
dc_table: &VlcTable,
prev_dc: i16,
) -> Result<i16, DecodeError> {
let peek_bits = dc_table.index_bits as u8;
let avail = reader.remaining_bits();
if avail == 0 {
return Err(DecodeError::Entropy("DC: empty bitstream".into()));
}
let actual_peek = peek_bits.min(avail as u8);
let raw = reader.read_bits_u32(actual_peek)?;
let shifted = raw << (32 - actual_peek as u32);
let (size_val, consumed) = dc_table
.lookup(shifted)
.ok_or_else(|| DecodeError::Entropy(format!("DC VLC not found, bits={shifted:032b}")))?;
let _ = (size_val, consumed, shifted);
Err(DecodeError::Entropy(
"internal: use decode_dc_sequential instead".into(),
))
}
pub fn decode_dc_sequential(
reader: &mut BitReader<'_>,
dc_table_entries: &[(u32, u8, i16)],
prev_dc: i16,
) -> Result<i16, DecodeError> {
let mut shift_reg: u32 = 0;
let mut bits_read: u8 = 0;
let max_len: u8 = dc_table_entries
.iter()
.map(|&(_, l, _)| l)
.max()
.unwrap_or(10);
let size_cat: u8 = loop {
if bits_read > max_len {
return Err(DecodeError::Entropy(format!(
"DC VLC not found after {bits_read} bits"
)));
}
let bit = reader.read_bit()? as u32;
shift_reg = (shift_reg << 1) | bit;
bits_read += 1;
let mut found = None;
for &(code, len, value) in dc_table_entries {
if len == bits_read && code == shift_reg {
found = Some(value as u8);
break;
}
}
if let Some(cat) = found {
break cat;
}
};
if size_cat == 0 {
return Ok(prev_dc);
}
let mag_bits = reader.read_bits_u32(size_cat)?;
let diff: i16 = if (mag_bits >> (size_cat - 1)) & 1 == 1 {
mag_bits as i16
} else {
let mask = (1u32 << size_cat) - 1;
let inv = (!mag_bits) & mask;
-(inv as i16)
};
Ok(prev_dc.wrapping_add(diff))
}
pub fn dc_table_entries_8bit() -> Vec<(u32, u8, i16)> {
use super::vlc_tables::DC_TABLE_8BIT;
DC_TABLE_8BIT
.iter()
.enumerate()
.map(|(size, e)| {
let code = (e.code as u32) >> (16 - e.len as u32);
(code, e.len, size as i16)
})
.collect()
}
pub fn decode_ac_coefficients(
reader: &mut BitReader<'_>,
ac_table: &VlcTable,
) -> Result<[i16; 64], DecodeError> {
let mut coeffs = [0i16; 64];
let mut pos: usize = 1;
while pos < 64 {
if reader.remaining_bits() == 0 {
break;
}
let peek = ac_table.index_bits;
let avail = reader.remaining_bits().min(peek as usize) as u8;
if avail == 0 {
break;
}
let mut shift_reg: u32 = 0;
let mut bits_read: u8 = 0;
let mut found: Option<(u8, u16, bool)> = None;
let max_bits: u8 = peek.min(12);
while bits_read < max_bits {
if reader.remaining_bits() == 0 {
break;
}
let bit = reader.read_bit()? as u32;
shift_reg = (shift_reg << 1) | bit;
bits_read += 1;
let aligned = shift_reg << (32 - bits_read as u32);
if let Some((val, consumed)) = ac_table.lookup(aligned) {
if consumed == bits_read {
let (run, level, last) = unpack_ac_value(val);
found = Some((run, level, last));
break;
}
}
}
let (run, level, last) = match found {
Some(entry) => entry,
None => {
break;
}
};
pos += run as usize;
if pos >= 64 {
break;
}
if level == 0 && last {
break;
}
if level > 0 {
let sign = if reader.remaining_bits() > 0 {
reader.read_bit()?
} else {
false
};
coeffs[pos] = if sign { -(level as i16) } else { level as i16 };
pos += 1;
}
if last {
break;
}
}
Ok(coeffs)
}
pub fn dequantize_block(coeffs: &[i16; 64], quant_matrix: &[u8; 64], qscale: u16) -> [i32; 64] {
let mut out = [0i32; 64];
out[0] = i32::from(coeffs[0]);
for i in 1..64 {
out[i] = i32::from(coeffs[i]) * i32::from(quant_matrix[i]) * i32::from(qscale);
}
out
}
pub const QUANT_MATRIX_DEFAULT: [u8; 64] = [1u8; 64];
pub const QUANT_MATRIX_LUMA_8BIT: [u8; 64] = [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
pub const QUANT_MATRIX_CHROMA_8BIT: [u8; 64] = QUANT_MATRIX_LUMA_8BIT;
#[cfg(test)]
mod tests {
use super::*;
use crate::dnxhd::bitreader::BitReader;
use crate::dnxhd::vlc_tables::build_ac_table;
fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
let mut padded = bits.to_vec();
while padded.len() % 8 != 0 {
padded.push(0);
}
padded
.chunks(8)
.map(|c| c.iter().fold(0u8, |acc, &b| (acc << 1) | b))
.collect()
}
#[test]
fn decode_dc_zero_diff() {
let entries = dc_table_entries_8bit();
let bits: Vec<u8> = vec![1, 0, 0, 0, 0, 0, 0, 0];
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let dc = decode_dc_sequential(&mut r, &entries, 100).unwrap();
assert_eq!(dc, 100); }
#[test]
fn decode_dc_positive_diff() {
let entries = dc_table_entries_8bit();
let bits: Vec<u8> = vec![0, 0, 1, 1, 1, 0, 0, 0];
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let dc = decode_dc_sequential(&mut r, &entries, 50).unwrap();
assert_eq!(dc, 53); }
#[test]
fn decode_dc_negative_diff() {
let entries = dc_table_entries_8bit();
let bits: Vec<u8> = vec![0, 0, 1, 0, 0, 0, 0, 0];
let bytes = bits_to_bytes(&bits);
let mut r = BitReader::new(&bytes);
let dc = decode_dc_sequential(&mut r, &entries, 50).unwrap();
assert_eq!(dc, 47); }
#[test]
fn decode_ac_eob() {
let ac_table = build_ac_table();
let bytes = vec![0b10000000u8];
let mut r = BitReader::new(&bytes);
let coeffs = decode_ac_coefficients(&mut r, &ac_table).unwrap();
assert!(coeffs[1..].iter().all(|&v| v == 0));
}
#[test]
fn dequantize_identity_matrix() {
let mut coeffs = [0i16; 64];
coeffs[0] = 128;
coeffs[1] = 5;
coeffs[2] = -3;
let result = dequantize_block(&coeffs, &QUANT_MATRIX_DEFAULT, 1);
assert_eq!(result[0], 128);
assert_eq!(result[1], 5);
assert_eq!(result[2], -3);
}
#[test]
fn dequantize_scales_ac() {
let mut coeffs = [0i16; 64];
coeffs[1] = 2;
let mut matrix = QUANT_MATRIX_DEFAULT;
matrix[1] = 3;
let result = dequantize_block(&coeffs, &matrix, 2);
assert_eq!(result[1], 12);
}
}