use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use libm::sqrtf;
use num_complex::Complex;
use realfft::{RealFftPlanner, RealToComplex};
use crate::dsp::windows::{WindowKind, make_window};
#[derive(Clone, Debug)]
pub struct StftConfig {
pub n_fft: usize,
pub hop: usize,
pub window: WindowKind,
pub center: bool,
}
impl StftConfig {
#[must_use]
pub fn new(n_fft: usize) -> Self {
Self {
n_fft,
hop: n_fft / 4,
window: WindowKind::Hann,
center: true,
}
}
}
pub struct ShortTimeFFT {
cfg: StftConfig,
fft: Arc<dyn RealToComplex<f32>>,
window: Vec<f32>,
scratch_in: Vec<f32>,
scratch_out: Vec<Complex<f32>>,
}
impl ShortTimeFFT {
#[must_use]
pub fn new(cfg: StftConfig) -> Self {
assert!(
cfg.n_fft > 0 && cfg.n_fft.is_power_of_two(),
"n_fft must be a non-zero power of two, got {}",
cfg.n_fft
);
assert!(
cfg.hop > 0 && cfg.hop <= cfg.n_fft,
"hop must be in (0, n_fft], got hop={} n_fft={}",
cfg.hop,
cfg.n_fft
);
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(cfg.n_fft);
let window = make_window(cfg.window, cfg.n_fft);
let scratch_in = fft.make_input_vec();
let scratch_out = fft.make_output_vec();
Self {
cfg,
fft,
window,
scratch_in,
scratch_out,
}
}
#[must_use]
pub fn config(&self) -> &StftConfig {
&self.cfg
}
#[must_use]
pub const fn n_bins(&self) -> usize {
self.cfg.n_fft / 2 + 1
}
#[must_use]
pub const fn n_frames(&self, n_samples: usize) -> usize {
if self.cfg.center {
1 + n_samples / self.cfg.hop
} else if n_samples < self.cfg.n_fft {
0
} else {
1 + (n_samples - self.cfg.n_fft) / self.cfg.hop
}
}
#[must_use]
pub fn magnitude(&mut self, samples: &[f32]) -> Vec<Vec<f32>> {
let (flat, n_frames, n_bins) = self.magnitude_flat(samples);
if n_frames == 0 {
return Vec::new();
}
let mut out = Vec::with_capacity(n_frames);
for f in 0..n_frames {
out.push(flat[f * n_bins..(f + 1) * n_bins].to_vec());
}
out
}
#[must_use]
pub fn power_flat(&mut self, samples: &[f32]) -> (Vec<f32>, usize, usize) {
if samples.is_empty() {
return (Vec::new(), 0, 0);
}
let n_fft = self.cfg.n_fft;
let hop = self.cfg.hop;
let n_frames = self.n_frames(samples.len());
let n_bins = self.n_bins();
let center_off = if self.cfg.center {
(n_fft / 2) as isize
} else {
0
};
let mut out = vec![0.0_f32; n_frames * n_bins];
for f in 0..n_frames {
let start = (f * hop) as isize - center_off;
self.fill_windowed(samples, start);
self.fft
.process(&mut self.scratch_in, &mut self.scratch_out)
.expect("FFT process: input/output length mismatch");
let row = &mut out[f * n_bins..(f + 1) * n_bins];
for (i, c) in self.scratch_out.iter().enumerate() {
row[i] = c.norm_sqr();
}
}
(out, n_frames, n_bins)
}
#[must_use]
pub fn magnitude_flat(&mut self, samples: &[f32]) -> (Vec<f32>, usize, usize) {
if samples.is_empty() {
return (Vec::new(), 0, 0);
}
let n_fft = self.cfg.n_fft;
let hop = self.cfg.hop;
let n_frames = self.n_frames(samples.len());
let n_bins = self.n_bins();
let center_off = if self.cfg.center {
(n_fft / 2) as isize
} else {
0
};
let mut out = vec![0.0_f32; n_frames * n_bins];
for f in 0..n_frames {
let start = (f * hop) as isize - center_off;
self.fill_windowed(samples, start);
self.fft
.process(&mut self.scratch_in, &mut self.scratch_out)
.expect("FFT process: input/output length mismatch");
let row = &mut out[f * n_bins..(f + 1) * n_bins];
for (i, c) in self.scratch_out.iter().enumerate() {
row[i] = sqrtf(c.norm_sqr());
}
}
(out, n_frames, n_bins)
}
pub fn process_frame_power(&mut self, frame: &[f32], out: &mut [f32]) {
assert_eq!(frame.len(), self.cfg.n_fft, "frame length must equal n_fft");
assert_eq!(out.len(), self.n_bins(), "out length must equal n_bins");
for (i, (s, w)) in frame.iter().zip(self.window.iter()).enumerate() {
self.scratch_in[i] = s * w;
}
self.fft
.process(&mut self.scratch_in, &mut self.scratch_out)
.expect("FFT process: input/output length mismatch");
for (c, o) in self.scratch_out.iter().zip(out.iter_mut()) {
*o = c.norm_sqr();
}
}
pub fn process_frame(&mut self, frame: &[f32], out: &mut [f32]) {
assert_eq!(frame.len(), self.cfg.n_fft, "frame length must equal n_fft");
assert_eq!(out.len(), self.n_bins(), "out length must equal n_bins");
for (i, (s, w)) in frame.iter().zip(self.window.iter()).enumerate() {
self.scratch_in[i] = s * w;
}
self.fft
.process(&mut self.scratch_in, &mut self.scratch_out)
.expect("FFT process: input/output length mismatch");
for (c, o) in self.scratch_out.iter().zip(out.iter_mut()) {
*o = sqrtf(c.norm_sqr());
}
}
fn fill_windowed(&mut self, samples: &[f32], start: isize) {
let n_fft = self.cfg.n_fft;
let len = samples.len();
if start >= 0 && (start as usize).saturating_add(n_fft) <= len {
let s_off = start as usize;
let src = &samples[s_off..s_off + n_fft];
let win = &self.window[..n_fft];
let dst = &mut self.scratch_in[..n_fft];
for i in 0..n_fft {
dst[i] = src[i] * win[i];
}
return;
}
for k in 0..n_fft {
let idx = start + k as isize;
let s = if (0..len as isize).contains(&idx) {
samples[idx as usize]
} else if self.cfg.center {
samples[reflect(idx, len)]
} else {
0.0
};
self.scratch_in[k] = s * self.window[k];
}
}
}
fn reflect(i: isize, len: usize) -> usize {
let n = len as isize;
if n <= 1 {
return 0;
}
let period = 2 * (n - 1);
let mut j = i.rem_euclid(period);
if j >= n {
j = period - j;
}
j as usize
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use core::f32::consts::PI;
#[test]
fn reflect_matches_numpy() {
let want = [3, 2, 1, 0, 1, 2, 3, 4, 3, 2, 1];
for (i, w) in (-3..8).zip(want) {
assert_eq!(reflect(i, 5), w, "i={i}");
}
}
#[test]
fn n_bins_and_frames() {
let s = ShortTimeFFT::new(StftConfig::new(1024));
assert_eq!(s.n_bins(), 513);
assert_eq!(s.n_frames(16_000), 63);
}
#[test]
fn empty_input_produces_no_frames() {
let mut s = ShortTimeFFT::new(StftConfig::new(1024));
assert!(s.magnitude(&[]).is_empty());
}
#[test]
fn dc_signal_concentrates_energy_in_bin_zero() {
let mut s = ShortTimeFFT::new(StftConfig::new(1024));
let samples = alloc::vec![1.0_f32; 4096];
let spec = s.magnitude(&samples);
let mid = spec.len() / 2;
let f = &spec[mid];
assert!(f[0] > 0.0);
for (k, &v) in f.iter().enumerate().skip(2) {
assert!(
f[0] > v * 1000.0,
"bin {k} ({v}) not negligible vs DC ({})",
f[0]
);
}
}
#[test]
fn pure_sine_peaks_at_expected_bin() {
let n_fft = 1024;
let sr = 16_000.0_f32;
let freq = 1000.0_f32;
let mut s = ShortTimeFFT::new(StftConfig::new(n_fft));
let samples: alloc::vec::Vec<f32> = (0..4096)
.map(|n| libm::sinf(2.0 * PI * freq * n as f32 / sr))
.collect();
let spec = s.magnitude(&samples);
let expected_bin = (freq * n_fft as f32 / sr) as usize;
let mid = spec.len() / 2;
let f = &spec[mid];
let (peak_bin, _) = f
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
assert_eq!(peak_bin, expected_bin);
}
#[test]
fn process_frame_matches_magnitude() {
let cfg = StftConfig {
n_fft: 256,
hop: 256,
window: WindowKind::Hann,
center: false,
};
let mut s = ShortTimeFFT::new(cfg.clone());
let samples: alloc::vec::Vec<f32> = (0..256)
.map(|n| libm::sinf(2.0 * PI * n as f32 / 32.0))
.collect();
let mut frame_out = alloc::vec![0.0_f32; s.n_bins()];
s.process_frame(&samples, &mut frame_out);
let mut s2 = ShortTimeFFT::new(cfg);
let buf_out = s2.magnitude(&samples);
assert_eq!(buf_out.len(), 1);
for (a, b) in frame_out.iter().zip(buf_out[0].iter()) {
assert_relative_eq!(a, b, max_relative = 1e-5);
}
}
}