use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
use rustfft::num_complex::Complex;
use std::sync::Arc;
use crate::ring_buffer::RingBuffer;
use crate::window::{self, WindowType};
pub struct FftFrame {
pub bins: Vec<Complex<f32>>,
pub window_size: usize,
}
impl FftFrame {
pub fn len(&self) -> usize {
self.bins.len()
}
pub fn is_empty(&self) -> bool {
self.bins.is_empty()
}
pub fn magnitudes(&self) -> Vec<f32> {
self.bins.iter().map(|c| c.norm()).collect()
}
pub fn phases(&self) -> Vec<f32> {
self.bins.iter().map(|c| c.arg()).collect()
}
pub fn from_polar(magnitudes: &[f32], phases: &[f32]) -> Self {
debug_assert_eq!(magnitudes.len(), phases.len());
let bins: Vec<Complex<f32>> = magnitudes
.iter()
.zip(phases.iter())
.map(|(&m, &p)| Complex::from_polar(m, p))
.collect();
let window_size = (bins.len() - 1) * 2;
Self { bins, window_size }
}
}
pub struct StftProcessor {
window_size: usize,
hop_size: usize,
window: Vec<f32>,
input_buf: RingBuffer,
output_buf: Vec<f32>,
output_pos: usize,
fft: Arc<dyn RealToComplex<f32>>,
ifft: Arc<dyn ComplexToReal<f32>>,
fft_input: Vec<f32>,
fft_output: Vec<Complex<f32>>,
ifft_input: Vec<Complex<f32>>,
ifft_output: Vec<f32>,
samples_since_fft: usize,
primed: bool,
}
impl StftProcessor {
pub fn new(window_size: usize, hop_size: usize, window_type: WindowType) -> Self {
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(window_size);
let ifft = planner.plan_fft_inverse(window_size);
let bin_count = window_size / 2 + 1;
Self {
window_size,
hop_size,
window: window::generate(window_type, window_size),
input_buf: RingBuffer::new(window_size),
output_buf: vec![0.0; window_size * 2], output_pos: 0,
fft,
ifft,
fft_input: vec![0.0; window_size],
fft_output: vec![Complex::default(); bin_count],
ifft_input: vec![Complex::default(); bin_count],
ifft_output: vec![0.0; window_size],
samples_since_fft: 0,
primed: false,
}
}
pub fn window_size(&self) -> usize {
self.window_size
}
pub fn hop_size(&self) -> usize {
self.hop_size
}
pub fn bin_count(&self) -> usize {
self.window_size / 2 + 1
}
fn forward_fft(&mut self) -> FftFrame {
self.input_buf.read_ordered(&mut self.fft_input);
window::apply(&mut self.fft_input, &self.window);
self.fft
.process(&mut self.fft_input, &mut self.fft_output)
.expect("FFT size mismatch");
FftFrame {
bins: self.fft_output.clone(),
window_size: self.window_size,
}
}
fn inverse_fft(&mut self, frame: &FftFrame) {
self.ifft_input.copy_from_slice(&frame.bins);
self.ifft
.process(&mut self.ifft_input, &mut self.ifft_output)
.expect("IFFT size mismatch");
let norm = 1.0 / self.window_size as f32;
for (i, &s) in self.ifft_output.iter().enumerate() {
let idx = (self.output_pos + i) % self.output_buf.len();
self.output_buf[idx] += s * norm * self.window[i];
}
}
pub fn process<F>(&mut self, input: &[f32], output: &mut Vec<f32>, transform: F)
where
F: Fn(FftFrame) -> FftFrame,
{
for &sample in input {
self.input_buf.push(sample);
self.samples_since_fft += 1;
if !self.primed && self.input_buf.len() >= self.window_size {
self.primed = true;
self.samples_since_fft = self.hop_size; }
if self.primed && self.samples_since_fft >= self.hop_size {
self.samples_since_fft = 0;
let frame = self.forward_fft();
let transformed = transform(frame);
self.inverse_fft(&transformed);
for _ in 0..self.hop_size {
let idx = self.output_pos % self.output_buf.len();
output.push(self.output_buf[idx]);
self.output_buf[idx] = 0.0; self.output_pos = (self.output_pos + 1) % self.output_buf.len();
}
}
}
}
pub fn analyze(&mut self, input: &[f32]) -> Vec<Vec<f32>> {
let mut results = Vec::new();
for &sample in input {
self.input_buf.push(sample);
self.samples_since_fft += 1;
if !self.primed && self.input_buf.len() >= self.window_size {
self.primed = true;
self.samples_since_fft = self.hop_size;
}
if self.primed && self.samples_since_fft >= self.hop_size {
self.samples_since_fft = 0;
let frame = self.forward_fft();
results.push(frame.magnitudes());
}
}
results
}
pub fn reset(&mut self) {
self.input_buf.clear();
self.output_buf.fill(0.0);
self.output_pos = 0;
self.samples_since_fft = 0;
self.primed = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft_frame_polar_roundtrip() {
let mags = vec![1.0, 0.5, 0.25, 0.0];
let phases = vec![0.0, 1.0, -1.0, 0.0];
let frame = FftFrame::from_polar(&mags, &phases);
let got_mags = frame.magnitudes();
let got_phases = frame.phases();
for (a, b) in mags.iter().zip(got_mags.iter()) {
assert!((a - b).abs() < 1e-5, "{} vs {}", a, b);
}
for (a, b) in phases.iter().zip(got_phases.iter()) {
assert!((a - b).abs() < 1e-5, "{} vs {}", a, b);
}
}
#[test]
fn test_stft_passthrough() {
let window_size = 256;
let hop_size = 64;
let mut stft = StftProcessor::new(window_size, hop_size, WindowType::Hann);
let n = 4096;
let input: Vec<f32> = (0..n)
.map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / 44100.0).sin() as f32)
.collect();
let mut output = Vec::new();
stft.process(&input, &mut output, |frame| frame);
assert!(
output.len() > n / 2,
"Should have produced substantial output, got {}",
output.len()
);
let rms: f32 = (output.iter().map(|s| s * s).sum::<f32>() / output.len() as f32).sqrt();
assert!(
rms > 0.1,
"STFT passthrough should preserve signal, rms={}",
rms
);
}
#[test]
fn test_stft_zero_bins_silences() {
let window_size = 256;
let hop_size = 64;
let mut stft = StftProcessor::new(window_size, hop_size, WindowType::Hann);
let input: Vec<f32> = (0..4096)
.map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / 44100.0).sin() as f32)
.collect();
let mut output = Vec::new();
stft.process(&input, &mut output, |mut frame| {
for bin in frame.bins.iter_mut() {
*bin = Complex::default();
}
frame
});
if !output.is_empty() {
let rms: f32 = (output.iter().map(|s| s * s).sum::<f32>() / output.len() as f32).sqrt();
assert!(
rms < 0.01,
"Zeroed spectrum should produce silence, rms={}",
rms
);
}
}
}