use crate::recorder::OptimizationRecorder;
use crate::{DEConfig, DEReport, differential_evolution};
use autoeq_env::get_records_dir;
use ndarray::Array1;
pub fn run_recorded_differential_evolution<F>(
function_name: &str,
func: F,
bounds: &[(f64, f64)],
config: DEConfig,
) -> Result<(DEReport, String), Box<dyn std::error::Error>>
where
F: Fn(&Array1<f64>) -> f64 + Send + Sync + 'static,
{
let records_dir =
get_records_dir().map_err(|e| format!("Failed to get records directory: {}", e))?;
let output_dir = records_dir.to_string_lossy().to_string();
let recorder = std::sync::Arc::new(OptimizationRecorder::with_output_dir(
function_name.to_string(),
output_dir.to_string(),
));
let recorder_clone = recorder.clone();
let recorded_func = move |x: &Array1<f64>| -> f64 {
let f_value = func(x);
recorder_clone.record_evaluation(x, f_value);
f_value
};
let result = differential_evolution(&recorded_func, bounds, config);
let csv_files = recorder.finalize()?;
let csv_path = if !csv_files.is_empty() {
csv_files[0].clone()
} else {
format!("{}/{}.csv", output_dir, function_name)
};
Ok((result, csv_path))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEConfigBuilder;
#[test]
fn test_run_recorded_basic() {
let quadratic = |x: &Array1<f64>| -> f64 { x.iter().map(|&xi| xi * xi).sum() };
let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
let config = DEConfigBuilder::new()
.seed(42)
.maxiter(20)
.popsize(10)
.build();
let result =
run_recorded_differential_evolution("test_quadratic", quadratic, &bounds, config);
match result {
Ok((report, csv_path)) => {
println!("Result: f = {:.6e}, x = {:?}", report.fun, report.x);
assert!(report.fun < 1e-3, "Function value too high: {}", report.fun);
for &xi in report.x.iter() {
assert!(xi.abs() < 1e-1, "Variable too far from 0: {}", xi);
}
println!("CSV saved to: {}", csv_path);
}
Err(e) => {
println!(
"Test requires AUTOEQ_DIR to be set. Error: {}\nPlease run: export AUTOEQ_DIR=/Users/pierrre/src.local/autoeq",
e
);
panic!("Test requires AUTOEQ_DIR to be set.");
}
}
}
}