use crate::butterfly::{b, brev, cos128, h};
use crate::math::round2;
pub fn inverse_dct(t: &mut [i64], n: u32, r: u32) {
assert!((2..=6).contains(&n), "inverse_dct: n must be in 2..=6");
let len = 1usize << n;
assert!(t.len() >= len, "inverse_dct: array shorter than 2^n");
let n = n as i32;
permute(t, n);
if n == 6 {
for i in 0..16 {
b(t, 32 + i, 63 - i, 63 - 4 * brev(4, i), false, r);
}
}
if n >= 5 {
for i in 0..8 {
b(t, 16 + i, 31 - i, 6 + (brev(3, 7 - i) << 3), false, r);
}
}
if n == 6 {
for i in 0..16 {
h(t, 32 + i * 2, 33 + i * 2, (i & 1) != 0, r);
}
}
if n >= 4 {
for i in 0..4 {
b(t, 8 + i, 15 - i, 12 + (brev(2, 3 - i) << 4), false, r);
}
}
if n >= 5 {
for i in 0..8 {
h(t, 16 + 2 * i, 17 + 2 * i, (i & 1) != 0, r);
}
}
if n == 6 {
for i in 0..4 {
for j in 0..2 {
b(
t,
62 - i * 4 - j,
33 + i * 4 + j,
60 - 16 * brev(2, i) + 64 * j,
true,
r,
);
}
}
}
if n >= 3 {
for i in 0..2 {
b(t, 4 + i, 7 - i, 56 - 32 * i, false, r);
}
}
if n >= 4 {
for i in 0..4 {
h(t, 8 + 2 * i, 9 + 2 * i, (i & 1) != 0, r);
}
}
if n >= 5 {
for i in 0..2 {
for j in 0..2 {
b(
t,
30 - 4 * i - j,
17 + 4 * i + j,
24 + (j << 6) + ((1 - i) << 5),
true,
r,
);
}
}
}
if n == 6 {
for i in 0..8 {
for j in 0..2 {
h(t, 32 + i * 4 + j, 35 + i * 4 - j, (i & 1) != 0, r);
}
}
}
for i in 0..2 {
b(t, 2 * i, 2 * i + 1, 32 + 16 * i, (1 - i) != 0, r);
}
if n >= 3 {
for i in 0..2 {
h(t, 4 + 2 * i, 5 + 2 * i, i != 0, r);
}
}
if n >= 4 {
for i in 0..2 {
b(t, 14 - i, 9 + i, 48 + 64 * i, true, r);
}
}
if n >= 5 {
for i in 0..4 {
for j in 0..2 {
h(t, 16 + 4 * i + j, 19 + 4 * i - j, (i & 1) != 0, r);
}
}
}
if n == 6 {
for i in 0..2 {
for j in 0..4 {
b(
t,
61 - i * 8 - j,
34 + i * 8 + j,
56 - i * 32 + (j >> 1) * 64,
true,
r,
);
}
}
}
for i in 0..2 {
h(t, i, 3 - i, false, r);
}
if n >= 3 {
b(t, 6, 5, 32, true, r);
}
if n >= 4 {
for i in 0..2 {
for j in 0..2 {
h(t, 8 + 4 * i + j, 11 + 4 * i - j, i != 0, r);
}
}
}
if n >= 5 {
for i in 0..4 {
b(t, 29 - i, 18 + i, 48 + (i >> 1) * 64, true, r);
}
}
if n == 6 {
for i in 0..4 {
for j in 0..4 {
h(t, 32 + 8 * i + j, 39 + 8 * i - j, (i & 1) != 0, r);
}
}
}
if n >= 3 {
for i in 0..4 {
h(t, i, 7 - i, false, r);
}
}
if n >= 4 {
for i in 0..2 {
b(t, 13 - i, 10 + i, 32, true, r);
}
}
if n >= 5 {
for i in 0..2 {
for j in 0..4 {
h(t, 16 + i * 8 + j, 23 + i * 8 - j, i != 0, r);
}
}
}
if n == 6 {
for i in 0..8 {
b(t, 59 - i, 36 + i, if i < 4 { 48 } else { 112 }, true, r);
}
}
if n >= 4 {
for i in 0..8 {
h(t, i, 15 - i, false, r);
}
}
if n >= 5 {
for i in 0..4 {
b(t, 27 - i, 20 + i, 32, true, r);
}
}
if n == 6 {
for i in 0..8 {
h(t, 32 + i, 47 - i, false, r);
h(t, 48 + i, 63 - i, true, r);
}
}
if n >= 5 {
for i in 0..16 {
h(t, i, 31 - i, false, r);
}
}
if n == 6 {
for i in 0..8 {
b(t, 55 - i, 40 + i, 32, true, r);
}
}
if n == 6 {
for i in 0..32 {
h(t, i, 63 - i, false, r);
}
}
}
fn permute(t: &mut [i64], n: i32) {
let len = 1usize << n;
let mut copy = [0i64; 64];
copy[..len].copy_from_slice(&t[..len]);
for (i, slot) in t[..len].iter_mut().enumerate() {
*slot = copy[brev(n as u32, i as i32) as usize];
}
}
pub fn forward_dct(t: &mut [i64], n: u32) {
assert!((2..=6).contains(&n), "forward_dct: n must be in 2..=6");
let len = 1usize << n;
assert!(t.len() >= len, "forward_dct: array shorter than 2^n");
let step = 64i32 >> n;
let mut out = [0i64; 64];
for (k, slot) in out[..len].iter_mut().enumerate() {
let mut acc = 0i64;
for (m, &x) in t[..len].iter().enumerate() {
let angle = step * (2 * m as i32 + 1) * k as i32;
acc += x * cos128(angle);
}
*slot = if k == 0 {
round2(acc * 2896, 24)
} else {
round2(acc, 12)
};
}
t[..len].copy_from_slice(&out[..len]);
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> u64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.0
}
fn coeff(&mut self, range: i64) -> i64 {
(self.next() >> 33) as i64 % (2 * range + 1) - range
}
}
fn dct_weight(k: usize) -> f64 {
if k == 0 {
std::f64::consts::FRAC_1_SQRT_2
} else {
1.0
}
}
fn naive_idct(x: &[f64]) -> Vec<f64> {
let n = x.len();
(0..n)
.map(|m| {
(0..n)
.map(|k| {
dct_weight(k)
* x[k]
* (PI * (2 * m + 1) as f64 * k as f64 / (2.0 * n as f64)).cos()
})
.sum()
})
.collect()
}
fn naive_dct(x: &[f64]) -> Vec<f64> {
let n = x.len();
(0..n)
.map(|k| {
dct_weight(k)
* (0..n)
.map(|m| {
x[m] * (PI * (2 * m + 1) as f64 * k as f64 / (2.0 * n as f64)).cos()
})
.sum::<f64>()
})
.collect()
}
fn assert_proportional(got: &[i64], want: &[f64], tol: f64, ctx: &str) -> f64 {
let (anchor, &wmax) = want
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.unwrap();
assert!(wmax.abs() > 1e-6, "{ctx}: degenerate reference");
let scale = got[anchor] as f64 / wmax;
for (m, (&g, &w)) in got.iter().zip(want).enumerate() {
let predicted = scale * w;
assert!(
(g as f64 - predicted).abs() <= tol,
"{ctx}: entry {m} = {g}, expected ≈ {predicted:.2} (scale {scale:.3})",
);
}
scale
}
#[test]
fn inverse_dct_matches_naive_idct() {
let mut rng = Lcg(0x1234_5678_9abc_def0);
for n in 2..=6u32 {
let len = 1usize << n;
for _ in 0..200 {
let coeffs: Vec<i64> = (0..len).map(|_| rng.coeff(64)).collect();
let mut t = coeffs.clone();
inverse_dct(&mut t, n, 16);
let want = naive_idct(&coeffs.iter().map(|&c| c as f64).collect::<Vec<_>>());
assert_proportional(&t[..len], &want, 2.0 * len as f64, &format!("idct n={n}"));
}
}
}
#[test]
fn forward_dct_matches_naive_dct() {
let mut rng = Lcg(0x0fed_cba9_8765_4321);
for n in 2..=6u32 {
let len = 1usize << n;
for _ in 0..200 {
let resid: Vec<i64> = (0..len).map(|_| rng.coeff(255)).collect();
let mut t = resid.clone();
forward_dct(&mut t, n);
let want = naive_dct(&resid.iter().map(|&c| c as f64).collect::<Vec<_>>());
assert_proportional(&t[..len], &want, len as f64, &format!("dct n={n}"));
}
}
}
#[test]
fn forward_then_inverse_is_proportional_identity() {
let mut rng = Lcg(0xdead_beef_cafe_0001);
for n in 2..=6u32 {
let len = 1usize << n;
for _ in 0..100 {
let resid: Vec<i64> = (0..len).map(|_| rng.coeff(120)).collect();
let mut t = resid.clone();
forward_dct(&mut t, n);
inverse_dct(&mut t, n, 24);
let want: Vec<f64> = resid.iter().map(|&c| c as f64).collect();
assert_proportional(&t[..len], &want, 6.0 * len as f64, &format!("rt n={n}"));
}
}
}
#[test]
fn dc_input_is_flat_after_inverse() {
for n in 2..=6u32 {
let len = 1usize << n;
let mut t = vec![0i64; len];
t[0] = 100;
inverse_dct(&mut t, n, 16);
let first = t[0];
for (m, &v) in t[..len].iter().enumerate() {
assert_eq!(v, first, "n={n}: DC should be flat, entry {m} differs");
}
}
}
}