Skip to main content

audio_engine_core/processor/
fir_eq.rs

1//! FIR EQ: Generates impulse response from frequency response specification
2//!
3//! This module creates linear-phase FIR filters from band gain specifications.
4//! The generated IR is used with FFTConvolver for efficient convolution.
5
6use rustfft::{num_complex::Complex, FftPlanner};
7use std::f64::consts::PI;
8
9/// Standard 10-band EQ frequencies (ISO octave bands)
10pub const STANDARD_BANDS: [(f64, f64); 10] = [
11    (31.0, 0.0),    // 31 Hz
12    (62.0, 0.0),    // 62 Hz
13    (125.0, 0.0),   // 125 Hz
14    (250.0, 0.0),   // 250 Hz
15    (500.0, 0.0),   // 500 Hz
16    (1000.0, 0.0),  // 1 kHz
17    (2000.0, 0.0),  // 2 kHz
18    (4000.0, 0.0),  // 4 kHz
19    (8000.0, 0.0),  // 8 kHz
20    (16000.0, 0.0), // 16 kHz
21];
22
23/// Phase mode for FIR EQ
24#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub enum FirPhaseMode {
26    #[default]
27    Linear, // Linear phase (symmetric IR, half-tap latency)
28    Minimum, // Minimum phase (zero latency, non-linear phase)
29}
30
31/// FIR EQ generator: creates IR from band gain specifications
32pub struct FirEq {
33    /// Number of FIR taps (must be odd for linear phase)
34    num_taps: usize,
35    /// Sample rate
36    sample_rate: f64,
37    /// Band gains: (freq_hz, gain_db) pairs, sorted by frequency
38    bands: [(f64, f64); 10],
39    /// Phase mode
40    phase_mode: FirPhaseMode,
41    /// Cached IR (regenerated when bands change)
42    cached_ir: Vec<f64>,
43}
44
45impl FirEq {
46    /// Create a new FIR EQ generator
47    ///
48    /// # Arguments
49    /// * `sample_rate` - Audio sample rate in Hz
50    /// * `num_taps` - Number of FIR taps (must be odd, will be forced to odd if even)
51    pub fn new(sample_rate: f64, num_taps: usize) -> Self {
52        // Ensure odd number of taps for symmetric IR
53        let num_taps = if num_taps.is_multiple_of(2) {
54            num_taps + 1
55        } else {
56            num_taps
57        };
58
59        let mut fir_eq = Self {
60            num_taps,
61            sample_rate,
62            bands: STANDARD_BANDS,
63            phase_mode: FirPhaseMode::Linear,
64            cached_ir: Vec::new(),
65        };
66
67        // Generate initial IR
68        fir_eq.regenerate_ir();
69        fir_eq
70    }
71
72    /// Set sample rate (triggers IR regeneration)
73    pub fn set_sample_rate(&mut self, sr: f64) {
74        self.sample_rate = sr;
75        self.regenerate_ir();
76    }
77
78    /// Set number of taps (triggers IR regeneration)
79    pub fn set_num_taps(&mut self, taps: usize) {
80        self.num_taps = if taps.is_multiple_of(2) {
81            taps + 1
82        } else {
83            taps
84        };
85        self.regenerate_ir();
86    }
87
88    /// Set phase mode (triggers IR regeneration)
89    pub fn set_phase_mode(&mut self, mode: FirPhaseMode) {
90        self.phase_mode = mode;
91        self.regenerate_ir();
92    }
93
94    /// Update a band gain (triggers IR regeneration)
95    ///
96    /// # Arguments
97    /// * `band_idx` - Band index (0-9 for standard 10-band EQ)
98    /// * `gain_db` - Gain in dB (-15 to +15)
99    pub fn set_band(&mut self, band_idx: usize, gain_db: f64) {
100        if band_idx < self.bands.len() {
101            self.bands[band_idx].1 = gain_db.clamp(-15.0, 15.0);
102            self.regenerate_ir();
103        }
104    }
105
106    /// Set all bands at once (single regeneration)
107    pub fn set_bands(&mut self, gains_db: &[f64; 10]) {
108        for (i, &gain) in gains_db.iter().enumerate() {
109            self.bands[i].1 = gain.clamp(-15.0, 15.0);
110        }
111        self.regenerate_ir();
112    }
113
114    /// Get current band gains
115    pub fn get_bands(&self) -> [(f64, f64); 10] {
116        self.bands
117    }
118
119    /// Get current IR (interleaved for all channels)
120    /// Returns IR repeated for each channel
121    pub fn get_ir(&self, channels: usize) -> Vec<f64> {
122        let mut ir = Vec::with_capacity(self.cached_ir.len() * channels);
123        for &sample in &self.cached_ir {
124            for _ in 0..channels {
125                ir.push(sample);
126            }
127        }
128        ir
129    }
130
131    /// Get IR length (per channel)
132    pub fn ir_length(&self) -> usize {
133        self.cached_ir.len()
134    }
135
136    /// Get number of taps
137    pub fn num_taps(&self) -> usize {
138        self.num_taps
139    }
140
141    /// Regenerate IR from current band settings
142    fn regenerate_ir(&mut self) {
143        match self.phase_mode {
144            FirPhaseMode::Linear => self.generate_linear_phase_ir(),
145            FirPhaseMode::Minimum => self.generate_minimum_phase_ir(),
146        }
147    }
148
149    /// Generate linear-phase FIR IR using frequency sampling method
150    fn generate_linear_phase_ir(&mut self) {
151        let num_taps = self.num_taps;
152        let sr = self.sample_rate;
153
154        // FFT size must be at least 2x num_taps for linear convolution
155        let mut fft_size = 1;
156        while fft_size < num_taps * 2 {
157            fft_size <<= 1;
158        }
159
160        // 1. Build desired frequency response magnitude at each FFT bin
161        let num_bins = fft_size / 2 + 1;
162        let mut magnitude = vec![1.0f64; num_bins];
163
164        for (bin, mag) in magnitude.iter_mut().enumerate() {
165            let freq = bin as f64 * sr / fft_size as f64;
166            *mag = self.interpolate_gain(freq);
167        }
168
169        // 2. Convert dB magnitude to linear
170        let linear_mag: Vec<f64> = magnitude
171            .iter()
172            .map(|&db| 10.0_f64.powf(db / 20.0))
173            .collect();
174
175        // 3. Build symmetric frequency response (Hermitian symmetry for real output)
176        let mut spectrum = vec![Complex::new(0.0, 0.0); fft_size];
177        for k in 0..linear_mag.len() {
178            spectrum[k] = Complex::new(linear_mag[k], 0.0);
179            if k > 0 && k < fft_size / 2 {
180                spectrum[fft_size - k] = Complex::new(linear_mag[k], 0.0);
181            }
182        }
183
184        // 4. IFFT to get the ideal IR
185        let mut planner = FftPlanner::new();
186        let ifft = planner.plan_fft_inverse(fft_size);
187        ifft.process(&mut spectrum);
188
189        // 5. Extract center num_taps samples (circular shift to make causal)
190        let half = num_taps / 2;
191        let mut ir_mono: Vec<f64> = (0..num_taps)
192            .map(|i| {
193                let idx = (i + fft_size - half) % fft_size;
194                spectrum[idx].re / fft_size as f64
195            })
196            .collect();
197
198        // 6. Apply Hann window to reduce Gibbs phenomenon
199        for (i, sample) in ir_mono.iter_mut().enumerate() {
200            let w = 0.5 * (1.0 - (2.0 * PI * i as f64 / (num_taps - 1) as f64).cos());
201            *sample *= w;
202        }
203
204        // 7. Normalize to preserve overall gain (0 dB at 1 kHz reference)
205        let ref_gain = self.interpolate_gain(1000.0);
206        let norm_factor = 10.0_f64.powf(-ref_gain / 20.0);
207        for sample in ir_mono.iter_mut() {
208            *sample *= norm_factor;
209        }
210
211        self.cached_ir = ir_mono;
212    }
213
214    /// Generate minimum-phase FIR IR
215    /// Uses cepstral method: log|H(w)| -> IFFT -> cosine transform -> FFT -> exp -> IFFT
216    fn generate_minimum_phase_ir(&mut self) {
217        let num_taps = self.num_taps;
218        let sr = self.sample_rate;
219
220        // FFT size
221        let mut fft_size = 1;
222        while fft_size < num_taps * 4 {
223            fft_size <<= 1;
224        }
225
226        let num_bins = fft_size / 2 + 1;
227
228        // 1. Build desired magnitude response
229        let mut log_mag = vec![0.0f64; fft_size];
230        for bin in 0..num_bins {
231            let freq = bin as f64 * sr / fft_size as f64;
232            let gain_db = self.interpolate_gain(freq);
233            log_mag[bin] = gain_db / 20.0 * std::f64::consts::LN_10; // Convert to natural log
234            if bin > 0 && bin < fft_size / 2 {
235                log_mag[fft_size - bin] = log_mag[bin];
236            }
237        }
238
239        // 2. IFFT of log magnitude to get cepstral coefficients
240        let mut spectrum: Vec<Complex<f64>> =
241            log_mag.iter().map(|&lm| Complex::new(lm, 0.0)).collect();
242
243        let mut planner = FftPlanner::new();
244        let ifft = planner.plan_fft_inverse(fft_size);
245        ifft.process(&mut spectrum);
246
247        // FIX for Defect 7: rustfft's IFFT does not apply 1/N normalization.
248        // Without this, cepstral coefficients are amplified by N, which propagates
249        // through FFT→exp→IFFT and distorts the frequency response shape
250        // (gains raised to the N-th power instead of being preserved).
251        let inv_n = 1.0 / fft_size as f64;
252        for s in spectrum.iter_mut() {
253            *s *= inv_n;
254        }
255
256        // 3. Apply cepstral window (keep positive frequencies, double, zero negative)
257        let half = fft_size / 2;
258        for (i, s) in spectrum.iter_mut().enumerate() {
259            if i == 0 || i == half {
260                // Keep DC and Nyquist as-is
261            } else if i < half {
262                *s *= 2.0; // Double positive frequencies
263            } else {
264                *s = Complex::new(0.0, 0.0); // Zero negative frequencies
265            }
266        }
267
268        // 4. FFT back to frequency domain
269        let fft = planner.plan_fft_forward(fft_size);
270        fft.process(&mut spectrum);
271
272        // 5. Exponentiate to get minimum phase frequency response
273        for s in spectrum.iter_mut() {
274            *s = s.exp();
275        }
276
277        // 6. IFFT to get minimum phase IR
278        ifft.process(&mut spectrum);
279
280        // 7. Extract first num_taps samples
281        let mut ir_mono: Vec<f64> = (0..num_taps)
282            .map(|i| spectrum[i].re / fft_size as f64)
283            .collect();
284
285        // 8. Apply half-window (fade out at the end)
286        for (i, sample) in ir_mono.iter_mut().enumerate() {
287            if i > num_taps / 2 {
288                let w =
289                    0.5 * (1.0 + ((num_taps - 1 - i) as f64 / (num_taps / 2) as f64 * PI).cos());
290                *sample *= w;
291            }
292        }
293
294        // 9. Normalize
295        let ref_gain = self.interpolate_gain(1000.0);
296        let norm_factor = 10.0_f64.powf(-ref_gain / 20.0);
297        for sample in ir_mono.iter_mut() {
298            *sample *= norm_factor;
299        }
300
301        self.cached_ir = ir_mono;
302    }
303
304    /// Log-frequency interpolation of gain across EQ bands
305    fn interpolate_gain(&self, freq_hz: f64) -> f64 {
306        if freq_hz <= 0.0 {
307            return self.bands[0].1;
308        }
309
310        // Find surrounding bands
311        for i in 0..self.bands.len() - 1 {
312            let (f0, g0) = self.bands[i];
313            let (f1, g1) = self.bands[i + 1];
314
315            if freq_hz >= f0 && freq_hz <= f1 {
316                // Linear interpolation in log-frequency space
317                let log_f0 = f0.log2();
318                let log_f1 = f1.log2();
319                let log_freq = freq_hz.log2();
320
321                if (log_f1 - log_f0).abs() < 1e-10 {
322                    return g0;
323                }
324
325                let t = (log_freq - log_f0) / (log_f1 - log_f0);
326                return g0 + (g1 - g0) * t;
327            }
328        }
329
330        // Extrapolate from nearest band
331        if freq_hz < self.bands[0].0 {
332            return self.bands[0].1;
333        }
334        self.bands[self.bands.len() - 1].1
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_fir_eq_flat() {
344        // Flat response (all bands at 0 dB) should produce near-unity impulse
345        let fir = FirEq::new(44100.0, 1023);
346        let ir = fir.get_ir(2);
347        assert!(!ir.is_empty());
348
349        // Sum should be approximately 1.0 for unity gain
350        let sum: f64 = fir.cached_ir.iter().sum();
351        assert!(
352            (sum - 1.0).abs() < 0.1,
353            "Flat IR sum should be ~1.0, got {}",
354            sum
355        );
356    }
357
358    #[test]
359    fn test_fir_eq_bass_boost() {
360        let mut fir = FirEq::new(44100.0, 1023);
361        fir.set_band(0, 6.0); // Boost 31 Hz by 6 dB
362
363        // IR should still be generated without error
364        let ir = fir.get_ir(2);
365        assert!(!ir.is_empty());
366
367        // Sum should be larger due to bass boost
368        let sum: f64 = fir.cached_ir.iter().sum();
369        assert!(sum > 1.0, "Bass boost IR sum should be > 1.0, got {}", sum);
370    }
371
372    #[test]
373    fn test_interpolate_gain() {
374        let fir = FirEq::new(44100.0, 1023);
375
376        // Test interpolation between bands
377        let gain_750 = fir.interpolate_gain(750.0);
378        let gain_500 = fir.interpolate_gain(500.0); // 0 dB (standard band)
379        let gain_1000 = fir.interpolate_gain(1000.0); // 0 dB (standard band)
380        assert!((gain_500 - 0.0).abs() < 0.01);
381        assert!((gain_1000 - 0.0).abs() < 0.01);
382
383        // At 750 Hz (between 500 and 1000, both 0 dB), should be 0 dB
384        assert!(
385            (gain_750 - 0.0).abs() < 0.01,
386            "Gain at 750 Hz should be ~0 dB"
387        );
388    }
389
390    #[test]
391    fn test_minimum_phase_flat() {
392        // Flat response in minimum phase mode should also produce near-unity sum
393        let mut fir = FirEq::new(44100.0, 1023);
394        fir.set_phase_mode(FirPhaseMode::Minimum);
395
396        let sum: f64 = fir.cached_ir.iter().sum();
397        assert!(
398            (sum - 1.0).abs() < 0.15,
399            "Minimum phase flat IR sum should be ~1.0, got {}",
400            sum
401        );
402    }
403
404    #[test]
405    fn test_minimum_phase_boost_bounded() {
406        // Defect 7 regression test: with 1/N normalization, a 6 dB bass boost
407        // should produce a reasonable IR sum, not one amplified by N.
408        let mut fir = FirEq::new(44100.0, 1023);
409        fir.set_phase_mode(FirPhaseMode::Minimum);
410        fir.set_band(0, 6.0); // Boost 31 Hz by 6 dB
411
412        let sum: f64 = fir.cached_ir.iter().sum();
413        // The sum should be in a reasonable range (not blown up by N ~= 4096)
414        assert!(
415            sum.abs() < 100.0,
416            "Minimum phase boosted IR sum should be bounded, got {}",
417            sum
418        );
419        assert!(
420            sum > 0.5,
421            "Minimum phase boosted IR sum should be positive and > 0.5, got {}",
422            sum
423        );
424    }
425}