use core::f32::consts::PI;
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum EnhancementMode {
#[default]
None,
Classical(ClassicalConfig),
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ClassicalConfig {
pub biquads: [Option<Biquad>; 2],
pub compressor: Option<Compressor>,
pub boundary_fade_samples: usize,
pub output_gain_db: f32,
}
impl Default for ClassicalConfig {
fn default() -> Self {
Self {
biquads: [
Some(Biquad::high_pass(8_000.0, 250.0, 0.707)),
Some(Biquad::peaking(8_000.0, 2_500.0, 1.0, 3.0)),
],
compressor: None,
boundary_fade_samples: 40, output_gain_db: 0.0,
}
}
}
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Biquad {
pub b0: f32,
pub b1: f32,
pub b2: f32,
pub a1: f32,
pub a2: f32,
}
impl Biquad {
pub fn high_pass(fs_hz: f32, fc_hz: f32, q: f32) -> Self {
let w0 = 2.0 * PI * fc_hz / fs_hz;
let cos_w0 = w0.cos();
let alpha = w0.sin() / (2.0 * q);
let a0 = 1.0 + alpha;
let b0 = (1.0 + cos_w0) / 2.0 / a0;
let b1 = -(1.0 + cos_w0) / a0;
let b2 = (1.0 + cos_w0) / 2.0 / a0;
let a1 = -2.0 * cos_w0 / a0;
let a2 = (1.0 - alpha) / a0;
Self { b0, b1, b2, a1, a2 }
}
pub fn peaking(fs_hz: f32, fc_hz: f32, q: f32, gain_db: f32) -> Self {
let a = 10f32.powf(gain_db / 40.0);
let w0 = 2.0 * PI * fc_hz / fs_hz;
let cos_w0 = w0.cos();
let alpha = w0.sin() / (2.0 * q);
let a0 = 1.0 + alpha / a;
let b0 = (1.0 + alpha * a) / a0;
let b1 = -2.0 * cos_w0 / a0;
let b2 = (1.0 - alpha * a) / a0;
let a1 = -2.0 * cos_w0 / a0;
let a2 = (1.0 - alpha / a) / a0;
Self { b0, b1, b2, a1, a2 }
}
pub fn low_shelf(fs_hz: f32, fc_hz: f32, gain_db: f32) -> Self {
let a = 10f32.powf(gain_db / 40.0);
let w0 = 2.0 * PI * fc_hz / fs_hz;
let cos_w0 = w0.cos();
let alpha = w0.sin() / 2.0 * 2f32.sqrt(); let two_sqrt_a_alpha = 2.0 * a.sqrt() * alpha;
let a0 = (a + 1.0) + (a - 1.0) * cos_w0 + two_sqrt_a_alpha;
let b0 = a * ((a + 1.0) - (a - 1.0) * cos_w0 + two_sqrt_a_alpha) / a0;
let b1 = 2.0 * a * ((a - 1.0) - (a + 1.0) * cos_w0) / a0;
let b2 = a * ((a + 1.0) - (a - 1.0) * cos_w0 - two_sqrt_a_alpha) / a0;
let a1 = -2.0 * ((a - 1.0) + (a + 1.0) * cos_w0) / a0;
let a2 = ((a + 1.0) + (a - 1.0) * cos_w0 - two_sqrt_a_alpha) / a0;
Self { b0, b1, b2, a1, a2 }
}
}
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Compressor {
pub threshold_db: f32,
pub ratio: f32,
pub attack_ms: f32,
pub release_ms: f32,
pub makeup_db: f32,
}
#[derive(Clone, Debug, Default)]
pub struct EnhancementState {
biquad_state: [BiquadState; 2],
env: f32,
pending_fade: usize,
}
#[derive(Clone, Copy, Debug, Default)]
struct BiquadState {
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl BiquadState {
#[inline]
fn process(&mut self, b: &Biquad, x: f32) -> f32 {
let y = b.b0 * x + b.b1 * self.x1 + b.b2 * self.x2 - b.a1 * self.y1 - b.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x;
self.y2 = self.y1;
self.y1 = y;
y
}
}
pub fn apply(
mode: &EnhancementMode,
state: &mut EnhancementState,
pcm: &mut [i16],
sample_rate_hz: f32,
prev_was_use: bool,
) {
let cfg = match mode {
EnhancementMode::None => return,
EnhancementMode::Classical(c) => c,
};
if !prev_was_use {
state.pending_fade = cfg.boundary_fade_samples;
}
let comp = cfg.compressor.as_ref();
let (alpha_a, alpha_r, threshold_lin, makeup_lin, ratio) = if let Some(c) = comp {
let a = (-1.0 / (c.attack_ms.max(0.1) * 0.001 * sample_rate_hz)).exp();
let r = (-1.0 / (c.release_ms.max(0.1) * 0.001 * sample_rate_hz)).exp();
let t = 10f32.powf(c.threshold_db / 20.0);
let m = 10f32.powf(c.makeup_db / 20.0);
(a, r, t, m, c.ratio.max(1.0))
} else {
(0.0, 0.0, 0.0, 1.0, 1.0)
};
let output_gain_lin = if cfg.output_gain_db != 0.0 {
10f32.powf(cfg.output_gain_db / 20.0)
} else {
1.0
};
for (i, sample) in pcm.iter_mut().enumerate() {
let mut x = (*sample as f32) / 32_768.0;
for (slot_cfg, slot_state) in cfg.biquads.iter().zip(state.biquad_state.iter_mut()) {
if let Some(b) = slot_cfg {
x = slot_state.process(b, x);
}
}
if comp.is_some() {
let abs_x = x.abs();
let alpha = if abs_x > state.env { alpha_a } else { alpha_r };
state.env = alpha * state.env + (1.0 - alpha) * abs_x;
let gain = if state.env > threshold_lin && state.env > 0.0 {
let over_db = 20.0 * (state.env / threshold_lin).log10();
let reduce_db = over_db * (1.0 - 1.0 / ratio);
10f32.powf(-reduce_db / 20.0)
} else {
1.0
};
x *= gain * makeup_lin;
}
if state.pending_fade > 0 && i < cfg.boundary_fade_samples {
let n = cfg.boundary_fade_samples as f32;
let pos = (cfg.boundary_fade_samples - state.pending_fade) as f32;
let w = 0.5 - 0.5 * (PI * pos / n).cos();
x *= w;
state.pending_fade -= 1;
}
x *= output_gain_lin;
let y = (x * 32_768.0).clamp(-32_768.0, 32_767.0);
*sample = y as i16;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_is_pass_through() {
let mut pcm = vec![1234i16; 160];
let snapshot = pcm.clone();
let mut state = EnhancementState::default();
apply(
&EnhancementMode::None,
&mut state,
&mut pcm,
8_000.0,
true,
);
assert_eq!(pcm, snapshot);
}
#[test]
fn classical_default_runs_without_panic() {
let mut pcm: Vec<i16> = (0..160).map(|i| (i as i16) * 100).collect();
let mut state = EnhancementState::default();
apply(
&EnhancementMode::Classical(ClassicalConfig::default()),
&mut state,
&mut pcm,
8_000.0,
true,
);
for &s in &pcm {
assert!(s.abs() < 32_767);
}
}
#[test]
fn high_pass_attenuates_dc() {
let mut pcm = vec![10_000i16; 8_000]; let mut state = EnhancementState::default();
let cfg = ClassicalConfig {
biquads: [Some(Biquad::high_pass(8_000.0, 250.0, 0.707)), None],
compressor: None,
boundary_fade_samples: 0,
output_gain_db: 0.0,
};
apply(
&EnhancementMode::Classical(cfg),
&mut state,
&mut pcm,
8_000.0,
true,
);
let tail_max = pcm[4_000..].iter().map(|s| s.abs()).max().unwrap_or(0);
assert!(tail_max < 100, "DC not attenuated: tail max {tail_max}");
}
#[test]
fn output_gain_db_doubles_amplitude_at_plus_6db() {
let original: Vec<i16> = (0..160)
.map(|i| ((i as f32 * 0.1).sin() * 5000.0) as i16)
.collect();
let mut pcm = original.clone();
let cfg = ClassicalConfig {
biquads: [None, None],
compressor: None,
boundary_fade_samples: 0,
output_gain_db: 6.0,
};
let mut state = EnhancementState::default();
apply(
&EnhancementMode::Classical(cfg),
&mut state,
&mut pcm,
8_000.0,
true,
);
for (i, (&out, &inp)) in pcm.iter().zip(original.iter()).enumerate() {
if inp.abs() < 100 {
continue;
}
let ratio = out as f32 / inp as f32;
assert!(
(ratio - 2.0).abs() < 0.05,
"sample {i}: expected ratio ~2.0, got {ratio}"
);
}
}
#[test]
fn fade_starts_after_mute() {
let mut pcm = vec![20_000i16; 160];
let mut state = EnhancementState::default();
let cfg = ClassicalConfig {
biquads: [None, None],
compressor: None,
boundary_fade_samples: 40,
output_gain_db: 0.0,
};
apply(
&EnhancementMode::Classical(cfg),
&mut state,
&mut pcm,
8_000.0,
false, );
assert!(pcm[0].abs() < 100);
assert!(pcm[40] > 19_000);
}
}