#![allow(dead_code)]
#![allow(clippy::cast_precision_loss)]
use std::f32::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BandType {
LowCut,
LowShelf,
Peak,
HighShelf,
HighCut,
Notch,
}
impl BandType {
#[must_use]
pub fn order(self) -> u32 {
match self {
BandType::LowCut | BandType::HighCut | BandType::Peak | BandType::Notch => 2,
BandType::LowShelf | BandType::HighShelf => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct EqBand {
pub frequency: f32,
pub gain_db: f32,
pub q: f32,
pub band_type: BandType,
}
impl EqBand {
#[must_use]
pub fn new(frequency: f32, gain_db: f32, q: f32, band_type: BandType) -> Self {
Self {
frequency,
gain_db,
q,
band_type,
}
}
#[must_use]
pub fn peak(frequency: f32, gain_db: f32, q: f32) -> Self {
Self::new(frequency, gain_db, q, BandType::Peak)
}
#[must_use]
pub fn low_shelf(frequency: f32, gain_db: f32) -> Self {
Self::new(frequency, gain_db, 0.707, BandType::LowShelf)
}
#[must_use]
pub fn high_shelf(frequency: f32, gain_db: f32) -> Self {
Self::new(frequency, gain_db, 0.707, BandType::HighShelf)
}
#[must_use]
pub fn low_cut(frequency: f32, q: f32) -> Self {
Self::new(frequency, 0.0, q, BandType::LowCut)
}
#[must_use]
pub fn high_cut(frequency: f32, q: f32) -> Self {
Self::new(frequency, 0.0, q, BandType::HighCut)
}
#[must_use]
pub fn notch(frequency: f32, q: f32) -> Self {
Self::new(frequency, 0.0, q, BandType::Notch)
}
}
#[derive(Debug, Clone, Copy)]
pub struct BiquadCoeff {
pub b0: f32,
pub b1: f32,
pub b2: f32,
pub a1: f32,
pub a2: f32,
}
impl BiquadCoeff {
#[must_use]
pub fn identity() -> Self {
Self {
b0: 1.0,
b1: 0.0,
b2: 0.0,
a1: 0.0,
a2: 0.0,
}
}
#[must_use]
pub fn from_band(band: &EqBand, sample_rate: f32) -> Self {
let w0 = 2.0 * PI * band.frequency / sample_rate;
let cos_w0 = w0.cos();
let sin_w0 = w0.sin();
let alpha = sin_w0 / (2.0 * band.q);
match band.band_type {
BandType::Peak => {
let a = 10.0_f32.powf(band.gain_db / 40.0);
let b0 = 1.0 + alpha * a;
let b1 = -2.0 * cos_w0;
let b2 = 1.0 - alpha * a;
let a0 = 1.0 + alpha / a;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha / a;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
BandType::LowShelf => {
let a = 10.0_f32.powf(band.gain_db / 40.0);
let sqrt_a = a.sqrt();
let b0 = a * ((a + 1.0) - (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha);
let b1 = 2.0 * a * ((a - 1.0) - (a + 1.0) * cos_w0);
let b2 = a * ((a + 1.0) - (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha);
let a0 = (a + 1.0) + (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha;
let a1 = -2.0 * ((a - 1.0) + (a + 1.0) * cos_w0);
let a2 = (a + 1.0) + (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
BandType::HighShelf => {
let a = 10.0_f32.powf(band.gain_db / 40.0);
let sqrt_a = a.sqrt();
let b0 = a * ((a + 1.0) + (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha);
let b1 = -2.0 * a * ((a - 1.0) + (a + 1.0) * cos_w0);
let b2 = a * ((a + 1.0) + (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha);
let a0 = (a + 1.0) - (a - 1.0) * cos_w0 + 2.0 * sqrt_a * alpha;
let a1 = 2.0 * ((a - 1.0) - (a + 1.0) * cos_w0);
let a2 = (a + 1.0) - (a - 1.0) * cos_w0 - 2.0 * sqrt_a * alpha;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
BandType::LowCut => {
let b0 = (1.0 + cos_w0) / 2.0;
let b1 = -(1.0 + cos_w0);
let b2 = (1.0 + cos_w0) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
BandType::HighCut => {
let b0 = (1.0 - cos_w0) / 2.0;
let b1 = 1.0 - cos_w0;
let b2 = (1.0 - cos_w0) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
BandType::Notch => {
let b0 = 1.0;
let b1 = -2.0 * cos_w0;
let b2 = 1.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_w0;
let a2 = 1.0 - alpha;
Self::normalize(b0, b1, b2, a0, a1, a2)
}
}
}
fn normalize(b0: f32, b1: f32, b2: f32, a0: f32, a1: f32, a2: f32) -> Self {
Self {
b0: b0 / a0,
b1: b1 / a0,
b2: b2 / a0,
a1: a1 / a0,
a2: a2 / a0,
}
}
}
pub struct BiquadFilter {
coeff: BiquadCoeff,
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl BiquadFilter {
#[must_use]
pub fn new(coeff: BiquadCoeff) -> Self {
Self {
coeff,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
#[must_use]
pub fn identity() -> Self {
Self::new(BiquadCoeff::identity())
}
#[must_use]
pub fn from_band(band: &EqBand, sample_rate: f32) -> Self {
Self::new(BiquadCoeff::from_band(band, sample_rate))
}
pub fn set_coeff(&mut self, coeff: BiquadCoeff) {
self.coeff = coeff;
}
pub fn process_sample(&mut self, x: f32) -> f32 {
let c = &self.coeff;
let y = c.b0 * x + self.x1;
self.x1 = c.b1 * x - c.a1 * y + self.x2;
self.x2 = c.b2 * x - c.a2 * y;
self.y1 = y;
self.y2 = self.y1;
y
}
pub fn reset(&mut self) {
self.x1 = 0.0;
self.x2 = 0.0;
self.y1 = 0.0;
self.y2 = 0.0;
}
}
pub struct ParametricEq {
pub bands: Vec<(EqBand, BiquadFilter)>,
}
impl ParametricEq {
#[must_use]
pub fn new() -> Self {
Self { bands: Vec::new() }
}
pub fn add_band(&mut self, band: EqBand, sample_rate: f32) {
let filter = BiquadFilter::from_band(&band, sample_rate);
self.bands.push((band, filter));
}
#[must_use]
pub fn process(&mut self, samples: &[f32]) -> Vec<f32> {
let mut output: Vec<f32> = samples.to_vec();
for sample in &mut output {
for (_, filter) in &mut self.bands {
*sample = filter.process_sample(*sample);
}
}
output
}
pub fn reset(&mut self) {
for (_, filter) in &mut self.bands {
filter.reset();
}
}
#[must_use]
pub fn broadcast_presence_boost() -> Self {
let mut eq = Self::new();
let sr = 48000.0;
eq.add_band(EqBand::low_cut(80.0, 0.707), sr);
eq.add_band(EqBand::peak(4000.0, 3.0, 1.5), sr);
eq.add_band(EqBand::high_shelf(12000.0, 1.5), sr);
eq
}
#[must_use]
pub fn bass_boost() -> Self {
let mut eq = Self::new();
let sr = 48000.0;
eq.add_band(EqBand::low_shelf(120.0, 6.0), sr);
eq.add_band(EqBand::peak(300.0, -2.0, 1.0), sr);
eq.add_band(EqBand::high_shelf(8000.0, 1.0), sr);
eq
}
#[must_use]
pub fn vocal_clarity() -> Self {
let mut eq = Self::new();
let sr = 48000.0;
eq.add_band(EqBand::low_cut(120.0, 0.707), sr);
eq.add_band(EqBand::peak(250.0, -2.0, 0.8), sr);
eq.add_band(EqBand::peak(3500.0, 3.0, 1.2), sr);
eq.add_band(EqBand::peak(8000.0, -2.0, 2.0), sr);
eq.add_band(EqBand::high_shelf(10000.0, 2.0), sr);
eq
}
}
impl Default for ParametricEq {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sine(freq: f32, sample_rate: f32, num_samples: usize) -> Vec<f32> {
(0..num_samples)
.map(|i| (2.0 * PI * freq * i as f32 / sample_rate).sin())
.collect()
}
#[test]
fn test_band_type_order() {
assert_eq!(BandType::LowCut.order(), 2);
assert_eq!(BandType::LowShelf.order(), 1);
assert_eq!(BandType::Peak.order(), 2);
assert_eq!(BandType::HighShelf.order(), 1);
assert_eq!(BandType::HighCut.order(), 2);
assert_eq!(BandType::Notch.order(), 2);
}
#[test]
fn test_biquad_identity() {
let mut filter = BiquadFilter::identity();
let output = filter.process_sample(1.0);
assert!(output.is_finite());
}
#[test]
fn test_peak_band_coefficients() {
let band = EqBand::peak(1000.0, 6.0, 1.0);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
assert!(coeff.b1.is_finite());
assert!(coeff.b2.is_finite());
assert!(coeff.a1.is_finite());
assert!(coeff.a2.is_finite());
}
#[test]
fn test_low_shelf_coefficients() {
let band = EqBand::low_shelf(200.0, 6.0);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
assert!(coeff.b0 > 0.0);
}
#[test]
fn test_high_shelf_coefficients() {
let band = EqBand::high_shelf(8000.0, -6.0);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
}
#[test]
fn test_low_cut_coefficients() {
let band = EqBand::low_cut(80.0, 0.707);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
}
#[test]
fn test_high_cut_coefficients() {
let band = EqBand::high_cut(16000.0, 0.707);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
}
#[test]
fn test_notch_coefficients() {
let band = EqBand::notch(1000.0, 5.0);
let coeff = BiquadCoeff::from_band(&band, 48000.0);
assert!(coeff.b0.is_finite());
}
#[test]
fn test_filter_output_is_finite() {
let band = EqBand::peak(1000.0, 6.0, 1.0);
let mut filter = BiquadFilter::from_band(&band, 48000.0);
let sine = make_sine(440.0, 48000.0, 512);
for s in &sine {
let out = filter.process_sample(*s);
assert!(out.is_finite(), "Output not finite: {out}");
}
}
#[test]
fn test_parametric_eq_process() {
let mut eq = ParametricEq::new();
eq.add_band(EqBand::peak(1000.0, 6.0, 1.0), 48000.0);
let input = vec![0.5f32; 128];
let output = eq.process(&input);
assert_eq!(output.len(), 128);
assert!(output.iter().all(|&s| s.is_finite()));
}
#[test]
fn test_broadcast_presence_boost() {
let mut eq = ParametricEq::broadcast_presence_boost();
assert_eq!(eq.bands.len(), 3);
let input = make_sine(1000.0, 48000.0, 256);
let output = eq.process(&input);
assert!(output.iter().all(|&s| s.is_finite()));
}
#[test]
fn test_bass_boost() {
let mut eq = ParametricEq::bass_boost();
assert_eq!(eq.bands.len(), 3);
let input = make_sine(100.0, 48000.0, 256);
let output = eq.process(&input);
assert!(output.iter().all(|&s| s.is_finite()));
}
#[test]
fn test_vocal_clarity() {
let mut eq = ParametricEq::vocal_clarity();
assert_eq!(eq.bands.len(), 5);
let input = make_sine(3000.0, 48000.0, 256);
let output = eq.process(&input);
assert!(output.iter().all(|&s| s.is_finite()));
}
#[test]
fn test_filter_reset() {
let band = EqBand::peak(1000.0, 6.0, 1.0);
let mut filter = BiquadFilter::from_band(&band, 48000.0);
filter.process_sample(1.0);
filter.reset();
assert_eq!(filter.x1, 0.0);
assert_eq!(filter.x2, 0.0);
}
#[test]
fn test_eq_multi_band() {
let mut eq = ParametricEq::new();
eq.add_band(EqBand::low_cut(80.0, 0.707), 48000.0);
eq.add_band(EqBand::peak(500.0, -3.0, 1.0), 48000.0);
eq.add_band(EqBand::peak(4000.0, 4.0, 1.5), 48000.0);
eq.add_band(EqBand::high_shelf(10000.0, 2.0), 48000.0);
let input = vec![0.1f32; 512];
let output = eq.process(&input);
assert!(output.iter().all(|&s| s.is_finite()));
}
#[test]
fn test_peak_boost_increases_level_near_center() {
let mut eq = ParametricEq::new();
eq.add_band(EqBand::peak(1000.0, 12.0, 1.0), 48000.0);
let settle: Vec<f32> = make_sine(1000.0, 48000.0, 2048);
let _ = eq.process(&settle);
let input = make_sine(1000.0, 48000.0, 512);
let output = eq.process(&input);
let in_rms: f32 = (input.iter().map(|&s| s * s).sum::<f32>() / input.len() as f32).sqrt();
let out_rms: f32 =
(output.iter().map(|&s| s * s).sum::<f32>() / output.len() as f32).sqrt();
assert!(
out_rms > in_rms,
"Expected out_rms {out_rms} > in_rms {in_rms}"
);
}
}