const N: usize = 8;
static COS_TAB: once_cell_once_lock::CosTab = once_cell_once_lock::CosTab::new();
mod once_cell_once_lock {
use std::sync::OnceLock;
pub struct CosTab {
inner: OnceLock<[[f32; 8]; 8]>,
}
impl CosTab {
pub const fn new() -> Self {
Self {
inner: OnceLock::new(),
}
}
pub fn get(&self) -> &[[f32; 8]; 8] {
self.inner.get_or_init(|| {
let mut t = [[0.0f32; 8]; 8];
for k in 0..8 {
for n in 0..8 {
let theta = std::f64::consts::PI * (2.0 * n as f64 + 1.0) * k as f64 / 16.0;
t[k][n] = theta.cos() as f32;
}
}
t
})
}
}
}
fn idct_1d(in_row: &[f32; N], out_row: &mut [f32; N]) {
let ct = COS_TAB.get();
let c0 = 1.0 / std::f32::consts::SQRT_2;
for n in 0..N {
let mut sum = 0.0f32;
for k in 0..N {
let c = if k == 0 { c0 } else { 1.0 };
sum += c * in_row[k] * ct[k][n];
}
out_row[n] = sum;
}
}
pub fn idct8x8(block: &mut [f32; 64]) {
let mut tmp = [0.0f32; 64];
let mut row_in = [0.0f32; 8];
let mut row_out = [0.0f32; 8];
for y in 0..N {
for x in 0..N {
row_in[x] = block[y * N + x];
}
idct_1d(&row_in, &mut row_out);
for x in 0..N {
tmp[y * N + x] = row_out[x] * 0.5;
}
}
let mut col_in = [0.0f32; 8];
let mut col_out = [0.0f32; 8];
for x in 0..N {
for y in 0..N {
col_in[y] = tmp[y * N + x];
}
idct_1d(&col_in, &mut col_out);
for y in 0..N {
block[y * N + x] = col_out[y] * 0.5;
}
}
}
pub fn idct_signed(coeffs: &[i32; 64], out: &mut [i32; 64]) {
let mut f = [0.0f32; 64];
for i in 0..64 {
f[i] = coeffs[i] as f32;
}
idct8x8(&mut f);
for i in 0..64 {
let v = f[i].round() as i32;
out[i] = v.clamp(-256, 255);
}
}
pub fn idct_intra(coeffs: &[i32; 64], out: &mut [u8; 64]) {
let mut f = [0.0f32; 64];
for i in 0..64 {
f[i] = coeffs[i] as f32;
}
idct8x8(&mut f);
for i in 0..64 {
let v = f[i].round() as i32;
out[i] = v.clamp(0, 255) as u8;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_give_zeros() {
let mut b = [0.0f32; 64];
idct8x8(&mut b);
for v in b.iter() {
assert!(v.abs() < 1e-4);
}
}
#[test]
fn dc_only_flat_block() {
let mut b = [0.0f32; 64];
b[0] = 8.0;
idct8x8(&mut b);
for v in b.iter() {
assert!((*v - 1.0).abs() < 1e-3, "{v}");
}
}
#[test]
fn idct_rounding_is_round_half_away_from_zero() {
assert_eq!(0.5f32.round() as i32, 1);
assert_eq!((-0.5f32).round() as i32, -1);
assert_eq!(1.5f32.round() as i32, 2);
assert_eq!((-1.5f32).round() as i32, -2);
}
#[test]
fn drift_stress_zero_residual_chain() {
use crate::block::dequant_ac;
use crate::fdct::fdct_signed;
use crate::quant::quant_ac;
let mut src = [0u8; 64];
for (i, s) in src.iter_mut().enumerate() {
*s = (40 + ((i * 17) % 100)) as u8;
}
let mut pred = [128u8; 64];
let q = 8u32;
let mut prev_recon = [0u8; 64];
for iter in 0..50 {
let mut resid = [0i32; 64];
for i in 0..64 {
resid[i] = src[i] as i32 - pred[i] as i32;
}
let mut coeffs = [0i32; 64];
fdct_signed(&resid, &mut coeffs);
let mut levels = [0i32; 64];
for i in 0..64 {
levels[i] = quant_ac(coeffs[i], q);
}
let mut dequant = [0i32; 64];
for i in 0..64 {
dequant[i] = dequant_ac(levels[i], q);
}
let mut rec_resid = [0i32; 64];
idct_signed(&dequant, &mut rec_resid);
let mut recon = [0u8; 64];
for i in 0..64 {
recon[i] = (pred[i] as i32 + rec_resid[i]).clamp(0, 255) as u8;
}
if iter >= 2 {
assert_eq!(
recon, prev_recon,
"drift at iter {iter}: recon != prev_recon — IDCT/quant chain is not idempotent"
);
}
prev_recon = recon;
pred = recon;
}
}
#[test]
fn intra_dc_level_matches_128() {
let mut b = [0i32; 64];
b[0] = 1024;
let mut out = [0u8; 64];
idct_intra(&b, &mut out);
for v in out.iter() {
assert_eq!(*v, 128);
}
}
}