Skip to main content

math_audio_optimisation/
run_recorded.rs

1//! Recording wrapper for differential evolution for testing purposes
2
3use crate::recorder::OptimizationRecorder;
4use crate::{DEConfig, DEReport, differential_evolution};
5use directories::ProjectDirs;
6use ndarray::Array1;
7
8/// Get the records directory using the directories crate
9fn get_records_dir() -> Result<std::path::PathBuf, String> {
10    let proj_dirs = ProjectDirs::from("org", "spinorama", "math-audio")
11        .ok_or("Failed to determine project directories")?;
12
13    let cache_dir = proj_dirs.cache_dir();
14    let records_dir = cache_dir.join("records");
15
16    // Create the directory if it doesn't exist
17    std::fs::create_dir_all(&records_dir)
18        .map_err(|e| format!("Failed to create records directory: {}", e))?;
19
20    Ok(records_dir)
21}
22
23/// Run differential evolution with recording of evaluations to CSV
24///
25/// This wrapper function is primarily used for testing and analysis.
26/// It records every function evaluation to CSV files for later analysis.
27pub fn run_recorded_differential_evolution<F>(
28    function_name: &str,
29    func: F,
30    bounds: &[(f64, f64)],
31    config: DEConfig,
32) -> Result<(DEReport, String), Box<dyn std::error::Error>>
33where
34    F: Fn(&Array1<f64>) -> f64 + Send + Sync + 'static,
35{
36    // Get the records directory from AUTOEQ_DIR environment variable
37    let records_dir =
38        get_records_dir().map_err(|e| format!("Failed to get records directory: {}", e))?;
39    let output_dir = records_dir.to_string_lossy().to_string();
40
41    // Create recorder for this optimization run
42    let recorder = std::sync::Arc::new(OptimizationRecorder::with_output_dir(
43        function_name.to_string(),
44        output_dir.to_string(),
45    ));
46
47    // Create wrapped objective function that records evaluations
48    let recorder_clone = recorder.clone();
49    let recorded_func = move |x: &Array1<f64>| -> f64 {
50        let f_value = func(x);
51        recorder_clone.record_evaluation(x, f_value);
52        f_value
53    };
54
55    // Run differential evolution with the wrapped function
56    let result = differential_evolution(&recorded_func, bounds, config)?;
57
58    // Finalize recording and get CSV file paths
59    let csv_files = recorder.finalize()?;
60
61    // Return the DE result and the primary CSV file path
62    let csv_path = if !csv_files.is_empty() {
63        csv_files[0].clone()
64    } else {
65        format!("{}/{}.csv", output_dir, function_name)
66    };
67
68    Ok((result, csv_path))
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::DEConfigBuilder;
75
76    #[test]
77    fn test_run_recorded_basic() {
78        // Simple quadratic function for testing
79        let quadratic = |x: &Array1<f64>| -> f64 { x.iter().map(|&xi| xi * xi).sum() };
80
81        let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
82        let config = DEConfigBuilder::new()
83            .seed(42)
84            .maxiter(20)
85            .popsize(10)
86            .build()
87            .expect("popsize must be >= 4");
88
89        let result =
90            run_recorded_differential_evolution("test_quadratic", quadratic, &bounds, config);
91
92        match result {
93            Ok((report, csv_path)) => {
94                // Should find minimum near origin
95                println!("Result: f = {:.6e}, x = {:?}", report.fun, report.x);
96                assert!(report.fun < 1e-3, "Function value too high: {}", report.fun);
97                for &xi in report.x.iter() {
98                    assert!(xi.abs() < 1e-1, "Variable too far from 0: {}", xi);
99                }
100
101                // CSV file should be created
102                println!("CSV saved to: {}", csv_path);
103            }
104            Err(e) => {
105                panic!("Failed to run recorded differential evolution: {}", e);
106            }
107        }
108    }
109}