use super::cli::PeqModel;
use super::constraints::{viol_ceiling_from_spl, viol_min_gain_from_xs, viol_spacing_from_xs};
use super::loss::{
DriversLossData, HeadphoneLossData, LossType, SpeakerLossData, drivers_flat_loss, flat_loss,
flat_loss_asymmetric, headphone_loss, speaker_score_loss,
};
use super::optim_de::{optimize_filters_autoeq, optimize_filters_autoeq_with_callback};
use super::optim_mh::{optimize_filters_mh, optimize_filters_mh_with_callback};
#[cfg(feature = "nlopt")]
use super::optim_nlopt::optimize_filters_nlopt;
use super::x2peq::x2spl;
use crate::Curve;
use ndarray::Array1;
#[cfg(feature = "nlopt")]
use nlopt::Algorithm;
pub mod pareto;
#[derive(Debug, Clone)]
pub struct AlgorithmInfo {
pub name: &'static str,
pub library: &'static str,
pub algorithm_type: AlgorithmType,
pub supports_linear_constraints: bool,
pub supports_nonlinear_constraints: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AlgorithmType {
Global,
Local,
}
pub fn get_all_algorithms() -> Vec<AlgorithmInfo> {
let algorithms = vec![
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:isres",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: true,
supports_nonlinear_constraints: true,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:ags",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: true,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:origdirect",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: true,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:crs2lm",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:direct",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:directl",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:gmlsl",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:gmlsllds",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:sbplx",
library: "NLOPT",
algorithm_type: AlgorithmType::Local,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:slsqp",
library: "NLOPT",
algorithm_type: AlgorithmType::Local,
supports_linear_constraints: true,
supports_nonlinear_constraints: true,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:stogo",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:stogorand",
library: "NLOPT",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:bobyqa",
library: "NLOPT",
algorithm_type: AlgorithmType::Local,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:cobyla",
library: "NLOPT",
algorithm_type: AlgorithmType::Local,
supports_linear_constraints: true,
supports_nonlinear_constraints: true,
},
#[cfg(feature = "nlopt")]
AlgorithmInfo {
name: "nlopt:neldermead",
library: "NLOPT",
algorithm_type: AlgorithmType::Local,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "mh:de",
library: "Metaheuristics",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "mh:pso",
library: "Metaheuristics",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "mh:rga",
library: "Metaheuristics",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "mh:tlbo",
library: "Metaheuristics",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "mh:firefly",
library: "Metaheuristics",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: false,
supports_nonlinear_constraints: false,
},
AlgorithmInfo {
name: "autoeq:de",
library: "AutoEQ",
algorithm_type: AlgorithmType::Global,
supports_linear_constraints: true,
supports_nonlinear_constraints: true,
},
];
algorithms
}
pub fn find_algorithm_info(name: &str) -> Option<AlgorithmInfo> {
let algorithms = get_all_algorithms();
if let Some(algo) = algorithms
.iter()
.find(|a| a.name.eq_ignore_ascii_case(name))
{
return Some(algo.clone());
}
let name_lower = name.to_lowercase();
for algo in &algorithms {
if let Some(suffix) = algo.name.split(':').nth(1)
&& suffix.eq_ignore_ascii_case(&name_lower)
{
return Some(algo.clone());
}
}
None
}
#[derive(Debug, Clone)]
pub struct ObjectiveData {
pub freqs: Array1<f64>,
pub target: Array1<f64>,
pub deviation: Array1<f64>,
pub srate: f64,
#[allow(dead_code)]
pub min_spacing_oct: f64,
pub spacing_weight: f64,
pub max_db: f64,
pub min_db: f64,
pub min_freq: f64,
pub max_freq: f64,
pub peq_model: PeqModel,
pub loss_type: LossType,
pub speaker_score_data: Option<SpeakerLossData>,
pub headphone_score_data: Option<HeadphoneLossData>,
pub input_curve: Option<Curve>,
pub drivers_data: Option<DriversLossData>,
pub fixed_crossover_freqs: Option<Vec<f64>>,
pub penalty_w_ceiling: f64,
pub penalty_w_spacing: f64,
pub penalty_w_mingain: f64,
pub integrality: Option<Vec<bool>>,
pub multi_objective: Option<MultiObjectiveData>,
pub smooth: bool,
pub smooth_n: usize,
pub max_boost_envelope: Option<Vec<(f64, f64)>>,
pub min_cut_envelope: Option<Vec<(f64, f64)>>,
}
#[derive(Debug, Clone)]
pub struct MultiObjectiveData {
pub objectives: Vec<ObjectiveData>,
pub strategy: crate::roomeq::MultiMeasurementStrategy,
pub weights: Vec<f64>,
pub variance_lambda: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PenaltyMode {
Disabled,
Standard,
Pso,
}
impl PenaltyMode {
pub const fn ceiling_weight(&self) -> f64 {
match self {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e4,
PenaltyMode::Pso => 5e2,
}
}
pub const fn mingain_weight(&self) -> f64 {
match self {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e3,
PenaltyMode::Pso => 50.0,
}
}
}
impl ObjectiveData {
pub fn configure_penalties(&mut self, mode: PenaltyMode) {
self.penalty_w_ceiling = mode.ceiling_weight();
self.penalty_w_mingain = mode.mingain_weight();
let spacing_scale = match mode {
PenaltyMode::Disabled => 0.0,
PenaltyMode::Standard => 1e3,
PenaltyMode::Pso => 5e2,
};
self.penalty_w_spacing = self.spacing_weight.max(0.0) * spacing_scale;
}
}
#[derive(Debug, Clone)]
pub enum AlgorithmCategory {
#[cfg(feature = "nlopt")]
Nlopt(Algorithm),
Metaheuristics(String),
AutoEQ(String),
}
pub fn parse_algorithm_name(name: &str) -> Option<AlgorithmCategory> {
if let Some(algo_info) = find_algorithm_info(name) {
let normalized_name = algo_info.name;
#[cfg(feature = "nlopt")]
if normalized_name.starts_with("nlopt:") {
let nlopt_name = normalized_name.strip_prefix("nlopt:").unwrap();
let nlopt_algo = match nlopt_name {
"bobyqa" => Algorithm::Bobyqa,
"cobyla" => Algorithm::Cobyla,
"neldermead" => Algorithm::Neldermead,
"isres" => Algorithm::Isres,
"ags" => Algorithm::Ags,
"origdirect" => Algorithm::OrigDirect,
"crs2lm" => Algorithm::Crs2Lm,
"direct" => Algorithm::Direct,
"directl" => Algorithm::DirectL,
"gmlsl" => Algorithm::GMlsl,
"gmlsllds" => Algorithm::GMlslLds,
"sbplx" => Algorithm::Sbplx,
"slsqp" => Algorithm::Slsqp,
"stogo" => Algorithm::StoGo,
"stogorand" => Algorithm::StoGoRand,
_ => Algorithm::Isres, };
return Some(AlgorithmCategory::Nlopt(nlopt_algo));
}
if normalized_name.starts_with("mh:") {
let mh_name = normalized_name.strip_prefix("mh:").unwrap();
return Some(AlgorithmCategory::Metaheuristics(mh_name.to_string()));
} else if normalized_name.starts_with("autoeq:") {
let autoeq_name = normalized_name.strip_prefix("autoeq:").unwrap();
return Some(AlgorithmCategory::AutoEQ(autoeq_name.to_string()));
}
}
None
}
fn compute_multi_objective_fitness(x: &[f64], mo: &MultiObjectiveData) -> f64 {
use crate::roomeq::MultiMeasurementStrategy;
let losses: Vec<f64> = mo
.objectives
.iter()
.map(|obj| compute_base_fitness_single(x, obj))
.collect();
match mo.strategy {
MultiMeasurementStrategy::Average => {
let sum: f64 = losses.iter().sum();
sum / losses.len() as f64
}
MultiMeasurementStrategy::WeightedSum => {
losses.iter().zip(&mo.weights).map(|(l, w)| l * w).sum()
}
MultiMeasurementStrategy::Minimax => {
losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
}
MultiMeasurementStrategy::VariancePenalized => {
let n = losses.len() as f64;
let mean = losses.iter().sum::<f64>() / n;
let variance = losses.iter().map(|l| (l - mean).powi(2)).sum::<f64>() / n;
mean + mo.variance_lambda * variance
}
MultiMeasurementStrategy::SpatialRobustness => {
unreachable!("SpatialRobustness strategy should not use multi-objective loss path")
}
}
}
pub fn clamp_gains_to_envelope(x: &[f64], envelope: &[(f64, f64)], peq_model: PeqModel) -> Vec<f64> {
use crate::param_utils;
let mut clamped = x.to_vec();
let num_filters = param_utils::num_filters(x, peq_model);
for i in 0..num_filters {
let params = param_utils::get_filter_params(x, i, peq_model);
let freq_hz = 10f64.powf(params.freq);
if params.gain > 0.0 {
let max_boost = interpolate_boost_envelope(envelope, freq_hz);
if params.gain > max_boost {
let ppf = param_utils::params_per_filter(peq_model);
let gain_idx = i * ppf + (ppf - 1);
clamped[gain_idx] = max_boost;
}
}
}
clamped
}
fn interpolate_boost_envelope(envelope: &[(f64, f64)], freq_hz: f64) -> f64 {
if envelope.is_empty() {
return f64::INFINITY;
}
if freq_hz <= envelope[0].0 {
return envelope[0].1;
}
let last = envelope.len() - 1;
if freq_hz >= envelope[last].0 {
return envelope[last].1;
}
for i in 0..last {
let (f0, db0) = envelope[i];
let (f1, db1) = envelope[i + 1];
if freq_hz >= f0 && freq_hz <= f1 {
let t = (freq_hz.ln() - f0.ln()) / (f1.ln() - f0.ln());
return db0 + t * (db1 - db0);
}
}
envelope[last].1
}
pub fn clamp_cuts_to_envelope(x: &[f64], envelope: &[(f64, f64)], peq_model: PeqModel) -> Vec<f64> {
use crate::param_utils;
let mut clamped = x.to_vec();
let num_filters = param_utils::num_filters(x, peq_model);
for i in 0..num_filters {
let params = param_utils::get_filter_params(x, i, peq_model);
let freq_hz = 10f64.powf(params.freq);
if params.gain < 0.0 {
let max_cut = interpolate_boost_envelope(envelope, freq_hz); if params.gain < max_cut {
let ppf = param_utils::params_per_filter(peq_model);
let gain_idx = i * ppf + (ppf - 1);
clamped[gain_idx] = max_cut;
}
}
}
clamped
}
fn compute_base_fitness_single(x: &[f64], data: &ObjectiveData) -> f64 {
let clamped_boost;
let clamped_cut;
let x = {
let skip = matches!(data.loss_type, LossType::DriversFlat | LossType::MultiSubFlat);
let x = if !skip && let Some(ref env) = data.max_boost_envelope {
clamped_boost = clamp_gains_to_envelope(x, env, data.peq_model);
&clamped_boost
} else {
x
};
if !skip && let Some(ref env) = data.min_cut_envelope {
clamped_cut = clamp_cuts_to_envelope(x, env, data.peq_model);
&clamped_cut
} else {
x
}
};
match data.loss_type {
LossType::DriversFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
let xover_freqs: Vec<f64> = if let Some(ref fixed) = data.fixed_crossover_freqs {
fixed.clone()
} else {
let xover_freqs_log10 = &x[2 * n_drivers..];
xover_freqs_log10
.iter()
.map(|f| 10.0_f64.powf(*f))
.collect()
};
drivers_flat_loss(
drivers_data,
gains,
&xover_freqs,
Some(delays),
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("drivers-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::MultiSubFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
crate::loss::multisub_flat_loss(
drivers_data,
gains,
delays,
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("multi-sub-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::HeadphoneFlat | LossType::SpeakerFlat => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerFlatAsymmetric => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss_asymmetric(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss_asymmetric(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref sd) = data.speaker_score_data {
let error = &peq_spl - &data.deviation;
let s = speaker_score_loss(sd, &data.freqs, &peq_spl);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq) / 3.0;
100.0 - s + p
} else {
log::error!("speaker score loss requested but score data is missing");
f64::INFINITY
}
}
LossType::HeadphoneScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref _hd) = data.headphone_score_data {
let error = &data.deviation - &peq_spl;
let error_curve = Curve {
freq: data.freqs.clone(),
spl: error.clone(),
phase: None,
};
let s = headphone_loss(&error_curve);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
1000.0 - s + p * 20.0
} else {
log::error!("headphone score loss requested but headphone data is missing");
f64::INFINITY
}
}
LossType::Epa => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let flatness = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
let freqs_vec: Vec<f64> = data.freqs.iter().copied().collect();
let corrected_spl: Vec<f64> = data
.freqs
.iter()
.enumerate()
.map(|(i, _)| data.target[i] + data.deviation[i] + peq_spl[i])
.collect();
let epa_config = crate::epa::score::EpaConfig::default();
crate::epa::score::epa_loss(&freqs_vec, &corrected_spl, &epa_config, flatness)
}
}
}
pub fn compute_base_fitness(x: &[f64], data: &ObjectiveData) -> f64 {
if let Some(ref mo) = data.multi_objective {
return compute_multi_objective_fitness(x, mo);
}
match data.loss_type {
LossType::DriversFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
let xover_freqs: Vec<f64> = if let Some(ref fixed) = data.fixed_crossover_freqs {
fixed.clone()
} else {
let xover_freqs_log10 = &x[2 * n_drivers..];
xover_freqs_log10
.iter()
.map(|f| 10.0_f64.powf(*f))
.collect()
};
drivers_flat_loss(
drivers_data,
gains,
&xover_freqs,
Some(delays),
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("drivers-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::MultiSubFlat => {
if let Some(ref drivers_data) = data.drivers_data {
let n_drivers = drivers_data.drivers.len();
let gains = &x[0..n_drivers];
let delays = &x[n_drivers..2 * n_drivers];
crate::loss::multisub_flat_loss(
drivers_data,
gains,
delays,
data.srate,
data.min_freq,
data.max_freq,
)
} else {
log::error!("multi-sub-flat loss requested but driver data is missing");
f64::INFINITY
}
}
LossType::HeadphoneFlat | LossType::SpeakerFlat => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerFlatAsymmetric => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
if data.smooth {
let curve = Curve {
freq: data.freqs.clone(),
spl: error,
phase: None,
};
let smoothed = crate::read::smooth_one_over_n_octave(&curve, data.smooth_n);
flat_loss_asymmetric(&data.freqs, &smoothed.spl, data.min_freq, data.max_freq)
} else {
flat_loss_asymmetric(&data.freqs, &error, data.min_freq, data.max_freq)
}
}
LossType::SpeakerScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref sd) = data.speaker_score_data {
let error = &peq_spl - &data.deviation;
let s = speaker_score_loss(sd, &data.freqs, &peq_spl);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq) / 3.0;
100.0 - s + p
} else {
log::error!("speaker score loss requested but score data is missing");
f64::INFINITY
}
}
LossType::HeadphoneScore => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
if let Some(ref _hd) = data.headphone_score_data {
let error = &data.deviation - &peq_spl;
let error_curve = Curve {
freq: data.freqs.clone(),
spl: error.clone(),
phase: None,
};
let s = headphone_loss(&error_curve);
let p = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
1000.0 - s + p * 20.0
} else {
log::error!("headphone score loss requested but headphone data is missing");
f64::INFINITY
}
}
LossType::Epa => {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let error = &peq_spl - &data.deviation;
let flatness = flat_loss(&data.freqs, &error, data.min_freq, data.max_freq);
let freqs_vec: Vec<f64> = data.freqs.iter().copied().collect();
let corrected_spl: Vec<f64> = data
.freqs
.iter()
.enumerate()
.map(|(i, _)| data.target[i] + data.deviation[i] + peq_spl[i])
.collect();
let epa_config = crate::epa::score::EpaConfig::default();
crate::epa::score::epa_loss(&freqs_vec, &corrected_spl, &epa_config, flatness)
}
}
}
pub fn compute_fitness_penalties_ref(x: &[f64], data: &ObjectiveData) -> f64 {
let fit = compute_base_fitness(x, data);
let is_peq_loss = !matches!(
data.loss_type,
LossType::DriversFlat | LossType::MultiSubFlat
);
let mut penalized = fit;
if data.penalty_w_ceiling > 0.0 && is_peq_loss {
let peq_spl = x2spl(&data.freqs, x, data.srate, data.peq_model);
let viol = viol_ceiling_from_spl(&peq_spl, data.max_db, data.peq_model);
penalized += data.penalty_w_ceiling * viol * viol;
}
if data.penalty_w_spacing > 0.0 && is_peq_loss {
let viol = viol_spacing_from_xs(x, data.peq_model, data.min_spacing_oct);
penalized += data.penalty_w_spacing * viol * viol;
}
if data.penalty_w_mingain > 0.0 && data.min_db > 0.0 && is_peq_loss {
let viol = viol_min_gain_from_xs(x, data.peq_model, data.min_db);
penalized += data.penalty_w_mingain * viol * viol;
}
penalized
}
pub fn compute_fitness_penalties(
x: &[f64],
_gradient: Option<&mut [f64]>,
data: &mut ObjectiveData,
) -> f64 {
compute_fitness_penalties_ref(x, data)
}
pub fn optimize_filters(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
cli_args: &crate::cli::Args,
) -> Result<(String, f64), (String, f64)> {
optimize_filters_with_algo_override(
x,
lower_bounds,
upper_bounds,
objective_data,
cli_args,
None,
)
}
pub fn optimize_filters_with_algo_override(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
cli_args: &crate::cli::Args,
algo_override: Option<&str>,
) -> Result<(String, f64), (String, f64)> {
let algo = algo_override.unwrap_or(&cli_args.algo);
let population = cli_args.population;
let maxeval = cli_args.maxeval;
match parse_algorithm_name(algo) {
#[cfg(feature = "nlopt")]
Some(AlgorithmCategory::Nlopt(nlopt_algo)) => optimize_filters_nlopt(
x,
lower_bounds,
upper_bounds,
objective_data,
nlopt_algo,
population,
maxeval,
),
Some(AlgorithmCategory::Metaheuristics(mh_name)) => optimize_filters_mh(
x,
lower_bounds,
upper_bounds,
objective_data,
&mh_name,
population,
maxeval,
),
Some(AlgorithmCategory::AutoEQ(autoeq_name)) => optimize_filters_autoeq(
x,
lower_bounds,
upper_bounds,
objective_data,
&autoeq_name,
cli_args,
),
None => Err((format!("Unknown algorithm: {}", algo), f64::INFINITY)),
}
}
pub type OptimProgressCallback = Box<dyn FnMut(usize, f64) -> crate::de::CallbackAction + Send>;
#[allow(clippy::too_many_arguments)]
pub fn optimize_filters_with_callback(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
cli_args: &crate::cli::Args,
mut callback: OptimProgressCallback,
) -> Result<(String, f64), (String, f64)> {
let algo = &cli_args.algo;
let population = cli_args.population;
let maxeval = cli_args.maxeval;
match parse_algorithm_name(algo) {
#[cfg(feature = "nlopt")]
Some(AlgorithmCategory::Nlopt(nlopt_algo)) => {
optimize_filters_nlopt(
x,
lower_bounds,
upper_bounds,
objective_data,
nlopt_algo,
population,
maxeval,
)
}
Some(AlgorithmCategory::Metaheuristics(mh_name)) => {
let mh_cb: Box<
dyn FnMut(&super::optim_mh::MHIntermediate) -> crate::de::CallbackAction + Send,
> = Box::new(move |intermediate| callback(intermediate.iter, intermediate.fun));
optimize_filters_mh_with_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
&mh_name,
population,
maxeval,
mh_cb,
)
}
Some(AlgorithmCategory::AutoEQ(autoeq_name)) => {
let de_cb: Box<
dyn FnMut(&crate::de::DEIntermediate) -> crate::de::CallbackAction + Send,
> = Box::new(move |intermediate| callback(intermediate.iter, intermediate.fun));
optimize_filters_autoeq_with_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
&autoeq_name,
cli_args,
de_cb,
)
}
None => Err((format!("Unknown algorithm: {}", algo), f64::INFINITY)),
}
}
pub fn compute_sorted_freqs_and_adjacent_octave_spacings(
x: &[f64],
peq_model: PeqModel,
) -> (Vec<f64>, Vec<f64>) {
let n = crate::param_utils::num_filters(x, peq_model);
let mut freqs: Vec<f64> = Vec::with_capacity(n);
for i in 0..n {
let params = crate::param_utils::get_filter_params(x, i, peq_model);
freqs.push(10f64.powf(params.freq));
}
freqs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let spacings: Vec<f64> = if freqs.len() < 2 {
Vec::new()
} else {
freqs
.windows(2)
.map(|w| (w[1].max(1e-9) / w[0].max(1e-9)).log2().abs())
.collect()
};
(freqs, spacings)
}
#[cfg(test)]
mod spacing_diag_tests {
use super::compute_sorted_freqs_and_adjacent_octave_spacings;
#[test]
fn adjacent_octave_spacings_basic() {
let x = [
100f64.log10(),
1.0,
0.0,
200f64.log10(),
1.0,
0.0,
400f64.log10(),
1.0,
0.0,
];
use crate::cli::PeqModel;
let (_freqs, spacings) =
compute_sorted_freqs_and_adjacent_octave_spacings(&x, PeqModel::Pk);
assert!((spacings[0] - 1.0).abs() < 1e-12);
assert!((spacings[1] - 1.0).abs() < 1e-12);
}
}