use super::enhanced_weights::{FrequencyBandWeights, combined_weighted_loss};
use ndarray::Array1;
const ASYMMETRIC_ERB_WEIGHT: f64 = 0.7;
const ASYMMETRIC_BAND_WEIGHT: f64 = 0.3;
#[derive(Debug, Clone, Copy)]
pub struct AsymmetricLossConfig {
pub peak_weight: f64,
pub dip_weight: f64,
pub bass_peak_weight: f64,
pub bass_dip_weight: f64,
pub transition_freq: f64,
}
impl Default for AsymmetricLossConfig {
fn default() -> Self {
Self {
peak_weight: 2.0,
dip_weight: 1.0,
bass_peak_weight: 5.0,
bass_dip_weight: 1.0,
transition_freq: 300.0,
}
}
}
fn asymmetric_weight(
freq: f64,
error: f64,
config: &AsymmetricLossConfig,
log_transition: f64,
sigmoid_k: f64,
null_mask: f64,
) -> f64 {
let blend = 1.0 / (1.0 + (-(freq.ln() - log_transition) * sigmoid_k).exp());
let peak_w = config.bass_peak_weight + blend * (config.peak_weight - config.bass_peak_weight);
let dip_w = config.bass_dip_weight + blend * (config.dip_weight - config.bass_dip_weight);
if error > 0.0 {
peak_w
} else {
dip_w * null_mask
}
}
pub fn weighted_mse_asymmetric(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
config: &AsymmetricLossConfig,
null_suppression: Option<&Array1<f64>>,
) -> f64 {
assert_eq!(freqs.len(), error.len());
if let Some(mask) = null_suppression
&& mask.len() != freqs.len()
{
return f64::INFINITY;
}
let mut f_kept: Vec<f64> = Vec::with_capacity(freqs.len());
let mut e_kept: Vec<f64> = Vec::with_capacity(freqs.len());
let mut m_kept: Vec<f64> = Vec::with_capacity(freqs.len());
let have_mask = null_suppression.is_some();
for i in 0..freqs.len() {
let f = freqs[i];
if f >= min_freq && f <= max_freq {
f_kept.push(f);
e_kept.push(error[i]);
if let Some(mask) = null_suppression {
m_kept.push(mask[i].clamp(0.0, 1.0));
}
}
}
if f_kept.is_empty() {
return f64::INFINITY;
}
let f_in = Array1::from(f_kept);
let e_in = Array1::from(e_kept);
let log_transition = config.transition_freq.ln();
let sigmoid_k = 2.0 * 9.0_f64.ln() / 2.0_f64.ln();
let mut weighted_buf: Vec<f64> = Vec::with_capacity(f_in.len());
for i in 0..f_in.len() {
let f = f_in[i];
let e = e_in[i];
let null_mask = if have_mask { m_kept[i] } else { 1.0 };
let w = asymmetric_weight(f, e, config, log_transition, sigmoid_k, null_mask);
weighted_buf.push(e * w.max(0.0).sqrt());
}
let weighted_error = Array1::from(weighted_buf);
let bands = FrequencyBandWeights::default();
combined_weighted_loss(
&f_in,
&weighted_error,
&bands,
ASYMMETRIC_ERB_WEIGHT,
ASYMMETRIC_BAND_WEIGHT,
)
}
pub fn flat_loss_asymmetric(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
null_suppression: Option<&Array1<f64>>,
) -> f64 {
weighted_mse_asymmetric(
freqs,
error,
min_freq,
max_freq,
&AsymmetricLossConfig::default(),
null_suppression,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loss::enhanced_weights::combined_weighted_loss;
fn linspace_log(f_min: f64, f_max: f64, n: usize) -> Array1<f64> {
let lo = f_min.ln();
let hi = f_max.ln();
Array1::from_iter((0..n).map(|i| (lo + (hi - lo) * i as f64 / (n - 1) as f64).exp()))
}
fn closest_index(freqs: &Array1<f64>, target: f64) -> usize {
let mut best_idx = 0;
let mut best_diff = f64::INFINITY;
for (i, &f) in freqs.iter().enumerate() {
let diff = (f - target).abs();
if diff < best_diff {
best_diff = diff;
best_idx = i;
}
}
best_idx
}
#[test]
fn zero_error_gives_zero_loss() {
let freqs = linspace_log(20.0, 20000.0, 64);
let error = Array1::zeros(freqs.len());
let loss = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, None);
assert!(loss.abs() < 1e-12, "expected zero loss, got {loss}");
}
#[test]
fn asymmetric_equals_combined_when_weights_are_unit() {
let freqs = linspace_log(50.0, 15000.0, 128);
let error: Array1<f64> = freqs.mapv(|f| ((f.ln() * 3.0).sin()) * 2.0);
let config = AsymmetricLossConfig {
peak_weight: 1.0,
dip_weight: 1.0,
bass_peak_weight: 1.0,
bass_dip_weight: 1.0,
transition_freq: 300.0,
};
let asym = weighted_mse_asymmetric(&freqs, &error, 20.0, 20000.0, &config, None);
let expected =
combined_weighted_loss(&freqs, &error, &FrequencyBandWeights::default(), 0.7, 0.3);
assert!(
(asym - expected).abs() < 1e-9,
"asymmetric with unit weights must equal combined_weighted_loss ({asym} vs {expected})"
);
}
#[test]
fn bass_peaks_penalized_more_than_bass_dips() {
let freqs = Array1::from_vec(vec![80.0]);
let err_peak = Array1::from_vec(vec![10.0]);
let err_dip = Array1::from_vec(vec![-10.0]);
let loss_peak = flat_loss_asymmetric(&freqs, &err_peak, 20.0, 20000.0, None);
let loss_dip = flat_loss_asymmetric(&freqs, &err_dip, 20.0, 20000.0, None);
assert!(
loss_peak > loss_dip,
"bass peak must be penalized more than bass dip (peak={loss_peak}, dip={loss_dip})"
);
}
#[test]
fn null_mask_suppresses_bass_dip() {
let freqs = linspace_log(20.0, 20000.0, 256);
let mut error: Array1<f64> = Array1::zeros(freqs.len());
let dip_idx = closest_index(&freqs, 80.0);
error[dip_idx] = -15.0;
let mut mask: Array1<f64> = Array1::ones(freqs.len());
mask[dip_idx] = 0.0;
let loss_unmasked = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, None);
let loss_masked = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, Some(&mask));
assert!(
loss_masked < loss_unmasked,
"null-masked loss must be smaller (masked={loss_masked}, unmasked={loss_unmasked})"
);
assert!(
loss_masked < 1e-6,
"fully masked single-point dip should collapse to ~0 (got {loss_masked})"
);
}
#[test]
fn null_mask_does_not_suppress_peaks() {
let freqs = linspace_log(20.0, 20000.0, 256);
let mut error: Array1<f64> = Array1::zeros(freqs.len());
let peak_idx = closest_index(&freqs, 80.0);
error[peak_idx] = 10.0;
let mut mask: Array1<f64> = Array1::ones(freqs.len());
mask[peak_idx] = 0.0;
let unmasked = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, None);
let masked = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, Some(&mask));
assert!(
(unmasked - masked).abs() < 1e-9,
"peak branch must not be affected by the null mask ({unmasked} vs {masked})"
);
}
#[test]
fn broad_bass_dip_is_penalized_under_new_defaults() {
let freqs = linspace_log(20.0, 20000.0, 256);
let error: Array1<f64> = freqs.mapv(|f| {
if (60.0..=120.0).contains(&f) {
-5.0
} else {
0.0
}
});
let old_config = AsymmetricLossConfig {
bass_dip_weight: 0.2,
..AsymmetricLossConfig::default()
};
let new_loss = flat_loss_asymmetric(&freqs, &error, 20.0, 20000.0, None);
let old_loss = weighted_mse_asymmetric(&freqs, &error, 20.0, 20000.0, &old_config, None);
assert!(
new_loss > old_loss,
"new default must penalize broad bass dips more (new={new_loss}, old={old_loss})"
);
}
}