1use crate::Strategy;
3use math_audio_test_functions::*;
4use ndarray::Array1;
5use std::collections::HashMap;
6
7pub type TestFunction = fn(&Array1<f64>) -> f64;
9
10pub type TracePoint = (Vec<f64>, f64, bool);
12
13#[derive(Clone, Debug)]
15pub struct BenchmarkConfig {
16 pub name: String,
18 pub function_name: String,
20 pub bounds: Vec<(f64, f64)>,
22 pub expected_optimum: Vec<f64>,
24 pub fun_tolerance: f64,
26 pub position_tolerance: f64,
28 pub maxiter: usize,
30 pub popsize: usize,
32 pub strategy: Strategy,
34 pub recombination: f64,
36 pub seed: u64,
38}
39
40pub struct FunctionRegistry {
42 functions: HashMap<String, TestFunction>,
43}
44
45impl FunctionRegistry {
46 pub fn new() -> Self {
48 let mut functions = HashMap::new();
49
50 functions.insert("sphere".to_string(), sphere as TestFunction);
52 functions.insert("rosenbrock".to_string(), rosenbrock as TestFunction);
53 functions.insert("booth".to_string(), booth as TestFunction);
54 functions.insert("matyas".to_string(), matyas as TestFunction);
55 functions.insert("beale".to_string(), beale as TestFunction);
56 functions.insert("himmelblau".to_string(), himmelblau as TestFunction);
57 functions.insert("sum_squares".to_string(), sum_squares as TestFunction);
58 functions.insert(
59 "different_powers".to_string(),
60 different_powers as TestFunction,
61 );
62 functions.insert("elliptic".to_string(), elliptic as TestFunction);
63 functions.insert("cigar".to_string(), cigar as TestFunction);
64 functions.insert("tablet".to_string(), tablet as TestFunction);
65 functions.insert("discus".to_string(), discus as TestFunction);
66 functions.insert("ridge".to_string(), ridge as TestFunction);
67 functions.insert("sharp_ridge".to_string(), sharp_ridge as TestFunction);
68 functions.insert("perm_0_d_beta".to_string(), perm_0_d_beta as TestFunction);
69 functions.insert("perm_d_beta".to_string(), perm_d_beta as TestFunction);
70
71 functions.insert("ackley".to_string(), ackley as TestFunction);
73 functions.insert("ackley_n2".to_string(), ackley_n2 as TestFunction);
74 functions.insert("ackley_n3".to_string(), ackley_n3 as TestFunction);
75 functions.insert("rastrigin".to_string(), rastrigin as TestFunction);
76 functions.insert("griewank".to_string(), griewank as TestFunction);
77 functions.insert("schwefel".to_string(), schwefel as TestFunction);
78 functions.insert("branin".to_string(), branin as TestFunction);
79 functions.insert(
80 "goldstein_price".to_string(),
81 goldstein_price as TestFunction,
82 );
83 functions.insert("six_hump_camel".to_string(), six_hump_camel as TestFunction);
84 functions.insert("hartman_3d".to_string(), hartman_3d as TestFunction);
85 functions.insert("hartman_4d".to_string(), hartman_4d as TestFunction);
86 functions.insert("hartman_6d".to_string(), hartman_6d as TestFunction);
87 functions.insert(
88 "xin_she_yang_n1".to_string(),
89 xin_she_yang_n1 as TestFunction,
90 );
91 functions.insert(
92 "xin_she_yang_n2".to_string(),
93 xin_she_yang_n2 as TestFunction,
94 );
95 functions.insert(
96 "xin_she_yang_n3".to_string(),
97 xin_she_yang_n3 as TestFunction,
98 );
99 functions.insert(
100 "xin_she_yang_n4".to_string(),
101 xin_she_yang_n4 as TestFunction,
102 );
103 functions.insert("katsuura".to_string(), katsuura as TestFunction);
104 functions.insert("happycat".to_string(), happycat as TestFunction);
105 functions.insert("happy_cat".to_string(), happy_cat as TestFunction);
106
107 functions.insert("alpine_n1".to_string(), alpine_n1 as TestFunction);
109 functions.insert("alpine_n2".to_string(), alpine_n2 as TestFunction);
110
111 functions.insert(
113 "gramacy_lee_2012".to_string(),
114 gramacy_lee_2012 as TestFunction,
115 );
116 functions.insert("forrester_2008".to_string(), forrester_2008 as TestFunction);
117 functions.insert("power_sum".to_string(), power_sum as TestFunction);
118 functions.insert("shekel".to_string(), shekel as TestFunction);
119 functions.insert(
120 "gramacy_lee_function".to_string(),
121 gramacy_lee_function as TestFunction,
122 );
123 functions.insert(
124 "expanded_griewank_rosenbrock".to_string(),
125 expanded_griewank_rosenbrock as TestFunction,
126 );
127
128 functions.insert("bohachevsky1".to_string(), bohachevsky1 as TestFunction);
130 functions.insert("bohachevsky2".to_string(), bohachevsky2 as TestFunction);
131 functions.insert("bohachevsky3".to_string(), bohachevsky3 as TestFunction);
132 functions.insert("bird".to_string(), bird as TestFunction);
133 functions.insert("bent_cigar".to_string(), bent_cigar as TestFunction);
134 functions.insert("bent_cigar_alt".to_string(), bent_cigar_alt as TestFunction);
135 functions.insert("brown".to_string(), brown as TestFunction);
136 functions.insert("bukin_n6".to_string(), bukin_n6 as TestFunction);
137 functions.insert("chung_reynolds".to_string(), chung_reynolds as TestFunction);
138 functions.insert("colville".to_string(), colville as TestFunction);
139 functions.insert("cosine_mixture".to_string(), cosine_mixture as TestFunction);
140 functions.insert("cross_in_tray".to_string(), cross_in_tray as TestFunction);
141 functions.insert("de_jong_step2".to_string(), de_jong_step2 as TestFunction);
142 functions.insert(
143 "dejong_f5_foxholes".to_string(),
144 dejong_f5_foxholes as TestFunction,
145 );
146 functions.insert("dixons_price".to_string(), dixons_price as TestFunction);
147 functions.insert("drop_wave".to_string(), drop_wave as TestFunction);
148 functions.insert("easom".to_string(), easom as TestFunction);
149 functions.insert("eggholder".to_string(), eggholder as TestFunction);
150 functions.insert(
151 "epistatic_michalewicz".to_string(),
152 epistatic_michalewicz as TestFunction,
153 );
154 functions.insert("exponential".to_string(), exponential as TestFunction);
155 functions.insert(
156 "freudenstein_roth".to_string(),
157 freudenstein_roth as TestFunction,
158 );
159 functions.insert("griewank2".to_string(), griewank2 as TestFunction);
160 functions.insert("holder_table".to_string(), holder_table as TestFunction);
161 functions.insert(
162 "lampinen_simplified".to_string(),
163 lampinen_simplified as TestFunction,
164 );
165 functions.insert("langermann".to_string(), langermann as TestFunction);
166 functions.insert("levi13".to_string(), levi13 as TestFunction);
167 functions.insert("levy".to_string(), levy as TestFunction);
168 functions.insert("levy_n13".to_string(), levy_n13 as TestFunction);
169 functions.insert("mccormick".to_string(), mccormick as TestFunction);
170 functions.insert("michalewicz".to_string(), michalewicz as TestFunction);
171 functions.insert("periodic".to_string(), periodic as TestFunction);
172 functions.insert("pinter".to_string(), pinter as TestFunction);
173 functions.insert("powell".to_string(), powell as TestFunction);
174 functions.insert("qing".to_string(), qing as TestFunction);
175 functions.insert("quadratic".to_string(), quadratic as TestFunction);
176 functions.insert("quartic".to_string(), quartic as TestFunction);
177 functions.insert(
178 "rotated_hyper_ellipsoid".to_string(),
179 rotated_hyper_ellipsoid as TestFunction,
180 );
181 functions.insert("salomon".to_string(), salomon as TestFunction);
182 functions.insert(
183 "salomon_corrected".to_string(),
184 salomon_corrected as TestFunction,
185 );
186 functions.insert("schaffer_n2".to_string(), schaffer_n2 as TestFunction);
187 functions.insert("schaffer_n4".to_string(), schaffer_n4 as TestFunction);
188 functions.insert("schwefel2".to_string(), schwefel2 as TestFunction);
189 functions.insert("shubert".to_string(), shubert as TestFunction);
190 functions.insert("step".to_string(), step as TestFunction);
191 functions.insert(
192 "styblinski_tang2".to_string(),
193 styblinski_tang2 as TestFunction,
194 );
195 functions.insert(
196 "sum_of_different_powers".to_string(),
197 sum_of_different_powers as TestFunction,
198 );
199 functions.insert(
200 "three_hump_camel".to_string(),
201 three_hump_camel as TestFunction,
202 );
203 functions.insert("trid".to_string(), trid as TestFunction);
204 functions.insert("vincent".to_string(), vincent as TestFunction);
205 functions.insert("whitley".to_string(), whitley as TestFunction);
206 functions.insert("zakharov".to_string(), zakharov as TestFunction);
207 functions.insert("zakharov2".to_string(), zakharov2 as TestFunction);
208
209 functions.insert(
211 "binh_korn_constraint1".to_string(),
212 binh_korn_constraint1 as TestFunction,
213 );
214 functions.insert(
215 "binh_korn_constraint2".to_string(),
216 binh_korn_constraint2 as TestFunction,
217 );
218 functions.insert(
219 "binh_korn_weighted".to_string(),
220 binh_korn_weighted as TestFunction,
221 );
222 functions.insert(
223 "keanes_bump_constraint1".to_string(),
224 keanes_bump_constraint1 as TestFunction,
225 );
226 functions.insert(
227 "keanes_bump_constraint2".to_string(),
228 keanes_bump_constraint2 as TestFunction,
229 );
230 functions.insert(
231 "keanes_bump_objective".to_string(),
232 keanes_bump_objective as TestFunction,
233 );
234 functions.insert(
235 "mishras_bird_constraint".to_string(),
236 mishras_bird_constraint as TestFunction,
237 );
238 functions.insert(
239 "mishras_bird_objective".to_string(),
240 mishras_bird_objective as TestFunction,
241 );
242 functions.insert(
243 "rosenbrock_disk_constraint".to_string(),
244 rosenbrock_disk_constraint as TestFunction,
245 );
246 functions.insert(
247 "rosenbrock_objective".to_string(),
248 rosenbrock_objective as TestFunction,
249 );
250
251 Self { functions }
252 }
253
254 pub fn get(&self, name: &str) -> Option<TestFunction> {
256 self.functions.get(name).copied()
257 }
258
259 pub fn list_functions(&self) -> Vec<String> {
261 let mut names: Vec<_> = self.functions.keys().cloned().collect();
262 names.sort();
263 names
264 }
265
266 pub fn iter(&self) -> impl Iterator<Item = (&String, &TestFunction)> {
268 self.functions.iter()
269 }
270}
271
272impl Default for FunctionRegistry {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[allow(clippy::vec_init_then_push)]
280pub fn generate_benchmark_configs() -> Vec<BenchmarkConfig> {
281 let mut configs = Vec::new();
282
283 configs.push(BenchmarkConfig {
285 name: "ackley_2d".to_string(),
286 function_name: "ackley".to_string(),
287 bounds: vec![(-32.768, 32.768), (-32.768, 32.768)],
288 expected_optimum: vec![0.0, 0.0],
289 fun_tolerance: 1e-3,
290 position_tolerance: 0.5,
291 maxiter: 800,
292 popsize: 40,
293 strategy: Strategy::Best1Exp,
294 recombination: 0.9,
295 seed: 42,
296 });
297
298 configs.push(BenchmarkConfig {
299 name: "ackley_10d".to_string(),
300 function_name: "ackley".to_string(),
301 bounds: vec![(-32.768, 32.768); 10],
302 expected_optimum: vec![0.0; 10],
303 fun_tolerance: 1e-2,
304 position_tolerance: 0.5,
305 maxiter: 1200,
306 popsize: 100,
307 strategy: Strategy::Rand1Exp,
308 recombination: 0.95,
309 seed: 43,
310 });
311
312 configs.push(BenchmarkConfig {
314 name: "beale_2d".to_string(),
315 function_name: "beale".to_string(),
316 bounds: vec![(-4.5, 4.5); 2],
317 expected_optimum: vec![3.0, 0.5],
318 fun_tolerance: 1e-6,
319 position_tolerance: 1e-3,
320 maxiter: 800,
321 popsize: 40,
322 strategy: Strategy::Best1Exp,
323 recombination: 0.9,
324 seed: 108,
325 });
326
327 configs.push(BenchmarkConfig {
329 name: "rosenbrock_2d".to_string(),
330 function_name: "rosenbrock".to_string(),
331 bounds: vec![(-2.048, 2.048), (-2.048, 2.048)],
332 expected_optimum: vec![1.0, 1.0],
333 fun_tolerance: 1e-4,
334 position_tolerance: 1e-2,
335 maxiter: 800,
336 popsize: 40,
337 strategy: Strategy::Best1Exp,
338 recombination: 0.9,
339 seed: 48,
340 });
341
342 configs
346}
347
348pub fn find_csv_files_for_function(csv_dir: &str, function_name: &str) -> Vec<String> {
351 use std::fs;
352 use std::path::Path;
353
354 let mut csv_files = Vec::new();
355
356 let old_path = format!("{}/{}.csv", csv_dir, function_name);
358 if Path::new(&old_path).exists() {
359 csv_files.push(old_path);
360 return csv_files;
361 }
362
363 if let Ok(entries) = fs::read_dir(csv_dir) {
365 for entry in entries.flatten() {
366 if let Some(filename) = entry.file_name().to_str() {
367 if filename.starts_with(function_name)
369 && filename.contains("_block_")
370 && filename.ends_with(".csv")
371 {
372 csv_files.push(entry.path().to_string_lossy().to_string());
373 }
374 }
375 }
376 }
377
378 csv_files.sort();
380 csv_files
381}
382
383pub fn read_combined_csv_traces(
385 csv_files: &[String],
386) -> Result<Vec<TracePoint>, Box<dyn std::error::Error>> {
387 use std::fs;
388
389 let mut all_points = Vec::new();
390
391 for csv_path in csv_files {
392 let content = fs::read_to_string(csv_path)?;
393 let lines: Vec<&str> = content.trim().split('\n').collect();
394
395 if lines.len() < 2 {
396 continue; }
398
399 let header = lines[0];
400
401 if !header.starts_with("eval_id,generation,") {
403 return Err(format!("Unsupported CSV format in {}", csv_path).into());
404 }
405
406 for line in lines.iter().skip(1) {
407 let parts: Vec<&str> = line.split(',').collect();
408
409 if parts.len() < 7 {
410 continue; }
412
413 let x_end = parts.len() - 3;
415 let mut x = Vec::new();
416 for part in parts.iter().take(x_end).skip(2) {
417 if let Ok(coord) = part.parse::<f64>() {
418 x.push(coord);
419 }
420 }
421
422 if let (Ok(f_value), Ok(is_improvement)) = (
423 parts[x_end].parse::<f64>(),
424 parts[x_end + 2].parse::<bool>(),
425 ) {
426 all_points.push((x, f_value, is_improvement));
427 }
428 }
429 }
430
431 Ok(all_points)
432}