#[rustfmt::skip]
const H8: [[i32; 8]; 8] = [
[ 8, 8, 8, 8, 8, 8, 8, 8],
[12, 10, 6, 3, -3, -6, -10, -12],
[ 8, 4, -4, -8, -8, -4, 4, 8],
[10, -3, -12, -6, 6, 12, 3, -10],
[ 8, -8, -8, 8, 8, -8, -8, 8],
[ 6, -12, 3, 10, -10, -3, 12, -6],
[ 4, -8, 8, -4, -4, 8, -8, 4],
[ 3, -6, 10, -12, 12, -10, 6, -3],
];
pub fn forward_dct_8x8(input: &[[i32; 8]; 8]) -> [[i32; 8]; 8] {
let mut temp = [[0i32; 8]; 8];
for i in 0..8 {
for j in 0..8 {
let mut acc = 0i32;
for k in 0..8 {
acc += H8[i][k] * input[k][j];
}
temp[i][j] = acc;
}
}
let mut output = [[0i32; 8]; 8];
for i in 0..8 {
for j in 0..8 {
let mut acc = 0i32;
for k in 0..8 {
acc += temp[i][k] * H8[j][k];
}
output[i][j] = acc;
}
}
output
}
#[inline]
fn inverse_1d_8(f: &[i32; 8]) -> [i32; 8] {
let e0 = f[0] + f[4];
let e1 = -f[3] + f[5] - f[7] - (f[7] >> 1);
let e2 = f[0] - f[4];
let e3 = f[1] + f[7] - f[3] - (f[3] >> 1);
let e4 = (f[2] >> 1) - f[6];
let e5 = -f[1] + f[7] + f[5] + (f[5] >> 1);
let e6 = f[2] + (f[6] >> 1);
let e7 = f[3] + f[5] + f[1] + (f[1] >> 1);
let g0 = e0 + e6;
let g1 = e1 + (e7 >> 2);
let g2 = e2 + e4;
let g3 = e3 + (e5 >> 2);
let g4 = e2 - e4;
let g5 = (e3 >> 2) - e5;
let g6 = e0 - e6;
let g7 = e7 - (e1 >> 2);
[
g0 + g7,
g2 + g5,
g4 + g3,
g6 + g1,
g6 - g1,
g4 - g3,
g2 - g5,
g0 - g7,
]
}
pub fn inverse_dct_8x8(input: &[[i32; 8]; 8]) -> [[i32; 8]; 8] {
let mut temp = [[0i32; 8]; 8];
for i in 0..8 {
temp[i] = inverse_1d_8(&input[i]);
}
let mut output = [[0i32; 8]; 8];
for j in 0..8 {
let col = [
temp[0][j], temp[1][j], temp[2][j], temp[3][j],
temp[4][j], temp[5][j], temp[6][j], temp[7][j],
];
let out = inverse_1d_8(&col);
for i in 0..8 {
output[i][j] = out[i];
}
}
output
}
const NORM_ADJUST_8X8: [[i32; 6]; 6] = [
[20, 18, 32, 19, 25, 24],
[22, 19, 35, 21, 28, 26],
[26, 23, 42, 24, 33, 31],
[28, 25, 45, 26, 35, 33],
[32, 28, 51, 30, 40, 38],
[36, 32, 58, 34, 46, 43],
];
const CLASS_SCAN_8X8: [u8; 16] = [
0, 3, 4, 3, 3, 1, 5, 1, 4, 5, 2, 5, 3, 1, 5, 1,
];
#[inline]
const fn class_of_8x8_pos(row: usize, col: usize) -> usize {
CLASS_SCAN_8X8[((row & 3) << 2) | (col & 3)] as usize
}
const MF_8X8: [[i32; 6]; 6] = [
[205, 228, 128, 216, 164, 171],
[186, 216, 117, 195, 146, 158],
[158, 178, 98, 171, 124, 132],
[146, 164, 91, 158, 117, 124],
[128, 146, 80, 137, 102, 108],
[114, 128, 71, 120, 89, 95],
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Slice8x8 {
Intra,
Inter,
}
pub fn quant_8x8_block(coeffs: &[[i32; 8]; 8], qp: u8, slice: Slice8x8) -> [[i16; 8]; 8] {
let qp = qp.min(51) as i32;
let m = (qp % 6) as usize;
let qbits: i32 = 16 + qp / 6;
let f_divisor: i32 = match slice {
Slice8x8::Intra => 3,
Slice8x8::Inter => 6,
};
let f: i64 = (1i64 << qbits) / (f_divisor as i64);
let mut out = [[0i16; 8]; 8];
for i in 0..8 {
for j in 0..8 {
let c = coeffs[i][j];
let abs_c = c.unsigned_abs() as i64;
let mf = MF_8X8[m][class_of_8x8_pos(i, j)] as i64;
let level_unsigned = ((abs_c * mf + f) >> qbits) as i32;
let signed = if c < 0 {
-level_unsigned
} else {
level_unsigned
};
out[i][j] = signed.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
}
out
}
pub fn dequant_8x8_block(levels: &[[i16; 8]; 8], qp: u8) -> [[i32; 8]; 8] {
let qp = qp.min(51) as i32;
let m = (qp % 6) as usize;
let qp_div_6 = qp / 6;
let mut out = [[0i32; 8]; 8];
for i in 0..8 {
for j in 0..8 {
let level = levels[i][j] as i32;
let level_scale = 16 * NORM_ADJUST_8X8[m][class_of_8x8_pos(i, j)];
let scaled = if qp_div_6 >= 6 {
level * level_scale * (1 << (qp_div_6 - 6))
} else {
let round = 1 << (5 - qp_div_6);
let shift = 6 - qp_div_6;
(level * level_scale + round) >> shift
};
out[i][j] = scaled;
}
}
out
}
#[cfg(test)]
mod tests_quant {
use super::*;
#[test]
fn mf_matches_derived_values() {
for m in 0..6 {
for c in 0..6 {
let derived = (4096 + NORM_ADJUST_8X8[m][c] / 2) / NORM_ADJUST_8X8[m][c];
assert_eq!(
MF_8X8[m][c], derived,
"MF_8X8[{m}][{c}] mismatch: table={}, derived={}",
MF_8X8[m][c], derived
);
}
}
}
#[test]
fn class_assignment_matches_spec_scan() {
assert_eq!(class_of_8x8_pos(0, 0), 0);
assert_eq!(class_of_8x8_pos(2, 2), 2);
for i in 0..8 {
for j in 0..8 {
assert_eq!(class_of_8x8_pos(i, j), class_of_8x8_pos(i + 4 * (i < 4) as usize, j));
assert_eq!(class_of_8x8_pos(i, j), class_of_8x8_pos(i, j + 4 * (j < 4) as usize));
}
}
}
#[test]
fn quant_zero_coefs_gives_zero_levels() {
let coeffs = [[0i32; 8]; 8];
let levels = quant_8x8_block(&coeffs, 24, Slice8x8::Intra);
for row in &levels {
for &v in row {
assert_eq!(v, 0);
}
}
}
#[test]
fn dequant_zero_levels_gives_zero_coefs() {
let levels = [[0i16; 8]; 8];
for qp in [0, 1, 12, 24, 36, 51] {
let coefs = dequant_8x8_block(&levels, qp);
for row in &coefs {
for &v in row {
assert_eq!(v, 0, "qp {qp}: expected 0, got {v}");
}
}
}
}
#[test]
fn roundtrip_flat_recovers_input_at_qp0() {
for v in [-64, -8, 0, 8, 16, 64, 128] {
let input = [[v; 8]; 8];
let fwd = forward_dct_8x8(&input);
let levels = quant_8x8_block(&fwd, 0, Slice8x8::Intra);
let deq = dequant_8x8_block(&levels, 0);
let inv = inverse_dct_8x8(&deq);
for row in &inv {
for &scaled in row {
let pixel = (scaled + 32) >> 6;
let err = (pixel - v).abs();
assert!(
err <= 1,
"flat v={v} qp=0 recon={pixel} err={err}",
);
}
}
}
}
#[test]
fn roundtrip_flat_tolerates_noise_at_high_qp() {
for qp in [12, 18, 24, 30, 36, 42] {
for v in [16, 64, 128] {
let input = [[v; 8]; 8];
let fwd = forward_dct_8x8(&input);
let levels = quant_8x8_block(&fwd, qp, Slice8x8::Intra);
let deq = dequant_8x8_block(&levels, qp);
let inv = inverse_dct_8x8(&deq);
let envelope = 1 << (qp / 6);
for row in &inv {
for &scaled in row {
let pixel = (scaled + 32) >> 6;
let err = (pixel - v).abs();
assert!(
err <= envelope,
"qp={qp} v={v} recon={pixel} err={err} > envelope {envelope}",
);
}
}
}
}
}
#[test]
fn intra_vs_inter_dead_zone_differs() {
let mut coefs = [[0i32; 8]; 8];
coefs[0][0] = 12288;
let lv_i = quant_8x8_block(&coefs, 18, Slice8x8::Intra)[0][0];
let lv_p = quant_8x8_block(&coefs, 18, Slice8x8::Inter)[0][0];
assert!(
lv_i >= lv_p,
"intra should preserve more than inter at same QP: intra={lv_i} inter={lv_p}",
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forward_zero_input_zero_output() {
let zero = [[0i32; 8]; 8];
let y = forward_dct_8x8(&zero);
for r in &y {
for &v in r {
assert_eq!(v, 0);
}
}
}
#[test]
fn forward_flat_input_dc_only() {
let flat = [[1i32; 8]; 8];
let y = forward_dct_8x8(&flat);
assert_eq!(y[0][0], 4096, "DC coefficient for flat-1 input");
for (i, row) in y.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
if i == 0 && j == 0 {
continue;
}
assert_eq!(v, 0, "expected 0 at ({i},{j}), got {v}");
}
}
}
#[test]
fn forward_flat_input_dc_scales_with_amplitude() {
for v in [-16, -4, 0, 4, 16, 64, 128] {
let input = [[v; 8]; 8];
let y = forward_dct_8x8(&input);
assert_eq!(y[0][0], 4096 * v);
}
}
#[test]
fn forward_column_stripe_row0_energy_only() {
let mut x = [[0i32; 8]; 8];
for i in 0..8 {
for j in 0..8 {
x[i][j] = if j & 1 == 0 { 2 } else { 0 };
}
}
let y = forward_dct_8x8(&x);
for i in 1..8 {
for j in 0..8 {
assert_eq!(y[i][j], 0, "row {i} col {j} should be zero, got {}", y[i][j]);
}
}
let row0_abs: i32 = y[0].iter().map(|v| v.abs()).sum();
assert!(row0_abs > 0);
}
#[test]
fn inverse_zero_input_zero_output() {
let zero = [[0i32; 8]; 8];
let x = inverse_dct_8x8(&zero);
for r in &x {
for &v in r {
assert_eq!(v, 0);
}
}
}
#[test]
fn inverse_dc_only_gives_flat() {
let mut f = [[0i32; 8]; 8];
f[0][0] = 64;
let x = inverse_dct_8x8(&f);
let v0 = x[0][0];
for row in &x {
for &v in row {
assert_eq!(v, v0, "DC-only input must produce a flat block");
}
}
assert!(v0 != 0, "DC-only non-zero input must produce non-zero output");
}
#[test]
fn roundtrip_flat_scales_by_4096() {
for v in [-32, -1, 0, 1, 3, 16, 64] {
let input = [[v; 8]; 8];
let fwd = forward_dct_8x8(&input);
let back = inverse_dct_8x8(&fwd);
for row in &back {
for &got in row {
assert_eq!(got, 4096 * v, "flat {v} should scale to {} per pixel", 4096 * v);
}
}
}
}
#[test]
fn roundtrip_non_flat_not_identity() {
let mut input = [[0i32; 8]; 8];
input[0][0] = 10; input[3][5] = 20; let fwd = forward_dct_8x8(&input);
let back = inverse_dct_8x8(&fwd);
let total: i32 = back.iter().flatten().map(|v| v.abs()).sum();
assert!(total > 0);
}
#[test]
fn per_basis_scale_factors() {
let mut seen_dc = false;
for i in 0..8 {
for j in 0..8 {
let mut e = [[0i32; 8]; 8];
e[i][j] = 1;
let x = inverse_dct_8x8(&e);
let y = forward_dct_8x8(&x);
if (i, j) == (0, 0) {
assert_eq!(y[0][0], 4096);
seen_dc = true;
}
}
}
assert!(seen_dc);
}
}