fn iwht_1d(t: [i64; 4], shift: u32) -> [i64; 4] {
let mut a = t[0] >> shift;
let mut c = t[1] >> shift;
let mut d = t[2] >> shift;
let mut b = t[3] >> shift;
a += c;
d -= b;
let e = (a - d) >> 1;
b = e - b;
c = e - c;
a -= b;
d += c;
[a, b, c, d]
}
fn iwht_1d_inverse(o: [i64; 4]) -> [i64; 4] {
let a1 = o[0] + o[1];
let d1 = o[3] - o[2];
let e = (a1 - d1) >> 1;
let in3 = e - o[1];
let in1 = e - o[2];
let in0 = a1 - in1;
let in2 = d1 + in3;
[in0, in1, in2, in3]
}
#[must_use]
pub fn fwht4x4(residual: &[i32; 16]) -> [i32; 16] {
let mut m = [[0i64; 4]; 4];
for j in 0..4 {
let col = [
i64::from(residual[j]),
i64::from(residual[4 + j]),
i64::from(residual[8 + j]),
i64::from(residual[12 + j]),
];
let r = iwht_1d_inverse(col);
for i in 0..4 {
m[i][j] = r[i];
}
}
let mut quant = [0i32; 16];
for i in 0..4 {
let r = iwht_1d_inverse(m[i]);
for j in 0..4 {
quant[i * 4 + j] = r[j] as i32;
}
}
quant
}
#[must_use]
pub fn iwht4x4(quant: &[i32; 16]) -> [i32; 16] {
let mut m = [[0i64; 4]; 4];
for i in 0..4 {
let t = [
i64::from(quant[i * 4]) * 4,
i64::from(quant[i * 4 + 1]) * 4,
i64::from(quant[i * 4 + 2]) * 4,
i64::from(quant[i * 4 + 3]) * 4,
];
m[i] = iwht_1d(t, 2);
}
let mut residual = [0i32; 16];
for j in 0..4 {
let t = [m[0][j], m[1][j], m[2][j], m[3][j]];
let r = iwht_1d(t, 0);
for i in 0..4 {
residual[i * 4 + j] = r[i] as i32;
}
}
residual
}
#[cfg(test)]
mod tests {
use super::*;
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> u64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.0
}
fn residual(&mut self) -> i32 {
(self.next() >> 33) as i32 % 511 - 255
}
}
#[test]
fn roundtrip_zero() {
let r = [0i32; 16];
assert_eq!(iwht4x4(&fwht4x4(&r)), r);
}
#[test]
fn roundtrip_extremes() {
let mut r = [0i32; 16];
for (idx, v) in r.iter_mut().enumerate() {
let row = idx / 4;
let col = idx % 4;
*v = if (row + col) % 2 == 0 { 255 } else { -255 };
}
let q = fwht4x4(&r);
for &c in &q {
assert!(
c * 4 >= -32768 && c * 4 <= 32767,
"coeff {c} would clamp on dequant"
);
}
assert_eq!(iwht4x4(&q), r);
}
#[test]
fn roundtrip_random() {
let mut rng = Lcg(0xabcd_1234_5678_9999);
for _ in 0..20_000 {
let mut r = [0i32; 16];
for v in &mut r {
*v = rng.residual();
}
assert_eq!(iwht4x4(&fwht4x4(&r)), r);
}
}
#[test]
fn dc_only_residual() {
let r = [7i32; 16];
let q = fwht4x4(&r);
assert_eq!(iwht4x4(&q), r);
}
}