audio_viz/
filter.rs

1// SPDX-FileCopyrightText: The audio-viz authors
2// SPDX-License-Identifier: MPL-2.0
3
4use biquad::{Biquad as _, Coefficients, DirectForm2Transposed, Hertz, Q_BUTTERWORTH_F32};
5
6use super::{FilteredWaveformBin, WaveformBin, WaveformVal};
7
8// Only needed for default initialization.
9const DEFAULT_SAMPLE_RATE_HZ: f32 = 44_100.0;
10
11// Only needed for default initialization.
12//
13// Adopted from [Superpowered](https://docs.superpowered.com/reference/latest/analyzer>)
14// which uses a resolution of 150 points/sec resolution.
15const DEFAULT_BINS_PER_SEC: f32 = 150.0;
16
17const MIN_SAMPLES_PER_BIN: f32 = 64.0;
18
19// Rekordbox bands: ~200/2000 Hz
20// Superpowered bands: 200/1600 Hz
21// [Superpowered](https://docs.superpowered.com/reference/latest/analyzer>)
22
23/// Crossover low/mid (low pass)
24const DEFAULT_LOW_LP_FILTER_HZ: f32 = 200.0;
25
26/// Crossover low/mid (high pass)
27///
28/// Overlapping with lows, i.e. lower than [`DEFAULT_LOW_LP_FILTER_HZ`].
29const DEFAULT_LOW_HP_FILTER_HZ: f32 = 160.0;
30
31/// Crossover mid/high (low pass)
32///
33/// Overlapping highs, i.e. greater than [`DEFAULT_HIGH_HP_FILTER_HZ`].
34const DEFAULT_HIGH_LP_FILTER_HZ: f32 = 1600.0;
35
36/// Crossover mid/high (high pass)
37const DEFAULT_HIGH_HP_FILTER_HZ: f32 = 1200.0;
38
39#[derive(Debug, Clone, PartialEq)]
40pub struct ThreeBandFilterFreqConfig {
41    pub low_lp_hz: f32,
42    pub low_hp_hz: f32,
43    pub high_lp_hz: f32,
44    pub high_hp_hz: f32,
45}
46
47impl ThreeBandFilterFreqConfig {
48    pub const MIN_FREQ_HZ: f32 = 20.0;
49    pub const MAX_FREQ_HZ: f32 = 20_000.0;
50
51    pub const DEFAULT: Self = Self {
52        low_lp_hz: DEFAULT_LOW_LP_FILTER_HZ,
53        low_hp_hz: DEFAULT_LOW_HP_FILTER_HZ,
54        high_lp_hz: DEFAULT_HIGH_LP_FILTER_HZ,
55        high_hp_hz: DEFAULT_HIGH_HP_FILTER_HZ,
56    };
57}
58
59impl Default for ThreeBandFilterFreqConfig {
60    fn default() -> Self {
61        Self::DEFAULT
62    }
63}
64
65// 3-band crossover using 4th-order Linkwitz-Riley (LR4) LP/HP filters (2 cascaded 2nd-order Butterworth)
66// and two 2nd-order Butterworth LP/HP filters for the mid band.
67#[derive(Debug)]
68struct ThreeBandFilterBank {
69    low_lp: [DirectForm2Transposed<f32>; 2],
70    mid_bp: [DirectForm2Transposed<f32>; 2],
71    high_hp: [DirectForm2Transposed<f32>; 2],
72}
73
74impl ThreeBandFilterBank {
75    #[expect(clippy::needless_pass_by_value)]
76    fn new(fs: Hertz<f32>, config: ThreeBandFilterFreqConfig) -> Self {
77        let ThreeBandFilterFreqConfig {
78            low_lp_hz,
79            low_hp_hz,
80            high_lp_hz,
81            high_hp_hz,
82        } = config;
83        debug_assert!(low_hp_hz >= ThreeBandFilterFreqConfig::MIN_FREQ_HZ);
84        debug_assert!(low_hp_hz <= low_lp_hz); // Overlapping mids with lows
85        debug_assert!(low_lp_hz < high_hp_hz); // Non-empty mids
86        debug_assert!(high_hp_hz <= high_lp_hz); // Overlapping mids with highs
87        debug_assert!(high_lp_hz <= ThreeBandFilterFreqConfig::MAX_FREQ_HZ);
88        let low_lp_f0 = Hertz::<f32>::from_hz(low_lp_hz).expect("valid frequency");
89        let low_lp = DirectForm2Transposed::<f32>::new(
90            Coefficients::<f32>::from_params(
91                biquad::Type::LowPass,
92                fs,
93                low_lp_f0,
94                Q_BUTTERWORTH_F32,
95            )
96            .expect("valid params"),
97        );
98        let low_hp_f0 = Hertz::<f32>::from_hz(low_hp_hz).expect("valid frequency");
99        let low_hp = DirectForm2Transposed::<f32>::new(
100            Coefficients::<f32>::from_params(
101                biquad::Type::HighPass,
102                fs,
103                low_hp_f0,
104                Q_BUTTERWORTH_F32,
105            )
106            .expect("valid params"),
107        );
108        let high_lp_f0 = Hertz::<f32>::from_hz(high_lp_hz).expect("valid frequency");
109        let high_lp = DirectForm2Transposed::<f32>::new(
110            Coefficients::<f32>::from_params(
111                biquad::Type::LowPass,
112                fs,
113                high_lp_f0,
114                Q_BUTTERWORTH_F32,
115            )
116            .expect("valid params"),
117        );
118        let high_hp_f0 = Hertz::<f32>::from_hz(high_hp_hz).expect("valid frequency");
119        let high_hp = DirectForm2Transposed::<f32>::new(
120            Coefficients::<f32>::from_params(
121                biquad::Type::HighPass,
122                fs,
123                high_hp_f0,
124                Q_BUTTERWORTH_F32,
125            )
126            .expect("valid params"),
127        );
128        Self {
129            low_lp: [low_lp, low_lp],
130            mid_bp: [low_hp, high_lp],
131            high_hp: [high_hp, high_hp],
132        }
133    }
134
135    #[expect(clippy::unused_self, reason = "TODO")]
136    #[expect(
137        clippy::missing_const_for_fn,
138        reason = "won't remain const if implemented"
139    )]
140    fn shape_input_signal(&mut self, sample: f32) -> f32 {
141        // TODO: Apply filtering to shape the input signal according to the
142        // ISO 226:2003 equal-loudness-level contour at 40 phons (A-weighting).
143        sample
144    }
145
146    fn run(&mut self, sample: f32) -> FilteredSample {
147        let all = self.shape_input_signal(sample);
148        let Self {
149            low_lp,
150            mid_bp,
151            high_hp,
152        } = self;
153        let low = low_lp
154            .iter_mut()
155            .fold(all, |sample, filter| filter.run(sample));
156        let mid = mid_bp
157            .iter_mut()
158            .fold(all, |sample, filter| filter.run(sample));
159        let high = high_hp
160            .iter_mut()
161            .fold(all, |sample, filter| filter.run(sample));
162        FilteredSample {
163            all,
164            low,
165            mid,
166            high,
167        }
168    }
169}
170
171#[derive(Debug, Default)]
172struct WaveformBinAccumulator {
173    peak: f32,
174    rms_sum: f64,
175}
176
177#[derive(Debug)]
178struct FilteredSample {
179    all: f32,
180    low: f32,
181    mid: f32,
182    high: f32,
183}
184
185impl WaveformBinAccumulator {
186    fn add_sample(&mut self, sample: f32) {
187        let sample_f64 = f64::from(sample);
188        self.peak = self.peak.max(sample.abs());
189        self.rms_sum += sample_f64 * sample_f64;
190    }
191
192    fn finish(self, rms_div: f64) -> WaveformBin {
193        debug_assert!(rms_div > 0.0);
194        let Self { peak, rms_sum } = self;
195        // For a sinusoidal signal, the RMS equals `SQRT_2` times the peak
196        // value. This is a good enough approximation of our expected input
197        // signal and we scale and clamp the RMS accordingly.
198        let energy = ((rms_sum / rms_div).sqrt() * std::f64::consts::SQRT_2).min(1.0);
199        #[expect(clippy::cast_possible_truncation)]
200        WaveformBin {
201            peak: WaveformVal::from_f32(peak),
202            energy: WaveformVal::from_f32(energy as f32),
203        }
204    }
205}
206
207#[derive(Debug, Default)]
208struct FilteredWaveformBinAccumulator {
209    sample_count: u32,
210    all: WaveformBinAccumulator,
211    low: WaveformBinAccumulator,
212    mid: WaveformBinAccumulator,
213    high: WaveformBinAccumulator,
214}
215
216impl FilteredWaveformBinAccumulator {
217    fn add_sample(&mut self, filter_bank: &mut ThreeBandFilterBank, sample: f32) {
218        self.sample_count += 1;
219        let FilteredSample {
220            all,
221            low,
222            mid,
223            high,
224        } = filter_bank.run(sample);
225        self.all.add_sample(all);
226        self.low.add_sample(low);
227        self.mid.add_sample(mid);
228        self.high.add_sample(high);
229    }
230
231    fn finish(self) -> Option<FilteredWaveformBin> {
232        let Self {
233            sample_count,
234            all,
235            low,
236            mid,
237            high,
238        } = self;
239        if sample_count == 0 {
240            return None;
241        }
242        let rms_div = f64::from(sample_count);
243        let all = all.finish(rms_div);
244        let low = low.finish(rms_div);
245        let mid = mid.finish(rms_div);
246        let high = high.finish(rms_div);
247        Some(FilteredWaveformBin {
248            all,
249            low,
250            mid,
251            high,
252        })
253    }
254}
255
256#[derive(Debug, Clone, PartialEq)]
257pub struct WaveformFilterConfig {
258    pub sample_rate_hz: f32,
259    pub bins_per_sec: f32,
260    pub filter_freqs: ThreeBandFilterFreqConfig,
261}
262
263impl WaveformFilterConfig {
264    pub const DEFAULT: Self = Self {
265        sample_rate_hz: DEFAULT_SAMPLE_RATE_HZ,
266        bins_per_sec: DEFAULT_BINS_PER_SEC,
267        filter_freqs: ThreeBandFilterFreqConfig::DEFAULT,
268    };
269}
270
271impl Default for WaveformFilterConfig {
272    fn default() -> Self {
273        Self::DEFAULT
274    }
275}
276
277#[derive(Debug)]
278pub struct WaveformFilter {
279    pending_samples_count: f32,
280    samples_per_bin: f32,
281    filter_bank: ThreeBandFilterBank,
282    filtered_accumulator: FilteredWaveformBinAccumulator,
283}
284
285impl Default for WaveformFilter {
286    fn default() -> Self {
287        Self::new(Default::default())
288    }
289}
290
291impl WaveformFilter {
292    #[must_use]
293    #[expect(clippy::missing_panics_doc)]
294    pub fn new(config: WaveformFilterConfig) -> Self {
295        let WaveformFilterConfig {
296            sample_rate_hz,
297            bins_per_sec,
298            filter_freqs,
299        } = config;
300        let sample_rate = Hertz::<f32>::from_hz(sample_rate_hz).expect("valid sample rate");
301        let samples_per_bin = (sample_rate_hz / bins_per_sec).max(MIN_SAMPLES_PER_BIN);
302        Self {
303            pending_samples_count: 0.0,
304            samples_per_bin,
305            filter_bank: ThreeBandFilterBank::new(sample_rate, filter_freqs),
306            filtered_accumulator: Default::default(),
307        }
308    }
309
310    fn finish_bin(&mut self) -> Option<FilteredWaveformBin> {
311        std::mem::take(&mut self.filtered_accumulator).finish()
312    }
313
314    pub fn add_sample(&mut self, sample: f32) -> Option<FilteredWaveformBin> {
315        let next_bin = if self.pending_samples_count >= self.samples_per_bin {
316            self.pending_samples_count -= self.samples_per_bin;
317            self.finish_bin()
318        } else {
319            None
320        };
321        self.filtered_accumulator
322            .add_sample(&mut self.filter_bank, sample);
323        self.pending_samples_count += 1.0;
324        next_bin
325    }
326
327    #[must_use]
328    pub fn finish(mut self) -> Option<FilteredWaveformBin> {
329        self.finish_bin()
330    }
331}