#![allow(clippy::cast_precision_loss)]
use std::f32::consts::PI;
#[derive(Debug, Clone, PartialEq)]
pub enum EqBandType {
LowShelf,
HighShelf,
Peaking,
Notch,
LowPass,
HighPass,
AllPass,
}
#[derive(Debug, Clone)]
pub struct EqBand {
pub band_type: EqBandType,
pub frequency_hz: f32,
pub gain_db: f32,
pub q: f32,
pub enabled: bool,
}
impl EqBand {
#[must_use]
pub fn peaking(freq_hz: f32, gain_db: f32, q: f32) -> Self {
Self {
band_type: EqBandType::Peaking,
frequency_hz: freq_hz,
gain_db,
q,
enabled: true,
}
}
#[must_use]
pub fn low_shelf(freq_hz: f32, gain_db: f32) -> Self {
Self {
band_type: EqBandType::LowShelf,
frequency_hz: freq_hz,
gain_db,
q: 0.707,
enabled: true,
}
}
#[must_use]
pub fn high_shelf(freq_hz: f32, gain_db: f32) -> Self {
Self {
band_type: EqBandType::HighShelf,
frequency_hz: freq_hz,
gain_db,
q: 0.707,
enabled: true,
}
}
#[must_use]
pub fn notch(freq_hz: f32, q: f32) -> Self {
Self {
band_type: EqBandType::Notch,
frequency_hz: freq_hz,
gain_db: 0.0,
q,
enabled: true,
}
}
#[must_use]
pub fn low_pass(freq_hz: f32, q: f32) -> Self {
Self {
band_type: EqBandType::LowPass,
frequency_hz: freq_hz,
gain_db: 0.0,
q,
enabled: true,
}
}
#[must_use]
pub fn high_pass(freq_hz: f32, q: f32) -> Self {
Self {
band_type: EqBandType::HighPass,
frequency_hz: freq_hz,
gain_db: 0.0,
q,
enabled: true,
}
}
#[must_use]
pub fn new(freq_hz: f32, gain_db: f32, q: f32) -> Self {
Self::peaking(freq_hz, gain_db, q)
}
#[must_use]
pub fn apply(&self, samples: &[f32], sample_rate: u32) -> Vec<f32> {
let sr = sample_rate as f32;
let coeffs = self.compute_biquad(sr);
let mut state = BiquadState::default();
if !self.enabled {
return samples.to_vec();
}
samples
.iter()
.map(|&s| state.process_sample(s, &coeffs))
.collect()
}
#[must_use]
pub fn compute_biquad(&self, sample_rate: f32) -> [f32; 6] {
let w0 = 2.0 * PI * self.frequency_hz / sample_rate;
let cos_w0 = w0.cos();
let sin_w0 = w0.sin();
let alpha = sin_w0 / (2.0 * self.q.max(f32::EPSILON));
match self.band_type {
EqBandType::Peaking => {
let a = 10.0_f32.powf(self.gain_db / 40.0);
let b0 = 1.0 + alpha * a;
let b1 = -2.0 * cos_w0;
let b2 = 1.0 - alpha * a;
let a0 = 1.0 + alpha / a;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha / a;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::LowShelf => {
let a = 10.0_f32.powf(self.gain_db / 40.0);
let sqrt_a = a.sqrt();
let b0 = a * ((a + 1.0) - (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha);
let b1 = 2.0 * a * ((a - 1.0) - (a + 1.0) * cos_w0);
let b2 = a * ((a + 1.0) - (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha);
let a0 = (a + 1.0) + (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha;
let a1 = -2.0 * ((a - 1.0) + (a + 1.0) * cos_w0);
let a2 = (a + 1.0) + (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::HighShelf => {
let a = 10.0_f32.powf(self.gain_db / 40.0);
let sqrt_a = a.sqrt();
let b0 = a * ((a + 1.0) + (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha);
let b1 = -2.0 * a * ((a - 1.0) + (a + 1.0) * cos_w0);
let b2 = a * ((a + 1.0) + (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha);
let a0 = (a + 1.0) - (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha;
let a1 = 2.0 * ((a - 1.0) - (a + 1.0) * cos_w0);
let a2 = (a + 1.0) - (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::Notch => {
let b0 = 1.0;
let b1 = -2.0 * cos_w0;
let b2 = 1.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::LowPass => {
let b0 = (1.0 - cos_w0) / 2.0;
let b1 = 1.0 - cos_w0;
let b2 = (1.0 - cos_w0) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::HighPass => {
let b0 = (1.0 + cos_w0) / 2.0;
let b1 = -(1.0 + cos_w0);
let b2 = (1.0 + cos_w0) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
[b0, b1, b2, a0, a1, a2]
}
EqBandType::AllPass => {
let b0 = 1.0 - alpha;
let b1 = -2.0 * cos_w0;
let b2 = 1.0 + alpha;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
[b0, b1, b2, a0, a1, a2]
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BiquadState {
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl BiquadState {
#[inline]
pub fn process_sample(&mut self, x: f32, coeffs: &[f32; 6]) -> f32 {
let [b0, b1, b2, a0, a1, a2] = *coeffs;
let a0_safe = if a0.abs() < f32::EPSILON { 1.0 } else { a0 };
let y = (b0 / a0_safe) * x
+ (b1 / a0_safe) * self.x1
+ (b2 / a0_safe) * self.x2
- (a1 / a0_safe) * self.y1
- (a2 / a0_safe) * self.y2;
self.x2 = self.x1;
self.x1 = x;
self.y2 = self.y1;
self.y1 = y;
y
}
pub fn reset(&mut self) {
self.x1 = 0.0;
self.x2 = 0.0;
self.y1 = 0.0;
self.y2 = 0.0;
}
}
pub struct ParametricEq {
pub bands: Vec<EqBand>,
states: Vec<BiquadState>,
pub sample_rate: f32,
}
impl ParametricEq {
#[must_use]
pub fn new(sample_rate: f32) -> Self {
Self {
bands: Vec::new(),
states: Vec::new(),
sample_rate,
}
}
#[must_use]
pub fn with_band(mut self, band: EqBand) -> Self {
self.add_band(band);
self
}
pub fn add_band(&mut self, band: EqBand) {
self.bands.push(band);
self.states.push(BiquadState::default());
}
pub fn set_band_gain(&mut self, index: usize, gain_db: f32) -> Result<(), String> {
if let Some(band) = self.bands.get_mut(index) {
band.gain_db = gain_db;
Ok(())
} else {
Err(format!(
"Band index {index} out of range (have {} bands)",
self.bands.len()
))
}
}
pub fn set_band_enabled(&mut self, index: usize, enabled: bool) -> Result<(), String> {
if let Some(band) = self.bands.get_mut(index) {
band.enabled = enabled;
Ok(())
} else {
Err(format!(
"Band index {index} out of range (have {} bands)",
self.bands.len()
))
}
}
pub fn process_sample(&mut self, mut sample: f32) -> f32 {
for (band, state) in self.bands.iter().zip(self.states.iter_mut()) {
if band.enabled {
let coeffs = band.compute_biquad(self.sample_rate);
sample = state.process_sample(sample, &coeffs);
}
}
sample
}
#[must_use]
pub fn process_buffer(&mut self, input: &[f32]) -> Vec<f32> {
input.iter().map(|&s| self.process_sample(s)).collect()
}
pub fn reset(&mut self) {
for state in &mut self.states {
state.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const SR: f32 = 48_000.0;
fn make_sine(freq_hz: f32, sr: f32, n: usize) -> Vec<f32> {
(0..n)
.map(|i| (2.0 * PI * freq_hz * i as f32 / sr).sin())
.collect()
}
fn rms(buf: &[f32]) -> f32 {
(buf.iter().map(|&s| s * s).sum::<f32>() / buf.len() as f32).sqrt()
}
#[test]
fn test_flat_eq_passes_signal_unchanged() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 0.0, 1.0));
let input: Vec<f32> = make_sine(440.0, SR, 512);
let output = eq.process_buffer(&input);
for (i, (&a, &b)) in input.iter().zip(output.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"Flat EQ should be unity, sample {i}: in={a}, out={b}"
);
}
}
#[test]
fn test_peaking_band_boosts_at_center_freq() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 12.0, 1.0));
let settle = make_sine(1000.0, SR, 2048);
let _ = eq.process_buffer(&settle);
let input = make_sine(1000.0, SR, 512);
let output = eq.process_buffer(&input);
assert!(
rms(&output) > rms(&input),
"Peak +12 dB at 1 kHz should increase RMS of 1 kHz sine"
);
}
#[test]
fn test_notch_reduces_at_center_freq() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::notch(1000.0, 1.0));
let settle = make_sine(1000.0, SR, 8192);
let _ = eq.process_buffer(&settle);
let input = make_sine(1000.0, SR, 1024);
let output = eq.process_buffer(&input);
assert!(
rms(&output) < rms(&input) * 0.8,
"Notch at 1 kHz should reduce RMS: in={:.4}, out={:.4}",
rms(&input),
rms(&output)
);
}
#[test]
fn test_low_shelf_affects_low_frequencies() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::low_shelf(200.0, 6.0));
let settle = make_sine(100.0, SR, 2048);
let _ = eq.process_buffer(&settle);
let input = make_sine(100.0, SR, 512);
let output = eq.process_buffer(&input);
assert!(
rms(&output) > rms(&input),
"Low shelf +6 dB should boost 100 Hz: in={:.4}, out={:.4}",
rms(&input),
rms(&output)
);
}
#[test]
fn test_process_buffer_length_correct() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 6.0, 1.0));
let input = vec![0.5_f32; 128];
let output = eq.process_buffer(&input);
assert_eq!(output.len(), 128);
}
#[test]
fn test_reset_clears_state() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 6.0, 1.0));
let _ = eq.process_buffer(&vec![1.0_f32; 64]);
eq.reset();
assert_eq!(eq.process_sample(0.0), 0.0, "reset should clear history");
}
#[test]
fn test_disabled_band_bypasses_signal() {
let band = EqBand {
band_type: EqBandType::Peaking,
frequency_hz: 1000.0,
gain_db: 20.0,
q: 1.0,
enabled: false,
};
let mut eq = ParametricEq::new(SR).with_band(band);
let input: Vec<f32> = make_sine(1000.0, SR, 256);
let output = eq.process_buffer(&input);
for (i, (&a, &b)) in input.iter().zip(output.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"Disabled band should pass unchanged, sample {i}: in={a}, out={b}"
);
}
}
#[test]
fn test_set_band_gain_updates_correctly() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 0.0, 1.0));
eq.set_band_gain(0, 6.0).expect("index 0 should be valid");
assert!((eq.bands[0].gain_db - 6.0).abs() < 1e-6);
}
#[test]
fn test_set_band_gain_out_of_range() {
let mut eq = ParametricEq::new(SR);
let result = eq.set_band_gain(5, 3.0);
assert!(result.is_err(), "Out-of-range index should return Err");
}
#[test]
fn test_set_band_enabled_toggles() {
let mut eq = ParametricEq::new(SR).with_band(EqBand::peaking(1000.0, 12.0, 1.0));
eq.set_band_enabled(0, false).expect("index 0 should be valid");
assert!(!eq.bands[0].enabled);
eq.set_band_enabled(0, true).expect("index 0 should be valid");
assert!(eq.bands[0].enabled);
}
#[test]
fn test_all_outputs_finite() {
let mut eq = ParametricEq::new(SR)
.with_band(EqBand::low_pass(5000.0, 0.707))
.with_band(EqBand::high_pass(80.0, 0.707))
.with_band(EqBand::peaking(1000.0, 3.0, 1.5))
.with_band(EqBand::notch(2000.0, 8.0));
let sine = make_sine(500.0, SR, 1024);
let output = eq.process_buffer(&sine);
for (i, &s) in output.iter().enumerate() {
assert!(s.is_finite(), "Output at sample {i} is not finite: {s}");
}
}
}