use crate::Curve;
use crate::loss::{CrossoverType, DriverMeasurement, DriversLossData};
use crate::workflow::DriverOptimizationResult;
use clap::Parser;
use ndarray::Array1;
use std::error::Error;
use super::types::{DBAConfig, OptimizerConfig};
use crate::read as load;
pub fn optimize_dba(
dba_config: &DBAConfig,
config: &OptimizerConfig,
sample_rate: f64,
) -> Result<(DriverOptimizationResult, Curve), Box<dyn Error>> {
let front_curve = sum_array_response(&dba_config.front)?;
let rear_curve = sum_array_response(&dba_config.rear)?;
let rear_curve_inverted = invert_polarity(&rear_curve);
let driver_measurements = vec![
DriverMeasurement {
freq: front_curve.freq.clone(),
spl: front_curve.spl.clone(),
phase: front_curve.phase.clone(),
},
DriverMeasurement {
freq: rear_curve_inverted.freq.clone(),
spl: rear_curve_inverted.spl.clone(),
phase: rear_curve_inverted.phase.clone(),
},
];
let drivers_data = DriversLossData::new(driver_measurements, CrossoverType::None);
let mut args = crate::cli::Args::parse_from(["autoeq"]); args.sample_rate = sample_rate;
args.min_freq = config.min_freq;
args.max_freq = config.max_freq;
args.maxeval = config.max_iter;
args.population = config.population;
args.algo = config.algorithm.clone();
args.seed = config.seed;
args.loss = crate::LossType::MultiSubFlat;
let optim_params = crate::OptimParams::from(&args);
let objective_data =
crate::workflow::setup_multisub_objective_data(&optim_params, drivers_data.clone());
let min_gain = config.min_db.min(-30.0);
let max_gain = 0.0;
let lower_bounds = vec![-0.01, min_gain, 0.0, 0.0];
let upper_bounds = vec![0.01, max_gain, 0.001, 100.0];
let mut x = vec![0.0, -3.0, 0.0, 10.0];
let opt_result = crate::optim::optimize_filters(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data,
&optim_params,
);
let converged = opt_result.is_ok();
let gains = vec![x[0], x[1]];
let delays = vec![x[2], x[3]];
let crossover_freqs = vec![];
let combined_curve = compute_dba_combined_curve(
&front_curve,
&rear_curve_inverted,
&gains,
&delays,
&drivers_data.freq_grid,
sample_rate,
)?;
Ok((
DriverOptimizationResult {
gains,
delays,
crossover_freqs,
pre_objective: 0.0, post_objective: 0.0,
converged,
},
combined_curve,
))
}
pub fn sum_array_response(
sources: &[super::types::MeasurementSource],
) -> Result<Curve, Box<dyn Error>> {
if sources.is_empty() {
return Err("Empty array".into());
}
let mut curves = Vec::new();
for source in sources {
let curve = load::load_source(source)?;
if curve.phase.is_none() {
return Err(format!(
"DBA array summation requires phase data for source {:?}",
source
)
.into());
}
curves.push(curve);
}
let ref_freq = curves[0].freq.clone();
use num_complex::Complex64;
use std::f64::consts::PI;
let mut sum_complex = Array1::<Complex64>::zeros(ref_freq.len());
for curve in &curves {
let interp = crate::read::interpolate_log_space(&ref_freq, curve);
for i in 0..ref_freq.len() {
let spl = interp.spl[i];
let phase = interp
.phase
.as_ref()
.ok_or("DBA interpolation lost required phase data")?[i];
let m = 10.0_f64.powf(spl / 20.0);
let phi = phase * PI / 180.0;
sum_complex[i] += Complex64::from_polar(m, phi);
}
}
let spl = sum_complex.mapv(|z| 20.0 * z.norm().max(1e-12).log10());
let phase = sum_complex.mapv(|z| z.arg() * 180.0 / PI);
Ok(Curve {
freq: ref_freq,
spl,
phase: Some(phase),
..Default::default()
})
}
fn invert_polarity(curve: &Curve) -> Curve {
let mut new_curve = curve.clone();
if let Some(ref mut phase) = new_curve.phase {
*phase = phase.mapv(|p| p + 180.0);
}
new_curve
}
fn compute_dba_combined_curve(
front_curve: &Curve,
rear_curve: &Curve,
gains: &[f64],
delays_ms: &[f64],
freq_grid: &Array1<f64>,
_sample_rate: f64,
) -> Result<Curve, Box<dyn Error>> {
use num_complex::Complex64;
use std::f64::consts::PI;
let front = crate::read::interpolate_log_space(freq_grid, front_curve);
let rear = crate::read::interpolate_log_space(freq_grid, rear_curve);
let front_phase = front
.phase
.as_ref()
.ok_or("DBA combined curve requires front phase data")?;
let rear_phase = rear
.phase
.as_ref()
.ok_or("DBA combined curve requires rear phase data")?;
let front_gain = gains.first().copied().unwrap_or(0.0);
let rear_gain = gains.get(1).copied().unwrap_or(0.0);
let front_delay_s = delays_ms.first().copied().unwrap_or(0.0) / 1000.0;
let rear_delay_s = delays_ms.get(1).copied().unwrap_or(0.0) / 1000.0;
let mut sum_complex = Array1::<Complex64>::zeros(freq_grid.len());
for i in 0..freq_grid.len() {
let f = freq_grid[i];
let front_mag = 10.0_f64.powf((front.spl[i] + front_gain) / 20.0);
let rear_mag = 10.0_f64.powf((rear.spl[i] + rear_gain) / 20.0);
let front_phi = front_phase[i].to_radians() - 2.0 * PI * f * front_delay_s;
let rear_phi = rear_phase[i].to_radians() - 2.0 * PI * f * rear_delay_s;
sum_complex[i] =
Complex64::from_polar(front_mag, front_phi) + Complex64::from_polar(rear_mag, rear_phi);
}
Ok(Curve {
freq: freq_grid.clone(),
spl: sum_complex.mapv(|z| 20.0 * z.norm().max(1e-12).log10()),
phase: Some(sum_complex.mapv(|z| z.arg().to_degrees())),
..Default::default()
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MeasurementSource;
#[test]
fn test_invert_polarity() {
let freq = Array1::from(vec![100.0, 1000.0]);
let spl = Array1::from(vec![80.0, 80.0]);
let phase = Array1::from(vec![0.0, -90.0]);
let curve = Curve {
freq: freq.clone(),
spl: spl.clone(),
phase: Some(phase.clone()),
..Default::default()
};
let inverted = invert_polarity(&curve);
let inv_phase = inverted.phase.unwrap();
assert!((inv_phase[0] - 180.0).abs() < 1e-6);
assert!((inv_phase[1] - 90.0).abs() < 1e-6); }
#[test]
fn sum_array_response_rejects_missing_phase() {
let curve = Curve {
freq: Array1::from(vec![50.0, 100.0]),
spl: Array1::from(vec![80.0, 80.0]),
phase: None,
..Default::default()
};
let err = sum_array_response(&[MeasurementSource::InMemory(curve)]).unwrap_err();
assert!(
err.to_string().contains("requires phase data"),
"unexpected error: {err}"
);
}
#[test]
fn sum_array_response_preserves_complex_phase() {
let curve_a = Curve {
freq: Array1::from(vec![100.0]),
spl: Array1::from(vec![80.0]),
phase: Some(Array1::from(vec![0.0])),
..Default::default()
};
let curve_b = Curve {
freq: Array1::from(vec![100.0]),
spl: Array1::from(vec![80.0]),
phase: Some(Array1::from(vec![90.0])),
..Default::default()
};
let summed = sum_array_response(&[
MeasurementSource::InMemory(curve_a),
MeasurementSource::InMemory(curve_b),
])
.unwrap();
assert!(summed.phase.is_some());
assert!((summed.phase.as_ref().unwrap()[0] - 45.0).abs() < 1e-6);
}
}