stft_rs/
lib.rs

1/*MIT License
2
3Copyright (c) 2025 David Maseda Neira
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in all
13copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22*/
23
24#![cfg_attr(not(feature = "std"), no_std)]
25
26#[cfg(not(feature = "std"))]
27extern crate alloc;
28
29#[cfg(not(feature = "std"))]
30use alloc::{collections::VecDeque, sync::Arc, vec, vec::Vec};
31
32#[cfg(feature = "std")]
33use std::{collections::VecDeque, sync::Arc, vec};
34
35use core::fmt;
36use core::marker::PhantomData;
37use num_traits::{Float, FromPrimitive};
38
39pub mod fft_backend;
40use fft_backend::{Complex, FftBackend, FftNum, FftPlanner, FftPlannerTrait};
41
42mod utils;
43pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
44
45pub mod mel;
46
47pub mod prelude {
48    pub use crate::fft_backend::Complex;
49    pub use crate::mel::{
50        BatchMelSpectrogram, BatchMelSpectrogramF32, BatchMelSpectrogramF64, MelConfig,
51        MelConfigF32, MelConfigF64, MelFilterbank, MelFilterbankF32, MelFilterbankF64, MelNorm,
52        MelScale, MelSpectrum, MelSpectrumF32, MelSpectrumF64, StreamingMelSpectrogram,
53        StreamingMelSpectrogramF32, StreamingMelSpectrogramF64,
54    };
55    pub use crate::utils::{
56        apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
57    };
58    pub use crate::{
59        BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
60        MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
61        MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
62        PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
63        SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
64        StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
65        StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
66    };
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum ReconstructionMode {
71    /// Overlap-Add: normalize by sum(w), requires COLA condition
72    Ola,
73
74    /// Weighted Overlap-Add: normalize by sum(w^2), requires NOLA condition
75    Wola,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum WindowType {
80    Hann,
81    Hamming,
82    Blackman,
83}
84
85#[derive(Debug, Clone)]
86pub enum ConfigError<T: Float + fmt::Debug> {
87    NolaViolation { min_energy: T, threshold: T },
88    ColaViolation { max_deviation: T, threshold: T },
89    InvalidHopSize,
90    InvalidFftSize,
91}
92
93impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            ConfigError::NolaViolation {
97                min_energy,
98                threshold,
99            } => {
100                write!(
101                    f,
102                    "NOLA condition violated: min_energy={} < threshold={}",
103                    min_energy, threshold
104                )
105            }
106            ConfigError::ColaViolation {
107                max_deviation,
108                threshold,
109            } => {
110                write!(
111                    f,
112                    "COLA condition violated: max_deviation={} > threshold={}",
113                    max_deviation, threshold
114                )
115            }
116            ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
117            ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
118        }
119    }
120}
121
122#[cfg(feature = "std")]
123impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum PadMode {
127    Reflect,
128    Zero,
129    Edge,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct StftConfig<T: Float> {
134    pub fft_size: usize,
135    pub hop_size: usize,
136    pub window: WindowType,
137    pub reconstruction_mode: ReconstructionMode,
138    _phantom: PhantomData<T>,
139}
140
141impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
142    fn nola_threshold() -> T {
143        T::from(1e-8).unwrap()
144    }
145
146    fn cola_relative_tolerance() -> T {
147        T::from(1e-4).unwrap()
148    }
149
150    #[deprecated(
151        since = "0.4.0",
152        note = "Use `StftConfig::builder()` instead for a more flexible API"
153    )]
154    pub fn new(
155        fft_size: usize,
156        hop_size: usize,
157        window: WindowType,
158        reconstruction_mode: ReconstructionMode,
159    ) -> Result<Self, ConfigError<T>> {
160        if fft_size == 0 || !(cfg!(feature = "rustfft-backend") || fft_size.is_power_of_two()) {
161            return Err(ConfigError::InvalidFftSize);
162        }
163        if hop_size == 0 || hop_size > fft_size {
164            return Err(ConfigError::InvalidHopSize);
165        }
166
167        let config = Self {
168            fft_size,
169            hop_size,
170            window,
171            reconstruction_mode,
172            _phantom: PhantomData,
173        };
174
175        // Validate appropriate condition based on reconstruction mode
176        match reconstruction_mode {
177            ReconstructionMode::Ola => config.validate_cola()?,
178            ReconstructionMode::Wola => config.validate_nola()?,
179        }
180
181        Ok(config)
182    }
183
184    /// Create a new builder for StftConfig
185    pub fn builder() -> StftConfigBuilder<T> {
186        StftConfigBuilder::new()
187    }
188
189    /// Default: 4096 FFT, 1024 hop, Hann window, OLA mode
190    #[allow(deprecated)]
191    pub fn default_4096() -> Self {
192        Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
193            .expect("Default config should always be valid")
194    }
195
196    pub fn freq_bins(&self) -> usize {
197        self.fft_size / 2 + 1
198    }
199
200    pub fn overlap_percent(&self) -> T {
201        let one = T::one();
202        let hundred = T::from(100.0).unwrap();
203        (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
204    }
205
206    fn generate_window(&self) -> Vec<T> {
207        generate_window(self.window, self.fft_size)
208    }
209
210    /// Validate NOLA condition: sum(w^2) > threshold everywhere
211    pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
212        let window = self.generate_window();
213        let num_overlaps = self.fft_size.div_ceil(self.hop_size);
214        let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
215        let mut energy = vec![T::zero(); test_len];
216
217        for i in 0..num_overlaps {
218            let offset = i * self.hop_size;
219            for j in 0..self.fft_size {
220                if offset + j < test_len {
221                    energy[offset + j] = energy[offset + j] + window[j] * window[j];
222                }
223            }
224        }
225
226        // Check the steady-state region (skip edges)
227        let start = self.fft_size / 2;
228        let end = test_len - self.fft_size / 2;
229        let min_energy = energy[start..end]
230            .iter()
231            .copied()
232            .min_by(|a, b| a.partial_cmp(b).unwrap())
233            .unwrap_or_else(T::zero);
234
235        if min_energy < Self::nola_threshold() {
236            return Err(ConfigError::NolaViolation {
237                min_energy,
238                threshold: Self::nola_threshold(),
239            });
240        }
241
242        Ok(())
243    }
244
245    /// Validate weak COLA condition: sum(w) is constant (within relative tolerance)
246    pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
247        let window = self.generate_window();
248        let window_len = window.len();
249
250        let mut cola_sum_period = vec![T::zero(); self.hop_size];
251        (0..window_len).for_each(|i| {
252            let idx = i % self.hop_size;
253            cola_sum_period[idx] = cola_sum_period[idx] + window[i];
254        });
255
256        let zero = T::zero();
257        let min_sum = cola_sum_period
258            .iter()
259            .min_by(|a, b| a.partial_cmp(b).unwrap())
260            .unwrap_or(&zero);
261        let max_sum = cola_sum_period
262            .iter()
263            .max_by(|a, b| a.partial_cmp(b).unwrap())
264            .unwrap_or(&zero);
265
266        let epsilon = T::from(1e-9).unwrap();
267        if *max_sum < epsilon {
268            return Err(ConfigError::ColaViolation {
269                max_deviation: T::infinity(),
270                threshold: Self::cola_relative_tolerance(),
271            });
272        }
273
274        let ripple = (*max_sum - *min_sum) / *max_sum;
275
276        let is_compliant = ripple < Self::cola_relative_tolerance();
277
278        if !is_compliant {
279            return Err(ConfigError::ColaViolation {
280                max_deviation: ripple,
281                threshold: Self::cola_relative_tolerance(),
282            });
283        }
284        Ok(())
285    }
286}
287
288/// Builder for StftConfig with fluent API
289#[derive(Debug, Clone, PartialEq)]
290pub struct StftConfigBuilder<T: Float> {
291    fft_size: Option<usize>,
292    hop_size: Option<usize>,
293    window: WindowType,
294    reconstruction_mode: ReconstructionMode,
295    _phantom: PhantomData<T>,
296}
297
298impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
299    /// Create a new builder with default values (Hann window, OLA mode)
300    pub fn new() -> Self {
301        Self {
302            fft_size: None,
303            hop_size: None,
304            window: WindowType::Hann,
305            reconstruction_mode: ReconstructionMode::Ola,
306            _phantom: PhantomData,
307        }
308    }
309
310    /// Set the FFT size (must be a power of two)
311    pub fn fft_size(mut self, fft_size: usize) -> Self {
312        self.fft_size = Some(fft_size);
313        self
314    }
315
316    /// Set the hop size (must be > 0 and <= fft_size)
317    pub fn hop_size(mut self, hop_size: usize) -> Self {
318        self.hop_size = Some(hop_size);
319        self
320    }
321
322    /// Set the window type (default: Hann)
323    pub fn window(mut self, window: WindowType) -> Self {
324        self.window = window;
325        self
326    }
327
328    /// Set the reconstruction mode (default: OLA)
329    pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
330        self.reconstruction_mode = mode;
331        self
332    }
333
334    /// Build the StftConfig, validating all parameters
335    ///
336    /// Returns an error if:
337    /// - fft_size is not set or not a power of two
338    /// - hop_size is not set, zero, or greater than fft_size
339    /// - COLA/NOLA conditions are violated
340    #[allow(deprecated)]
341    pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
342        let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
343        let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
344
345        StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
346    }
347}
348
349impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
350    fn default() -> Self {
351        Self::new()
352    }
353}
354
355fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
356    let pi = T::from(core::f64::consts::PI).unwrap();
357    let two = T::from(2.0).unwrap();
358
359    match window_type {
360        WindowType::Hann => (0..size)
361            .map(|i| {
362                let half = T::from(0.5).unwrap();
363                let one = T::one();
364                let i_t = T::from(i).unwrap();
365                let size_t = T::from(size).unwrap(); // Use N, not N-1 for periodic window
366                half * (one - (two * pi * i_t / size_t).cos())
367            })
368            .collect(),
369        WindowType::Hamming => (0..size)
370            .map(|i| {
371                let i_t = T::from(i).unwrap();
372                let size_t = T::from(size).unwrap(); // Use N, not N-1 for periodic window
373                T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_t).cos()
374            })
375            .collect(),
376        WindowType::Blackman => (0..size)
377            .map(|i| {
378                let i_t = T::from(i).unwrap();
379                let size_t = T::from(size).unwrap(); // Use N, not N-1 for periodic window
380                let angle = two * pi * i_t / size_t;
381                T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
382                    + T::from(0.08).unwrap() * (two * angle).cos()
383            })
384            .collect(),
385    }
386}
387
388#[derive(Debug, Clone, PartialEq)]
389pub struct SpectrumFrame<T: Float> {
390    pub freq_bins: usize,
391    pub data: Vec<Complex<T>>,
392}
393
394impl<T: Float> SpectrumFrame<T> {
395    pub fn new(freq_bins: usize) -> Self {
396        Self {
397            freq_bins,
398            data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
399        }
400    }
401
402    pub fn from_data(data: Vec<Complex<T>>) -> Self {
403        let freq_bins = data.len();
404        Self { freq_bins, data }
405    }
406
407    /// Prepare frame for reuse by clearing data (keeps capacity)
408    pub fn clear(&mut self) {
409        for val in &mut self.data {
410            *val = Complex::new(T::zero(), T::zero());
411        }
412    }
413
414    /// Resize frame if needed to match freq_bins
415    pub fn resize_if_needed(&mut self, freq_bins: usize) {
416        if self.freq_bins != freq_bins {
417            self.freq_bins = freq_bins;
418            self.data
419                .resize(freq_bins, Complex::new(T::zero(), T::zero()));
420        }
421    }
422
423    /// Write data from a slice into this frame
424    pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
425        self.resize_if_needed(data.len());
426        self.data.copy_from_slice(data);
427    }
428
429    /// Get the magnitude of a frequency bin
430    #[inline]
431    pub fn magnitude(&self, bin: usize) -> T {
432        let c = &self.data[bin];
433        (c.re * c.re + c.im * c.im).sqrt()
434    }
435
436    /// Get the phase of a frequency bin in radians
437    #[inline]
438    pub fn phase(&self, bin: usize) -> T {
439        let c = &self.data[bin];
440        c.im.atan2(c.re)
441    }
442
443    /// Set a frequency bin from magnitude and phase
444    pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
445        self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
446    }
447
448    /// Create a SpectrumFrame from magnitude and phase arrays
449    pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
450        assert_eq!(
451            magnitudes.len(),
452            phases.len(),
453            "Magnitude and phase arrays must have same length"
454        );
455        let freq_bins = magnitudes.len();
456        let data: Vec<Complex<T>> = magnitudes
457            .iter()
458            .zip(phases.iter())
459            .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
460            .collect();
461        Self { freq_bins, data }
462    }
463
464    /// Get all magnitudes as a Vec
465    pub fn magnitudes(&self) -> Vec<T> {
466        self.data
467            .iter()
468            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
469            .collect()
470    }
471
472    /// Get all phases as a Vec
473    pub fn phases(&self) -> Vec<T> {
474        self.data.iter().map(|c| c.im.atan2(c.re)).collect()
475    }
476}
477
478#[derive(Debug, Clone, PartialEq)]
479pub struct Spectrum<T: Float> {
480    pub num_frames: usize,
481    pub freq_bins: usize,
482    pub data: Vec<T>,
483}
484
485impl<T: Float> Spectrum<T> {
486    pub fn new(num_frames: usize, freq_bins: usize) -> Self {
487        Self {
488            num_frames,
489            freq_bins,
490            data: vec![T::zero(); 2 * num_frames * freq_bins],
491        }
492    }
493
494    #[inline]
495    pub fn real(&self, frame: usize, bin: usize) -> T {
496        self.data[frame * self.freq_bins + bin]
497    }
498
499    #[inline]
500    pub fn imag(&self, frame: usize, bin: usize) -> T {
501        let offset = self.num_frames * self.freq_bins;
502        self.data[offset + frame * self.freq_bins + bin]
503    }
504
505    #[inline]
506    pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
507        Complex::new(self.real(frame, bin), self.imag(frame, bin))
508    }
509
510    pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
511        (0..self.num_frames).map(move |frame_idx| {
512            let data: Vec<Complex<T>> = (0..self.freq_bins)
513                .map(|bin| self.get_complex(frame_idx, bin))
514                .collect();
515            SpectrumFrame::from_data(data)
516        })
517    }
518
519    /// Set the real part of a bin
520    #[inline]
521    pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
522        self.data[frame * self.freq_bins + bin] = value;
523    }
524
525    /// Set the imaginary part of a bin
526    #[inline]
527    pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
528        let offset = self.num_frames * self.freq_bins;
529        self.data[offset + frame * self.freq_bins + bin] = value;
530    }
531
532    /// Set a bin from a complex value
533    #[inline]
534    pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
535        self.set_real(frame, bin, value.re);
536        self.set_imag(frame, bin, value.im);
537    }
538
539    /// Get the magnitude of a frequency bin
540    #[inline]
541    pub fn magnitude(&self, frame: usize, bin: usize) -> T {
542        let re = self.real(frame, bin);
543        let im = self.imag(frame, bin);
544        (re * re + im * im).sqrt()
545    }
546
547    /// Get the phase of a frequency bin in radians
548    #[inline]
549    pub fn phase(&self, frame: usize, bin: usize) -> T {
550        let re = self.real(frame, bin);
551        let im = self.imag(frame, bin);
552        im.atan2(re)
553    }
554
555    /// Set a frequency bin from magnitude and phase
556    pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
557        self.set_real(frame, bin, magnitude * phase.cos());
558        self.set_imag(frame, bin, magnitude * phase.sin());
559    }
560
561    /// Get all magnitudes for a frame
562    pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
563        (0..self.freq_bins)
564            .map(|bin| self.magnitude(frame, bin))
565            .collect()
566    }
567
568    /// Get all phases for a frame
569    pub fn frame_phases(&self, frame: usize) -> Vec<T> {
570        (0..self.freq_bins)
571            .map(|bin| self.phase(frame, bin))
572            .collect()
573    }
574
575    /// Apply a function to all bins
576    pub fn apply<F>(&mut self, mut f: F)
577    where
578        F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
579    {
580        for frame in 0..self.num_frames {
581            for bin in 0..self.freq_bins {
582                let c = self.get_complex(frame, bin);
583                let new_c = f(frame, bin, c);
584                self.set_complex(frame, bin, new_c);
585            }
586        }
587    }
588
589    /// Apply a gain to a range of bins across all frames
590    pub fn apply_gain(&mut self, bin_range: core::ops::Range<usize>, gain: T) {
591        for frame in 0..self.num_frames {
592            for bin in bin_range.clone() {
593                if bin < self.freq_bins {
594                    let c = self.get_complex(frame, bin);
595                    self.set_complex(frame, bin, c * gain);
596                }
597            }
598        }
599    }
600
601    /// Zero out a range of bins across all frames
602    pub fn zero_bins(&mut self, bin_range: core::ops::Range<usize>) {
603        for frame in 0..self.num_frames {
604            for bin in bin_range.clone() {
605                if bin < self.freq_bins {
606                    self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
607                }
608            }
609        }
610    }
611}
612
613#[derive(Debug, Clone)]
614pub struct BatchStft<T: Float + FftNum> {
615    config: StftConfig<T>,
616    window: Vec<T>,
617    fft: Arc<dyn FftBackend<T>>,
618}
619
620impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
621    pub fn new(config: StftConfig<T>) -> Self
622    where
623        FftPlanner<T>: FftPlannerTrait<T>,
624    {
625        let window = config.generate_window();
626        let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
627        let fft = planner.plan_fft_forward(config.fft_size);
628
629        Self {
630            config,
631            window,
632            fft,
633        }
634    }
635
636    pub fn process(&self, signal: &[T]) -> Spectrum<T> {
637        self.process_padded(signal, PadMode::Reflect)
638    }
639
640    pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
641        let pad_amount = self.config.fft_size / 2;
642        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
643
644        let num_frames = if padded.len() >= self.config.fft_size {
645            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
646        } else {
647            0
648        };
649
650        let freq_bins = self.config.freq_bins();
651        let mut result = Spectrum::new(num_frames, freq_bins);
652
653        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
654
655        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
656            .step_by(self.config.hop_size)
657            .enumerate()
658        {
659            // Apply window and prepare FFT input
660            for i in 0..self.config.fft_size {
661                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
662            }
663
664            // Compute FFT
665            self.fft.process(&mut fft_buffer);
666
667            // Store positive frequencies in flat layout
668            (0..freq_bins).for_each(|bin| {
669                let idx = frame_idx * freq_bins + bin;
670                result.data[idx] = fft_buffer[bin].re;
671                result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
672            });
673        }
674
675        result
676    }
677
678    /// Process signal and write into a pre-allocated Spectrum.
679    /// The spectrum must have the correct dimensions (num_frames x freq_bins).
680    /// Returns true if successful, false if dimensions don't match.
681    pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
682        self.process_padded_into(signal, PadMode::Reflect, spectrum)
683    }
684
685    /// Process signal with padding and write into a pre-allocated Spectrum.
686    pub fn process_padded_into(
687        &self,
688        signal: &[T],
689        pad_mode: PadMode,
690        spectrum: &mut Spectrum<T>,
691    ) -> bool {
692        let pad_amount = self.config.fft_size / 2;
693        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
694
695        let num_frames = if padded.len() >= self.config.fft_size {
696            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
697        } else {
698            0
699        };
700
701        let freq_bins = self.config.freq_bins();
702
703        // Check dimensions
704        if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
705            return false;
706        }
707
708        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
709
710        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
711            .step_by(self.config.hop_size)
712            .enumerate()
713        {
714            // Apply window and prepare FFT input
715            for i in 0..self.config.fft_size {
716                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
717            }
718
719            // Compute FFT
720            self.fft.process(&mut fft_buffer);
721
722            // Store positive frequencies in flat layout
723            (0..freq_bins).for_each(|bin| {
724                let idx = frame_idx * freq_bins + bin;
725                spectrum.data[idx] = fft_buffer[bin].re;
726                spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
727            });
728        }
729
730        true
731    }
732
733    /// Process multiple channels independently.
734    /// Returns one Spectrum per channel.
735    ///
736    /// # Arguments
737    ///
738    /// * `channels` - Slice of audio channels, each as a separate Vec
739    ///
740    /// # Panics
741    ///
742    /// Panics if channels is empty or if channels have different lengths.
743    ///
744    /// # Example
745    ///
746    /// ```
747    /// use stft_rs::prelude::*;
748    ///
749    /// let config = StftConfigF32::default_4096();
750    /// let stft = BatchStftF32::new(config);
751    ///
752    /// let left = vec![0.0; 44100];
753    /// let right = vec![0.0; 44100];
754    /// let channels = vec![left, right];
755    ///
756    /// let spectra = stft.process_multichannel(&channels);
757    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
758    /// ```
759    pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
760        assert!(!channels.is_empty(), "channels must not be empty");
761
762        // Validate all channels have same length
763        let expected_len = channels[0].len();
764        for (i, channel) in channels.iter().enumerate() {
765            assert_eq!(
766                channel.len(),
767                expected_len,
768                "Channel {} has length {}, expected {}",
769                i,
770                channel.len(),
771                expected_len
772            );
773        }
774
775        // Process each channel independently
776        #[cfg(feature = "rayon")]
777        {
778            use rayon::prelude::*;
779            channels
780                .par_iter()
781                .map(|channel| self.process(channel))
782                .collect()
783        }
784        #[cfg(not(feature = "rayon"))]
785        {
786            channels
787                .iter()
788                .map(|channel| self.process(channel))
789                .collect()
790        }
791    }
792
793    /// Process interleaved multi-channel audio.
794    /// Converts interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo)
795    /// into separate Spectrum for each channel.
796    ///
797    /// # Arguments
798    ///
799    /// * `data` - Interleaved audio data
800    /// * `num_channels` - Number of channels
801    ///
802    /// # Panics
803    ///
804    /// Panics if `num_channels` is 0 or if `data.len()` is not divisible by `num_channels`.
805    ///
806    /// # Example
807    ///
808    /// ```
809    /// use stft_rs::prelude::*;
810    ///
811    /// let config = StftConfigF32::default_4096();
812    /// let stft = BatchStftF32::new(config);
813    ///
814    /// // Stereo interleaved: L,R,L,R,L,R,...
815    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
816    ///
817    /// let spectra = stft.process_interleaved(&interleaved, 2);
818    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
819    /// ```
820    pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
821        let channels = utils::deinterleave(data, num_channels);
822        self.process_multichannel(&channels)
823    }
824}
825
826#[derive(Debug, Clone)]
827pub struct BatchIstft<T: Float + FftNum> {
828    config: StftConfig<T>,
829    window: Vec<T>,
830    ifft: Arc<dyn FftBackend<T>>,
831}
832
833impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
834    pub fn new(config: StftConfig<T>) -> Self
835    where
836        FftPlanner<T>: FftPlannerTrait<T>,
837    {
838        let window = config.generate_window();
839        let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
840        let ifft = planner.plan_fft_inverse(config.fft_size);
841
842        Self {
843            config,
844            window,
845            ifft,
846        }
847    }
848
849    pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
850        assert_eq!(
851            spectrum.freq_bins,
852            self.config.freq_bins(),
853            "Frequency bins mismatch"
854        );
855
856        let num_frames = spectrum.num_frames;
857        let original_time_len = (num_frames - 1) * self.config.hop_size;
858        let pad_amount = self.config.fft_size / 2;
859        let padded_len = original_time_len + 2 * pad_amount;
860
861        let mut overlap_buffer = vec![T::zero(); padded_len];
862        let mut window_energy = vec![T::zero(); padded_len];
863        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
864
865        // Precompute window energy normalization
866        for frame_idx in 0..num_frames {
867            let pos = frame_idx * self.config.hop_size;
868            for i in 0..self.config.fft_size {
869                match self.config.reconstruction_mode {
870                    ReconstructionMode::Ola => {
871                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
872                    }
873                    ReconstructionMode::Wola => {
874                        window_energy[pos + i] =
875                            window_energy[pos + i] + self.window[i] * self.window[i];
876                    }
877                }
878            }
879        }
880
881        // Process each frame
882        for frame_idx in 0..num_frames {
883            // Build full spectrum with conjugate symmetry
884            (0..spectrum.freq_bins).for_each(|bin| {
885                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
886            });
887
888            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
889            for bin in 1..(spectrum.freq_bins - 1) {
890                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
891            }
892
893            // Compute IFFT
894            self.ifft.process(&mut ifft_buffer);
895
896            // Overlap-add
897            let pos = frame_idx * self.config.hop_size;
898            for i in 0..self.config.fft_size {
899                let fft_size_t = T::from(self.config.fft_size).unwrap();
900                let sample = ifft_buffer[i].re / fft_size_t;
901
902                match self.config.reconstruction_mode {
903                    ReconstructionMode::Ola => {
904                        // OLA: no windowing on inverse
905                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
906                    }
907                    ReconstructionMode::Wola => {
908                        // WOLA: apply window on inverse
909                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
910                    }
911                }
912            }
913        }
914
915        // Normalize by window energy
916        let threshold = T::from(1e-8).unwrap();
917        for i in 0..padded_len {
918            if window_energy[i] > threshold {
919                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
920            }
921        }
922
923        // Remove padding
924        overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
925    }
926
927    /// Process spectrum and write into a pre-allocated output buffer.
928    /// The output buffer will be resized if needed.
929    pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
930        assert_eq!(
931            spectrum.freq_bins,
932            self.config.freq_bins(),
933            "Frequency bins mismatch"
934        );
935
936        let num_frames = spectrum.num_frames;
937        let original_time_len = (num_frames - 1) * self.config.hop_size;
938        let pad_amount = self.config.fft_size / 2;
939        let padded_len = original_time_len + 2 * pad_amount;
940
941        let mut overlap_buffer = vec![T::zero(); padded_len];
942        let mut window_energy = vec![T::zero(); padded_len];
943        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
944
945        // Precompute window energy normalization
946        for frame_idx in 0..num_frames {
947            let pos = frame_idx * self.config.hop_size;
948            for i in 0..self.config.fft_size {
949                match self.config.reconstruction_mode {
950                    ReconstructionMode::Ola => {
951                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
952                    }
953                    ReconstructionMode::Wola => {
954                        window_energy[pos + i] =
955                            window_energy[pos + i] + self.window[i] * self.window[i];
956                    }
957                }
958            }
959        }
960
961        // Process each frame
962        for frame_idx in 0..num_frames {
963            // Build full spectrum with conjugate symmetry
964            (0..spectrum.freq_bins).for_each(|bin| {
965                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
966            });
967
968            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
969            for bin in 1..(spectrum.freq_bins - 1) {
970                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
971            }
972
973            // Compute IFFT
974            self.ifft.process(&mut ifft_buffer);
975
976            // Overlap-add
977            let pos = frame_idx * self.config.hop_size;
978            for i in 0..self.config.fft_size {
979                let fft_size_t = T::from(self.config.fft_size).unwrap();
980                let sample = ifft_buffer[i].re / fft_size_t;
981
982                match self.config.reconstruction_mode {
983                    ReconstructionMode::Ola => {
984                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
985                    }
986                    ReconstructionMode::Wola => {
987                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
988                    }
989                }
990            }
991        }
992
993        // Normalize by window energy
994        let threshold = T::from(1e-8).unwrap();
995        for i in 0..padded_len {
996            if window_energy[i] > threshold {
997                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
998            }
999        }
1000
1001        // Copy to output (resize if needed)
1002        output.clear();
1003        output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
1004    }
1005
1006    /// Reconstruct multiple channels from their spectra.
1007    /// Returns one Vec per channel.
1008    ///
1009    /// # Arguments
1010    ///
1011    /// * `spectra` - Slice of Spectrum, one per channel
1012    ///
1013    /// # Panics
1014    ///
1015    /// Panics if spectra is empty.
1016    ///
1017    /// # Example
1018    ///
1019    /// ```
1020    /// use stft_rs::prelude::*;
1021    ///
1022    /// let config = StftConfigF32::default_4096();
1023    /// let stft = BatchStftF32::new(config.clone());
1024    /// let istft = BatchIstftF32::new(config);
1025    ///
1026    /// let left = vec![0.0; 44100];
1027    /// let right = vec![0.0; 44100];
1028    /// let channels = vec![left, right];
1029    ///
1030    /// let spectra = stft.process_multichannel(&channels);
1031    /// let reconstructed = istft.process_multichannel(&spectra);
1032    ///
1033    /// assert_eq!(reconstructed.len(), 2); // One channel per spectrum
1034    /// ```
1035    pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1036        assert!(!spectra.is_empty(), "spectra must not be empty");
1037
1038        // Process each spectrum independently
1039        #[cfg(feature = "rayon")]
1040        {
1041            use rayon::prelude::*;
1042            spectra
1043                .par_iter()
1044                .map(|spectrum| self.process(spectrum))
1045                .collect()
1046        }
1047        #[cfg(not(feature = "rayon"))]
1048        {
1049            spectra
1050                .iter()
1051                .map(|spectrum| self.process(spectrum))
1052                .collect()
1053        }
1054    }
1055
1056    /// Reconstruct multiple channels and interleave them into a single buffer.
1057    /// Converts separate channels back to interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo).
1058    ///
1059    /// # Arguments
1060    ///
1061    /// * `spectra` - Slice of Spectrum, one per channel
1062    ///
1063    /// # Panics
1064    ///
1065    /// Panics if spectra is empty or if channels have different lengths.
1066    ///
1067    /// # Example
1068    ///
1069    /// ```
1070    /// use stft_rs::prelude::*;
1071    ///
1072    /// let config = StftConfigF32::default_4096();
1073    /// let stft = BatchStftF32::new(config.clone());
1074    /// let istft = BatchIstftF32::new(config);
1075    ///
1076    /// // Process interleaved stereo
1077    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
1078    /// let spectra = stft.process_interleaved(&interleaved, 2);
1079    ///
1080    /// // Reconstruct back to interleaved
1081    /// let output = istft.process_multichannel_interleaved(&spectra);
1082    /// // Output length may differ slightly due to padding/framing
1083    /// assert_eq!(output.len() / 2, 44032); // samples per channel after reconstruction
1084    /// ```
1085    pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1086        let channels = self.process_multichannel(spectra);
1087        utils::interleave(&channels)
1088    }
1089}
1090
1091#[derive(Debug, Clone)]
1092pub struct StreamingStft<T: Float + FftNum> {
1093    config: StftConfig<T>,
1094    window: Vec<T>,
1095    fft: Arc<dyn FftBackend<T>>,
1096    input_buffer: VecDeque<T>,
1097    frame_index: usize,
1098    fft_buffer: Vec<Complex<T>>,
1099}
1100
1101impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1102    pub fn new(config: StftConfig<T>) -> Self
1103    where
1104        FftPlanner<T>: FftPlannerTrait<T>,
1105    {
1106        let window = config.generate_window();
1107        let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
1108        let fft = planner.plan_fft_forward(config.fft_size);
1109        let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1110
1111        Self {
1112            config,
1113            window,
1114            fft,
1115            input_buffer: VecDeque::new(),
1116            frame_index: 0,
1117            fft_buffer,
1118        }
1119    }
1120
1121    pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1122        self.input_buffer.extend(samples.iter().copied());
1123
1124        let mut frames = Vec::new();
1125
1126        while self.input_buffer.len() >= self.config.fft_size {
1127            // Process one frame
1128            for i in 0..self.config.fft_size {
1129                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1130            }
1131
1132            self.fft.process(&mut self.fft_buffer);
1133
1134            let freq_bins = self.config.freq_bins();
1135            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1136            frames.push(SpectrumFrame::from_data(data));
1137
1138            // Advance by hop size
1139            self.input_buffer.drain(..self.config.hop_size);
1140            self.frame_index += 1;
1141        }
1142
1143        frames
1144    }
1145
1146    /// Push samples and write frames into a pre-allocated buffer.
1147    /// Returns the number of frames written.
1148    pub fn push_samples_into(
1149        &mut self,
1150        samples: &[T],
1151        output: &mut Vec<SpectrumFrame<T>>,
1152    ) -> usize {
1153        self.input_buffer.extend(samples.iter().copied());
1154
1155        let initial_len = output.len();
1156
1157        while self.input_buffer.len() >= self.config.fft_size {
1158            // Process one frame
1159            for i in 0..self.config.fft_size {
1160                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1161            }
1162
1163            self.fft.process(&mut self.fft_buffer);
1164
1165            let freq_bins = self.config.freq_bins();
1166            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1167            output.push(SpectrumFrame::from_data(data));
1168
1169            // Advance by hop size
1170            self.input_buffer.drain(..self.config.hop_size);
1171            self.frame_index += 1;
1172        }
1173
1174        output.len() - initial_len
1175    }
1176
1177    /// Push samples and write directly into pre-existing SpectrumFrame buffers.
1178    /// This is a zero-allocation method - frames must be pre-allocated with correct size.
1179    /// Returns the number of frames written.
1180    ///
1181    /// # Example
1182    /// ```ignore
1183    /// let mut frame_pool = vec![SpectrumFrame::new(config.freq_bins()); 16];
1184    /// let mut frame_index = 0;
1185    ///
1186    /// let frames_written = stft.push_samples_write(chunk, &mut frame_pool, &mut frame_index);
1187    /// // Process frames 0..frames_written
1188    /// ```
1189    pub fn push_samples_write(
1190        &mut self,
1191        samples: &[T],
1192        frame_pool: &mut [SpectrumFrame<T>],
1193        pool_index: &mut usize,
1194    ) -> usize {
1195        self.input_buffer.extend(samples.iter().copied());
1196
1197        let initial_index = *pool_index;
1198        let freq_bins = self.config.freq_bins();
1199
1200        while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1201            // Process one frame
1202            for i in 0..self.config.fft_size {
1203                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1204            }
1205
1206            self.fft.process(&mut self.fft_buffer);
1207
1208            // Write directly into the pre-allocated frame
1209            let frame = &mut frame_pool[*pool_index];
1210            debug_assert_eq!(
1211                frame.freq_bins, freq_bins,
1212                "Frame pool frames must match freq_bins"
1213            );
1214            frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1215
1216            // Advance by hop size
1217            self.input_buffer.drain(..self.config.hop_size);
1218            self.frame_index += 1;
1219            *pool_index += 1;
1220        }
1221
1222        *pool_index - initial_index
1223    }
1224
1225    pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1226        // For streaming, we typically don't process partial frames
1227        // Could zero-pad if needed, but that changes the signal
1228        Vec::new()
1229    }
1230
1231    pub fn reset(&mut self) {
1232        self.input_buffer.clear();
1233        self.frame_index = 0;
1234    }
1235
1236    pub fn buffered_samples(&self) -> usize {
1237        self.input_buffer.len()
1238    }
1239}
1240
1241/// Multi-channel streaming STFT processor with independent state per channel.
1242#[derive(Debug, Clone)]
1243pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1244    processors: Vec<StreamingStft<T>>,
1245}
1246
1247impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T>
1248where
1249    FftPlanner<T>: FftPlannerTrait<T>,
1250{
1251    /// Create a new multi-channel streaming STFT processor.
1252    ///
1253    /// # Arguments
1254    ///
1255    /// * `config` - STFT configuration
1256    /// * `num_channels` - Number of channels
1257    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1258        assert!(num_channels > 0, "num_channels must be > 0");
1259        let processors = (0..num_channels)
1260            .map(|_| StreamingStft::new(config.clone()))
1261            .collect();
1262        Self { processors }
1263    }
1264
1265    /// Push samples for all channels and get frames for each channel.
1266    /// Returns Vec<Vec<SpectrumFrame>>, outer Vec = channels, inner Vec = frames.
1267    ///
1268    /// # Arguments
1269    ///
1270    /// * `channels` - Slice of sample slices, one per channel
1271    ///
1272    /// # Panics
1273    ///
1274    /// Panics if channels.len() doesn't match num_channels.
1275    pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1276        assert_eq!(
1277            channels.len(),
1278            self.processors.len(),
1279            "Expected {} channels, got {}",
1280            self.processors.len(),
1281            channels.len()
1282        );
1283
1284        #[cfg(feature = "rayon")]
1285        {
1286            use rayon::prelude::*;
1287            self.processors
1288                .par_iter_mut()
1289                .zip(channels.par_iter())
1290                .map(|(stft, channel)| stft.push_samples(channel))
1291                .collect()
1292        }
1293        #[cfg(not(feature = "rayon"))]
1294        {
1295            self.processors
1296                .iter_mut()
1297                .zip(channels.iter())
1298                .map(|(stft, channel)| stft.push_samples(channel))
1299                .collect()
1300        }
1301    }
1302
1303    /// Flush all channels and return remaining frames.
1304    pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1305        #[cfg(feature = "rayon")]
1306        {
1307            use rayon::prelude::*;
1308            self.processors
1309                .par_iter_mut()
1310                .map(|stft| stft.flush())
1311                .collect()
1312        }
1313        #[cfg(not(feature = "rayon"))]
1314        {
1315            self.processors
1316                .iter_mut()
1317                .map(|stft| stft.flush())
1318                .collect()
1319        }
1320    }
1321
1322    /// Reset all channels.
1323    pub fn reset(&mut self) {
1324        #[cfg(feature = "rayon")]
1325        {
1326            use rayon::prelude::*;
1327            self.processors.par_iter_mut().for_each(|stft| stft.reset());
1328        }
1329        #[cfg(not(feature = "rayon"))]
1330        {
1331            self.processors.iter_mut().for_each(|stft| stft.reset());
1332        }
1333    }
1334
1335    /// Get the number of channels.
1336    pub fn num_channels(&self) -> usize {
1337        self.processors.len()
1338    }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct StreamingIstft<T: Float + FftNum> {
1343    config: StftConfig<T>,
1344    window: Vec<T>,
1345    ifft: Arc<dyn FftBackend<T>>,
1346    overlap_buffer: Vec<T>,
1347    window_energy: Vec<T>,
1348    output_position: usize,
1349    frames_processed: usize,
1350    ifft_buffer: Vec<Complex<T>>,
1351}
1352
1353impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1354    pub fn new(config: StftConfig<T>) -> Self
1355    where
1356        FftPlanner<T>: FftPlannerTrait<T>,
1357    {
1358        let window = config.generate_window();
1359        let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
1360        let ifft = planner.plan_fft_inverse(config.fft_size);
1361
1362        // Buffer needs to hold enough samples for full overlap
1363        // For proper reconstruction, need at least fft_size samples
1364        let buffer_size = config.fft_size * 2;
1365        let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1366
1367        Self {
1368            config,
1369            window,
1370            ifft,
1371            overlap_buffer: vec![T::zero(); buffer_size],
1372            window_energy: vec![T::zero(); buffer_size],
1373            output_position: 0,
1374            frames_processed: 0,
1375            ifft_buffer,
1376        }
1377    }
1378
1379    pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1380        assert_eq!(
1381            frame.freq_bins,
1382            self.config.freq_bins(),
1383            "Frequency bins mismatch"
1384        );
1385
1386        // Build full spectrum with conjugate symmetry
1387        for bin in 0..frame.freq_bins {
1388            self.ifft_buffer[bin] = frame.data[bin];
1389        }
1390
1391        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1392        for bin in 1..(frame.freq_bins - 1) {
1393            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1394        }
1395
1396        // Compute IFFT
1397        self.ifft.process(&mut self.ifft_buffer);
1398
1399        // Overlap-add into buffer at the current write position
1400        let write_pos = self.frames_processed * self.config.hop_size;
1401        for i in 0..self.config.fft_size {
1402            let fft_size_t = T::from(self.config.fft_size).unwrap();
1403            let sample = self.ifft_buffer[i].re / fft_size_t;
1404            let buf_idx = write_pos + i;
1405
1406            // Extend buffers if needed
1407            if buf_idx >= self.overlap_buffer.len() {
1408                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1409                self.window_energy.resize(buf_idx + 1, T::zero());
1410            }
1411
1412            match self.config.reconstruction_mode {
1413                ReconstructionMode::Ola => {
1414                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1415                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1416                }
1417                ReconstructionMode::Wola => {
1418                    self.overlap_buffer[buf_idx] =
1419                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1420                    self.window_energy[buf_idx] =
1421                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1422                }
1423            }
1424        }
1425
1426        self.frames_processed += 1;
1427
1428        // Calculate how many samples are "ready" (have full window energy)
1429        // Samples are ready when no future frames will contribute to them
1430        let ready_until = if self.frames_processed == 1 {
1431            0 // First frame: no output yet, need overlap
1432        } else {
1433            // Samples before the current frame's start position are complete
1434            (self.frames_processed - 1) * self.config.hop_size
1435        };
1436
1437        // Extract ready samples
1438        let output_start = self.output_position;
1439        let output_end = ready_until;
1440        let mut output = Vec::new();
1441
1442        let threshold = T::from(1e-8).unwrap();
1443        if output_end > output_start {
1444            for i in output_start..output_end {
1445                let normalized = if self.window_energy[i] > threshold {
1446                    self.overlap_buffer[i] / self.window_energy[i]
1447                } else {
1448                    T::zero()
1449                };
1450                output.push(normalized);
1451            }
1452            self.output_position = output_end;
1453        }
1454
1455        output
1456    }
1457
1458    /// Push a frame and write output samples into a pre-allocated buffer.
1459    /// Returns the number of samples written.
1460    pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1461        assert_eq!(
1462            frame.freq_bins,
1463            self.config.freq_bins(),
1464            "Frequency bins mismatch"
1465        );
1466
1467        // Build full spectrum with conjugate symmetry
1468        for bin in 0..frame.freq_bins {
1469            self.ifft_buffer[bin] = frame.data[bin];
1470        }
1471
1472        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1473        for bin in 1..(frame.freq_bins - 1) {
1474            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1475        }
1476
1477        // Compute IFFT
1478        self.ifft.process(&mut self.ifft_buffer);
1479
1480        // Overlap-add into buffer at the current write position
1481        let write_pos = self.frames_processed * self.config.hop_size;
1482        for i in 0..self.config.fft_size {
1483            let fft_size_t = T::from(self.config.fft_size).unwrap();
1484            let sample = self.ifft_buffer[i].re / fft_size_t;
1485            let buf_idx = write_pos + i;
1486
1487            // Extend buffers if needed
1488            if buf_idx >= self.overlap_buffer.len() {
1489                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1490                self.window_energy.resize(buf_idx + 1, T::zero());
1491            }
1492
1493            match self.config.reconstruction_mode {
1494                ReconstructionMode::Ola => {
1495                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1496                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1497                }
1498                ReconstructionMode::Wola => {
1499                    self.overlap_buffer[buf_idx] =
1500                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1501                    self.window_energy[buf_idx] =
1502                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1503                }
1504            }
1505        }
1506
1507        self.frames_processed += 1;
1508
1509        // Calculate how many samples are "ready" (have full window energy)
1510        // Samples are ready when no future frames will contribute to them
1511        let ready_until = if self.frames_processed == 1 {
1512            0 // First frame: no output yet, need overlap
1513        } else {
1514            // Samples before the current frame's start position are complete
1515            (self.frames_processed - 1) * self.config.hop_size
1516        };
1517
1518        // Extract ready samples
1519        let output_start = self.output_position;
1520        let output_end = ready_until;
1521        let initial_len = output.len();
1522
1523        let threshold = T::from(1e-8).unwrap();
1524        if output_end > output_start {
1525            for i in output_start..output_end {
1526                let normalized = if self.window_energy[i] > threshold {
1527                    self.overlap_buffer[i] / self.window_energy[i]
1528                } else {
1529                    T::zero()
1530                };
1531                output.push(normalized);
1532            }
1533            self.output_position = output_end;
1534        }
1535
1536        output.len() - initial_len
1537    }
1538
1539    pub fn flush(&mut self) -> Vec<T> {
1540        // Return all remaining samples in buffer
1541        let mut output = Vec::new();
1542        let threshold = T::from(1e-8).unwrap();
1543        for i in self.output_position..self.overlap_buffer.len() {
1544            if self.window_energy[i] > threshold {
1545                output.push(self.overlap_buffer[i] / self.window_energy[i]);
1546            } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1547                output.push(T::zero()); // Sample in valid range but no window energy
1548            } else {
1549                break; // Past the end of valid data
1550            }
1551        }
1552
1553        // Determine the actual end of valid data
1554        let valid_end =
1555            (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1556        if output.len() > valid_end - self.output_position {
1557            output.truncate(valid_end - self.output_position);
1558        }
1559
1560        self.reset();
1561        output
1562    }
1563
1564    pub fn reset(&mut self) {
1565        self.overlap_buffer.clear();
1566        self.overlap_buffer
1567            .resize(self.config.fft_size * 2, T::zero());
1568        self.window_energy.clear();
1569        self.window_energy
1570            .resize(self.config.fft_size * 2, T::zero());
1571        self.output_position = 0;
1572        self.frames_processed = 0;
1573    }
1574}
1575
1576/// Multi-channel streaming iSTFT processor with independent state per channel.
1577#[derive(Debug, Clone)]
1578pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1579    processors: Vec<StreamingIstft<T>>,
1580}
1581
1582impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T>
1583where
1584    FftPlanner<T>: FftPlannerTrait<T>,
1585{
1586    /// Create a new multi-channel streaming iSTFT processor.
1587    ///
1588    /// # Arguments
1589    ///
1590    /// * `config` - STFT configuration
1591    /// * `num_channels` - Number of channels
1592    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1593        assert!(num_channels > 0, "num_channels must be > 0");
1594        let processors = (0..num_channels)
1595            .map(|_| StreamingIstft::new(config.clone()))
1596            .collect();
1597        Self { processors }
1598    }
1599
1600    /// Push frames for all channels and get samples for each channel.
1601    /// Returns Vec<Vec<T>>, outer Vec = channels, inner Vec = samples.
1602    ///
1603    /// # Arguments
1604    ///
1605    /// * `frames` - Slice of frames, one per channel
1606    ///
1607    /// # Panics
1608    ///
1609    /// Panics if frames.len() doesn't match num_channels.
1610    pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1611        assert_eq!(
1612            frames.len(),
1613            self.processors.len(),
1614            "Expected {} channels, got {}",
1615            self.processors.len(),
1616            frames.len()
1617        );
1618
1619        #[cfg(feature = "rayon")]
1620        {
1621            use rayon::prelude::*;
1622            self.processors
1623                .par_iter_mut()
1624                .zip(frames.par_iter())
1625                .map(|(istft, frame)| istft.push_frame(frame))
1626                .collect()
1627        }
1628        #[cfg(not(feature = "rayon"))]
1629        {
1630            self.processors
1631                .iter_mut()
1632                .zip(frames.iter())
1633                .map(|(istft, frame)| istft.push_frame(frame))
1634                .collect()
1635        }
1636    }
1637
1638    /// Flush all channels and return remaining samples.
1639    pub fn flush(&mut self) -> Vec<Vec<T>> {
1640        #[cfg(feature = "rayon")]
1641        {
1642            use rayon::prelude::*;
1643            self.processors
1644                .par_iter_mut()
1645                .map(|istft| istft.flush())
1646                .collect()
1647        }
1648        #[cfg(not(feature = "rayon"))]
1649        {
1650            self.processors
1651                .iter_mut()
1652                .map(|istft| istft.flush())
1653                .collect()
1654        }
1655    }
1656
1657    /// Reset all channels.
1658    pub fn reset(&mut self) {
1659        #[cfg(feature = "rayon")]
1660        {
1661            use rayon::prelude::*;
1662            self.processors
1663                .par_iter_mut()
1664                .for_each(|istft| istft.reset());
1665        }
1666        #[cfg(not(feature = "rayon"))]
1667        {
1668            self.processors.iter_mut().for_each(|istft| istft.reset());
1669        }
1670    }
1671
1672    /// Get the number of channels.
1673    pub fn num_channels(&self) -> usize {
1674        self.processors.len()
1675    }
1676}
1677
1678// Type aliases for common float types
1679pub type StftConfigF32 = StftConfig<f32>;
1680pub type StftConfigF64 = StftConfig<f64>;
1681
1682pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1683pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1684
1685pub type BatchStftF32 = BatchStft<f32>;
1686pub type BatchStftF64 = BatchStft<f64>;
1687
1688pub type BatchIstftF32 = BatchIstft<f32>;
1689pub type BatchIstftF64 = BatchIstft<f64>;
1690
1691pub type StreamingStftF32 = StreamingStft<f32>;
1692pub type StreamingStftF64 = StreamingStft<f64>;
1693
1694pub type StreamingIstftF32 = StreamingIstft<f32>;
1695pub type StreamingIstftF64 = StreamingIstft<f64>;
1696
1697pub type SpectrumF32 = Spectrum<f32>;
1698pub type SpectrumF64 = Spectrum<f64>;
1699
1700pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1701pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1702
1703pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1704pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1705
1706pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1707pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;