1use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
11use rustfft::num_complex::Complex;
12use std::sync::Arc;
13
14pub fn generate_hann_window(size: usize) -> Vec<f32> {
21 (0..size)
22 .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos()))
23 .collect()
24}
25
26pub fn generate_hann_window_symmetric(size: usize) -> Vec<f32> {
29 if size <= 1 {
30 return vec![1.0; size];
31 }
32 let n_minus_1 = (size as f32) - 1.0;
33 (0..size)
34 .map(|i| 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / n_minus_1).cos()))
35 .collect()
36}
37
38pub fn generate_sqrt_hann_window(size: usize) -> Vec<f32> {
42 (0..size)
43 .map(|i| {
44 let hann = 0.5 * (1.0 - ((2.0 * std::f32::consts::PI * i as f32) / size as f32).cos());
45 hann.sqrt()
46 })
47 .collect()
48}
49
50pub struct RealFftProcessor {
58 #[allow(dead_code)]
59 pub fft_size: usize,
60 pub spectrum_size: usize,
61 fft_forward: Arc<dyn RealToComplex<f32>>,
62 fft_inverse: Option<Arc<dyn ComplexToReal<f32>>>,
63 pub time_buffer: Vec<f32>,
64 pub freq_buffer: Vec<Complex<f32>>,
65}
66
67impl RealFftProcessor {
68 pub fn new_forward_only(fft_size: usize) -> Self {
70 let spectrum_size = fft_size / 2 + 1;
71 let mut planner = RealFftPlanner::<f32>::new();
72 let fft_forward = planner.plan_fft_forward(fft_size);
73
74 Self {
75 fft_size,
76 spectrum_size,
77 fft_forward,
78 fft_inverse: None,
79 time_buffer: vec![0.0; fft_size],
80 freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
81 }
82 }
83
84 #[allow(dead_code)]
86 pub fn new_bidirectional(fft_size: usize) -> Self {
87 let spectrum_size = fft_size / 2 + 1;
88 let mut planner = RealFftPlanner::<f32>::new();
89 let fft_forward = planner.plan_fft_forward(fft_size);
90 let fft_inverse = planner.plan_fft_inverse(fft_size);
91
92 Self {
93 fft_size,
94 spectrum_size,
95 fft_forward,
96 fft_inverse: Some(fft_inverse),
97 time_buffer: vec![0.0; fft_size],
98 freq_buffer: vec![Complex::new(0.0, 0.0); spectrum_size],
99 }
100 }
101
102 pub fn forward(&mut self) {
105 self.fft_forward
106 .process(&mut self.time_buffer, &mut self.freq_buffer)
107 .expect("FFT forward failed");
108 }
109
110 #[allow(dead_code)]
113 pub fn inverse(&mut self) {
114 self.fft_inverse
115 .as_ref()
116 .expect("Inverse FFT not available (forward-only processor)")
117 .process(&mut self.freq_buffer, &mut self.time_buffer)
118 .expect("FFT inverse failed");
119 }
120}
121
122pub struct RingAccumulator {
130 buffer: Vec<f32>,
131 write_pos: usize,
132 samples_since_trigger: usize,
133 filled: bool,
134 window_size: usize,
135 hop_size: usize,
136}
137
138impl RingAccumulator {
139 pub fn new(window_size: usize, hop_size: usize) -> Self {
140 Self {
141 buffer: vec![0.0; window_size],
142 write_pos: 0,
143 samples_since_trigger: 0,
144 filled: false,
145 window_size,
146 hop_size,
147 }
148 }
149
150 pub fn push(&mut self, sample: f32) -> bool {
153 self.buffer[self.write_pos] = sample;
154 self.write_pos = (self.write_pos + 1) % self.window_size;
155 self.samples_since_trigger += 1;
156
157 if !self.filled && self.samples_since_trigger >= self.window_size {
158 self.filled = true;
159 }
160
161 if self.filled && self.samples_since_trigger >= self.hop_size {
162 self.samples_since_trigger = 0;
163 true
164 } else {
165 false
166 }
167 }
168
169 pub fn read_window(&self, dest: &mut [f32]) {
173 debug_assert!(dest.len() >= self.window_size);
174 let start = self.write_pos; let first_len = self.window_size - start;
176 dest[..first_len].copy_from_slice(&self.buffer[start..]);
177 if start > 0 {
178 dest[first_len..self.window_size].copy_from_slice(&self.buffer[..start]);
179 }
180 }
181
182 pub fn reset(&mut self) {
183 self.buffer.fill(0.0);
184 self.write_pos = 0;
185 self.samples_since_trigger = 0;
186 self.filled = false;
187 }
188}
189
190pub struct DualWindowStft {
204 analysis_window: Vec<f32>,
205 synthesis_window: Vec<f32>,
206 analysis_size: usize,
207 input_ring: RingAccumulator,
209 output_accum: Vec<f32>,
211 output_read_pos: usize,
212 fft: RealFftProcessor,
214 window_buf: Vec<f32>,
216 #[allow(dead_code)]
218 cola_norm: Vec<f32>,
219}
220
221pub fn design_dual_windows(
231 analysis_size: usize,
232 synthesis_size: usize,
233 hop_size: usize,
234) -> (Vec<f32>, Vec<f32>) {
235 let w_a = generate_hann_window(analysis_size);
237
238 let offset = (analysis_size - synthesis_size) / 2;
241
242 let w_s_raw = generate_hann_window(synthesis_size);
244
245 let num_overlaps = analysis_size.div_ceil(hop_size);
249
250 let mut cola_sum = vec![0.0f32; hop_size];
251 for k in 0..num_overlaps {
252 let shift = k * hop_size;
253 for (n, cola_val) in cola_sum.iter_mut().enumerate() {
254 let ana_idx = n + shift;
255 if ana_idx < analysis_size {
256 let syn_idx = ana_idx.wrapping_sub(offset);
258 if syn_idx < synthesis_size {
259 *cola_val += w_a[ana_idx] * w_s_raw[syn_idx];
260 }
261 }
262 }
263 }
264
265 let avg_cola: f32 = cola_sum.iter().sum::<f32>() / cola_sum.len() as f32;
267 let norm_factor = if avg_cola > 1e-10 {
268 1.0 / avg_cola
269 } else {
270 1.0
271 };
272
273 let mut w_s = vec![0.0f32; analysis_size];
274 for i in 0..synthesis_size {
275 w_s[offset + i] = w_s_raw[i] * norm_factor;
276 }
277
278 (w_a, w_s)
279}
280
281impl DualWindowStft {
282 pub fn new(analysis_size: usize, synthesis_size: usize, hop_size: usize) -> Self {
289 let (analysis_window, synthesis_window) =
290 design_dual_windows(analysis_size, synthesis_size, hop_size);
291
292 let fft = RealFftProcessor::new_bidirectional(analysis_size);
293
294 Self {
295 analysis_window,
296 synthesis_window,
297 analysis_size,
298 input_ring: RingAccumulator::new(analysis_size, hop_size),
299 output_accum: vec![0.0; analysis_size * 3],
300 output_read_pos: 0,
301 fft,
302 window_buf: vec![0.0; analysis_size],
303 cola_norm: vec![1.0; analysis_size],
304 }
305 }
306
307 pub fn analyze(&mut self, sample: f32) -> bool {
312 if !self.input_ring.push(sample) {
313 return false;
314 }
315
316 self.input_ring.read_window(&mut self.window_buf);
318
319 for i in 0..self.analysis_size {
321 self.fft.time_buffer[i] = self.window_buf[i] * self.analysis_window[i];
322 }
323
324 self.fft.forward();
326
327 true
328 }
329
330 pub fn freq_buffer_mut(&mut self) -> &mut [Complex<f32>] {
332 &mut self.fft.freq_buffer
333 }
334
335 pub fn synthesize_in_place(&mut self) {
341 self.fft.inverse();
343
344 let scale = 1.0 / self.analysis_size as f32;
346 for i in 0..self.analysis_size {
347 let pos = (self.output_read_pos + i) % self.output_accum.len();
348 self.output_accum[pos] += self.fft.time_buffer[i] * self.synthesis_window[i] * scale;
349 }
350 }
351
352 pub fn read_output(&mut self) -> f32 {
354 let sample = self.output_accum[self.output_read_pos];
355 self.output_accum[self.output_read_pos] = 0.0;
356 self.output_read_pos = (self.output_read_pos + 1) % self.output_accum.len();
357 sample
358 }
359
360 pub fn process_block<F>(&mut self, input: &[f32], output: &mut [f32], mut process_fn: F)
367 where
368 F: FnMut(&mut [Complex<f32>]),
369 {
370 for (i, &sample) in input.iter().enumerate() {
371 if self.analyze(sample) {
372 process_fn(&mut self.fft.freq_buffer);
373 self.synthesize_in_place();
374 }
375 output[i] = self.read_output();
376 }
377 }
378
379 pub fn latency_samples(&self) -> usize {
381 self.analysis_size
382 }
383
384 pub fn reset(&mut self) {
386 self.input_ring.reset();
387 self.output_accum.fill(0.0);
388 self.output_read_pos = 0;
389 }
390}
391
392#[cfg(test)]
397#[allow(clippy::needless_range_loop)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_hann_window_size_and_symmetry() {
403 let window = generate_hann_window(8);
404 assert_eq!(window.len(), 8);
405
406 assert!((window[0] - 0.0).abs() < 0.01);
408 assert!((window[4] - 1.0).abs() < 0.01);
409
410 for i in 1..4 {
412 assert!(
413 (window[i] - window[8 - i]).abs() < 1e-6,
414 "Window not symmetric at i={}: {} vs {}",
415 i,
416 window[i],
417 window[8 - i]
418 );
419 }
420 }
421
422 #[test]
423 fn test_sqrt_hann_cola_property() {
424 let n = 256;
427 let sqrt_window = generate_sqrt_hann_window(n);
428 let hop = n / 2;
429
430 for i in 0..hop {
431 let hann_i = sqrt_window[i] * sqrt_window[i];
433 let hann_shifted = sqrt_window[i + hop] * sqrt_window[i + hop];
434 let sum = hann_i + hann_shifted;
435 assert!(
436 (sum - 1.0).abs() < 1e-5,
437 "sqrt(Hann) COLA violated at i={}: sum={}, expected 1.0",
438 i,
439 sum
440 );
441 }
442 }
443
444 #[test]
445 fn test_hann_window_cola_property() {
446 let n = 256;
448 let window = generate_hann_window(n);
449 let hop = n / 2;
450
451 for i in 0..hop {
452 let sum = window[i] + window[i + hop];
453 assert!(
454 (sum - 1.0).abs() < 1e-5,
455 "COLA violated at i={}: sum={}, expected 1.0",
456 i,
457 sum
458 );
459 }
460 }
461
462 #[test]
463 fn test_symmetric_hann_endpoints_are_zero() {
464 let window = generate_hann_window_symmetric(256);
465 assert!(window[0].abs() < 1e-7, "First sample should be 0");
466 assert!(window[255].abs() < 1e-7, "Last sample should be 0");
467 assert!((window[128] - 1.0).abs() < 0.01);
469 }
470
471 #[test]
472 fn test_symmetric_hann_no_nan_for_small_sizes() {
473 let w0 = generate_hann_window_symmetric(0);
475 assert!(w0.is_empty());
476
477 let w1 = generate_hann_window_symmetric(1);
479 assert_eq!(w1.len(), 1);
480 assert!(w1[0].is_finite(), "size=1 produced non-finite: {}", w1[0]);
481 assert!((w1[0] - 1.0).abs() < 1e-6);
482
483 let w2 = generate_hann_window_symmetric(2);
485 assert_eq!(w2.len(), 2);
486 assert!(w2[0].is_finite());
487 assert!(w2[1].is_finite());
488 }
489
490 #[test]
491 fn test_fft_roundtrip() {
492 let fft_size = 256;
493 let mut fft = RealFftProcessor::new_bidirectional(fft_size);
494
495 let original: Vec<f32> = (0..fft_size)
497 .map(|i| (2.0 * std::f32::consts::PI * 10.0 * i as f32 / fft_size as f32).sin())
498 .collect();
499 fft.time_buffer.copy_from_slice(&original);
500
501 fft.forward();
503 fft.inverse();
504
505 let scale = 1.0 / fft_size as f32;
507 for i in 0..fft_size {
508 let recovered = fft.time_buffer[i] * scale;
509 assert!(
510 (recovered - original[i]).abs() < 1e-4,
511 "FFT roundtrip mismatch at i={}: expected {}, got {}",
512 i,
513 original[i],
514 recovered,
515 );
516 }
517 }
518
519 #[test]
520 fn test_ring_accumulator_trigger_timing() {
521 let window_size = 8;
522 let hop_size = 4;
523 let mut ring = RingAccumulator::new(window_size, hop_size);
524
525 let mut triggers = Vec::new();
526 for i in 0..24 {
527 if ring.push(i as f32) {
528 triggers.push(i);
529 }
530 }
531
532 assert_eq!(triggers, vec![7, 11, 15, 19, 23]);
535 }
536
537 #[test]
538 fn test_ring_accumulator_window_readout() {
539 let window_size = 4;
540 let hop_size = 2;
541 let mut ring = RingAccumulator::new(window_size, hop_size);
542
543 for i in 0..6 {
548 ring.push(i as f32);
549 }
550
551 let mut dest = vec![0.0; 4];
552 ring.read_window(&mut dest);
553 assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
554 }
555
556 #[test]
557 fn test_ring_accumulator_reset() {
558 let mut ring = RingAccumulator::new(8, 4);
559
560 for i in 0..12 {
562 ring.push(i as f32);
563 }
564 assert!(ring.filled);
565
566 ring.reset();
567 assert!(!ring.filled);
568 assert_eq!(ring.write_pos, 0);
569 assert_eq!(ring.samples_since_trigger, 0);
570
571 let mut triggered = false;
573 for _ in 0..4 {
574 triggered |= ring.push(1.0);
575 }
576 assert!(!triggered, "Should not trigger before ring is filled again");
577 }
578
579 #[test]
580 fn test_dual_window_design() {
581 let analysis_size = 1024;
582 let synthesis_size = 256;
583 let hop_size = 128;
584
585 let (w_a, w_s) = design_dual_windows(analysis_size, synthesis_size, hop_size);
586 assert_eq!(w_a.len(), analysis_size);
587 assert_eq!(w_s.len(), analysis_size);
588
589 let offset = (analysis_size - synthesis_size) / 2;
591 for i in 0..offset {
592 assert_eq!(w_s[i], 0.0, "Synthesis window should be zero before offset");
593 }
594 for i in (offset + synthesis_size)..analysis_size {
595 assert_eq!(w_s[i], 0.0, "Synthesis window should be zero after support");
596 }
597 }
598
599 #[test]
600 fn test_dual_window_stft_passthrough() {
601 let analysis_size = 512;
602 let synthesis_size = 128;
603 let hop_size = 64;
604
605 let mut stft = DualWindowStft::new(analysis_size, synthesis_size, hop_size);
606
607 let num_samples = 4096;
609 let signal: Vec<f32> = (0..num_samples)
610 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
611 .collect();
612
613 let mut output = vec![0.0f32; num_samples];
614
615 stft.process_block(&signal, &mut output, |_spectrum| {
617 });
619
620 let latency = stft.latency_samples();
622 let check_start = latency + 512; let check_end = num_samples - 512;
624
625 if check_end > check_start {
626 let rms_error: f32 = output[check_start..check_end]
627 .iter()
628 .zip(&signal[check_start - latency..check_end - latency])
629 .map(|(o, s)| (o - s).powi(2))
630 .sum::<f32>()
631 / (check_end - check_start) as f32;
632
633 assert!(
635 rms_error < 1.0,
636 "Dual-window STFT passthrough RMS error too high: {rms_error:.6}"
637 );
638 }
639 }
640
641 #[test]
642 fn test_dual_window_stft_reset() {
643 let mut stft = DualWindowStft::new(512, 128, 64);
644
645 let signal: Vec<f32> = (0..2048).map(|i| (i as f32 * 0.1).sin()).collect();
647 let mut output = vec![0.0; 2048];
648 stft.process_block(&signal, &mut output, |_| {});
649
650 stft.reset();
652
653 let silence = vec![0.0f32; 1024];
655 let mut output2 = vec![0.0; 1024];
656 stft.process_block(&silence, &mut output2, |_| {});
657
658 let max_output: f32 = output2.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
659 assert!(
660 max_output < 0.01,
661 "After reset + silence, max output should be ~0, got {max_output}"
662 );
663 }
664}