use crate::AutoeqError;
use crate::Curve;
use crate::HeadphoneLossData;
use crate::PeqModel;
use crate::SpeakerLossData;
use crate::iir::Biquad;
use crate::loss::DriversLossData;
use super::{ObjectiveData, optimize_filters_with_algo_override};
use super::de::optimize_filters_autoeq_with_callback;
use crate::read;
use crate::x2peq;
use ndarray::Array1;
use std::collections::HashMap;
use std::error::Error;
pub fn setup_objective_data(
params: &crate::OptimParams,
input_curve: &Curve,
target_curve: &Curve,
deviation_curve: &Curve,
spin_data: &Option<HashMap<String, Curve>>,
) -> Result<(ObjectiveData, bool), AutoeqError> {
let use_cea = spin_data.is_some();
let speaker_score_data_opt = if let Some(spin) = spin_data {
Some(SpeakerLossData::try_new(spin)?)
} else {
None
};
let headphone_score_data_opt = if !use_cea {
Some(HeadphoneLossData::new(params.smooth, params.smooth_n))
} else {
None
};
let objective_data = ObjectiveData {
freqs: input_curve.freq.clone(),
target: target_curve.spl.clone(),
deviation: deviation_curve.spl.clone(), srate: params.sample_rate,
min_spacing_oct: params.min_spacing_oct,
spacing_weight: params.spacing_weight,
max_db: params.max_db,
min_db: params.min_db,
min_freq: params.min_freq,
max_freq: params.max_freq,
peq_model: params.peq_model,
loss_type: params.loss,
speaker_score_data: speaker_score_data_opt,
headphone_score_data: headphone_score_data_opt,
input_curve: if !use_cea {
Some(input_curve.clone())
} else {
None
},
drivers_data: None,
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: params.smooth,
smooth_n: params.smooth_n,
max_boost_envelope: None,
min_cut_envelope: None,
epa_config: None,
detected_problems: Vec::new(),
null_suppression: None,
};
Ok((objective_data, use_cea))
}
pub fn setup_drivers_objective_data(
params: &crate::OptimParams,
drivers_data: DriversLossData,
) -> ObjectiveData {
ObjectiveData {
freqs: drivers_data.freq_grid.clone(),
target: Array1::zeros(drivers_data.freq_grid.len()),
deviation: Array1::zeros(drivers_data.freq_grid.len()),
srate: params.sample_rate,
min_spacing_oct: 0.0, spacing_weight: 0.0,
max_db: params.max_db,
min_db: params.min_db,
min_freq: params.min_freq,
max_freq: params.max_freq,
peq_model: params.peq_model,
loss_type: crate::LossType::DriversFlat,
speaker_score_data: None,
headphone_score_data: None,
input_curve: None,
drivers_data: Some(drivers_data),
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false, smooth_n: 1,
max_boost_envelope: None,
min_cut_envelope: None,
epa_config: None,
detected_problems: Vec::new(),
null_suppression: None,
}
}
pub fn setup_drivers_bounds(
params: &crate::OptimParams,
drivers_data: &DriversLossData,
) -> (Vec<f64>, Vec<f64>) {
let n_drivers = drivers_data.drivers.len();
let n_params = n_drivers * 2 + (n_drivers - 1);
let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-params.max_db);
upper_bounds.push(params.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(-20.0);
upper_bounds.push(20.0);
}
for i in 0..(n_drivers - 1) {
let driver_low = &drivers_data.drivers[i];
let driver_high = &drivers_data.drivers[i + 1];
let mean_low = driver_low.mean_freq();
let mean_high = driver_high.mean_freq();
let geometric_center = (mean_low * mean_high).sqrt();
let xover_min = (geometric_center * 0.5).max(params.min_freq).log10();
let xover_max = (geometric_center * 2.0).min(params.max_freq).log10();
let xover_min = xover_min.min(xover_max - 0.1);
lower_bounds.push(xover_min);
upper_bounds.push(xover_max);
}
(lower_bounds, upper_bounds)
}
pub fn drivers_initial_guess(
lower_bounds: &[f64],
upper_bounds: &[f64],
n_drivers: usize,
) -> Vec<f64> {
let mut x = Vec::new();
x.extend(vec![0.0; n_drivers]);
x.extend(vec![0.0; n_drivers]);
for i in (2 * n_drivers)..lower_bounds.len() {
let xover_log10 = (lower_bounds[i] + upper_bounds[i]) / 2.0;
x.push(xover_log10);
}
x
}
pub fn setup_drivers_bounds_fixed_freqs(
params: &crate::OptimParams,
drivers_data: &DriversLossData,
) -> (Vec<f64>, Vec<f64>) {
let n_drivers = drivers_data.drivers.len();
let n_params = n_drivers * 2;
let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-params.max_db);
upper_bounds.push(params.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(-20.0);
upper_bounds.push(20.0);
}
(lower_bounds, upper_bounds)
}
pub fn drivers_initial_guess_fixed_freqs(
_lower_bounds: &[f64],
_upper_bounds: &[f64],
n_drivers: usize,
) -> Vec<f64> {
let mut x = Vec::new();
x.extend(vec![0.0; n_drivers]);
x.extend(vec![0.0; n_drivers]);
x
}
pub fn restrict_boost_above_schroeder(
upper_bounds: &mut [f64],
params: &crate::OptimParams,
schroeder_hz: f64,
) {
use crate::cli::PeqModel;
if schroeder_hz <= 0.0 {
return;
}
let model = params.peq_model;
let ppf = crate::param_utils::params_per_filter(model);
let log_schroeder = schroeder_hz.log10();
for i in 0..params.num_filters {
let offset = i * ppf;
let (freq_idx, gain_idx) = match model {
PeqModel::Pk
| PeqModel::HpPk
| PeqModel::HpPkLp
| PeqModel::LsPk
| PeqModel::LsPkHs => (offset, offset + 2),
PeqModel::FreePkFree | PeqModel::Free => (offset + 1, offset + 3),
};
if freq_idx >= upper_bounds.len() || gain_idx >= upper_bounds.len() {
continue;
}
if upper_bounds[freq_idx] <= log_schroeder && upper_bounds[gain_idx] > 0.0 {
upper_bounds[gain_idx] = 0.0;
}
}
}
pub fn setup_bounds(params: &crate::OptimParams) -> (Vec<f64>, Vec<f64>) {
use crate::cli::PeqModel;
let model = params.peq_model;
let ppf = crate::param_utils::params_per_filter(model);
let num_params = params.num_filters * ppf;
let mut lower_bounds = Vec::with_capacity(num_params);
let mut upper_bounds = Vec::with_capacity(num_params);
let spacing = 1.0; let gain_lower = -3.0 * params.max_db;
let q_lower = params.min_q.max(0.1);
let range = (params.max_freq.log10() - params.min_freq.log10()) / (params.num_filters as f64);
for i in 0..params.num_filters {
let f_center = params.min_freq.log10() + (i as f64) * range;
let f_low = (f_center - spacing * range).max(params.min_freq.log10());
let f_high = (f_center + spacing * range).min(params.max_freq.log10());
let f_low_adjusted = if i > 0 {
let prev_freq_idx = if ppf == 3 {
(i - 1) * 3
} else {
(i - 1) * 4 + 1
};
f_low.max(lower_bounds[prev_freq_idx])
} else {
f_low
};
let f_high_adjusted = if i > 0 {
let prev_freq_idx = if ppf == 3 {
(i - 1) * 3
} else {
(i - 1) * 4 + 1
};
f_high.max(upper_bounds[prev_freq_idx])
} else {
f_high
};
let f_high_adjusted = f_high_adjusted.max(f_low_adjusted);
match model {
PeqModel::Pk
| PeqModel::HpPk
| PeqModel::HpPkLp
| PeqModel::LsPk
| PeqModel::LsPkHs => {
lower_bounds.extend_from_slice(&[f_low_adjusted, q_lower, gain_lower]);
upper_bounds.extend_from_slice(&[f_high_adjusted, params.max_q, params.max_db]);
}
PeqModel::FreePkFree | PeqModel::Free => {
let (type_low, type_high) = if model == PeqModel::Free
|| (model == PeqModel::FreePkFree && (i == 0 || i == params.num_filters - 1))
{
crate::param_utils::filter_type_bounds()
} else {
(0.0, 0.999) };
lower_bounds.extend_from_slice(&[type_low, f_low_adjusted, q_lower, gain_lower]);
upper_bounds.extend_from_slice(&[
type_high,
f_high_adjusted,
params.max_q,
params.max_db,
]);
}
}
}
match model {
PeqModel::HpPk | PeqModel::HpPkLp => {
lower_bounds[0] = 20.0_f64.max(params.min_freq).log10();
upper_bounds[0] = 120.0_f64.min(params.min_freq + 20.0).log10();
lower_bounds[1] = 1.0;
upper_bounds[1] = 1.5; lower_bounds[2] = 0.0;
upper_bounds[2] = 0.0;
}
PeqModel::LsPk | PeqModel::LsPkHs => {
lower_bounds[0] = 20.0_f64.max(params.min_freq).log10();
upper_bounds[0] = 120.0_f64.min(params.min_freq + 20.0).log10();
lower_bounds[1] = params.min_q;
upper_bounds[1] = params.max_q;
lower_bounds[2] = -params.max_db;
upper_bounds[2] = params.max_db;
}
_ => {}
}
if params.num_filters > 1 {
if matches!(model, PeqModel::HpPkLp) {
let last_idx = (params.num_filters - 1) * ppf;
if ppf == 3 {
lower_bounds[last_idx] = (params.max_freq - 2000.0).max(5000.0).log10();
upper_bounds[last_idx] = params.max_freq.log10();
lower_bounds[last_idx + 1] = 1.0;
upper_bounds[last_idx + 1] = 1.5;
lower_bounds[last_idx + 2] = 0.0;
upper_bounds[last_idx + 2] = 0.0;
}
}
if matches!(model, PeqModel::LsPkHs) {
let last_idx = (params.num_filters - 1) * ppf;
if ppf == 3 {
lower_bounds[last_idx] = (params.max_freq - 2000.0).max(5000.0).log10();
upper_bounds[last_idx] = params.max_freq.log10();
lower_bounds[last_idx + 1] = params.min_q;
upper_bounds[last_idx + 1] = params.max_q;
lower_bounds[last_idx + 2] = -params.max_db;
upper_bounds[last_idx + 2] = params.max_db;
}
}
}
if !params.quiet {
log::info!("\n📏 Parameter Bounds (Model: {}):", model);
log::info!("+----+-------------------+---------------+-----------------+--------+");
log::info!("| # | Freq Range (Hz) | Q Range | Gain Range (dB) | Type |");
log::info!("+----+-------------------+---------------+-----------------+--------+");
for i in 0..params.num_filters {
let offset = i * ppf;
let (freq_idx, q_idx, gain_idx) = if ppf == 3 {
(offset, offset + 1, offset + 2)
} else {
(offset + 1, offset + 2, offset + 3)
};
let freq_low_hz = 10f64.powf(lower_bounds[freq_idx]);
let freq_high_hz = 10f64.powf(upper_bounds[freq_idx]);
let q_low = lower_bounds[q_idx];
let q_high = upper_bounds[q_idx];
let gain_low = lower_bounds[gain_idx];
let gain_high = upper_bounds[gain_idx];
let filter_type = match model {
PeqModel::Pk => "PK",
PeqModel::HpPk if i == 0 => "HP",
PeqModel::HpPk => "PK",
PeqModel::HpPkLp if i == 0 => "HP",
PeqModel::HpPkLp if i == params.num_filters - 1 => "LP",
PeqModel::HpPkLp => "PK",
PeqModel::LsPk if i == 0 => "LS",
PeqModel::LsPk => "PK",
PeqModel::LsPkHs if i == 0 => "LS",
PeqModel::LsPkHs if i == params.num_filters - 1 => "HS",
PeqModel::LsPkHs => "PK",
PeqModel::FreePkFree if i == 0 || i == params.num_filters - 1 => "??",
PeqModel::FreePkFree => "PK",
PeqModel::Free => "??",
};
log::info!(
"| {:2} | {:7.1} - {:7.1} | {:5.2} - {:5.2} | {:+6.2} - {:+6.2} | {:6} |",
i + 1,
freq_low_hz,
freq_high_hz,
q_low,
q_high,
gain_low,
gain_high,
filter_type
);
}
log::info!("+----+-------------------+---------------+-----------------+--------+\n");
}
(lower_bounds, upper_bounds)
}
pub fn initial_guess(
params: &crate::OptimParams,
lower_bounds: &[f64],
upper_bounds: &[f64],
) -> Vec<f64> {
let model = params.peq_model;
let ppf = crate::param_utils::params_per_filter(model);
let mut x = vec![];
for i in 0..params.num_filters {
let offset = i * ppf;
match model {
PeqModel::Pk
| PeqModel::HpPk
| PeqModel::HpPkLp
| PeqModel::LsPk
| PeqModel::LsPkHs => {
let freq = lower_bounds[offset]
.min(params.max_freq.log10())
.clamp(lower_bounds[offset], upper_bounds[offset]);
let q = (upper_bounds[offset + 1] * lower_bounds[offset + 1])
.sqrt()
.clamp(lower_bounds[offset + 1], upper_bounds[offset + 1]);
let sign = if i % 2 == 0 { 0.5 } else { -0.5 };
let gain = (sign * upper_bounds[offset + 2].max(params.min_db))
.clamp(lower_bounds[offset + 2], upper_bounds[offset + 2]);
x.extend_from_slice(&[freq, q, gain]);
}
PeqModel::FreePkFree | PeqModel::Free => {
let filter_type = 0.0_f64.clamp(lower_bounds[offset], upper_bounds[offset]);
let freq = lower_bounds[offset + 1]
.min(params.max_freq.log10())
.clamp(lower_bounds[offset + 1], upper_bounds[offset + 1]);
let q = (upper_bounds[offset + 2] * lower_bounds[offset + 2])
.sqrt()
.clamp(lower_bounds[offset + 2], upper_bounds[offset + 2]);
let sign = if i % 2 == 0 { 0.5 } else { -0.5 };
let gain = (sign * upper_bounds[offset + 3].max(params.min_db))
.clamp(lower_bounds[offset + 3], upper_bounds[offset + 3]);
x.extend_from_slice(&[filter_type, freq, q, gain]);
}
}
}
x
}
pub fn perform_optimization(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
) -> Result<Vec<f64>, Box<dyn Error>> {
perform_optimization_with_callback(
args,
objective_data,
Box::new(|_intermediate| crate::de::CallbackAction::Continue),
)
}
pub fn perform_optimization_with_callback(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
callback: Box<dyn FnMut(&crate::de::DEIntermediate) -> crate::de::CallbackAction + Send>,
) -> Result<Vec<f64>, Box<dyn Error>> {
let params = crate::OptimParams::from(args);
let (lower_bounds, upper_bounds) = setup_bounds(¶ms);
let mut x = initial_guess(¶ms, &lower_bounds, &upper_bounds);
let result = optimize_filters_autoeq_with_callback(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
¶ms.algo,
¶ms,
callback,
);
match result {
Ok((_status, _val)) => {}
Err((e, _final_value)) => {
return Err(std::io::Error::other(e).into());
}
};
if params.refine {
let local_result = optimize_filters_with_algo_override(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
¶ms,
Some(¶ms.local_algo),
);
match local_result {
Ok((_local_status, _local_val)) => {}
Err((e, _final_value)) => {
return Err(std::io::Error::other(e).into());
}
}
}
Ok(x)
}
#[derive(Debug, Clone)]
pub struct ProgressUpdate {
pub iteration: usize,
pub max_iterations: usize,
pub loss: f64,
pub score: Option<f64>,
pub convergence: f64,
pub params: Vec<f64>,
pub biquads: Vec<Biquad>,
pub filter_response: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct ProgressCallbackConfig {
pub interval: usize,
pub include_biquads: bool,
pub include_filter_response: bool,
pub frequencies: Vec<f64>,
}
impl Default for ProgressCallbackConfig {
fn default() -> Self {
Self {
interval: 25,
include_biquads: true,
include_filter_response: true,
frequencies: Vec::new(), }
}
}
#[derive(Debug, Clone)]
pub struct OptimizationOutput {
pub params: Vec<f64>,
pub history: Vec<(usize, f64)>,
}
pub fn perform_optimization_with_progress<F>(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
config: ProgressCallbackConfig,
mut callback: F,
) -> Result<OptimizationOutput, Box<dyn Error>>
where
F: FnMut(&ProgressUpdate) -> crate::de::CallbackAction + Send + 'static,
{
use std::sync::{Arc, Mutex};
let frequencies: Vec<f64> = if config.frequencies.is_empty() {
read::create_log_frequency_grid(200, 20.0, 20000.0)
.iter()
.copied()
.collect()
} else {
config.frequencies.clone()
};
let freq_array = Array1::from(frequencies.clone());
let speaker_score_data = objective_data.speaker_score_data.clone();
let sample_rate = args.sample_rate;
let peq_model = args.peq_model;
let maxeval = args.maxeval;
let last_reported = Arc::new(Mutex::new(0usize));
let history = Arc::new(Mutex::new(Vec::new()));
let last_reported_clone = Arc::clone(&last_reported);
let history_clone = Arc::clone(&history);
let freq_array_clone = freq_array.clone();
let frequencies_clone = frequencies.clone();
let de_callback = move |intermediate: &crate::de::DEIntermediate| -> crate::de::CallbackAction {
{
let mut hist = history_clone.lock().unwrap();
hist.push((intermediate.iter, intermediate.fun));
}
let mut last = last_reported_clone.lock().unwrap();
if intermediate.iter == 0 || intermediate.iter.saturating_sub(*last) >= config.interval {
*last = intermediate.iter;
let biquads: Vec<Biquad> = if config.include_biquads {
x2peq(&intermediate.x.to_vec(), sample_rate, peq_model)
.into_iter()
.map(|(_, b)| b)
.collect()
} else {
Vec::new()
};
let filter_response: Vec<f64> = if config.include_filter_response && !biquads.is_empty()
{
frequencies_clone
.iter()
.map(|&f| biquads.iter().map(|b| b.log_result(f)).sum())
.collect()
} else {
Vec::new()
};
let score = speaker_score_data.as_ref().map(|sd| {
let peq_response = if !filter_response.is_empty() {
Array1::from(filter_response.clone())
} else {
let bs = x2peq(&intermediate.x.to_vec(), sample_rate, peq_model);
let resp: Vec<f64> = frequencies_clone
.iter()
.map(|&f| bs.iter().map(|(_, b)| b.log_result(f)).sum())
.collect();
Array1::from(resp)
};
crate::loss::speaker_score_loss(sd, &freq_array_clone, &peq_response)
});
let update = ProgressUpdate {
iteration: intermediate.iter,
max_iterations: maxeval,
loss: intermediate.fun,
score,
convergence: intermediate.convergence,
params: intermediate.x.to_vec(),
biquads,
filter_response,
};
callback(&update)
} else {
crate::de::CallbackAction::Continue
}
};
let params = perform_optimization_with_callback(args, objective_data, Box::new(de_callback))?;
let final_history = Arc::try_unwrap(history)
.map(|m| m.into_inner().unwrap())
.unwrap_or_default();
Ok(OptimizationOutput {
params,
history: final_history,
})
}
pub fn setup_multisub_objective_data(
params: &crate::OptimParams,
drivers_data: DriversLossData,
) -> ObjectiveData {
ObjectiveData {
freqs: drivers_data.freq_grid.clone(),
target: Array1::zeros(drivers_data.freq_grid.len()),
deviation: Array1::zeros(drivers_data.freq_grid.len()),
srate: params.sample_rate,
min_spacing_oct: 0.0,
spacing_weight: 0.0,
max_db: params.max_db,
min_db: params.min_db,
min_freq: params.min_freq,
max_freq: params.max_freq,
peq_model: params.peq_model,
loss_type: crate::LossType::MultiSubFlat,
speaker_score_data: None,
headphone_score_data: None,
input_curve: None,
drivers_data: Some(drivers_data),
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false, smooth_n: 1,
max_boost_envelope: None,
min_cut_envelope: None,
epa_config: None,
detected_problems: Vec::new(),
null_suppression: None,
}
}
pub fn setup_multisub_bounds(params: &crate::OptimParams, n_drivers: usize) -> (Vec<f64>, Vec<f64>) {
let n_params = n_drivers * 2; let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-params.max_db);
upper_bounds.push(params.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(0.0);
upper_bounds.push(20.0);
}
(lower_bounds, upper_bounds)
}
pub fn multisub_initial_guess(n_drivers: usize) -> Vec<f64> {
vec![0.0; n_drivers * 2]
}