Skip to main content

math_audio_dsp/
stft.rs

1// ============================================================================
2// Shared STFT Infrastructure
3// ============================================================================
4//
5// Reusable components for STFT-based plugins:
6// - generate_hann_window: Hann window generation for STFT analysis
7// - RealFftProcessor: Thin wrapper around realfft for single-channel use
8// - RingAccumulator: Sample accumulator with hop-based triggering
9
10use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
11use rustfft::num_complex::Complex;
12use std::sync::Arc;
13
14// ============================================================================
15// Hann Window
16// ============================================================================
17
18/// Generate a Hann window of the given size.
19/// Uses N (not N-1) divisor for perfect COLA with 50% overlap.
20pub fn generate_hann_window(size: usize) -> Vec<f32> {
21    (0..size)
22        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos()))
23        .collect()
24}
25
26/// Generate a symmetric Hann window of the given size.
27/// Uses N-1 divisor — suitable for spectral analysis (zero at endpoints).
28pub fn generate_hann_window_symmetric(size: usize) -> Vec<f32> {
29    if size <= 1 {
30        return vec![1.0; size];
31    }
32    let n_minus_1 = (size as f32) - 1.0;
33    (0..size)
34        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / n_minus_1).cos()))
35        .collect()
36}
37
38/// Generate a sqrt(Hann) window for WOLA (Weighted Overlap-Add) processing.
39/// When used as both analysis and synthesis window, the product is Hann,
40/// which has perfect COLA at 50% overlap.
41pub fn generate_sqrt_hann_window(size: usize) -> Vec<f32> {
42    (0..size)
43        .map(|i| {
44            let hann = 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos());
45            hann.sqrt()
46        })
47        .collect()
48}
49
50// ============================================================================
51// RealFftProcessor
52// ============================================================================
53
54/// Thin wrapper around `realfft` encapsulating planner + buffers for
55/// single-channel use. Provides forward (real→complex) and optional
56/// inverse (complex→real) FFT.
57pub struct RealFftProcessor {
58    #[allow(dead_code)]
59    pub fft_size: usize,
60    pub spectrum_size: usize,
61    fft_forward: Arc<dyn RealToComplex<f32>>,
62    fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
63    pub time_buffer: Vec<f32>,
64    pub freq_buffer: Vec<Complex<f32>>,
65}
66
67impl RealFftProcessor {
68    /// Create a forward-only FFT processor (no inverse).
69    pub fn new_forward_only(fft_size: usize) -> Self {
70        let spectrum_size = fft_size / 2 + 1;
71        let mut planner = RealFftPlanner::<f32>::new();
72        let fft_forward = planner.plan_fft_forward(fft_size);
73
74        Self {
75            fft_size,
76            spectrum_size,
77            fft_forward,
78            fft_inverse: None,
79            time_buffer: vec![0.0; fft_size],
80            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
81        }
82    }
83
84    /// Create a bidirectional FFT processor (forward + inverse).
85    #[allow(dead_code)]
86    pub fn new_bidirectional(fft_size: usize) -> Self {
87        let spectrum_size = fft_size / 2 + 1;
88        let mut planner = RealFftPlanner::<f32>::new();
89        let fft_forward = planner.plan_fft_forward(fft_size);
90        let fft_inverse = planner.plan_fft_inverse(fft_size);
91
92        Self {
93            fft_size,
94            spectrum_size,
95            fft_forward,
96            fft_inverse: Some(fft_inverse),
97            time_buffer: vec![0.0; fft_size],
98            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
99        }
100    }
101
102    /// Perform forward FFT: time_buffer → freq_buffer.
103    /// The caller should fill `time_buffer` before calling this.
104    pub fn forward(&mut self) {
105        self.fft_forward
106            .process(&mut self.time_buffer, &mut self.freq_buffer)
107            .expect("FFT forward failed");
108    }
109
110    /// Perform inverse FFT: freq_buffer → time_buffer.
111    /// Panics if this processor was created with `new_forward_only`.
112    #[allow(dead_code)]
113    pub fn inverse(&mut self) {
114        self.fft_inverse
115            .as_ref()
116            .expect("Inverse FFT not available (forward-only processor)")
117            .process(&mut self.freq_buffer, &mut self.time_buffer)
118            .expect("FFT inverse failed");
119    }
120}
121
122// ============================================================================
123// RingAccumulator
124// ============================================================================
125
126/// Sample accumulator with hop-based triggering.
127/// Accumulates samples into a circular buffer and signals when `hop_size`
128/// new samples have been written (and the buffer has been filled at least once).
129pub struct RingAccumulator {
130    buffer: Vec<f32>,
131    write_pos: usize,
132    samples_since_trigger: usize,
133    filled: bool,
134    window_size: usize,
135    hop_size: usize,
136}
137
138impl RingAccumulator {
139    pub fn new(window_size: usize, hop_size: usize) -> Self {
140        Self {
141            buffer: vec![0.0; window_size],
142            write_pos: 0,
143            samples_since_trigger: 0,
144            filled: false,
145            window_size,
146            hop_size,
147        }
148    }
149
150    /// Push a single sample. Returns `true` when `hop_size` samples have
151    /// accumulated since the last trigger (and the buffer is full).
152    pub fn push(&mut self, sample: f32) -> bool {
153        self.buffer[self.write_pos] = sample;
154        self.write_pos = (self.write_pos + 1) % self.window_size;
155        self.samples_since_trigger += 1;
156
157        if !self.filled && self.samples_since_trigger >= self.window_size {
158            self.filled = true;
159        }
160
161        if self.filled && self.samples_since_trigger >= self.hop_size {
162            self.samples_since_trigger = 0;
163            true
164        } else {
165            false
166        }
167    }
168
169    /// Copy the current window (oldest-first) into `dest`.
170    /// `dest` must be at least `window_size` long.
171    /// Uses two contiguous copies instead of per-element modulo.
172    pub fn read_window(&self, dest: &mut [f32]) {
173        debug_assert!(dest.len() >= self.window_size);
174        let start = self.write_pos; // oldest sample
175        let first_len = self.window_size - start;
176        dest[..first_len].copy_from_slice(&self.buffer[start..]);
177        if start > 0 {
178            dest[first_len..self.window_size].copy_from_slice(&self.buffer[..start]);
179        }
180    }
181
182    pub fn reset(&mut self) {
183        self.buffer.fill(0.0);
184        self.write_pos = 0;
185        self.samples_since_trigger = 0;
186        self.filled = false;
187    }
188}
189
190// ============================================================================
191// Dual-Window STFT Framework
192// ============================================================================
193//
194// Decouples frequency resolution from latency by using separate analysis
195// (long) and synthesis (short) windows. The analysis window provides high
196// frequency resolution while the synthesis window determines the output latency.
197
198/// Dual-window STFT processor.
199///
200/// Uses a long analysis window for frequency resolution and a shorter
201/// synthesis window for low-latency output. The output latency equals
202/// the synthesis window size, not the analysis window size.
203pub struct DualWindowStft {
204    analysis_window: Vec<f32>,
205    synthesis_window: Vec<f32>,
206    analysis_size: usize,
207    /// Input ring buffer sized to analysis window
208    input_ring: RingAccumulator,
209    /// Overlap-add output accumulator
210    output_accum: Vec<f32>,
211    output_read_pos: usize,
212    /// FFT processor (analysis size)
213    fft: RealFftProcessor,
214    /// Window read buffer
215    window_buf: Vec<f32>,
216    /// COLA normalization factor
217    #[allow(dead_code)]
218    cola_norm: Vec<f32>,
219}
220
221/// Design a dual-window pair satisfying the COLA (Constant Overlap-Add) condition.
222///
223/// # Arguments
224/// * `analysis_size` - Analysis window length (long, e.g. 1024)
225/// * `synthesis_size` - Synthesis window length (short, e.g. 256)
226/// * `hop_size` - Hop size in samples
227///
228/// # Returns
229/// (analysis_window, synthesis_window) pair
230pub fn design_dual_windows(
231    analysis_size: usize,
232    synthesis_size: usize,
233    hop_size: usize,
234) -> (Vec<f32>, Vec<f32>) {
235    // Analysis window: Hann
236    let w_a = generate_hann_window(analysis_size);
237
238    // Synthesis window: truncated Hann centered in the analysis window,
239    // normalized to satisfy COLA
240    let offset = (analysis_size - synthesis_size) / 2;
241
242    // Start with a Hann window of synthesis_size
243    let w_s_raw = generate_hann_window(synthesis_size);
244
245    // Compute the COLA sum: Σ_k w_a(n - k*hop) * w_s(n - k*hop)
246    // across all hop-shifted positions. We need this to be constant.
247    // Normalize w_s so the sum equals 1.
248    let num_overlaps = analysis_size.div_ceil(hop_size);
249
250    let mut cola_sum = vec![0.0f32; hop_size];
251    for k in 0..num_overlaps {
252        let shift = k * hop_size;
253        for (n, cola_val) in cola_sum.iter_mut().enumerate() {
254            let ana_idx = n + shift;
255            if ana_idx < analysis_size {
256                // Check if this falls within the synthesis window support
257                let syn_idx = ana_idx.wrapping_sub(offset);
258                if syn_idx < synthesis_size {
259                    *cola_val += w_a[ana_idx] * w_s_raw[syn_idx];
260                }
261            }
262        }
263    }
264
265    // Normalize synthesis window
266    let avg_cola: f32 = cola_sum.iter().sum::<f32>() / cola_sum.len() as f32;
267    let norm_factor = if avg_cola > 1e-10 {
268        1.0 / avg_cola
269    } else {
270        1.0
271    };
272
273    let mut w_s = vec![0.0f32; analysis_size];
274    for i in 0..synthesis_size {
275        w_s[offset + i] = w_s_raw[i] * norm_factor;
276    }
277
278    (w_a, w_s)
279}
280
281impl DualWindowStft {
282    /// Create a new dual-window STFT processor.
283    ///
284    /// # Arguments
285    /// * `analysis_size` - Analysis window size (determines frequency resolution)
286    /// * `synthesis_size` - Synthesis window size (determines output latency)
287    /// * `hop_size` - Hop size in samples
288    pub fn new(analysis_size: usize, synthesis_size: usize, hop_size: usize) -> Self {
289        let (analysis_window, synthesis_window) =
290            design_dual_windows(analysis_size, synthesis_size, hop_size);
291
292        let fft = RealFftProcessor::new_bidirectional(analysis_size);
293
294        Self {
295            analysis_window,
296            synthesis_window,
297            analysis_size,
298            input_ring: RingAccumulator::new(analysis_size, hop_size),
299            output_accum: vec![0.0; analysis_size * 3],
300            output_read_pos: 0,
301            fft,
302            window_buf: vec![0.0; analysis_size],
303            cola_norm: vec![1.0; analysis_size],
304        }
305    }
306
307    /// Push a single sample. Returns `true` when a hop boundary is reached.
308    ///
309    /// When `true`, the spectrum is available in `freq_buffer_mut()` for
310    /// in-place modification. Call `synthesize_in_place()` after modifying.
311    pub fn analyze(&mut self, sample: f32) -> bool {
312        if !self.input_ring.push(sample) {
313            return false;
314        }
315
316        // Read the analysis window worth of samples
317        self.input_ring.read_window(&mut self.window_buf);
318
319        // Apply analysis window
320        for i in 0..self.analysis_size {
321            self.fft.time_buffer[i] = self.window_buf[i] * self.analysis_window[i];
322        }
323
324        // Forward FFT
325        self.fft.forward();
326
327        true
328    }
329
330    /// Access the frequency buffer for in-place modification after `analyze()` returns `true`.
331    pub fn freq_buffer_mut(&mut self) -> &mut [Complex<f32>] {
332        &mut self.fft.freq_buffer
333    }
334
335    /// Synthesize output from the current frequency buffer (after in-place modification).
336    ///
337    /// Call this after `analyze()` returns `true` and the spectrum has been modified
338    /// via `freq_buffer_mut()`. The output samples accumulate in the internal buffer
339    /// and can be read via `read_output()`.
340    pub fn synthesize_in_place(&mut self) {
341        // Inverse FFT (operates on self.fft.freq_buffer directly)
342        self.fft.inverse();
343
344        // Apply synthesis window and overlap-add
345        let scale = 1.0 / self.analysis_size as f32;
346        for i in 0..self.analysis_size {
347            let pos = (self.output_read_pos + i) % self.output_accum.len();
348            self.output_accum[pos] += self.fft.time_buffer[i] * self.synthesis_window[i] * scale;
349        }
350    }
351
352    /// Read one output sample. Returns 0.0 if no output is ready yet.
353    pub fn read_output(&mut self) -> f32 {
354        let sample = self.output_accum[self.output_read_pos];
355        self.output_accum[self.output_read_pos] = 0.0;
356        self.output_read_pos = (self.output_read_pos + 1) % self.output_accum.len();
357        sample
358    }
359
360    /// Process a block: analyze, apply user function, synthesize.
361    ///
362    /// # Arguments
363    /// * `input` - Input samples
364    /// * `output` - Output buffer (same length as input)
365    /// * `process_fn` - Function to modify the spectrum (called at each hop boundary)
366    pub fn process_block<F>(&mut self, input: &[f32], output: &mut [f32], mut process_fn: F)
367    where
368        F: FnMut(&mut [Complex<f32>]),
369    {
370        for (i, &sample) in input.iter().enumerate() {
371            if self.analyze(sample) {
372                process_fn(&mut self.fft.freq_buffer);
373                self.synthesize_in_place();
374            }
375            output[i] = self.read_output();
376        }
377    }
378
379    /// Get the output latency in samples.
380    pub fn latency_samples(&self) -> usize {
381        self.analysis_size
382    }
383
384    /// Reset all internal state.
385    pub fn reset(&mut self) {
386        self.input_ring.reset();
387        self.output_accum.fill(0.0);
388        self.output_read_pos = 0;
389    }
390}
391
392// ============================================================================
393// Tests
394// ============================================================================
395
396#[cfg(test)]
397#[allow(clippy::needless_range_loop)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_hann_window_size_and_symmetry() {
403        let window = generate_hann_window(8);
404        assert_eq!(window.len(), 8);
405
406        // Hann window should start near zero and peak at center
407        assert!((window[0] - 0.0).abs() < 0.01);
408        assert!((window[4] - 1.0).abs() < 0.01);
409
410        // Symmetric: w[i] == w[N-i] for periodic Hann
411        for i in 1..4 {
412            assert!(
413                (window[i] - window[8 - i]).abs() < 1e-6,
414                "Window not symmetric at i={}: {} vs {}",
415                i,
416                window[i],
417                window[8 - i]
418            );
419        }
420    }
421
422    #[test]
423    fn test_sqrt_hann_cola_property() {
424        // sqrt(Hann) analysis * sqrt(Hann) synthesis = Hann
425        // Hann has perfect COLA at 50% overlap: w[i] + w[i+N/2] = 1.0
426        let n = 256;
427        let sqrt_window = generate_sqrt_hann_window(n);
428        let hop = n / 2;
429
430        for i in 0..hop {
431            // Product of analysis and synthesis = Hann
432            let hann_i = sqrt_window[i] * sqrt_window[i];
433            let hann_shifted = sqrt_window[i + hop] * sqrt_window[i + hop];
434            let sum = hann_i + hann_shifted;
435            assert!(
436                (sum - 1.0).abs() < 1e-5,
437                "sqrt(Hann) COLA violated at i={}: sum={}, expected 1.0",
438                i,
439                sum
440            );
441        }
442    }
443
444    #[test]
445    fn test_hann_window_cola_property() {
446        // With 50% overlap, w[i] + w[i + N/2] should equal 1.0 (COLA)
447        let n = 256;
448        let window = generate_hann_window(n);
449        let hop = n / 2;
450
451        for i in 0..hop {
452            let sum = window[i] + window[i + hop];
453            assert!(
454                (sum - 1.0).abs() < 1e-5,
455                "COLA violated at i={}: sum={}, expected 1.0",
456                i,
457                sum
458            );
459        }
460    }
461
462    #[test]
463    fn test_symmetric_hann_endpoints_are_zero() {
464        let window = generate_hann_window_symmetric(256);
465        assert!(window[0].abs() < 1e-7, "First sample should be 0");
466        assert!(window[255].abs() < 1e-7, "Last sample should be 0");
467        // Peak at center
468        assert!((window[128] - 1.0).abs() < 0.01);
469    }
470
471    #[test]
472    fn test_symmetric_hann_no_nan_for_small_sizes() {
473        // size=0: empty
474        let w0 = generate_hann_window_symmetric(0);
475        assert!(w0.is_empty());
476
477        // size=1: should be [1.0], not NaN
478        let w1 = generate_hann_window_symmetric(1);
479        assert_eq!(w1.len(), 1);
480        assert!(w1[0].is_finite(), "size=1 produced non-finite: {}", w1[0]);
481        assert!((w1[0] - 1.0).abs() < 1e-6);
482
483        // size=2: endpoints [0, 0] by symmetric formula
484        let w2 = generate_hann_window_symmetric(2);
485        assert_eq!(w2.len(), 2);
486        assert!(w2[0].is_finite());
487        assert!(w2[1].is_finite());
488    }
489
490    #[test]
491    fn test_fft_roundtrip() {
492        let fft_size = 256;
493        let mut fft = RealFftProcessor::new_bidirectional(fft_size);
494
495        // Fill with a known signal
496        let original: Vec<f32> = (0..fft_size)
497            .map(|i| (2.0 * std::f32::consts::PI * 10.0 * i as f32 / fft_size as f32).sin())
498            .collect();
499        fft.time_buffer.copy_from_slice(&original);
500
501        // Forward then inverse
502        fft.forward();
503        fft.inverse();
504
505        // Inverse FFT scales by fft_size, so divide
506        let scale = 1.0 / fft_size as f32;
507        for i in 0..fft_size {
508            let recovered = fft.time_buffer[i] * scale;
509            assert!(
510                (recovered - original[i]).abs() < 1e-4,
511                "FFT roundtrip mismatch at i={}: expected {}, got {}",
512                i,
513                original[i],
514                recovered,
515            );
516        }
517    }
518
519    #[test]
520    fn test_ring_accumulator_trigger_timing() {
521        let window_size = 8;
522        let hop_size = 4;
523        let mut ring = RingAccumulator::new(window_size, hop_size);
524
525        let mut triggers = Vec::new();
526        for i in 0..24 {
527            if ring.push(i as f32) {
528                triggers.push(i);
529            }
530        }
531
532        // First trigger at sample 7 (index 7 = 8th sample, filling window)
533        // Then every hop_size (4) samples: 11, 15, 19, 23
534        assert_eq!(triggers, vec![7, 11, 15, 19, 23]);
535    }
536
537    #[test]
538    fn test_ring_accumulator_window_readout() {
539        let window_size = 4;
540        let hop_size = 2;
541        let mut ring = RingAccumulator::new(window_size, hop_size);
542
543        // Push 6 samples: [0, 1, 2, 3, 4, 5]
544        // After 4 samples, ring is filled. After 6 samples (2 more = hop), trigger.
545        // Ring state: write_pos = 2, buffer = [4, 5, 2, 3]
546        // oldest-first read: [2, 3, 4, 5]
547        for i in 0..6 {
548            ring.push(i as f32);
549        }
550
551        let mut dest = vec![0.0; 4];
552        ring.read_window(&mut dest);
553        assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
554    }
555
556    #[test]
557    fn test_ring_accumulator_reset() {
558        let mut ring = RingAccumulator::new(8, 4);
559
560        // Fill and trigger
561        for i in 0..12 {
562            ring.push(i as f32);
563        }
564        assert!(ring.filled);
565
566        ring.reset();
567        assert!(!ring.filled);
568        assert_eq!(ring.write_pos, 0);
569        assert_eq!(ring.samples_since_trigger, 0);
570
571        // Should not trigger until filled again
572        let mut triggered = false;
573        for _ in 0..4 {
574            triggered |= ring.push(1.0);
575        }
576        assert!(!triggered, "Should not trigger before ring is filled again");
577    }
578
579    #[test]
580    fn test_dual_window_design() {
581        let analysis_size = 1024;
582        let synthesis_size = 256;
583        let hop_size = 128;
584
585        let (w_a, w_s) = design_dual_windows(analysis_size, synthesis_size, hop_size);
586        assert_eq!(w_a.len(), analysis_size);
587        assert_eq!(w_s.len(), analysis_size);
588
589        // Synthesis window should be non-zero only in the center
590        let offset = (analysis_size - synthesis_size) / 2;
591        for i in 0..offset {
592            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero before offset");
593        }
594        for i in (offset + synthesis_size)..analysis_size {
595            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero after support");
596        }
597    }
598
599    #[test]
600    fn test_dual_window_stft_passthrough() {
601        let analysis_size = 512;
602        let synthesis_size = 128;
603        let hop_size = 64;
604
605        let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
606
607        // Generate a tone
608        let num_samples = 4096;
609        let signal: Vec<f32> = (0..num_samples)
610            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
611            .collect();
612
613        let mut output = vec![0.0f32; num_samples];
614
615        // Pass-through (no spectral modification)
616        stft.process_block(&signal, &mut output, |_spectrum| {
617            // Identity: don't modify spectrum
618        });
619
620        // After latency, output should approximate input
621        let latency = stft.latency_samples();
622        let check_start = latency + 512; // skip transient
623        let check_end = num_samples - 512;
624
625        if check_end > check_start {
626            let rms_error: f32 = output[check_start..check_end]
627                .iter()
628                .zip(&signal[check_start - latency..check_end - latency])
629                .map(|(o, s)| (o - s).powi(2))
630                .sum::<f32>()
631                / (check_end - check_start) as f32;
632
633            // Some error is expected from windowing; just verify it's bounded
634            assert!(
635                rms_error < 1.0,
636                "Dual-window STFT passthrough RMS error too high: {rms_error:.6}"
637            );
638        }
639    }
640
641    #[test]
642    fn test_dual_window_stft_reset() {
643        let mut stft = DualWindowStft::new(512, 128, 64);
644
645        // Process some data
646        let signal: Vec<f32> = (0..2048).map(|i| (i as f32 * 0.1).sin()).collect();
647        let mut output = vec![0.0; 2048];
648        stft.process_block(&signal, &mut output, |_| {});
649
650        // Reset
651        stft.reset();
652
653        // Process silence — output should be near zero
654        let silence = vec![0.0f32; 1024];
655        let mut output2 = vec![0.0; 1024];
656        stft.process_block(&silence, &mut output2, |_| {});
657
658        let max_output: f32 = output2.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
659        assert!(
660            max_output < 0.01,
661            "After reset + silence, max output should be ~0, got {max_output}"
662        );
663    }
664}