1use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31pub mod prelude {
32 pub use crate::{
33 BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64, PadMode,
34 ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame, SpectrumFrameF32,
35 SpectrumFrameF64, StftConfig, StftConfigF32, StftConfigF64, StreamingIstft,
36 StreamingIstftF32, StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64,
37 WindowType, apply_padding,
38 };
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReconstructionMode {
43 Ola,
45
46 Wola,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum WindowType {
52 Hann,
53 Hamming,
54 Blackman,
55}
56
57#[derive(Debug, Clone)]
58pub enum ConfigError<T: Float + fmt::Debug> {
59 NolaViolation { min_energy: T, threshold: T },
60 ColaViolation { max_deviation: T, threshold: T },
61 InvalidHopSize,
62 InvalidFftSize,
63}
64
65impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 match self {
68 ConfigError::NolaViolation {
69 min_energy,
70 threshold,
71 } => {
72 write!(
73 f,
74 "NOLA condition violated: min_energy={} < threshold={}",
75 min_energy, threshold
76 )
77 }
78 ConfigError::ColaViolation {
79 max_deviation,
80 threshold,
81 } => {
82 write!(
83 f,
84 "COLA condition violated: max_deviation={} > threshold={}",
85 max_deviation, threshold
86 )
87 }
88 ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
89 ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
90 }
91 }
92}
93
94impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
95
96#[derive(Debug, Clone, Copy)]
97pub enum PadMode {
98 Reflect,
99 Zero,
100 Edge,
101}
102
103#[derive(Clone)]
104pub struct StftConfig<T: Float> {
105 pub fft_size: usize,
106 pub hop_size: usize,
107 pub window: WindowType,
108 pub reconstruction_mode: ReconstructionMode,
109 _phantom: std::marker::PhantomData<T>,
110}
111
112impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
113 fn nola_threshold() -> T {
114 T::from(1e-8).unwrap()
115 }
116
117 fn cola_relative_tolerance() -> T {
118 T::from(1e-4).unwrap()
119 }
120
121 pub fn new(
122 fft_size: usize,
123 hop_size: usize,
124 window: WindowType,
125 reconstruction_mode: ReconstructionMode,
126 ) -> Result<Self, ConfigError<T>> {
127 if fft_size == 0 || !fft_size.is_power_of_two() {
128 return Err(ConfigError::InvalidFftSize);
129 }
130 if hop_size == 0 || hop_size > fft_size {
131 return Err(ConfigError::InvalidHopSize);
132 }
133
134 let config = Self {
135 fft_size,
136 hop_size,
137 window,
138 reconstruction_mode,
139 _phantom: std::marker::PhantomData,
140 };
141
142 match reconstruction_mode {
144 ReconstructionMode::Ola => config.validate_cola()?,
145 ReconstructionMode::Wola => config.validate_nola()?,
146 }
147
148 Ok(config)
149 }
150
151 pub fn default_4096() -> Self {
153 Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
154 .expect("Default config should always be valid")
155 }
156
157 pub fn freq_bins(&self) -> usize {
158 self.fft_size / 2 + 1
159 }
160
161 pub fn overlap_percent(&self) -> T {
162 let one = T::one();
163 let hundred = T::from(100.0).unwrap();
164 (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
165 }
166
167 fn generate_window(&self) -> Vec<T> {
168 generate_window(self.window, self.fft_size)
169 }
170
171 pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
173 let window = self.generate_window();
174 let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
175 let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
176 let mut energy = vec![T::zero(); test_len];
177
178 for i in 0..num_overlaps {
179 let offset = i * self.hop_size;
180 for j in 0..self.fft_size {
181 if offset + j < test_len {
182 energy[offset + j] = energy[offset + j] + window[j] * window[j];
183 }
184 }
185 }
186
187 let start = self.fft_size / 2;
189 let end = test_len - self.fft_size / 2;
190 let min_energy = energy[start..end]
191 .iter()
192 .copied()
193 .min_by(|a, b| a.partial_cmp(b).unwrap())
194 .unwrap_or_else(T::zero);
195
196 if min_energy < Self::nola_threshold() {
197 return Err(ConfigError::NolaViolation {
198 min_energy,
199 threshold: Self::nola_threshold(),
200 });
201 }
202
203 Ok(())
204 }
205
206 pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
208 let window = self.generate_window();
209 let window_len = window.len();
210
211 let mut cola_sum_period = vec![T::zero(); self.hop_size];
212 for i in 0..window_len {
213 let idx = i % self.hop_size;
214 cola_sum_period[idx] = cola_sum_period[idx] + window[i];
215 }
216
217 let zero = T::zero();
218 let min_sum = cola_sum_period
219 .iter()
220 .min_by(|a, b| a.partial_cmp(b).unwrap())
221 .unwrap_or(&zero);
222 let max_sum = cola_sum_period
223 .iter()
224 .max_by(|a, b| a.partial_cmp(b).unwrap())
225 .unwrap_or(&zero);
226
227 let epsilon = T::from(1e-9).unwrap();
228 if *max_sum < epsilon {
229 return Err(ConfigError::ColaViolation {
230 max_deviation: T::infinity(),
231 threshold: Self::cola_relative_tolerance(),
232 });
233 }
234
235 let ripple = (*max_sum - *min_sum) / *max_sum;
236
237 let is_compliant = ripple < Self::cola_relative_tolerance();
238
239 if !is_compliant {
240 return Err(ConfigError::ColaViolation {
241 max_deviation: ripple,
242 threshold: Self::cola_relative_tolerance(),
243 });
244 }
245 Ok(())
246 }
247}
248
249fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
250 let pi = T::from(std::f64::consts::PI).unwrap();
251 let two = T::from(2.0).unwrap();
252
253 match window_type {
254 WindowType::Hann => (0..size)
255 .map(|i| {
256 let half = T::from(0.5).unwrap();
257 let one = T::one();
258 let i_t = T::from(i).unwrap();
259 let size_m1 = T::from(size - 1).unwrap();
260 half * (one - (two * pi * i_t / size_m1).cos())
261 })
262 .collect(),
263 WindowType::Hamming => (0..size)
264 .map(|i| {
265 let i_t = T::from(i).unwrap();
266 let size_m1 = T::from(size - 1).unwrap();
267 T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
268 })
269 .collect(),
270 WindowType::Blackman => (0..size)
271 .map(|i| {
272 let i_t = T::from(i).unwrap();
273 let size_m1 = T::from(size - 1).unwrap();
274 let angle = two * pi * i_t / size_m1;
275 T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
276 + T::from(0.08).unwrap() * (two * angle).cos()
277 })
278 .collect(),
279 }
280}
281
282#[derive(Clone)]
283pub struct SpectrumFrame<T: Float> {
284 pub freq_bins: usize,
285 pub data: Vec<Complex<T>>,
286}
287
288impl<T: Float> SpectrumFrame<T> {
289 pub fn new(freq_bins: usize) -> Self {
290 Self {
291 freq_bins,
292 data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
293 }
294 }
295
296 pub fn from_data(data: Vec<Complex<T>>) -> Self {
297 let freq_bins = data.len();
298 Self { freq_bins, data }
299 }
300
301 pub fn clear(&mut self) {
303 for val in &mut self.data {
304 *val = Complex::new(T::zero(), T::zero());
305 }
306 }
307
308 pub fn resize_if_needed(&mut self, freq_bins: usize) {
310 if self.freq_bins != freq_bins {
311 self.freq_bins = freq_bins;
312 self.data
313 .resize(freq_bins, Complex::new(T::zero(), T::zero()));
314 }
315 }
316
317 pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
319 self.resize_if_needed(data.len());
320 self.data.copy_from_slice(data);
321 }
322
323 #[inline]
325 pub fn magnitude(&self, bin: usize) -> T {
326 let c = &self.data[bin];
327 (c.re * c.re + c.im * c.im).sqrt()
328 }
329
330 #[inline]
332 pub fn phase(&self, bin: usize) -> T {
333 let c = &self.data[bin];
334 c.im.atan2(c.re)
335 }
336
337 pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
339 self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
340 }
341
342 pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
344 assert_eq!(
345 magnitudes.len(),
346 phases.len(),
347 "Magnitude and phase arrays must have same length"
348 );
349 let freq_bins = magnitudes.len();
350 let data: Vec<Complex<T>> = magnitudes
351 .iter()
352 .zip(phases.iter())
353 .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
354 .collect();
355 Self { freq_bins, data }
356 }
357
358 pub fn magnitudes(&self) -> Vec<T> {
360 self.data
361 .iter()
362 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
363 .collect()
364 }
365
366 pub fn phases(&self) -> Vec<T> {
368 self.data.iter().map(|c| c.im.atan2(c.re)).collect()
369 }
370}
371
372#[derive(Clone)]
373pub struct Spectrum<T: Float> {
374 pub num_frames: usize,
375 pub freq_bins: usize,
376 pub data: Vec<T>,
377}
378
379impl<T: Float> Spectrum<T> {
380 pub fn new(num_frames: usize, freq_bins: usize) -> Self {
381 Self {
382 num_frames,
383 freq_bins,
384 data: vec![T::zero(); 2 * num_frames * freq_bins],
385 }
386 }
387
388 #[inline]
389 pub fn real(&self, frame: usize, bin: usize) -> T {
390 self.data[frame * self.freq_bins + bin]
391 }
392
393 #[inline]
394 pub fn imag(&self, frame: usize, bin: usize) -> T {
395 let offset = self.num_frames * self.freq_bins;
396 self.data[offset + frame * self.freq_bins + bin]
397 }
398
399 #[inline]
400 pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
401 Complex::new(self.real(frame, bin), self.imag(frame, bin))
402 }
403
404 pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
405 (0..self.num_frames).map(move |frame_idx| {
406 let data: Vec<Complex<T>> = (0..self.freq_bins)
407 .map(|bin| self.get_complex(frame_idx, bin))
408 .collect();
409 SpectrumFrame::from_data(data)
410 })
411 }
412
413 #[inline]
415 pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
416 self.data[frame * self.freq_bins + bin] = value;
417 }
418
419 #[inline]
421 pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
422 let offset = self.num_frames * self.freq_bins;
423 self.data[offset + frame * self.freq_bins + bin] = value;
424 }
425
426 #[inline]
428 pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
429 self.set_real(frame, bin, value.re);
430 self.set_imag(frame, bin, value.im);
431 }
432
433 #[inline]
435 pub fn magnitude(&self, frame: usize, bin: usize) -> T {
436 let re = self.real(frame, bin);
437 let im = self.imag(frame, bin);
438 (re * re + im * im).sqrt()
439 }
440
441 #[inline]
443 pub fn phase(&self, frame: usize, bin: usize) -> T {
444 let re = self.real(frame, bin);
445 let im = self.imag(frame, bin);
446 im.atan2(re)
447 }
448
449 pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
451 self.set_real(frame, bin, magnitude * phase.cos());
452 self.set_imag(frame, bin, magnitude * phase.sin());
453 }
454
455 pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
457 (0..self.freq_bins)
458 .map(|bin| self.magnitude(frame, bin))
459 .collect()
460 }
461
462 pub fn frame_phases(&self, frame: usize) -> Vec<T> {
464 (0..self.freq_bins)
465 .map(|bin| self.phase(frame, bin))
466 .collect()
467 }
468
469 pub fn apply<F>(&mut self, mut f: F)
471 where
472 F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
473 {
474 for frame in 0..self.num_frames {
475 for bin in 0..self.freq_bins {
476 let c = self.get_complex(frame, bin);
477 let new_c = f(frame, bin, c);
478 self.set_complex(frame, bin, new_c);
479 }
480 }
481 }
482
483 pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
485 for frame in 0..self.num_frames {
486 for bin in bin_range.clone() {
487 if bin < self.freq_bins {
488 let c = self.get_complex(frame, bin);
489 self.set_complex(frame, bin, c * gain);
490 }
491 }
492 }
493 }
494
495 pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
497 for frame in 0..self.num_frames {
498 for bin in bin_range.clone() {
499 if bin < self.freq_bins {
500 self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
501 }
502 }
503 }
504 }
505}
506
507pub struct BatchStft<T: Float + FftNum> {
508 config: StftConfig<T>,
509 window: Vec<T>,
510 fft: Arc<dyn Fft<T>>,
511}
512
513impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
514 pub fn new(config: StftConfig<T>) -> Self {
515 let window = config.generate_window();
516 let mut planner = FftPlanner::new();
517 let fft = planner.plan_fft_forward(config.fft_size);
518
519 Self {
520 config,
521 window,
522 fft,
523 }
524 }
525
526 pub fn process(&self, signal: &[T]) -> Spectrum<T> {
527 self.process_padded(signal, PadMode::Reflect)
528 }
529
530 pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
531 let pad_amount = self.config.fft_size / 2;
532 let padded = apply_padding(signal, pad_amount, pad_mode);
533
534 let num_frames = if padded.len() >= self.config.fft_size {
535 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
536 } else {
537 0
538 };
539
540 let freq_bins = self.config.freq_bins();
541 let mut result = Spectrum::new(num_frames, freq_bins);
542
543 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
544
545 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
546 .step_by(self.config.hop_size)
547 .enumerate()
548 {
549 for i in 0..self.config.fft_size {
551 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
552 }
553
554 self.fft.process(&mut fft_buffer);
556
557 for bin in 0..freq_bins {
559 let idx = frame_idx * freq_bins + bin;
560 result.data[idx] = fft_buffer[bin].re;
561 result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
562 }
563 }
564
565 result
566 }
567
568 pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
572 self.process_padded_into(signal, PadMode::Reflect, spectrum)
573 }
574
575 pub fn process_padded_into(
577 &self,
578 signal: &[T],
579 pad_mode: PadMode,
580 spectrum: &mut Spectrum<T>,
581 ) -> bool {
582 let pad_amount = self.config.fft_size / 2;
583 let padded = apply_padding(signal, pad_amount, pad_mode);
584
585 let num_frames = if padded.len() >= self.config.fft_size {
586 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
587 } else {
588 0
589 };
590
591 let freq_bins = self.config.freq_bins();
592
593 if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
595 return false;
596 }
597
598 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
599
600 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
601 .step_by(self.config.hop_size)
602 .enumerate()
603 {
604 for i in 0..self.config.fft_size {
606 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
607 }
608
609 self.fft.process(&mut fft_buffer);
611
612 for bin in 0..freq_bins {
614 let idx = frame_idx * freq_bins + bin;
615 spectrum.data[idx] = fft_buffer[bin].re;
616 spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
617 }
618 }
619
620 true
621 }
622}
623
624pub struct BatchIstft<T: Float + FftNum> {
625 config: StftConfig<T>,
626 window: Vec<T>,
627 ifft: Arc<dyn Fft<T>>,
628}
629
630impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
631 pub fn new(config: StftConfig<T>) -> Self {
632 let window = config.generate_window();
633 let mut planner = FftPlanner::new();
634 let ifft = planner.plan_fft_inverse(config.fft_size);
635
636 Self {
637 config,
638 window,
639 ifft,
640 }
641 }
642
643 pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
644 assert_eq!(
645 spectrum.freq_bins,
646 self.config.freq_bins(),
647 "Frequency bins mismatch"
648 );
649
650 let num_frames = spectrum.num_frames;
651 let original_time_len = (num_frames - 1) * self.config.hop_size;
652 let pad_amount = self.config.fft_size / 2;
653 let padded_len = original_time_len + 2 * pad_amount;
654
655 let mut overlap_buffer = vec![T::zero(); padded_len];
656 let mut window_energy = vec![T::zero(); padded_len];
657 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
658
659 for frame_idx in 0..num_frames {
661 let pos = frame_idx * self.config.hop_size;
662 for i in 0..self.config.fft_size {
663 match self.config.reconstruction_mode {
664 ReconstructionMode::Ola => {
665 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
666 }
667 ReconstructionMode::Wola => {
668 window_energy[pos + i] =
669 window_energy[pos + i] + self.window[i] * self.window[i];
670 }
671 }
672 }
673 }
674
675 for frame_idx in 0..num_frames {
677 for bin in 0..spectrum.freq_bins {
679 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
680 }
681
682 for bin in 1..(spectrum.freq_bins - 1) {
684 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
685 }
686
687 self.ifft.process(&mut ifft_buffer);
689
690 let pos = frame_idx * self.config.hop_size;
692 for i in 0..self.config.fft_size {
693 let fft_size_t = T::from(self.config.fft_size).unwrap();
694 let sample = ifft_buffer[i].re / fft_size_t;
695
696 match self.config.reconstruction_mode {
697 ReconstructionMode::Ola => {
698 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
700 }
701 ReconstructionMode::Wola => {
702 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
704 }
705 }
706 }
707 }
708
709 let threshold = T::from(1e-8).unwrap();
711 for i in 0..padded_len {
712 if window_energy[i] > threshold {
713 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
714 }
715 }
716
717 overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
719 }
720
721 pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
724 assert_eq!(
725 spectrum.freq_bins,
726 self.config.freq_bins(),
727 "Frequency bins mismatch"
728 );
729
730 let num_frames = spectrum.num_frames;
731 let original_time_len = (num_frames - 1) * self.config.hop_size;
732 let pad_amount = self.config.fft_size / 2;
733 let padded_len = original_time_len + 2 * pad_amount;
734
735 let mut overlap_buffer = vec![T::zero(); padded_len];
736 let mut window_energy = vec![T::zero(); padded_len];
737 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
738
739 for frame_idx in 0..num_frames {
741 let pos = frame_idx * self.config.hop_size;
742 for i in 0..self.config.fft_size {
743 match self.config.reconstruction_mode {
744 ReconstructionMode::Ola => {
745 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
746 }
747 ReconstructionMode::Wola => {
748 window_energy[pos + i] =
749 window_energy[pos + i] + self.window[i] * self.window[i];
750 }
751 }
752 }
753 }
754
755 for frame_idx in 0..num_frames {
757 for bin in 0..spectrum.freq_bins {
759 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
760 }
761
762 for bin in 1..(spectrum.freq_bins - 1) {
764 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
765 }
766
767 self.ifft.process(&mut ifft_buffer);
769
770 let pos = frame_idx * self.config.hop_size;
772 for i in 0..self.config.fft_size {
773 let fft_size_t = T::from(self.config.fft_size).unwrap();
774 let sample = ifft_buffer[i].re / fft_size_t;
775
776 match self.config.reconstruction_mode {
777 ReconstructionMode::Ola => {
778 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
779 }
780 ReconstructionMode::Wola => {
781 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
782 }
783 }
784 }
785 }
786
787 let threshold = T::from(1e-8).unwrap();
789 for i in 0..padded_len {
790 if window_energy[i] > threshold {
791 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
792 }
793 }
794
795 output.clear();
797 output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
798 }
799}
800
801pub struct StreamingStft<T: Float + FftNum> {
802 config: StftConfig<T>,
803 window: Vec<T>,
804 fft: Arc<dyn Fft<T>>,
805 input_buffer: VecDeque<T>,
806 frame_index: usize,
807 fft_buffer: Vec<Complex<T>>,
808}
809
810impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
811 pub fn new(config: StftConfig<T>) -> Self {
812 let window = config.generate_window();
813 let mut planner = FftPlanner::new();
814 let fft = planner.plan_fft_forward(config.fft_size);
815 let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
816
817 Self {
818 config,
819 window,
820 fft,
821 input_buffer: VecDeque::new(),
822 frame_index: 0,
823 fft_buffer,
824 }
825 }
826
827 pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
828 self.input_buffer.extend(samples.iter().copied());
829
830 let mut frames = Vec::new();
831
832 while self.input_buffer.len() >= self.config.fft_size {
833 for i in 0..self.config.fft_size {
835 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
836 }
837
838 self.fft.process(&mut self.fft_buffer);
839
840 let freq_bins = self.config.freq_bins();
841 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
842 frames.push(SpectrumFrame::from_data(data));
843
844 self.input_buffer.drain(..self.config.hop_size);
846 self.frame_index += 1;
847 }
848
849 frames
850 }
851
852 pub fn push_samples_into(
855 &mut self,
856 samples: &[T],
857 output: &mut Vec<SpectrumFrame<T>>,
858 ) -> usize {
859 self.input_buffer.extend(samples.iter().copied());
860
861 let initial_len = output.len();
862
863 while self.input_buffer.len() >= self.config.fft_size {
864 for i in 0..self.config.fft_size {
866 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
867 }
868
869 self.fft.process(&mut self.fft_buffer);
870
871 let freq_bins = self.config.freq_bins();
872 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
873 output.push(SpectrumFrame::from_data(data));
874
875 self.input_buffer.drain(..self.config.hop_size);
877 self.frame_index += 1;
878 }
879
880 output.len() - initial_len
881 }
882
883 pub fn push_samples_write(
896 &mut self,
897 samples: &[T],
898 frame_pool: &mut [SpectrumFrame<T>],
899 pool_index: &mut usize,
900 ) -> usize {
901 self.input_buffer.extend(samples.iter().copied());
902
903 let initial_index = *pool_index;
904 let freq_bins = self.config.freq_bins();
905
906 while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
907 for i in 0..self.config.fft_size {
909 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
910 }
911
912 self.fft.process(&mut self.fft_buffer);
913
914 let frame = &mut frame_pool[*pool_index];
916 debug_assert_eq!(
917 frame.freq_bins, freq_bins,
918 "Frame pool frames must match freq_bins"
919 );
920 frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
921
922 self.input_buffer.drain(..self.config.hop_size);
924 self.frame_index += 1;
925 *pool_index += 1;
926 }
927
928 *pool_index - initial_index
929 }
930
931 pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
932 Vec::new()
935 }
936
937 pub fn reset(&mut self) {
938 self.input_buffer.clear();
939 self.frame_index = 0;
940 }
941
942 pub fn buffered_samples(&self) -> usize {
943 self.input_buffer.len()
944 }
945}
946
947pub struct StreamingIstft<T: Float + FftNum> {
948 config: StftConfig<T>,
949 window: Vec<T>,
950 ifft: Arc<dyn Fft<T>>,
951 overlap_buffer: Vec<T>,
952 window_energy: Vec<T>,
953 output_position: usize,
954 frames_processed: usize,
955 ifft_buffer: Vec<Complex<T>>,
956}
957
958impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
959 pub fn new(config: StftConfig<T>) -> Self {
960 let window = config.generate_window();
961 let mut planner = FftPlanner::new();
962 let ifft = planner.plan_fft_inverse(config.fft_size);
963
964 let buffer_size = config.fft_size * 2;
967 let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
968
969 Self {
970 config,
971 window,
972 ifft,
973 overlap_buffer: vec![T::zero(); buffer_size],
974 window_energy: vec![T::zero(); buffer_size],
975 output_position: 0,
976 frames_processed: 0,
977 ifft_buffer,
978 }
979 }
980
981 pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
982 assert_eq!(
983 frame.freq_bins,
984 self.config.freq_bins(),
985 "Frequency bins mismatch"
986 );
987
988 for bin in 0..frame.freq_bins {
990 self.ifft_buffer[bin] = frame.data[bin];
991 }
992
993 for bin in 1..(frame.freq_bins - 1) {
995 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
996 }
997
998 self.ifft.process(&mut self.ifft_buffer);
1000
1001 let write_pos = self.frames_processed * self.config.hop_size;
1003 for i in 0..self.config.fft_size {
1004 let fft_size_t = T::from(self.config.fft_size).unwrap();
1005 let sample = self.ifft_buffer[i].re / fft_size_t;
1006 let buf_idx = write_pos + i;
1007
1008 if buf_idx >= self.overlap_buffer.len() {
1010 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1011 self.window_energy.resize(buf_idx + 1, T::zero());
1012 }
1013
1014 match self.config.reconstruction_mode {
1015 ReconstructionMode::Ola => {
1016 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1017 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1018 }
1019 ReconstructionMode::Wola => {
1020 self.overlap_buffer[buf_idx] =
1021 self.overlap_buffer[buf_idx] + sample * self.window[i];
1022 self.window_energy[buf_idx] =
1023 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1024 }
1025 }
1026 }
1027
1028 self.frames_processed += 1;
1029
1030 let ready_until = if self.frames_processed == 1 {
1033 0 } else {
1035 (self.frames_processed - 1) * self.config.hop_size
1037 };
1038
1039 let output_start = self.output_position;
1041 let output_end = ready_until;
1042 let mut output = Vec::new();
1043
1044 let threshold = T::from(1e-8).unwrap();
1045 if output_end > output_start {
1046 for i in output_start..output_end {
1047 let normalized = if self.window_energy[i] > threshold {
1048 self.overlap_buffer[i] / self.window_energy[i]
1049 } else {
1050 T::zero()
1051 };
1052 output.push(normalized);
1053 }
1054 self.output_position = output_end;
1055 }
1056
1057 output
1058 }
1059
1060 pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1063 assert_eq!(
1064 frame.freq_bins,
1065 self.config.freq_bins(),
1066 "Frequency bins mismatch"
1067 );
1068
1069 for bin in 0..frame.freq_bins {
1071 self.ifft_buffer[bin] = frame.data[bin];
1072 }
1073
1074 for bin in 1..(frame.freq_bins - 1) {
1076 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1077 }
1078
1079 self.ifft.process(&mut self.ifft_buffer);
1081
1082 let write_pos = self.frames_processed * self.config.hop_size;
1084 for i in 0..self.config.fft_size {
1085 let fft_size_t = T::from(self.config.fft_size).unwrap();
1086 let sample = self.ifft_buffer[i].re / fft_size_t;
1087 let buf_idx = write_pos + i;
1088
1089 if buf_idx >= self.overlap_buffer.len() {
1091 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1092 self.window_energy.resize(buf_idx + 1, T::zero());
1093 }
1094
1095 match self.config.reconstruction_mode {
1096 ReconstructionMode::Ola => {
1097 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1098 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1099 }
1100 ReconstructionMode::Wola => {
1101 self.overlap_buffer[buf_idx] =
1102 self.overlap_buffer[buf_idx] + sample * self.window[i];
1103 self.window_energy[buf_idx] =
1104 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1105 }
1106 }
1107 }
1108
1109 self.frames_processed += 1;
1110
1111 let ready_until = if self.frames_processed == 1 {
1114 0 } else {
1116 (self.frames_processed - 1) * self.config.hop_size
1118 };
1119
1120 let output_start = self.output_position;
1122 let output_end = ready_until;
1123 let initial_len = output.len();
1124
1125 let threshold = T::from(1e-8).unwrap();
1126 if output_end > output_start {
1127 for i in output_start..output_end {
1128 let normalized = if self.window_energy[i] > threshold {
1129 self.overlap_buffer[i] / self.window_energy[i]
1130 } else {
1131 T::zero()
1132 };
1133 output.push(normalized);
1134 }
1135 self.output_position = output_end;
1136 }
1137
1138 output.len() - initial_len
1139 }
1140
1141 pub fn flush(&mut self) -> Vec<T> {
1142 let mut output = Vec::new();
1144 let threshold = T::from(1e-8).unwrap();
1145 for i in self.output_position..self.overlap_buffer.len() {
1146 if self.window_energy[i] > threshold {
1147 output.push(self.overlap_buffer[i] / self.window_energy[i]);
1148 } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1149 output.push(T::zero()); } else {
1151 break; }
1153 }
1154
1155 let valid_end =
1157 (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1158 if output.len() > valid_end - self.output_position {
1159 output.truncate(valid_end - self.output_position);
1160 }
1161
1162 self.reset();
1163 output
1164 }
1165
1166 pub fn reset(&mut self) {
1167 self.overlap_buffer.clear();
1168 self.overlap_buffer
1169 .resize(self.config.fft_size * 2, T::zero());
1170 self.window_energy.clear();
1171 self.window_energy
1172 .resize(self.config.fft_size * 2, T::zero());
1173 self.output_position = 0;
1174 self.frames_processed = 0;
1175 }
1176}
1177
1178pub fn apply_padding<T: Float>(signal: &[T], pad_amount: usize, mode: PadMode) -> Vec<T> {
1181 let total_len = signal.len() + 2 * pad_amount;
1182 let mut padded = vec![T::zero(); total_len];
1183
1184 padded[pad_amount..pad_amount + signal.len()].copy_from_slice(signal);
1185
1186 match mode {
1187 PadMode::Reflect => {
1188 for i in 0..pad_amount {
1189 if i + 1 < signal.len() {
1190 padded[pad_amount - 1 - i] = signal[i + 1];
1191 }
1192 }
1193
1194 let n = signal.len();
1195 for i in 0..pad_amount {
1196 if n >= 2 && n - 2 >= i {
1197 padded[pad_amount + n + i] = signal[n - 2 - i];
1198 }
1199 }
1200 }
1201 PadMode::Zero => {}
1202 PadMode::Edge => {
1203 if !signal.is_empty() {
1204 for i in 0..pad_amount {
1205 padded[i] = signal[0];
1206 }
1207 for i in 0..pad_amount {
1208 padded[pad_amount + signal.len() + i] = signal[signal.len() - 1];
1209 }
1210 }
1211 }
1212 }
1213
1214 padded
1215}
1216
1217pub type StftConfigF32 = StftConfig<f32>;
1219pub type StftConfigF64 = StftConfig<f64>;
1220
1221pub type BatchStftF32 = BatchStft<f32>;
1222pub type BatchStftF64 = BatchStft<f64>;
1223
1224pub type BatchIstftF32 = BatchIstft<f32>;
1225pub type BatchIstftF64 = BatchIstft<f64>;
1226
1227pub type StreamingStftF32 = StreamingStft<f32>;
1228pub type StreamingStftF64 = StreamingStft<f64>;
1229
1230pub type StreamingIstftF32 = StreamingIstft<f32>;
1231pub type StreamingIstftF64 = StreamingIstft<f64>;
1232
1233pub type SpectrumF32 = Spectrum<f32>;
1234pub type SpectrumF64 = Spectrum<f64>;
1235
1236pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1237pub type SpectrumFrameF64 = SpectrumFrame<f64>;