use self::de::optimize_filters_autoeq_with_callback;
use super::cli::PeqModel;
use super::constraints::{viol_ceiling_from_spl, viol_min_gain_from_xs, viol_spacing_from_xs};
use super::loss::{
DriversLossData, HeadphoneLossData, LossType, SpeakerLossData, drivers_flat_loss, flat_loss,
flat_loss_asymmetric, headphone_loss, speaker_score_loss,
};
use super::x2peq::x2spl;
use crate::Curve;
use ndarray::Array1;
pub mod backend;
pub mod callback;
pub mod cobyla;
pub mod constraints_install;
pub mod de;
pub mod isres;
pub mod mh;
pub mod params;
pub mod pareto;
pub mod registry;
pub mod setup;
pub use backend::{AlgorithmType, ConstraintCapabilities, FilterOptimizer};
#[derive(Debug, Clone)]
pub struct AlgorithmInfo {
pub name: &'static str,
pub library: &'static str,
pub algorithm_type: AlgorithmType,
pub supports_linear_constraints: bool,
pub supports_nonlinear_constraints: bool,
}
impl AlgorithmInfo {
fn from_backend(backend: &dyn FilterOptimizer) -> Self {
let caps = backend.capabilities();
Self {
name: backend.name(),
library: backend.library(),
algorithm_type: backend.algorithm_type(),
supports_linear_constraints: caps.linear,
supports_nonlinear_constraints: caps.nonlinear_ineq,
}
}
}
pub fn get_all_algorithms() -> Vec<AlgorithmInfo> {
registry::all_algorithms()
.iter()
.map(|b| AlgorithmInfo::from_backend(b.as_ref()))
.collect()
}
pub fn find_algorithm_info(name: &str) -> Option<AlgorithmInfo> {
registry::resolve(name).map(|b| AlgorithmInfo::from_backend(b.as_ref()))
}
#[derive(Debug, Clone)]
pub struct ObjectiveData {
pub freqs: Array1<f64>,
pub target: Array1<f64>,
pub deviation: Array1<f64>,
pub srate: f64,
#[allow(dead_code)]
pub min_spacing_oct: f64,
pub spacing_weight: f64,
pub max_db: f64,
pub min_db: f64,
pub min_freq: f64,
pub max_freq: f64,
pub peq_model: PeqModel,
pub loss_type: LossType,
pub speaker_score_data: Option<SpeakerLossData>,
pub headphone_score_data: Option<HeadphoneLossData>,
pub input_curve: Option<Curve>,
pub drivers_data: Option<DriversLossData>,
pub fixed_crossover_freqs: Option<Vec<f64>>,
pub penalty_w_ceiling: f64,
pub penalty_w_spacing: f64,
pub penalty_w_mingain: f64,
pub integrality: Option<Vec<bool>>,
pub multi_objective: Option<MultiObjectiveData>,
pub smooth: bool,
pub smooth_n: usize,
pub max_boost_envelope: Option<Vec<(f64, f64)>>,
pub min_cut_envelope: Option<Vec<(f64, f64)>>,
pub epa_config: Option<crate::loss::epa::score::EpaConfig>,
pub detected_problems: Vec<(f64, f64, f64)>,
pub null_suppression: Option<Array1<f64>>,
}
#[derive(Debug, Clone)]
pub struct MultiObjectiveData {
pub objectives: Vec<ObjectiveData>,
pub strategy: crate::roomeq::MultiMeasurementStrategy,
pub weights: Vec<f64>,
pub variance_lambda: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PenaltyMode {
Disabled,
Standard,
Pso,
}
impl PenaltyMode {
pub const fn ceiling_weight(&self) -> f64 {
match self {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e4,
PenaltyMode::Pso => 5e2,
}
}
pub const fn mingain_weight(&self) -> f64 {
match self {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e3,
PenaltyMode::Pso => 50.0,
}
}
}
impl ObjectiveData {
pub fn configure_penalties(&mut self, mode: PenaltyMode) {
self.penalty_w_ceiling = mode.ceiling_weight();
self.penalty_w_mingain = mode.mingain_weight();
let spacing_scale = match mode {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e3,
PenaltyMode::Pso => 5e2,
};
self.penalty_w_spacing = self.spacing_weight.max(0.0) * spacing_scale;
}
}
fn compute_multi_objective_fitness(x: &[f64], mo: &MultiObjectiveData) -> f64 {
use crate::roomeq::MultiMeasurementStrategy;
let losses: Vec<f64> = mo
.objectives
.iter()
.map(|obj| compute_base_fitness_single(x, obj))
.collect();
match mo.strategy {
MultiMeasurementStrategy::Average => {
let sum: f64 = losses.iter().sum();
sum / losses.len() as f64
}
MultiMeasurementStrategy::WeightedSum => {
losses.iter().zip(&mo.weights).map(|(l, w)| l * w).sum()
}
MultiMeasurementStrategy::Minimax => {
losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
}
MultiMeasurementStrategy::VariancePenalized => {
let n = losses.len() as f64;
let mean = losses.iter().sum::<f64>() / n;
let variance = losses.iter().map(|l| (l - mean).powi(2)).sum::<f64>() / n;
mean + mo.variance_lambda * variance
}
MultiMeasurementStrategy::SpatialRobustness => {
unreachable!("SpatialRobustness strategy should not use multi-objective loss path")
}
}
}
pub fn clamp_gains_to_envelope(
x: &[f64],
envelope: &[(f64, f64)],
peq_model: PeqModel,
) -> Vec<f64> {
use crate::param_utils;
let mut clamped = x.to_vec();
let num_filters = param_utils::num_filters(x, peq_model);
for i in 0..num_filters {
let params = param_utils::get_filter_params(x, i, peq_model);
let freq_hz = 10f64.powf(params.freq);
if params.gain > 0.0 {
let max_boost = interpolate_boost_envelope(envelope, freq_hz);
if params.gain > max_boost {
let ppf = param_utils::params_per_filter(peq_model);
let gain_idx = i * ppf + (ppf - 1);
clamped[gain_idx] = max_boost;
}
}
}
clamped
}
fn interpolate_boost_envelope(envelope: &[(f64, f64)], freq_hz: f64) -> f64 {
if envelope.is_empty() {
return f64::INFINITY;
}
if freq_hz <= envelope[0].0 {
return envelope[0].1;
}
let last = envelope.len() - 1;
if freq_hz >= envelope[last].0 {
return envelope[last].1;
}
for i in 0..last {
let (f0, db0) = envelope[i];
let (f1, db1) = envelope[i + 1];
if freq_hz >= f0 && freq_hz <= f1 {
let t = (freq_hz.ln() - f0.ln()) / (f1.ln() - f0.ln());
return db0 + t * (db1 - db0);
}
}
envelope[last].1
}
pub fn clamp_cuts_to_envelope(x: &[f64], envelope: &[(f64, f64)], peq_model: PeqModel) -> Vec<f64> {
use crate::param_utils;
let mut clamped = x.to_vec();
let num_filters = param_utils::num_filters(x, peq_model);
for i in 0..num_filters {
let params = param_utils::get_filter_params(x, i, peq_model);
let freq_hz = 10f64.powf(params.freq);
if params.gain < 0.0 {
let max_cut = interpolate_boost_envelope(envelope, freq_hz); if params.gain < max_cut {
let ppf = param_utils::params_per_filter(peq_model);
let gain_idx = i * ppf + (ppf - 1);
clamped[gain_idx] = max_cut;
}
}
}
clamped
}
fn compute_base_fitness_single(x: &[f64], data: &ObjectiveData) -> f64 {
let clamped_boost;
let clamped_cut;
let x = {
let skip = matches!(
data.loss_type,
LossType::DriversFlat | LossType::MultiSubFlat
);
let x = if !skip && let Some(ref env) = data.max_boost_envelope {
clamped_boost = clamp_gains_to_envelope(x, env, data.peq_model);
&clamped_boost
} else {
x
};
if !skip && let Some(ref env) = data.min_cut_envelope {
clamped_cut = clamp_cuts_to_envelope(x, env, data.peq_model);
&clamped_cut
} else {
x
}
};
match data.loss_type {
LossType::DriversFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
let xover_freqs: Vec<f64> = if let Some(ref fixed) = data.fixed_crossover_freqs {
fixed.clone()
} else {
let xover_freqs_log10 = &x[2 * n_drivers..];
xover_freqs_log10
.iter()
.map(|f| 10.0_f64.powf(*f))
.collect()
};
drivers_flat_loss(
drivers_data,
gains,
&xover_freqs,
Some(delays),
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("drivers-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::MultiSubFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
crate::loss::multisub_flat_loss(
drivers_data,
gains,
delays,
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("multi-sub-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::HeadphoneFlat | LossType::SpeakerFlat => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
..Default::default()
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerFlatAsymmetric => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let null_mask = data.null_suppression.as_ref();
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
..Default::default()
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss_asymmetric(
&data.freqs,
&smoothed.spl,
data.min_freq,
data.max_freq,
null_mask,
)
} else {
flat_loss_asymmetric(&data.freqs, &error, data.min_freq, data.max_freq, null_mask)
}
}
LossType::SpeakerScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref sd) = data.speaker_score_data {
let error = &peq_spl - &data.deviation;
let s = speaker_score_loss(sd, &data.freqs, &peq_spl);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq) / 3.0;
100.0 - s + p
} else {
log::error!("speaker score loss requested but score data is missing");
f64::INFINITY
}
}
LossType::HeadphoneScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref _hd) = data.headphone_score_data {
let error = &data.deviation - &peq_spl;
let error_curve = Curve {
freq: data.freqs.clone(),
spl: error.clone(),
phase: None,
..Default::default()
};
let s = headphone_loss(&error_curve);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
1000.0 - s + p * 20.0
} else {
log::error!("headphone score loss requested but headphone data is missing");
f64::INFINITY
}
}
LossType::Epa => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let epa_config = data.epa_config.clone().unwrap_or_default();
let flatness = crate::loss::epa::score::epa_flatness(
&data.freqs,
&error,
data.min_freq,
data.max_freq,
&epa_config,
);
let freqs_vec: Vec<f64> = data.freqs.iter().copied().collect();
let corrected_spl: Vec<f64> = data
.freqs
.iter()
.enumerate()
.map(|(i, _)| data.target[i] + data.deviation[i] + peq_spl[i])
.collect();
crate::loss::epa::score::epa_loss_normalized(
&freqs_vec,
&corrected_spl,
&epa_config,
flatness,
)
}
}
}
pub fn compute_base_fitness(x: &[f64], data: &ObjectiveData) -> f64 {
if let Some(ref mo) = data.multi_objective {
return compute_multi_objective_fitness(x, mo);
}
match data.loss_type {
LossType::DriversFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
let xover_freqs: Vec<f64> = if let Some(ref fixed) = data.fixed_crossover_freqs {
fixed.clone()
} else {
let xover_freqs_log10 = &x[2 * n_drivers..];
xover_freqs_log10
.iter()
.map(|f| 10.0_f64.powf(*f))
.collect()
};
drivers_flat_loss(
drivers_data,
gains,
&xover_freqs,
Some(delays),
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("drivers-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::MultiSubFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
crate::loss::multisub_flat_loss(
drivers_data,
gains,
delays,
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("multi-sub-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::HeadphoneFlat | LossType::SpeakerFlat => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
..Default::default()
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerFlatAsymmetric => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let null_mask = data.null_suppression.as_ref();
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
..Default::default()
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss_asymmetric(
&data.freqs,
&smoothed.spl,
data.min_freq,
data.max_freq,
null_mask,
)
} else {
flat_loss_asymmetric(&data.freqs, &error, data.min_freq, data.max_freq, null_mask)
}
}
LossType::SpeakerScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref sd) = data.speaker_score_data {
let error = &peq_spl - &data.deviation;
let s = speaker_score_loss(sd, &data.freqs, &peq_spl);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq) / 3.0;
100.0 - s + p
} else {
log::error!("speaker score loss requested but score data is missing");
f64::INFINITY
}
}
LossType::HeadphoneScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref _hd) = data.headphone_score_data {
let error = &data.deviation - &peq_spl;
let error_curve = Curve {
freq: data.freqs.clone(),
spl: error.clone(),
phase: None,
..Default::default()
};
let s = headphone_loss(&error_curve);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
1000.0 - s + p * 20.0
} else {
log::error!("headphone score loss requested but headphone data is missing");
f64::INFINITY
}
}
LossType::Epa => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let flatness = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
let freqs_vec: Vec<f64> = data.freqs.iter().copied().collect();
let corrected_spl: Vec<f64> = data
.freqs
.iter()
.enumerate()
.map(|(i, _)| data.target[i] + data.deviation[i] + peq_spl[i])
.collect();
let epa_config = data.epa_config.clone().unwrap_or_default();
crate::loss::epa::score::epa_loss_normalized(
&freqs_vec,
&corrected_spl,
&epa_config,
flatness,
)
}
}
}
pub fn compute_fitness_penalties_ref(x: &[f64], data: &ObjectiveData) -> f64 {
let fit = compute_base_fitness(x, data);
let is_peq_loss = !matches!(
data.loss_type,
LossType::DriversFlat | LossType::MultiSubFlat
);
let mut penalized = fit;
if data.penalty_w_ceiling > 0.0 && is_peq_loss {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let viol = viol_ceiling_from_spl(&peq_spl, data.max_db, data.peq_model);
penalized += data.penalty_w_ceiling * viol * viol;
}
if data.penalty_w_spacing > 0.0 && is_peq_loss {
let viol = viol_spacing_from_xs(x, data.peq_model, data.min_spacing_oct);
penalized += data.penalty_w_spacing * viol * viol;
}
if data.penalty_w_mingain > 0.0 && data.min_db > 0.0 && is_peq_loss {
let viol = viol_min_gain_from_xs(x, data.peq_model, data.min_db);
penalized += data.penalty_w_mingain * viol * viol;
}
penalized
}
pub fn compute_fitness_penalties(
x: &[f64],
_gradient: Option<&mut [f64]>,
data: &mut ObjectiveData,
) -> f64 {
compute_fitness_penalties_ref(x, data)
}
pub fn optimize_filters(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
params: &crate::OptimParams,
) -> Result<(String, f64), (String, f64)> {
optimize_filters_with_algo_override(x, lower_bounds, upper_bounds, objective_data, params, None)
}
pub fn optimize_filters_with_algo_override(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
params: &crate::OptimParams,
algo_override: Option<&str>,
) -> Result<(String, f64), (String, f64)> {
let algo = algo_override.unwrap_or(¶ms.algo);
let backend = registry::resolve(algo)
.ok_or_else(|| (format!("Unknown algorithm: {}", algo), f64::INFINITY))?;
backend.optimize(x, lower_bounds, upper_bounds, objective_data, params, None)
}
pub type OptimProgressCallback =
Box<dyn FnMut(usize, f64, Option<f64>) -> crate::de::CallbackAction + Send>;
pub fn optimize_filters_with_callback(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
params: &crate::OptimParams,
callback: OptimProgressCallback,
) -> Result<(String, f64), (String, f64)> {
let backend = registry::resolve(¶ms.algo)
.ok_or_else(|| (format!("Unknown algorithm: {}", params.algo), f64::INFINITY))?;
if backend.name().eq_ignore_ascii_case("autoeq:de") {
return run_autoeq_de_with_epa_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
params,
backend.name(),
callback,
);
}
let cb_for_backend: Option<OptimProgressCallback> = if backend.capabilities().iteration_callback
{
Some(callback)
} else {
None
};
backend.optimize(
x,
lower_bounds,
upper_bounds,
objective_data,
params,
cb_for_backend,
)
}
fn run_autoeq_de_with_epa_callback(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
params: &crate::OptimParams,
autoeq_name: &str,
mut callback: OptimProgressCallback,
) -> Result<(String, f64), (String, f64)> {
const EPA_INTERVAL: usize = 10;
let epa_config = objective_data.epa_config.clone();
let epa_freqs =
ndarray::Array1::from(objective_data.freqs.iter().copied().collect::<Vec<f64>>());
let epa_normalized: Vec<f64> = objective_data
.target
.iter()
.zip(objective_data.deviation.iter())
.map(|(&t, &d)| t - d)
.collect();
let epa_srate = objective_data.srate;
let epa_model = objective_data.peq_model;
let mut epa_gen_counter: usize = 0;
let de_cb: Box<dyn FnMut(&crate::de::DEIntermediate) -> crate::de::CallbackAction + Send> =
Box::new(move |intermediate| {
epa_gen_counter += 1;
let epa = if epa_gen_counter.is_multiple_of(EPA_INTERVAL) {
let peq_spl = x2spl(
&epa_freqs,
intermediate.x.as_slice().unwrap(),
epa_srate,
epa_model,
);
let corrected: Vec<f64> = epa_normalized
.iter()
.enumerate()
.map(|(i, &n)| n + peq_spl[i])
.collect();
let cfg = epa_config.clone().unwrap_or_default();
let score = crate::loss::epa::score::compute_epa_normalized(
epa_freqs.as_slice().unwrap(),
&corrected,
&cfg,
);
Some(score.preference)
} else {
None
};
callback(intermediate.iter, intermediate.fun, epa)
});
optimize_filters_autoeq_with_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
autoeq_name,
params,
de_cb,
)
}
pub fn compute_sorted_freqs_and_adjacent_octave_spacings(
x: &[f64],
peq_model: PeqModel,
) -> (Vec<f64>, Vec<f64>) {
let n = crate::param_utils::num_filters(x, peq_model);
let mut freqs: Vec<f64> = Vec::with_capacity(n);
for i in 0..n {
let params = crate::param_utils::get_filter_params(x, i, peq_model);
freqs.push(10f64.powf(params.freq));
}
freqs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let spacings: Vec<f64> = if freqs.len() < 2 {
Vec::new()
} else {
freqs
.windows(2)
.map(|w| (w[1].max(1e-9) / w[0].max(1e-9)).log2().abs())
.collect()
};
(freqs, spacings)
}
#[cfg(test)]
mod dispatch_tests {
use super::*;
#[test]
fn autoeq_cobyla_and_isres_have_own_names() {
let cobyla = registry::resolve("autoeq:cobyla").expect("autoeq:cobyla missing");
assert_eq!(cobyla.name(), "autoeq:cobyla");
assert_eq!(cobyla.library(), "AutoEQ");
let isres = registry::resolve("autoeq:isres").expect("autoeq:isres missing");
assert_eq!(isres.name(), "autoeq:isres");
assert_eq!(isres.library(), "AutoEQ");
let de = registry::resolve("autoeq:de").expect("autoeq:de missing");
assert_eq!(de.name(), "autoeq:de");
assert_eq!(de.library(), "AutoEQ");
assert_ne!(cobyla.name(), de.name());
assert_ne!(isres.name(), de.name());
}
}
#[cfg(test)]
mod spacing_diag_tests {
use super::compute_sorted_freqs_and_adjacent_octave_spacings;
#[test]
fn adjacent_octave_spacings_basic() {
let x = [
100f64.log10(),
1.0,
0.0,
200f64.log10(),
1.0,
0.0,
400f64.log10(),
1.0,
0.0,
];
use crate::cli::PeqModel;
let (_freqs, spacings) =
compute_sorted_freqs_and_adjacent_octave_spacings(&x, PeqModel::Pk);
assert!((spacings[0] - 1.0).abs() < 1e-12);
assert!((spacings[1] - 1.0).abs() < 1e-12);
}
}