math_audio_optimisation/
run_recorded.rs1use crate::recorder::OptimizationRecorder;
4use crate::{DEConfig, DEReport, differential_evolution};
5use directories::ProjectDirs;
6use ndarray::Array1;
7
8fn get_records_dir() -> Result<std::path::PathBuf, String> {
10 let proj_dirs = ProjectDirs::from("org", "spinorama", "math-audio")
11 .ok_or("Failed to determine project directories")?;
12
13 let cache_dir = proj_dirs.cache_dir();
14 let records_dir = cache_dir.join("records");
15
16 std::fs::create_dir_all(&records_dir)
18 .map_err(|e| format!("Failed to create records directory: {}", e))?;
19
20 Ok(records_dir)
21}
22
23pub fn run_recorded_differential_evolution<F>(
28 function_name: &str,
29 func: F,
30 bounds: &[(f64, f64)],
31 config: DEConfig,
32) -> Result<(DEReport, String), Box<dyn std::error::Error>>
33where
34 F: Fn(&Array1<f64>) -> f64 + Send + Sync + 'static,
35{
36 let records_dir =
38 get_records_dir().map_err(|e| format!("Failed to get records directory: {}", e))?;
39 let output_dir = records_dir.to_string_lossy().to_string();
40
41 let recorder = std::sync::Arc::new(OptimizationRecorder::with_output_dir(
43 function_name.to_string(),
44 output_dir.to_string(),
45 ));
46
47 let recorder_clone = recorder.clone();
49 let recorded_func = move |x: &Array1<f64>| -> f64 {
50 let f_value = func(x);
51 recorder_clone.record_evaluation(x, f_value);
52 f_value
53 };
54
55 let result = differential_evolution(&recorded_func, bounds, config)?;
57
58 let csv_files = recorder.finalize()?;
60
61 let csv_path = if !csv_files.is_empty() {
63 csv_files[0].clone()
64 } else {
65 format!("{}/{}.csv", output_dir, function_name)
66 };
67
68 Ok((result, csv_path))
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::DEConfigBuilder;
75
76 #[test]
77 fn test_run_recorded_basic() {
78 let quadratic = |x: &Array1<f64>| -> f64 { x.iter().map(|&xi| xi * xi).sum() };
80
81 let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
82 let config = DEConfigBuilder::new()
83 .seed(42)
84 .maxiter(20)
85 .popsize(10)
86 .build()
87 .expect("popsize must be >= 4");
88
89 let result =
90 run_recorded_differential_evolution("test_quadratic", quadratic, &bounds, config);
91
92 match result {
93 Ok((report, csv_path)) => {
94 println!("Result: f = {:.6e}, x = {:?}", report.fun, report.x);
96 assert!(report.fun < 1e-3, "Function value too high: {}", report.fun);
97 for &xi in report.x.iter() {
98 assert!(xi.abs() < 1e-1, "Variable too far from 0: {}", xi);
99 }
100
101 println!("CSV saved to: {}", csv_path);
103 }
104 Err(e) => {
105 panic!("Failed to run recorded differential evolution: {}", e);
106 }
107 }
108 }
109}