use ndarray::Array1;
use std::sync::Arc;
use super::constraints::{
CeilingConstraintData, MinGainConstraintData, SpacingConstraintData, constraint_ceiling,
constraint_min_gain, constraint_spacing,
};
use crate::de::init_sobol::init_sobol;
use super::initial_guess::{SmartInitConfig, create_smart_initial_guesses};
use super::optim::{ObjectiveData, PenaltyMode, compute_fitness_penalties_ref};
use super::optim_callback::{ProgressTracker, format_param_summary};
use crate::de::{
CallbackAction, DEConfig, DEConfigBuilder, DEIntermediate, DEReport, Init, Mutation,
NonlinearConstraintHelper, ParallelConfig, Strategy, differential_evolution,
};
pub struct DESetup {
pub bounds: Vec<(f64, f64)>,
pub penalty_data: ObjectiveData,
pub pop_multiplier: usize,
pub population_size: usize,
pub max_iter: usize,
}
fn count_free_dimensions(lower_bounds: &[f64], upper_bounds: &[f64]) -> usize {
lower_bounds
.iter()
.zip(upper_bounds.iter())
.filter(|(lo, hi)| **hi > **lo)
.count()
.max(1)
}
const MIN_DE_GENERATIONS: usize = 5000;
fn derive_de_budget(
lower_bounds: &[f64],
upper_bounds: &[f64],
population: usize,
maxeval: usize,
) -> (usize, usize, usize) {
let n_free = count_free_dimensions(lower_bounds, upper_bounds);
let desired_population = population.max(1).min(maxeval.max(1));
let pop_multiplier = desired_population.div_ceil(n_free).max(4);
let population_size = pop_multiplier * n_free;
let max_iter =
(maxeval.saturating_sub(population_size) / population_size).max(MIN_DE_GENERATIONS);
(pop_multiplier, population_size, max_iter)
}
pub fn setup_de_common(
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
population: usize,
maxeval: usize,
qa_mode: bool,
) -> DESetup {
let bounds: Vec<(f64, f64)> = lower_bounds
.iter()
.zip(upper_bounds.iter())
.map(|(&lo, &hi)| (lo, hi))
.collect();
let (pop_multiplier, population_size, max_iter) =
derive_de_budget(lower_bounds, upper_bounds, population, maxeval);
let mut penalty_data = objective_data.clone();
penalty_data.configure_penalties(PenaltyMode::Disabled);
if !qa_mode {
let params_desc = if penalty_data.loss_type == crate::LossType::DriversFlat {
format!("{} parameters", bounds.len())
} else {
let params_per_filter = crate::param_utils::params_per_filter(penalty_data.peq_model);
let num_filters = bounds.len() / params_per_filter;
format!("{} filters", num_filters)
};
log::debug!(
"DE Setup: {}, pop_multiplier={}, population_size={}, max_iter={}, maxeval={}",
params_desc,
pop_multiplier,
population_size,
max_iter,
maxeval
);
log::debug!(
" Penalty weights: ceiling={:.1e}, spacing={:.1e}, mingain={:.1e}",
penalty_data.penalty_w_ceiling,
penalty_data.penalty_w_spacing,
penalty_data.penalty_w_mingain
);
log::debug!(
" Constraints: max_db={:.1}, min_spacing={:.3} oct, min_db={:.1}",
penalty_data.max_db,
penalty_data.min_spacing_oct,
penalty_data.min_db
);
}
DESetup {
bounds,
penalty_data,
pop_multiplier,
population_size,
max_iter,
}
}
pub fn create_de_callback(
algo_name: &str,
qa_mode: bool,
) -> Box<dyn FnMut(&DEIntermediate) -> CallbackAction + Send> {
let name = algo_name.to_string();
let mut tracker = ProgressTracker::default();
Box::new(move |intermediate: &DEIntermediate| -> CallbackAction {
let (improvement, _) = tracker.update(intermediate.fun);
if !qa_mode && (tracker.just_started_stalling() || tracker.stall_at_interval(25)) {
log::debug!(
"{} iter {:4} fitness={:.6e} {} conv={:.3e}",
name,
intermediate.iter,
intermediate.fun,
improvement,
intermediate.convergence
);
}
if !qa_mode && intermediate.iter.is_multiple_of(100) {
let summary = format_param_summary(intermediate.x.as_slice().unwrap(), 3);
log::debug!(" --> Best params: {}", summary);
}
CallbackAction::Continue
})
}
pub fn create_de_objective(penalty_data: ObjectiveData) -> impl Fn(&Array1<f64>) -> f64 {
move |x_arr: &Array1<f64>| -> f64 {
let x_slice = x_arr.as_slice().unwrap();
compute_fitness_penalties_ref(x_slice, &penalty_data)
}
}
fn register_de_constraint<T, F>(config: &mut DEConfig, constraint_fn: F, data: T)
where
T: Clone + Send + Sync + 'static,
F: Fn(&[f64], Option<&mut [f64]>, &mut T) -> f64 + Send + Sync + 'static,
{
let constraint = NonlinearConstraintHelper {
fun: Arc::new(move |x: &Array1<f64>| {
let mut result = Array1::zeros(1);
let mut data = data.clone();
result[0] = constraint_fn(x.as_slice().unwrap(), None, &mut data);
result
}),
lb: Array1::from(vec![-1e30]),
ub: Array1::from(vec![0.0]),
};
constraint.apply_to(config, 1e3, 1e3);
}
pub fn process_de_results(
x: &mut [f64],
result: DEReport,
algo_name: &str,
) -> Result<(String, f64), (String, f64)> {
if result.x.len() == x.len() {
for (i, &value) in result.x.iter().enumerate() {
x[i] = value;
}
}
let status = if result.success {
format!("AutoEQ {}: {}", algo_name, result.message)
} else {
format!("AutoEQ {}: {} (not converged)", algo_name, result.message)
};
Ok((status, result.fun))
}
pub fn optimize_filters_autoeq(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
autoeq_name: &str,
cli_args: &crate::cli::Args,
) -> Result<(String, f64), (String, f64)> {
let callback = create_de_callback("autoeq::DE", cli_args.qa.is_some());
optimize_filters_autoeq_with_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
autoeq_name,
cli_args,
callback,
)
}
pub fn optimize_filters_autoeq_with_callback(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
_autoeq_name: &str,
cli_args: &crate::cli::Args,
mut callback: Box<dyn FnMut(&DEIntermediate) -> CallbackAction + Send>,
) -> Result<(String, f64), (String, f64)> {
let population = cli_args.population;
let maxeval = cli_args.maxeval;
let setup = setup_de_common(
lower_bounds,
upper_bounds,
objective_data.clone(),
population,
maxeval,
cli_args.qa.is_some(),
);
let base_objective_fn = create_de_objective(setup.penalty_data.clone());
let smart_guesses = if matches!(
setup.penalty_data.loss_type,
crate::LossType::DriversFlat | crate::LossType::MultiSubFlat
) {
Vec::new()
} else {
let params_per_filter =
crate::param_utils::params_per_filter(cli_args.effective_peq_model());
let num_filters = x.len() / params_per_filter;
let smart_config = SmartInitConfig {
seed: cli_args.seed, ..SmartInitConfig::default()
};
let target_response = &setup.penalty_data.deviation;
let freq_grid = &setup.penalty_data.freqs;
if cli_args.qa.is_none() {
log::debug!(
"🧠 Generating smart initial guesses based on frequency response analysis..."
);
}
let guesses = create_smart_initial_guesses(
target_response,
freq_grid,
num_filters,
&setup.bounds,
&smart_config,
cli_args.effective_peq_model(),
);
if cli_args.qa.is_none() {
log::debug!("📊 Generated {} smart initial guesses", guesses.len());
}
guesses
};
let sobol_samples = init_sobol(
x.len(),
setup.population_size.saturating_sub(smart_guesses.len()),
&setup.bounds,
);
if cli_args.qa.is_none() {
log::debug!(
"🎯 Generated {} Sobol quasi-random samples",
sobol_samples.len()
);
}
let best_initial_guess = if !smart_guesses.is_empty() {
Array1::from(smart_guesses[0].clone())
} else if !sobol_samples.is_empty() {
Array1::from(sobol_samples[0].clone())
} else {
Array1::from(x.to_vec())
};
if cli_args.qa.is_none() {
log::debug!("🚀 Using smart initial guess with Sobol population initialization");
}
use std::str::FromStr;
let strategy = Strategy::from_str(&cli_args.strategy).unwrap_or_else(|_| {
if cli_args.qa.is_none() {
log::debug!(
"⚠️ Warning: Invalid strategy '{}', falling back to CurrentToBest1Bin",
cli_args.strategy
);
}
Strategy::CurrentToBest1Bin
});
let adaptive_config = if matches!(strategy, Strategy::AdaptiveBin | Strategy::AdaptiveExp) {
Some(crate::de::AdaptiveConfig {
adaptive_mutation: true,
wls_enabled: false, w_max: 0.8, w_min: 0.2, w_f: cli_args.adaptive_weight_f * 0.5, w_cr: cli_args.adaptive_weight_cr * 0.5, f_m: 0.6, cr_m: 0.5, wls_prob: 0.0, wls_scale: 0.0, })
} else {
None
};
let (tolerance, atolerance) =
if matches!(strategy, Strategy::AdaptiveBin | Strategy::AdaptiveExp) {
(cli_args.tolerance * 10.0, cli_args.atolerance * 10.0)
} else {
(cli_args.tolerance, cli_args.atolerance)
};
let mut config_builder = DEConfigBuilder::new()
.maxiter(setup.max_iter)
.popsize(setup.pop_multiplier)
.tol(tolerance)
.atol(atolerance)
.strategy(strategy)
.mutation(Mutation::Range { min: 0.4, max: 1.2 })
.recombination(cli_args.recombination)
.init(Init::LatinHypercube) .x0(best_initial_guess) .disp(false)
.callback(Box::new(move |intermediate| callback(intermediate)));
if let Some(seed_value) = cli_args.seed {
config_builder = config_builder.seed(seed_value);
if cli_args.qa.is_none() {
log::debug!("🎲 Using deterministic seed: {}", seed_value);
}
}
if let Some(adaptive_cfg) = adaptive_config {
config_builder = config_builder.adaptive(adaptive_cfg);
}
let parallel_config = ParallelConfig {
enabled: !cli_args.no_parallel,
num_threads: if cli_args.parallel_threads == 0 {
None } else {
Some(cli_args.parallel_threads)
},
};
config_builder = config_builder.parallel(parallel_config);
if !cli_args.no_parallel && cli_args.qa.is_none() {
log::debug!(
"🚄 Parallel evaluation enabled with {} threads",
if cli_args.parallel_threads.eq(&0) {
"all available".to_string()
} else {
cli_args.parallel_threads.to_string()
}
);
}
let mut config = config_builder
.build()
.map_err(|e| (format!("DE config build failed: {:?}", e), f64::INFINITY))?;
if setup.penalty_data.max_db > 0.0 {
register_de_constraint(
&mut config,
constraint_ceiling,
CeilingConstraintData {
freqs: setup.penalty_data.freqs.clone(),
srate: setup.penalty_data.srate,
max_db: setup.penalty_data.max_db,
peq_model: setup.penalty_data.peq_model,
},
);
}
if setup.penalty_data.min_db > 0.0 {
register_de_constraint(
&mut config,
constraint_min_gain,
MinGainConstraintData {
min_db: setup.penalty_data.min_db,
peq_model: setup.penalty_data.peq_model,
},
);
}
if setup.penalty_data.min_spacing_oct > 0.0 {
register_de_constraint(
&mut config,
constraint_spacing,
SpacingConstraintData {
min_spacing_oct: setup.penalty_data.min_spacing_oct,
peq_model: setup.penalty_data.peq_model,
},
);
}
let result = differential_evolution(&base_objective_fn, &setup.bounds, config)
.map_err(|e| (format!("DE optimization failed: {:?}", e), f64::INFINITY))?;
process_de_results(x, result, "AutoDE")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LossType;
use crate::cli::PeqModel;
use ndarray::{Array1, array};
fn test_objective_data() -> ObjectiveData {
ObjectiveData {
freqs: array![100.0, 1000.0],
target: Array1::zeros(2),
deviation: Array1::zeros(2),
srate: 48_000.0,
min_spacing_oct: 0.0,
spacing_weight: 0.0,
max_db: 6.0,
min_db: 0.0,
min_freq: 20.0,
max_freq: 20_000.0,
peq_model: PeqModel::Pk,
loss_type: LossType::SpeakerFlat,
speaker_score_data: None,
headphone_score_data: None,
input_curve: None,
drivers_data: None,
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false,
smooth_n: 2,
max_boost_envelope: None,
min_cut_envelope: None,
}
}
#[test]
fn setup_de_common_enforces_minimum_generations() {
let lower_bounds = vec![-1.0, -1.0];
let upper_bounds = vec![1.0, 1.0];
let setup = setup_de_common(
&lower_bounds,
&upper_bounds,
test_objective_data(),
20,
55,
true,
);
assert_eq!(setup.population_size, 20);
assert_eq!(setup.max_iter, MIN_DE_GENERATIONS);
}
#[test]
fn setup_de_common_respects_large_maxeval() {
let lower_bounds = vec![-1.0, -1.0, -1.0];
let upper_bounds = vec![1.0, 1.0, 1.0];
let setup = setup_de_common(
&lower_bounds,
&upper_bounds,
test_objective_data(),
20,
1_000_000,
true,
);
assert!(setup.max_iter >= MIN_DE_GENERATIONS);
let expected = (1_000_000 - setup.population_size) / setup.population_size;
assert_eq!(setup.max_iter, expected);
}
}