use crate::AutoeqError;
use crate::Cea2034Data;
use crate::Curve;
use crate::HeadphoneLossData;
use crate::PeqModel;
use crate::SpeakerLossData;
use crate::iir::Biquad;
use crate::loss::DriversLossData;
use crate::optim::{ObjectiveData, optimize_filters_with_algo_override};
use crate::optim_de::optimize_filters_autoeq_with_callback;
use crate::read;
use crate::x2peq;
use ndarray::Array1;
use std::collections::HashMap;
use std::error::Error;
use std::path::PathBuf;
pub mod resume;
pub async fn load_input_curve(
args: &crate::cli::Args,
) -> Result<(Curve, Option<HashMap<String, Curve>>), Box<dyn Error>> {
let mut spin_data: Option<HashMap<String, Curve>> = None;
let input_curve = if let (Some(speaker), Some(version), Some(measurement)) =
(&args.speaker, &args.version, &args.measurement)
{
if measurement == "Estimated In-Room Response" {
let plot_data = read::fetch_measurement_plot_data(speaker, version, "CEA2034").await?;
let curves = read::extract_cea2034_curves_original(&plot_data, "CEA2034")?;
spin_data = Some(curves.clone());
let pir_curve = curves
.get("Estimated In-Room Response")
.ok_or("PIR curve not found in CEA2034 data")?;
pir_curve.clone()
} else {
let plot_data =
read::fetch_measurement_plot_data(speaker, version, measurement).await?;
let extracted_curve =
read::extract_curve_by_name(&plot_data, measurement, &args.curve_name)?;
if measurement == "CEA2034" {
spin_data = Some(read::extract_cea2034_curves_original(
&plot_data, "CEA2034",
)?);
}
extracted_curve
}
} else {
let curve_path = args.curve.as_ref().ok_or(
"Either --curve or all of --speaker, --version, and --measurement must be provided",
)?;
read::read_curve_from_csv(curve_path)?
};
Ok((input_curve, spin_data))
}
pub fn build_target_curve(
args: &crate::cli::Args,
freqs: &Array1<f64>,
input_curve: &Curve,
) -> Result<Curve, AutoeqError> {
if let Some(ref target_path) = args.target {
crate::qa_println!(
args,
"[RUST DEBUG] Loading target curve from path: {}",
target_path.display()
);
crate::qa_println!(
args,
"[RUST DEBUG] Current working directory: {:?}",
std::env::current_dir()
);
let target_curve =
read::read_curve_from_csv(target_path).map_err(|e| AutoeqError::TargetCurveLoad {
path: target_path.display().to_string(),
message: e.to_string(),
})?;
Ok(read::normalize_and_interpolate_response(
freqs,
&target_curve,
))
} else {
match args.curve_name.as_str() {
"Listening Window" => {
let log_f_min = 1000.0_f64.log10();
let log_f_max = 20000.0_f64.log10();
let denom = log_f_max - log_f_min;
let spl = Array1::from_shape_fn(freqs.len(), |i| {
let f_hz = freqs[i].max(1e-12);
let fl = f_hz.log10();
if fl < log_f_min {
0.0
} else if fl >= log_f_max {
-0.5
} else {
let t = (fl - log_f_min) / denom;
-0.5 * t
}
});
Ok(Curve {
freq: freqs.clone(),
spl,
phase: None,
})
}
"Sound Power" | "Early Reflections" | "Estimated In-Room Response" => {
let slope =
crate::loss::curve_slope_per_octave_in_range(input_curve, 100.0, 10000.0)
.unwrap_or(-1.2)
- 0.2;
let lo = 100.0_f64;
let hi = 20000.0_f64;
let hi_val = slope * (hi / lo).log2();
let spl = Array1::from_shape_fn(freqs.len(), |i| {
let f = freqs[i].max(1e-12);
if f < lo {
0.0
} else if f >= hi {
hi_val
} else {
slope * (f / lo).log2()
}
});
Ok(Curve {
freq: freqs.clone(),
spl,
phase: None,
})
}
_ => {
let spl = Array1::zeros(freqs.len());
Ok(Curve {
freq: freqs.clone(),
spl,
phase: None,
})
}
}
}
}
pub fn setup_objective_data(
args: &crate::cli::Args,
input_curve: &Curve,
target_curve: &Curve,
deviation_curve: &Curve,
spin_data: &Option<HashMap<String, Curve>>,
) -> Result<(ObjectiveData, bool), AutoeqError> {
let use_cea = spin_data.is_some();
let speaker_score_data_opt = if let Some(spin) = spin_data {
Some(SpeakerLossData::try_new(spin)?)
} else {
None
};
let headphone_score_data_opt = if !use_cea {
Some(HeadphoneLossData::new(args.smooth, args.smooth_n))
} else {
None
};
let objective_data = ObjectiveData {
freqs: input_curve.freq.clone(),
target: target_curve.spl.clone(),
deviation: deviation_curve.spl.clone(), srate: args.sample_rate,
min_spacing_oct: args.min_spacing_oct,
spacing_weight: args.spacing_weight,
max_db: args.max_db,
min_db: args.min_db,
min_freq: args.min_freq,
max_freq: args.max_freq,
peq_model: args.effective_peq_model(),
loss_type: args.loss,
speaker_score_data: speaker_score_data_opt,
headphone_score_data: headphone_score_data_opt,
input_curve: if !use_cea {
Some(input_curve.clone())
} else {
None
},
drivers_data: None,
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: args.smooth,
smooth_n: args.smooth_n,
max_boost_envelope: None,
min_cut_envelope: None,
};
Ok((objective_data, use_cea))
}
fn interpolate_cea2034_data(spin_data: &Cea2034Data, standard_freq: &Array1<f64>) -> Cea2034Data {
let interpolate = |curve: &Curve| read::interpolate_response(standard_freq, curve);
let on_axis = interpolate(&spin_data.on_axis);
let listening_window = interpolate(&spin_data.listening_window);
let early_reflections = interpolate(&spin_data.early_reflections);
let sound_power = interpolate(&spin_data.sound_power);
let estimated_in_room = interpolate(&spin_data.estimated_in_room);
let er_di = interpolate(&spin_data.er_di);
let sp_di = interpolate(&spin_data.sp_di);
let mut curves = HashMap::new();
curves.insert("On Axis".to_string(), on_axis.clone());
curves.insert("Listening Window".to_string(), listening_window.clone());
curves.insert("Early Reflections".to_string(), early_reflections.clone());
curves.insert("Sound Power".to_string(), sound_power.clone());
curves.insert(
"Estimated In-Room Response".to_string(),
estimated_in_room.clone(),
);
Cea2034Data {
on_axis,
listening_window,
early_reflections,
sound_power,
estimated_in_room,
er_di,
sp_di,
curves,
}
}
pub fn setup_drivers_objective_data(
args: &crate::cli::Args,
drivers_data: DriversLossData,
) -> ObjectiveData {
ObjectiveData {
freqs: drivers_data.freq_grid.clone(),
target: Array1::zeros(drivers_data.freq_grid.len()),
deviation: Array1::zeros(drivers_data.freq_grid.len()),
srate: args.sample_rate,
min_spacing_oct: 0.0, spacing_weight: 0.0,
max_db: args.max_db,
min_db: args.min_db,
min_freq: args.min_freq,
max_freq: args.max_freq,
peq_model: args.effective_peq_model(),
loss_type: crate::LossType::DriversFlat,
speaker_score_data: None,
headphone_score_data: None,
input_curve: None,
drivers_data: Some(drivers_data),
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false, smooth_n: 1,
max_boost_envelope: None,
min_cut_envelope: None,
}
}
pub fn setup_drivers_bounds(
args: &crate::cli::Args,
drivers_data: &DriversLossData,
) -> (Vec<f64>, Vec<f64>) {
let n_drivers = drivers_data.drivers.len();
let n_params = n_drivers * 2 + (n_drivers - 1);
let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-args.max_db);
upper_bounds.push(args.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(-20.0);
upper_bounds.push(20.0);
}
for i in 0..(n_drivers - 1) {
let driver_low = &drivers_data.drivers[i];
let driver_high = &drivers_data.drivers[i + 1];
let mean_low = driver_low.mean_freq();
let mean_high = driver_high.mean_freq();
let geometric_center = (mean_low * mean_high).sqrt();
let xover_min = (geometric_center * 0.5).max(args.min_freq).log10();
let xover_max = (geometric_center * 2.0).min(args.max_freq).log10();
let xover_min = xover_min.min(xover_max - 0.1);
lower_bounds.push(xover_min);
upper_bounds.push(xover_max);
}
(lower_bounds, upper_bounds)
}
pub fn drivers_initial_guess(
lower_bounds: &[f64],
upper_bounds: &[f64],
n_drivers: usize,
) -> Vec<f64> {
let mut x = Vec::new();
x.extend(vec![0.0; n_drivers]);
x.extend(vec![0.0; n_drivers]);
for i in (2 * n_drivers)..lower_bounds.len() {
let xover_log10 = (lower_bounds[i] + upper_bounds[i]) / 2.0;
x.push(xover_log10);
}
x
}
pub fn setup_drivers_bounds_fixed_freqs(
args: &crate::cli::Args,
drivers_data: &DriversLossData,
) -> (Vec<f64>, Vec<f64>) {
let n_drivers = drivers_data.drivers.len();
let n_params = n_drivers * 2;
let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-args.max_db);
upper_bounds.push(args.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(-20.0);
upper_bounds.push(20.0);
}
(lower_bounds, upper_bounds)
}
pub fn drivers_initial_guess_fixed_freqs(
_lower_bounds: &[f64],
_upper_bounds: &[f64],
n_drivers: usize,
) -> Vec<f64> {
let mut x = Vec::new();
x.extend(vec![0.0; n_drivers]);
x.extend(vec![0.0; n_drivers]);
x
}
pub fn setup_bounds(args: &crate::cli::Args) -> (Vec<f64>, Vec<f64>) {
use crate::cli::PeqModel;
let model = args.effective_peq_model();
let ppf = crate::param_utils::params_per_filter(model);
let num_params = args.num_filters * ppf;
let mut lower_bounds = Vec::with_capacity(num_params);
let mut upper_bounds = Vec::with_capacity(num_params);
let spacing = 1.0; let gain_lower = -3.0 * args.max_db;
let q_lower = args.min_q.max(0.1);
let range = (args.max_freq.log10() - args.min_freq.log10()) / (args.num_filters as f64);
for i in 0..args.num_filters {
let f_center = args.min_freq.log10() + (i as f64) * range;
let f_low = (f_center - spacing * range).max(args.min_freq.log10());
let f_high = (f_center + spacing * range).min(args.max_freq.log10());
let f_low_adjusted = if i > 0 {
let prev_freq_idx = if ppf == 3 {
(i - 1) * 3
} else {
(i - 1) * 4 + 1
};
f_low.max(lower_bounds[prev_freq_idx])
} else {
f_low
};
let f_high_adjusted = if i > 0 {
let prev_freq_idx = if ppf == 3 {
(i - 1) * 3
} else {
(i - 1) * 4 + 1
};
f_high.max(upper_bounds[prev_freq_idx])
} else {
f_high
};
match model {
PeqModel::Pk
| PeqModel::HpPk
| PeqModel::HpPkLp
| PeqModel::LsPk
| PeqModel::LsPkHs => {
lower_bounds.extend_from_slice(&[f_low_adjusted, q_lower, gain_lower]);
upper_bounds.extend_from_slice(&[f_high_adjusted, args.max_q, args.max_db]);
}
PeqModel::FreePkFree | PeqModel::Free => {
let (type_low, type_high) = if model == PeqModel::Free
|| (model == PeqModel::FreePkFree && (i == 0 || i == args.num_filters - 1))
{
crate::param_utils::filter_type_bounds()
} else {
(0.0, 0.999) };
lower_bounds.extend_from_slice(&[type_low, f_low_adjusted, q_lower, gain_lower]);
upper_bounds.extend_from_slice(&[
type_high,
f_high_adjusted,
args.max_q,
args.max_db,
]);
}
}
}
match model {
PeqModel::HpPk | PeqModel::HpPkLp => {
lower_bounds[0] = 20.0_f64.max(args.min_freq).log10();
upper_bounds[0] = 120.0_f64.min(args.min_freq + 20.0).log10();
lower_bounds[1] = 1.0;
upper_bounds[1] = 1.5; lower_bounds[2] = 0.0;
upper_bounds[2] = 0.0;
}
PeqModel::LsPk | PeqModel::LsPkHs => {
lower_bounds[0] = 20.0_f64.max(args.min_freq).log10();
upper_bounds[0] = 120.0_f64.min(args.min_freq + 20.0).log10();
lower_bounds[1] = args.min_q;
upper_bounds[1] = args.max_q;
lower_bounds[2] = -args.max_db;
upper_bounds[2] = args.max_db;
}
_ => {}
}
if args.num_filters > 1 {
if matches!(model, PeqModel::HpPkLp) {
let last_idx = (args.num_filters - 1) * ppf;
if ppf == 3 {
lower_bounds[last_idx] = (args.max_freq - 2000.0).max(5000.0).log10();
upper_bounds[last_idx] = args.max_freq.log10();
lower_bounds[last_idx + 1] = 1.0;
upper_bounds[last_idx + 1] = 1.5;
lower_bounds[last_idx + 2] = 0.0;
upper_bounds[last_idx + 2] = 0.0;
}
}
if matches!(model, PeqModel::LsPkHs) {
let last_idx = (args.num_filters - 1) * ppf;
if ppf == 3 {
lower_bounds[last_idx] = (args.max_freq - 2000.0).max(5000.0).log10();
upper_bounds[last_idx] = args.max_freq.log10();
lower_bounds[last_idx + 1] = args.min_q;
upper_bounds[last_idx + 1] = args.max_q;
lower_bounds[last_idx + 2] = -args.max_db;
upper_bounds[last_idx + 2] = args.max_db;
}
}
}
if args.qa.is_none() {
log::info!("\n📏 Parameter Bounds (Model: {}):", model);
log::info!("+----+-------------------+---------------+-----------------+--------+");
log::info!("| # | Freq Range (Hz) | Q Range | Gain Range (dB) | Type |");
log::info!("+----+-------------------+---------------+-----------------+--------+");
for i in 0..args.num_filters {
let offset = i * ppf;
let (freq_idx, q_idx, gain_idx) = if ppf == 3 {
(offset, offset + 1, offset + 2)
} else {
(offset + 1, offset + 2, offset + 3)
};
let freq_low_hz = 10f64.powf(lower_bounds[freq_idx]);
let freq_high_hz = 10f64.powf(upper_bounds[freq_idx]);
let q_low = lower_bounds[q_idx];
let q_high = upper_bounds[q_idx];
let gain_low = lower_bounds[gain_idx];
let gain_high = upper_bounds[gain_idx];
let filter_type = match model {
PeqModel::Pk => "PK",
PeqModel::HpPk if i == 0 => "HP",
PeqModel::HpPk => "PK",
PeqModel::HpPkLp if i == 0 => "HP",
PeqModel::HpPkLp if i == args.num_filters - 1 => "LP",
PeqModel::HpPkLp => "PK",
PeqModel::LsPk if i == 0 => "LS",
PeqModel::LsPk => "PK",
PeqModel::LsPkHs if i == 0 => "LS",
PeqModel::LsPkHs if i == args.num_filters - 1 => "HS",
PeqModel::LsPkHs => "PK",
PeqModel::FreePkFree if i == 0 || i == args.num_filters - 1 => "??",
PeqModel::FreePkFree => "PK",
PeqModel::Free => "??",
};
log::info!(
"| {:2} | {:7.1} - {:7.1} | {:5.2} - {:5.2} | {:+6.2} - {:+6.2} | {:6} |",
i + 1,
freq_low_hz,
freq_high_hz,
q_low,
q_high,
gain_low,
gain_high,
filter_type
);
}
log::info!("+----+-------------------+---------------+-----------------+--------+\n");
}
(lower_bounds, upper_bounds)
}
pub fn initial_guess(
args: &crate::cli::Args,
lower_bounds: &[f64],
upper_bounds: &[f64],
) -> Vec<f64> {
let model = args.effective_peq_model();
let ppf = crate::param_utils::params_per_filter(model);
let mut x = vec![];
for i in 0..args.num_filters {
let offset = i * ppf;
match model {
PeqModel::Pk
| PeqModel::HpPk
| PeqModel::HpPkLp
| PeqModel::LsPk
| PeqModel::LsPkHs => {
let freq = lower_bounds[offset]
.min(args.max_freq.log10())
.clamp(lower_bounds[offset], upper_bounds[offset]);
let q = (upper_bounds[offset + 1] * lower_bounds[offset + 1])
.sqrt()
.clamp(lower_bounds[offset + 1], upper_bounds[offset + 1]);
let sign = if i % 2 == 0 { 0.5 } else { -0.5 };
let gain = (sign * upper_bounds[offset + 2].max(args.min_db))
.clamp(lower_bounds[offset + 2], upper_bounds[offset + 2]);
x.extend_from_slice(&[freq, q, gain]);
}
PeqModel::FreePkFree | PeqModel::Free => {
let filter_type = 0.0_f64.clamp(lower_bounds[offset], upper_bounds[offset]);
let freq = lower_bounds[offset + 1]
.min(args.max_freq.log10())
.clamp(lower_bounds[offset + 1], upper_bounds[offset + 1]);
let q = (upper_bounds[offset + 2] * lower_bounds[offset + 2])
.sqrt()
.clamp(lower_bounds[offset + 2], upper_bounds[offset + 2]);
let sign = if i % 2 == 0 { 0.5 } else { -0.5 };
let gain = (sign * upper_bounds[offset + 3].max(args.min_db))
.clamp(lower_bounds[offset + 3], upper_bounds[offset + 3]);
x.extend_from_slice(&[filter_type, freq, q, gain]);
}
}
}
x
}
pub fn perform_optimization(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
) -> Result<Vec<f64>, Box<dyn Error>> {
perform_optimization_with_callback(
args,
objective_data,
Box::new(|_intermediate| crate::de::CallbackAction::Continue),
)
}
pub fn perform_optimization_with_callback(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
callback: Box<dyn FnMut(&crate::de::DEIntermediate) -> crate::de::CallbackAction + Send>,
) -> Result<Vec<f64>, Box<dyn Error>> {
let (lower_bounds, upper_bounds) = setup_bounds(args);
let mut x = initial_guess(args, &lower_bounds, &upper_bounds);
let result = optimize_filters_autoeq_with_callback(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
&args.algo,
args,
callback,
);
match result {
Ok((_status, _val)) => {}
Err((e, _final_value)) => {
return Err(std::io::Error::other(e).into());
}
};
if args.refine {
let local_result = optimize_filters_with_algo_override(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
args,
Some(&args.local_algo),
);
match local_result {
Ok((_local_status, _local_val)) => {}
Err((e, _final_value)) => {
return Err(std::io::Error::other(e).into());
}
}
}
Ok(x)
}
#[derive(Debug, Clone)]
pub struct ProgressUpdate {
pub iteration: usize,
pub max_iterations: usize,
pub loss: f64,
pub score: Option<f64>,
pub convergence: f64,
pub params: Vec<f64>,
pub biquads: Vec<Biquad>,
pub filter_response: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct ProgressCallbackConfig {
pub interval: usize,
pub include_biquads: bool,
pub include_filter_response: bool,
pub frequencies: Vec<f64>,
}
impl Default for ProgressCallbackConfig {
fn default() -> Self {
Self {
interval: 25,
include_biquads: true,
include_filter_response: true,
frequencies: Vec::new(), }
}
}
#[derive(Debug, Clone)]
pub struct OptimizationOutput {
pub params: Vec<f64>,
pub history: Vec<(usize, f64)>,
}
pub fn perform_optimization_with_progress<F>(
args: &crate::cli::Args,
objective_data: &ObjectiveData,
config: ProgressCallbackConfig,
mut callback: F,
) -> Result<OptimizationOutput, Box<dyn Error>>
where
F: FnMut(&ProgressUpdate) -> crate::de::CallbackAction + Send + 'static,
{
use std::sync::{Arc, Mutex};
let frequencies: Vec<f64> = if config.frequencies.is_empty() {
read::create_log_frequency_grid(200, 20.0, 20000.0)
.iter()
.copied()
.collect()
} else {
config.frequencies.clone()
};
let freq_array = Array1::from(frequencies.clone());
let speaker_score_data = objective_data.speaker_score_data.clone();
let sample_rate = args.sample_rate;
let peq_model = args.peq_model;
let maxeval = args.maxeval;
let last_reported = Arc::new(Mutex::new(0usize));
let history = Arc::new(Mutex::new(Vec::new()));
let last_reported_clone = Arc::clone(&last_reported);
let history_clone = Arc::clone(&history);
let freq_array_clone = freq_array.clone();
let frequencies_clone = frequencies.clone();
let de_callback = move |intermediate: &crate::de::DEIntermediate| -> crate::de::CallbackAction {
{
let mut hist = history_clone.lock().unwrap();
hist.push((intermediate.iter, intermediate.fun));
}
let mut last = last_reported_clone.lock().unwrap();
if intermediate.iter == 0 || intermediate.iter.saturating_sub(*last) >= config.interval {
*last = intermediate.iter;
let biquads: Vec<Biquad> = if config.include_biquads {
x2peq(&intermediate.x.to_vec(), sample_rate, peq_model)
.into_iter()
.map(|(_, b)| b)
.collect()
} else {
Vec::new()
};
let filter_response: Vec<f64> = if config.include_filter_response && !biquads.is_empty()
{
frequencies_clone
.iter()
.map(|&f| biquads.iter().map(|b| b.log_result(f)).sum())
.collect()
} else {
Vec::new()
};
let score = speaker_score_data.as_ref().map(|sd| {
let peq_response = if !filter_response.is_empty() {
Array1::from(filter_response.clone())
} else {
let bs = x2peq(&intermediate.x.to_vec(), sample_rate, peq_model);
let resp: Vec<f64> = frequencies_clone
.iter()
.map(|&f| bs.iter().map(|(_, b)| b.log_result(f)).sum())
.collect();
Array1::from(resp)
};
crate::loss::speaker_score_loss(sd, &freq_array_clone, &peq_response)
});
let update = ProgressUpdate {
iteration: intermediate.iter,
max_iterations: maxeval,
loss: intermediate.fun,
score,
convergence: intermediate.convergence,
params: intermediate.x.to_vec(),
biquads,
filter_response,
};
callback(&update)
} else {
crate::de::CallbackAction::Continue
}
};
let params = perform_optimization_with_callback(args, objective_data, Box::new(de_callback))?;
let final_history = Arc::try_unwrap(history)
.map(|m| m.into_inner().unwrap())
.unwrap_or_default();
Ok(OptimizationOutput {
params,
history: final_history,
})
}
#[derive(Debug, Clone)]
pub struct VisualizationCurves {
pub frequencies: Vec<f64>,
pub input_curve: Vec<f64>,
pub target_curve: Vec<f64>,
pub deviation_curve: Vec<f64>,
pub filter_response: Vec<f64>,
pub error_curve: Vec<f64>,
pub corrected_curve: Vec<f64>,
pub individual_filter_responses: Vec<Vec<f64>>,
}
pub fn compute_visualization_curves(
frequencies: &[f64],
input_curve: &Curve,
target_curve: &Curve,
biquads: &[Biquad],
) -> VisualizationCurves {
let input_vec: Vec<f64> = input_curve.spl.iter().copied().collect();
let target_vec: Vec<f64> = target_curve.spl.iter().copied().collect();
let deviation_vec: Vec<f64> = target_vec
.iter()
.zip(input_vec.iter())
.map(|(t, i)| t - i)
.collect();
let filter_response: Vec<f64> = frequencies
.iter()
.map(|&freq| biquads.iter().map(|b| b.log_result(freq)).sum())
.collect();
let individual_filter_responses: Vec<Vec<f64>> = biquads
.iter()
.map(|biquad| {
frequencies
.iter()
.map(|&freq| biquad.log_result(freq))
.collect()
})
.collect();
let error_vec: Vec<f64> = deviation_vec
.iter()
.zip(filter_response.iter())
.map(|(d, f)| d - f)
.collect();
let corrected_vec: Vec<f64> = input_vec
.iter()
.zip(filter_response.iter())
.map(|(i, f)| i + f)
.collect();
VisualizationCurves {
frequencies: frequencies.to_vec(),
input_curve: input_vec,
target_curve: target_vec,
deviation_curve: deviation_vec,
filter_response,
error_curve: error_vec,
corrected_curve: corrected_vec,
individual_filter_responses,
}
}
#[derive(Debug, Clone)]
pub struct SpeakerOptResult {
pub biquads: Vec<Biquad>,
pub curves: VisualizationCurves,
pub spin_data: Option<Cea2034Data>,
pub history: Vec<(usize, f64)>,
pub initial_loss: f64,
pub final_loss: f64,
}
pub async fn optimize_speaker<F>(
speaker: &str,
version: &str,
measurement: &str,
args: &crate::cli::Args,
progress_config: Option<ProgressCallbackConfig>,
progress_callback: Option<F>,
) -> Result<SpeakerOptResult, Box<dyn Error>>
where
F: FnMut(&ProgressUpdate) -> crate::de::CallbackAction + Send + 'static,
{
let (input_curve, spin_data) =
read::load_spinorama_with_spin(speaker, version, measurement, &args.curve_name).await?;
let standard_freq = read::create_log_frequency_grid(200, 20.0, 20000.0);
let input_normalized = read::normalize_and_interpolate_response(&standard_freq, &input_curve);
let target_curve = build_target_curve(args, &standard_freq, &input_normalized)?;
let deviation_curve = Curve {
freq: target_curve.freq.clone(),
spl: &target_curve.spl - &input_normalized.spl,
phase: None,
};
let spin_map = spin_data.as_ref().map(|s| {
s.curves
.iter()
.map(|(name, curve)| {
let normalized = read::normalize_and_interpolate_response(&standard_freq, curve);
(name.clone(), normalized)
})
.collect::<HashMap<String, Curve>>()
});
let (objective_data, _) = setup_objective_data(
args,
&input_normalized,
&target_curve,
&deviation_curve,
&spin_map,
)?;
let (params, history) = if let (Some(config), Some(callback)) =
(progress_config, progress_callback)
{
let output = perform_optimization_with_progress(args, &objective_data, config, callback)?;
(output.params, output.history)
} else {
let params = perform_optimization_with_callback(
args,
&objective_data,
Box::new(|_| crate::de::CallbackAction::Continue),
)?;
(params, Vec::new())
};
let biquads: Vec<Biquad> = x2peq(¶ms, args.sample_rate, args.peq_model)
.into_iter()
.map(|(_, b)| b)
.collect();
let frequencies: Vec<f64> = standard_freq.iter().copied().collect();
let curves =
compute_visualization_curves(&frequencies, &input_normalized, &target_curve, &biquads);
let initial_loss = history.first().map(|x| x.1).unwrap_or(0.0);
let final_loss = history.last().map(|x| x.1).unwrap_or(0.0);
let interpolated_spin_data = spin_data.map(|s| interpolate_cea2034_data(&s, &standard_freq));
Ok(SpeakerOptResult {
biquads,
curves,
spin_data: interpolated_spin_data,
history,
initial_loss,
final_loss,
})
}
#[derive(Debug, Clone)]
pub struct HeadphoneOptResult {
pub biquads: Vec<Biquad>,
pub curves: VisualizationCurves,
pub history: Vec<(usize, f64)>,
pub initial_loss: f64,
pub final_loss: f64,
}
pub fn optimize_headphone<F>(
curve_path: &PathBuf,
target_curve: &Curve,
args: &crate::cli::Args,
progress_config: Option<ProgressCallbackConfig>,
progress_callback: Option<F>,
) -> Result<HeadphoneOptResult, Box<dyn Error>>
where
F: FnMut(&ProgressUpdate) -> crate::de::CallbackAction + Send + 'static,
{
let input_curve = read::read_curve_from_csv(curve_path)?;
let standard_freq = read::create_log_frequency_grid(200, 20.0, 20000.0);
let input_normalized = read::normalize_and_interpolate_response(&standard_freq, &input_curve);
let target_normalized = read::normalize_and_interpolate_response(&standard_freq, target_curve);
let deviation_curve = Curve {
freq: target_normalized.freq.clone(),
spl: &target_normalized.spl - &input_normalized.spl,
phase: None,
};
let (objective_data, _) = setup_objective_data(
args,
&input_normalized,
&target_normalized,
&deviation_curve,
&None,
)?;
let (params, history) = if let (Some(config), Some(callback)) =
(progress_config, progress_callback)
{
let output = perform_optimization_with_progress(args, &objective_data, config, callback)?;
(output.params, output.history)
} else {
let params = perform_optimization_with_callback(
args,
&objective_data,
Box::new(|_| crate::de::CallbackAction::Continue),
)?;
(params, Vec::new())
};
let biquads: Vec<Biquad> = x2peq(¶ms, args.sample_rate, args.peq_model)
.into_iter()
.map(|(_, b)| b)
.collect();
let frequencies: Vec<f64> = standard_freq.iter().copied().collect();
let curves = compute_visualization_curves(
&frequencies,
&input_normalized,
&target_normalized,
&biquads,
);
let initial_loss = history.first().map(|x| x.1).unwrap_or(0.0);
let final_loss = history.last().map(|x| x.1).unwrap_or(0.0);
Ok(HeadphoneOptResult {
biquads,
curves,
history,
initial_loss,
final_loss,
})
}
#[derive(Debug, Clone)]
pub struct DriverOptimizationResult {
pub gains: Vec<f64>,
pub delays: Vec<f64>,
pub crossover_freqs: Vec<f64>,
pub pre_objective: f64,
pub post_objective: f64,
pub converged: bool,
}
fn create_driver_optimization_args(
min_freq: f64,
max_freq: f64,
sample_rate: f64,
algorithm: &str,
max_iter: usize,
min_db: f64,
max_db: f64,
seed: Option<u64>,
) -> crate::cli::Args {
use crate::LossType;
use crate::cli::{Args, PeqModel};
Args {
num_filters: 0, curve: None,
target: None,
speaker: None,
version: None,
measurement: None,
curve_name: "On Axis".to_string(),
sample_rate,
min_freq,
max_freq,
min_q: 0.5,
max_q: 10.0,
min_db,
max_db,
algo: algorithm.to_string(),
strategy: "currenttobest1bin".to_string(),
algo_list: false,
strategy_list: false,
peq_model: PeqModel::Pk,
peq_model_list: false,
population: 300,
maxeval: max_iter,
refine: false,
local_algo: "cobyla".to_string(),
min_spacing_oct: 0.0,
spacing_weight: 0.0,
smooth: false,
smooth_n: 1,
loss: LossType::DriversFlat,
tolerance: 1e-3,
atolerance: 1e-4,
recombination: 0.9,
adaptive_weight_f: 0.9,
adaptive_weight_cr: 0.9,
no_parallel: false,
output: None,
driver1: None,
driver2: None,
driver3: None,
driver4: None,
crossover_type: "linkwitzriley4".to_string(),
parallel_threads: num_cpus::get(),
seed,
qa: None,
preset: None,
}
}
#[allow(clippy::too_many_arguments)]
pub fn optimize_drivers_crossover(
drivers_data: crate::loss::DriversLossData,
min_freq: f64,
max_freq: f64,
sample_rate: f64,
algorithm: &str,
max_iter: usize,
min_db: f64,
max_db: f64,
fixed_freqs: Option<Vec<f64>>,
seed: Option<u64>,
) -> Result<DriverOptimizationResult, Box<dyn std::error::Error>> {
let n_drivers = drivers_data.drivers.len();
let args = create_driver_optimization_args(
min_freq,
max_freq,
sample_rate,
algorithm,
max_iter,
min_db,
max_db,
seed,
);
let objective_data = if let Some(ref freqs) = fixed_freqs {
let mut data = setup_drivers_objective_data(&args, drivers_data.clone());
data.fixed_crossover_freqs = Some(freqs.clone());
data
} else {
setup_drivers_objective_data(&args, drivers_data.clone())
};
let (lower_bounds, upper_bounds) = if fixed_freqs.is_some() {
setup_drivers_bounds_fixed_freqs(&args, &drivers_data)
} else {
setup_drivers_bounds(&args, &drivers_data)
};
let mut x = if fixed_freqs.is_some() {
drivers_initial_guess_fixed_freqs(&lower_bounds, &upper_bounds, n_drivers)
} else {
drivers_initial_guess(&lower_bounds, &upper_bounds, n_drivers)
};
let pre_objective = crate::optim::compute_base_fitness(&x, &objective_data);
let opt_result = crate::optim::optimize_filters(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
&args,
);
let converged = match opt_result {
Ok((_status, _val)) => true,
Err((_err, _val)) => false,
};
let post_objective = crate::optim::compute_base_fitness(&x, &objective_data);
let gains = x[0..n_drivers].to_vec();
let delays = x[n_drivers..2 * n_drivers].to_vec();
let crossover_freqs = if let Some(freqs) = fixed_freqs {
freqs
} else {
let xover_freqs_log10 = &x[2 * n_drivers..];
xover_freqs_log10.iter().map(|x| 10_f64.powf(*x)).collect()
};
Ok(DriverOptimizationResult {
gains,
delays,
crossover_freqs,
pre_objective,
post_objective,
converged,
})
}
pub fn load_driver_measurements_from_files(
driver_paths: &[std::path::PathBuf],
) -> Result<Vec<crate::loss::DriverMeasurement>, Box<dyn std::error::Error>> {
use crate::loss::DriverMeasurement;
use crate::read::load_driver_measurement;
let mut measurements = Vec::new();
for (i, path) in driver_paths.iter().enumerate() {
match load_driver_measurement(path) {
Ok((freq, spl, phase)) => {
measurements.push(DriverMeasurement::new(freq, spl, phase));
log::debug!("✓ Loaded driver {} from {}", i + 1, path.display());
}
Err(e) => {
return Err(format!(
"Failed to load driver {} from {}: {}",
i + 1,
path.display(),
e
)
.into());
}
}
}
Ok(measurements)
}
#[allow(clippy::too_many_arguments)]
pub fn optimize_multisub(
drivers_data: crate::loss::DriversLossData,
min_freq: f64,
max_freq: f64,
sample_rate: f64,
algorithm: &str,
max_iter: usize,
min_db: f64,
max_db: f64,
seed: Option<u64>,
) -> Result<DriverOptimizationResult, Box<dyn std::error::Error>> {
let n_drivers = drivers_data.drivers.len();
let mut args = create_driver_optimization_args(
min_freq,
max_freq,
sample_rate,
algorithm,
max_iter,
min_db,
max_db,
seed,
);
args.loss = crate::LossType::MultiSubFlat;
let objective_data = setup_multisub_objective_data(&args, drivers_data.clone());
let (lower_bounds, upper_bounds) = setup_multisub_bounds(&args, n_drivers);
let mut x = multisub_initial_guess(n_drivers);
let pre_objective = crate::optim::compute_base_fitness(&x, &objective_data);
let opt_result = crate::optim::optimize_filters(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data.clone(),
&args,
);
let converged = opt_result.is_ok();
let post_objective = crate::optim::compute_base_fitness(&x, &objective_data);
let gains = x[0..n_drivers].to_vec();
let delays = x[n_drivers..2 * n_drivers].to_vec();
let crossover_freqs = vec![];
Ok(DriverOptimizationResult {
gains,
delays,
crossover_freqs,
pre_objective,
post_objective,
converged,
})
}
pub fn setup_multisub_objective_data(
args: &crate::cli::Args,
drivers_data: DriversLossData,
) -> ObjectiveData {
ObjectiveData {
freqs: drivers_data.freq_grid.clone(),
target: Array1::zeros(drivers_data.freq_grid.len()),
deviation: Array1::zeros(drivers_data.freq_grid.len()),
srate: args.sample_rate,
min_spacing_oct: 0.0,
spacing_weight: 0.0,
max_db: args.max_db,
min_db: args.min_db,
min_freq: args.min_freq,
max_freq: args.max_freq,
peq_model: args.effective_peq_model(),
loss_type: crate::LossType::MultiSubFlat,
speaker_score_data: None,
headphone_score_data: None,
input_curve: None,
drivers_data: Some(drivers_data),
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false, smooth_n: 1,
max_boost_envelope: None,
min_cut_envelope: None,
}
}
pub fn setup_multisub_bounds(args: &crate::cli::Args, n_drivers: usize) -> (Vec<f64>, Vec<f64>) {
let n_params = n_drivers * 2; let mut lower_bounds = Vec::with_capacity(n_params);
let mut upper_bounds = Vec::with_capacity(n_params);
for _ in 0..n_drivers {
lower_bounds.push(-args.max_db);
upper_bounds.push(args.max_db);
}
for _ in 0..n_drivers {
lower_bounds.push(0.0);
upper_bounds.push(20.0);
}
(lower_bounds, upper_bounds)
}
pub fn multisub_initial_guess(n_drivers: usize) -> Vec<f64> {
vec![0.0; n_drivers * 2]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::Args;
use clap::Parser;
fn zero_curve(freqs: Vec<f64>) -> Curve {
let n = freqs.len();
Curve {
freq: Array1::from(freqs),
spl: Array1::zeros(n),
phase: None,
}
}
#[test]
fn build_target_curve_respects_smoothing_flag() {
let mut args = Args::parse_from(["autoeq-test"]);
args.curve_name = "Listening Window".to_string();
let curve = zero_curve(vec![100.0, 1000.0, 10000.0, 20000.0]);
args.smooth = false;
let freqs = Array1::from(vec![100.0, 1000.0, 10000.0]);
let _target_curve = super::build_target_curve(&args, &freqs, &curve)
.expect("build_target_curve should succeed");
let smoothed_none: Option<Curve> = None;
assert!(smoothed_none.is_none());
args.smooth = true;
let freqs = Array1::from(vec![100.0, 1000.0, 10000.0]);
let target_curve = super::build_target_curve(&args, &freqs, &curve)
.expect("build_target_curve should succeed");
let inv_smooth = target_curve.clone();
let s = target_curve;
assert_eq!(s.spl.len(), inv_smooth.spl.len());
}
#[test]
fn setup_objective_data_sets_use_cea_when_expected() {
let mut args = Args::parse_from(["autoeq-test"]);
args.speaker = Some("spk".to_string());
args.version = Some("v".to_string());
args.measurement = Some("CEA2034".to_string());
let input_curve = zero_curve(vec![100.0, 1000.0]);
let target = Curve {
freq: input_curve.freq.clone(),
spl: Array1::zeros(input_curve.freq.len()),
phase: None,
};
let deviation = Curve {
freq: input_curve.freq.clone(),
spl: Array1::zeros(input_curve.freq.len()),
phase: None,
};
let mut spin: HashMap<String, Curve> = HashMap::new();
for k in [
"On Axis",
"Listening Window",
"Sound Power",
"Estimated In-Room Response",
] {
spin.insert(k.to_string(), zero_curve(vec![100.0, 1000.0]));
}
let spin_opt = Some(spin);
let (obj, use_cea) =
super::setup_objective_data(&args, &input_curve, &target, &deviation, &spin_opt)
.expect("setup_objective_data should succeed with valid spin data");
assert!(use_cea);
assert!(obj.speaker_score_data.is_some());
let mut args2 = args.clone();
args2.measurement = Some("On Axis".to_string());
let (obj2, use_cea2) =
super::setup_objective_data(&args2, &input_curve, &target, &deviation, &spin_opt)
.expect("setup_objective_data should succeed with valid spin data");
assert!(use_cea2); assert!(obj2.speaker_score_data.is_some());
let (obj3, use_cea3) =
super::setup_objective_data(&args, &input_curve, &target, &deviation, &None)
.expect("setup_objective_data should succeed with no spin data");
assert!(!use_cea3);
assert!(obj3.speaker_score_data.is_none());
}
#[test]
fn test_args_speaker_defaults() {
let args = Args::speaker_defaults();
assert_eq!(args.num_filters, 5);
assert_eq!(args.sample_rate, 48000.0);
assert_eq!(args.loss, crate::LossType::SpeakerFlat);
assert_eq!(args.algo, "autoeq:de");
assert_eq!(args.curve_name, "Listening Window");
assert_eq!(args.min_freq, 20.0);
assert_eq!(args.max_freq, 20000.0);
}
#[test]
fn test_args_headphone_defaults() {
let args = Args::headphone_defaults();
assert_eq!(args.num_filters, 7);
assert_eq!(args.loss, crate::LossType::HeadphoneScore);
assert_eq!(args.sample_rate, 48000.0);
assert_eq!(args.algo, "autoeq:de");
}
#[test]
fn test_args_roomeq_defaults() {
let args = Args::roomeq_defaults();
assert_eq!(args.num_filters, 10);
assert_eq!(args.max_freq, 500.0); assert_eq!(args.sample_rate, 48000.0);
assert_eq!(args.loss, crate::LossType::SpeakerFlat);
}
#[test]
fn test_progress_callback_config_default() {
let config = ProgressCallbackConfig::default();
assert_eq!(config.interval, 25);
assert!(config.include_biquads);
assert!(config.include_filter_response);
assert!(config.frequencies.is_empty());
}
#[test]
fn test_compute_visualization_curves() {
use crate::iir::BiquadFilterType;
let frequencies = vec![100.0, 1000.0, 10000.0];
let input_curve = Curve {
freq: Array1::from(frequencies.clone()),
spl: Array1::from(vec![80.0, 85.0, 82.0]),
phase: None,
};
let target_curve = Curve {
freq: Array1::from(frequencies.clone()),
spl: Array1::from(vec![80.0, 80.0, 80.0]),
phase: None,
};
let biquad = Biquad::new(BiquadFilterType::Peak, 1000.0, 48000.0, 1.0, -5.0);
let biquads = vec![biquad];
let curves =
compute_visualization_curves(&frequencies, &input_curve, &target_curve, &biquads);
assert_eq!(curves.frequencies.len(), 3);
assert_eq!(curves.input_curve.len(), 3);
assert_eq!(curves.target_curve.len(), 3);
assert_eq!(curves.deviation_curve.len(), 3);
assert_eq!(curves.filter_response.len(), 3);
assert_eq!(curves.error_curve.len(), 3);
assert_eq!(curves.corrected_curve.len(), 3);
assert_eq!(curves.individual_filter_responses.len(), 1);
for i in 0..3 {
let expected_deviation = target_curve.spl[i] - input_curve.spl[i];
assert!((curves.deviation_curve[i] - expected_deviation).abs() < 1e-10);
}
for i in 0..3 {
let expected_corrected = input_curve.spl[i] + curves.filter_response[i];
assert!((curves.corrected_curve[i] - expected_corrected).abs() < 1e-10);
}
}
#[test]
fn test_visualization_curves_empty_biquads() {
let frequencies = vec![100.0, 1000.0, 10000.0];
let input_curve = Curve {
freq: Array1::from(frequencies.clone()),
spl: Array1::from(vec![80.0, 85.0, 82.0]),
phase: None,
};
let target_curve = Curve {
freq: Array1::from(frequencies.clone()),
spl: Array1::from(vec![80.0, 80.0, 80.0]),
phase: None,
};
let biquads: Vec<Biquad> = vec![];
let curves =
compute_visualization_curves(&frequencies, &input_curve, &target_curve, &biquads);
for &val in &curves.filter_response {
assert!((val - 0.0).abs() < 1e-10);
}
for i in 0..3 {
assert!((curves.corrected_curve[i] - input_curve.spl[i]).abs() < 1e-10);
}
for i in 0..3 {
assert!((curves.error_curve[i] - curves.deviation_curve[i]).abs() < 1e-10);
}
}
#[test]
fn initial_guess_respects_fixed_bounds_for_special_filters() {
let mut args = Args::parse_from(["autoeq-test", "--peq-model", "hp-pk-lp", "-n", "3"]);
args.min_freq = 20.0;
args.max_freq = 20_000.0;
let (lower_bounds, upper_bounds) = setup_bounds(&args);
let x = initial_guess(&args, &lower_bounds, &upper_bounds);
for ((value, lower), upper) in x.iter().zip(lower_bounds.iter()).zip(upper_bounds.iter()) {
assert!(
*value >= *lower && *value <= *upper,
"initial guess must lie within bounds: value={}, lower={}, upper={}",
value,
lower,
upper
);
}
assert_eq!(
x[2], 0.0,
"fixed high-pass gain should stay at its fixed bound"
);
assert_eq!(
x[8], 0.0,
"fixed low-pass gain should stay at its fixed bound"
);
}
#[test]
fn setup_bounds_gain_lower_is_minus_3x_max_db() {
let args = Args {
num_filters: 3,
min_freq: 20.0,
max_freq: 20000.0,
min_q: 0.5,
max_q: 10.0,
min_db: 1.0, max_db: 12.0,
..Args::parse_from(["autoeq-test"])
};
let (lower, upper) = setup_bounds(&args);
for i in 0..args.num_filters {
let gain_lower = lower[i * 3 + 2];
let gain_upper = upper[i * 3 + 2];
assert!(
(gain_lower - (-36.0)).abs() < 1e-9,
"filter {} gain lower bound should be -3*max_db=-36, got {}",
i,
gain_lower
);
assert!(
(gain_upper - 12.0).abs() < 1e-9,
"filter {} gain upper bound should be max_db=12, got {}",
i,
gain_upper
);
}
}
#[test]
fn setup_bounds_gain_scales_with_max_db() {
for max_db in [4.0, 6.0, 12.0] {
let args = Args {
num_filters: 1,
min_freq: 100.0,
max_freq: 10000.0,
min_q: 0.5,
max_q: 5.0,
min_db: 1.0,
max_db,
..Args::parse_from(["autoeq-test"])
};
let (lower, _upper) = setup_bounds(&args);
let gain_lower = lower[2]; let expected = -3.0 * max_db;
assert!(
(gain_lower - expected).abs() < 1e-9,
"max_db={}: gain_lower should be {}, got {}",
max_db,
expected,
gain_lower
);
}
}
}