use ndarray::Array1;
#[derive(Debug, Clone)]
pub struct FrequencyBandWeights {
pub bass_min: f64,
pub bass_max: f64,
pub mid_min: f64,
pub mid_max: f64,
pub treble_min: f64,
pub treble_max: f64,
pub bass_weight: f64,
pub mid_weight: f64,
pub treble_weight: f64,
}
impl Default for FrequencyBandWeights {
fn default() -> Self {
Self {
bass_min: 20.0,
bass_max: 200.0,
mid_min: 200.0,
mid_max: 4000.0,
treble_min: 4000.0,
treble_max: 20000.0,
bass_weight: 2.0,
mid_weight: 1.0,
treble_weight: 0.8,
}
}
}
pub fn erb(frequency: f64) -> f64 {
24.7 * (1.0 + 4.37 * frequency / 1000.0)
}
pub fn erb_weighted_loss(freqs: &Array1<f64>, error: &Array1<f64>) -> f64 {
assert_eq!(freqs.len(), error.len());
let erbs: Array1<f64> = freqs.mapv(erb);
let weights: Array1<f64> = erbs.mapv(|e| 1.0 / e);
let total_weight: f64 = weights.iter().sum();
if total_weight == 0.0 {
return 0.0;
}
let weighted_sum: f64 = error
.iter()
.zip(weights.iter())
.map(|(e, w)| e * e * w)
.sum();
(weighted_sum / total_weight).sqrt()
}
pub fn band_weighted_loss(
freqs: &Array1<f64>,
error: &Array1<f64>,
bands: &FrequencyBandWeights,
) -> f64 {
assert_eq!(freqs.len(), error.len());
let mut bass_ss = 0.0;
let mut bass_n = 0usize;
let mut mid_ss = 0.0;
let mut mid_n = 0usize;
let mut treble_ss = 0.0;
let mut treble_n = 0usize;
for (&f, &e) in freqs.iter().zip(error.iter()) {
if f >= bands.bass_min && f <= bands.bass_max {
bass_ss += e * e;
bass_n += 1;
} else if f >= bands.mid_min && f <= bands.mid_max {
mid_ss += e * e;
mid_n += 1;
} else if f >= bands.treble_min && f <= bands.treble_max {
treble_ss += e * e;
treble_n += 1;
}
}
let bass_rms = if bass_n > 0 {
(bass_ss / bass_n as f64).sqrt()
} else {
0.0
};
let mid_rms = if mid_n > 0 {
(mid_ss / mid_n as f64).sqrt()
} else {
0.0
};
let treble_rms = if treble_n > 0 {
(treble_ss / treble_n as f64).sqrt()
} else {
0.0
};
bands.bass_weight * bass_rms + bands.mid_weight * mid_rms + bands.treble_weight * treble_rms
}
pub fn combined_weighted_loss(
freqs: &Array1<f64>,
error: &Array1<f64>,
bands: &FrequencyBandWeights,
erb_weight: f64,
band_weight: f64,
) -> f64 {
let erb_loss = erb_weighted_loss(freqs, error);
let band_loss = band_weighted_loss(freqs, error, bands);
erb_weight * erb_loss + band_weight * band_loss
}