cute_dsp/
stft.rs

1//! Short-Time Fourier Transform implementation
2//!
3//! This module provides a self-normalizing STFT implementation with variable
4//! position/window for output blocks.
5
6#![allow(unused_imports)]
7
8#[cfg(feature = "std")]
9use std::{f32::consts::PI, vec::Vec, marker::PhantomData};
10
11#[cfg(not(feature = "std"))]
12use core::{f32::consts::PI, marker::PhantomData};
13
14#[cfg(all(not(feature = "std"), feature = "alloc"))]
15use alloc::vec::Vec;
16
17use num_complex::Complex;
18use num_traits::Float;
19use num_traits::FromPrimitive;
20use num_traits::NumCast;
21
22use crate::fft;
23use crate::windows;
24
25/// Window shape for STFT
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum WindowShape {
28    /// Ignore window shape (use rectangular window)
29    Ignore,
30    /// Approximate Confined Gaussian window
31    ACG,
32    /// Kaiser window
33    Kaiser,
34}
35
36/// A self-normalizing STFT, with variable position/window for output blocks
37pub struct STFT<T: Float> {
38    // FFT implementation
39    fft: fft::Pow2RealFFT<T>,
40
41    // Configuration
42    analysis_channels: usize,
43    synthesis_channels: usize,
44    block_samples: usize,
45    fft_samples: usize,
46    fft_bins: usize,
47    input_length_samples: usize,
48    default_interval: usize,
49
50    // Windows
51    analysis_window: Vec<T>,
52    synthesis_window: Vec<T>,
53    analysis_offset: usize,
54    synthesis_offset: usize,
55
56    // Buffers
57    input_buffer: Vec<T>,
58    input_pos: usize,
59    output_buffer: Vec<T>,
60    output_pos: usize,
61    window_products: Vec<T>,
62    spectrum_buffer: Vec<Complex<T>>,
63    time_buffer: Vec<T>,
64
65    // Constants
66    almost_zero: T,
67    modified: bool,
68}
69
70#[cfg(feature = "std")]
71use std::ops::AddAssign;
72
73#[cfg(not(feature = "std"))]
74use core::ops::AddAssign;
75
76impl<T: Float + FromPrimitive + NumCast + AddAssign> STFT<T> {
77    /// Create a new STFT instance
78    pub fn new(modified: bool) -> Self {
79        Self {
80            fft: fft::Pow2RealFFT::new(0),
81            analysis_channels: 0,
82            synthesis_channels: 0,
83            block_samples: 0,
84            fft_samples: 0,
85            fft_bins: 0,
86            input_length_samples: 0,
87            default_interval: 0,
88            analysis_window: Vec::new(),
89            synthesis_window: Vec::new(),
90            analysis_offset: 0,
91            synthesis_offset: 0,
92            input_buffer: Vec::new(),
93            input_pos: 0,
94            output_buffer: Vec::new(),
95            output_pos: 0,
96            window_products: Vec::new(),
97            spectrum_buffer: Vec::new(),
98            time_buffer: Vec::new(),
99            almost_zero: T::from_f32(1e-20).unwrap(),
100            modified,
101        }
102    }
103
104    /// Configure the STFT
105    pub fn configure(
106        &mut self,
107        in_channels: usize,
108        out_channels: usize,
109        block_samples: usize,
110        extra_input_history: usize,
111        interval_samples: usize,
112    ) {
113        self.analysis_channels = in_channels;
114        self.synthesis_channels = out_channels;
115        self.block_samples = block_samples;
116
117        // Calculate FFT size (power of 2 >= block_samples)
118        let mut fft_samples = 1;
119        while fft_samples < block_samples {
120            fft_samples *= 2;
121        }
122        self.fft_samples = fft_samples;
123        self.fft.resize(fft_samples);
124        self.fft_bins = fft_samples / 2 + 1; // For real FFT
125
126        self.input_length_samples = block_samples + extra_input_history;
127        self.input_buffer.resize(self.input_length_samples * in_channels, T::zero());
128
129        self.output_buffer.resize(block_samples * out_channels, T::zero());
130        self.window_products.resize(block_samples, T::zero());
131        self.spectrum_buffer.resize(self.fft_bins * in_channels.max(out_channels), Complex::new(T::zero(), T::zero()));
132        self.time_buffer.resize(fft_samples, T::zero());
133
134        self.analysis_window.resize(block_samples, T::zero());
135        self.synthesis_window.resize(block_samples, T::zero());
136
137        // Set default interval if not specified
138        let interval = if interval_samples > 0 {
139            interval_samples
140        } else {
141            block_samples / 4
142        };
143        self.set_interval(interval, WindowShape::ACG);
144
145        self.reset_default();
146    }
147
148    /// Get the block size in samples
149    pub fn block_samples(&self) -> usize {
150        self.block_samples
151    }
152
153    /// Get the FFT size in samples
154    pub fn fft_samples(&self) -> usize {
155        self.fft_samples
156    }
157
158    /// Get the default interval between blocks
159    pub fn default_interval(&self) -> usize {
160        self.default_interval
161    }
162
163    /// Get the number of frequency bands
164    pub fn bands(&self) -> usize {
165        self.fft_bins
166    }
167
168    /// Get the analysis latency
169    pub fn analysis_latency(&self) -> usize {
170        self.block_samples - self.analysis_offset
171    }
172
173    /// Get the synthesis latency
174    pub fn synthesis_latency(&self) -> usize {
175        self.synthesis_offset
176    }
177
178    /// Get the total latency
179    pub fn latency(&self) -> usize {
180        self.synthesis_latency() + self.analysis_latency()
181    }
182
183    /// Convert bin index to frequency
184    pub fn bin_to_freq(&self, bin: T) -> T {
185        if self.modified {
186            (bin + T::from_f32(0.5).unwrap()) / T::from_usize(self.fft_samples).unwrap()
187        } else {
188            bin / T::from_usize(self.fft_samples).unwrap()
189        }
190    }
191
192    /// Convert frequency to bin index
193    pub fn freq_to_bin(&self, freq: T) -> T {
194        if self.modified {
195            freq * T::from_usize(self.fft_samples).unwrap() - T::from_f32(0.5).unwrap()
196        } else {
197            freq * T::from_usize(self.fft_samples).unwrap()
198        }
199    }
200
201    /// Reset the STFT state
202    pub fn reset(&mut self, product_weight: T) {
203        self.input_pos = self.block_samples;
204        self.output_pos = 0;
205
206        // Clear buffers
207        for v in &mut self.input_buffer {
208            *v = T::zero();
209        }
210        for v in &mut self.output_buffer {
211            *v = T::zero();
212        }
213        for v in &mut self.spectrum_buffer {
214            *v = Complex::new(T::zero(), T::zero());
215        }
216        for v in &mut self.window_products {
217            *v = T::zero();
218        }
219
220        // Initialize window products
221        self.add_window_product();
222
223        // Accumulate window products for overlapping windows
224        for i in (0..self.block_samples - self.default_interval).rev() {
225            self.window_products[i] = self.window_products[i] + self.window_products[i + self.default_interval];
226        }
227
228        // Scale window products
229        for v in &mut self.window_products {
230            *v = *v * product_weight + self.almost_zero;
231        }
232
233        // Move output position to be ready for first block
234        self.move_output(self.default_interval);
235    }
236
237    /// Reset the STFT state with default product weight
238    pub fn reset_default(&mut self) {
239        self.reset(T::one());
240    }
241
242    /// Write input samples to a specific channel
243    pub fn write_input(&mut self, channel: usize, offset: usize, length: usize, input_array: &[T]) {
244        assert!(channel < self.analysis_channels, "Channel index out of bounds");
245        assert!(offset + length <= input_array.len(), "Input array too small");
246
247        let buffer_start = channel * self.input_length_samples;
248        let offset_pos = (self.input_pos + offset) % self.input_length_samples;
249
250        // Handle wrapping around the circular buffer
251        let input_wrap_index = self.input_length_samples - offset_pos;
252        let chunk1 = length.min(input_wrap_index);
253
254        // Copy first chunk (before wrap)
255        for i in 0..chunk1 {
256            let buffer_index = buffer_start + offset_pos + i;
257            self.input_buffer[buffer_index] = input_array[i];
258        }
259
260        // Copy second chunk (after wrap)
261        for i in chunk1..length {
262            let buffer_index = buffer_start + i + offset_pos - self.input_length_samples;
263            self.input_buffer[buffer_index] = input_array[i];
264        }
265    }
266
267    /// Write input samples to a specific channel (without offset)
268    pub fn write_input_simple(&mut self, channel: usize, input_array: &[T]) {
269        self.write_input(channel, 0, input_array.len(), input_array);
270    }
271
272    /// Read output samples from a specific channel
273    pub fn read_output(&self, channel: usize, offset: usize, length: usize, output_array: &mut [T]) {
274        assert!(channel < self.synthesis_channels, "Channel index out of bounds");
275        assert!(offset + length <= output_array.len(), "Output array too small");
276
277        let buffer_start = channel * self.block_samples;
278        let offset_pos = (self.output_pos + offset) % self.block_samples;
279
280        // Handle wrapping around the circular buffer
281        let output_wrap_index = self.block_samples - offset_pos;
282        let chunk1 = length.min(output_wrap_index);
283
284        // Copy first chunk (before wrap)
285        for i in 0..chunk1 {
286            let buffer_index = buffer_start + offset_pos + i;
287            output_array[i] = self.output_buffer[buffer_index];
288        }
289
290        // Copy second chunk (after wrap)
291        for i in chunk1..length {
292            let buffer_index = buffer_start + i + offset_pos - self.block_samples;
293            output_array[i] = self.output_buffer[buffer_index];
294        }
295    }
296
297    /// Read output samples from a specific channel (without offset)
298    pub fn read_output_simple(&self, channel: usize, output_array: &mut [T]) {
299        self.read_output(channel, 0, output_array.len(), output_array);
300    }
301
302    /// Move the input position
303    pub fn move_input(&mut self, samples: usize) {
304        self.input_pos = (self.input_pos + samples) % self.input_length_samples;
305    }
306
307    /// Move the output position
308    pub fn move_output(&mut self, samples: usize) {
309        self.output_pos = (self.output_pos + samples) % self.block_samples;
310    }
311
312    /// Set the interval between blocks and update windows
313    pub fn set_interval(&mut self, interval: usize, window_shape: WindowShape) {
314        self.default_interval = interval;
315
316        // Set window offsets
317        self.analysis_offset = self.block_samples / 2;
318        self.synthesis_offset = self.block_samples / 2;
319
320        // Create windows
321        match window_shape {
322            WindowShape::Ignore => {
323                // Rectangular window
324                for i in 0..self.block_samples {
325                    self.analysis_window[i] = T::one();
326                    self.synthesis_window[i] = T::one();
327                }
328            },
329            WindowShape::ACG => {
330                // Approximate Confined Gaussian window
331                let acg = windows::ApproximateConfinedGaussian::with_bandwidth(T::from_f32(2.5).unwrap());
332                acg.fill(self.analysis_window.as_mut_slice());
333                acg.fill(self.synthesis_window.as_mut_slice());
334            },
335            WindowShape::Kaiser => {
336                // Kaiser window
337                let kaiser = windows::Kaiser::with_bandwidth(T::from_f32(2.5).unwrap(), true);
338                kaiser.fill(self.analysis_window.as_mut_slice());
339                kaiser.fill(self.synthesis_window.as_mut_slice());
340            },
341        }
342
343        // Force perfect reconstruction
344        windows::force_perfect_reconstruction(&mut self.synthesis_window, self.block_samples, interval);
345    }
346
347    /// Add window product to the accumulation buffer
348    fn add_window_product(&mut self) {
349        for i in 0..self.block_samples {
350            self.window_products[i] += self.analysis_window[i] * self.synthesis_window[i];
351        }
352    }
353
354    /// Process a block of input samples to produce a spectrum
355    pub fn process_block_to_spectrum(&mut self, channel: usize) -> &[Complex<T>] {
356        assert!(channel < self.analysis_channels, "Channel index out of bounds");
357
358        // Copy input to time buffer with analysis window applied
359        let buffer_start = channel * self.input_length_samples;
360        for i in 0..self.block_samples {
361            let input_index = (self.input_pos + self.block_samples - self.analysis_offset + i) % self.input_length_samples;
362            self.time_buffer[i] = self.input_buffer[buffer_start + input_index] * self.analysis_window[i];
363        }
364
365        // Zero-pad the rest of the FFT buffer
366        for i in self.block_samples..self.fft_samples {
367            self.time_buffer[i] = T::zero();
368        }
369
370        // Perform FFT
371        let spectrum_start = channel * self.fft_bins;
372        let spectrum_slice = &mut self.spectrum_buffer[spectrum_start..spectrum_start + self.fft_bins];
373        self.fft.fft(&self.time_buffer, spectrum_slice);
374
375        // Return the spectrum for this channel
376        &self.spectrum_buffer[spectrum_start..spectrum_start + self.fft_bins]
377    }
378
379    /// Process a spectrum to produce a block of output samples
380    pub fn process_spectrum_to_block(&mut self, channel: usize, spectrum: &[Complex<T>]) {
381        assert!(channel < self.synthesis_channels, "Channel index out of bounds");
382        assert!(spectrum.len() >= self.fft_bins, "Spectrum too small");
383
384        // Perform inverse FFT
385        self.fft.ifft(spectrum, &mut self.time_buffer);
386
387        // Apply synthesis window and add to output buffer
388        let buffer_start = channel * self.block_samples;
389        for i in 0..self.block_samples {
390            // Calculate output index with proper circular buffer handling
391            let output_index = (self.output_pos + self.synthesis_offset + i) % self.block_samples;
392            let window_product = self.window_products[i];
393            let value = self.time_buffer[i] * self.synthesis_window[i] / window_product;
394            self.output_buffer[buffer_start + output_index] += value;
395        }
396    }
397
398    /// Process a block of input samples directly to output
399    pub fn process_block(&mut self, in_channel: usize, out_channel: usize) {
400        // Process input to spectrum
401        let spectrum = self.process_block_to_spectrum(in_channel);
402
403        // Copy spectrum to avoid borrowing issues
404        let spectrum_copy = spectrum.to_vec();
405
406        // Process spectrum to output
407        self.process_spectrum_to_block(out_channel, &spectrum_copy);
408    }
409
410    /// Process multiple channels at once
411    pub fn process_channels(&mut self, in_channels: &[usize], out_channels: &[usize]) {
412        assert!(in_channels.len() == out_channels.len(), "Channel arrays must have the same length");
413
414        for (in_ch, out_ch) in in_channels.iter().zip(out_channels.iter()) {
415            self.process_block(*in_ch, *out_ch);
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_stft_configuration() {
426        let mut stft = STFT::<f32>::new(false);
427        stft.configure(2, 2, 1024, 0, 256);
428
429        assert_eq!(stft.block_samples(), 1024);
430        assert_eq!(stft.fft_samples(), 1024);
431        assert_eq!(stft.default_interval(), 256);
432        assert_eq!(stft.bands(), 513); // N/2 + 1 for real FFT
433    }
434
435    #[test]
436    fn test_stft_io() {
437        let mut stft = STFT::<f32>::new(false);
438        stft.configure(1, 1, 16, 0, 4);
439
440        // Create a test signal (impulse)
441        let mut input = vec![0.0; 32];
442        input[0] = 1.0;
443
444        // Write first 16 samples
445        stft.write_input_simple(0, &input[0..16]);
446        stft.process_block(0, 0);
447
448        // Write next 16 samples
449        stft.move_input(4);
450        stft.write_input(0, 0, 4, &vec![0.0; 4]);
451        stft.process_block(0, 0);
452
453        stft.move_input(4);
454        stft.write_input(0, 0, 4, &vec![0.0; 4]);
455        stft.process_block(0, 0);
456
457        stft.move_input(4);
458        stft.write_input(0, 0, 4, &vec![0.0; 4]);
459        stft.process_block(0, 0);
460
461        // Read output
462        let mut output = vec![0.0; 16];
463        stft.read_output_simple(0, &mut output);
464
465        // Find peak position
466        let max_index = output.iter()
467            .enumerate()
468            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
469            .map(|(index, _)| index)
470            .unwrap();
471
472        // Should be at total latency (16 samples) which wraps to index 0
473        // but appears at index 4 in the output array due to synthesis_offset
474        assert_eq!(max_index, 4);
475    }
476
477    #[test]
478    fn test_stft_frequency_conversion() {
479        let mut stft = STFT::<f32>::new(false);
480        stft.configure(1, 1, 1024, 0, 256);
481
482        // Test bin to frequency conversion
483        let bin = 100.0;
484        let freq = stft.bin_to_freq(bin);
485        let bin2 = stft.freq_to_bin(freq);
486
487        assert!((bin - bin2).abs() < 1e-10);
488    }
489}