Skip to main content

math_audio_dsp/
stft.rs

1// ============================================================================
2// Shared STFT Infrastructure
3// ============================================================================
4//
5// Reusable components for STFT-based plugins:
6// - generate_hann_window: Hann window generation for STFT analysis
7// - RealFftProcessor: Thin wrapper around realfft for single-channel use
8// - RingAccumulator: Sample accumulator with hop-based triggering
9
10use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
11use rustfft::num_complex::Complex;
12use std::sync::Arc;
13
14// ============================================================================
15// Hann Window
16// ============================================================================
17
18/// Generate a Hann window of the given size.
19/// Uses N (not N-1) divisor for perfect COLA with 50% overlap.
20pub fn generate_hann_window(size: usize) -> Vec<f32> {
21    (0..size)
22        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos()))
23        .collect()
24}
25
26/// Generate a symmetric Hann window of the given size.
27/// Uses N-1 divisor — suitable for spectral analysis (zero at endpoints).
28pub fn generate_hann_window_symmetric(size: usize) -> Vec<f32> {
29    if size <= 1 {
30        return vec![1.0; size];
31    }
32    let n_minus_1 = (size as f32) - 1.0;
33    (0..size)
34        .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / n_minus_1).cos()))
35        .collect()
36}
37
38/// Generate a sqrt(Hann) window for WOLA (Weighted Overlap-Add) processing.
39/// When used as both analysis and synthesis window, the product is Hann,
40/// which has perfect COLA at 50% overlap.
41pub fn generate_sqrt_hann_window(size: usize) -> Vec<f32> {
42    (0..size)
43        .map(|i| {
44            let hann = 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos());
45            hann.sqrt()
46        })
47        .collect()
48}
49
50// ============================================================================
51// RealFftProcessor
52// ============================================================================
53
54/// Thin wrapper around `realfft` encapsulating planner + buffers for
55/// single-channel use. Provides forward (real→complex) and optional
56/// inverse (complex→real) FFT.
57pub struct RealFftProcessor {
58    #[allow(dead_code)]
59    pub fft_size: usize,
60    pub spectrum_size: usize,
61    fft_forward: Arc<dyn RealToComplex<f32>>,
62    fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
63    pub time_buffer: Vec<f32>,
64    pub freq_buffer: Vec<Complex<f32>>,
65}
66
67impl RealFftProcessor {
68    /// Create a forward-only FFT processor (no inverse).
69    pub fn new_forward_only(fft_size: usize) -> Self {
70        let spectrum_size = fft_size / 2 + 1;
71        let mut planner = RealFftPlanner::<f32>::new();
72        let fft_forward = planner.plan_fft_forward(fft_size);
73
74        Self {
75            fft_size,
76            spectrum_size,
77            fft_forward,
78            fft_inverse: None,
79            time_buffer: vec![0.0; fft_size],
80            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
81        }
82    }
83
84    /// Create a bidirectional FFT processor (forward + inverse).
85    #[allow(dead_code)]
86    pub fn new_bidirectional(fft_size: usize) -> Self {
87        let spectrum_size = fft_size / 2 + 1;
88        let mut planner = RealFftPlanner::<f32>::new();
89        let fft_forward = planner.plan_fft_forward(fft_size);
90        let fft_inverse = planner.plan_fft_inverse(fft_size);
91
92        Self {
93            fft_size,
94            spectrum_size,
95            fft_forward,
96            fft_inverse: Some(fft_inverse),
97            time_buffer: vec![0.0; fft_size],
98            freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
99        }
100    }
101
102    /// Perform forward FFT: time_buffer → freq_buffer.
103    /// The caller should fill `time_buffer` before calling this.
104    pub fn forward(&mut self) {
105        self.fft_forward
106            .process(&mut self.time_buffer, &mut self.freq_buffer)
107            .expect("FFT forward failed");
108    }
109
110    /// Perform inverse FFT: freq_buffer → time_buffer.
111    /// Panics if this processor was created with `new_forward_only`.
112    #[allow(dead_code)]
113    pub fn inverse(&mut self) {
114        self.fft_inverse
115            .as_ref()
116            .expect("Inverse FFT not available (forward-only processor)")
117            .process(&mut self.freq_buffer, &mut self.time_buffer)
118            .expect("FFT inverse failed");
119    }
120}
121
122// ============================================================================
123// BatchedRealFftProcessor
124// ============================================================================
125
126/// Batched wrapper around `realfft` for independent real FFTs that share the
127/// same FFT size.
128///
129/// Buffers are stored flat in channel-major order:
130/// - `time_buffers`: `[channel][sample]`
131/// - `freq_buffers`: `[channel][bin]`
132///
133/// This keeps each channel contiguous for `realfft`, avoids one allocation per
134/// channel, and reuses the same FFT plans and scratch buffers across all
135/// channels. The FFTs are still computed sequentially; this type is a portable
136/// baseline rather than a platform-specific batched FFT backend.
137pub struct BatchedRealFftProcessor {
138    channels: usize,
139    fft_size: usize,
140    spectrum_size: usize,
141    fft_forward: Arc<dyn RealToComplex<f32>>,
142    fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
143    forward_scratch: Vec<Complex<f32>>,
144    inverse_scratch: Vec<Complex<f32>>,
145    time_buffers: Vec<f32>,
146    freq_buffers: Vec<Complex<f32>>,
147}
148
149impl BatchedRealFftProcessor {
150    /// Create a forward-only batched FFT processor.
151    pub fn new_forward_only(channels: usize, fft_size: usize) -> Self {
152        Self::new(channels, fft_size, false)
153    }
154
155    /// Create a bidirectional batched FFT processor.
156    pub fn new_bidirectional(channels: usize, fft_size: usize) -> Self {
157        Self::new(channels, fft_size, true)
158    }
159
160    fn new(channels: usize, fft_size: usize, include_inverse: bool) -> Self {
161        assert!(
162            channels > 0,
163            "BatchedRealFftProcessor requires at least one channel"
164        );
165
166        let spectrum_size = fft_size / 2 + 1;
167        let mut planner = RealFftPlanner::<f32>::new();
168        let fft_forward = planner.plan_fft_forward(fft_size);
169        let fft_inverse = if include_inverse {
170            Some(planner.plan_fft_inverse(fft_size))
171        } else {
172            None
173        };
174
175        let forward_scratch = vec![Complex::new(0.0, 0.0); fft_forward.get_scratch_len()];
176        let inverse_scratch = fft_inverse
177            .as_ref()
178            .map(|fft| vec![Complex::new(0.0, 0.0); fft.get_scratch_len()])
179            .unwrap_or_default();
180
181        Self {
182            channels,
183            fft_size,
184            spectrum_size,
185            fft_forward,
186            fft_inverse,
187            forward_scratch,
188            inverse_scratch,
189            time_buffers: vec![0.0; channels * fft_size],
190            freq_buffers: vec![Complex::new(0.0, 0.0); channels * spectrum_size],
191        }
192    }
193
194    pub fn channels(&self) -> usize {
195        self.channels
196    }
197
198    pub fn fft_size(&self) -> usize {
199        self.fft_size
200    }
201
202    pub fn spectrum_size(&self) -> usize {
203        self.spectrum_size
204    }
205
206    pub fn time_buffers(&self) -> &[f32] {
207        &self.time_buffers
208    }
209
210    pub fn time_buffers_mut(&mut self) -> &mut [f32] {
211        &mut self.time_buffers
212    }
213
214    pub fn freq_buffers(&self) -> &[Complex<f32>] {
215        &self.freq_buffers
216    }
217
218    pub fn freq_buffers_mut(&mut self) -> &mut [Complex<f32>] {
219        &mut self.freq_buffers
220    }
221
222    pub fn time_channel(&self, ch: usize) -> &[f32] {
223        debug_assert!(ch < self.channels);
224        let range = self.time_range(ch);
225        &self.time_buffers[range]
226    }
227
228    pub fn time_channel_mut(&mut self, ch: usize) -> &mut [f32] {
229        debug_assert!(ch < self.channels);
230        let range = self.time_range(ch);
231        &mut self.time_buffers[range]
232    }
233
234    pub fn freq_channel(&self, ch: usize) -> &[Complex<f32>] {
235        debug_assert!(ch < self.channels);
236        let range = self.freq_range(ch);
237        &self.freq_buffers[range]
238    }
239
240    pub fn freq_channel_mut(&mut self, ch: usize) -> &mut [Complex<f32>] {
241        debug_assert!(ch < self.channels);
242        let range = self.freq_range(ch);
243        &mut self.freq_buffers[range]
244    }
245
246    /// Perform forward FFTs for all channels.
247    /// The caller should fill each time-domain channel buffer before calling this.
248    pub fn forward_all(&mut self) {
249        for ch in 0..self.channels {
250            let time_range = self.time_range(ch);
251            let freq_range = self.freq_range(ch);
252            self.fft_forward
253                .process_with_scratch(
254                    &mut self.time_buffers[time_range],
255                    &mut self.freq_buffers[freq_range],
256                    &mut self.forward_scratch,
257                )
258                .expect("FFT forward failed");
259        }
260    }
261
262    /// Perform inverse FFTs for all channels.
263    /// Panics if this processor was created with `new_forward_only`.
264    pub fn inverse_all(&mut self) {
265        let fft_inverse = self
266            .fft_inverse
267            .as_ref()
268            .expect("Inverse FFT not available (forward-only processor)");
269
270        for ch in 0..self.channels {
271            let time_range = self.time_range(ch);
272            let freq_range = self.freq_range(ch);
273            fft_inverse
274                .process_with_scratch(
275                    &mut self.freq_buffers[freq_range],
276                    &mut self.time_buffers[time_range],
277                    &mut self.inverse_scratch,
278                )
279                .expect("FFT inverse failed");
280        }
281    }
282
283    fn time_range(&self, ch: usize) -> std::ops::Range<usize> {
284        ch * self.fft_size..(ch + 1) * self.fft_size
285    }
286
287    fn freq_range(&self, ch: usize) -> std::ops::Range<usize> {
288        ch * self.spectrum_size..(ch + 1) * self.spectrum_size
289    }
290}
291
292// ============================================================================
293// RingAccumulator
294// ============================================================================
295
296/// Sample accumulator with hop-based triggering.
297/// Accumulates samples into a circular buffer and signals when `hop_size`
298/// new samples have been written (and the buffer has been filled at least once).
299pub struct RingAccumulator {
300    buffer: Vec<f32>,
301    write_pos: usize,
302    samples_since_trigger: usize,
303    filled: bool,
304    window_size: usize,
305    hop_size: usize,
306}
307
308impl RingAccumulator {
309    pub fn new(window_size: usize, hop_size: usize) -> Self {
310        Self {
311            buffer: vec![0.0; window_size],
312            write_pos: 0,
313            samples_since_trigger: 0,
314            filled: false,
315            window_size,
316            hop_size,
317        }
318    }
319
320    /// Push a single sample. Returns `true` when `hop_size` samples have
321    /// accumulated since the last trigger (and the buffer is full).
322    pub fn push(&mut self, sample: f32) -> bool {
323        self.buffer[self.write_pos] = sample;
324        self.write_pos = (self.write_pos + 1) % self.window_size;
325        self.samples_since_trigger += 1;
326
327        if !self.filled && self.samples_since_trigger >= self.window_size {
328            self.filled = true;
329        }
330
331        if self.filled && self.samples_since_trigger >= self.hop_size {
332            self.samples_since_trigger = 0;
333            true
334        } else {
335            false
336        }
337    }
338
339    /// Copy the current window (oldest-first) into `dest`.
340    /// `dest` must be at least `window_size` long.
341    /// Uses two contiguous copies instead of per-element modulo.
342    pub fn read_window(&self, dest: &mut [f32]) {
343        debug_assert!(dest.len() >= self.window_size);
344        let start = self.write_pos; // oldest sample
345        let first_len = self.window_size - start;
346        dest[..first_len].copy_from_slice(&self.buffer[start..]);
347        if start > 0 {
348            dest[first_len..self.window_size].copy_from_slice(&self.buffer[..start]);
349        }
350    }
351
352    pub fn reset(&mut self) {
353        self.buffer.fill(0.0);
354        self.write_pos = 0;
355        self.samples_since_trigger = 0;
356        self.filled = false;
357    }
358}
359
360// ============================================================================
361// Dual-Window STFT Framework
362// ============================================================================
363//
364// Decouples frequency resolution from latency by using separate analysis
365// (long) and synthesis (short) windows. The analysis window provides high
366// frequency resolution while the synthesis window determines the output latency.
367
368/// Dual-window STFT processor.
369///
370/// Uses a long analysis window for frequency resolution and a shorter
371/// synthesis window for low-latency output. The output latency equals
372/// the synthesis window size, not the analysis window size.
373pub struct DualWindowStft {
374    analysis_window: Vec<f32>,
375    synthesis_window: Vec<f32>,
376    analysis_size: usize,
377    /// Input ring buffer sized to analysis window
378    input_ring: RingAccumulator,
379    /// Overlap-add output accumulator
380    output_accum: Vec<f32>,
381    output_read_pos: usize,
382    /// FFT processor (analysis size)
383    fft: RealFftProcessor,
384    /// Window read buffer
385    window_buf: Vec<f32>,
386    /// COLA normalization factor
387    #[allow(dead_code)]
388    cola_norm: Vec<f32>,
389}
390
391/// Design a dual-window pair satisfying the COLA (Constant Overlap-Add) condition.
392///
393/// # Arguments
394/// * `analysis_size` - Analysis window length (long, e.g. 1024)
395/// * `synthesis_size` - Synthesis window length (short, e.g. 256)
396/// * `hop_size` - Hop size in samples
397///
398/// # Returns
399/// (analysis_window, synthesis_window) pair
400pub fn design_dual_windows(
401    analysis_size: usize,
402    synthesis_size: usize,
403    hop_size: usize,
404) -> (Vec<f32>, Vec<f32>) {
405    // Analysis window: Hann
406    let w_a = generate_hann_window(analysis_size);
407
408    // Synthesis window: truncated Hann centered in the analysis window,
409    // normalized to satisfy COLA
410    let offset = (analysis_size - synthesis_size) / 2;
411
412    // Start with a Hann window of synthesis_size
413    let w_s_raw = generate_hann_window(synthesis_size);
414
415    // Compute the COLA sum: Σ_k w_a(n - k*hop) * w_s(n - k*hop)
416    // across all hop-shifted positions. We need this to be constant.
417    // Normalize w_s so the sum equals 1.
418    let num_overlaps = analysis_size.div_ceil(hop_size);
419
420    let mut cola_sum = vec![0.0f32; hop_size];
421    for k in 0..num_overlaps {
422        let shift = k * hop_size;
423        for (n, cola_val) in cola_sum.iter_mut().enumerate() {
424            let ana_idx = n + shift;
425            if ana_idx < analysis_size {
426                // Check if this falls within the synthesis window support
427                let syn_idx = ana_idx.wrapping_sub(offset);
428                if syn_idx < synthesis_size {
429                    *cola_val += w_a[ana_idx] * w_s_raw[syn_idx];
430                }
431            }
432        }
433    }
434
435    // Normalize synthesis window
436    let avg_cola: f32 = cola_sum.iter().sum::<f32>() / cola_sum.len() as f32;
437    let norm_factor = if avg_cola > 1e-10 {
438        1.0 / avg_cola
439    } else {
440        1.0
441    };
442
443    let mut w_s = vec![0.0f32; analysis_size];
444    for i in 0..synthesis_size {
445        w_s[offset + i] = w_s_raw[i] * norm_factor;
446    }
447
448    (w_a, w_s)
449}
450
451impl DualWindowStft {
452    /// Create a new dual-window STFT processor.
453    ///
454    /// # Arguments
455    /// * `analysis_size` - Analysis window size (determines frequency resolution)
456    /// * `synthesis_size` - Synthesis window size (determines output latency)
457    /// * `hop_size` - Hop size in samples
458    pub fn new(analysis_size: usize, synthesis_size: usize, hop_size: usize) -> Self {
459        let (analysis_window, synthesis_window) =
460            design_dual_windows(analysis_size, synthesis_size, hop_size);
461
462        let fft = RealFftProcessor::new_bidirectional(analysis_size);
463
464        Self {
465            analysis_window,
466            synthesis_window,
467            analysis_size,
468            input_ring: RingAccumulator::new(analysis_size, hop_size),
469            output_accum: vec![0.0; analysis_size * 3],
470            output_read_pos: 0,
471            fft,
472            window_buf: vec![0.0; analysis_size],
473            cola_norm: vec![1.0; analysis_size],
474        }
475    }
476
477    /// Push a single sample. Returns `true` when a hop boundary is reached.
478    ///
479    /// When `true`, the spectrum is available in `freq_buffer_mut()` for
480    /// in-place modification. Call `synthesize_in_place()` after modifying.
481    pub fn analyze(&mut self, sample: f32) -> bool {
482        if !self.input_ring.push(sample) {
483            return false;
484        }
485
486        // Read the analysis window worth of samples
487        self.input_ring.read_window(&mut self.window_buf);
488
489        // Apply analysis window
490        for i in 0..self.analysis_size {
491            self.fft.time_buffer[i] = self.window_buf[i] * self.analysis_window[i];
492        }
493
494        // Forward FFT
495        self.fft.forward();
496
497        true
498    }
499
500    /// Access the frequency buffer for in-place modification after `analyze()` returns `true`.
501    pub fn freq_buffer_mut(&mut self) -> &mut [Complex<f32>] {
502        &mut self.fft.freq_buffer
503    }
504
505    /// Synthesize output from the current frequency buffer (after in-place modification).
506    ///
507    /// Call this after `analyze()` returns `true` and the spectrum has been modified
508    /// via `freq_buffer_mut()`. The output samples accumulate in the internal buffer
509    /// and can be read via `read_output()`.
510    pub fn synthesize_in_place(&mut self) {
511        // Inverse FFT (operates on self.fft.freq_buffer directly)
512        self.fft.inverse();
513
514        // Apply synthesis window and overlap-add
515        let scale = 1.0 / self.analysis_size as f32;
516        for i in 0..self.analysis_size {
517            let pos = (self.output_read_pos + i) % self.output_accum.len();
518            self.output_accum[pos] += self.fft.time_buffer[i] * self.synthesis_window[i] * scale;
519        }
520    }
521
522    /// Read one output sample. Returns 0.0 if no output is ready yet.
523    pub fn read_output(&mut self) -> f32 {
524        let sample = self.output_accum[self.output_read_pos];
525        self.output_accum[self.output_read_pos] = 0.0;
526        self.output_read_pos = (self.output_read_pos + 1) % self.output_accum.len();
527        sample
528    }
529
530    /// Process a block: analyze, apply user function, synthesize.
531    ///
532    /// # Arguments
533    /// * `input` - Input samples
534    /// * `output` - Output buffer (same length as input)
535    /// * `process_fn` - Function to modify the spectrum (called at each hop boundary)
536    pub fn process_block<F>(&mut self, input: &[f32], output: &mut [f32], mut process_fn: F)
537    where
538        F: FnMut(&mut [Complex<f32>]),
539    {
540        for (i, &sample) in input.iter().enumerate() {
541            if self.analyze(sample) {
542                process_fn(&mut self.fft.freq_buffer);
543                self.synthesize_in_place();
544            }
545            output[i] = self.read_output();
546        }
547    }
548
549    /// Get the output latency in samples.
550    pub fn latency_samples(&self) -> usize {
551        self.analysis_size
552    }
553
554    /// Reset all internal state.
555    pub fn reset(&mut self) {
556        self.input_ring.reset();
557        self.output_accum.fill(0.0);
558        self.output_read_pos = 0;
559    }
560}
561
562// ============================================================================
563// Tests
564// ============================================================================
565
566#[cfg(test)]
567#[allow(clippy::needless_range_loop)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn test_hann_window_size_and_symmetry() {
573        let window = generate_hann_window(8);
574        assert_eq!(window.len(), 8);
575
576        // Hann window should start near zero and peak at center
577        assert!((window[0] - 0.0).abs() < 0.01);
578        assert!((window[4] - 1.0).abs() < 0.01);
579
580        // Symmetric: w[i] == w[N-i] for periodic Hann
581        for i in 1..4 {
582            assert!(
583                (window[i] - window[8 - i]).abs() < 1e-6,
584                "Window not symmetric at i={}: {} vs {}",
585                i,
586                window[i],
587                window[8 - i]
588            );
589        }
590    }
591
592    #[test]
593    fn test_sqrt_hann_cola_property() {
594        // sqrt(Hann) analysis * sqrt(Hann) synthesis = Hann
595        // Hann has perfect COLA at 50% overlap: w[i] + w[i+N/2] = 1.0
596        let n = 256;
597        let sqrt_window = generate_sqrt_hann_window(n);
598        let hop = n / 2;
599
600        for i in 0..hop {
601            // Product of analysis and synthesis = Hann
602            let hann_i = sqrt_window[i] * sqrt_window[i];
603            let hann_shifted = sqrt_window[i + hop] * sqrt_window[i + hop];
604            let sum = hann_i + hann_shifted;
605            assert!(
606                (sum - 1.0).abs() < 1e-5,
607                "sqrt(Hann) COLA violated at i={}: sum={}, expected 1.0",
608                i,
609                sum
610            );
611        }
612    }
613
614    #[test]
615    fn test_hann_window_cola_property() {
616        // With 50% overlap, w[i] + w[i + N/2] should equal 1.0 (COLA)
617        let n = 256;
618        let window = generate_hann_window(n);
619        let hop = n / 2;
620
621        for i in 0..hop {
622            let sum = window[i] + window[i + hop];
623            assert!(
624                (sum - 1.0).abs() < 1e-5,
625                "COLA violated at i={}: sum={}, expected 1.0",
626                i,
627                sum
628            );
629        }
630    }
631
632    #[test]
633    fn test_symmetric_hann_endpoints_are_zero() {
634        let window = generate_hann_window_symmetric(256);
635        assert!(window[0].abs() < 1e-7, "First sample should be 0");
636        assert!(window[255].abs() < 1e-7, "Last sample should be 0");
637        // Peak at center
638        assert!((window[128] - 1.0).abs() < 0.01);
639    }
640
641    #[test]
642    fn test_symmetric_hann_no_nan_for_small_sizes() {
643        // size=0: empty
644        let w0 = generate_hann_window_symmetric(0);
645        assert!(w0.is_empty());
646
647        // size=1: should be [1.0], not NaN
648        let w1 = generate_hann_window_symmetric(1);
649        assert_eq!(w1.len(), 1);
650        assert!(w1[0].is_finite(), "size=1 produced non-finite: {}", w1[0]);
651        assert!((w1[0] - 1.0).abs() < 1e-6);
652
653        // size=2: endpoints [0, 0] by symmetric formula
654        let w2 = generate_hann_window_symmetric(2);
655        assert_eq!(w2.len(), 2);
656        assert!(w2[0].is_finite());
657        assert!(w2[1].is_finite());
658    }
659
660    #[test]
661    fn test_fft_roundtrip() {
662        let fft_size = 256;
663        let mut fft = RealFftProcessor::new_bidirectional(fft_size);
664
665        // Fill with a known signal
666        let original: Vec<f32> = (0..fft_size)
667            .map(|i| (2.0 * std::f32::consts::PI * 10.0 * i as f32 / fft_size as f32).sin())
668            .collect();
669        fft.time_buffer.copy_from_slice(&original);
670
671        // Forward then inverse
672        fft.forward();
673        fft.inverse();
674
675        // Inverse FFT scales by fft_size, so divide
676        let scale = 1.0 / fft_size as f32;
677        for i in 0..fft_size {
678            let recovered = fft.time_buffer[i] * scale;
679            assert!(
680                (recovered - original[i]).abs() < 1e-4,
681                "FFT roundtrip mismatch at i={}: expected {}, got {}",
682                i,
683                original[i],
684                recovered,
685            );
686        }
687    }
688
689    #[test]
690    fn test_ring_accumulator_trigger_timing() {
691        let window_size = 8;
692        let hop_size = 4;
693        let mut ring = RingAccumulator::new(window_size, hop_size);
694
695        let mut triggers = Vec::new();
696        for i in 0..24 {
697            if ring.push(i as f32) {
698                triggers.push(i);
699            }
700        }
701
702        // First trigger at sample 7 (index 7 = 8th sample, filling window)
703        // Then every hop_size (4) samples: 11, 15, 19, 23
704        assert_eq!(triggers, vec![7, 11, 15, 19, 23]);
705    }
706
707    #[test]
708    fn test_ring_accumulator_window_readout() {
709        let window_size = 4;
710        let hop_size = 2;
711        let mut ring = RingAccumulator::new(window_size, hop_size);
712
713        // Push 6 samples: [0, 1, 2, 3, 4, 5]
714        // After 4 samples, ring is filled. After 6 samples (2 more = hop), trigger.
715        // Ring state: write_pos = 2, buffer = [4, 5, 2, 3]
716        // oldest-first read: [2, 3, 4, 5]
717        for i in 0..6 {
718            ring.push(i as f32);
719        }
720
721        let mut dest = vec![0.0; 4];
722        ring.read_window(&mut dest);
723        assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
724    }
725
726    #[test]
727    fn test_ring_accumulator_reset() {
728        let mut ring = RingAccumulator::new(8, 4);
729
730        // Fill and trigger
731        for i in 0..12 {
732            ring.push(i as f32);
733        }
734        assert!(ring.filled);
735
736        ring.reset();
737        assert!(!ring.filled);
738        assert_eq!(ring.write_pos, 0);
739        assert_eq!(ring.samples_since_trigger, 0);
740
741        // Should not trigger until filled again
742        let mut triggered = false;
743        for _ in 0..4 {
744            triggered |= ring.push(1.0);
745        }
746        assert!(!triggered, "Should not trigger before ring is filled again");
747    }
748
749    #[test]
750    fn test_dual_window_design() {
751        let analysis_size = 1024;
752        let synthesis_size = 256;
753        let hop_size = 128;
754
755        let (w_a, w_s) = design_dual_windows(analysis_size, synthesis_size, hop_size);
756        assert_eq!(w_a.len(), analysis_size);
757        assert_eq!(w_s.len(), analysis_size);
758
759        // Synthesis window should be non-zero only in the center
760        let offset = (analysis_size - synthesis_size) / 2;
761        for i in 0..offset {
762            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero before offset");
763        }
764        for i in (offset + synthesis_size)..analysis_size {
765            assert_eq!(w_s[i], 0.0, "Synthesis window should be zero after support");
766        }
767    }
768
769    #[test]
770    fn test_dual_window_stft_passthrough() {
771        let analysis_size = 512;
772        let synthesis_size = 128;
773        let hop_size = 64;
774
775        let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
776
777        // Generate a tone
778        let num_samples = 4096;
779        let signal: Vec<f32> = (0..num_samples)
780            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
781            .collect();
782
783        let mut output = vec![0.0f32; num_samples];
784
785        // Pass-through (no spectral modification)
786        stft.process_block(&signal, &mut output, |_spectrum| {
787            // Identity: don't modify spectrum
788        });
789
790        // After latency, output should approximate input
791        let latency = stft.latency_samples();
792        let check_start = latency + 512; // skip transient
793        let check_end = num_samples - 512;
794
795        if check_end > check_start {
796            let rms_error: f32 = output[check_start..check_end]
797                .iter()
798                .zip(&signal[check_start - latency..check_end - latency])
799                .map(|(o, s)| (o - s).powi(2))
800                .sum::<f32>()
801                / (check_end - check_start) as f32;
802
803            // Some error is expected from windowing; just verify it's bounded
804            assert!(
805                rms_error < 1.0,
806                "Dual-window STFT passthrough RMS error too high: {rms_error:.6}"
807            );
808        }
809    }
810
811    #[test]
812    fn test_dual_window_stft_reset() {
813        let mut stft = DualWindowStft::new(512, 128, 64);
814
815        // Process some data
816        let signal: Vec<f32> = (0..2048).map(|i| (i as f32 * 0.1).sin()).collect();
817        let mut output = vec![0.0; 2048];
818        stft.process_block(&signal, &mut output, |_| {});
819
820        // Reset
821        stft.reset();
822
823        // Process silence — output should be near zero
824        let silence = vec![0.0f32; 1024];
825        let mut output2 = vec![0.0; 1024];
826        stft.process_block(&silence, &mut output2, |_| {});
827
828        let max_output: f32 = output2.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
829        assert!(
830            max_output < 0.01,
831            "After reset + silence, max output should be ~0, got {max_output}"
832        );
833    }
834}
835
836#[cfg(test)]
837mod batched_real_fft_processor_tests {
838    use super::*;
839
840    const EPSILON: f32 = 1e-3;
841
842    fn fill_signal(buffer: &mut [f32], ch: usize) {
843        for (i, sample) in buffer.iter_mut().enumerate() {
844            let phase = i as f32 * 0.13 + ch as f32 * 0.37;
845            *sample = phase.sin() + 0.25 * (phase * 2.7).cos();
846        }
847    }
848
849    fn assert_complex_close(actual: Complex<f32>, expected: Complex<f32>) {
850        assert!((actual.re - expected.re).abs() <= EPSILON);
851        assert!((actual.im - expected.im).abs() <= EPSILON);
852    }
853
854    fn assert_slice_close(actual: &[f32], expected: &[f32]) {
855        assert_eq!(actual.len(), expected.len());
856        for (actual, expected) in actual.iter().zip(expected) {
857            assert!((actual - expected).abs() <= EPSILON);
858        }
859    }
860
861    #[test]
862    fn forward_matches_independent_processors_for_representative_channel_counts() {
863        for channels in [1, 2, 8, 16, 24] {
864            let fft_size = 64;
865            let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
866
867            for ch in 0..channels {
868                fill_signal(batched.time_channel_mut(ch), ch);
869            }
870
871            let inputs = batched.time_buffers().to_vec();
872            batched.forward_all();
873
874            for ch in 0..channels {
875                let mut independent = RealFftProcessor::new_forward_only(fft_size);
876                independent
877                    .time_buffer
878                    .copy_from_slice(&inputs[ch * fft_size..(ch + 1) * fft_size]);
879                independent.forward();
880
881                for (actual, expected) in batched
882                    .freq_channel(ch)
883                    .iter()
884                    .zip(&independent.freq_buffer)
885                {
886                    assert_complex_close(*actual, *expected);
887                }
888            }
889        }
890    }
891
892    #[test]
893    fn bidirectional_round_trip_restores_each_channel_after_scaling() {
894        let channels = 8;
895        let fft_size = 128;
896        let mut batched = BatchedRealFftProcessor::new_bidirectional(channels, fft_size);
897
898        for ch in 0..channels {
899            fill_signal(batched.time_channel_mut(ch), ch);
900        }
901
902        let original = batched.time_buffers().to_vec();
903        batched.forward_all();
904        batched.inverse_all();
905
906        for ch in 0..channels {
907            let mut expected = original[ch * fft_size..(ch + 1) * fft_size].to_vec();
908            for sample in &mut expected {
909                *sample *= fft_size as f32;
910            }
911            assert_slice_close(batched.time_channel(ch), &expected);
912        }
913    }
914
915    #[test]
916    fn channel_slices_use_flat_channel_major_layout() {
917        let channels = 3;
918        let fft_size = 4;
919        let spectrum_size = fft_size / 2 + 1;
920        let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
921
922        for ch in 0..channels {
923            for (i, sample) in batched.time_channel_mut(ch).iter_mut().enumerate() {
924                *sample = (ch * 10 + i) as f32;
925            }
926            for (i, bin) in batched.freq_channel_mut(ch).iter_mut().enumerate() {
927                *bin = Complex::new((ch * 10 + i) as f32, ch as f32);
928            }
929        }
930
931        assert_eq!(
932            batched.time_buffers(),
933            &[
934                0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0, 20.0, 21.0, 22.0, 23.0
935            ]
936        );
937        assert_eq!(batched.freq_buffers().len(), channels * spectrum_size);
938        assert_eq!(batched.freq_channel(2)[1], Complex::new(21.0, 2.0));
939    }
940}