Skip to main content

ringkernel_audio_fft/
fft.rs

1//! FFT and IFFT processing utilities.
2//!
3//! This module provides FFT/IFFT processing using rustfft, with support for
4//! overlap-add processing and various window functions.
5
6use std::sync::Arc;
7
8use num_complex::Complex as NumComplex;
9use rustfft::{Fft, FftPlanner};
10
11use crate::error::{AudioFftError, Result};
12use crate::messages::Complex;
13
14/// Window function types for FFT analysis.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum WindowFunction {
17    /// Rectangular (no window).
18    Rectangular,
19    /// Hann window (cosine-squared).
20    Hann,
21    /// Hamming window.
22    Hamming,
23    /// Blackman window.
24    Blackman,
25    /// Blackman-Harris window.
26    BlackmanHarris,
27    /// Kaiser window with beta parameter.
28    Kaiser(u8), // beta * 10 to avoid float in enum
29}
30
31impl WindowFunction {
32    /// Generate window coefficients.
33    pub fn generate(&self, size: usize) -> Vec<f32> {
34        let n = size as f32;
35        (0..size)
36            .map(|i| {
37                let x = i as f32;
38                match self {
39                    Self::Rectangular => 1.0,
40                    Self::Hann => 0.5 * (1.0 - (2.0 * std::f32::consts::PI * x / n).cos()),
41                    Self::Hamming => 0.54 - 0.46 * (2.0 * std::f32::consts::PI * x / n).cos(),
42                    Self::Blackman => {
43                        let a0 = 0.42;
44                        let a1 = 0.5;
45                        let a2 = 0.08;
46                        a0 - a1 * (2.0 * std::f32::consts::PI * x / n).cos()
47                            + a2 * (4.0 * std::f32::consts::PI * x / n).cos()
48                    }
49                    Self::BlackmanHarris => {
50                        let a0 = 0.35875;
51                        let a1 = 0.48829;
52                        let a2 = 0.14128;
53                        let a3 = 0.01168;
54                        a0 - a1 * (2.0 * std::f32::consts::PI * x / n).cos()
55                            + a2 * (4.0 * std::f32::consts::PI * x / n).cos()
56                            - a3 * (6.0 * std::f32::consts::PI * x / n).cos()
57                    }
58                    Self::Kaiser(beta_10) => {
59                        let beta = *beta_10 as f32 / 10.0;
60                        let alpha = (n - 1.0) / 2.0;
61                        let r = (x - alpha) / alpha;
62                        bessel_i0(beta * (1.0 - r * r).sqrt()) / bessel_i0(beta)
63                    }
64                }
65            })
66            .collect()
67    }
68
69    /// Get the coherent gain for this window.
70    pub fn coherent_gain(&self) -> f32 {
71        match self {
72            Self::Rectangular => 1.0,
73            Self::Hann => 0.5,
74            Self::Hamming => 0.54,
75            Self::Blackman => 0.42,
76            Self::BlackmanHarris => 0.35875,
77            Self::Kaiser(_) => 0.5, // Approximate
78        }
79    }
80}
81
82/// Bessel I0 function for Kaiser window.
83fn bessel_i0(x: f32) -> f32 {
84    let mut sum = 1.0f32;
85    let mut term = 1.0f32;
86    let x2 = x * x / 4.0;
87
88    for k in 1..20 {
89        term *= x2 / (k * k) as f32;
90        sum += term;
91        if term < 1e-10 {
92            break;
93        }
94    }
95    sum
96}
97
98/// FFT processor for time-to-frequency conversion.
99pub struct FftProcessor {
100    /// FFT size (must be power of 2).
101    fft_size: usize,
102    /// Hop size for overlap-add.
103    hop_size: usize,
104    /// Sample rate in Hz.
105    sample_rate: u32,
106    /// Window function (stored for potential reconfiguration).
107    #[allow(dead_code)]
108    window: WindowFunction,
109    /// Pre-computed window coefficients.
110    window_coeffs: Vec<f32>,
111    /// FFT planner.
112    fft: Arc<dyn Fft<f32>>,
113    /// Scratch buffer.
114    scratch: Vec<NumComplex<f32>>,
115    /// Input buffer for overlap.
116    input_buffer: Vec<f32>,
117}
118
119impl FftProcessor {
120    /// Create a new FFT processor.
121    pub fn new(fft_size: usize, hop_size: usize, sample_rate: u32) -> Result<Self> {
122        Self::with_window(fft_size, hop_size, sample_rate, WindowFunction::Hann)
123    }
124
125    /// Create with a specific window function.
126    pub fn with_window(
127        fft_size: usize,
128        hop_size: usize,
129        sample_rate: u32,
130        window: WindowFunction,
131    ) -> Result<Self> {
132        if !fft_size.is_power_of_two() {
133            return Err(AudioFftError::config(format!(
134                "FFT size must be power of 2, got {}",
135                fft_size
136            )));
137        }
138
139        if hop_size > fft_size {
140            return Err(AudioFftError::config(format!(
141                "Hop size {} cannot exceed FFT size {}",
142                hop_size, fft_size
143            )));
144        }
145
146        let mut planner = FftPlanner::new();
147        let fft = planner.plan_fft_forward(fft_size);
148        let scratch_len = fft.get_inplace_scratch_len();
149
150        Ok(Self {
151            fft_size,
152            hop_size,
153            sample_rate,
154            window,
155            window_coeffs: window.generate(fft_size),
156            fft,
157            scratch: vec![NumComplex::default(); scratch_len],
158            input_buffer: Vec::with_capacity(fft_size * 2),
159        })
160    }
161
162    /// Get the FFT size.
163    pub fn fft_size(&self) -> usize {
164        self.fft_size
165    }
166
167    /// Get the hop size.
168    pub fn hop_size(&self) -> usize {
169        self.hop_size
170    }
171
172    /// Get the number of frequency bins (positive frequencies only).
173    pub fn num_bins(&self) -> usize {
174        self.fft_size / 2 + 1
175    }
176
177    /// Get the frequency in Hz for a given bin.
178    pub fn bin_to_frequency(&self, bin: usize) -> f32 {
179        bin as f32 * self.sample_rate as f32 / self.fft_size as f32
180    }
181
182    /// Get the bin index for a given frequency.
183    pub fn frequency_to_bin(&self, freq: f32) -> usize {
184        (freq * self.fft_size as f32 / self.sample_rate as f32).round() as usize
185    }
186
187    /// Process a frame of audio and return FFT bins.
188    pub fn process_frame(&mut self, samples: &[f32]) -> Vec<Complex> {
189        // Add samples to input buffer
190        self.input_buffer.extend_from_slice(samples);
191
192        // Check if we have enough for a frame
193        if self.input_buffer.len() < self.fft_size {
194            return Vec::new();
195        }
196
197        // Take the FFT frame
198        let mut buffer: Vec<NumComplex<f32>> = self.input_buffer[..self.fft_size]
199            .iter()
200            .enumerate()
201            .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
202            .collect();
203
204        // Perform FFT
205        self.fft
206            .process_with_scratch(&mut buffer, &mut self.scratch);
207
208        // Remove processed samples (hop)
209        self.input_buffer.drain(..self.hop_size);
210
211        // Convert to our Complex type (positive frequencies only)
212        buffer[..self.num_bins()]
213            .iter()
214            .map(|c| Complex::new(c.re, c.im))
215            .collect()
216    }
217
218    /// Process all available frames in the buffer.
219    pub fn process_all(&mut self, samples: &[f32]) -> Vec<Vec<Complex>> {
220        self.input_buffer.extend_from_slice(samples);
221
222        let mut frames = Vec::new();
223
224        while self.input_buffer.len() >= self.fft_size {
225            // Take the FFT frame
226            let mut buffer: Vec<NumComplex<f32>> = self.input_buffer[..self.fft_size]
227                .iter()
228                .enumerate()
229                .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
230                .collect();
231
232            // Perform FFT
233            self.fft
234                .process_with_scratch(&mut buffer, &mut self.scratch);
235
236            // Remove processed samples (hop)
237            self.input_buffer.drain(..self.hop_size);
238
239            // Convert to our Complex type
240            frames.push(
241                buffer[..self.num_bins()]
242                    .iter()
243                    .map(|c| Complex::new(c.re, c.im))
244                    .collect(),
245            );
246        }
247
248        frames
249    }
250
251    /// Flush remaining samples (zero-pad if necessary).
252    pub fn flush(&mut self) -> Option<Vec<Complex>> {
253        if self.input_buffer.is_empty() {
254            return None;
255        }
256
257        // Zero-pad to FFT size
258        self.input_buffer.resize(self.fft_size, 0.0);
259
260        let mut buffer: Vec<NumComplex<f32>> = self
261            .input_buffer
262            .iter()
263            .enumerate()
264            .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
265            .collect();
266
267        self.fft
268            .process_with_scratch(&mut buffer, &mut self.scratch);
269        self.input_buffer.clear();
270
271        Some(
272            buffer[..self.num_bins()]
273                .iter()
274                .map(|c| Complex::new(c.re, c.im))
275                .collect(),
276        )
277    }
278
279    /// Reset the processor state.
280    pub fn reset(&mut self) {
281        self.input_buffer.clear();
282    }
283}
284
285/// IFFT processor for frequency-to-time conversion.
286pub struct IfftProcessor {
287    /// FFT size.
288    fft_size: usize,
289    /// Hop size.
290    hop_size: usize,
291    /// IFFT planner.
292    ifft: Arc<dyn Fft<f32>>,
293    /// Scratch buffer.
294    scratch: Vec<NumComplex<f32>>,
295    /// Synthesis window.
296    synthesis_window: Vec<f32>,
297    /// Output buffer for overlap-add.
298    output_buffer: Vec<f32>,
299    /// Normalization factor.
300    norm_factor: f32,
301}
302
303impl IfftProcessor {
304    /// Create a new IFFT processor.
305    pub fn new(fft_size: usize, hop_size: usize) -> Result<Self> {
306        Self::with_window(fft_size, hop_size, WindowFunction::Hann)
307    }
308
309    /// Create with a specific synthesis window.
310    pub fn with_window(fft_size: usize, hop_size: usize, window: WindowFunction) -> Result<Self> {
311        if !fft_size.is_power_of_two() {
312            return Err(AudioFftError::config(format!(
313                "FFT size must be power of 2, got {}",
314                fft_size
315            )));
316        }
317
318        let mut planner = FftPlanner::new();
319        let ifft = planner.plan_fft_inverse(fft_size);
320        let scratch_len = ifft.get_inplace_scratch_len();
321
322        // Calculate COLA normalization for overlap-add
323        // For Hann window with 50% overlap, sum of squared windows = 1.5
324        let window_coeffs = window.generate(fft_size);
325        let overlap_factor = fft_size / hop_size;
326
327        // Calculate the sum of squared windows at each output sample
328        let mut cola_sum = vec![0.0f32; hop_size];
329        for offset in 0..overlap_factor {
330            for (i, sum) in cola_sum.iter_mut().enumerate() {
331                let window_idx = offset * hop_size + i;
332                if window_idx < fft_size {
333                    *sum += window_coeffs[window_idx] * window_coeffs[window_idx];
334                }
335            }
336        }
337        let avg_cola = cola_sum.iter().sum::<f32>() / hop_size as f32;
338
339        Ok(Self {
340            fft_size,
341            hop_size,
342            ifft,
343            scratch: vec![NumComplex::default(); scratch_len],
344            synthesis_window: window_coeffs,
345            output_buffer: vec![0.0; fft_size * 2],
346            norm_factor: 1.0 / (fft_size as f32 * avg_cola.sqrt()),
347        })
348    }
349
350    /// Process FFT bins and return audio samples.
351    pub fn process_frame(&mut self, bins: &[Complex]) -> Vec<f32> {
352        // Reconstruct full spectrum (mirror conjugate)
353        let mut buffer: Vec<NumComplex<f32>> = Vec::with_capacity(self.fft_size);
354
355        // Positive frequencies
356        for bin in bins.iter().take(self.fft_size / 2 + 1) {
357            buffer.push(NumComplex::new(bin.re, bin.im));
358        }
359
360        // Negative frequencies (conjugate mirror)
361        for i in 1..self.fft_size / 2 {
362            let idx = self.fft_size / 2 - i;
363            if idx < bins.len() {
364                buffer.push(NumComplex::new(bins[idx].re, -bins[idx].im));
365            } else {
366                buffer.push(NumComplex::default());
367            }
368        }
369
370        // Pad if necessary
371        while buffer.len() < self.fft_size {
372            buffer.push(NumComplex::default());
373        }
374
375        // Perform IFFT
376        self.ifft
377            .process_with_scratch(&mut buffer, &mut self.scratch);
378
379        // Apply synthesis window and add to output buffer
380        for (i, c) in buffer.iter().enumerate() {
381            self.output_buffer[i] += c.re * self.synthesis_window[i] * self.norm_factor;
382        }
383
384        // Extract output samples
385        let output: Vec<f32> = self.output_buffer[..self.hop_size].to_vec();
386
387        // Shift buffer
388        self.output_buffer.copy_within(self.hop_size.., 0);
389        for i in (self.output_buffer.len() - self.hop_size)..self.output_buffer.len() {
390            self.output_buffer[i] = 0.0;
391        }
392
393        output
394    }
395
396    /// Flush remaining samples.
397    pub fn flush(&mut self) -> Vec<f32> {
398        let mut output = Vec::new();
399
400        // Drain the output buffer
401        while self.output_buffer.iter().any(|&x| x.abs() > 1e-10) {
402            output.extend_from_slice(&self.output_buffer[..self.hop_size]);
403            self.output_buffer.copy_within(self.hop_size.., 0);
404            for i in (self.output_buffer.len() - self.hop_size)..self.output_buffer.len() {
405                self.output_buffer[i] = 0.0;
406            }
407        }
408
409        output
410    }
411
412    /// Reset the processor state.
413    pub fn reset(&mut self) {
414        self.output_buffer.fill(0.0);
415    }
416}
417
418/// Helper for STFT processing with proper overlap-add.
419pub struct StftProcessor {
420    /// FFT processor.
421    pub fft: FftProcessor,
422    /// IFFT processor.
423    pub ifft: IfftProcessor,
424}
425
426impl StftProcessor {
427    /// Create a new STFT processor.
428    pub fn new(fft_size: usize, hop_size: usize, sample_rate: u32) -> Result<Self> {
429        Self::with_window(fft_size, hop_size, sample_rate, WindowFunction::Hann)
430    }
431
432    /// Create with a specific window function.
433    pub fn with_window(
434        fft_size: usize,
435        hop_size: usize,
436        sample_rate: u32,
437        window: WindowFunction,
438    ) -> Result<Self> {
439        Ok(Self {
440            fft: FftProcessor::with_window(fft_size, hop_size, sample_rate, window)?,
441            ifft: IfftProcessor::with_window(fft_size, hop_size, window)?,
442        })
443    }
444
445    /// Process samples through FFT, apply a function, and IFFT back.
446    pub fn process<F>(&mut self, samples: &[f32], mut processor: F) -> Vec<f32>
447    where
448        F: FnMut(&mut [Complex]),
449    {
450        let mut output = Vec::new();
451
452        for mut frame in self.fft.process_all(samples) {
453            processor(&mut frame);
454            output.extend(self.ifft.process_frame(&frame));
455        }
456
457        output
458    }
459
460    /// Flush remaining samples.
461    pub fn flush<F>(&mut self, mut processor: F) -> Vec<f32>
462    where
463        F: FnMut(&mut [Complex]),
464    {
465        let mut output = Vec::new();
466
467        if let Some(mut frame) = self.fft.flush() {
468            processor(&mut frame);
469            output.extend(self.ifft.process_frame(&frame));
470        }
471
472        output.extend(self.ifft.flush());
473        output
474    }
475
476    /// Reset the processor state.
477    pub fn reset(&mut self) {
478        self.fft.reset();
479        self.ifft.reset();
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_window_functions() {
489        let size = 1024;
490
491        let hann = WindowFunction::Hann.generate(size);
492        assert!((hann[0] - 0.0).abs() < 1e-6);
493        assert!((hann[size / 2] - 1.0).abs() < 1e-6);
494
495        let hamming = WindowFunction::Hamming.generate(size);
496        assert!((hamming[0] - 0.08).abs() < 0.01);
497    }
498
499    #[test]
500    fn test_fft_roundtrip() {
501        let fft_size = 1024;
502        let hop_size = 256;
503        let sample_rate = 44100;
504
505        let mut stft = StftProcessor::new(fft_size, hop_size, sample_rate).unwrap();
506
507        // Generate a test signal (sine wave at 440 Hz)
508        let duration = 0.1;
509        let samples: Vec<f32> = (0..(sample_rate as f32 * duration) as usize)
510            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
511            .collect();
512
513        // Process through FFT and back
514        let output = stft.process(&samples, |_bins| {
515            // Identity transform
516        });
517
518        // Check that output length is reasonable
519        assert!(!output.is_empty());
520
521        // The output should be similar to input (with some latency)
522        // Due to windowing/overlap-add, there's some distortion at edges
523    }
524
525    #[test]
526    fn test_bin_frequency_conversion() {
527        let fft = FftProcessor::new(2048, 512, 44100).unwrap();
528
529        // DC bin
530        assert!((fft.bin_to_frequency(0) - 0.0).abs() < 1e-6);
531
532        // Nyquist
533        let nyquist = fft.bin_to_frequency(1024);
534        assert!((nyquist - 22050.0).abs() < 1.0);
535
536        // Round-trip
537        let freq = 1000.0;
538        let bin = fft.frequency_to_bin(freq);
539        let recovered = fft.bin_to_frequency(bin);
540        assert!((recovered - freq).abs() < 50.0); // Within one bin width
541    }
542}