const FULL_SCALE: f64 = 32768.0;
const MIN_RMS_THRESHOLD: f64 = 0.001;
#[derive(Debug, Clone)]
pub struct Normalizer {
target_rms: f64,
current_gain: f64,
smoothing: f64,
peak_limit: bool,
}
impl Normalizer {
pub fn new(target_dbfs: f32) -> Self {
assert!(
target_dbfs <= 0.0,
"Target dBFS must be <= 0, got {target_dbfs}"
);
let target_rms = FULL_SCALE * 10_f64.powf(target_dbfs as f64 / 20.0);
Self {
target_rms,
current_gain: 1.0,
smoothing: 0.1, peak_limit: true,
}
}
pub fn with_settings(target_dbfs: f32, smoothing: f64, peak_limit: bool) -> Self {
let mut normalizer = Self::new(target_dbfs);
normalizer.smoothing = smoothing.clamp(0.01, 1.0);
normalizer.peak_limit = peak_limit;
normalizer
}
fn calculate_rms(samples: &[i16]) -> f64 {
if samples.is_empty() {
return 0.0;
}
let sum_squares: f64 = samples.iter().map(|&s| (s as f64).powi(2)).sum();
(sum_squares / samples.len() as f64).sqrt()
}
#[allow(dead_code)]
pub fn rms_to_dbfs(rms: f64) -> f64 {
if rms <= 0.0 {
return -96.0; }
20.0 * (rms / FULL_SCALE).log10()
}
pub fn process(&mut self, samples: &[i16]) -> Vec<i16> {
if samples.is_empty() {
return Vec::new();
}
let input_rms = Self::calculate_rms(samples);
if input_rms < MIN_RMS_THRESHOLD * FULL_SCALE {
return samples.to_vec();
}
let target_gain = self.target_rms / input_rms;
self.current_gain += self.smoothing * (target_gain - self.current_gain);
samples
.iter()
.map(|&s| {
let amplified = s as f64 * self.current_gain;
if self.peak_limit {
let normalized = amplified / FULL_SCALE;
let limited = if normalized.abs() > 0.9 {
let sign = normalized.signum();
let magnitude = normalized.abs();
let compressed = 0.9 + 0.1 * ((magnitude - 0.9) / 0.1).tanh();
sign * compressed * FULL_SCALE
} else {
amplified
};
limited.round().clamp(-32768.0, 32767.0) as i16
} else {
amplified.round().clamp(-32768.0, 32767.0) as i16
}
})
.collect()
}
pub fn reset(&mut self) {
self.current_gain = 1.0;
}
pub fn current_gain(&self) -> f64 {
self.current_gain
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalizer_creation() {
let norm = Normalizer::new(-20.0);
assert!(norm.target_rms > 0.0);
assert!(norm.target_rms < FULL_SCALE);
}
#[test]
#[should_panic(expected = "Target dBFS must be <= 0")]
fn test_normalizer_invalid_target() {
Normalizer::new(6.0);
}
#[test]
fn test_rms_calculation() {
let dc: Vec<i16> = vec![1000; 100];
let rms = Normalizer::calculate_rms(&dc);
assert!((rms - 1000.0).abs() < 1.0);
let silence: Vec<i16> = vec![0; 100];
let rms = Normalizer::calculate_rms(&silence);
assert_eq!(rms, 0.0);
}
#[test]
fn test_normalizer_amplifies_quiet() {
let mut norm = Normalizer::new(-20.0);
let quiet: Vec<i16> = vec![100; 480];
let output = norm.process(&quiet);
let input_rms = Normalizer::calculate_rms(&quiet);
let output_rms = Normalizer::calculate_rms(&output);
assert!(
output_rms > input_rms,
"Output RMS {output_rms} should be > input RMS {input_rms}"
);
}
#[test]
fn test_normalizer_attenuates_loud() {
let mut norm = Normalizer::new(-20.0);
let loud: Vec<i16> = vec![16000; 480];
let output = norm.process(&loud);
let input_rms = Normalizer::calculate_rms(&loud);
let output_rms = Normalizer::calculate_rms(&output);
assert!(
output_rms < input_rms,
"Output RMS {output_rms} should be < input RMS {input_rms}"
);
}
#[test]
fn test_normalizer_skips_silence() {
let mut norm = Normalizer::new(-20.0);
let silence: Vec<i16> = vec![1; 480];
let output = norm.process(&silence);
assert_eq!(output, silence);
}
#[test]
fn test_normalizer_peak_limiting() {
let mut norm = Normalizer::with_settings(-6.0, 1.0, true);
let input: Vec<i16> = vec![10000; 480];
let output = norm.process(&input);
assert!(!output.is_empty());
}
#[test]
fn test_normalizer_reset() {
let mut norm = Normalizer::new(-20.0);
let samples: Vec<i16> = vec![1000; 480];
norm.process(&samples);
assert!(norm.current_gain() != 1.0);
norm.reset();
assert_eq!(norm.current_gain(), 1.0);
}
#[test]
fn test_dbfs_conversion() {
let dbfs = Normalizer::rms_to_dbfs(FULL_SCALE);
assert!((dbfs - 0.0).abs() < 0.01);
let dbfs = Normalizer::rms_to_dbfs(FULL_SCALE / 2.0);
assert!((dbfs - (-6.02)).abs() < 0.1);
}
}