use std::f64::consts::PI;
use std::num::NonZeroUsize;
use ndarray::Array2;
use num_complex::Complex;
use crate::{C2cPlan, C2cPlanF32, SpectrogramError, SpectrogramResult, WindowType, make_window};
mod private_seal {
pub trait Seal {}
impl Seal for f32 {}
impl Seal for f64 {}
}
trait MdctNum:
private_seal::Seal
+ Copy
+ Default
+ Send
+ Sync
+ 'static
+ std::ops::Add<Output = Self>
+ std::ops::Sub<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Neg<Output = Self>
+ std::ops::AddAssign
+ std::ops::MulAssign
{
fn zero() -> Self;
fn from_f64(x: f64) -> Self;
fn scale(n: usize) -> Self; }
impl MdctNum for f32 {
#[inline]
fn zero() -> Self {
0.0f32
}
#[inline]
fn from_f64(x: f64) -> Self {
x as Self
}
#[inline]
fn scale(n: usize) -> Self {
2.0f32 / n as Self
}
}
impl MdctNum for f64 {
#[inline]
fn zero() -> Self {
0.0f64
}
#[inline]
fn from_f64(x: f64) -> Self {
x
}
#[inline]
fn scale(n: usize) -> Self {
2.0f64 / n as Self
}
}
trait MdctC2cFft<T: MdctNum>: Send {
fn forward(&mut self, buf: &mut [Complex<T>]) -> SpectrogramResult<()>;
}
struct C2cWrapper<P: C2cPlan + Send>(P);
impl<P: C2cPlan + Send> MdctC2cFft<f64> for C2cWrapper<P> {
#[inline]
fn forward(&mut self, buf: &mut [Complex<f64>]) -> SpectrogramResult<()> {
self.0.forward(buf)
}
}
struct C2cF32Wrapper<P: C2cPlanF32 + Send>(P);
impl<P: C2cPlanF32 + Send> MdctC2cFft<f32> for C2cF32Wrapper<P> {
#[inline]
fn forward(&mut self, buf: &mut [Complex<f32>]) -> SpectrogramResult<()> {
self.0.forward(buf)
}
}
#[derive(Debug, Clone)]
pub struct MdctParams {
pub window_size: NonZeroUsize,
pub hop_size: NonZeroUsize,
pub window: WindowType,
}
impl MdctParams {
#[inline]
pub fn new(
window_size: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
) -> SpectrogramResult<Self> {
if !window_size.get().is_multiple_of(2) {
return Err(SpectrogramError::invalid_input(format!(
"window_size must be even, got {}",
window_size.get()
)));
}
if window_size.get() < 4 {
return Err(SpectrogramError::invalid_input(format!(
"window_size must be >= 4, got {}",
window_size.get()
)));
}
Ok(Self {
window_size,
hop_size,
window,
})
}
#[inline]
pub fn sine_window(window_size: NonZeroUsize) -> SpectrogramResult<Self> {
let n = window_size.get();
if !n.is_multiple_of(2) {
return Err(SpectrogramError::invalid_input(format!(
"window_size must be even, got {n}"
)));
}
if n < 4 {
return Err(SpectrogramError::invalid_input(format!(
"window_size must be >= 4, got {n}"
)));
}
let coeffs: Vec<f64> = (0..n)
.map(|k| (PI * (k as f64 + 0.5) / n as f64).sin())
.collect();
let window = WindowType::custom(coeffs)?;
let hop_size = NonZeroUsize::new(n / 2)
.ok_or_else(|| SpectrogramError::invalid_input("hop_size computed as zero"))?;
Ok(Self {
window_size,
hop_size,
window,
})
}
#[inline]
#[must_use]
pub const fn n_coefficients(&self) -> usize {
self.window_size.get() / 2
}
}
struct MdctFwdPlan<T: MdctNum> {
analysis_re: Vec<T>,
analysis_im: Vec<T>,
mdct_post_re: Vec<T>,
mdct_post_im: Vec<T>,
c2c: Box<dyn MdctC2cFft<T>>,
fwd_z: Vec<Complex<T>>,
n: usize,
}
impl<T: MdctNum> MdctFwdPlan<T> {
fn new(params: &MdctParams, c2c: Box<dyn MdctC2cFft<T>>) -> Self {
let two_n = params.window_size.get();
let n = two_n / 2;
let window_f64 = make_window(params.window.clone(), params.window_size);
let analysis_re: Vec<T> = (0..two_n)
.map(|m| T::from_f64(window_f64[m] * (PI * m as f64 / two_n as f64).cos()))
.collect();
let analysis_im: Vec<T> = (0..two_n)
.map(|m| T::from_f64(-window_f64[m] * (PI * m as f64 / two_n as f64).sin()))
.collect();
let (mdct_post_re, mdct_post_im): (Vec<T>, Vec<T>) = (0..n)
.map(|k| {
let angle = PI * (2 * k + 1) as f64 * (n + 1) as f64 / (4 * n) as f64;
(T::from_f64(angle.cos()), T::from_f64(-angle.sin()))
})
.unzip();
Self {
analysis_re,
analysis_im,
mdct_post_re,
mdct_post_im,
c2c,
fwd_z: vec![
Complex {
re: T::zero(),
im: T::zero()
};
n
],
n,
}
}
fn mdct_frame(&mut self, frame: &[T], out: &mut [T]) -> SpectrogramResult<()> {
let n = self.n;
for m in 0..n {
self.fwd_z[m] = Complex {
re: frame[m] * self.analysis_re[m] + frame[m + n] * self.analysis_re[m + n],
im: frame[m] * self.analysis_im[m] + frame[m + n] * self.analysis_im[m + n],
};
}
self.c2c.forward(&mut self.fwd_z)?;
let mut k = 0usize;
while k < n {
let p = k / 2;
let f_re = self.fwd_z[p].re;
let f_im = self.fwd_z[p].im;
out[k] = self.mdct_post_re[k] * f_re - self.mdct_post_im[k] * f_im;
k += 2;
}
let mut k = 1usize;
while k < n {
let p = (k - 1) / 2;
let f_re = self.fwd_z[n - 1 - p].re;
let f_im = self.fwd_z[n - 1 - p].im; out[k] = self.mdct_post_re[k] * f_re + self.mdct_post_im[k] * f_im;
k += 2;
}
Ok(())
}
}
struct MdctInvPlan<T: MdctNum> {
window_samples: Vec<T>,
imdct_post_re: Vec<T>,
imdct_post_im: Vec<T>,
pre_z: Vec<Complex<T>>,
pre_zprime: Vec<Complex<T>>,
c2c: Box<dyn MdctC2cFft<T>>,
z: Vec<Complex<T>>,
zprime: Vec<Complex<T>>,
n: usize,
}
impl<T: MdctNum> MdctInvPlan<T> {
fn new(params: &MdctParams, c2c: Box<dyn MdctC2cFft<T>>) -> Self {
let two_n = params.window_size.get();
let n = two_n / 2;
let window_f64 = make_window(params.window.clone(), params.window_size);
let window_samples: Vec<T> = window_f64.iter().map(|&w| T::from_f64(w)).collect();
let pre_z: Vec<Complex<T>> = (0..n)
.map(|k| {
let angle = PI * k as f64 * (n + 1) as f64 / two_n as f64;
Complex {
re: T::from_f64(angle.cos()),
im: T::from_f64(-angle.sin()),
}
})
.collect();
let pre_zprime: Vec<Complex<T>> = (0..n)
.map(|k| {
let w_re = T::from_f64((PI * k as f64 / n as f64).cos());
let w_im = T::from_f64(-(PI * k as f64 / n as f64).sin());
Complex {
re: pre_z[k].re * w_re - pre_z[k].im * w_im,
im: pre_z[k].re * w_im + pre_z[k].im * w_re,
}
})
.collect();
let (imdct_post_re, imdct_post_im): (Vec<T>, Vec<T>) = (0..two_n)
.map(|m| {
let angle = PI * (2 * m + 1 + n) as f64 / (4 * n) as f64;
(T::from_f64(angle.cos()), T::from_f64(-angle.sin()))
})
.unzip();
Self {
window_samples,
imdct_post_re,
imdct_post_im,
pre_z,
pre_zprime,
c2c,
z: vec![
Complex {
re: T::zero(),
im: T::zero()
};
n
],
zprime: vec![
Complex {
re: T::zero(),
im: T::zero()
};
n
],
n,
}
}
fn imdct_frame(&mut self, coeffs: &[T], out: &mut [T]) -> SpectrogramResult<()> {
let n = self.n;
let scale = T::scale(n);
for (k, &c) in coeffs.iter().enumerate() {
self.z[k] = Complex {
re: c * self.pre_z[k].re,
im: c * self.pre_z[k].im,
};
self.zprime[k] = Complex {
re: c * self.pre_zprime[k].re,
im: c * self.pre_zprime[k].im,
};
}
self.c2c.forward(&mut self.z)?;
self.c2c.forward(&mut self.zprime)?;
for j in 0..n {
let m = 2 * j;
out[m] = scale
* (self.imdct_post_re[m] * self.z[j].re - self.imdct_post_im[m] * self.z[j].im);
}
for j in 0..n {
let m = 2 * j + 1;
out[m] = scale
* (self.imdct_post_re[m] * self.zprime[j].re
- self.imdct_post_im[m] * self.zprime[j].im);
}
Ok(())
}
}
#[inline]
pub fn mdct(
samples: &non_empty_slice::NonEmptySlice<f64>,
params: &MdctParams,
) -> SpectrogramResult<Array2<f64>> {
let samples = samples.as_slice();
let two_n = params.window_size.get();
let hop = params.hop_size.get();
let n = params.n_coefficients();
if samples.len() < two_n {
return Err(SpectrogramError::invalid_input(format!(
"samples length ({}) must be >= window_size ({})",
samples.len(),
two_n
)));
}
let n_frames = (samples.len() - two_n) / hop + 1;
#[cfg(feature = "realfft")]
let c2c: Box<dyn MdctC2cFft<f64>> = {
Box::new(C2cWrapper(
crate::fft_backend::realfft_backend::RealFftC2cPlan::new(n),
))
};
#[cfg(feature = "fftw")]
let c2c: Box<dyn MdctC2cFft<f64>> = {
let mut planner = crate::FftwPlanner::new();
Box::new(C2cWrapper(planner.plan_c2c(n)?))
};
let mut plan = MdctFwdPlan::<f64>::new(params, c2c);
let mut output = Array2::<f64>::zeros((n, n_frames));
let mut coef_buf = vec![0.0f64; n];
for f in 0..n_frames {
let start = f * hop;
let frame = &samples[start..start + two_n];
plan.mdct_frame(frame, &mut coef_buf)?;
for (i, &v) in coef_buf.iter().enumerate() {
output[(i, f)] = v;
}
}
Ok(output)
}
#[inline]
pub fn imdct(
coefficients: &Array2<f64>,
params: &MdctParams,
original_length: Option<usize>,
) -> SpectrogramResult<Vec<f64>> {
let n = params.n_coefficients();
let two_n = params.window_size.get();
let hop = params.hop_size.get();
if coefficients.nrows() != n {
return Err(SpectrogramError::invalid_input(format!(
"coefficients has {} rows but params.n_coefficients() = {}",
coefficients.nrows(),
n
)));
}
let n_frames = coefficients.ncols();
if n_frames == 0 {
return Ok(Vec::new());
}
#[cfg(feature = "realfft")]
let c2c: Box<dyn MdctC2cFft<f64>> = {
Box::new(C2cWrapper(
crate::fft_backend::realfft_backend::RealFftC2cPlan::new(n),
))
};
#[cfg(feature = "fftw")]
let c2c: Box<dyn MdctC2cFft<f64>> = {
let mut planner = crate::FftwPlanner::new();
Box::new(C2cWrapper(planner.plan_c2c(n)?))
};
let mut plan = MdctInvPlan::<f64>::new(params, c2c);
let out_len = hop * n_frames + two_n - hop;
let mut output = vec![0.0f64; out_len];
let mut frame_out = vec![0.0f64; two_n];
let mut col_buf = vec![0.0f64; n];
for f in 0..n_frames {
let col = coefficients.column(f);
for (i, &v) in col.iter().enumerate() {
col_buf[i] = v;
}
plan.imdct_frame(&col_buf, &mut frame_out)?;
let start = f * hop;
for m in 0..two_n {
output[start + m] += frame_out[m] * plan.window_samples[m];
}
}
if let Some(len) = original_length {
output.truncate(len);
}
Ok(output)
}
#[inline]
pub fn mdct_f32(
samples: &non_empty_slice::NonEmptySlice<f32>,
params: &MdctParams,
) -> SpectrogramResult<Array2<f32>> {
let samples = samples.as_slice();
let two_n = params.window_size.get();
let hop = params.hop_size.get();
let n = params.n_coefficients();
if samples.len() < two_n {
return Err(SpectrogramError::invalid_input(format!(
"samples length ({}) must be >= window_size ({})",
samples.len(),
two_n
)));
}
let n_frames = (samples.len() - two_n) / hop + 1;
#[cfg(feature = "realfft")]
let c2c: Box<dyn MdctC2cFft<f32>> = {
Box::new(C2cF32Wrapper(
crate::fft_backend::realfft_backend::RealFftC2cPlanF32::new(n),
))
};
#[cfg(feature = "fftw")]
let c2c: Box<dyn MdctC2cFft<f32>> = {
return Err(SpectrogramError::invalid_input(
"mdct_f32 is not yet implemented for the fftw backend; use --features realfft",
));
};
let mut plan = MdctFwdPlan::<f32>::new(params, c2c);
let mut output = Array2::<f32>::zeros((n, n_frames));
let mut coef_buf = vec![0.0f32; n];
for f in 0..n_frames {
let start = f * hop;
let frame = &samples[start..start + two_n];
plan.mdct_frame(frame, &mut coef_buf)?;
for (i, &v) in coef_buf.iter().enumerate() {
output[(i, f)] = v;
}
}
Ok(output)
}
#[inline]
pub fn imdct_f32(
coefficients: &Array2<f32>,
params: &MdctParams,
original_length: Option<usize>,
) -> SpectrogramResult<Vec<f32>> {
let n = params.n_coefficients();
let two_n = params.window_size.get();
let hop = params.hop_size.get();
if coefficients.nrows() != n {
return Err(SpectrogramError::invalid_input(format!(
"coefficients has {} rows but params.n_coefficients() = {}",
coefficients.nrows(),
n
)));
}
let n_frames = coefficients.ncols();
if n_frames == 0 {
return Ok(Vec::new());
}
#[cfg(feature = "realfft")]
let c2c: Box<dyn MdctC2cFft<f32>> = {
Box::new(C2cF32Wrapper(
crate::fft_backend::realfft_backend::RealFftC2cPlanF32::new(n),
))
};
#[cfg(feature = "fftw")]
let c2c: Box<dyn MdctC2cFft<f32>> = {
return Err(SpectrogramError::invalid_input(
"imdct_f32 is not yet implemented for the fftw backend; use --features realfft",
));
};
let mut plan = MdctInvPlan::<f32>::new(params, c2c);
let out_len = hop * n_frames + two_n - hop;
let mut output = vec![0.0f32; out_len];
let mut frame_out = vec![0.0f32; two_n];
let mut col_buf = vec![0.0f32; n];
for f in 0..n_frames {
let col = coefficients.column(f);
for (i, &v) in col.iter().enumerate() {
col_buf[i] = v;
}
plan.imdct_frame(&col_buf, &mut frame_out)?;
let start = f * hop;
for m in 0..two_n {
output[start + m] += frame_out[m] * plan.window_samples[m];
}
}
if let Some(len) = original_length {
output.truncate(len);
}
Ok(output)
}
#[cfg(all(test, feature = "realfft"))]
mod tests {
use super::*;
fn make_sine(n: usize, freq: f64, sr: f64) -> Vec<f64> {
(0..n)
.map(|i| (2.0 * PI * freq * i as f64 / sr).sin())
.collect()
}
#[test]
fn single_frame_matches_direct_formula() {
let window_size = std::num::NonZeroUsize::new(16).unwrap();
let hop = std::num::NonZeroUsize::new(8).unwrap();
let params = MdctParams::new(window_size, hop, WindowType::Rectangular).unwrap();
let two_n = 16usize;
let n = 8usize;
let x: Vec<f64> = (0..two_n).map(|i| (i as f64 + 1.0) * 0.1).collect();
let x_ne = non_empty_slice::NonEmptyVec::new(x.clone()).unwrap();
let coefs = mdct(x_ne.as_non_empty_slice(), ¶ms).unwrap();
for k in 0..n {
let ref_val: f64 = (0..two_n)
.map(|m| {
x[m] * (PI * (2 * m + 1 + n) as f64 * (2 * k + 1) as f64 / (4 * n) as f64).cos()
})
.sum();
if (coefs[(k, 0)] - ref_val).abs() >= 1e-10 {
eprintln!("FAIL k={k}: got {:.12}, ref {:.12}", coefs[(k, 0)], ref_val);
let angle = PI * (2 * k + 1) as f64 * (n + 1) as f64 / (4 * n) as f64;
eprintln!(
" post angle={angle:.6}, cos={:.6}, -sin={:.6}",
angle.cos(),
-angle.sin()
);
}
assert!(
(coefs[(k, 0)] - ref_val).abs() < 1e-10,
"k={k}: got {:.12}, ref {:.12}",
coefs[(k, 0)],
ref_val
);
}
}
#[test]
fn c2c_packing_matches_two_r2c() {
use crate::fft_backend::{C2cPlan, R2cPlan};
let n = 8usize;
let x: Vec<f64> = (0..16).map(|i| (i as f64 + 1.0) * 0.1).collect();
let _params = MdctParams::new(
std::num::NonZeroUsize::new(16).unwrap(),
std::num::NonZeroUsize::new(8).unwrap(),
WindowType::Rectangular,
)
.unwrap();
let analysis_re: Vec<f64> = (0..16).map(|m| (PI * m as f64 / 16.0).cos()).collect();
let analysis_im: Vec<f64> = (0..16).map(|m| -(PI * m as f64 / 16.0).sin()).collect();
let mut a = vec![0.0f64; n];
let mut b = vec![0.0f64; n];
for m in 0..n {
a[m] = x[m] * analysis_re[m] + x[m + n] * analysis_re[m + n];
b[m] = x[m] * analysis_im[m] + x[m + n] * analysis_im[m + n];
}
let mut r2c_a_plan = {
let mut planner = crate::RealFftPlanner::new();
let p = planner.get_or_create(n);
crate::fft_backend::realfft_backend::RealFftPlan::new(n, p)
};
let mut out_a = vec![
Complex {
re: 0.0f64,
im: 0.0
};
n / 2 + 1
];
let mut out_b = vec![
Complex {
re: 0.0f64,
im: 0.0
};
n / 2 + 1
];
r2c_a_plan.process(&a, &mut out_a).unwrap();
let mut r2c_b_plan = {
let mut planner = crate::RealFftPlanner::new();
let p = planner.get_or_create(n);
crate::fft_backend::realfft_backend::RealFftPlan::new(n, p)
};
r2c_b_plan.process(&b, &mut out_b).unwrap();
eprintln!("A[0..4] = {:?}", &out_a);
eprintln!("B[0..4] = {:?}", &out_b);
let mut c2c_plan = crate::fft_backend::realfft_backend::RealFftC2cPlan::new(n);
let mut z: Vec<Complex<f64>> = (0..n).map(|m| Complex { re: a[m], im: b[m] }).collect();
c2c_plan.forward(&mut z).unwrap();
eprintln!("Z[0..8] = {:?}", &z);
for k in 0..=n / 2 {
eprintln!(
"k={k}: A+iB=({:.6},{:.6}), Z=({:.6},{:.6})",
out_a[k].re - out_b[k].im,
out_a[k].im + out_b[k].re,
z[k].re,
z[k].im
);
}
}
#[test]
fn perfect_reconstruction_f64() {
let window_size = std::num::NonZeroUsize::new(256).unwrap();
let params = MdctParams::sine_window(window_size).unwrap();
let n = 2048usize;
let x = make_sine(n, 440.0, 44100.0);
let x_ne = non_empty_slice::NonEmptyVec::new(x.clone()).unwrap();
let coefs = mdct(x_ne.as_non_empty_slice(), ¶ms).unwrap();
let x_rec = imdct(&coefs, ¶ms, Some(n)).unwrap();
let margin = 256;
for i in margin..(n - margin) {
assert!(
(x_rec[i] - x[i]).abs() < 1e-10,
"sample {i}: got {:.12}, expected {:.12}",
x_rec[i],
x[i]
);
}
}
#[test]
fn perfect_reconstruction_f32() {
let window_size = std::num::NonZeroUsize::new(256).unwrap();
let params = MdctParams::sine_window(window_size).unwrap();
let n = 2048usize;
let x: Vec<f32> = make_sine(n, 440.0, 44100.0)
.into_iter()
.map(|v| v as f32)
.collect();
let x_ne = non_empty_slice::NonEmptyVec::new(x.clone()).unwrap();
let coefs = mdct_f32(x_ne.as_non_empty_slice(), ¶ms).unwrap();
let x_rec = imdct_f32(&coefs, ¶ms, Some(n)).unwrap();
let margin = 256;
for i in margin..(n - margin) {
assert!(
(x_rec[i] - x[i]).abs() < 1e-5,
"sample {i}: got {:.8}, expected {:.8}",
x_rec[i],
x[i]
);
}
}
}