use num_traits::FloatConst;
use crate::operations::iir_filtering::IirFilter;
use crate::operations::traits::AudioParametricEq;
use crate::operations::types::{EqBand, EqBandType, ParametricEq};
use crate::traits::StandardSample;
use crate::utils::audio_math::db_to_amplitude as db_to_linear;
use crate::{AudioData, AudioSampleError, LayoutError, ParameterError};
use crate::{AudioSampleResult, AudioSamples, AudioTypeConversion, ConvertTo};
impl<T> AudioParametricEq for AudioSamples<'_, T>
where
T: StandardSample,
{
#[inline]
fn apply_parametric_eq(&mut self, eq: &ParametricEq) -> AudioSampleResult<()> {
let sample_rate = self.sample_rate_hz();
if eq.is_bypassed() {
return Ok(());
}
let eq = eq.clone().validate(sample_rate)?;
for band in &eq.bands {
if band.is_enabled() {
self.apply_eq_band(band)?;
}
}
if eq.output_gain_db != 0.0 {
let output_gain_linear = db_to_linear(eq.output_gain_db);
self.apply_linear_gain(output_gain_linear);
}
Ok(())
}
#[inline]
fn apply_eq_band(&mut self, band: &EqBand) -> AudioSampleResult<()> {
let sample_rate = self.sample_rate_hz();
if !band.is_enabled() {
return Ok(());
}
band.validate(sample_rate)?;
let (b_coeffs, a_coeffs) = design_eq_band_filter(band, sample_rate);
let mut filter = IirFilter::new(b_coeffs, a_coeffs);
match &mut self.data {
AudioData::Mono(_) => {
let mut working_samples = self.as_float();
let Some(mono_self) = self.as_mono_mut() else {
return Err(AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "parametric EQ".to_string(),
layout_type: "non-contiguous mono samples".to_string(),
}));
};
let working_samples = working_samples.as_mono_mut().ok_or_else(|| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Failed to get mono data. Underlying data is not mono.",
))
})?;
let working_samples = working_samples.as_slice_mut();
filter.process_samples_in_place(working_samples);
for (i, output) in working_samples.iter_mut().enumerate() {
mono_self[i] = (*output).convert_to();
}
}
AudioData::Multi(samples) => {
let num_channels = samples.nrows().get();
for channel in 0..num_channels {
let mut working_samples = self.as_float();
let Some(multi_self) = self.as_multi_channel_mut() else {
return Err(AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "parametric EQ".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
}));
};
let working_samples = working_samples.as_multi_channel_mut() .ok_or_else(|| AudioSampleError::Parameter(ParameterError::invalid_value(
"audio_format",
"Failed to get multi-channel data. Underlying data is not multi-channel."
)))?;
let working_samples = working_samples.as_slice_mut().ok_or_else(|| {
AudioSampleError::Layout(LayoutError::NonContiguous {
operation: "parametric EQ".to_string(),
layout_type: "non-contiguous multi-channel samples".to_string(),
})
})?;
filter.process_samples_in_place(working_samples);
for (i, output) in working_samples.iter().enumerate() {
multi_self[[channel, i]] = (*output).convert_to();
}
filter.reset();
}
}
}
Ok(())
}
#[inline]
fn apply_peak_filter(
&mut self,
frequency: f64,
gain_db: f64,
q_factor: f64,
) -> AudioSampleResult<()> {
let band = EqBand::peak(frequency, gain_db, q_factor);
self.apply_eq_band(&band)
}
#[inline]
fn apply_low_shelf(
&mut self,
frequency: f64,
gain_db: f64,
q_factor: f64,
) -> AudioSampleResult<()> {
let band = EqBand::low_shelf(frequency, gain_db, q_factor);
self.apply_eq_band(&band)
}
#[inline]
fn apply_high_shelf(
&mut self,
frequency: f64,
gain_db: f64,
q_factor: f64,
) -> AudioSampleResult<()> {
let band = EqBand::high_shelf(frequency, gain_db, q_factor);
self.apply_eq_band(&band)
}
#[inline]
fn apply_three_band_eq(
&mut self,
low_freq: f64,
low_gain: f64,
mid_freq: f64,
mid_gain: f64,
mid_q: f64,
high_freq: f64,
high_gain: f64,
) -> AudioSampleResult<()> {
let eq = ParametricEq::three_band(
low_freq, low_gain, mid_freq, mid_gain, mid_q, high_freq, high_gain,
);
self.apply_parametric_eq(&eq)
}
#[inline]
fn eq_frequency_response(
&self,
eq: &ParametricEq,
frequencies: &[f64],
) -> AudioSampleResult<(Vec<f64>, Vec<f64>)> {
let mut combined_magnitude = vec![1.0; frequencies.len()];
let mut combined_phase = vec![0.0; frequencies.len()];
let sample_rate = self.sample_rate().get();
for band in &eq.bands {
if !band.is_enabled() {
continue;
}
let (b_coeffs, a_coeffs) = design_eq_band_filter(band, f64::from(sample_rate));
let filter = IirFilter::new(b_coeffs, a_coeffs);
let (magnitude, phase) = filter.frequency_response(frequencies, f64::from(sample_rate));
for i in 0..frequencies.len() {
combined_magnitude[i] *= magnitude[i];
combined_phase[i] += phase[i];
}
}
if eq.output_gain_db != 0.0 {
let output_gain_linear = db_to_linear(eq.output_gain_db);
for magnitude in &mut combined_magnitude {
*magnitude *= output_gain_linear;
}
}
Ok((combined_magnitude, combined_phase))
}
}
fn design_eq_band_filter(band: &EqBand, sample_rate: f64) -> (Vec<f64>, Vec<f64>) {
match band.band_type {
EqBandType::Peak => {
design_peak_filter(band.frequency, band.gain_db, band.q_factor, sample_rate)
}
EqBandType::LowShelf => {
design_low_shelf_filter(band.frequency, band.gain_db, band.q_factor, sample_rate)
}
EqBandType::HighShelf => {
design_high_shelf_filter(band.frequency, band.gain_db, band.q_factor, sample_rate)
}
EqBandType::LowPass => design_lowpass_filter(band.frequency, band.q_factor, sample_rate),
EqBandType::HighPass => design_highpass_filter(band.frequency, band.q_factor, sample_rate),
EqBandType::BandPass => design_bandpass_filter(band.frequency, band.q_factor, sample_rate),
EqBandType::BandStop => design_bandstop_filter(band.frequency, band.q_factor, sample_rate),
}
}
fn design_peak_filter(
frequency: f64,
gain_db: f64,
q_factor: f64,
sample_rate: f64,
) -> (Vec<f64>, Vec<f64>) {
let a = 10.04f64.powf(gain_db / 40.0); let omega = 2.0 * std::f64::consts::PI * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let b0 = 1.0 + alpha * a;
let b1 = -2.0 * cos_omega;
let b2 = 1.0 - alpha * a;
let a0 = 1.0 + alpha / a;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha / a;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_low_shelf_filter(
frequency: f64,
gain_db: f64,
q_factor: f64,
sample_rate: f64,
) -> (Vec<f64>, Vec<f64>) {
let a = 10.0f64.powf(gain_db / 40.0); let omega = 2.0 * std::f64::consts::PI * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let sqrt_2a = (2.0 * a).sqrt();
let b0 = a * ((a - 1.0).mul_add(-cos_omega, a + 1.0) + sqrt_2a * alpha);
let b1 = 2.0 * a * (a + 1.0).mul_add(-cos_omega, a - 1.0);
let b2 = a * ((a - 1.0).mul_add(-cos_omega, a + 1.0) - sqrt_2a * alpha);
let a0 = (a - 1.0).mul_add(cos_omega, a + 1.0) + sqrt_2a * alpha;
let a1 = -2.0 * (a + 1.0).mul_add(cos_omega, a - 1.0);
let a2 = (a - 1.0).mul_add(cos_omega, a + 1.0) - sqrt_2a * alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_high_shelf_filter(
frequency: f64,
gain_db: f64,
q_factor: f64,
sample_rate: f64,
) -> (Vec<f64>, Vec<f64>) {
let a = 10.0f64.powf(gain_db / 40.0); let omega = 2.0 * f64::PI() * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let sqrt_2a = (2.0 * a).sqrt();
let b0 = a * ((a - 1.0).mul_add(cos_omega, a + 1.0) + sqrt_2a * alpha);
let b1 = -2.0 * a * (a + 1.0).mul_add(cos_omega, a - 1.0);
let b2 = a * ((a - 1.0).mul_add(cos_omega, a + 1.0) - sqrt_2a * alpha);
let a0 = (a - 1.0).mul_add(-cos_omega, a + 1.0) + sqrt_2a * alpha;
let a1 = 2.0 * (a + 1.0).mul_add(-cos_omega, a - 1.0);
let a2 = (a - 1.0).mul_add(-cos_omega, a + 1.0) - sqrt_2a * alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_lowpass_filter(frequency: f64, q_factor: f64, sample_rate: f64) -> (Vec<f64>, Vec<f64>) {
let omega = 2.0 * std::f64::consts::PI * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let b0 = (1.0 - cos_omega) / 2.0;
let b1 = 1.0 - cos_omega;
let b2 = (1.0 - cos_omega) / 2.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_highpass_filter(frequency: f64, q_factor: f64, sample_rate: f64) -> (Vec<f64>, Vec<f64>) {
let omega = 2.0 * std::f64::consts::PI * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let b0 = f64::midpoint(1.0, cos_omega);
let b1 = -(1.0 + cos_omega);
let b2 = f64::midpoint(1.0, cos_omega);
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_bandpass_filter(frequency: f64, q_factor: f64, sample_rate: f64) -> (Vec<f64>, Vec<f64>) {
let omega = 2.0 * f64::PI() * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let b0 = alpha;
let b1 = 0.0;
let b2 = -alpha;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
fn design_bandstop_filter(frequency: f64, q_factor: f64, sample_rate: f64) -> (Vec<f64>, Vec<f64>) {
let omega = 2.0 * std::f64::consts::PI * frequency / sample_rate;
let (sin_omega, cos_omega) = omega.sin_cos();
let alpha = sin_omega / (2.0 * q_factor);
let b0 = 1.0;
let b1 = -2.0 * cos_omega;
let b2 = 1.0;
let a0 = 1.0 + alpha;
let a1 = -2.0 * cos_omega;
let a2 = 1.0 - alpha;
let b_coeffs = vec![b0 / a0, b1 / a0, b2 / a0];
let a_coeffs = vec![1.0, a1 / a0, a2 / a0];
(b_coeffs, a_coeffs)
}
impl<T> AudioSamples<'_, T>
where
T: StandardSample,
{
fn apply_linear_gain(&mut self, gain: f64) {
self.apply(|x| {
let x_f: f64 = T::cast_into(x);
let y_f = x_f * gain;
T::cast_from(y_f)
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operations::traits::AudioParametricEq;
use crate::sample_rate;
use crate::utils::audio_math::amplitude_to_db as linear_to_db;
use non_empty_slice::{NonEmptyVec, non_empty_vec};
use std::f64::consts::PI;
#[test]
fn test_peak_filter() {
let sample_rate = 44100.0;
let duration = 0.1;
let samples_count = (sample_rate * duration) as usize;
let mut samples = Vec::new();
for i in 0..samples_count {
let t = i as f64 / sample_rate;
let value = (2.0 * PI * 440.0 * t).sin()
+ (2.0 * PI * 880.0 * t).sin()
+ (2.0 * PI * 1760.0 * t).sin();
samples.push(value as f32);
}
let samples = NonEmptyVec::new(samples).unwrap();
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(samples, sample_rate!(44100));
let result = audio.apply_peak_filter(880.0, 6.0, 2.0);
assert!(result.is_ok());
}
#[test]
fn test_low_shelf_filter() {
let sample_rate = 44100.0;
let duration = 0.1;
let samples_count = (sample_rate * duration) as usize;
let mut samples = Vec::new();
for i in 0..samples_count {
let t = i as f64 / sample_rate;
let value = (2.0 * PI * 100.0 * t).sin()
+ (2.0 * PI * 1000.0 * t).sin()
+ (2.0 * PI * 5000.0 * t).sin();
samples.push(value as f32);
}
let samples = NonEmptyVec::new(samples).unwrap();
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(samples, sample_rate!(44100));
let result = audio.apply_low_shelf(500.0, -3.0, 0.707);
assert!(result.is_ok());
}
#[test]
fn test_high_shelf_filter() {
let sample_rate = 44100.0;
let duration = 0.1;
let samples_count = (sample_rate * duration) as usize;
let mut samples = Vec::new();
for i in 0..samples_count {
let t = i as f64 / sample_rate;
let value = (2.0 * PI * 100.0 * t).sin()
+ (2.0 * PI * 1000.0 * t).sin()
+ (2.0 * PI * 5000.0 * t).sin();
samples.push(value as f32);
}
let samples = NonEmptyVec::new(samples).unwrap();
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(samples, sample_rate!(44100));
let result = audio.apply_high_shelf(2000.0, 4.0, 0.707);
assert!(result.is_ok());
}
#[test]
fn test_three_band_eq() {
let sample_rate = 44100.0;
let duration = 0.1;
let samples_count = (sample_rate * duration) as usize;
let mut samples = Vec::new();
for i in 0..samples_count {
let t = i as f64 / sample_rate;
let value = (2.0 * PI * 100.0 * t).sin()
+ (2.0 * PI * 1000.0 * t).sin()
+ (2.0 * PI * 5000.0 * t).sin();
samples.push(value as f32);
}
let samples = NonEmptyVec::new(samples).unwrap();
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(samples, sample_rate!(44100));
let result = audio.apply_three_band_eq(200.0, -2.0, 1000.0, 3.0, 2.0, 4000.0, 1.0);
assert!(result.is_ok());
}
#[test]
fn test_parametric_eq_configuration() {
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(non_empty_vec![1.0f32, 0.0, -1.0], sample_rate!(44100));
let mut eq = ParametricEq::new();
eq.add_band(EqBand::peak(1000.0, 3.0, 2.0));
eq.add_band(EqBand::low_shelf(100.0, -2.0, 0.707));
eq.set_output_gain(1.0);
let result = audio.apply_parametric_eq(&eq);
assert!(result.is_ok());
assert_eq!(eq.band_count(), 2);
assert_eq!(eq.output_gain_db, 1.0);
assert!(!eq.is_bypassed());
}
#[test]
fn test_eq_band_validation() {
let sample_rate = 44100.0;
let valid_band = EqBand::peak(1000.0, 3.0, 2.0);
assert!(valid_band.validate(sample_rate).is_ok());
let invalid_band = EqBand::peak(sample_rate, 3.0, 2.0);
assert!(invalid_band.validate(sample_rate).is_err());
let invalid_band = EqBand::peak(1000.0, 3.0, 0.0);
assert!(invalid_band.validate(sample_rate).is_err());
let extreme_band = EqBand::peak(1000.0, 50.0, 2.0);
assert!(extreme_band.validate(sample_rate).is_err());
}
#[test]
fn test_eq_band_enable_disable() {
let mut band = EqBand::peak(1000.0, 3.0, 2.0);
assert!(band.is_enabled());
band.set_enabled(false);
assert!(!band.is_enabled());
band.set_enabled(true);
assert!(band.is_enabled());
}
#[test]
fn test_parametric_eq_bypass() {
let mut audio: AudioSamples<'_, f32> =
AudioSamples::from_mono_vec(non_empty_vec![1.0f32, 0.5, -0.5], sample_rate!(44100));
let original_samples = audio.data.clone();
let mut eq = ParametricEq::new();
eq.add_band(EqBand::peak(1000.0, 10.0, 2.0)); eq.set_bypassed(true);
let result = audio.apply_parametric_eq(&eq);
assert!(result.is_ok());
match (&audio.data, &original_samples) {
(AudioData::Mono(new), AudioData::Mono(orig)) => {
assert_eq!(new, orig);
}
_ => panic!("Expected mono audio"),
}
}
#[test]
fn test_db_linear_conversion() {
assert!((db_to_linear(0.0_f64) - 1.0).abs() < 1e-10);
assert!((db_to_linear(20.0_f64) - 10.0).abs() < 1e-10);
assert!((db_to_linear(-20.0_f64) - 0.1).abs() < 1e-10);
assert!((linear_to_db(1.0_f64) - 0.0).abs() < 1e-10);
assert!((linear_to_db(10.0_f64) - 20.0).abs() < 1e-10);
assert!((linear_to_db(0.1_f64) - (-20.0)).abs() < 1e-10);
}
#[test]
fn test_five_band_eq() {
let eq = ParametricEq::five_band();
assert_eq!(eq.band_count(), 5);
assert_eq!(eq.get_band(0).unwrap().frequency, 100.0);
assert_eq!(eq.get_band(1).unwrap().frequency, 300.0);
assert_eq!(eq.get_band(2).unwrap().frequency, 1000.0);
assert_eq!(eq.get_band(3).unwrap().frequency, 3000.0);
assert_eq!(eq.get_band(4).unwrap().frequency, 8000.0);
}
}