Skip to main content

math_audio_optimisation/
function_registry.rs

1/// Shared function registry for differential evolution benchmarks and plotting
2use crate::Strategy;
3use math_audio_test_functions::*;
4use ndarray::Array1;
5use std::collections::HashMap;
6
7/// Test function type definition
8pub type TestFunction = fn(&Array1<f64>) -> f64;
9
10/// CSV trace point: (x_vector, f_value, is_improvement)
11pub type TracePoint = (Vec<f64>, f64, bool);
12
13/// Configuration for a benchmark run.
14#[derive(Clone, Debug)]
15pub struct BenchmarkConfig {
16    /// Descriptive name for the benchmark.
17    pub name: String,
18    /// Name of the test function.
19    pub function_name: String,
20    /// Variable bounds as (lower, upper) pairs.
21    pub bounds: Vec<(f64, f64)>,
22    /// Expected optimal solution coordinates.
23    pub expected_optimum: Vec<f64>,
24    /// Tolerance for objective function value comparison.
25    pub fun_tolerance: f64,
26    /// Tolerance for solution position comparison.
27    pub position_tolerance: f64,
28    /// Maximum iterations for the benchmark.
29    pub maxiter: usize,
30    /// Population size multiplier.
31    pub popsize: usize,
32    /// DE strategy to use.
33    pub strategy: Strategy,
34    /// Crossover probability.
35    pub recombination: f64,
36    /// Random seed for reproducibility.
37    pub seed: u64,
38}
39
40/// Function registry mapping names to actual function pointers.
41pub struct FunctionRegistry {
42    functions: HashMap<String, TestFunction>,
43}
44
45impl FunctionRegistry {
46    /// Creates a new registry with all standard test functions.
47    pub fn new() -> Self {
48        let mut functions = HashMap::new();
49
50        // Unimodal functions
51        functions.insert("sphere".to_string(), sphere as TestFunction);
52        functions.insert("rosenbrock".to_string(), rosenbrock as TestFunction);
53        functions.insert("booth".to_string(), booth as TestFunction);
54        functions.insert("matyas".to_string(), matyas as TestFunction);
55        functions.insert("beale".to_string(), beale as TestFunction);
56        functions.insert("himmelblau".to_string(), himmelblau as TestFunction);
57        functions.insert("sum_squares".to_string(), sum_squares as TestFunction);
58        functions.insert(
59            "different_powers".to_string(),
60            different_powers as TestFunction,
61        );
62        functions.insert("elliptic".to_string(), elliptic as TestFunction);
63        functions.insert("cigar".to_string(), cigar as TestFunction);
64        functions.insert("tablet".to_string(), tablet as TestFunction);
65        functions.insert("discus".to_string(), discus as TestFunction);
66        functions.insert("ridge".to_string(), ridge as TestFunction);
67        functions.insert("sharp_ridge".to_string(), sharp_ridge as TestFunction);
68        functions.insert("perm_0_d_beta".to_string(), perm_0_d_beta as TestFunction);
69        functions.insert("perm_d_beta".to_string(), perm_d_beta as TestFunction);
70
71        // Multimodal functions
72        functions.insert("ackley".to_string(), ackley as TestFunction);
73        functions.insert("ackley_n2".to_string(), ackley_n2 as TestFunction);
74        functions.insert("ackley_n3".to_string(), ackley_n3 as TestFunction);
75        functions.insert("rastrigin".to_string(), rastrigin as TestFunction);
76        functions.insert("griewank".to_string(), griewank as TestFunction);
77        functions.insert("schwefel".to_string(), schwefel as TestFunction);
78        functions.insert("branin".to_string(), branin as TestFunction);
79        functions.insert(
80            "goldstein_price".to_string(),
81            goldstein_price as TestFunction,
82        );
83        functions.insert("six_hump_camel".to_string(), six_hump_camel as TestFunction);
84        functions.insert("hartman_3d".to_string(), hartman_3d as TestFunction);
85        functions.insert("hartman_4d".to_string(), hartman_4d as TestFunction);
86        functions.insert("hartman_6d".to_string(), hartman_6d as TestFunction);
87        functions.insert(
88            "xin_she_yang_n1".to_string(),
89            xin_she_yang_n1 as TestFunction,
90        );
91        functions.insert(
92            "xin_she_yang_n2".to_string(),
93            xin_she_yang_n2 as TestFunction,
94        );
95        functions.insert(
96            "xin_she_yang_n3".to_string(),
97            xin_she_yang_n3 as TestFunction,
98        );
99        functions.insert(
100            "xin_she_yang_n4".to_string(),
101            xin_she_yang_n4 as TestFunction,
102        );
103        functions.insert("katsuura".to_string(), katsuura as TestFunction);
104        functions.insert("happycat".to_string(), happycat as TestFunction);
105        functions.insert("happy_cat".to_string(), happy_cat as TestFunction);
106
107        // Alpine functions
108        functions.insert("alpine_n1".to_string(), alpine_n1 as TestFunction);
109        functions.insert("alpine_n2".to_string(), alpine_n2 as TestFunction);
110
111        // Additional functions
112        functions.insert(
113            "gramacy_lee_2012".to_string(),
114            gramacy_lee_2012 as TestFunction,
115        );
116        functions.insert("forrester_2008".to_string(), forrester_2008 as TestFunction);
117        functions.insert("power_sum".to_string(), power_sum as TestFunction);
118        functions.insert("shekel".to_string(), shekel as TestFunction);
119        functions.insert(
120            "gramacy_lee_function".to_string(),
121            gramacy_lee_function as TestFunction,
122        );
123        functions.insert(
124            "expanded_griewank_rosenbrock".to_string(),
125            expanded_griewank_rosenbrock as TestFunction,
126        );
127
128        // More classical functions
129        functions.insert("bohachevsky1".to_string(), bohachevsky1 as TestFunction);
130        functions.insert("bohachevsky2".to_string(), bohachevsky2 as TestFunction);
131        functions.insert("bohachevsky3".to_string(), bohachevsky3 as TestFunction);
132        functions.insert("bird".to_string(), bird as TestFunction);
133        functions.insert("bent_cigar".to_string(), bent_cigar as TestFunction);
134        functions.insert("bent_cigar_alt".to_string(), bent_cigar_alt as TestFunction);
135        functions.insert("brown".to_string(), brown as TestFunction);
136        functions.insert("bukin_n6".to_string(), bukin_n6 as TestFunction);
137        functions.insert("chung_reynolds".to_string(), chung_reynolds as TestFunction);
138        functions.insert("colville".to_string(), colville as TestFunction);
139        functions.insert("cosine_mixture".to_string(), cosine_mixture as TestFunction);
140        functions.insert("cross_in_tray".to_string(), cross_in_tray as TestFunction);
141        functions.insert("de_jong_step2".to_string(), de_jong_step2 as TestFunction);
142        functions.insert(
143            "dejong_f5_foxholes".to_string(),
144            dejong_f5_foxholes as TestFunction,
145        );
146        functions.insert("dixons_price".to_string(), dixons_price as TestFunction);
147        functions.insert("drop_wave".to_string(), drop_wave as TestFunction);
148        functions.insert("easom".to_string(), easom as TestFunction);
149        functions.insert("eggholder".to_string(), eggholder as TestFunction);
150        functions.insert(
151            "epistatic_michalewicz".to_string(),
152            epistatic_michalewicz as TestFunction,
153        );
154        functions.insert("exponential".to_string(), exponential as TestFunction);
155        functions.insert(
156            "freudenstein_roth".to_string(),
157            freudenstein_roth as TestFunction,
158        );
159        functions.insert("griewank2".to_string(), griewank2 as TestFunction);
160        functions.insert("holder_table".to_string(), holder_table as TestFunction);
161        functions.insert(
162            "lampinen_simplified".to_string(),
163            lampinen_simplified as TestFunction,
164        );
165        functions.insert("langermann".to_string(), langermann as TestFunction);
166        functions.insert("levi13".to_string(), levi13 as TestFunction);
167        functions.insert("levy".to_string(), levy as TestFunction);
168        functions.insert("levy_n13".to_string(), levy_n13 as TestFunction);
169        functions.insert("mccormick".to_string(), mccormick as TestFunction);
170        functions.insert("michalewicz".to_string(), michalewicz as TestFunction);
171        functions.insert("periodic".to_string(), periodic as TestFunction);
172        functions.insert("pinter".to_string(), pinter as TestFunction);
173        functions.insert("powell".to_string(), powell as TestFunction);
174        functions.insert("qing".to_string(), qing as TestFunction);
175        functions.insert("quadratic".to_string(), quadratic as TestFunction);
176        functions.insert("quartic".to_string(), quartic as TestFunction);
177        functions.insert(
178            "rotated_hyper_ellipsoid".to_string(),
179            rotated_hyper_ellipsoid as TestFunction,
180        );
181        functions.insert("salomon".to_string(), salomon as TestFunction);
182        functions.insert(
183            "salomon_corrected".to_string(),
184            salomon_corrected as TestFunction,
185        );
186        functions.insert("schaffer_n2".to_string(), schaffer_n2 as TestFunction);
187        functions.insert("schaffer_n4".to_string(), schaffer_n4 as TestFunction);
188        functions.insert("schwefel2".to_string(), schwefel2 as TestFunction);
189        functions.insert("shubert".to_string(), shubert as TestFunction);
190        functions.insert("step".to_string(), step as TestFunction);
191        functions.insert(
192            "styblinski_tang2".to_string(),
193            styblinski_tang2 as TestFunction,
194        );
195        functions.insert(
196            "sum_of_different_powers".to_string(),
197            sum_of_different_powers as TestFunction,
198        );
199        functions.insert(
200            "three_hump_camel".to_string(),
201            three_hump_camel as TestFunction,
202        );
203        functions.insert("trid".to_string(), trid as TestFunction);
204        functions.insert("vincent".to_string(), vincent as TestFunction);
205        functions.insert("whitley".to_string(), whitley as TestFunction);
206        functions.insert("zakharov".to_string(), zakharov as TestFunction);
207        functions.insert("zakharov2".to_string(), zakharov2 as TestFunction);
208
209        // Constraint functions (for completeness)
210        functions.insert(
211            "binh_korn_constraint1".to_string(),
212            binh_korn_constraint1 as TestFunction,
213        );
214        functions.insert(
215            "binh_korn_constraint2".to_string(),
216            binh_korn_constraint2 as TestFunction,
217        );
218        functions.insert(
219            "binh_korn_weighted".to_string(),
220            binh_korn_weighted as TestFunction,
221        );
222        functions.insert(
223            "keanes_bump_constraint1".to_string(),
224            keanes_bump_constraint1 as TestFunction,
225        );
226        functions.insert(
227            "keanes_bump_constraint2".to_string(),
228            keanes_bump_constraint2 as TestFunction,
229        );
230        functions.insert(
231            "keanes_bump_objective".to_string(),
232            keanes_bump_objective as TestFunction,
233        );
234        functions.insert(
235            "mishras_bird_constraint".to_string(),
236            mishras_bird_constraint as TestFunction,
237        );
238        functions.insert(
239            "mishras_bird_objective".to_string(),
240            mishras_bird_objective as TestFunction,
241        );
242        functions.insert(
243            "rosenbrock_disk_constraint".to_string(),
244            rosenbrock_disk_constraint as TestFunction,
245        );
246        functions.insert(
247            "rosenbrock_objective".to_string(),
248            rosenbrock_objective as TestFunction,
249        );
250
251        Self { functions }
252    }
253
254    /// Gets a test function by name.
255    pub fn get(&self, name: &str) -> Option<TestFunction> {
256        self.functions.get(name).copied()
257    }
258
259    /// Lists all available function names, sorted alphabetically.
260    pub fn list_functions(&self) -> Vec<String> {
261        let mut names: Vec<_> = self.functions.keys().cloned().collect();
262        names.sort();
263        names
264    }
265
266    /// Returns an iterator over all (name, function) pairs.
267    pub fn iter(&self) -> impl Iterator<Item = (&String, &TestFunction)> {
268        self.functions.iter()
269    }
270}
271
272impl Default for FunctionRegistry {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Generate all benchmark configurations
279#[allow(clippy::vec_init_then_push)]
280pub fn generate_benchmark_configs() -> Vec<BenchmarkConfig> {
281    let mut configs = Vec::new();
282
283    // ACKLEY function benchmarks
284    configs.push(BenchmarkConfig {
285        name: "ackley_2d".to_string(),
286        function_name: "ackley".to_string(),
287        bounds: vec![(-32.768, 32.768), (-32.768, 32.768)],
288        expected_optimum: vec![0.0, 0.0],
289        fun_tolerance: 1e-3,
290        position_tolerance: 0.5,
291        maxiter: 800,
292        popsize: 40,
293        strategy: Strategy::Best1Exp,
294        recombination: 0.9,
295        seed: 42,
296    });
297
298    configs.push(BenchmarkConfig {
299        name: "ackley_10d".to_string(),
300        function_name: "ackley".to_string(),
301        bounds: vec![(-32.768, 32.768); 10],
302        expected_optimum: vec![0.0; 10],
303        fun_tolerance: 1e-2,
304        position_tolerance: 0.5,
305        maxiter: 1200,
306        popsize: 100,
307        strategy: Strategy::Rand1Exp,
308        recombination: 0.95,
309        seed: 43,
310    });
311
312    // BEALE function
313    configs.push(BenchmarkConfig {
314        name: "beale_2d".to_string(),
315        function_name: "beale".to_string(),
316        bounds: vec![(-4.5, 4.5); 2],
317        expected_optimum: vec![3.0, 0.5],
318        fun_tolerance: 1e-6,
319        position_tolerance: 1e-3,
320        maxiter: 800,
321        popsize: 40,
322        strategy: Strategy::Best1Exp,
323        recombination: 0.9,
324        seed: 108,
325    });
326
327    // ROSENBROCK function benchmarks
328    configs.push(BenchmarkConfig {
329        name: "rosenbrock_2d".to_string(),
330        function_name: "rosenbrock".to_string(),
331        bounds: vec![(-2.048, 2.048), (-2.048, 2.048)],
332        expected_optimum: vec![1.0, 1.0],
333        fun_tolerance: 1e-4,
334        position_tolerance: 1e-2,
335        maxiter: 800,
336        popsize: 40,
337        strategy: Strategy::Best1Exp,
338        recombination: 0.9,
339        seed: 48,
340    });
341
342    // Add more configurations as needed...
343    // (For brevity, I'm not including all configs here, but they should all be moved from benchmark_convergence.rs)
344
345    configs
346}
347
348/// Find CSV files for a given function in the records directory
349/// Handles both old single-file format and new block-based format
350pub fn find_csv_files_for_function(csv_dir: &str, function_name: &str) -> Vec<String> {
351    use std::fs;
352    use std::path::Path;
353
354    let mut csv_files = Vec::new();
355
356    // Try old format first
357    let old_path = format!("{}/{}.csv", csv_dir, function_name);
358    if Path::new(&old_path).exists() {
359        csv_files.push(old_path);
360        return csv_files;
361    }
362
363    // Look for block-based format files
364    if let Ok(entries) = fs::read_dir(csv_dir) {
365        for entry in entries.flatten() {
366            if let Some(filename) = entry.file_name().to_str() {
367                // Match files like function_name_block_NNNN.csv
368                if filename.starts_with(function_name)
369                    && filename.contains("_block_")
370                    && filename.ends_with(".csv")
371                {
372                    csv_files.push(entry.path().to_string_lossy().to_string());
373                }
374            }
375        }
376    }
377
378    // Sort files to ensure they're read in order
379    csv_files.sort();
380    csv_files
381}
382
383/// Read and combine multiple CSV files for a function
384pub fn read_combined_csv_traces(
385    csv_files: &[String],
386) -> Result<Vec<TracePoint>, Box<dyn std::error::Error>> {
387    use std::fs;
388
389    let mut all_points = Vec::new();
390
391    for csv_path in csv_files {
392        let content = fs::read_to_string(csv_path)?;
393        let lines: Vec<&str> = content.trim().split('\n').collect();
394
395        if lines.len() < 2 {
396            continue; // Skip empty files
397        }
398
399        let header = lines[0];
400
401        // Check if it's the new format
402        if !header.starts_with("eval_id,generation,") {
403            return Err(format!("Unsupported CSV format in {}", csv_path).into());
404        }
405
406        for line in lines.iter().skip(1) {
407            let parts: Vec<&str> = line.split(',').collect();
408
409            if parts.len() < 7 {
410                continue; // Skip malformed lines
411            }
412
413            // Parse x coordinates (between generation and f_value/best_so_far/is_improvement)
414            let x_end = parts.len() - 3;
415            let mut x = Vec::new();
416            for part in parts.iter().take(x_end).skip(2) {
417                if let Ok(coord) = part.parse::<f64>() {
418                    x.push(coord);
419                }
420            }
421
422            if let (Ok(f_value), Ok(is_improvement)) = (
423                parts[x_end].parse::<f64>(),
424                parts[x_end + 2].parse::<bool>(),
425            ) {
426                all_points.push((x, f_value, is_improvement));
427            }
428        }
429    }
430
431    Ok(all_points)
432}