use crate::recorder::OptimizationRecorder;
use crate::{differential_evolution, DEConfig, DEReport};
use directories::ProjectDirs;
use ndarray::Array1;
fn get_records_dir() -> Result<std::path::PathBuf, String> {
let proj_dirs = ProjectDirs::from("org", "spinorama", "math-audio")
.ok_or("Failed to determine project directories")?;
let cache_dir = proj_dirs.cache_dir();
let records_dir = cache_dir.join("records");
std::fs::create_dir_all(&records_dir)
.map_err(|e| format!("Failed to create records directory: {}", e))?;
Ok(records_dir)
}
pub fn run_recorded_differential_evolution<F>(
function_name: &str,
func: F,
bounds: &[(f64, f64)],
config: DEConfig,
) -> Result<(DEReport, String), Box<dyn std::error::Error>>
where
F: Fn(&Array1<f64>) -> f64 + Send + Sync + 'static,
{
let records_dir =
get_records_dir().map_err(|e| format!("Failed to get records directory: {}", e))?;
let output_dir = records_dir.to_string_lossy().to_string();
let recorder = std::sync::Arc::new(OptimizationRecorder::with_output_dir(
function_name.to_string(),
output_dir.to_string(),
));
let recorder_clone = recorder.clone();
let recorded_func = move |x: &Array1<f64>| -> f64 {
let f_value = func(x);
recorder_clone.record_evaluation(x, f_value);
f_value
};
let result = differential_evolution(&recorded_func, bounds, config)?;
let csv_files = recorder.finalize()?;
let csv_path = if !csv_files.is_empty() {
csv_files[0].clone()
} else {
format!("{}/{}.csv", output_dir, function_name)
};
Ok((result, csv_path))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEConfigBuilder;
#[test]
fn test_run_recorded_basic() {
let quadratic = |x: &Array1<f64>| -> f64 { x.iter().map(|&xi| xi * xi).sum() };
let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
let config = DEConfigBuilder::new()
.seed(42)
.maxiter(20)
.popsize(10)
.build()
.expect("popsize must be >= 4");
let result =
run_recorded_differential_evolution("test_quadratic", quadratic, &bounds, config);
match result {
Ok((report, csv_path)) => {
println!("Result: f = {:.6e}, x = {:?}", report.fun, report.x);
assert!(report.fun < 1e-3, "Function value too high: {}", report.fun);
for &xi in report.x.iter() {
assert!(xi.abs() < 1e-1, "Variable too far from 0: {}", xi);
}
println!("CSV saved to: {}", csv_path);
}
Err(e) => {
panic!("Failed to run recorded differential evolution: {}", e);
}
}
}
}