use log;
#[derive(Debug, Clone)]
pub struct AgcConfig {
pub highpass_hz: f32,
pub target_rms: f32,
pub noise_gate_rms: f32,
pub max_gain: f32,
pub min_gain: f32,
pub attack_ms: f32,
pub release_ms: f32,
pub limiter_knee: f32,
}
impl Default for AgcConfig {
fn default() -> Self {
Self {
highpass_hz: 90.0,
target_rms: 3_277.0,
noise_gate_rms: 165.0,
max_gain: 31.6,
min_gain: 0.125,
attack_ms: 10.0,
release_ms: 400.0,
limiter_knee: 22_937.0,
}
}
}
struct BiquadHighPass {
b0: f32,
b1: f32,
b2: f32,
a1: f32,
a2: f32,
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl BiquadHighPass {
fn new(cutoff_hz: f32, sample_rate: u32) -> Self {
let w0 = 2.0 * std::f32::consts::PI * cutoff_hz / sample_rate as f32;
let cos_w0 = w0.cos();
let q = std::f32::consts::FRAC_1_SQRT_2; let alpha = w0.sin() / (2.0 * q);
let a0 = 1.0 + alpha;
Self {
b0: ((1.0 + cos_w0) / 2.0) / a0,
b1: (-(1.0 + cos_w0)) / a0,
b2: ((1.0 + cos_w0) / 2.0) / a0,
a1: (-2.0 * cos_w0) / a0,
a2: (1.0 - alpha) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
#[inline]
fn process(&mut self, x: f32) -> f32 {
let y = self.b0 * x + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x;
self.y2 = self.y1;
self.y1 = y;
y
}
fn reset(&mut self) {
self.x1 = 0.0;
self.x2 = 0.0;
self.y1 = 0.0;
self.y2 = 0.0;
}
}
pub struct AudioEnhancer {
config: AgcConfig,
highpass: BiquadHighPass,
gain: f32,
attack_coef: f32,
release_coef: f32,
enabled: bool,
}
impl AudioEnhancer {
pub fn new(sample_rate: u32) -> Self {
Self::with_config(sample_rate, AgcConfig::default())
}
pub fn with_config(sample_rate: u32, config: AgcConfig) -> Self {
let sr = sample_rate as f32;
let attack_coef = (-1.0 / (config.attack_ms / 1_000.0 * sr)).exp();
let release_coef = (-1.0 / (config.release_ms / 1_000.0 * sr)).exp();
log::info!(
"AudioEnhancer: highpass={}Hz target_rms={:.0} max_gain={:+.1}dB",
config.highpass_hz,
config.target_rms,
20.0 * config.max_gain.log10(),
);
Self {
highpass: BiquadHighPass::new(config.highpass_hz, sample_rate),
gain: 1.0,
attack_coef,
release_coef,
config,
enabled: true,
}
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn pre_filter(&mut self, audio: &[i16]) -> Vec<i16> {
if !self.enabled || audio.is_empty() {
return audio.to_vec();
}
audio
.iter()
.map(|&s| clamp_i16(self.highpass.process(s as f32)))
.collect()
}
pub fn post_filter(&mut self, audio: &[i16]) -> Vec<i16> {
if !self.enabled || audio.is_empty() {
return audio.to_vec();
}
let rms = chunk_rms(audio);
let desired = if rms > self.config.noise_gate_rms {
(self.config.target_rms / rms)
.clamp(self.config.min_gain, self.config.max_gain)
} else {
self.gain
};
let knee = self.config.limiter_knee;
let headroom = 32_767.0 - knee;
audio
.iter()
.map(|&s| {
let coef = if desired < self.gain {
self.attack_coef
} else {
self.release_coef
};
self.gain = coef * self.gain + (1.0 - coef) * desired;
let x = s as f32 * self.gain;
let y = if x.abs() <= knee {
x
} else {
x.signum() * (knee + headroom * ((x.abs() - knee) / headroom).tanh())
};
clamp_i16(y)
})
.collect()
}
pub fn reset(&mut self) {
self.highpass.reset();
}
pub fn gain_db(&self) -> f32 {
20.0 * self.gain.log10()
}
}
fn chunk_rms(audio: &[i16]) -> f32 {
let sum_sq: f64 = audio.iter().map(|&s| (s as f64) * (s as f64)).sum();
((sum_sq / audio.len() as f64) as f32).sqrt()
}
#[inline]
fn clamp_i16(s: f32) -> i16 {
s.clamp(-32_768.0, 32_767.0) as i16
}
#[cfg(test)]
mod tests {
use super::*;
fn sine(freq: f32, peak: f32, sample_rate: u32, n: usize) -> Vec<i16> {
(0..n)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(peak * (2.0 * std::f32::consts::PI * freq * t).sin()) as i16
})
.collect()
}
#[test]
fn output_length_matches_input() {
let mut enh = AudioEnhancer::new(16_000);
let input = sine(300.0, 8_000.0, 16_000, 1_234);
assert_eq!(enh.pre_filter(&input).len(), input.len());
assert_eq!(enh.post_filter(&input).len(), input.len());
}
#[test]
fn highpass_removes_dc_offset() {
let mut enh = AudioEnhancer::new(16_000);
let input = vec![5_000i16; 16_000]; let out = enh.pre_filter(&input);
let tail_mean: f64 =
out[8_000..].iter().map(|&s| s as f64).sum::<f64>() / 8_000.0;
assert!(tail_mean.abs() < 50.0, "residual DC: {tail_mean}");
}
#[test]
fn highpass_passes_speech_band() {
let mut enh = AudioEnhancer::new(16_000);
let input = sine(300.0, 8_000.0, 16_000, 16_000);
let out = enh.pre_filter(&input);
let in_rms = chunk_rms(&input);
let out_rms = chunk_rms(&out[8_000..]);
assert!(out_rms > in_rms * 0.89, "in={in_rms} out={out_rms}");
}
#[test]
fn agc_boosts_quiet_audio_toward_target() {
let mut enh = AudioEnhancer::new(16_000);
let input = sine(300.0, 1_000.0, 16_000, 16_000 * 4);
let out = enh.post_filter(&input);
let out_rms = chunk_rms(&out[out.len() / 2..]);
assert!(
out_rms > 2_500.0 && out_rms < 4_500.0,
"rms after AGC: {out_rms}"
);
}
#[test]
fn agc_reduces_loud_audio_toward_target() {
let mut enh = AudioEnhancer::new(16_000);
let input = sine(300.0, 28_000.0, 16_000, 16_000 * 4);
let out = enh.post_filter(&input);
let out_rms = chunk_rms(&out[out.len() / 2..]);
assert!(
out_rms > 2_500.0 && out_rms < 4_500.0,
"rms after AGC: {out_rms}"
);
}
#[test]
fn agc_holds_gain_during_silence() {
let mut enh = AudioEnhancer::new(16_000);
let speech = sine(300.0, 1_000.0, 16_000, 16_000 * 4);
enh.post_filter(&speech);
let learned = enh.gain_db();
assert!(learned > 6.0, "expected boost, got {learned} dB");
let silence = vec![10i16; 16_000];
enh.post_filter(&silence);
assert!((enh.gain_db() - learned).abs() < 0.5);
}
#[test]
fn limiter_prevents_hard_clipping() {
let mut enh = AudioEnhancer::new(16_000);
let input = sine(300.0, 30_000.0, 16_000, 16_000);
let out = enh.post_filter(&input);
assert!(out.iter().all(|&s| s > i16::MIN));
let max = out.iter().map(|&s| s.unsigned_abs()).max().unwrap();
let at_max = out.iter().filter(|&&s| s.unsigned_abs() == max).count();
assert!(at_max < 20, "{at_max} samples pinned at peak {max}");
}
#[test]
fn disabled_passes_through() {
let mut enh = AudioEnhancer::new(16_000);
enh.set_enabled(false);
let input = sine(300.0, 1_000.0, 16_000, 480);
assert_eq!(enh.pre_filter(&input), input);
assert_eq!(enh.post_filter(&input), input);
}
#[test]
fn reset_keeps_learned_gain() {
let mut enh = AudioEnhancer::new(16_000);
let speech = sine(300.0, 1_000.0, 16_000, 16_000 * 4);
enh.post_filter(&speech);
let learned = enh.gain_db();
enh.reset();
assert!((enh.gain_db() - learned).abs() < 0.01);
}
}