#![allow(dead_code)]
use ndarray::Array1;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FrequencyWeighting {
Flat,
AWeighting,
BassEmphasis,
Custom,
}
#[derive(Debug, Clone)]
pub struct WeightedLossConfig {
pub weighting: FrequencyWeighting,
pub bass_emphasis: f64,
pub midrange_emphasis: f64,
pub treble_emphasis: f64,
pub custom_bands: Vec<(f64, f64, f64)>,
}
impl Default for WeightedLossConfig {
fn default() -> Self {
Self {
weighting: FrequencyWeighting::Flat,
bass_emphasis: 1.0,
midrange_emphasis: 1.0,
treble_emphasis: 1.0,
custom_bands: Vec::new(),
}
}
}
pub fn a_weighting_linear(freq: f64) -> f64 {
let f2 = freq * freq;
let f4 = f2 * f2;
let c1 = 12194.0_f64.powi(2);
let c2 = 20.6_f64.powi(2);
let c3 = 107.7_f64.powi(2);
let c4 = 737.9_f64.powi(2);
let num = c1 * f4;
let den1 = f2 + c2;
let den2 = ((f2 + c3) * (f2 + c4)).sqrt();
let den3 = f2 + c1;
let ra = num / (den1 * den2 * den3);
let ra_1k = {
let f_1k = 1000.0;
let f2_1k = f_1k * f_1k;
let f4_1k = f2_1k * f2_1k;
let num_1k = c1 * f4_1k;
let den1_1k = f2_1k + c2;
let den2_1k = ((f2_1k + c3) * (f2_1k + c4)).sqrt();
let den3_1k = f2_1k + c1;
num_1k / (den1_1k * den2_1k * den3_1k)
};
ra / ra_1k
}
pub fn a_weighting_db(freq: f64) -> f64 {
20.0 * a_weighting_linear(freq).log10()
}
pub fn compute_weights(freq: &Array1<f64>, config: &WeightedLossConfig) -> Array1<f64> {
match config.weighting {
FrequencyWeighting::Flat => Array1::ones(freq.len()),
FrequencyWeighting::AWeighting => freq.map(|&f| a_weighting_linear(f)),
FrequencyWeighting::BassEmphasis => freq.map(|&f| {
if f < 200.0 {
config.bass_emphasis
} else if f < 2000.0 {
config.midrange_emphasis
} else {
config.treble_emphasis
}
}),
FrequencyWeighting::Custom => {
freq.map(|&f| {
for &(low, high, weight) in &config.custom_bands {
if f >= low && f < high {
return weight;
}
}
1.0
})
}
}
}
pub fn weighted_rms_error(error: &Array1<f64>, weights: &Array1<f64>) -> f64 {
let weighted_sq: f64 = error
.iter()
.zip(weights.iter())
.map(|(&e, &w)| e * e * w)
.sum();
let total_weight: f64 = weights.iter().sum();
if total_weight > 0.0 {
(weighted_sq / total_weight).sqrt()
} else {
0.0
}
}
pub fn weighted_mae(error: &Array1<f64>, weights: &Array1<f64>) -> f64 {
let weighted_abs: f64 = error
.iter()
.zip(weights.iter())
.map(|(&e, &w)| e.abs() * w)
.sum();
let total_weight: f64 = weights.iter().sum();
if total_weight > 0.0 {
weighted_abs / total_weight
} else {
0.0
}
}
pub fn weighted_combined_loss(error: &Array1<f64>, weights: &Array1<f64>, peak_weight: f64) -> f64 {
let rms = weighted_rms_error(error, weights);
let peak = error
.iter()
.zip(weights.iter())
.map(|(&e, &w)| e.abs() * w.sqrt()) .fold(0.0_f64, f64::max);
(1.0 - peak_weight) * rms + peak_weight * peak
}
pub fn bass_emphasis_config() -> WeightedLossConfig {
WeightedLossConfig {
weighting: FrequencyWeighting::Custom,
bass_emphasis: 2.0,
midrange_emphasis: 1.0,
treble_emphasis: 0.5,
custom_bands: vec![
(20.0, 80.0, 2.5), (80.0, 200.0, 2.0), (200.0, 500.0, 1.5), (500.0, 2000.0, 1.0), (2000.0, 8000.0, 0.8), (8000.0, 20000.0, 0.5), ],
}
}
pub fn a_weighted_config() -> WeightedLossConfig {
WeightedLossConfig {
weighting: FrequencyWeighting::AWeighting,
..Default::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_approx_eq(a: f64, b: f64, epsilon: f64) {
assert!(
(a - b).abs() < epsilon,
"assertion failed: {} ≈ {} (diff = {}, epsilon = {})",
a,
b,
(a - b).abs(),
epsilon
);
}
#[test]
fn test_a_weighting_1khz() {
let db = a_weighting_db(1000.0);
assert_approx_eq(db, 0.0, 0.1);
}
#[test]
fn test_a_weighting_low_freq() {
let db_20 = a_weighting_db(20.0);
let db_100 = a_weighting_db(100.0);
assert!(db_20 < -40.0, "20 Hz should be < -40 dB, got {}", db_20);
assert!(db_100 < -15.0, "100 Hz should be < -15 dB, got {}", db_100);
}
#[test]
fn test_a_weighting_high_freq() {
let db_10k = a_weighting_db(10000.0);
assert!(
db_10k > -5.0 && db_10k < 5.0,
"10 kHz should be near 0 dB, got {}",
db_10k
);
}
#[test]
fn test_flat_weighting() {
let freq = Array1::linspace(20.0, 20000.0, 100);
let config = WeightedLossConfig::default();
let weights = compute_weights(&freq, &config);
assert!(weights.iter().all(|&w| (w - 1.0).abs() < 0.001));
}
#[test]
fn test_weighted_rms() {
let error = Array1::from_vec(vec![2.0, 4.0, 6.0]);
let weights = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let rms = weighted_rms_error(&error, &weights);
assert_approx_eq(rms, 4.32, 0.1);
}
#[test]
fn test_weighted_mae() {
let error = Array1::from_vec(vec![-2.0, 4.0, -6.0]);
let weights = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let mae = weighted_mae(&error, &weights);
assert_approx_eq(mae, 4.0, 0.01);
}
#[test]
fn test_bass_emphasis_config() {
let config = bass_emphasis_config();
let freq = Array1::from_vec(vec![50.0, 150.0, 1000.0, 10000.0]);
let weights = compute_weights(&freq, &config);
assert!(weights[0] > weights[3]);
assert!(weights[1] > weights[3]);
}
#[test]
fn test_custom_bands() {
let config = WeightedLossConfig {
weighting: FrequencyWeighting::Custom,
custom_bands: vec![
(0.0, 100.0, 3.0),
(100.0, 1000.0, 2.0),
(1000.0, 20000.0, 1.0),
],
..Default::default()
};
let freq = Array1::from_vec(vec![50.0, 500.0, 5000.0]);
let weights = compute_weights(&freq, &config);
assert_approx_eq(weights[0], 3.0, 0.01);
assert_approx_eq(weights[1], 2.0, 0.01);
assert_approx_eq(weights[2], 1.0, 0.01);
}
#[test]
fn test_a_weighting_at_100hz() {
let db = a_weighting_db(100.0);
assert_approx_eq(db, -19.1, 0.5);
}
#[test]
fn test_weighted_combined_loss_peak_dominance() {
let error = Array1::from_vec(vec![1.0, 1.0, 10.0, 1.0]);
let weights = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let loss_low_peak = weighted_combined_loss(&error, &weights, 0.0);
let loss_high_peak = weighted_combined_loss(&error, &weights, 0.9);
assert!(
loss_high_peak > loss_low_peak,
"High peak_weight loss ({:.2}) should exceed low peak_weight loss ({:.2})",
loss_high_peak,
loss_low_peak
);
assert!(
loss_high_peak > 8.0,
"With peak_weight=0.9 and peak=10, combined loss should be > 8.0, got {:.2}",
loss_high_peak
);
}
#[test]
fn test_custom_bands_overlap() {
let config = WeightedLossConfig {
weighting: FrequencyWeighting::Custom,
custom_bands: vec![
(50.0, 500.0, 3.0),
(100.0, 1000.0, 5.0), ],
..Default::default()
};
let freq = Array1::from_vec(vec![200.0]); let weights = compute_weights(&freq, &config);
assert_approx_eq(weights[0], 3.0, 0.01);
}
}