cute_dsp/
spectral.rs

1//! Spectral Processing
2//!
3//! This module provides tools for frequency-domain manipulation of audio signals.
4
5#![allow(unused_imports)]
6
7#[cfg(feature = "std")]
8use std::{f32::consts::PI, vec::Vec, marker::PhantomData};
9
10#[cfg(feature = "std")]
11use std::ops::AddAssign;
12
13#[cfg(not(feature = "std"))]
14use core::{f32::consts::PI, marker::PhantomData};
15
16#[cfg(not(feature = "std"))]
17use core::ops::AddAssign;
18
19#[cfg(all(not(feature = "std"), feature = "alloc"))]
20use alloc::vec::Vec;
21
22
23use num_complex::Complex;
24use num_traits::{Float, FromPrimitive};
25
26use crate::fft;
27use crate::windows;
28use crate::delay;
29use crate::perf;
30
31/// An FFT with built-in windowing and round-trip scaling
32///
33/// This uses a Modified Real FFT, which applies half-bin shift before the transform.
34/// The result therefore has `N/2` bins, centred at the frequencies: `(i + 0.5)/N`.
35///
36/// This avoids the awkward (real-valued) bands for DC-offset and Nyquist.
37pub struct WindowedFFT<T: Float> {
38    fft: fft::Pow2RealFFT<T>,
39    fft_window: Vec<T>,
40    time_buffer: Vec<T>,
41    offset_samples: usize,
42}
43
44impl<T: Float + num_traits::FloatConst + num_traits::FromPrimitive> WindowedFFT<T> {
45    /// Create a new WindowedFFT with the specified size
46    pub fn new(size: usize, rotate_samples: usize) -> Self {
47        let mut result = Self {
48            fft: fft::Pow2RealFFT::new(0),
49            fft_window: Vec::new(),
50            time_buffer: Vec::new(),
51            offset_samples: 0,
52        };
53        result.set_size(size, rotate_samples);
54        result
55    }
56    
57    /// Create a new WindowedFFT with a custom window function
58    pub fn with_window<F>(size: usize, window_fn: F, window_offset: T, rotate_samples: usize) -> Self
59    where
60        F: Fn(T) -> T,
61    {
62        let mut result = Self {
63            fft: fft::Pow2RealFFT::new(0),
64            fft_window: Vec::new(),
65            time_buffer: Vec::new(),
66            offset_samples: 0,
67        };
68        result.set_size_with_window(size, window_fn, window_offset, rotate_samples);
69        result
70    }
71    
72    /// Returns a fast FFT size >= `size`
73    pub fn fast_size_above(size: usize, divisor: usize) -> usize {
74        // Find the next power of 2 >= size/divisor, then multiply by divisor
75        let target = (size + divisor - 1) / divisor; // Ceiling division
76        let mut result = 1;
77        while result < target {
78            result *= 2;
79        }
80        result * divisor
81    }
82    
83    /// Returns a fast FFT size <= `size`
84    pub fn fast_size_below(size: usize, divisor: usize) -> usize {
85        // Find the largest power of 2 <= size/divisor, then multiply by divisor
86        let target = size / divisor;
87        let mut result = 1;
88        while result * 2 <= target {
89            result *= 2;
90        }
91        result * divisor
92    }
93    
94    /// Sets the size, returning the window for modification (initially all 1s)
95    pub fn set_size_window(&mut self, size: usize, rotate_samples: usize) -> &mut Vec<T> {
96        self.fft.resize(size);
97        self.fft_window = vec![T::one(); size];
98        self.time_buffer.resize(size, T::zero());
99        self.offset_samples = rotate_samples % size;
100        self.fft_window.as_mut()
101    }
102    
103    /// Sets the FFT size, with a user-defined function for the window
104    pub fn set_size_with_window<F>(&mut self, size: usize, window_fn: F, window_offset: T, rotate_samples: usize)
105    where
106        F: Fn(T) -> T,
107    {
108        self.set_size_window(size, rotate_samples);
109        
110        let inv_size = T::from_f32(1.0).unwrap() / T::from_f32(size as f32).unwrap();
111        for i in 0..size {
112            let r = (T::from_f32(i as f32).unwrap() + window_offset) * inv_size;
113            self.fft_window[i] = window_fn(r);
114        }
115    }
116    
117    /// Sets the size (using the default Blackman-Harris window)
118    pub fn set_size(&mut self, size: usize, rotate_samples: usize) {
119        self.set_size_with_window(
120            size,
121            |x| {
122                let phase = T::PI() * T::from_f32(2.0).unwrap() * x;
123                // Blackman-Harris
124                T::from_f32(0.35875).unwrap() -
125                T::from_f32(0.48829).unwrap() * phase.cos() +
126                T::from_f32(0.14128).unwrap() * (phase * T::from_f32(2.0).unwrap()).cos() -
127                T::from_f32(0.01168).unwrap() * (phase * T::from_f32(3.0).unwrap()).cos()
128            },
129            T::from_f32(0.5).unwrap(),
130            rotate_samples,
131        );
132    }
133    
134    /// Get a reference to the window
135    pub fn window(&self) -> &[T] {
136        &self.fft_window
137    }
138    
139    /// Get the FFT size
140    pub fn size(&self) -> usize {
141        self.fft_window.len()
142    }
143    
144    /// Performs an FFT, with windowing and rotation (if enabled)
145    pub fn fft<I, O>(&mut self, input: I, output: &mut [O], with_window: bool, with_scaling: bool)
146    where
147        I: AsRef<[T]>,
148        O: From<Complex<T>> + Copy,
149    {
150        let input = input.as_ref();
151        let fft_size = self.size();
152        let norm = if with_scaling {
153            T::from_f32(1.0).unwrap() / T::from_f32(fft_size as f32).unwrap()
154        } else {
155            T::one()
156        };
157        
158        // Apply window and handle rotation
159        for i in 0..self.offset_samples {
160            // Inverted polarity since we're using the Modified Real FFT
161            self.time_buffer[i + fft_size - self.offset_samples] = 
162                -input[i] * norm * if with_window { self.fft_window[i] } else { T::one() };
163        }
164        for i in self.offset_samples..fft_size {
165            self.time_buffer[i - self.offset_samples] = 
166                input[i] * norm * if with_window { self.fft_window[i] } else { T::one() };
167        }
168        
169        // Perform FFT
170        let mut complex_output = vec![Complex::new(T::zero(), T::zero()); fft_size / 2 + 1];
171        self.fft.fft(&self.time_buffer, &mut complex_output);
172        
173        // Copy to output
174        for i in 0..complex_output.len() {
175            output[i] = complex_output[i].into();
176        }
177    }
178    
179    /// Performs an inverse FFT, with windowing and rotation (if enabled)
180    pub fn ifft<I, O>(&mut self, input: &[I], mut output: O, with_window: bool)
181    where
182        I: Copy + Into<Complex<T>>,
183        O: AsMut<[T]>,
184    {
185        let output = output.as_mut();
186        let fft_size = self.size();
187        
188        // Convert input to complex
189        let mut complex_input = vec![Complex::new(T::zero(), T::zero()); fft_size / 2 + 1];
190        for i in 0..complex_input.len() {
191            complex_input[i] = input[i].into();
192        }
193        
194        // Perform inverse FFT
195        self.fft.ifft(&complex_input, &mut self.time_buffer);
196        
197        // Apply window and handle rotation
198        for i in 0..self.offset_samples {
199            output[i] = self.time_buffer[i + fft_size - self.offset_samples] * 
200                if with_window { self.fft_window[i] } else { T::one() };
201        }
202        for i in self.offset_samples..fft_size {
203            output[i] = self.time_buffer[i - self.offset_samples] * 
204                if with_window { self.fft_window[i] } else { T::one() };
205        }
206    }
207}
208
209/// A processor for spectral manipulation of audio
210pub struct SpectralProcessor<T: Float> {
211    fft: WindowedFFT<T>,
212    overlap: usize,
213    hop_size: usize,
214    input_buffer: Vec<T>,
215    output_buffer: Vec<T>,
216    spectrum: Vec<Complex<T>>,
217    window_sum: Vec<T>,
218    steady_state: Vec<T>, // Added for steady-state normalization
219}
220
221
222
223impl<T: Float + AddAssign + num_traits::FloatConst + FromPrimitive> SpectralProcessor<T> {
224    /// Create a new SpectralProcessor with the specified parameters
225    pub fn new(fft_size: usize, overlap: usize) -> Self {
226        let mut result = Self {
227            fft: WindowedFFT::new(fft_size, 0),
228            overlap,
229            hop_size: fft_size / overlap,
230            input_buffer: Vec::new(),
231            output_buffer: Vec::new(),
232            spectrum: Vec::new(),
233            window_sum: Vec::new(),
234            steady_state: Vec::new(), // Added
235        };
236        result.reset();
237        result
238    }
239
240    /// Reset the processor state
241    pub fn reset(&mut self) {
242        let fft_size = self.fft.size();
243        self.input_buffer.resize(fft_size, T::zero());
244        self.output_buffer.resize(fft_size, T::zero());
245        self.spectrum.resize(fft_size / 2 + 1, Complex::new(T::zero(), T::zero()));
246
247        // Calculate window sum for normalization (for each absolute sample in the first fft_size samples)
248        self.window_sum = vec![T::zero(); fft_size];
249        for i in 0..self.overlap {
250            let hop = self.hop_size;
251            for j in 0..fft_size {
252                let absolute_index = i * hop + j;
253                if absolute_index < fft_size {
254                    let win_val = self.fft.window()[j];
255                    self.window_sum[absolute_index] += win_val * win_val;
256                }
257            }
258        }
259
260        // Precompute steady-state normalization factors for remainders
261        self.steady_state = vec![T::zero(); self.hop_size];
262        for r in 0..self.hop_size {
263            let mut offset = r;
264            while offset < fft_size {
265                let win_val = self.fft.window()[offset];
266                self.steady_state[r] += win_val * win_val;
267                offset += self.hop_size;
268            }
269            // Avoid division by zero in steady-state
270            if self.steady_state[r] < T::from_f32(1e-10).unwrap() {
271                self.steady_state[r] = T::one();
272            }
273        }
274
275        // Avoid division by zero in window_sum
276        for value in self.window_sum.iter_mut() {
277            if *value < T::from_f32(1e-10).unwrap() {
278                *value = T::one();
279            }
280        }
281    }
282    
283    /// Get the FFT size
284    pub fn fft_size(&self) -> usize {
285        self.fft.size()
286    }
287    
288    /// Get the hop size (distance between consecutive frames)
289    pub fn hop_size(&self) -> usize {
290        self.hop_size
291    }
292    
293    /// Get the overlap factor
294    pub fn overlap(&self) -> usize {
295        self.overlap
296    }
297    
298    /// Process a block of input samples with a spectral processing function
299    pub fn process<F>(&mut self, input: &[T], output: &mut [T], processor: F)
300    where
301        F: FnMut(&mut [Complex<T>]),
302    {
303        self.process_with_options(input, output, processor, true, true);
304    }
305
306    /// Process a block of input samples with a spectral processing function and options
307    pub fn process_with_options<F>(
308        &mut self,
309        input: &[T],
310        output: &mut [T],
311        mut processor: F,
312        with_window: bool,
313        with_scaling: bool,
314    )
315    where
316        F: FnMut(&mut [Complex<T>]),
317    {
318        let fft_size = self.fft.size();
319        let input_len = input.len();
320        let output_len = output.len();
321
322        // Process in overlapping blocks
323        for i in (0..input_len).step_by(self.hop_size) {
324            // Copy input to buffer with bounds checking
325            let copy_len = (input_len - i).min(fft_size);
326            self.input_buffer[..copy_len].copy_from_slice(&input[i..i + copy_len]);
327            self.input_buffer[copy_len..].fill(T::zero());
328
329            // Perform FFT
330            self.fft.fft(&self.input_buffer, &mut self.spectrum, with_window, with_scaling);
331
332            // Apply spectral processing
333            processor(&mut self.spectrum);
334
335            // Perform inverse FFT
336            self.fft.ifft(&self.spectrum, &mut self.output_buffer, with_window);
337
338            // Overlap-add to output with safe bounds checking
339            let output_offset = i;
340            let add_len = (output_len.saturating_sub(output_offset)).min(fft_size);
341            for j in 0..add_len {
342                let abs_index = output_offset + j;
343                let norm_factor = if abs_index < fft_size {
344                    // Use exact normalization factor for initial samples
345                    self.window_sum[abs_index]
346                } else {
347                    // Use steady-state factor for periodic part
348                    let r = abs_index % self.hop_size;
349                    self.steady_state[r]
350                };
351                output[abs_index] += self.output_buffer[j] / norm_factor;
352            }
353        }
354    }
355}
356
357/// Utility functions for spectral processing
358pub mod utils {
359    use super::*;
360    
361    /// Convert magnitude and phase to complex
362    pub fn mag_phase_to_complex<T: Float>(mag: T, phase: T) -> Complex<T> {
363        Complex::new(mag * phase.cos(), mag * phase.sin())
364    }
365    
366    /// Convert complex to magnitude and phase
367    pub fn complex_to_mag_phase<T: Float>(complex: Complex<T>) -> (T, T) {
368        (complex.norm(), complex.arg())
369    }
370    
371    /// Convert linear magnitude to decibels
372    pub fn linear_to_db<T: Float>(linear: T) -> T {
373        T::from(20.0).unwrap() * linear.log10()
374    }
375    
376    /// Convert decibels to linear magnitude
377    pub fn db_to_linear<T: Float>(db: T) -> T {
378        T::from(10.0).unwrap().powf(db / T::from(20.0).unwrap())
379    }
380    
381    /// Apply a gain to a spectrum (in decibels)
382    pub fn apply_gain<T: Float>(spectrum: &mut [Complex<T>], gain_db: T) {
383        let gain_linear = db_to_linear(gain_db);
384        for bin in spectrum {
385            *bin = *bin * gain_linear;
386        }
387    }
388    
389    /// Apply a phase shift to a spectrum (in radians)
390    pub fn apply_phase_shift<T: Float>(spectrum: &mut [Complex<T>], phase_shift: T) {
391        for bin in spectrum {
392            let (mag, phase) = complex_to_mag_phase(*bin);
393            *bin = mag_phase_to_complex(mag, phase + phase_shift);
394        }
395    }
396    
397    /// Apply a time shift to a spectrum
398    pub fn apply_time_shift<T: Float>(spectrum: &mut [Complex<T>], time_shift: T, sample_rate: T) {
399        let fft_size = spectrum.len() * 2 - 2;
400        let bin_width = sample_rate / T::from(fft_size as f32).unwrap();
401
402        for (i, bin) in spectrum.iter_mut().enumerate() {
403            let freq = T::from(i as f32).unwrap() * bin_width;
404            let phase_shift = T::from(2.0 * PI).unwrap() * freq * time_shift;
405            let (mag, phase) = complex_to_mag_phase(*bin);
406            *bin = mag_phase_to_complex(mag, phase + phase_shift);
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_windowed_fft() {
417        let mut fft = WindowedFFT::<f32>::new(1024, 0);
418
419        // Create a test signal (sine wave)
420        let mut input = vec![0.0; 1024];
421        for i in 0..1024 {
422            input[i] = (i as f32 * 0.1).sin();
423        }
424
425        // Perform FFT
426        let mut output = vec![Complex::new(0.0, 0.0); 513]; // N/2 + 1
427        fft.fft(&input, &mut output, true, true);
428
429        // The spectrum should have peaks at the sine wave frequency
430        let peak_bin = output.iter()
431            .enumerate()
432            .max_by(|(_, a), (_, b)| a.norm().partial_cmp(&b.norm()).unwrap())
433            .map(|(index, _)| index)
434            .unwrap();
435
436        // Expected peak at around bin 16-18 (0.1 * 1024 / (2*PI) ≈ 16.3)
437        assert!(peak_bin >= 16 && peak_bin <= 18);  // Fixed expected bin range
438    }
439
440    #[test]
441    fn test_spectral_processor() {
442        let mut processor = SpectralProcessor::<f32>::new(1024, 4);
443
444        // Create a test signal (sine wave)
445        let mut input = vec![0.0; 2048];
446        for i in 0..2048 {
447            input[i] = (i as f32 * 0.1).sin();
448        }
449
450        // Create output buffer
451        let mut output = vec![0.0; 2048];
452
453        // Process with identity function (should reconstruct the input)
454        processor.process(&input, &mut output, |_spectrum| {
455            // Do nothing (identity)
456        });
457
458        // Check that the output approximates the input
459        for i in 512..1536 { // Ignore edges due to windowing effects
460            assert!((input[i] - output[i]).abs() < 0.1);
461        }
462    }
463    
464    #[test]
465    fn test_spectral_utils() {
466        // Test magnitude/phase conversion
467        let complex = Complex::new(3.0, 4.0);
468        let (mag, phase) = utils::complex_to_mag_phase(complex);
469        let complex2 = utils::mag_phase_to_complex(mag, phase);
470        
471        assert!((complex.re - complex2.re).abs() < 1e-10);
472        assert!((complex.im - complex2.im).abs() < 1e-10);
473        
474        // Test dB conversion
475        let linear = 10.0;
476        let db = utils::linear_to_db(linear);
477        let linear2 = utils::db_to_linear(db);
478        
479        assert!((linear - linear2).abs() < 1e-10);
480    }
481}