use autoeq::LossType;
use autoeq::OptimParams;
use autoeq::cli::{Args, PeqModel};
use autoeq::de::CallbackAction;
use autoeq::optim_mh::{MHIntermediate, create_mh_callback, optimize_filters_mh_with_callback};
use autoeq::workflow::{initial_guess, setup_bounds};
use clap::Parser;
use ndarray::Array1;
use std::sync::{Arc, Mutex};
fn create_test_objective_data() -> autoeq::optim::ObjectiveData {
let freqs = Array1::from(vec![100.0, 1000.0, 10000.0]);
let target = Array1::from(vec![1.0, 1.0, 1.0]);
let deviation = Array1::from(vec![0.5, 0.5, 0.5]);
autoeq::optim::ObjectiveData {
freqs: freqs.clone(),
target,
deviation,
input_curve: None,
srate: 48000.0,
min_spacing_oct: 0.5,
spacing_weight: 20.0,
max_db: 3.0,
min_db: 1.0,
min_freq: 60.0,
max_freq: 16000.0,
peq_model: PeqModel::Pk,
loss_type: LossType::SpeakerFlat,
speaker_score_data: None,
headphone_score_data: None,
drivers_data: None,
fixed_crossover_freqs: None,
penalty_w_ceiling: 0.0,
penalty_w_spacing: 0.0,
penalty_w_mingain: 0.0,
integrality: None,
multi_objective: None,
smooth: false,
smooth_n: 3,
max_boost_envelope: None,
min_cut_envelope: None,
epa_config: None,
detected_problems: Vec::new(),
null_suppression: None,
}
}
#[test]
fn test_mh_callback_is_invoked() {
let mut args = Args::parse_from([
"autoeq-test",
"--algo",
"mh:pso",
"--num-filters",
"2",
"--maxeval",
"200", ]);
args.population = 20;
let objective_data = create_test_objective_data();
let (lower_bounds, upper_bounds) = setup_bounds(&OptimParams::from(&args));
let mut x = initial_guess(&OptimParams::from(&args), &lower_bounds, &upper_bounds);
let callback_count = Arc::new(Mutex::new(0));
let callback_count_clone = Arc::clone(&callback_count);
let callback = Box::new(move |_intermediate: &MHIntermediate| -> CallbackAction {
if let Ok(mut count) = callback_count_clone.lock() {
*count += 1;
}
CallbackAction::Continue
});
let result = optimize_filters_mh_with_callback(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data,
"pso",
args.population,
args.maxeval,
callback,
);
assert!(result.is_ok(), "Optimization should succeed: {:?}", result);
let count = *callback_count.lock().unwrap();
assert!(
count > 0,
"Callback should have been invoked at least once, got {} invocations",
count
);
println!(
"✅ Callback was invoked {} times during optimization",
count
);
}
#[test]
fn test_mh_callback_receives_progress_data() {
let mut args = Args::parse_from([
"autoeq-test",
"--algo",
"mh:de",
"--num-filters",
"2",
"--maxeval",
"200",
]);
args.population = 15;
let objective_data = create_test_objective_data();
let (lower_bounds, upper_bounds) = setup_bounds(&OptimParams::from(&args));
let mut x = initial_guess(&OptimParams::from(&args), &lower_bounds, &upper_bounds);
let best_fitness = Arc::new(Mutex::new(f64::INFINITY));
let best_fitness_clone = Arc::clone(&best_fitness);
let callback = Box::new(move |intermediate: &MHIntermediate| -> CallbackAction {
if let Ok(mut best) = best_fitness_clone.lock() {
if intermediate.fun < *best {
*best = intermediate.fun;
}
assert!(intermediate.iter > 0, "Iteration should be positive");
assert!(!intermediate.x.is_empty(), "Parameters should not be empty");
assert!(intermediate.fun.is_finite(), "Fitness should be finite");
}
CallbackAction::Continue
});
let result = optimize_filters_mh_with_callback(
&mut x,
&lower_bounds,
&upper_bounds,
objective_data,
"de",
args.population,
args.maxeval,
callback,
);
assert!(result.is_ok(), "Optimization should succeed");
let final_best = *best_fitness.lock().unwrap();
assert!(
final_best < f64::INFINITY,
"Should have recorded at least one fitness value"
);
println!("✅ Best fitness observed via callback: {:.6e}", final_best);
}
#[test]
fn test_default_mh_callback_works() {
let mut callback = create_mh_callback("test_algo");
let intermediate = MHIntermediate {
x: Array1::from(vec![1.0, 2.0, 3.0]),
fun: 0.5,
iter: 10,
};
let result = callback(&intermediate);
assert!(matches!(result, CallbackAction::Continue));
println!("✅ Default callback works without crashing");
}