use std::f32::consts::PI;
#[derive(Debug)]
pub struct Mdct {
n: usize,
window_size: usize,
#[allow(dead_code)]
twiddle_fwd: Vec<f32>,
#[allow(dead_code)]
twiddle_inv: Vec<f32>,
window: Vec<f32>,
}
impl Mdct {
pub fn new(n: usize) -> Self {
let window_size = 2 * n;
let twiddle_fwd = Self::compute_twiddle_factors(n, false);
let twiddle_inv = Self::compute_twiddle_factors(n, true);
let window = Self::compute_window(window_size);
Self {
n,
window_size,
twiddle_fwd,
twiddle_inv,
window,
}
}
fn compute_twiddle_factors(n: usize, inverse: bool) -> Vec<f32> {
let mut twiddle = Vec::with_capacity(n);
let sign = if inverse { 1.0 } else { -1.0 };
for k in 0..n {
let angle = sign * PI * (k as f32 + 0.5) / (n as f32);
twiddle.push(angle.cos());
}
twiddle
}
fn compute_window(size: usize) -> Vec<f32> {
let mut window = Vec::with_capacity(size);
for i in 0..size {
let x = ((i as f32 + 0.5) / size as f32 * PI).sin();
window.push((x * PI / 2.0).sin());
}
window
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), self.window_size);
assert_eq!(output.len(), self.n);
let mut windowed = vec![0.0f32; self.window_size];
for i in 0..self.window_size {
windowed[i] = input[i] * self.window[i];
}
for k in 0..self.n {
let mut sum = 0.0;
for n in 0..self.window_size {
let angle = PI / (self.window_size as f32)
* (n as f32 + 0.5 + self.n as f32 / 2.0)
* (k as f32 + 0.5);
sum += windowed[n] * angle.cos();
}
output[k] = sum;
}
}
pub fn inverse(&self, input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), self.n);
assert_eq!(output.len(), self.window_size);
for n in 0..self.window_size {
let mut sum = 0.0;
for k in 0..self.n {
let angle = PI / (self.window_size as f32)
* (n as f32 + 0.5 + self.n as f32 / 2.0)
* (k as f32 + 0.5);
sum += input[k] * angle.cos();
}
output[n] = sum * 2.0 / self.n as f32;
}
for i in 0..self.window_size {
output[i] *= self.window[i];
}
}
#[must_use]
pub const fn size(&self) -> usize {
self.n
}
#[must_use]
pub const fn window_size(&self) -> usize {
self.window_size
}
}
#[derive(Debug)]
pub struct OverlapAdd {
size: usize,
overlap: Vec<f32>,
}
impl OverlapAdd {
pub fn new(size: usize) -> Self {
Self {
size,
overlap: vec![0.0; size],
}
}
pub fn process(&mut self, input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), 2 * self.size);
assert_eq!(output.len(), self.size);
for i in 0..self.size {
output[i] = self.overlap[i] + input[i];
}
self.overlap.copy_from_slice(&input[self.size..]);
}
pub fn reset(&mut self) {
self.overlap.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mdct_creation() {
let mdct = Mdct::new(256);
assert_eq!(mdct.size(), 256);
assert_eq!(mdct.window_size(), 512);
}
#[test]
fn test_mdct_forward_inverse() {
let mdct = Mdct::new(64);
let input = vec![1.0f32; 128];
let mut coeffs = vec![0.0f32; 64];
let mut output = vec![0.0f32; 128];
mdct.forward(&input, &mut coeffs);
mdct.inverse(&coeffs, &mut output);
assert!(coeffs.iter().any(|&x| x.abs() > 0.1));
}
#[test]
fn test_overlap_add() {
let mut ola = OverlapAdd::new(64);
let input = vec![1.0f32; 128];
let mut output = vec![0.0f32; 64];
ola.process(&input, &mut output);
assert_eq!(output[0], 1.0);
ola.process(&input, &mut output);
assert_eq!(output[0], 2.0);
}
#[test]
fn test_overlap_add_reset() {
let mut ola = OverlapAdd::new(64);
let input = vec![1.0f32; 128];
let mut output = vec![0.0f32; 64];
ola.process(&input, &mut output);
ola.reset();
ola.process(&input, &mut output);
assert_eq!(output[0], 1.0);
}
}