use crate::Curve;
use crate::loss::{CrossoverType, DriverMeasurement, DriversLossData};
use crate::workflow::DriverOptimizationResult;
use clap::Parser;
use log::warn;
use ndarray::Array1;
use std::error::Error;
use super::types::{DBAConfig, OptimizerConfig};
use crate::read as load;
pub fn optimize_dba(
dba_config: &DBAConfig,
config: &OptimizerConfig,
sample_rate: f64,
) -> Result<(DriverOptimizationResult, Curve), Box<dyn Error>> {
let front_curve = sum_array_response(&dba_config.front)?;
let rear_curve = sum_array_response(&dba_config.rear)?;
let rear_curve_inverted = invert_polarity(&rear_curve);
let driver_measurements = vec![
DriverMeasurement {
freq: front_curve.freq.clone(),
spl: front_curve.spl.clone(),
phase: front_curve.phase.clone(),
},
DriverMeasurement {
freq: rear_curve_inverted.freq.clone(),
spl: rear_curve_inverted.spl.clone(),
phase: rear_curve_inverted.phase.clone(),
},
];
let drivers_data = DriversLossData::new(driver_measurements, CrossoverType::None);
let mut args = crate::cli::Args::parse_from(["autoeq"]); args.sample_rate = sample_rate;
args.min_freq = config.min_freq;
args.max_freq = config.max_freq;
args.maxeval = config.max_iter;
args.algo = config.algorithm.clone();
args.seed = config.seed;
args.loss = crate::LossType::MultiSubFlat;
let objective_data =
crate::workflow::setup_multisub_objective_data(&args, drivers_data.clone());
let min_gain = config.min_db.min(-30.0);
let max_gain = 0.0;
let lower_bounds = vec![-0.01, min_gain, 0.0, 0.0];
let upper_bounds = vec![0.01, max_gain, 0.001, 100.0];
let mut x = vec![0.0, -3.0, 0.0, 10.0];
let opt_result =
crate::optim::optimize_filters(&mut x, &lower_bounds, &upper_bounds, objective_data, &args);
let converged = opt_result.is_ok();
let gains = vec![x[0], x[1]];
let delays = vec![x[2], x[3]];
let crossover_freqs = vec![];
let combined_response = crate::loss::compute_drivers_combined_response(
&drivers_data,
&gains,
&crossover_freqs,
Some(&delays),
sample_rate,
);
let combined_curve = Curve {
freq: drivers_data.freq_grid.clone(),
spl: combined_response,
phase: None,
};
Ok((
DriverOptimizationResult {
gains,
delays,
crossover_freqs,
pre_objective: 0.0, post_objective: 0.0,
converged,
},
combined_curve,
))
}
pub fn sum_array_response(
sources: &[super::types::MeasurementSource],
) -> Result<Curve, Box<dyn Error>> {
if sources.is_empty() {
return Err("Empty array".into());
}
let mut curves = Vec::new();
let mut missing_phase_count = 0;
for source in sources {
let curve = load::load_source(source)?;
if curve.phase.is_none() {
missing_phase_count += 1;
}
curves.push(curve);
}
if missing_phase_count > 0 {
warn!(
"DBA array summation: {} of {} measurements are missing phase data. \
Assuming 0° phase for these measurements, which may reduce optimization accuracy. \
For best results, include phase data in your measurements.",
missing_phase_count,
sources.len()
);
}
let ref_freq = curves[0].freq.clone();
use num_complex::Complex64;
use std::f64::consts::PI;
let mut sum_complex = Array1::<Complex64>::zeros(ref_freq.len());
for curve in &curves {
let interp = crate::read::interpolate_log_space(&ref_freq, curve);
for i in 0..ref_freq.len() {
let spl = interp.spl[i];
let phase = interp.phase.as_ref().map(|p| p[i]).unwrap_or(0.0);
let m = 10.0_f64.powf(spl / 20.0);
let phi = phase * PI / 180.0;
sum_complex[i] += Complex64::from_polar(m, phi);
}
}
let spl = sum_complex.mapv(|z| 20.0 * z.norm().max(1e-12).log10());
let phase = sum_complex.mapv(|z| z.arg() * 180.0 / PI);
Ok(Curve {
freq: ref_freq,
spl,
phase: Some(phase),
})
}
fn invert_polarity(curve: &Curve) -> Curve {
let mut new_curve = curve.clone();
if let Some(ref mut phase) = new_curve.phase {
*phase = phase.mapv(|p| p + 180.0);
} else {
new_curve.phase = Some(Array1::from_elem(curve.freq.len(), 180.0));
}
new_curve
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_invert_polarity() {
let freq = Array1::from(vec![100.0, 1000.0]);
let spl = Array1::from(vec![80.0, 80.0]);
let phase = Array1::from(vec![0.0, -90.0]);
let curve = Curve {
freq: freq.clone(),
spl: spl.clone(),
phase: Some(phase.clone()),
};
let inverted = invert_polarity(&curve);
let inv_phase = inverted.phase.unwrap();
assert!((inv_phase[0] - 180.0).abs() < 1e-6);
assert!((inv_phase[1] - 90.0).abs() < 1e-6); }
}