audio_engine_core/processor/
fir_eq.rs1use rustfft::{num_complex::Complex, FftPlanner};
7use std::f64::consts::PI;
8
9pub const STANDARD_BANDS: [(f64, f64); 10] = [
11 (31.0, 0.0), (62.0, 0.0), (125.0, 0.0), (250.0, 0.0), (500.0, 0.0), (1000.0, 0.0), (2000.0, 0.0), (4000.0, 0.0), (8000.0, 0.0), (16000.0, 0.0), ];
22
23#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub enum FirPhaseMode {
26 #[default]
27 Linear, Minimum, }
30
31pub struct FirEq {
33 num_taps: usize,
35 sample_rate: f64,
37 bands: [(f64, f64); 10],
39 phase_mode: FirPhaseMode,
41 cached_ir: Vec<f64>,
43}
44
45impl FirEq {
46 pub fn new(sample_rate: f64, num_taps: usize) -> Self {
52 let num_taps = if num_taps.is_multiple_of(2) {
54 num_taps + 1
55 } else {
56 num_taps
57 };
58
59 let mut fir_eq = Self {
60 num_taps,
61 sample_rate,
62 bands: STANDARD_BANDS,
63 phase_mode: FirPhaseMode::Linear,
64 cached_ir: Vec::new(),
65 };
66
67 fir_eq.regenerate_ir();
69 fir_eq
70 }
71
72 pub fn set_sample_rate(&mut self, sr: f64) {
74 self.sample_rate = sr;
75 self.regenerate_ir();
76 }
77
78 pub fn set_num_taps(&mut self, taps: usize) {
80 self.num_taps = if taps.is_multiple_of(2) {
81 taps + 1
82 } else {
83 taps
84 };
85 self.regenerate_ir();
86 }
87
88 pub fn set_phase_mode(&mut self, mode: FirPhaseMode) {
90 self.phase_mode = mode;
91 self.regenerate_ir();
92 }
93
94 pub fn set_band(&mut self, band_idx: usize, gain_db: f64) {
100 if band_idx < self.bands.len() {
101 self.bands[band_idx].1 = gain_db.clamp(-15.0, 15.0);
102 self.regenerate_ir();
103 }
104 }
105
106 pub fn set_bands(&mut self, gains_db: &[f64; 10]) {
108 for (i, &gain) in gains_db.iter().enumerate() {
109 self.bands[i].1 = gain.clamp(-15.0, 15.0);
110 }
111 self.regenerate_ir();
112 }
113
114 pub fn get_bands(&self) -> [(f64, f64); 10] {
116 self.bands
117 }
118
119 pub fn get_ir(&self, channels: usize) -> Vec<f64> {
122 let mut ir = Vec::with_capacity(self.cached_ir.len() * channels);
123 for &sample in &self.cached_ir {
124 for _ in 0..channels {
125 ir.push(sample);
126 }
127 }
128 ir
129 }
130
131 pub fn ir_length(&self) -> usize {
133 self.cached_ir.len()
134 }
135
136 pub fn num_taps(&self) -> usize {
138 self.num_taps
139 }
140
141 fn regenerate_ir(&mut self) {
143 match self.phase_mode {
144 FirPhaseMode::Linear => self.generate_linear_phase_ir(),
145 FirPhaseMode::Minimum => self.generate_minimum_phase_ir(),
146 }
147 }
148
149 fn generate_linear_phase_ir(&mut self) {
151 let num_taps = self.num_taps;
152 let sr = self.sample_rate;
153
154 let mut fft_size = 1;
156 while fft_size < num_taps * 2 {
157 fft_size <<= 1;
158 }
159
160 let num_bins = fft_size / 2 + 1;
162 let mut magnitude = vec![1.0f64; num_bins];
163
164 for (bin, mag) in magnitude.iter_mut().enumerate() {
165 let freq = bin as f64 * sr / fft_size as f64;
166 *mag = self.interpolate_gain(freq);
167 }
168
169 let linear_mag: Vec<f64> = magnitude
171 .iter()
172 .map(|&db| 10.0_f64.powf(db / 20.0))
173 .collect();
174
175 let mut spectrum = vec![Complex::new(0.0, 0.0); fft_size];
177 for k in 0..linear_mag.len() {
178 spectrum[k] = Complex::new(linear_mag[k], 0.0);
179 if k > 0 && k < fft_size / 2 {
180 spectrum[fft_size - k] = Complex::new(linear_mag[k], 0.0);
181 }
182 }
183
184 let mut planner = FftPlanner::new();
186 let ifft = planner.plan_fft_inverse(fft_size);
187 ifft.process(&mut spectrum);
188
189 let half = num_taps / 2;
191 let mut ir_mono: Vec<f64> = (0..num_taps)
192 .map(|i| {
193 let idx = (i + fft_size - half) % fft_size;
194 spectrum[idx].re / fft_size as f64
195 })
196 .collect();
197
198 for (i, sample) in ir_mono.iter_mut().enumerate() {
200 let w = 0.5 * (1.0 - (2.0 * PI * i as f64 / (num_taps - 1) as f64).cos());
201 *sample *= w;
202 }
203
204 let ref_gain = self.interpolate_gain(1000.0);
206 let norm_factor = 10.0_f64.powf(-ref_gain / 20.0);
207 for sample in ir_mono.iter_mut() {
208 *sample *= norm_factor;
209 }
210
211 self.cached_ir = ir_mono;
212 }
213
214 fn generate_minimum_phase_ir(&mut self) {
217 let num_taps = self.num_taps;
218 let sr = self.sample_rate;
219
220 let mut fft_size = 1;
222 while fft_size < num_taps * 4 {
223 fft_size <<= 1;
224 }
225
226 let num_bins = fft_size / 2 + 1;
227
228 let mut log_mag = vec![0.0f64; fft_size];
230 for bin in 0..num_bins {
231 let freq = bin as f64 * sr / fft_size as f64;
232 let gain_db = self.interpolate_gain(freq);
233 log_mag[bin] = gain_db / 20.0 * std::f64::consts::LN_10; if bin > 0 && bin < fft_size / 2 {
235 log_mag[fft_size - bin] = log_mag[bin];
236 }
237 }
238
239 let mut spectrum: Vec<Complex<f64>> =
241 log_mag.iter().map(|&lm| Complex::new(lm, 0.0)).collect();
242
243 let mut planner = FftPlanner::new();
244 let ifft = planner.plan_fft_inverse(fft_size);
245 ifft.process(&mut spectrum);
246
247 let inv_n = 1.0 / fft_size as f64;
252 for s in spectrum.iter_mut() {
253 *s *= inv_n;
254 }
255
256 let half = fft_size / 2;
258 for (i, s) in spectrum.iter_mut().enumerate() {
259 if i == 0 || i == half {
260 } else if i < half {
262 *s *= 2.0; } else {
264 *s = Complex::new(0.0, 0.0); }
266 }
267
268 let fft = planner.plan_fft_forward(fft_size);
270 fft.process(&mut spectrum);
271
272 for s in spectrum.iter_mut() {
274 *s = s.exp();
275 }
276
277 ifft.process(&mut spectrum);
279
280 let mut ir_mono: Vec<f64> = (0..num_taps)
282 .map(|i| spectrum[i].re / fft_size as f64)
283 .collect();
284
285 for (i, sample) in ir_mono.iter_mut().enumerate() {
287 if i > num_taps / 2 {
288 let w =
289 0.5 * (1.0 + ((num_taps - 1 - i) as f64 / (num_taps / 2) as f64 * PI).cos());
290 *sample *= w;
291 }
292 }
293
294 let ref_gain = self.interpolate_gain(1000.0);
296 let norm_factor = 10.0_f64.powf(-ref_gain / 20.0);
297 for sample in ir_mono.iter_mut() {
298 *sample *= norm_factor;
299 }
300
301 self.cached_ir = ir_mono;
302 }
303
304 fn interpolate_gain(&self, freq_hz: f64) -> f64 {
306 if freq_hz <= 0.0 {
307 return self.bands[0].1;
308 }
309
310 for i in 0..self.bands.len() - 1 {
312 let (f0, g0) = self.bands[i];
313 let (f1, g1) = self.bands[i + 1];
314
315 if freq_hz >= f0 && freq_hz <= f1 {
316 let log_f0 = f0.log2();
318 let log_f1 = f1.log2();
319 let log_freq = freq_hz.log2();
320
321 if (log_f1 - log_f0).abs() < 1e-10 {
322 return g0;
323 }
324
325 let t = (log_freq - log_f0) / (log_f1 - log_f0);
326 return g0 + (g1 - g0) * t;
327 }
328 }
329
330 if freq_hz < self.bands[0].0 {
332 return self.bands[0].1;
333 }
334 self.bands[self.bands.len() - 1].1
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_fir_eq_flat() {
344 let fir = FirEq::new(44100.0, 1023);
346 let ir = fir.get_ir(2);
347 assert!(!ir.is_empty());
348
349 let sum: f64 = fir.cached_ir.iter().sum();
351 assert!(
352 (sum - 1.0).abs() < 0.1,
353 "Flat IR sum should be ~1.0, got {}",
354 sum
355 );
356 }
357
358 #[test]
359 fn test_fir_eq_bass_boost() {
360 let mut fir = FirEq::new(44100.0, 1023);
361 fir.set_band(0, 6.0); let ir = fir.get_ir(2);
365 assert!(!ir.is_empty());
366
367 let sum: f64 = fir.cached_ir.iter().sum();
369 assert!(sum > 1.0, "Bass boost IR sum should be > 1.0, got {}", sum);
370 }
371
372 #[test]
373 fn test_interpolate_gain() {
374 let fir = FirEq::new(44100.0, 1023);
375
376 let gain_750 = fir.interpolate_gain(750.0);
378 let gain_500 = fir.interpolate_gain(500.0); let gain_1000 = fir.interpolate_gain(1000.0); assert!((gain_500 - 0.0).abs() < 0.01);
381 assert!((gain_1000 - 0.0).abs() < 0.01);
382
383 assert!(
385 (gain_750 - 0.0).abs() < 0.01,
386 "Gain at 750 Hz should be ~0 dB"
387 );
388 }
389
390 #[test]
391 fn test_minimum_phase_flat() {
392 let mut fir = FirEq::new(44100.0, 1023);
394 fir.set_phase_mode(FirPhaseMode::Minimum);
395
396 let sum: f64 = fir.cached_ir.iter().sum();
397 assert!(
398 (sum - 1.0).abs() < 0.15,
399 "Minimum phase flat IR sum should be ~1.0, got {}",
400 sum
401 );
402 }
403
404 #[test]
405 fn test_minimum_phase_boost_bounded() {
406 let mut fir = FirEq::new(44100.0, 1023);
409 fir.set_phase_mode(FirPhaseMode::Minimum);
410 fir.set_band(0, 6.0); let sum: f64 = fir.cached_ir.iter().sum();
413 assert!(
415 sum.abs() < 100.0,
416 "Minimum phase boosted IR sum should be bounded, got {}",
417 sum
418 );
419 assert!(
420 sum > 0.5,
421 "Minimum phase boosted IR sum should be positive and > 0.5, got {}",
422 sum
423 );
424 }
425}