1use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
11use rustfft::num_complex::Complex;
12use std::sync::Arc;
13
14pub fn generate_hann_window(size: usize) -> Vec<f32> {
21 (0..size)
22 .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos()))
23 .collect()
24}
25
26pub fn generate_hann_window_symmetric(size: usize) -> Vec<f32> {
29 if size <= 1 {
30 return vec![1.0; size];
31 }
32 let n_minus_1 = (size as f32) - 1.0;
33 (0..size)
34 .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / n_minus_1).cos()))
35 .collect()
36}
37
38pub fn generate_sqrt_hann_window(size: usize) -> Vec<f32> {
42 (0..size)
43 .map(|i| {
44 let hann = 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos());
45 hann.sqrt()
46 })
47 .collect()
48}
49
50pub struct RealFftProcessor {
58 #[allow(dead_code)]
59 pub fft_size: usize,
60 pub spectrum_size: usize,
61 fft_forward: Arc<dyn RealToComplex<f32>>,
62 fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
63 pub time_buffer: Vec<f32>,
64 pub freq_buffer: Vec<Complex<f32>>,
65}
66
67impl RealFftProcessor {
68 pub fn new_forward_only(fft_size: usize) -> Self {
70 let spectrum_size = fft_size / 2 + 1;
71 let mut planner = RealFftPlanner::<f32>::new();
72 let fft_forward = planner.plan_fft_forward(fft_size);
73
74 Self {
75 fft_size,
76 spectrum_size,
77 fft_forward,
78 fft_inverse: None,
79 time_buffer: vec![0.0; fft_size],
80 freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
81 }
82 }
83
84 #[allow(dead_code)]
86 pub fn new_bidirectional(fft_size: usize) -> Self {
87 let spectrum_size = fft_size / 2 + 1;
88 let mut planner = RealFftPlanner::<f32>::new();
89 let fft_forward = planner.plan_fft_forward(fft_size);
90 let fft_inverse = planner.plan_fft_inverse(fft_size);
91
92 Self {
93 fft_size,
94 spectrum_size,
95 fft_forward,
96 fft_inverse: Some(fft_inverse),
97 time_buffer: vec![0.0; fft_size],
98 freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
99 }
100 }
101
102 pub fn forward(&mut self) {
105 self.fft_forward
106 .process(&mut self.time_buffer, &mut self.freq_buffer)
107 .expect("FFT forward failed");
108 }
109
110 #[allow(dead_code)]
113 pub fn inverse(&mut self) {
114 self.fft_inverse
115 .as_ref()
116 .expect("Inverse FFT not available (forward-only processor)")
117 .process(&mut self.freq_buffer, &mut self.time_buffer)
118 .expect("FFT inverse failed");
119 }
120}
121
122pub struct BatchedRealFftProcessor {
138 channels: usize,
139 fft_size: usize,
140 spectrum_size: usize,
141 fft_forward: Arc<dyn RealToComplex<f32>>,
142 fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
143 forward_scratch: Vec<Complex<f32>>,
144 inverse_scratch: Vec<Complex<f32>>,
145 time_buffers: Vec<f32>,
146 freq_buffers: Vec<Complex<f32>>,
147}
148
149impl BatchedRealFftProcessor {
150 pub fn new_forward_only(channels: usize, fft_size: usize) -> Self {
152 Self::new(channels, fft_size, false)
153 }
154
155 pub fn new_bidirectional(channels: usize, fft_size: usize) -> Self {
157 Self::new(channels, fft_size, true)
158 }
159
160 fn new(channels: usize, fft_size: usize, include_inverse: bool) -> Self {
161 assert!(
162 channels > 0,
163 "BatchedRealFftProcessor requires at least one channel"
164 );
165
166 let spectrum_size = fft_size / 2 + 1;
167 let mut planner = RealFftPlanner::<f32>::new();
168 let fft_forward = planner.plan_fft_forward(fft_size);
169 let fft_inverse = if include_inverse {
170 Some(planner.plan_fft_inverse(fft_size))
171 } else {
172 None
173 };
174
175 let forward_scratch = vec![Complex::new(0.0, 0.0); fft_forward.get_scratch_len()];
176 let inverse_scratch = fft_inverse
177 .as_ref()
178 .map(|fft| vec![Complex::new(0.0, 0.0); fft.get_scratch_len()])
179 .unwrap_or_default();
180
181 Self {
182 channels,
183 fft_size,
184 spectrum_size,
185 fft_forward,
186 fft_inverse,
187 forward_scratch,
188 inverse_scratch,
189 time_buffers: vec![0.0; channels * fft_size],
190 freq_buffers: vec![Complex::new(0.0, 0.0); channels * spectrum_size],
191 }
192 }
193
194 pub fn channels(&self) -> usize {
195 self.channels
196 }
197
198 pub fn fft_size(&self) -> usize {
199 self.fft_size
200 }
201
202 pub fn spectrum_size(&self) -> usize {
203 self.spectrum_size
204 }
205
206 pub fn time_buffers(&self) -> &[f32] {
207 &self.time_buffers
208 }
209
210 pub fn time_buffers_mut(&mut self) -> &mut [f32] {
211 &mut self.time_buffers
212 }
213
214 pub fn freq_buffers(&self) -> &[Complex<f32>] {
215 &self.freq_buffers
216 }
217
218 pub fn freq_buffers_mut(&mut self) -> &mut [Complex<f32>] {
219 &mut self.freq_buffers
220 }
221
222 pub fn time_channel(&self, ch: usize) -> &[f32] {
223 debug_assert!(ch < self.channels);
224 let range = self.time_range(ch);
225 &self.time_buffers[range]
226 }
227
228 pub fn time_channel_mut(&mut self, ch: usize) -> &mut [f32] {
229 debug_assert!(ch < self.channels);
230 let range = self.time_range(ch);
231 &mut self.time_buffers[range]
232 }
233
234 pub fn freq_channel(&self, ch: usize) -> &[Complex<f32>] {
235 debug_assert!(ch < self.channels);
236 let range = self.freq_range(ch);
237 &self.freq_buffers[range]
238 }
239
240 pub fn freq_channel_mut(&mut self, ch: usize) -> &mut [Complex<f32>] {
241 debug_assert!(ch < self.channels);
242 let range = self.freq_range(ch);
243 &mut self.freq_buffers[range]
244 }
245
246 pub fn forward_all(&mut self) {
249 for ch in 0..self.channels {
250 let time_range = self.time_range(ch);
251 let freq_range = self.freq_range(ch);
252 self.fft_forward
253 .process_with_scratch(
254 &mut self.time_buffers[time_range],
255 &mut self.freq_buffers[freq_range],
256 &mut self.forward_scratch,
257 )
258 .expect("FFT forward failed");
259 }
260 }
261
262 pub fn inverse_all(&mut self) {
265 let fft_inverse = self
266 .fft_inverse
267 .as_ref()
268 .expect("Inverse FFT not available (forward-only processor)");
269
270 for ch in 0..self.channels {
271 let time_range = self.time_range(ch);
272 let freq_range = self.freq_range(ch);
273 fft_inverse
274 .process_with_scratch(
275 &mut self.freq_buffers[freq_range],
276 &mut self.time_buffers[time_range],
277 &mut self.inverse_scratch,
278 )
279 .expect("FFT inverse failed");
280 }
281 }
282
283 fn time_range(&self, ch: usize) -> std::ops::Range<usize> {
284 ch * self.fft_size..(ch + 1) * self.fft_size
285 }
286
287 fn freq_range(&self, ch: usize) -> std::ops::Range<usize> {
288 ch * self.spectrum_size..(ch + 1) * self.spectrum_size
289 }
290}
291
292pub struct RingAccumulator {
300 buffer: Vec<f32>,
301 write_pos: usize,
302 samples_since_trigger: usize,
303 filled: bool,
304 window_size: usize,
305 hop_size: usize,
306}
307
308impl RingAccumulator {
309 pub fn new(window_size: usize, hop_size: usize) -> Self {
310 Self {
311 buffer: vec![0.0; window_size],
312 write_pos: 0,
313 samples_since_trigger: 0,
314 filled: false,
315 window_size,
316 hop_size,
317 }
318 }
319
320 pub fn push(&mut self, sample: f32) -> bool {
323 self.buffer[self.write_pos] = sample;
324 self.write_pos = (self.write_pos + 1) % self.window_size;
325 self.samples_since_trigger += 1;
326
327 if !self.filled && self.samples_since_trigger >= self.window_size {
328 self.filled = true;
329 }
330
331 if self.filled && self.samples_since_trigger >= self.hop_size {
332 self.samples_since_trigger = 0;
333 true
334 } else {
335 false
336 }
337 }
338
339 pub fn read_window(&self, dest: &mut [f32]) {
343 debug_assert!(dest.len() >= self.window_size);
344 let start = self.write_pos; let first_len = self.window_size - start;
346 dest[..first_len].copy_from_slice(&self.buffer[start..]);
347 if start > 0 {
348 dest[first_len..self.window_size].copy_from_slice(&self.buffer[..start]);
349 }
350 }
351
352 pub fn reset(&mut self) {
353 self.buffer.fill(0.0);
354 self.write_pos = 0;
355 self.samples_since_trigger = 0;
356 self.filled = false;
357 }
358}
359
360pub struct DualWindowStft {
375 analysis_window: Vec<f32>,
376 synthesis_window: Vec<f32>,
377 analysis_size: usize,
378 input_ring: RingAccumulator,
380 output_accum: Vec<f32>,
382 output_read_pos: usize,
383 fft: RealFftProcessor,
385 window_buf: Vec<f32>,
387}
388
389pub fn design_dual_windows(
399 analysis_size: usize,
400 synthesis_size: usize,
401 hop_size: usize,
402) -> (Vec<f32>, Vec<f32>) {
403 let w_a = generate_hann_window(analysis_size);
405
406 let offset = (analysis_size - synthesis_size) / 2;
409
410 let w_s_raw = generate_hann_window(synthesis_size);
412
413 let num_overlaps = analysis_size.div_ceil(hop_size);
417
418 let mut cola_sum = vec![0.0f32; hop_size];
419 for k in 0..num_overlaps {
420 let shift = k * hop_size;
421 for (n, cola_val) in cola_sum.iter_mut().enumerate() {
422 let ana_idx = n + shift;
423 if ana_idx < analysis_size {
424 let syn_idx = ana_idx.wrapping_sub(offset);
426 if syn_idx < synthesis_size {
427 *cola_val += w_a[ana_idx] * w_s_raw[syn_idx];
428 }
429 }
430 }
431 }
432
433 let avg_cola: f32 = cola_sum.iter().sum::<f32>() / cola_sum.len() as f32;
435 let norm_factor = if avg_cola > 1e-10 {
436 1.0 / avg_cola
437 } else {
438 1.0
439 };
440
441 let mut w_s = vec![0.0f32; analysis_size];
442 for i in 0..synthesis_size {
443 w_s[offset + i] = w_s_raw[i] * norm_factor;
444 }
445
446 (w_a, w_s)
447}
448
449impl DualWindowStft {
450 pub fn new(analysis_size: usize, synthesis_size: usize, hop_size: usize) -> Self {
457 let (analysis_window, synthesis_window) =
458 design_dual_windows(analysis_size, synthesis_size, hop_size);
459
460 let fft = RealFftProcessor::new_bidirectional(analysis_size);
461
462 Self {
463 analysis_window,
464 synthesis_window,
465 analysis_size,
466 input_ring: RingAccumulator::new(analysis_size, hop_size),
467 output_accum: vec![0.0; analysis_size * 3],
468 output_read_pos: 0,
469 fft,
470 window_buf: vec![0.0; analysis_size],
471 }
472 }
473
474 pub fn analyze(&mut self, sample: f32) -> bool {
479 if !self.input_ring.push(sample) {
480 return false;
481 }
482
483 self.input_ring.read_window(&mut self.window_buf);
485
486 for i in 0..self.analysis_size {
488 self.fft.time_buffer[i] = self.window_buf[i] * self.analysis_window[i];
489 }
490
491 self.fft.forward();
493
494 true
495 }
496
497 pub fn freq_buffer_mut(&mut self) -> &mut [Complex<f32>] {
499 &mut self.fft.freq_buffer
500 }
501
502 pub fn synthesize_in_place(&mut self) {
516 self.fft.inverse();
518
519 let scale = 1.0 / self.analysis_size as f32;
522 for i in 0..self.analysis_size {
523 let pos = (self.output_read_pos + i) % self.output_accum.len();
524 self.output_accum[pos] += self.fft.time_buffer[i] * self.synthesis_window[i] * scale;
525 }
526 }
527
528 pub fn read_output(&mut self) -> f32 {
530 let sample = self.output_accum[self.output_read_pos];
531 self.output_accum[self.output_read_pos] = 0.0;
532 self.output_read_pos = (self.output_read_pos + 1) % self.output_accum.len();
533 sample
534 }
535
536 pub fn process_block<F>(&mut self, input: &[f32], output: &mut [f32], mut process_fn: F)
543 where
544 F: FnMut(&mut [Complex<f32>]),
545 {
546 for (i, &sample) in input.iter().enumerate() {
547 if self.analyze(sample) {
548 process_fn(&mut self.fft.freq_buffer);
549 self.synthesize_in_place();
550 }
551 output[i] = self.read_output();
552 }
553 }
554
555 pub fn latency_samples(&self) -> usize {
561 self.analysis_size
562 }
563
564 pub fn reset(&mut self) {
566 self.input_ring.reset();
567 self.output_accum.fill(0.0);
568 self.output_read_pos = 0;
569 }
570}
571
572#[cfg(test)]
577#[allow(clippy::needless_range_loop)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_hann_window_size_and_symmetry() {
583 let window = generate_hann_window(8);
584 assert_eq!(window.len(), 8);
585
586 assert!((window[0] - 0.0).abs() < 0.01);
588 assert!((window[4] - 1.0).abs() < 0.01);
589
590 for i in 1..4 {
592 assert!(
593 (window[i] - window[8 - i]).abs() < 1e-6,
594 "Window not symmetric at i={}: {} vs {}",
595 i,
596 window[i],
597 window[8 - i]
598 );
599 }
600 }
601
602 #[test]
603 fn test_sqrt_hann_cola_property() {
604 let n = 256;
607 let sqrt_window = generate_sqrt_hann_window(n);
608 let hop = n / 2;
609
610 for i in 0..hop {
611 let hann_i = sqrt_window[i] * sqrt_window[i];
613 let hann_shifted = sqrt_window[i + hop] * sqrt_window[i + hop];
614 let sum = hann_i + hann_shifted;
615 assert!(
616 (sum - 1.0).abs() < 1e-5,
617 "sqrt(Hann) COLA violated at i={}: sum={}, expected 1.0",
618 i,
619 sum
620 );
621 }
622 }
623
624 #[test]
625 fn test_hann_window_cola_property() {
626 let n = 256;
628 let window = generate_hann_window(n);
629 let hop = n / 2;
630
631 for i in 0..hop {
632 let sum = window[i] + window[i + hop];
633 assert!(
634 (sum - 1.0).abs() < 1e-5,
635 "COLA violated at i={}: sum={}, expected 1.0",
636 i,
637 sum
638 );
639 }
640 }
641
642 #[test]
643 fn test_symmetric_hann_endpoints_are_zero() {
644 let window = generate_hann_window_symmetric(256);
645 assert!(window[0].abs() < 1e-7, "First sample should be 0");
646 assert!(window[255].abs() < 1e-7, "Last sample should be 0");
647 assert!((window[128] - 1.0).abs() < 0.01);
649 }
650
651 #[test]
652 fn test_symmetric_hann_no_nan_for_small_sizes() {
653 let w0 = generate_hann_window_symmetric(0);
655 assert!(w0.is_empty());
656
657 let w1 = generate_hann_window_symmetric(1);
659 assert_eq!(w1.len(), 1);
660 assert!(w1[0].is_finite(), "size=1 produced non-finite: {}", w1[0]);
661 assert!((w1[0] - 1.0).abs() < 1e-6);
662
663 let w2 = generate_hann_window_symmetric(2);
665 assert_eq!(w2.len(), 2);
666 assert!(w2[0].is_finite());
667 assert!(w2[1].is_finite());
668 }
669
670 #[test]
671 fn test_fft_roundtrip() {
672 let fft_size = 256;
673 let mut fft = RealFftProcessor::new_bidirectional(fft_size);
674
675 let original: Vec<f32> = (0..fft_size)
677 .map(|i| (2.0 * std::f32::consts::PI * 10.0 * i as f32 / fft_size as f32).sin())
678 .collect();
679 fft.time_buffer.copy_from_slice(&original);
680
681 fft.forward();
683 fft.inverse();
684
685 let scale = 1.0 / fft_size as f32;
687 for i in 0..fft_size {
688 let recovered = fft.time_buffer[i] * scale;
689 assert!(
690 (recovered - original[i]).abs() < 1e-4,
691 "FFT roundtrip mismatch at i={}: expected {}, got {}",
692 i,
693 original[i],
694 recovered,
695 );
696 }
697 }
698
699 #[test]
700 fn test_ring_accumulator_trigger_timing() {
701 let window_size = 8;
702 let hop_size = 4;
703 let mut ring = RingAccumulator::new(window_size, hop_size);
704
705 let mut triggers = Vec::new();
706 for i in 0..24 {
707 if ring.push(i as f32) {
708 triggers.push(i);
709 }
710 }
711
712 assert_eq!(triggers, vec![7, 11, 15, 19, 23]);
715 }
716
717 #[test]
718 fn test_ring_accumulator_window_readout() {
719 let window_size = 4;
720 let hop_size = 2;
721 let mut ring = RingAccumulator::new(window_size, hop_size);
722
723 for i in 0..6 {
728 ring.push(i as f32);
729 }
730
731 let mut dest = vec![0.0; 4];
732 ring.read_window(&mut dest);
733 assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
734 }
735
736 #[test]
737 fn test_ring_accumulator_reset() {
738 let mut ring = RingAccumulator::new(8, 4);
739
740 for i in 0..12 {
742 ring.push(i as f32);
743 }
744 assert!(ring.filled);
745
746 ring.reset();
747 assert!(!ring.filled);
748 assert_eq!(ring.write_pos, 0);
749 assert_eq!(ring.samples_since_trigger, 0);
750
751 let mut triggered = false;
753 for _ in 0..4 {
754 triggered |= ring.push(1.0);
755 }
756 assert!(!triggered, "Should not trigger before ring is filled again");
757 }
758
759 #[test]
760 fn test_dual_window_design() {
761 let analysis_size = 1024;
762 let synthesis_size = 256;
763 let hop_size = 128;
764
765 let (w_a, w_s) = design_dual_windows(analysis_size, synthesis_size, hop_size);
766 assert_eq!(w_a.len(), analysis_size);
767 assert_eq!(w_s.len(), analysis_size);
768
769 let offset = (analysis_size - synthesis_size) / 2;
771 for i in 0..offset {
772 assert_eq!(w_s[i], 0.0, "Synthesis window should be zero before offset");
773 }
774 for i in (offset + synthesis_size)..analysis_size {
775 assert_eq!(w_s[i], 0.0, "Synthesis window should be zero after support");
776 }
777 }
778
779 #[test]
780 fn test_dual_window_stft_passthrough() {
781 let analysis_size = 512;
782 let synthesis_size = 128;
783 let hop_size = 64;
784
785 let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
786
787 let num_samples = 4096;
789 let signal: Vec<f32> = (0..num_samples)
790 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
791 .collect();
792
793 let mut output = vec![0.0f32; num_samples];
794
795 stft.process_block(&signal, &mut output, |_spectrum| {
797 });
799
800 let latency = stft.latency_samples();
802 let check_start = latency + 512; let check_end = num_samples - 512;
804
805 if check_end > check_start {
806 let rms_error: f32 = output[check_start..check_end]
807 .iter()
808 .zip(&signal[check_start - latency..check_end - latency])
809 .map(|(o, s)| (o - s).powi(2))
810 .sum::<f32>()
811 / (check_end - check_start) as f32;
812
813 assert!(
815 rms_error < 1.0,
816 "Dual-window STFT passthrough RMS error too high: {rms_error:.6}"
817 );
818 }
819 }
820
821 #[test]
827 fn test_dual_window_stft_roundtrip_unity_gain() {
828 let analysis_size = 512;
829 let synthesis_size = 128;
830 let hop_size = 64;
831
832 let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
833
834 let num_samples = 6144;
836 let signal = vec![0.5_f32; num_samples];
837 let mut output = vec![0.0_f32; num_samples];
838
839 stft.process_block(&signal, &mut output, |_spectrum| {});
840
841 let latency = stft.latency_samples();
843 let check_start = latency + 2 * analysis_size;
844 let check_end = num_samples - analysis_size;
845
846 if check_end > check_start {
847 let rms_error: f32 = output[check_start..check_end]
848 .iter()
849 .zip(&signal[check_start - latency..check_end - latency])
850 .map(|(o, s)| (o - s).powi(2))
851 .sum::<f32>()
852 / (check_end - check_start) as f32;
853
854 assert!(
855 rms_error < 1e-4,
856 "DualWindowStft round-trip RMS error too high ({rms_error:.6}); \
857 IFFT scale or synthesis-window normalization may be wrong"
858 );
859 }
860 }
861
862 #[test]
863 fn test_dual_window_stft_latency_reports_analysis_fill_delay() {
864 let stft = DualWindowStft::new(512, 128, 64);
865 assert_eq!(stft.latency_samples(), 512);
866 }
867
868 #[test]
869 fn test_dual_window_stft_reset() {
870 let mut stft = DualWindowStft::new(512, 128, 64);
871
872 let signal: Vec<f32> = (0..2048).map(|i| (i as f32 * 0.1).sin()).collect();
874 let mut output = vec![0.0; 2048];
875 stft.process_block(&signal, &mut output, |_| {});
876
877 stft.reset();
879
880 let silence = vec![0.0f32; 1024];
882 let mut output2 = vec![0.0; 1024];
883 stft.process_block(&silence, &mut output2, |_| {});
884
885 let max_output: f32 = output2.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
886 assert!(
887 max_output < 0.01,
888 "After reset + silence, max output should be ~0, got {max_output}"
889 );
890 }
891}
892
893#[cfg(test)]
894mod batched_real_fft_processor_tests {
895 use super::*;
896
897 const EPSILON: f32 = 1e-3;
898
899 fn fill_signal(buffer: &mut [f32], ch: usize) {
900 for (i, sample) in buffer.iter_mut().enumerate() {
901 let phase = i as f32 * 0.13 + ch as f32 * 0.37;
902 *sample = phase.sin() + 0.25 * (phase * 2.7).cos();
903 }
904 }
905
906 fn assert_complex_close(actual: Complex<f32>, expected: Complex<f32>) {
907 assert!((actual.re - expected.re).abs() <= EPSILON);
908 assert!((actual.im - expected.im).abs() <= EPSILON);
909 }
910
911 fn assert_slice_close(actual: &[f32], expected: &[f32]) {
912 assert_eq!(actual.len(), expected.len());
913 for (actual, expected) in actual.iter().zip(expected) {
914 assert!((actual - expected).abs() <= EPSILON);
915 }
916 }
917
918 #[test]
919 fn forward_matches_independent_processors_for_representative_channel_counts() {
920 for channels in [1, 2, 8, 16, 24] {
921 let fft_size = 64;
922 let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
923
924 for ch in 0..channels {
925 fill_signal(batched.time_channel_mut(ch), ch);
926 }
927
928 let inputs = batched.time_buffers().to_vec();
929 batched.forward_all();
930
931 for ch in 0..channels {
932 let mut independent = RealFftProcessor::new_forward_only(fft_size);
933 independent
934 .time_buffer
935 .copy_from_slice(&inputs[ch * fft_size..(ch + 1) * fft_size]);
936 independent.forward();
937
938 for (actual, expected) in batched
939 .freq_channel(ch)
940 .iter()
941 .zip(&independent.freq_buffer)
942 {
943 assert_complex_close(*actual, *expected);
944 }
945 }
946 }
947 }
948
949 #[test]
950 fn bidirectional_round_trip_restores_each_channel_after_scaling() {
951 let channels = 8;
952 let fft_size = 128;
953 let mut batched = BatchedRealFftProcessor::new_bidirectional(channels, fft_size);
954
955 for ch in 0..channels {
956 fill_signal(batched.time_channel_mut(ch), ch);
957 }
958
959 let original = batched.time_buffers().to_vec();
960 batched.forward_all();
961 batched.inverse_all();
962
963 for ch in 0..channels {
964 let mut expected = original[ch * fft_size..(ch + 1) * fft_size].to_vec();
965 for sample in &mut expected {
966 *sample *= fft_size as f32;
967 }
968 assert_slice_close(batched.time_channel(ch), &expected);
969 }
970 }
971
972 #[test]
973 fn channel_slices_use_flat_channel_major_layout() {
974 let channels = 3;
975 let fft_size = 4;
976 let spectrum_size = fft_size / 2 + 1;
977 let mut batched = BatchedRealFftProcessor::new_forward_only(channels, fft_size);
978
979 for ch in 0..channels {
980 for (i, sample) in batched.time_channel_mut(ch).iter_mut().enumerate() {
981 *sample = (ch * 10 + i) as f32;
982 }
983 for (i, bin) in batched.freq_channel_mut(ch).iter_mut().enumerate() {
984 *bin = Complex::new((ch * 10 + i) as f32, ch as f32);
985 }
986 }
987
988 assert_eq!(
989 batched.time_buffers(),
990 &[
991 0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0, 20.0, 21.0, 22.0, 23.0
992 ]
993 );
994 assert_eq!(batched.freq_buffers().len(), channels * spectrum_size);
995 assert_eq!(batched.freq_channel(2)[1], Complex::new(21.0, 2.0));
996 }
997}