use crate::dsp::complex::Complex;
use crate::dsp::fft::*;
pub struct Imdct {
fft: Fft,
scratch: Box<[Complex<f32>]>,
twiddle: Box<[Complex<f32>]>,
}
impl Imdct {
pub fn new(n: usize) -> Self {
Imdct::new_scaled(n, 1.0)
}
pub fn new_scaled(n: usize, scale: f64) -> Self {
assert!(n.is_power_of_two(), "n must be a power of two");
assert!(n <= 2 * MAX_SIZE, "maximum size exceeded");
let n2 = n / 2;
let mut twiddle = Vec::with_capacity(n2);
let alpha = 1.0 / 8.0 + if scale.is_sign_positive() { 0.0 } else { n2 as f64 };
let pi_n = std::f64::consts::PI / n as f64;
let sqrt_scale = scale.abs().sqrt();
for k in 0..n2 {
let theta = pi_n * (alpha + k as f64);
let re = sqrt_scale * theta.cos();
let im = sqrt_scale * theta.sin();
twiddle.push(Complex::new(re as f32, im as f32));
}
let scratch = vec![Default::default(); n2].into_boxed_slice();
Imdct { fft: Fft::new(n2), scratch, twiddle: twiddle.into_boxed_slice() }
}
pub fn imdct(&mut self, spec: &[f32], out: &mut [f32]) {
let n = self.fft.size() << 1;
let n2 = n >> 1;
let n4 = n >> 2;
assert_eq!(spec.len(), n);
assert_eq!(out.len(), 2 * n);
for (i, (&w, t)) in self.twiddle.iter().zip(self.scratch.iter_mut()).enumerate() {
let even = spec[i * 2];
let odd = -spec[n - 1 - i * 2];
let re = odd * w.im - even * w.re;
let im = odd * w.re + even * w.im;
*t = Complex::new(re, im);
}
self.fft.fft_inplace(&mut self.scratch);
let (vec0, vec1) = out.split_at_mut(n2);
let (vec1, vec2) = vec1.split_at_mut(n2);
let (vec2, vec3) = vec2.split_at_mut(n2);
for (i, (x, &w)) in self.scratch[..n4].iter().zip(self.twiddle[..n4].iter()).enumerate() {
let val = w * x.conj();
let fi = 2 * i;
let ri = n2 - 1 - 2 * i;
vec0[ri] = -val.im;
vec1[fi] = val.im;
vec2[ri] = val.re;
vec3[fi] = val.re;
}
for (i, (x, &w)) in self.scratch[n4..].iter().zip(self.twiddle[n4..].iter()).enumerate() {
let val = w * x.conj();
let fi = 2 * i;
let ri = n2 - 1 - 2 * i;
vec0[fi] = -val.re;
vec1[ri] = val.re;
vec2[fi] = val.im;
vec3[ri] = val.im;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64;
fn imdct_analytical(x: &[f32], y: &mut [f32], scale: f64) {
assert!(y.len() == 2 * x.len());
let n_in = x.len();
let n_out = x.len() << 1;
let pi_2n = f64::consts::PI / (2 * n_out) as f64;
for (i, item) in y.iter_mut().enumerate().take(n_out) {
let accum: f64 = x
.iter()
.copied()
.map(f64::from)
.enumerate()
.take(n_in)
.map(|(j, jtem)| jtem * (pi_2n * ((2 * i + 1 + n_in) * (2 * j + 1)) as f64).cos())
.sum();
*item = (scale * accum) as f32;
}
}
#[test]
fn verify_imdct() {
#[rustfmt::skip]
const TEST_VECTOR: [f32; 32] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0,
];
let mut actual = [0f32; 64];
let mut expected = [0f32; 64];
let scale = (2.0f64 / 64.0).sqrt();
imdct_analytical(&TEST_VECTOR, &mut expected, scale);
let mut mdct = Imdct::new_scaled(32, scale);
mdct.imdct(&TEST_VECTOR, &mut actual);
for i in 0..64 {
let delta = f64::from(actual[i]) - f64::from(expected[i]);
assert!(delta.abs() < 0.00001);
}
}
}