use crate::Curve;
use crate::cea2034 as score;
use crate::error::{AutoeqError, Result};
use crate::read;
use clap::ValueEnum;
use ndarray::{Array1, Zip};
use num_complex::Complex64;
use std::collections::HashMap;
use std::f64::consts::PI;
pub mod bass_boost;
pub mod enhanced_weights;
pub mod phase_aware;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum LossType {
SpeakerFlat,
SpeakerFlatAsymmetric,
SpeakerScore,
HeadphoneFlat,
HeadphoneScore,
DriversFlat,
MultiSubFlat,
Epa,
}
#[derive(Debug, Clone)]
pub struct SpeakerLossData {
pub on: Array1<f64>,
pub lw: Array1<f64>,
pub sp: Array1<f64>,
pub pir: Array1<f64>,
}
impl SpeakerLossData {
pub fn try_new(spin: &HashMap<String, Curve>) -> Result<Self> {
let on = spin
.get("On Axis")
.ok_or_else(|| AutoeqError::MissingCea2034Curve {
curve_name: "On Axis".to_string(),
})?
.spl
.clone();
let lw = spin
.get("Listening Window")
.ok_or_else(|| AutoeqError::MissingCea2034Curve {
curve_name: "Listening Window".to_string(),
})?
.spl
.clone();
let sp = spin
.get("Sound Power")
.ok_or_else(|| AutoeqError::MissingCea2034Curve {
curve_name: "Sound Power".to_string(),
})?
.spl
.clone();
let pir = spin
.get("Estimated In-Room Response")
.ok_or_else(|| AutoeqError::MissingCea2034Curve {
curve_name: "Estimated In-Room Response".to_string(),
})?
.spl
.clone();
if on.len() != lw.len() || on.len() != sp.len() || on.len() != pir.len() {
return Err(AutoeqError::CurveLengthMismatch {
on_len: on.len(),
lw_len: lw.len(),
sp_len: sp.len(),
pir_len: pir.len(),
});
}
Ok(Self { on, lw, sp, pir })
}
}
#[derive(Debug, Clone)]
pub struct HeadphoneLossData {
pub smooth: bool,
pub smooth_n: usize,
}
impl HeadphoneLossData {
pub fn new(smooth: bool, smooth_n: usize) -> Self {
Self { smooth, smooth_n }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum CrossoverType {
Butterworth2,
LinkwitzRiley2,
#[serde(alias = "LR24")]
LinkwitzRiley4,
#[serde(alias = "LR48")]
LinkwitzRiley8,
None,
}
impl Default for CrossoverType {
fn default() -> Self {
CrossoverType::LinkwitzRiley4
}
}
impl CrossoverType {
pub fn to_plugin_string(&self) -> &'static str {
match self {
CrossoverType::Butterworth2 => "Butterworth12",
CrossoverType::LinkwitzRiley2 => "LR12",
CrossoverType::LinkwitzRiley4 => "LR24",
CrossoverType::LinkwitzRiley8 => "LR48",
CrossoverType::None => "None",
}
}
pub fn display_name(&self) -> &'static str {
match self {
CrossoverType::Butterworth2 => "2nd order Butterworth",
CrossoverType::LinkwitzRiley2 => "2nd order Linkwitz-Riley",
CrossoverType::LinkwitzRiley4 => "4th order Linkwitz-Riley",
CrossoverType::LinkwitzRiley8 => "8th order Linkwitz-Riley",
CrossoverType::None => "No Crossover (Multi-Sub)",
}
}
}
impl std::str::FromStr for CrossoverType {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"butterworth2" | "bw2" | "butterworth12" | "bw12" => Ok(CrossoverType::Butterworth2),
"lr2" | "lr12" | "linkwitzriley2" | "linkwitzriley12" => Ok(CrossoverType::LinkwitzRiley2),
"lr4" | "lr24" | "linkwitzriley4" | "linkwitzriley24" => Ok(CrossoverType::LinkwitzRiley4),
"lr8" | "lr48" | "linkwitzriley8" | "linkwitzriley48" => Ok(CrossoverType::LinkwitzRiley8),
"none" => Ok(CrossoverType::None),
_ => Err(format!("Unknown crossover type: {}", s)),
}
}
}
impl std::fmt::Display for CrossoverType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.to_plugin_string())
}
}
#[derive(Debug, Clone)]
pub struct DriverMeasurement {
pub freq: Array1<f64>,
pub spl: Array1<f64>,
pub phase: Option<Array1<f64>>,
}
impl DriverMeasurement {
pub fn new(freq: Array1<f64>, spl: Array1<f64>, phase: Option<Array1<f64>>) -> Self {
assert_eq!(freq.len(), spl.len(), "freq and spl must have same length");
if let Some(ref p) = phase {
assert_eq!(freq.len(), p.len(), "freq and phase must have same length");
}
Self { freq, spl, phase }
}
pub fn freq_range(&self) -> (f64, f64) {
let min_freq = self.freq.iter().copied().fold(f64::INFINITY, f64::min);
let max_freq = self.freq.iter().copied().fold(f64::NEG_INFINITY, f64::max);
(min_freq, max_freq)
}
pub fn mean_freq(&self) -> f64 {
let (min_freq, max_freq) = self.freq_range();
(min_freq * max_freq).sqrt()
}
}
#[derive(Debug, Clone)]
pub struct DriversLossData {
pub drivers: Vec<DriverMeasurement>,
pub crossover_type: CrossoverType,
pub freq_grid: Array1<f64>,
}
impl DriversLossData {
pub fn new(mut drivers: Vec<DriverMeasurement>, crossover_type: CrossoverType) -> Self {
assert!(
drivers.len() >= 2 && drivers.len() <= 4,
"Must have 2-4 drivers, got {}",
drivers.len()
);
drivers.sort_by(|a, b| {
a.mean_freq()
.partial_cmp(&b.mean_freq())
.unwrap_or(std::cmp::Ordering::Equal)
});
let min_freq = drivers
.iter()
.map(|d| d.freq_range().0)
.fold(f64::INFINITY, f64::min);
let max_freq = drivers
.iter()
.map(|d| d.freq_range().1)
.fold(f64::NEG_INFINITY, f64::max);
let freq_grid = crate::read::create_log_frequency_grid(
10 * 10, min_freq.max(20.0),
max_freq.min(20000.0),
);
Self {
drivers,
crossover_type,
freq_grid,
}
}
}
pub fn flat_loss(freqs: &Array1<f64>, error: &Array1<f64>, min_freq: f64, max_freq: f64) -> f64 {
weighted_mse(freqs, error, min_freq, max_freq)
}
pub fn speaker_score_loss(
score_data: &SpeakerLossData,
freq: &Array1<f64>,
peq_response: &Array1<f64>,
) -> f64 {
let intervals = score::octave_intervals(2, freq);
let metrics = if peq_response.iter().all(|v| v.abs() < 1e-12) {
score::score(
freq,
&intervals,
&score_data.on,
&score_data.lw,
&score_data.sp,
&score_data.pir,
)
} else {
score::score_peq_approx(
freq,
&intervals,
&score_data.lw,
&score_data.sp,
&score_data.pir,
&score_data.on,
peq_response,
)
};
metrics.pref_score
}
pub fn mixed_loss(
score_data: &SpeakerLossData,
freq: &Array1<f64>,
peq_response: &Array1<f64>,
) -> f64 {
let lw2 = &score_data.lw + peq_response;
let pir2 = &score_data.pir + peq_response;
let lw2_slope = regression_slope_per_octave_in_range(freq, &lw2, 100.0, 10000.0);
let pir_og_slope = regression_slope_per_octave_in_range(freq, &score_data.pir, 100.0, 10000.0);
let pir2_slope = regression_slope_per_octave_in_range(freq, &pir2, 100.0, 10000.0);
if let (Some(lw2eq), Some(pir2og), Some(pir2eq)) = (lw2_slope, pir_og_slope, pir2_slope) {
(0.5 + lw2eq).powi(2) + (pir2og - pir2eq).powi(2)
} else {
f64::INFINITY
}
}
const DEFAULT_BASS_TREBLE_SPLIT_HZ: f64 = 3000.0;
fn weighted_mse(freqs: &Array1<f64>, error: &Array1<f64>, min_freq: f64, max_freq: f64) -> f64 {
weighted_mse_with_split(
freqs,
error,
min_freq,
max_freq,
DEFAULT_BASS_TREBLE_SPLIT_HZ,
)
}
fn weighted_mse_with_split(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
bass_treble_split_hz: f64,
) -> f64 {
let _in_range = freqs.mapv(|f| f >= min_freq && f <= max_freq);
let bass_band = freqs.mapv(|f| f < bass_treble_split_hz && f >= min_freq && f <= max_freq);
let treble_band = freqs.mapv(|f| f >= bass_treble_split_hz && f >= min_freq && f <= max_freq);
let n1: usize = bass_band.iter().filter(|&&b| b).count();
let n2: usize = treble_band.iter().filter(|&&b| b).count();
if n1 == 0 && n2 == 0 {
return f64::INFINITY;
}
let squared_errors = error.mapv(|e| e * e);
let ss1: f64 = Zip::from(&bass_band)
.and(&squared_errors)
.fold(0.0, |acc, &mask, &err| if mask { acc + err } else { acc });
let ss2: f64 = Zip::from(&treble_band)
.and(&squared_errors)
.fold(0.0, |acc, &mask, &err| if mask { acc + err } else { acc });
let err1 = if n1 > 0 {
(ss1 / n1 as f64).sqrt()
} else {
0.0
};
let err2 = if n2 > 0 {
(ss2 / n2 as f64).sqrt()
} else {
0.0
};
match (n1 > 0, n2 > 0) {
(true, true) => err1 + err2 / 3.0,
(true, false) => err1,
(false, true) => err2,
(false, false) => f64::INFINITY,
}
}
#[derive(Debug, Clone, Copy)]
pub struct AsymmetricLossConfig {
pub peak_weight: f64,
pub dip_weight: f64,
pub bass_peak_weight: f64,
pub bass_dip_weight: f64,
pub transition_freq: f64,
}
impl Default for AsymmetricLossConfig {
fn default() -> Self {
Self {
peak_weight: 2.0,
dip_weight: 1.0,
bass_peak_weight: 5.0,
bass_dip_weight: 0.2,
transition_freq: 300.0,
}
}
}
pub fn weighted_mse_asymmetric(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
config: &AsymmetricLossConfig,
) -> f64 {
weighted_mse_asymmetric_with_split(
freqs,
error,
min_freq,
max_freq,
config,
DEFAULT_BASS_TREBLE_SPLIT_HZ,
)
}
pub fn weighted_mse_asymmetric_with_split(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
config: &AsymmetricLossConfig,
bass_treble_split_hz: f64,
) -> f64 {
let bass_band = freqs.mapv(|f| f < bass_treble_split_hz && f >= min_freq && f <= max_freq);
let treble_band = freqs.mapv(|f| f >= bass_treble_split_hz && f >= min_freq && f <= max_freq);
let n1: usize = bass_band.iter().filter(|&&b| b).count();
let n2: usize = treble_band.iter().filter(|&&b| b).count();
if n1 == 0 && n2 == 0 {
return f64::INFINITY;
}
let log_transition = config.transition_freq.ln();
let sigmoid_k = 2.0 * 9.0_f64.ln() / 2.0_f64.ln();
let weighted_squared_errors = Zip::from(freqs).and(error).map_collect(|&f, &e| {
let blend = 1.0 / (1.0 + (-(f.ln() - log_transition) * sigmoid_k).exp());
let peak_w =
config.bass_peak_weight + blend * (config.peak_weight - config.bass_peak_weight);
let dip_w =
config.bass_dip_weight + blend * (config.dip_weight - config.bass_dip_weight);
let weight = if e > 0.0 { peak_w } else { dip_w };
weight * e * e
});
let ss1: f64 = Zip::from(&bass_band)
.and(&weighted_squared_errors)
.fold(0.0, |acc, &mask, &err| if mask { acc + err } else { acc });
let ss2: f64 = Zip::from(&treble_band)
.and(&weighted_squared_errors)
.fold(0.0, |acc, &mask, &err| if mask { acc + err } else { acc });
let err1 = if n1 > 0 {
(ss1 / n1 as f64).sqrt()
} else {
0.0
};
let err2 = if n2 > 0 {
(ss2 / n2 as f64).sqrt()
} else {
0.0
};
match (n1 > 0, n2 > 0) {
(true, true) => err1 + err2 / 3.0,
(true, false) => err1,
(false, true) => err2,
(false, false) => f64::INFINITY,
}
}
pub fn flat_loss_asymmetric(
freqs: &Array1<f64>,
error: &Array1<f64>,
min_freq: f64,
max_freq: f64,
) -> f64 {
weighted_mse_asymmetric(
freqs,
error,
min_freq,
max_freq,
&AsymmetricLossConfig::default(),
)
}
pub fn regression_slope_per_octave_in_range(
freq: &Array1<f64>,
y: &Array1<f64>,
fmin: f64,
fmax: f64,
) -> Option<f64> {
assert_eq!(freq.len(), y.len(), "freq and y must have same length");
if fmax <= fmin {
return None;
}
let mut n: usize = 0;
let mut sum_x = 0.0;
let mut sum_y = 0.0;
let mut sum_xy = 0.0;
let mut sum_x2 = 0.0;
for i in 0..freq.len() {
let f = freq[i];
if f > 0.0 && f >= fmin && f <= fmax {
let xi = f.log2();
let yi = y[i];
n += 1;
sum_x += xi;
sum_y += yi;
sum_xy += xi * yi;
sum_x2 += xi * xi;
}
}
if n < 2 {
return None;
}
let n_f = n as f64;
let cov_xy = sum_xy - (sum_x * sum_y) / n_f;
let var_x = sum_x2 - (sum_x * sum_x) / n_f;
if var_x.abs() < 1e-10 {
return None;
}
Some(cov_xy / var_x)
}
pub fn curve_slope_per_octave_in_range(curve: &crate::Curve, fmin: f64, fmax: f64) -> Option<f64> {
regression_slope_per_octave_in_range(&curve.freq, &curve.spl, fmin, fmax)
}
fn biquad_complex_response(biquad: &crate::iir::Biquad, f: f64) -> Complex64 {
let (a1, a2, b0, b1, b2) = biquad.constants();
let omega = 2.0 * PI * f / biquad.srate;
let z_inv = Complex64::from_polar(1.0, -omega);
let z_inv2 = z_inv * z_inv;
let num = b0 + b1 * z_inv + b2 * z_inv2;
let den = 1.0 + a1 * z_inv + a2 * z_inv2;
num / den
}
fn prepare_driver_curves(data: &DriversLossData, crossover_freqs: &[f64]) -> Vec<Curve> {
let n_drivers = data.drivers.len();
let mut driver_curves = Vec::new();
for (i, driver) in data.drivers.iter().enumerate() {
let (passband_low, passband_high) = if let CrossoverType::None = data.crossover_type {
(20.0, 20000.0)
} else {
(
if i == 0 { 20.0 } else { crossover_freqs[i - 1] },
if i == n_drivers - 1 {
20000.0
} else {
crossover_freqs[i]
},
)
};
let interpolated = crate::read::normalize_and_interpolate_response_with_range(
&data.freq_grid,
&Curve {
freq: driver.freq.clone(),
spl: driver.spl.clone(),
phase: driver.phase.clone(),
},
passband_low,
passband_high,
);
driver_curves.push(interpolated);
}
driver_curves
}
fn build_crossover_filters_for_driver(
driver_index: usize,
n_drivers: usize,
crossover_type: CrossoverType,
crossover_freqs: &[f64],
sample_rate: f64,
) -> Vec<(f64, crate::iir::Biquad)> {
use crate::iir::{
peq_butterworth_highpass, peq_butterworth_lowpass, peq_linkwitzriley_highpass,
peq_linkwitzriley_lowpass,
};
let mut filters = Vec::new();
if let CrossoverType::None = crossover_type {
return filters;
}
if driver_index > 0 {
let xover_freq = crossover_freqs[driver_index - 1];
let hp_peq = match crossover_type {
CrossoverType::Butterworth2 => peq_butterworth_highpass(2, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley2 => peq_linkwitzriley_highpass(2, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley4 => peq_linkwitzriley_highpass(4, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley8 => peq_linkwitzriley_highpass(8, xover_freq, sample_rate),
CrossoverType::None => vec![],
};
filters.extend(hp_peq);
}
if driver_index < n_drivers - 1 {
let xover_freq = crossover_freqs[driver_index];
let lp_peq = match crossover_type {
CrossoverType::Butterworth2 => peq_butterworth_lowpass(2, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley2 => peq_linkwitzriley_lowpass(2, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley4 => peq_linkwitzriley_lowpass(4, xover_freq, sample_rate),
CrossoverType::LinkwitzRiley8 => peq_linkwitzriley_lowpass(8, xover_freq, sample_rate),
CrossoverType::None => vec![],
};
filters.extend(lp_peq);
}
filters
}
fn compute_single_driver_complex(
freq_grid: &Array1<f64>,
curve: &Curve,
gain: f64,
delay_s: f64,
filters: &[(f64, crate::iir::Biquad)],
) -> Array1<Complex64> {
let mag_factor = 10.0_f64.powf(gain / 20.0);
let mut result = Array1::<Complex64>::zeros(freq_grid.len());
for j in 0..freq_grid.len() {
let f = freq_grid[j];
let spl = curve.spl[j];
let z_driver = if let Some(phase) = &curve.phase {
let phi = phase[j].to_radians();
let m = 10.0_f64.powf(spl / 20.0);
Complex64::from_polar(m, phi)
} else {
let m = 10.0_f64.powf(spl / 20.0);
Complex64::new(m, 0.0)
};
let phi_delay = -2.0 * PI * f * delay_s;
let z_delay = Complex64::from_polar(1.0, phi_delay);
let mut z_filters = Complex64::new(1.0, 0.0);
for (_, biquad) in filters {
z_filters *= biquad_complex_response(biquad, f);
}
result[j] = z_driver * mag_factor * z_filters * z_delay;
}
result
}
fn validate_driver_args(
data: &DriversLossData,
gains: &[f64],
crossover_freqs: &[f64],
delays: Option<&[f64]>,
) {
let n_drivers = data.drivers.len();
assert_eq!(gains.len(), n_drivers);
if !matches!(data.crossover_type, CrossoverType::None) {
assert_eq!(crossover_freqs.len(), n_drivers - 1);
}
if let Some(d) = delays {
assert_eq!(d.len(), n_drivers);
}
}
pub fn compute_drivers_combined_response(
data: &DriversLossData,
gains: &[f64],
crossover_freqs: &[f64],
delays: Option<&[f64]>,
sample_rate: f64,
) -> Array1<f64> {
validate_driver_args(data, gains, crossover_freqs, delays);
let n_drivers = data.drivers.len();
let driver_curves = prepare_driver_curves(data, crossover_freqs);
let mut combined_complex = Array1::<Complex64>::zeros(data.freq_grid.len());
for i in 0..n_drivers {
let delay_s = delays.map(|d| d[i]).unwrap_or(0.0) / 1000.0;
let filters = build_crossover_filters_for_driver(
i,
n_drivers,
data.crossover_type,
crossover_freqs,
sample_rate,
);
let driver_complex = compute_single_driver_complex(
&data.freq_grid,
&driver_curves[i],
gains[i],
delay_s,
&filters,
);
combined_complex += &driver_complex;
}
combined_complex.mapv(|z| 20.0 * z.norm().max(1e-12).log10())
}
pub fn compute_per_driver_responses(
data: &DriversLossData,
gains: &[f64],
crossover_freqs: &[f64],
delays: Option<&[f64]>,
sample_rate: f64,
) -> Vec<Array1<f64>> {
validate_driver_args(data, gains, crossover_freqs, delays);
let n_drivers = data.drivers.len();
let driver_curves = prepare_driver_curves(data, crossover_freqs);
let mut results = Vec::with_capacity(n_drivers);
for i in 0..n_drivers {
let delay_s = delays.map(|d| d[i]).unwrap_or(0.0) / 1000.0;
let filters = build_crossover_filters_for_driver(
i,
n_drivers,
data.crossover_type,
crossover_freqs,
sample_rate,
);
let driver_complex = compute_single_driver_complex(
&data.freq_grid,
&driver_curves[i],
gains[i],
delay_s,
&filters,
);
results.push(driver_complex.mapv(|z| 20.0 * z.norm().max(1e-12).log10()));
}
results
}
pub fn drivers_flat_loss(
data: &DriversLossData,
gains: &[f64],
crossover_freqs: &[f64],
delays: Option<&[f64]>,
sample_rate: f64,
min_freq: f64,
max_freq: f64,
) -> f64 {
let combined_response =
compute_drivers_combined_response(data, gains, crossover_freqs, delays, sample_rate);
let mut sum = 0.0;
let mut count = 0;
for i in 0..data.freq_grid.len() {
let freq = data.freq_grid[i];
if freq >= min_freq && freq <= max_freq {
sum += combined_response[i];
count += 1;
}
}
let mean = if count > 0 { sum / count as f64 } else { 0.0 };
let normalized = &combined_response - mean;
flat_loss(&data.freq_grid, &normalized, min_freq, max_freq)
}
pub fn multisub_flat_loss(
data: &DriversLossData,
gains: &[f64],
delays: &[f64],
sample_rate: f64,
min_freq: f64,
max_freq: f64,
) -> f64 {
let crossover_freqs = vec![];
let combined_response =
compute_drivers_combined_response(data, gains, &crossover_freqs, Some(delays), sample_rate);
let mut sum = 0.0;
let mut count = 0;
for i in 0..data.freq_grid.len() {
let freq = data.freq_grid[i];
if freq >= min_freq && freq <= max_freq {
sum += combined_response[i];
count += 1;
}
}
let mean = if count > 0 { sum / count as f64 } else { 0.0 };
let normalized = &combined_response - mean;
flat_loss(&data.freq_grid, &normalized, min_freq, max_freq)
}
pub fn calculate_standard_deviation_in_range(
freq: &Array1<f64>,
deviation: &Array1<f64>,
fmin: f64,
fmax: f64,
) -> f64 {
assert_eq!(
freq.len(),
deviation.len(),
"freq and deviation must have same length"
);
let mut values = Vec::new();
for i in 0..freq.len() {
let f = freq[i];
if f >= fmin && f <= fmax {
values.push(deviation[i]);
}
}
if values.is_empty() {
return 0.0;
}
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
variance.sqrt()
}
fn calculate_absolute_slope_in_range(
freq: &Array1<f64>,
deviation: &Array1<f64>,
fmin: f64,
fmax: f64,
) -> f64 {
match regression_slope_per_octave_in_range(freq, deviation, fmin, fmax) {
Some(slope) => slope.abs(),
None => 0.0,
}
}
pub fn headphone_loss(curve: &Curve) -> f64 {
let freq = &curve.freq;
let deviation = &curve.spl;
const FMIN: f64 = 50.0;
const FMAX: f64 = 10000.0;
let sd = calculate_standard_deviation_in_range(freq, deviation, FMIN, FMAX);
let as_value = calculate_absolute_slope_in_range(freq, deviation, FMIN, FMAX);
114.49 - (12.62 * sd) - (15.52 * as_value)
}
pub fn headphone_loss_with_target(
data: &HeadphoneLossData,
response: &Curve,
target: &Curve,
) -> f64 {
let freqs = read::create_log_frequency_grid(10 * 12, 20.0, 20000.0);
let input_curve = read::normalize_and_interpolate_response(&freqs, response);
let target_curve = read::normalize_and_interpolate_response(&freqs, target);
let deviation = Curve {
freq: freqs.clone(),
spl: &target_curve.spl - &input_curve.spl,
phase: None,
};
let smooth_deviation = if data.smooth {
read::smooth_one_over_n_octave(&deviation, data.smooth_n)
} else {
deviation.clone()
};
headphone_loss(&smooth_deviation)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
use ndarray::array;
use std::collections::HashMap;
#[test]
fn score_loss_matches_score_when_peq_zero() {
let freq = Array1::from(vec![100.0, 1000.0]);
let on = Array1::from(vec![80.0_f64, 85.0_f64]);
let lw = Array1::from(vec![81.0_f64, 84.0_f64]);
let sp = Array1::from(vec![78.0_f64, 82.0_f64]);
let pir = Array1::from(vec![80.5_f64, 84.0_f64]);
let mut spin: HashMap<String, Curve> = HashMap::new();
spin.insert(
"On Axis".to_string(),
Curve {
freq: freq.clone(),
spl: on.clone(),
phase: None,
},
);
spin.insert(
"Listening Window".to_string(),
Curve {
freq: freq.clone(),
spl: lw.clone(),
phase: None,
},
);
spin.insert(
"Sound Power".to_string(),
Curve {
freq: freq.clone(),
spl: sp.clone(),
phase: None,
},
);
spin.insert(
"Estimated In-Room Response".to_string(),
Curve {
freq: freq.clone(),
spl: pir.clone(),
phase: None,
},
);
let sd = SpeakerLossData::try_new(&spin).expect("test spin data should be valid");
let zero = Array1::zeros(freq.len());
let intervals = super::score::octave_intervals(2, &freq);
let expected = super::score::score(&freq, &intervals, &on, &lw, &sp, &pir);
let got = speaker_score_loss(&sd, &freq, &zero);
if got.is_nan() && expected.pref_score.is_nan() {
} else {
assert!((got - expected.pref_score).abs() < 1e-12);
}
}
#[test]
fn regression_slope_per_octave_linear_log_relation_full_range() {
let freq = Array1::from(vec![100.0, 200.0, 400.0, 800.0]);
let y = freq.mapv(|f: f64| 3.0 * f.log2() + 1.0);
let slope = regression_slope_per_octave_in_range(&freq, &y, 100.0, 800.0).unwrap();
assert!((slope - 3.0).abs() < 1e-12);
}
#[test]
fn regression_slope_per_octave_sub_range() {
let freq = Array1::from(vec![100.0, 200.0, 400.0, 800.0]);
let y = freq.mapv(|f: f64| -2.5 * f.log2() + 4.0);
let slope = regression_slope_per_octave_in_range(&freq, &y, 200.0, 800.0).unwrap();
assert!((slope + 2.5).abs() < 1e-12);
}
#[test]
fn test_calculate_standard_deviation_in_range() {
let freq = Array1::from(vec![50.0, 100.0, 1000.0, 5000.0, 10000.0]);
let deviation = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let sd = calculate_standard_deviation_in_range(&freq, &deviation, 50.0, 10000.0);
let expected_sd = 2.0_f64.sqrt();
assert!(
(sd - expected_sd).abs() < 1e-12,
"SD calculation incorrect: got {}, expected {}",
sd,
expected_sd
);
}
#[test]
fn test_calculate_standard_deviation_filtered_range() {
let freq = Array1::from(vec![20.0, 100.0, 1000.0, 5000.0, 15000.0]); let deviation = Array1::from(vec![10.0, 2.0, 4.0, 6.0, 20.0]);
let sd = calculate_standard_deviation_in_range(&freq, &deviation, 50.0, 10000.0);
let expected_sd = (8.0_f64 / 3.0_f64).sqrt();
assert!(
(sd - expected_sd).abs() < 1e-12,
"SD calculation with filtering incorrect: got {}, expected {}",
sd,
expected_sd
);
}
#[test]
fn test_calculate_absolute_slope_in_range() {
let freq = Array1::from(vec![
50.0, 100.0, 200.0, 400.0, 800.0, 1600.0, 3200.0, 6400.0, 10000.0,
]);
let deviation = freq.mapv(|f: f64| 2.0 * f.log2());
let as_value = calculate_absolute_slope_in_range(&freq, &deviation, 50.0, 10000.0);
assert!(
(as_value - 2.0).abs() < 1e-12,
"AS calculation incorrect: got {}, expected 2.0",
as_value
);
}
#[test]
fn test_calculate_absolute_slope_negative() {
let freq = Array1::from(vec![
50.0, 100.0, 200.0, 400.0, 800.0, 1600.0, 3200.0, 6400.0, 10000.0,
]);
let deviation = freq.mapv(|f: f64| -3.0 * f.log2());
let as_value = calculate_absolute_slope_in_range(&freq, &deviation, 50.0, 10000.0);
assert!(
(as_value - 3.0).abs() < 1e-12,
"AS calculation with negative slope incorrect: got {}, expected 3.0",
as_value
);
}
#[test]
fn test_headphone_loss_perfect_harman_deviation() {
let freq = Array1::from(vec![50.0, 100.0, 1000.0, 5000.0, 10000.0]);
let deviation = Array1::zeros(5);
let curve = Curve {
freq: freq.clone(),
spl: deviation,
phase: None,
};
let score = headphone_loss(&curve);
let expected_score = 114.49;
assert!(
(score - expected_score).abs() < 1e-12,
"Perfect Harman score incorrect: got {}, expected {}",
score,
expected_score
);
}
#[test]
fn test_headphone_loss_with_deviation() {
let freq = Array1::from(vec![50.0, 100.0, 1000.0, 5000.0, 10000.0]);
let deviation = Array1::from(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
let curve = Curve {
freq: freq.clone(),
spl: deviation,
phase: None,
};
let score = headphone_loss(&curve);
let expected_preference = 114.49;
let expected_score = expected_preference;
assert!(
(score - expected_score).abs() < 1e-10,
"Constant deviation score incorrect: got {}, expected {}",
score,
expected_score
);
}
#[test]
fn test_headphone_loss_with_slope() {
let freq = Array1::from(vec![
50.0, 100.0, 200.0, 400.0, 800.0, 1600.0, 3200.0, 6400.0, 10000.0,
]);
let deviation = freq.mapv(|f: f64| 1.0 * f.log2());
let curve = Curve {
freq: freq.clone(),
spl: deviation,
phase: None,
};
let score = headphone_loss(&curve);
assert!(
score > 50.0,
"Sloped deviation should have lower preference: got {}",
score
);
}
#[test]
fn test_headphone_loss_with_target() {
let freq = Array1::logspace(10.0, 1.301, 4.301, 100);
let response = Array1::from_elem(100, 5.0); let target = Array1::from_elem(100, 5.0);
let response_curve = Curve {
freq: freq.clone(),
spl: response,
phase: None,
};
let target_curve = Curve {
freq: freq.clone(),
spl: target,
phase: None,
};
let data = HeadphoneLossData::new(false, 2);
let score = headphone_loss_with_target(&data, &response_curve, &target_curve);
let expected_perfect_score = 114.49;
assert!(
(score - expected_perfect_score).abs() < 1e-10,
"Perfect target match score incorrect: got {}, expected {}",
score,
expected_perfect_score
);
}
#[test]
fn mixed_loss_finite_with_zero_peq() {
let freq = Array1::from(vec![
100.0, 200.0, 400.0, 800.0, 1600.0, 3200.0, 6400.0, 10000.0,
]);
let on = Array1::zeros(freq.len());
let lw = Array1::zeros(freq.len());
let sp = Array1::zeros(freq.len());
let pir = Array1::zeros(freq.len());
let mut spin: HashMap<String, Curve> = HashMap::new();
spin.insert(
"On Axis".to_string(),
Curve {
freq: freq.clone(),
spl: on,
phase: None,
},
);
spin.insert(
"Listening Window".to_string(),
Curve {
freq: freq.clone(),
spl: lw,
phase: None,
},
);
spin.insert(
"Sound Power".to_string(),
Curve {
freq: freq.clone(),
spl: sp,
phase: None,
},
);
spin.insert(
"Estimated In-Room Response".to_string(),
Curve {
freq: freq.clone(),
spl: pir,
phase: None,
},
);
let sd = SpeakerLossData::try_new(&spin).expect("test spin data should be valid");
let peq = Array1::zeros(freq.len());
let v = mixed_loss(&sd, &freq, &peq);
assert!(v.is_finite(), "mixed_loss should be finite, got {}", v);
}
#[test]
fn test_weighted_mse_basic() {
let freqs = array![1000.0, 2000.0, 4000.0, 8000.0];
let err = array![1.0, 1.0, 1.0, 1.0];
let v = weighted_mse(&freqs, &err, 100.0, 10000.0); assert!((v - (1.0 + 1.0 / 3.0)).abs() < 1e-12, "got {}", v);
}
#[test]
fn test_weighted_mse_empty_upper_segment() {
let freqs = array![100.0, 200.0, 500.0];
let err = array![2.0, 2.0, 2.0]; let v = weighted_mse(&freqs, &err, 50.0, 10000.0); assert!((v - 2.0).abs() < 1e-12, "got {}", v);
}
#[test]
fn test_weighted_mse_scaling() {
let freqs = array![1000.0, 1500.0, 4000.0, 6000.0];
let err = array![2.0, 2.0, 3.0, 3.0];
let v = weighted_mse(&freqs, &err, 500.0, 10000.0); let expected = 2.0 + 3.0 / 3.0; assert!((v - expected).abs() < 1e-12, "got {}", v);
}
#[test]
fn test_weighted_mse_frequency_filtering() {
let freqs = array![100.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0, 16000.0];
let err = array![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0];
let v = weighted_mse(&freqs, &err, 1000.0, 4000.0);
let expected = 2.0 + 3.0 / 3.0; assert!(
(v - expected).abs() < 1e-12,
"got {} expected {}",
v,
expected
);
}
#[test]
fn test_weighted_mse_no_frequencies_in_range() {
let freqs = array![100.0, 200.0, 500.0];
let err = array![2.0, 3.0, 1.0];
let v = weighted_mse(&freqs, &err, 1000.0, 5000.0);
assert!(
v.is_infinite(),
"Should return INFINITY when no frequencies in range"
);
}
#[test]
fn test_weighted_mse_partial_range() {
let freqs = array![100.0, 1000.0, 2000.0, 4000.0, 8000.0];
let err = array![1.0, 2.0, 2.0, 3.0, 4.0];
let v = weighted_mse(&freqs, &err, 500.0, 3000.0);
let expected = 2.0;
assert!(
(v - expected).abs() < 1e-12,
"got {} expected {}",
v,
expected
);
}
#[test]
fn test_flat_loss_frequency_filtering() {
let freqs = array![100.0, 1000.0, 2000.0, 4000.0, 8000.0];
let err = array![1.0, 2.0, 2.0, 3.0, 4.0];
let v1 = flat_loss(&freqs, &err, 1000.0, 4000.0);
let v2 = weighted_mse(&freqs, &err, 1000.0, 4000.0);
assert_eq!(v1, v2, "flat_loss should equal weighted_mse");
}
#[test]
fn test_frequency_filtering_boundary_conditions() {
let freqs = array![1000.0, 2000.0, 3000.0];
let err = array![1.0, 1.0, 1.0];
let v = weighted_mse(&freqs, &err, 1000.0, 3000.0);
let expected = 1.0 + 1.0 / 3.0;
assert!(
(v - expected).abs() < 1e-12,
"got {} expected {}",
v,
expected
);
let v2 = weighted_mse(&freqs, &err, 1001.0, 2999.0);
let expected2 = 1.0;
assert!(
(v2 - expected2).abs() < 1e-12,
"got {} expected {}",
v2,
expected2
);
}
#[test]
fn test_headphone_loss_perfect_correction() {
let freq = Array1::logspace(10.0, 1.699, 4.0, 100); let zero_deviation = Array1::zeros(100);
let curve = Curve {
freq: freq.clone(),
spl: zero_deviation,
phase: None,
};
let score = headphone_loss(&curve);
let expected_perfect = 114.49;
assert!(
(score - expected_perfect).abs() < 1e-10,
"Perfect correction score incorrect: got {}, expected {}",
score,
expected_perfect
);
}
#[test]
fn test_headphone_loss_sign_independence() {
let freq = Array1::logspace(10.0, 1.699, 4.0, 100);
let deviation_positive = freq.mapv(|f: f64| 0.5 * f.log2() + 2.0);
let deviation_negative = -&deviation_positive;
let curve_pos = Curve {
freq: freq.clone(),
spl: deviation_positive,
phase: None,
};
let curve_neg = Curve {
freq: freq.clone(),
spl: deviation_negative,
phase: None,
};
let score_pos = headphone_loss(&curve_pos);
let score_neg = headphone_loss(&curve_neg);
assert!(
(score_pos - score_neg).abs() < 1e-10,
"Sign independence violated: pos={}, neg={}",
score_pos,
score_neg
);
}
#[test]
fn test_headphone_loss_worse_than_perfect() {
let freq = Array1::logspace(10.0, 1.699, 4.0, 100);
let zero_deviation = Array1::zeros(100);
let nonzero_deviation = Array1::from_elem(100, 3.0);
let perfect_curve = Curve {
freq: freq.clone(),
spl: zero_deviation,
phase: None,
};
let imperfect_curve = Curve {
freq: freq.clone(),
spl: nonzero_deviation,
phase: None,
};
let perfect_score = headphone_loss(&perfect_curve);
let imperfect_score = headphone_loss(&imperfect_curve);
assert!(
imperfect_score < perfect_score,
"Imperfect correction should score lower: perfect={}, imperfect={}",
perfect_score,
imperfect_score
);
assert!(
(perfect_score - 114.49).abs() < 1e-10,
"Perfect score should be 114.49, got {}",
perfect_score
);
}
#[test]
fn test_weighted_mse_treble_only() {
let freqs = Array1::from_vec(vec![4000.0, 8000.0, 16000.0]);
let error = Array1::from_vec(vec![2.0, 2.0, 2.0]);
let result = weighted_mse(&freqs, &error, 4000.0, 16000.0);
assert!(
(result - 2.0).abs() < 1e-10,
"treble-only loss should be full RMS, got {}",
result
);
}
#[test]
fn test_weighted_mse_bass_only() {
let freqs = Array1::from_vec(vec![100.0, 500.0, 2000.0]);
let error = Array1::from_vec(vec![3.0, 3.0, 3.0]);
let result = weighted_mse(&freqs, &error, 100.0, 2000.0);
assert!(
(result - 3.0).abs() < 1e-10,
"bass-only loss should be full RMS, got {}",
result
);
}
#[test]
fn test_weighted_mse_both_bands() {
let freqs = Array1::from_vec(vec![100.0, 1000.0, 5000.0, 10000.0]);
let error = Array1::from_vec(vec![3.0, 3.0, 6.0, 6.0]);
let result = weighted_mse(&freqs, &error, 100.0, 10000.0);
assert!(
(result - 5.0).abs() < 1e-10,
"two-band loss incorrect, got {}",
result
);
}
#[test]
fn test_regression_slope_identical_freqs() {
let freq = Array1::from_vec(vec![1000.0, 1000.0, 1000.0]);
let y = Array1::from_vec(vec![80.0, 85.0, 90.0]);
let result = regression_slope_per_octave_in_range(&freq, &y, 999.0, 1001.0);
assert!(
result.is_none(),
"identical frequencies should return None, got {:?}",
result
);
}
#[test]
fn test_bass_asymmetry_penalizes_peaks_heavily() {
let config = AsymmetricLossConfig::default();
let freqs_bass = Array1::from_vec(vec![80.0]);
let error_bass = Array1::from_vec(vec![10.0]);
let loss_bass = weighted_mse_asymmetric_with_split(
&freqs_bass,
&error_bass,
20.0,
20000.0,
&config,
3000.0,
);
let freqs_mid = Array1::from_vec(vec![2000.0]);
let error_mid = Array1::from_vec(vec![10.0]);
let loss_mid = weighted_mse_asymmetric_with_split(
&freqs_mid,
&error_mid,
20.0,
20000.0,
&config,
3000.0,
);
let ratio = loss_bass / loss_mid;
let expected_ratio = (5.0_f64 / 2.0).sqrt();
assert!(
(ratio - expected_ratio).abs() < 0.1,
"bass peak penalty ratio should be ~{:.2}x (sqrt(5/2)), got {:.2}",
expected_ratio,
ratio
);
}
#[test]
fn test_bass_dips_nearly_ignored() {
let config = AsymmetricLossConfig::default();
let freqs = Array1::from_vec(vec![80.0]);
let error_dip = Array1::from_vec(vec![-10.0]);
let loss_dip = weighted_mse_asymmetric_with_split(
&freqs,
&error_dip,
20.0,
20000.0,
&config,
3000.0,
);
let error_peak = Array1::from_vec(vec![10.0]);
let loss_peak = weighted_mse_asymmetric_with_split(
&freqs,
&error_peak,
20.0,
20000.0,
&config,
3000.0,
);
let ratio = loss_dip / loss_peak;
assert!(
ratio < 0.25,
"bass dip penalty should be much smaller than bass peak, ratio={:.4}",
ratio
);
}
}