Skip to main content

math_audio_dsp/
stft.rs

1// ============================================================================
2// Shared STFT Infrastructure
3// ============================================================================
4//
5// Reusable components for STFT-based plugins:
6// - generate_hann_window: Hann window generation for STFT analysis
7// - RealFftProcessor: Thin wrapper around realfft for single-channel use
8// - RingAccumulator: Sample accumulator with hop-based triggering
9
10use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
11use rustfft::num_complex::Complex;
12use std::sync::Arc;
13
14// ============================================================================
15// Hann Window
16// ============================================================================
17
18/// Generate a Hann window of the given size.
19/// Uses N (not N-1) divisor for perfect COLA with 50% overlap.
20pub fn generate_hann_window(size: usize) -> Vec<f32> {
21    (0..size)
22        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos()))
23        .collect()
24}
25
26/// Generate a symmetric Hann window of the given size.
27/// Uses N-1 divisor — suitable for spectral analysis (zero at endpoints).
28pub fn generate_hann_window_symmetric(size: usize) -> Vec<f32> {
29    if size <= 1 {
30        return vec![1.0; size];
31    }
32    let n_minus_1 = (size as f32) - 1.0;
33    (0..size)
34        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / n_minus_1).cos()))
35        .collect()
36}
37
38/// Generate a sqrt(Hann) window for WOLA (Weighted Overlap-Add) processing.
39/// When used as both analysis and synthesis window, the product is Hann,
40/// which has perfect COLA at 50% overlap.
41pub fn generate_sqrt_hann_window(size: usize) -> Vec<f32> {
42    (0..size)
43        .map(|i| {
44            let hann = 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos());
45            hann.sqrt()
46        })
47        .collect()
48}
49
50// ============================================================================
51// RealFftProcessor
52// ============================================================================
53
54/// Thin wrapper around `realfft` encapsulating planner + buffers for
55/// single-channel use. Provides forward (real→complex) and optional
56/// inverse (complex→real) FFT.
57pub struct RealFftProcessor {
58    #[allow(dead_code)]
59    pub fft_size: usize,
60    pub spectrum_size: usize,
61    fft_forward: Arc<dyn RealToComplex<f32>>,
62    fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
63    pub time_buffer: Vec<f32>,
64    pub freq_buffer: Vec<Complex<f32>>,
65}
66
67impl RealFftProcessor {
68    /// Create a forward-only FFT processor (no inverse).
69    pub fn new_forward_only(fft_size: usize) -> Self {
70        let spectrum_size = fft_size / 2 + 1;
71        let mut planner = RealFftPlanner::<f32>::new();
72        let fft_forward = planner.plan_fft_forward(fft_size);
73
74        Self {
75            fft_size,
76            spectrum_size,
77            fft_forward,
78            fft_inverse: None,
79            time_buffer: vec![0.0; fft_size],
80            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
81        }
82    }
83
84    /// Create a bidirectional FFT processor (forward + inverse).
85    #[allow(dead_code)]
86    pub fn new_bidirectional(fft_size: usize) -> Self {
87        let spectrum_size = fft_size / 2 + 1;
88        let mut planner = RealFftPlanner::<f32>::new();
89        let fft_forward = planner.plan_fft_forward(fft_size);
90        let fft_inverse = planner.plan_fft_inverse(fft_size);
91
92        Self {
93            fft_size,
94            spectrum_size,
95            fft_forward,
96            fft_inverse: Some(fft_inverse),
97            time_buffer: vec![0.0; fft_size],
98            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
99        }
100    }
101
102    /// Perform forward FFT: time_buffer → freq_buffer.
103    /// The caller should fill `time_buffer` before calling this.
104    pub fn forward(&mut self) {
105        self.fft_forward
106            .process(&mut self.time_buffer, &mut self.freq_buffer)
107            .expect("FFT forward failed");
108    }
109
110    /// Perform inverse FFT: freq_buffer → time_buffer.
111    /// Panics if this processor was created with `new_forward_only`.
112    #[allow(dead_code)]
113    pub fn inverse(&mut self) {
114        self.fft_inverse
115            .as_ref()
116            .expect("Inverse FFT not available (forward-only processor)")
117            .process(&mut self.freq_buffer, &mut self.time_buffer)
118            .expect("FFT inverse failed");
119    }
120}
121
122// ============================================================================
123// BatchedRealFftProcessor
124// ============================================================================
125
126/// Batched wrapper around `realfft` for independent real FFTs that share the
127/// same FFT size.
128///
129/// Buffers are stored flat in channel-major order:
130/// - `time_buffers`: `[channel][sample]`
131/// - `freq_buffers`: `[channel][bin]`
132///
133/// This keeps each channel contiguous for `realfft`, avoids one allocation per
134/// channel, and reuses the same FFT plans and scratch buffers across all
135/// channels. The FFTs are still computed sequentially; this type is a portable
136/// baseline rather than a platform-specific batched FFT backend.
137pub struct BatchedRealFftProcessor {
138    channels: usize,
139    fft_size: usize,
140    spectrum_size: usize,
141    fft_forward: Arc<dyn RealToComplex<f32>>,
142    fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
143    forward_scratch: Vec<Complex<f32>>,
144    inverse_scratch: Vec<Complex<f32>>,
145    time_buffers: Vec<f32>,
146    freq_buffers: Vec<Complex<f32>>,
147}
148
149impl BatchedRealFftProcessor {
150    /// Create a forward-only batched FFT processor.
151    pub fn new_forward_only(channels: usize, fft_size: usize) -> Self {
152        Self::new(channels, fft_size, false)
153    }
154
155    /// Create a bidirectional batched FFT processor.
156    pub fn new_bidirectional(channels: usize, fft_size: usize) -> Self {
157        Self::new(channels, fft_size, true)
158    }
159
160    fn new(channels: usize, fft_size: usize, include_inverse: bool) -> Self {
161        assert!(
162            channels > 0,
163            "BatchedRealFftProcessor requires at least one channel"
164        );
165
166        let spectrum_size = fft_size / 2 + 1;
167        let mut planner = RealFftPlanner::<f32>::new();
168        let fft_forward = planner.plan_fft_forward(fft_size);
169        let fft_inverse = if include_inverse {
170            Some(planner.plan_fft_inverse(fft_size))
171        } else {
172            None
173        };
174
175        let forward_scratch = vec![Complex::new(0.0, 0.0); fft_forward.get_scratch_len()];
176        let inverse_scratch = fft_inverse
177            .as_ref()
178            .map(|fft| vec![Complex::new(0.0, 0.0); fft.get_scratch_len()])
179            .unwrap_or_default();
180
181        Self {
182            channels,
183            fft_size,
184            spectrum_size,
185            fft_forward,
186            fft_inverse,
187            forward_scratch,
188            inverse_scratch,
189            time_buffers: vec![0.0; channels * fft_size],
190            freq_buffers: vec![Complex::new(0.0, 0.0); channels * spectrum_size],
191        }
192    }
193
194    pub fn channels(&self) -> usize {
195        self.channels
196    }
197
198    pub fn fft_size(&self) -> usize {
199        self.fft_size
200    }
201
202    pub fn spectrum_size(&self) -> usize {
203        self.spectrum_size
204    }
205
206    pub fn time_buffers(&self) -> &[f32] {
207        &self.time_buffers
208    }
209
210    pub fn time_buffers_mut(&mut self) -> &mut [f32] {
211        &mut self.time_buffers
212    }
213
214    pub fn freq_buffers(&self) -> &[Complex<f32>] {
215        &self.freq_buffers
216    }
217
218    pub fn freq_buffers_mut(&mut self) -> &mut [Complex<f32>] {
219        &mut self.freq_buffers
220    }
221
222    pub fn time_channel(&self, ch: usize) -> &[f32] {
223        debug_assert!(ch < self.channels);
224        let range = self.time_range(ch);
225        &self.time_buffers[range]
226    }
227
228    pub fn time_channel_mut(&mut self, ch: usize) -> &mut [f32] {
229        debug_assert!(ch < self.channels);
230        let range = self.time_range(ch);
231        &mut self.time_buffers[range]
232    }
233
234    pub fn freq_channel(&self, ch: usize) -> &[Complex<f32>] {
235        debug_assert!(ch < self.channels);
236        let range = self.freq_range(ch);
237        &self.freq_buffers[range]
238    }
239
240    pub fn freq_channel_mut(&mut self, ch: usize) -> &mut [Complex<f32>] {
241        debug_assert!(ch < self.channels);
242        let range = self.freq_range(ch);
243        &mut self.freq_buffers[range]
244    }
245
246    /// Perform forward FFTs for all channels.
247    /// The caller should fill each time-domain channel buffer before calling this.
248    pub fn forward_all(&mut self) {
249        for ch in 0..self.channels {
250            let time_range = self.time_range(ch);
251            let freq_range = self.freq_range(ch);
252            self.fft_forward
253                .process_with_scratch(
254                    &mut self.time_buffers[time_range],
255                    &mut self.freq_buffers[freq_range],
256                    &mut self.forward_scratch,
257                )
258                .expect("FFT forward failed");
259        }
260    }
261
262    /// Perform inverse FFTs for all channels.
263    /// Panics if this processor was created with `new_forward_only`.
264    pub fn inverse_all(&mut self) {
265        let fft_inverse = self
266            .fft_inverse
267            .as_ref()
268            .expect("Inverse FFT not available (forward-only processor)");
269
270        for ch in 0..self.channels {
271            let time_range = self.time_range(ch);
272            let freq_range = self.freq_range(ch);
273            fft_inverse
274                .process_with_scratch(
275                    &mut self.freq_buffers[freq_range],
276                    &mut self.time_buffers[time_range],
277                    &mut self.inverse_scratch,
278                )
279                .expect("FFT inverse failed");
280        }
281    }
282
283    fn time_range(&self, ch: usize) -> std::ops::Range<usize> {
284        ch * self.fft_size..(ch + 1) * self.fft_size
285    }
286
287    fn freq_range(&self, ch: usize) -> std::ops::Range<usize> {
288        ch * self.spectrum_size..(ch + 1) * self.spectrum_size
289    }
290}
291
292// ============================================================================
293// RingAccumulator
294// ============================================================================
295
296/// Sample accumulator with hop-based triggering.
297/// Accumulates samples into a circular buffer and signals when `hop_size`
298/// new samples have been written (and the buffer has been filled at least once).
299pub struct RingAccumulator {
300    buffer: Vec<f32>,
301    write_pos: usize,
302    samples_since_trigger: usize,
303    filled: bool,
304    window_size: usize,
305    hop_size: usize,
306}
307
308impl RingAccumulator {
309    pub fn new(window_size: usize, hop_size: usize) -> Self {
310        Self {
311            buffer: vec![0.0; window_size],
312            write_pos: 0,
313            samples_since_trigger: 0,
314            filled: false,
315            window_size,
316            hop_size,
317        }
318    }
319
320    /// Push a single sample. Returns `true` when `hop_size` samples have
321    /// accumulated since the last trigger (and the buffer is full).
322    pub fn push(&mut self, sample: f32) -> bool {
323        self.buffer[self.write_pos] = sample;
324        self.write_pos = (self.write_pos + 1) % self.window_size;
325        self.samples_since_trigger += 1;
326
327        if !self.filled && self.samples_since_trigger >= self.window_size {
328            self.filled = true;
329        }
330
331        if self.filled && self.samples_since_trigger >= self.hop_size {
332            self.samples_since_trigger = 0;
333            true
334        } else {
335            false
336        }
337    }
338
339    /// Copy the current window (oldest-first) into `dest`.
340    /// `dest` must be at least `window_size` long.
341    /// Uses two contiguous copies instead of per-element modulo.
342    pub fn read_window(&self, dest: &mut [f32]) {
343        debug_assert!(dest.len() >= self.window_size);
344        let start = self.write_pos; // oldest sample
345        let first_len = self.window_size - start;
346        dest[..first_len].copy_from_slice(&self.buffer[start..]);
347        if start > 0 {
348            dest[first_len..self.window_size].copy_from_slice(&self.buffer[..start]);
349        }
350    }
351
352    pub fn reset(&mut self) {
353        self.buffer.fill(0.0);
354        self.write_pos = 0;
355        self.samples_since_trigger = 0;
356        self.filled = false;
357    }
358}
359
360// ============================================================================
361// Dual-Window STFT Framework
362// ============================================================================
363//
364// Decouples frequency resolution from latency by using separate analysis
365// (long) and synthesis (short) windows. The analysis window provides high
366// frequency resolution while the synthesis window determines the output latency.
367
368/// Dual-window STFT processor.
369///
370/// Uses a long analysis window for frequency resolution and a shorter
371/// synthesis window for low-latency output. The output latency equals
372/// the **analysis** window size: the ring buffer must fill `analysis_size`
373/// samples before the first hop fires and produces output.
374pub struct DualWindowStft {
375    analysis_window: Vec<f32>,
376    synthesis_window: Vec<f32>,
377    analysis_size: usize,
378    /// Input ring buffer sized to analysis window
379    input_ring: RingAccumulator,
380    /// Overlap-add output accumulator
381    output_accum: Vec<f32>,
382    output_read_pos: usize,
383    /// FFT processor (analysis size)
384    fft: RealFftProcessor,
385    /// Window read buffer
386    window_buf: Vec<f32>,
387}
388
389/// Design a dual-window pair satisfying the COLA (Constant Overlap-Add) condition.
390///
391/// # Arguments
392/// * `analysis_size` - Analysis window length (long, e.g. 1024)
393/// * `synthesis_size` - Synthesis window length (short, e.g. 256)
394/// * `hop_size` - Hop size in samples
395///
396/// # Returns
397/// (analysis_window, synthesis_window) pair
398pub fn design_dual_windows(
399    analysis_size: usize,
400    synthesis_size: usize,
401    hop_size: usize,
402) -> (Vec<f32>, Vec<f32>) {
403    // Analysis window: Hann
404    let w_a = generate_hann_window(analysis_size);
405
406    // Synthesis window: truncated Hann centered in the analysis window,
407    // normalized to satisfy COLA
408    let offset = (analysis_size - synthesis_size) / 2;
409
410    // Start with a Hann window of synthesis_size
411    let w_s_raw = generate_hann_window(synthesis_size);
412
413    // Compute the COLA sum: Σ_k w_a(n - k*hop) * w_s(n - k*hop)
414    // across all hop-shifted positions. We need this to be constant.
415    // Normalize w_s so the sum equals 1.
416    let num_overlaps = analysis_size.div_ceil(hop_size);
417
418    let mut cola_sum = vec![0.0f32; hop_size];
419    for k in 0..num_overlaps {
420        let shift = k * hop_size;
421        for (n, cola_val) in cola_sum.iter_mut().enumerate() {
422            let ana_idx = n + shift;
423            if ana_idx < analysis_size {
424                // Check if this falls within the synthesis window support
425                let syn_idx = ana_idx.wrapping_sub(offset);
426                if syn_idx < synthesis_size {
427                    *cola_val += w_a[ana_idx] * w_s_raw[syn_idx];
428                }
429            }
430        }
431    }
432
433    // Normalize synthesis window
434    let avg_cola: f32 = cola_sum.iter().sum::<f32>() / cola_sum.len() as f32;
435    let norm_factor = if avg_cola > 1e-10 {
436        1.0 / avg_cola
437    } else {
438        1.0
439    };
440
441    let mut w_s = vec![0.0f32; analysis_size];
442    for i in 0..synthesis_size {
443        w_s[offset + i] = w_s_raw[i] * norm_factor;
444    }
445
446    (w_a, w_s)
447}
448
449impl DualWindowStft {
450    /// Create a new dual-window STFT processor.
451    ///
452    /// # Arguments
453    /// * `analysis_size` - Analysis window size (determines frequency resolution)
454    /// * `synthesis_size` - Synthesis window size (determines output latency)
455    /// * `hop_size` - Hop size in samples
456    pub fn new(analysis_size: usize, synthesis_size: usize, hop_size: usize) -> Self {
457        let (analysis_window, synthesis_window) =
458            design_dual_windows(analysis_size, synthesis_size, hop_size);
459
460        let fft = RealFftProcessor::new_bidirectional(analysis_size);
461
462        Self {
463            analysis_window,
464            synthesis_window,
465            analysis_size,
466            input_ring: RingAccumulator::new(analysis_size, hop_size),
467            output_accum: vec![0.0; analysis_size * 3],
468            output_read_pos: 0,
469            fft,
470            window_buf: vec![0.0; analysis_size],
471        }
472    }
473
474    /// Push a single sample. Returns `true` when a hop boundary is reached.
475    ///
476    /// When `true`, the spectrum is available in `freq_buffer_mut()` for
477    /// in-place modification. Call `synthesize_in_place()` after modifying.
478    pub fn analyze(&mut self, sample: f32) -> bool {
479        if !self.input_ring.push(sample) {
480            return false;
481        }
482
483        // Read the analysis window worth of samples
484        self.input_ring.read_window(&mut self.window_buf);
485
486        // Apply analysis window
487        for i in 0..self.analysis_size {
488            self.fft.time_buffer[i] = self.window_buf[i] * self.analysis_window[i];
489        }
490
491        // Forward FFT
492        self.fft.forward();
493
494        true
495    }
496
497    /// Access the frequency buffer for in-place modification after `analyze()` returns `true`.
498    pub fn freq_buffer_mut(&mut self) -> &mut [Complex<f32>] {
499        &mut self.fft.freq_buffer
500    }
501
502    /// Synthesize output from the current frequency buffer (after in-place modification).
503    ///
504    /// Call this after `analyze()` returns `true` and the spectrum has been modified
505    /// via `freq_buffer_mut()`. The output samples accumulate in the internal buffer
506    /// and can be read via `read_output()`.
507    ///
508    /// # Scaling convention
509    /// `realfft`'s IFFT output is unnormalized: `IFFT(FFT(x))[n] = N * x[n]`.
510    /// `design_dual_windows` normalizes the synthesis window so that the COLA
511    /// overlap-add sum equals 1 assuming a hypothetical normalized IFFT (output = 1).
512    /// The `1/N` factor corrects for `realfft`'s unnormalized output.
513    /// The two normalizations serve distinct roles: synthesis window → COLA unity;
514    /// `1/N` → IFFT scale convention. Neither makes the other redundant.
515    pub fn synthesize_in_place(&mut self) {
516        // Inverse FFT (operates on self.fft.freq_buffer directly)
517        self.fft.inverse();
518
519        // Apply synthesis window and overlap-add.
520        // 1/analysis_size compensates for realfft's unnormalized IFFT (output × N).
521        let scale = 1.0 / self.analysis_size as f32;
522        for i in 0..self.analysis_size {
523            let pos = (self.output_read_pos + i) % self.output_accum.len();
524            self.output_accum[pos] += self.fft.time_buffer[i] * self.synthesis_window[i] * scale;
525        }
526    }
527
528    /// Read one output sample. Returns 0.0 if no output is ready yet.
529    pub fn read_output(&mut self) -> f32 {
530        let sample = self.output_accum[self.output_read_pos];
531        self.output_accum[self.output_read_pos] = 0.0;
532        self.output_read_pos = (self.output_read_pos + 1) % self.output_accum.len();
533        sample
534    }
535
536    /// Process a block: analyze, apply user function, synthesize.
537    ///
538    /// # Arguments
539    /// * `input` - Input samples
540    /// * `output` - Output buffer (same length as input)
541    /// * `process_fn` - Function to modify the spectrum (called at each hop boundary)
542    pub fn process_block<F>(&mut self, input: &[f32], output: &mut [f32], mut process_fn: F)
543    where
544        F: FnMut(&mut [Complex<f32>]),
545    {
546        for (i, &sample) in input.iter().enumerate() {
547            if self.analyze(sample) {
548                process_fn(&mut self.fft.freq_buffer);
549                self.synthesize_in_place();
550            }
551            output[i] = self.read_output();
552        }
553    }
554
555    /// Get the output latency in samples.
556    ///
557    /// Returns the analysis window size. The ring accumulator must fill
558    /// `analysis_size` samples before the first hop triggers, so valid
559    /// output first appears at sample index `analysis_size`.
560    pub fn latency_samples(&self) -> usize {
561        self.analysis_size
562    }
563
564    /// Reset all internal state.
565    pub fn reset(&mut self) {
566        self.input_ring.reset();
567        self.output_accum.fill(0.0);
568        self.output_read_pos = 0;
569    }
570}
571
572// ============================================================================
573// Tests
574// ============================================================================
575
576#[cfg(test)]
577#[allow(clippy::needless_range_loop)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_hann_window_size_and_symmetry() {
583        let window = generate_hann_window(8);
584        assert_eq!(window.len(), 8);
585
586        // Hann window should start near zero and peak at center
587        assert!((window[0] - 0.0).abs() < 0.01);
588        assert!((window[4] - 1.0).abs() < 0.01);
589
590        // Symmetric: w[i] == w[N-i] for periodic Hann
591        for i in 1..4 {
592            assert!(
593                (window[i] - window[8 - i]).abs() < 1e-6,
594                "Window not symmetric at i={}: {} vs {}",
595                i,
596                window[i],
597                window[8 - i]
598            );
599        }
600    }
601
602    #[test]
603    fn test_sqrt_hann_cola_property() {
604        // sqrt(Hann) analysis * sqrt(Hann) synthesis = Hann
605        // Hann has perfect COLA at 50% overlap: w[i] + w[i+N/2] = 1.0
606        let n = 256;
607        let sqrt_window = generate_sqrt_hann_window(n);
608        let hop = n / 2;
609
610        for i in 0..hop {
611            // Product of analysis and synthesis = Hann
612            let hann_i = sqrt_window[i] * sqrt_window[i];
613            let hann_shifted = sqrt_window[i + hop] * sqrt_window[i + hop];
614            let sum = hann_i + hann_shifted;
615            assert!(
616                (sum - 1.0).abs() < 1e-5,
617                "sqrt(Hann) COLA violated at i={}: sum={}, expected 1.0",
618                i,
619                sum
620            );
621        }
622    }
623
624    #[test]
625    fn test_hann_window_cola_property() {
626        // With 50% overlap, w[i] + w[i + N/2] should equal 1.0 (COLA)
627        let n = 256;
628        let window = generate_hann_window(n);
629        let hop = n / 2;
630
631        for i in 0..hop {
632            let sum = window[i] + window[i + hop];
633            assert!(
634                (sum - 1.0).abs() < 1e-5,
635                "COLA violated at i={}: sum={}, expected 1.0",
636                i,
637                sum
638            );
639        }
640    }
641
642    #[test]
643    fn test_symmetric_hann_endpoints_are_zero() {
644        let window = generate_hann_window_symmetric(256);
645        assert!(window[0].abs() < 1e-7, "First sample should be 0");
646        assert!(window[255].abs() < 1e-7, "Last sample should be 0");
647        // Peak at center
648        assert!((window[128] - 1.0).abs() < 0.01);
649    }
650
651    #[test]
652    fn test_symmetric_hann_no_nan_for_small_sizes() {
653        // size=0: empty
654        let w0 = generate_hann_window_symmetric(0);
655        assert!(w0.is_empty());
656
657        // size=1: should be [1.0], not NaN
658        let w1 = generate_hann_window_symmetric(1);
659        assert_eq!(w1.len(), 1);
660        assert!(w1[0].is_finite(), "size=1 produced non-finite: {}", w1[0]);
661        assert!((w1[0] - 1.0).abs() < 1e-6);
662
663        // size=2: endpoints [0, 0] by symmetric formula
664        let w2 = generate_hann_window_symmetric(2);
665        assert_eq!(w2.len(), 2);
666        assert!(w2[0].is_finite());
667        assert!(w2[1].is_finite());
668    }
669
670    #[test]
671    fn test_fft_roundtrip() {
672        let fft_size = 256;
673        let mut fft = RealFftProcessor::new_bidirectional(fft_size);
674
675        // Fill with a known signal
676        let original: Vec<f32> = (0..fft_size)
677            .map(|i| (2.0 * std::f32::consts::PI * 10.0 * i as f32 / fft_size as f32).sin())
678            .collect();
679        fft.time_buffer.copy_from_slice(&original);
680
681        // Forward then inverse
682        fft.forward();
683        fft.inverse();
684
685        // Inverse FFT scales by fft_size, so divide
686        let scale = 1.0 / fft_size as f32;
687        for i in 0..fft_size {
688            let recovered = fft.time_buffer[i] * scale;
689            assert!(
690                (recovered - original[i]).abs() < 1e-4,
691                "FFT roundtrip mismatch at i={}: expected {}, got {}",
692                i,
693                original[i],
694                recovered,
695            );
696        }
697    }
698
699    #[test]
700    fn test_ring_accumulator_trigger_timing() {
701        let window_size = 8;
702        let hop_size = 4;
703        let mut ring = RingAccumulator::new(window_size, hop_size);
704
705        let mut triggers = Vec::new();
706        for i in 0..24 {
707            if ring.push(i as f32) {
708                triggers.push(i);
709            }
710        }
711
712        // First trigger at sample 7 (index 7 = 8th sample, filling window)
713        // Then every hop_size (4) samples: 11, 15, 19, 23
714        assert_eq!(triggers, vec![7, 11, 15, 19, 23]);
715    }
716
717    #[test]
718    fn test_ring_accumulator_window_readout() {
719        let window_size = 4;
720        let hop_size = 2;
721        let mut ring = RingAccumulator::new(window_size, hop_size);
722
723        // Push 6 samples: [0, 1, 2, 3, 4, 5]
724        // After 4 samples, ring is filled. After 6 samples (2 more = hop), trigger.
725        // Ring state: write_pos = 2, buffer = [4, 5, 2, 3]
726        // oldest-first read: [2, 3, 4, 5]
727        for i in 0..6 {
728            ring.push(i as f32);
729        }
730
731        let mut dest = vec![0.0; 4];
732        ring.read_window(&mut dest);
733        assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
734    }
735
736    #[test]
737    fn test_ring_accumulator_reset() {
738        let mut ring = RingAccumulator::new(8, 4);
739
740        // Fill and trigger
741        for i in 0..12 {
742            ring.push(i as f32);
743        }
744        assert!(ring.filled);
745
746        ring.reset();
747        assert!(!ring.filled);
748        assert_eq!(ring.write_pos, 0);
749        assert_eq!(ring.samples_since_trigger, 0);
750
751        // Should not trigger until filled again
752        let mut triggered = false;
753        for _ in 0..4 {
754            triggered |= ring.push(1.0);
755        }
756        assert!(!triggered, "Should not trigger before ring is filled again");
757    }
758
759    #[test]
760    fn test_dual_window_design() {
761        let analysis_size = 1024;
762        let synthesis_size = 256;
763        let hop_size = 128;
764
765        let (w_a, w_s) = design_dual_windows(analysis_size, synthesis_size, hop_size);
766        assert_eq!(w_a.len(), analysis_size);
767        assert_eq!(w_s.len(), analysis_size);
768
769        // Synthesis window should be non-zero only in the center
770        let offset = (analysis_size - synthesis_size) / 2;
771        for i in 0..offset {
772            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero before offset");
773        }
774        for i in (offset + synthesis_size)..analysis_size {
775            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero after support");
776        }
777    }
778
779    #[test]
780    fn test_dual_window_stft_passthrough() {
781        let analysis_size = 512;
782        let synthesis_size = 128;
783        let hop_size = 64;
784
785        let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
786
787        // Generate a tone
788        let num_samples = 4096;
789        let signal: Vec<f32> = (0..num_samples)
790            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
791            .collect();
792
793        let mut output = vec![0.0f32; num_samples];
794
795        // Pass-through (no spectral modification)
796        stft.process_block(&signal, &mut output, |_spectrum| {
797            // Identity: don't modify spectrum
798        });
799
800        // After latency, output should approximate input
801        let latency = stft.latency_samples();
802        let check_start = latency + 512; // skip transient
803        let check_end = num_samples - 512;
804
805        if check_end > check_start {
806            let rms_error: f32 = output[check_start..check_end]
807                .iter()
808                .zip(&signal[check_start - latency..check_end - latency])
809                .map(|(o, s)| (o - s).powi(2))
810                .sum::<f32>()
811                / (check_end - check_start) as f32;
812
813            // Some error is expected from windowing; just verify it's bounded
814            assert!(
815                rms_error < 1.0,
816                "Dual-window STFT passthrough RMS error too high: {rms_error:.6}"
817            );
818        }
819    }
820
821    /// Round-trip identity test for DualWindowStft.
822    ///
823    /// A pass-through (no spectral modification) on a DC signal must recover
824    /// unit amplitude after the latency period. This verifies that the synthesis
825    /// window COLA normalization and the 1/N IFFT correction yield unity gain.
826    #[test]
827    fn test_dual_window_stft_roundtrip_unity_gain() {
828        let analysis_size = 512;
829        let synthesis_size = 128;
830        let hop_size = 64;
831
832        let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
833
834        // DC signal at amplitude 0.5 — easy to verify amplitude recovery.
835        let num_samples = 6144;
836        let signal = vec![0.5_f32; num_samples];
837        let mut output = vec![0.0_f32; num_samples];
838
839        stft.process_block(&signal, &mut output, |_spectrum| {});
840
841        // Skip latency + two extra analysis windows for transient to settle.
842        let latency = stft.latency_samples();
843        let check_start = latency + 2 * analysis_size;
844        let check_end = num_samples - analysis_size;
845
846        if check_end > check_start {
847            let rms_error: f32 = output[check_start..check_end]
848                .iter()
849                .zip(&signal[check_start - latency..check_end - latency])
850                .map(|(o, s)| (o - s).powi(2))
851                .sum::<f32>()
852                / (check_end - check_start) as f32;
853
854            assert!(
855                rms_error < 1e-4,
856                "DualWindowStft round-trip RMS error too high ({rms_error:.6}); \
857                 IFFT scale or synthesis-window normalization may be wrong"
858            );
859        }
860    }
861
862    #[test]
863    fn test_dual_window_stft_latency_reports_analysis_fill_delay() {
864        let stft = DualWindowStft::new(512, 128, 64);
865        assert_eq!(stft.latency_samples(), 512);
866    }
867
868    #[test]
869    fn test_dual_window_stft_reset() {
870        let mut stft = DualWindowStft::new(512, 128, 64);
871
872        // Process some data
873        let signal: Vec<f32> = (0..2048).map(|i| (i as f32 * 0.1).sin()).collect();
874        let mut output = vec![0.0; 2048];
875        stft.process_block(&signal, &mut output, |_| {});
876
877        // Reset
878        stft.reset();
879
880        // Process silence — output should be near zero
881        let silence = vec![0.0f32; 1024];
882        let mut output2 = vec![0.0; 1024];
883        stft.process_block(&silence, &mut output2, |_| {});
884
885        let max_output: f32 = output2.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
886        assert!(
887            max_output < 0.01,
888            "After reset + silence, max output should be ~0, got {max_output}"
889        );
890    }
891}
892
893#[cfg(test)]
894mod batched_real_fft_processor_tests {
895    use super::*;
896
897    const EPSILON: f32 = 1e-3;
898
899    fn fill_signal(buffer: &mut [f32], ch: usize) {
900        for (i, sample) in buffer.iter_mut().enumerate() {
901            let phase = i as f32 * 0.13 + ch as f32 * 0.37;
902            *sample = phase.sin() + 0.25 * (phase * 2.7).cos();
903        }
904    }
905
906    fn assert_complex_close(actual: Complex<f32>, expected: Complex<f32>) {
907        assert!((actual.re - expected.re).abs() <= EPSILON);
908        assert!((actual.im - expected.im).abs() <= EPSILON);
909    }
910
911    fn assert_slice_close(actual: &[f32], expected: &[f32]) {
912        assert_eq!(actual.len(), expected.len());
913        for (actual, expected) in actual.iter().zip(expected) {
914            assert!((actual - expected).abs() <= EPSILON);
915        }
916    }
917
918    #[test]
919    fn forward_matches_independent_processors_for_representative_channel_counts() {
920        for channels in [1, 2, 8, 16, 24] {
921            let fft_size = 64;
922            let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
923
924            for ch in 0..channels {
925                fill_signal(batched.time_channel_mut(ch), ch);
926            }
927
928            let inputs = batched.time_buffers().to_vec();
929            batched.forward_all();
930
931            for ch in 0..channels {
932                let mut independent = RealFftProcessor::new_forward_only(fft_size);
933                independent
934                    .time_buffer
935                    .copy_from_slice(&inputs[ch * fft_size..(ch + 1) * fft_size]);
936                independent.forward();
937
938                for (actual, expected) in batched
939                    .freq_channel(ch)
940                    .iter()
941                    .zip(&independent.freq_buffer)
942                {
943                    assert_complex_close(*actual, *expected);
944                }
945            }
946        }
947    }
948
949    #[test]
950    fn bidirectional_round_trip_restores_each_channel_after_scaling() {
951        let channels = 8;
952        let fft_size = 128;
953        let mut batched = BatchedRealFftProcessor::new_bidirectional(channels, fft_size);
954
955        for ch in 0..channels {
956            fill_signal(batched.time_channel_mut(ch), ch);
957        }
958
959        let original = batched.time_buffers().to_vec();
960        batched.forward_all();
961        batched.inverse_all();
962
963        for ch in 0..channels {
964            let mut expected = original[ch * fft_size..(ch + 1) * fft_size].to_vec();
965            for sample in &mut expected {
966                *sample *= fft_size as f32;
967            }
968            assert_slice_close(batched.time_channel(ch), &expected);
969        }
970    }
971
972    #[test]
973    fn channel_slices_use_flat_channel_major_layout() {
974        let channels = 3;
975        let fft_size = 4;
976        let spectrum_size = fft_size / 2 + 1;
977        let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
978
979        for ch in 0..channels {
980            for (i, sample) in batched.time_channel_mut(ch).iter_mut().enumerate() {
981                *sample = (ch * 10 + i) as f32;
982            }
983            for (i, bin) in batched.freq_channel_mut(ch).iter_mut().enumerate() {
984                *bin = Complex::new((ch * 10 + i) as f32, ch as f32);
985            }
986        }
987
988        assert_eq!(
989            batched.time_buffers(),
990            &[
991                0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0, 20.0, 21.0, 22.0, 23.0
992            ]
993        );
994        assert_eq!(batched.freq_buffers().len(), channels * spectrum_size);
995        assert_eq!(batched.freq_channel(2)[1], Complex::new(21.0, 2.0));
996    }
997}