use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use crate::float::FloatExt;
#[derive(Debug, Clone)]
pub struct MdctLookup {
n: usize,
trig: Vec<f32>,
fwd_pre: [Vec<f32>; 4],
}
impl MdctLookup {
#[must_use]
pub fn new(n: usize) -> Self {
let n4 = n >> 2;
let trig: Vec<f32> = (0..=n4)
.map(|i| (2.0 * core::f64::consts::PI * i as f64 / n as f64).cos() as f32)
.collect();
let fwd_pre = core::array::from_fn(|s| {
let ns = n >> s;
let n4s = ns >> 2;
let sine = (2.0 * core::f64::consts::PI * 0.125 / ns as f64) as f32;
let mut v = Vec::with_capacity(2 * n4s);
for i in 0..n4s {
let t1 = trig[i << s];
let t2 = trig[(n4s - i) << s];
v.push(-t1 + t2 * sine);
v.push(t1 * sine + t2);
}
v
});
MdctLookup { n, trig, fwd_pre }
}
#[must_use]
pub const fn size(&self) -> usize {
self.n
}
}
#[cfg(feature = "spectrograms")]
mod accel {
use std::cell::RefCell;
use std::collections::HashMap;
use spectrograms::{C2cPlan, Complex, RealFftC2cPlan};
std::thread_local! {
static PLANS: RefCell<HashMap<usize, RealFftC2cPlan<f32>>> = RefCell::new(HashMap::new());
}
std::thread_local! {
static SCRATCH: RefCell<Vec<Complex<f32>>> = const { RefCell::new(Vec::new()) };
}
pub fn run(input: &[(f32, f32)], output: &mut [(f32, f32)], inverse: bool, scale: f32) {
SCRATCH.with(|scratch| {
let mut buf = scratch.borrow_mut();
buf.clear();
buf.extend(input.iter().map(|&(re, im)| Complex::new(re, im)));
PLANS.with(|plans| {
let mut plans = plans.borrow_mut();
let plan = plans
.entry(buf.len())
.or_insert_with(|| RealFftC2cPlan::<f32>::new(buf.len()));
if inverse {
plan.inverse(&mut buf[..]).expect("plan sized to buffer");
} else {
plan.forward(&mut buf[..]).expect("plan sized to buffer");
}
});
for (out, c) in output.iter_mut().zip(buf.iter()) {
*out = (c.re * scale, c.im * scale);
}
});
}
}
#[cfg(feature = "spectrograms")]
fn fft_inverse(input: &[(f32, f32)], output: &mut [(f32, f32)]) {
accel::run(input, output, true, 1.0);
}
#[cfg(feature = "spectrograms")]
fn fft_forward(input: &[(f32, f32)], output: &mut [(f32, f32)]) {
accel::run(input, output, false, 1.0 / input.len() as f32);
}
#[cfg(not(feature = "spectrograms"))]
fn fft_inverse(input: &[(f32, f32)], output: &mut [(f32, f32)]) {
let n = input.len();
let step = 2.0 * core::f64::consts::PI / n as f64;
for (k, out) in output.iter_mut().enumerate() {
let mut re = 0.0f64;
let mut im = 0.0f64;
for (j, &(xr, xi)) in input.iter().enumerate() {
let phase = step * (k * j % n) as f64;
let (s, c) = phase.sin_cos();
re += f64::from(xr) * c - f64::from(xi) * s;
im += f64::from(xr) * s + f64::from(xi) * c;
}
*out = (re as f32, im as f32);
}
}
#[cfg(not(feature = "spectrograms"))]
fn fft_forward(input: &[(f32, f32)], output: &mut [(f32, f32)]) {
let n = input.len();
let step = 2.0 * core::f64::consts::PI / n as f64;
let scale = 1.0 / n as f64;
for (k, out) in output.iter_mut().enumerate() {
let mut re = 0.0f64;
let mut im = 0.0f64;
for (j, &(xr, xi)) in input.iter().enumerate() {
let phase = step * (k * j % n) as f64;
let (s, c) = phase.sin_cos();
re += f64::from(xr) * c + f64::from(xi) * s;
im += -f64::from(xr) * s + f64::from(xi) * c;
}
*out = ((re * scale) as f32, (im * scale) as f32);
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code)]
fn prerotate(f: &[f32], tw: &[f32], fc: &mut [(f32, f32)]) {
use core::arch::x86_64::*;
let n4 = tw.len() / 2;
debug_assert!(f.len() >= 2 * n4 && fc.len() >= n4);
let fp = f.as_ptr();
let twp = tw.as_ptr();
let fcp = fc.as_mut_ptr().cast::<f32>();
unsafe {
let mut i = 0;
while i + 4 <= n4 {
let lo = _mm_loadu_ps(fp.add(2 * i));
let hi = _mm_loadu_ps(fp.add(2 * i + 4));
let re = _mm_shuffle_ps::<0x88>(lo, hi);
let im = _mm_shuffle_ps::<0xDD>(lo, hi);
let twlo = _mm_loadu_ps(twp.add(2 * i));
let twhi = _mm_loadu_ps(twp.add(2 * i + 4));
let wr = _mm_shuffle_ps::<0x88>(twlo, twhi);
let wi = _mm_shuffle_ps::<0xDD>(twlo, twhi);
let cr = _mm_sub_ps(_mm_mul_ps(re, wr), _mm_mul_ps(im, wi));
let ci = _mm_add_ps(_mm_mul_ps(re, wi), _mm_mul_ps(im, wr));
_mm_storeu_ps(fcp.add(2 * i), _mm_unpacklo_ps(cr, ci));
_mm_storeu_ps(fcp.add(2 * i + 4), _mm_unpackhi_ps(cr, ci));
i += 4;
}
while i < n4 {
let (re, im) = (f[2 * i], f[2 * i + 1]);
let (wr, wi) = (tw[2 * i], tw[2 * i + 1]);
fc[i] = (re * wr - im * wi, re * wi + im * wr);
i += 1;
}
}
}
#[cfg(not(target_arch = "x86_64"))]
fn prerotate(f: &[f32], tw: &[f32], fc: &mut [(f32, f32)]) {
for (i, c) in fc.iter_mut().enumerate().take(tw.len() / 2) {
let (re, im) = (f[2 * i], f[2 * i + 1]);
let (wr, wi) = (tw[2 * i], tw[2 * i + 1]);
*c = (re * wr - im * wi, re * wi + im * wr);
}
}
#[cfg(feature = "std")]
std::thread_local! {
static SCRATCH_F: core::cell::RefCell<Vec<f32>> = const { core::cell::RefCell::new(Vec::new()) };
static SCRATCH_C1: core::cell::RefCell<Vec<(f32, f32)>> = const { core::cell::RefCell::new(Vec::new()) };
static SCRATCH_C2: core::cell::RefCell<Vec<(f32, f32)>> = const { core::cell::RefCell::new(Vec::new()) };
}
#[cfg(feature = "std")]
fn take_f() -> Vec<f32> {
SCRATCH_F.with(|s| core::mem::take(&mut *s.borrow_mut()))
}
#[cfg(not(feature = "std"))]
fn take_f() -> Vec<f32> {
Vec::new()
}
#[cfg(feature = "std")]
fn put_f(v: Vec<f32>) {
SCRATCH_F.with(|s| *s.borrow_mut() = v);
}
#[cfg(not(feature = "std"))]
fn put_f(_v: Vec<f32>) {}
#[cfg(feature = "std")]
fn take_c1() -> Vec<(f32, f32)> {
SCRATCH_C1.with(|s| core::mem::take(&mut *s.borrow_mut()))
}
#[cfg(not(feature = "std"))]
fn take_c1() -> Vec<(f32, f32)> {
Vec::new()
}
#[cfg(feature = "std")]
fn put_c1(v: Vec<(f32, f32)>) {
SCRATCH_C1.with(|s| *s.borrow_mut() = v);
}
#[cfg(not(feature = "std"))]
fn put_c1(_v: Vec<(f32, f32)>) {}
#[cfg(feature = "std")]
fn take_c2() -> Vec<(f32, f32)> {
SCRATCH_C2.with(|s| core::mem::take(&mut *s.borrow_mut()))
}
#[cfg(not(feature = "std"))]
fn take_c2() -> Vec<(f32, f32)> {
Vec::new()
}
#[cfg(feature = "std")]
fn put_c2(v: Vec<(f32, f32)>) {
SCRATCH_C2.with(|s| *s.borrow_mut() = v);
}
#[cfg(not(feature = "std"))]
fn put_c2(_v: Vec<(f32, f32)>) {}
impl MdctLookup {
pub fn backward(
&self,
input: &[f32],
out: &mut [f32],
window: &[f32],
overlap: usize,
shift: usize,
stride: usize,
) {
let n = self.n >> shift;
let n2 = n >> 1;
let n4 = n >> 2;
debug_assert!(out.len() >= (overlap >> 1) + n2);
debug_assert_eq!(window.len(), overlap);
let sine = (2.0 * core::f64::consts::PI * 0.125 / n as f64) as f32;
let mut f2 = take_c1();
f2.resize(n4, (0.0, 0.0));
{
let t = &self.trig;
for (i, y) in f2.iter_mut().enumerate() {
let x1 = input[stride * 2 * i]; let x2 = input[stride * (n2 - 1 - 2 * i)]; let yr = -x2 * t[i << shift] + x1 * t[(n4 - i) << shift];
let yi = -x2 * t[(n4 - i) << shift] - x1 * t[i << shift];
*y = (yr - yi * sine, yi + yr * sine);
}
}
let mut time = take_c2();
time.resize(n4, (0.0, 0.0));
fft_inverse(&f2, &mut time);
for (i, &(re, im)) in time.iter().enumerate() {
out[(overlap >> 1) + 2 * i] = re;
out[(overlap >> 1) + 2 * i + 1] = im;
}
put_c1(f2);
put_c2(time);
{
let base = overlap >> 1;
let t = &self.trig;
for i in 0..((n4 + 1) >> 1) {
let p0 = base + 2 * i;
let p1 = base + n2 - 2 - 2 * i;
let re = out[p0];
let im = out[p0 + 1];
let t0 = t[i << shift];
let t1 = t[(n4 - i) << shift];
let yr = re * t0 - im * t1;
let yi = im * t0 + re * t1;
let re2 = out[p1];
let im2 = out[p1 + 1];
out[p0] = -(yr - yi * sine);
out[p1 + 1] = yi + yr * sine;
let t0 = t[(n4 - i - 1) << shift];
let t1 = t[(i + 1) << shift];
let yr = re2 * t0 - im2 * t1;
let yi = im2 * t0 + re2 * t1;
out[p1] = -(yr - yi * sine);
out[p0 + 1] = yi + yr * sine;
}
}
{
for i in 0..overlap / 2 {
let a = i;
let b = overlap - 1 - i;
let x1 = out[b];
let x2 = out[a];
let w1 = window[i];
let w2 = window[overlap - 1 - i];
out[a] = w2 * x2 - w1 * x1;
out[b] = w1 * x2 + w2 * x1;
}
}
}
pub fn forward(&self, input: &[f32], out: &mut [f32], window: &[f32], overlap: usize, shift: usize, stride: usize) {
let n = self.n >> shift;
let n2 = n >> 1;
let n4 = n >> 2;
debug_assert!(input.len() >= n2 + overlap);
debug_assert_eq!(window.len(), overlap);
let sine = (2.0 * core::f64::consts::PI * 0.125 / n as f64) as f32;
let mut f = take_f();
f.resize(n2, 0.0);
{
let half = overlap >> 1;
let quarter = (overlap + 3) >> 2;
let mut yp = 0usize;
for i in 0..quarter {
let xp1 = half + 2 * i;
let xp2 = n2 - 1 + half - 2 * i;
let w1 = window[half + 2 * i];
let w2 = window[half - 1 - 2 * i];
f[yp] = w2 * input[xp1 + n2] + w1 * input[xp2];
f[yp + 1] = w1 * input[xp1] - w2 * input[xp2 - n2];
yp += 2;
}
for i in quarter..(n4 - quarter) {
let xp1 = half + 2 * i;
let xp2 = n2 - 1 + half - 2 * i;
f[yp] = input[xp2];
f[yp + 1] = input[xp1];
yp += 2;
}
for i in (n4 - quarter)..n4 {
let k = i - (n4 - quarter);
let xp1 = half + 2 * i;
let xp2 = n2 - 1 + half - 2 * i;
let w1 = window[2 * k];
let w2 = window[overlap - 1 - 2 * k];
f[yp] = -w1 * input[xp1 - n2] + w2 * input[xp2];
f[yp + 1] = w2 * input[xp1] + w1 * input[xp2 + n2];
yp += 2;
}
}
let mut fc = take_c1();
fc.resize(n4, (0.0, 0.0));
prerotate(&f, &self.fwd_pre[shift], &mut fc);
let mut f2 = take_c2();
f2.resize(n4, (0.0, 0.0));
fft_forward(&fc, &mut f2);
{
let t = &self.trig;
for (i, &(re, im)) in f2.iter().enumerate() {
let yr = im * t[(n4 - i) << shift] + re * t[i << shift];
let yi = re * t[(n4 - i) << shift] - im * t[i << shift];
out[stride * 2 * i] = yr - yi * sine;
out[stride * (n2 - 1 - 2 * i)] = yi + yr * sine;
}
}
put_f(f);
put_c1(fc);
put_c2(f2);
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use super::*;
use crate::celt::tables::WINDOW120;
#[test]
fn window_satisfies_tdac() {
for (i, &w) in WINDOW120.iter().enumerate() {
let w1 = f64::from(w);
let w2 = f64::from(WINDOW120[119 - i]);
assert!((w1 * w1 + w2 * w2 - 1.0).abs() < 1e-6, "i={i}");
let t = core::f64::consts::FRAC_PI_2 * (i as f64 + 0.5) / 120.0;
let expected = (core::f64::consts::FRAC_PI_2 * t.sin().powi(2)).sin();
assert!((w1 - expected).abs() < 1e-5, "i={i}");
}
}
#[test]
fn tdac_perfect_reconstruction() {
let lookup = MdctLookup::new(1920);
let shift = 2usize;
let n2 = (1920 >> shift) / 2;
let overlap = 120usize;
let frames = 6usize;
let total = n2 * frames + overlap + n2;
let signal: Vec<f32> = (0..total)
.map(|i| {
let t = i as f32;
(t * 0.1).sin() + 0.5 * (t * 0.037).cos() + 0.25 * (t * 0.41).sin()
})
.collect();
let mut synth = vec![0.0f32; total];
for f in 0..frames {
let mut freq = vec![0.0f32; n2];
lookup.forward(&signal[f * n2..], &mut freq, &WINDOW120, overlap, shift, 1);
let out = &mut synth[f * n2..];
lookup.backward(&freq, out, &WINDOW120, overlap, shift, 1);
}
for i in 2 * overlap..(frames - 1) * n2 {
let got = synth[i];
let want = signal[i];
assert!(
(got - want).abs() < 1e-3,
"sample {i}: got {got}, want {want} (err {})",
(got - want).abs()
);
}
}
#[test]
fn interleaved_short_blocks_round_trip() {
let lookup = MdctLookup::new(1920);
let shift = 3usize; let n2 = (1920 >> shift) / 2;
let b = 2usize;
let overlap = 120usize;
let frames = 8usize;
let total = n2 * frames + overlap + n2;
let signal: Vec<f32> = (0..total)
.map(|i| ((i as f32) * 0.21).sin() - 0.3 * ((i as f32) * 0.05).cos())
.collect();
let mut synth = vec![0.0f32; total];
for f in 0..frames / b {
let base = f * b * n2;
let mut freq = vec![0.0f32; n2 * b];
for k in 0..b {
lookup.forward(&signal[base + k * n2..], &mut freq[k..], &WINDOW120, overlap, shift, b);
}
for k in 0..b {
lookup.backward(&freq[k..], &mut synth[base + k * n2..], &WINDOW120, overlap, shift, b);
}
}
for i in 2 * overlap..(frames - 2) * n2 {
let got = synth[i];
let want = signal[i];
assert!((got - want).abs() < 1e-3, "sample {i}: got {got}, want {want}");
}
}
}