use super::bitwriter::BitWriter;
use super::vlc_tables::{match_vlc, AcSymbol, AcTablePtr, DcTablePtr, EOB_RUN, ESCAPE_RUN};
use super::Mpeg2Error;
use super::Mpeg2Result;
#[must_use]
pub fn dc_size_category(diff: i32) -> u8 {
if diff == 0 {
0
} else {
let mag = diff.unsigned_abs();
(32 - mag.leading_zeros()) as u8
}
}
fn write_dc_differential_bits(writer: &mut BitWriter, diff: i32, size: u8) {
if size == 0 {
return;
}
let raw = if diff > 0 {
diff as u32
} else {
(diff + ((1i32 << size) - 1)) as u32
};
writer.write_bits(raw, size);
}
fn dc_size_codeword(table: DcTablePtr, size: u8) -> Mpeg2Result<(u16, u8)> {
for &(code, len, sz) in table {
if sz == size {
return Ok((code, len));
}
}
Err(Mpeg2Error::InvalidData(format!(
"no DC size codeword for size category {size}"
)))
}
pub fn encode_dc(writer: &mut BitWriter, table: DcTablePtr, diff: i32) -> Mpeg2Result<()> {
let size = dc_size_category(diff);
let (code, len) = dc_size_codeword(table, size)?;
writer.write_bits(u32::from(code), len);
write_dc_differential_bits(writer, diff, size);
Ok(())
}
fn msb_aligned(code: u32, len: u8) -> u32 {
if len == 0 {
return 0;
}
code << (32 - u32::from(len))
}
fn code_decodes_to(table: AcTablePtr, code: u32, len: u8, run: u8, level: u16) -> bool {
let peek = msb_aligned(code, len);
match match_vlc(table, peek) {
Ok(AcSymbol::RunLevel {
run: r,
level: l,
bits,
}) => r == run && l == level && bits == len,
_ => false,
}
}
fn find_ac_codeword(table: AcTablePtr, run: u8, level: u16) -> Option<(u32, u8)> {
let mut best: Option<(u32, u8)> = None;
for &(code, len, r, l) in table {
if r == run && l == level && code_decodes_to(table, code, len, run, level) {
match best {
Some((_, blen)) if blen <= len => {}
_ => best = Some((code, len)),
}
}
}
best
}
fn special_codes(table: AcTablePtr) -> Mpeg2Result<((u32, u8), (u32, u8))> {
let mut escape: Option<(u32, u8)> = None;
let mut eob: Option<(u32, u8)> = None;
for &(code, len, run, _) in table {
match run {
ESCAPE_RUN => escape = Some((code, len)),
EOB_RUN => eob = Some((code, len)),
_ => {}
}
}
match (escape, eob) {
(Some(e), Some(b)) => Ok((e, b)),
_ => Err(Mpeg2Error::InvalidData(
"AC table missing escape or EOB codeword".into(),
)),
}
}
fn write_escape(writer: &mut BitWriter, escape: (u32, u8), run: u8, signed_level: i32) {
let (code, len) = escape;
writer.write_bits(code, len);
writer.write_bits(u32::from(run), 6);
let level12 = (signed_level & 0xFFF) as u32;
writer.write_bits(level12, 12);
}
pub fn encode_ac_run_level(
writer: &mut BitWriter,
table: AcTablePtr,
run: u8,
signed_level: i32,
) -> Mpeg2Result<()> {
if signed_level == 0 || !(-2047..=2047).contains(&signed_level) {
return Err(Mpeg2Error::InvalidData(format!(
"AC level {signed_level} out of legal non-zero range [-2047, 2047]"
)));
}
if run > 63 {
return Err(Mpeg2Error::InvalidData(format!("AC run {run} exceeds 63")));
}
if let Ok(magnitude) = u16::try_from(signed_level.unsigned_abs()) {
if let Some((code, len)) = find_ac_codeword(table, run, magnitude) {
writer.write_bits(code, len);
writer.write_bit(signed_level < 0);
return Ok(());
}
}
let (escape, _eob) = special_codes(table)?;
write_escape(writer, escape, run, signed_level);
Ok(())
}
pub fn encode_eob(writer: &mut BitWriter, table: AcTablePtr) -> Mpeg2Result<()> {
let (_escape, eob) = special_codes(table)?;
let (code, len) = eob;
writer.write_bits(code, len);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mpeg2::bitreader::BitReader;
use crate::mpeg2::entropy::{decode_ac, decode_dc, BlockComponent, DcPredictors};
use crate::mpeg2::vlc_tables::{AC_TABLE_B14, AC_TABLE_B15, DC_SIZE_CHROMA, DC_SIZE_LUMA};
#[test]
fn dc_size_category_values() {
assert_eq!(dc_size_category(0), 0);
assert_eq!(dc_size_category(1), 1);
assert_eq!(dc_size_category(-1), 1);
assert_eq!(dc_size_category(2), 2);
assert_eq!(dc_size_category(3), 2);
assert_eq!(dc_size_category(-3), 2);
assert_eq!(dc_size_category(4), 3);
assert_eq!(dc_size_category(255), 8);
assert_eq!(dc_size_category(-256), 9);
}
fn dc_round_trip(table_is_luma: bool, diff: i32) {
let component = if table_is_luma {
BlockComponent::Luma
} else {
BlockComponent::Cb
};
let table = if table_is_luma {
DC_SIZE_LUMA
} else {
DC_SIZE_CHROMA
};
let mut w = BitWriter::new();
encode_dc(&mut w, table, diff).expect("encode dc");
w.write_bits(0, 16);
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let mut preds = DcPredictors { y: 0, cb: 0, cr: 0 };
let decoded = decode_dc(&mut r, &mut preds, component).expect("decode dc");
assert_eq!(decoded, diff, "DC diff {diff} (luma={table_is_luma})");
}
#[test]
fn dc_differentials_round_trip_luma() {
for d in [-255, -16, -3, -1, 0, 1, 2, 3, 100, 255, 1024, -1024] {
dc_round_trip(true, d);
}
}
#[test]
fn dc_differentials_round_trip_chroma() {
for d in [-255, -3, -1, 0, 1, 3, 100, 255, 2047, -2047] {
dc_round_trip(false, d);
}
}
fn ac_round_trip(table: AcTablePtr, intra_vlc_format: bool, run: u8, level: i32) {
let mut w = BitWriter::new();
encode_ac_run_level(&mut w, table, run, level).expect("encode ac");
encode_eob(&mut w, table).expect("encode eob");
w.write_bits(0, 24);
let bytes = w.into_bytes();
let mut r = BitReader::new(&bytes);
let mut block = [0i32; 64];
decode_ac(&mut r, &mut block, intra_vlc_format, false).expect("decode ac");
let scan_index = run as usize + 1;
let raster = crate::mpeg2::zigzag::SCAN_PROGRESSIVE[scan_index];
assert_eq!(
block[raster], level,
"AC (run={run}, level={level}) at raster {raster}"
);
for (i, &v) in block.iter().enumerate() {
if i != raster {
assert_eq!(v, 0, "stray coeff at {i}");
}
}
}
#[test]
fn ac_b14_table_entries_round_trip() {
for &(_, _, run, level) in AC_TABLE_B14 {
if run == EOB_RUN || run == ESCAPE_RUN {
continue;
}
ac_round_trip(AC_TABLE_B14, false, run, i32::from(level));
ac_round_trip(AC_TABLE_B14, false, run, -i32::from(level));
}
}
#[test]
fn ac_b15_table_entries_round_trip() {
for &(_, _, run, level) in AC_TABLE_B15 {
if run == EOB_RUN || run == ESCAPE_RUN {
continue;
}
ac_round_trip(AC_TABLE_B15, true, run, i32::from(level));
ac_round_trip(AC_TABLE_B15, true, run, -i32::from(level));
}
}
#[test]
fn ac_escape_used_for_large_level() {
ac_round_trip(AC_TABLE_B14, false, 0, 300);
ac_round_trip(AC_TABLE_B14, false, 5, -777);
ac_round_trip(AC_TABLE_B15, true, 2, 1000);
}
#[test]
fn ac_escape_used_for_large_run() {
ac_round_trip(AC_TABLE_B14, false, 40, 1);
ac_round_trip(AC_TABLE_B14, false, 50, -2);
}
#[test]
fn ac_rejects_zero_and_forbidden_level() {
let mut w = BitWriter::new();
assert!(encode_ac_run_level(&mut w, AC_TABLE_B14, 0, 0).is_err());
assert!(encode_ac_run_level(&mut w, AC_TABLE_B14, 0, -2048).is_err());
assert!(encode_ac_run_level(&mut w, AC_TABLE_B14, 0, 2048).is_err());
}
#[test]
fn colliding_b14_run16_level1_round_trips_via_escape() {
ac_round_trip(AC_TABLE_B14, false, 16, 1);
ac_round_trip(AC_TABLE_B14, false, 16, -1);
}
}