use crate::butterfly::{b, h};
use crate::math::round2;
const SINPI_1_9: i64 = 1321;
const SINPI_2_9: i64 = 2482;
const SINPI_3_9: i64 = 3344;
const SINPI_4_9: i64 = 3803;
pub fn inverse_adst(t: &mut [i64], n: u32, r: u32) {
match n {
2 => inverse_adst4(t, r),
3 => inverse_adst8(t, r),
4 => inverse_adst16(t, r),
_ => panic!("inverse_adst: n must be in 2..=4"),
}
}
fn inverse_adst4(t: &mut [i64], _r: u32) {
assert!(t.len() >= 4, "inverse_adst4: array shorter than 4");
let mut s = [0i64; 7];
s[0] = SINPI_1_9 * t[0];
s[1] = SINPI_2_9 * t[0];
s[2] = SINPI_3_9 * t[1];
s[3] = SINPI_4_9 * t[2];
s[4] = SINPI_1_9 * t[2];
s[5] = SINPI_2_9 * t[3];
s[6] = SINPI_4_9 * t[3];
let a7 = t[0] - t[2];
let b7 = a7 + t[3];
s[0] += s[3];
s[1] -= s[4];
s[3] = s[2];
s[2] = SINPI_3_9 * b7;
s[0] += s[5];
s[1] -= s[6];
let x0 = s[0] + s[3];
let x1 = s[1] + s[3];
let x2 = s[2];
let x3 = s[0] + s[1] - s[3];
t[0] = round2(x0, 12);
t[1] = round2(x1, 12);
t[2] = round2(x2, 12);
t[3] = round2(x3, 12);
}
fn inverse_adst8(t: &mut [i64], r: u32) {
assert!(t.len() >= 8, "inverse_adst8: array shorter than 8");
adst_input_permute(t, 3);
for i in 0..4 {
b(t, 2 * i, 2 * i + 1, 60 - 16 * i, true, r);
}
for i in 0..4 {
h(t, i, 4 + i, false, r);
}
for i in 0..2 {
b(t, 4 + 3 * i, 5 + i, 48 - 32 * i, true, r);
}
for j in 0..2 {
for i in 0..2 {
h(t, 4 * j + i, 2 + 4 * j + i, false, r);
}
}
for i in 0..2 {
b(t, 2 + 4 * i, 3 + 4 * i, 32, true, r);
}
adst_output_permute(t, 3);
}
fn inverse_adst16(t: &mut [i64], r: u32) {
assert!(t.len() >= 16, "inverse_adst16: array shorter than 16");
adst_input_permute(t, 4);
for i in 0..8 {
b(t, 2 * i, 2 * i + 1, 62 - 8 * i, true, r);
}
for i in 0..8 {
h(t, i, 8 + i, false, r);
}
for i in 0..2 {
b(t, 8 + 2 * i, 9 + 2 * i, 56 - 32 * i, true, r);
b(t, 13 + 2 * i, 12 + 2 * i, 8 + 32 * i, true, r);
}
for j in 0..2 {
for i in 0..4 {
h(t, 8 * j + i, 4 + 8 * j + i, false, r);
}
}
for j in 0..2 {
for i in 0..2 {
b(t, 4 + 8 * j + 3 * i, 5 + 8 * j + i, 48 - 32 * i, true, r);
}
}
for j in 0..4 {
for i in 0..2 {
h(t, 4 * j + i, 2 + 4 * j + i, false, r);
}
}
for i in 0..4 {
b(t, 2 + 4 * i, 3 + 4 * i, 32, true, r);
}
adst_output_permute(t, 4);
}
fn adst_input_permute(t: &mut [i64], n: u32) {
let n0 = 1usize << n;
let mut copy = [0i64; 16];
copy[..n0].copy_from_slice(&t[..n0]);
for (i, slot) in t[..n0].iter_mut().enumerate() {
let idx = if i & 1 == 1 { i - 1 } else { n0 - i - 1 };
*slot = copy[idx];
}
}
fn adst_output_permute(t: &mut [i64], n: u32) {
let n0 = 1usize << n;
let mut copy = [0i64; 16];
copy[..n0].copy_from_slice(&t[..n0]);
for (i, slot) in t[..n0].iter_mut().enumerate() {
let a = (i >> 3) & 1;
let b = ((i >> 2) & 1) ^ ((i >> 3) & 1);
let c = ((i >> 1) & 1) ^ ((i >> 2) & 1);
let d = (i & 1) ^ ((i >> 1) & 1);
let idx = ((d << 3) | (c << 2) | (b << 1) | a) >> (4 - n);
*slot = if i & 1 == 1 { -copy[idx] } else { copy[idx] };
}
}
pub fn forward_adst(t: &mut [i64], n: u32) {
assert!((2..=4).contains(&n), "forward_adst: n must be in 2..=4");
let len = 1usize << n;
assert!(t.len() >= len, "forward_adst: array shorter than 2^n");
const UNIT: i64 = 1 << 12;
let mut m_inv = [[0i64; 16]; 16];
for k in 0..len {
let mut e = [0i64; 16];
e[k] = UNIT;
inverse_adst(&mut e[..len], n, 24);
for m in 0..len {
m_inv[m][k] = e[m];
}
}
let mut out = [0i64; 16];
for (k, slot) in out[..len].iter_mut().enumerate() {
let mut acc = 0i64;
for (m, &x) in t[..len].iter().enumerate() {
acc += m_inv[m][k] * x;
}
*slot = round2(acc, 12);
}
t[..len].copy_from_slice(&out[..len]);
}
pub fn flip_in_place(t: &mut [i64], len: usize) {
t[..len].reverse();
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> u64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.0
}
fn coeff(&mut self, range: i64) -> i64 {
(self.next() >> 33) as i64 % (2 * range + 1) - range
}
}
fn naive_iadst(x: &[f64]) -> Vec<f64> {
let n = x.len();
let nn = n as f64;
(0..n)
.map(|m| {
(0..n)
.map(|k| {
let basis = if n == 4 {
(PI * (2 * k + 1) as f64 * (m + 1) as f64 / (2.0 * nn + 1.0)).sin()
} else {
(PI * (2 * m + 1) as f64 * (2 * k + 1) as f64 / (4.0 * nn)).sin()
};
x[k] * basis
})
.sum()
})
.collect()
}
fn assert_proportional(got: &[i64], want: &[f64], tol: f64, ctx: &str) -> f64 {
let (anchor, &wmax) = want
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.unwrap();
assert!(wmax.abs() > 1e-6, "{ctx}: degenerate reference");
let scale = got[anchor] as f64 / wmax;
for (m, (&g, &w)) in got.iter().zip(want).enumerate() {
let predicted = scale * w;
assert!(
(g as f64 - predicted).abs() <= tol,
"{ctx}: entry {m} = {g}, expected ≈ {predicted:.2} (scale {scale:.3})",
);
}
scale
}
#[test]
fn inverse_adst_matches_naive_reference() {
let mut rng = Lcg(0x51a7_c0de_1234_5678);
for n in 2..=4u32 {
let len = 1usize << n;
for _ in 0..200 {
let coeffs: Vec<i64> = (0..len).map(|_| rng.coeff(64)).collect();
let mut t = coeffs.clone();
inverse_adst(&mut t, n, 16);
let want = naive_iadst(&coeffs.iter().map(|&c| c as f64).collect::<Vec<_>>());
assert_proportional(&t[..len], &want, 2.0 * len as f64, &format!("iadst n={n}"));
}
}
}
#[test]
fn forward_then_inverse_is_proportional_identity() {
let mut rng = Lcg(0x0bad_f00d_dead_0001);
for n in 2..=4u32 {
let len = 1usize << n;
for _ in 0..200 {
let resid: Vec<i64> = (0..len).map(|_| rng.coeff(200)).collect();
let mut t = resid.clone();
forward_adst(&mut t, n);
inverse_adst(&mut t, n, 24);
let want: Vec<f64> = resid.iter().map(|&c| c as f64).collect();
assert_proportional(&t[..len], &want, 8.0 * len as f64, &format!("rt n={n}"));
}
}
}
#[test]
fn input_then_output_permute_indices_are_valid() {
for n in 3..=4u32 {
let n0 = 1usize << n;
let mut t: Vec<i64> = (0..n0 as i64).collect();
adst_input_permute(&mut t, n);
let mut seen: Vec<i64> = t[..n0].to_vec();
seen.sort_unstable();
assert_eq!(
seen,
(0..n0 as i64).collect::<Vec<_>>(),
"input permute n={n}"
);
}
}
#[test]
fn flip_reverses_samples() {
let mut t = [1i64, 2, 3, 4, 99];
flip_in_place(&mut t, 4);
assert_eq!(t, [4, 3, 2, 1, 99]);
}
#[test]
fn dc_coefficient_basis_increases_across_block() {
for n in 2..=4u32 {
let len = 1usize << n;
let mut t = vec![0i64; len];
t[0] = 1000;
inverse_adst(&mut t, n, 16);
assert!(t[0] > 0, "n={n}: first sample should be positive");
for w in t[..len].windows(2) {
assert!(
w[1] > w[0],
"n={n}: DC basis should increase across the block"
);
}
}
}
}