1use crate::{CallbackAction, DEIntermediate};
2use ndarray::Array1;
3use std::fs::{File, create_dir_all};
4use std::io::{BufWriter, Write};
5use std::sync::{Arc, Mutex};
6
7#[derive(Debug)]
9pub struct OptimizationRecorder {
10 function_name: String,
12 output_dir: String,
14 records: Arc<Mutex<Vec<EvaluationRecord>>>,
16 best_value: Arc<Mutex<Option<f64>>>,
18 eval_counter: Arc<Mutex<usize>>,
20 current_generation: Arc<Mutex<usize>>,
22 block_counter: Arc<Mutex<usize>>,
24}
25
26#[derive(Debug, Clone)]
28pub struct EvaluationRecord {
29 pub eval_id: usize,
31 pub generation: usize,
33 pub x: Vec<f64>,
35 pub f_value: f64,
37 pub best_so_far: f64,
39 pub is_improvement: bool,
41}
42
43#[derive(Debug, Clone)]
45pub struct OptimizationRecord {
46 pub iteration: usize,
48 pub x: Vec<f64>,
50 pub best_result: f64,
52 pub convergence: f64,
54 pub is_improvement: bool,
56}
57
58impl OptimizationRecorder {
59 pub fn new(function_name: String) -> Self {
62 Self::with_output_dir(function_name, "./data_generated/records".to_string())
63 }
64
65 pub fn with_output_dir(function_name: String, output_dir: String) -> Self {
67 Self {
68 function_name,
69 output_dir,
70 records: Arc::new(Mutex::new(Vec::new())),
71 best_value: Arc::new(Mutex::new(None)),
72 eval_counter: Arc::new(Mutex::new(0)),
73 current_generation: Arc::new(Mutex::new(0)),
74 block_counter: Arc::new(Mutex::new(0)),
75 }
76 }
77
78 pub fn record_evaluation(&self, x: &Array1<f64>, f_value: f64) {
80 let mut eval_counter_guard = self.eval_counter.lock().unwrap();
81 *eval_counter_guard += 1;
82 let eval_id = *eval_counter_guard;
83
84 drop(eval_counter_guard);
85
86 let mut best_guard = self.best_value.lock().unwrap();
88 let is_improvement = match *best_guard {
89 Some(best) => f_value < best,
90 None => true,
91 };
92
93 let best_so_far = if is_improvement {
94 *best_guard = Some(f_value);
95 f_value
96 } else {
97 best_guard.unwrap_or(f_value)
98 };
99 drop(best_guard);
100
101 let mut records_guard = self.records.lock().unwrap();
103 let current_gen = *self.current_generation.lock().unwrap();
104 records_guard.push(EvaluationRecord {
105 eval_id,
106 generation: current_gen,
107 x: x.to_vec(),
108 f_value,
109 best_so_far,
110 is_improvement,
111 });
112
113 if records_guard.len() >= 10_000 {
115 let records_to_save = records_guard.clone();
116 records_guard.clear();
117 drop(records_guard);
118
119 let mut block_counter = self.block_counter.lock().unwrap();
121 *block_counter += 1;
122 let block_id = *block_counter;
123 drop(block_counter);
124
125 if let Err(e) = self.save_block_to_csv(&records_to_save, block_id) {
126 eprintln!(
127 "Warning: Failed to save evaluation block {}: {}",
128 block_id, e
129 );
130 }
131 }
132 }
133
134 pub fn set_generation(&self, generation: usize) {
136 *self.current_generation.lock().unwrap() = generation;
137 }
138
139 pub fn create_callback(&self) -> Box<dyn FnMut(&DEIntermediate) -> CallbackAction + Send> {
141 let current_generation = self.current_generation.clone();
142
143 Box::new(move |intermediate: &DEIntermediate| -> CallbackAction {
144 *current_generation.lock().unwrap() = intermediate.iter;
145 CallbackAction::Continue
146 })
147 }
148
149 fn save_block_to_csv(
151 &self,
152 records: &[EvaluationRecord],
153 block_id: usize,
154 ) -> Result<(), Box<dyn std::error::Error>> {
155 create_dir_all(&self.output_dir)?;
157
158 let filename = format!(
159 "{}/{}_block_{:04}.csv",
160 self.output_dir, self.function_name, block_id
161 );
162 let mut file = BufWriter::new(File::create(&filename)?);
163
164 if records.is_empty() {
165 return Ok(());
166 }
167
168 let num_dimensions = records[0].x.len();
170 write!(file, "eval_id,generation,")?;
171 for i in 0..num_dimensions {
172 write!(file, "x{},", i)?;
173 }
174 writeln!(file, "f_value,best_so_far,is_improvement")?;
175
176 for record in records.iter() {
178 write!(file, "{},{},", record.eval_id, record.generation)?;
179 for &xi in &record.x {
180 write!(file, "{:.16},", xi)?;
181 }
182 writeln!(
183 file,
184 "{:.16},{:.16},{}",
185 record.f_value, record.best_so_far, record.is_improvement
186 )?;
187 }
188
189 file.flush()?;
190 Ok(())
191 }
192
193 pub fn finalize(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
195 let mut records_guard = self.records.lock().unwrap();
197 if !records_guard.is_empty() {
198 let records_to_save = records_guard.clone();
199 records_guard.clear();
200 drop(records_guard);
201
202 let mut block_counter = self.block_counter.lock().unwrap();
203 *block_counter += 1;
204 let block_id = *block_counter;
205 drop(block_counter);
206
207 self.save_block_to_csv(&records_to_save, block_id)?;
208 } else {
209 drop(records_guard);
210 }
211
212 self.save_summary(&[])?;
214
215 let total_blocks = *self.block_counter.lock().unwrap();
217 let mut saved_files = Vec::new();
218 for block_id in 1..=total_blocks {
219 saved_files.push(format!(
220 "{}/{}_block_{:04}.csv",
221 self.output_dir, self.function_name, block_id
222 ));
223 }
224
225 Ok(saved_files)
226 }
227
228 fn save_summary(&self, _block_files: &[String]) -> Result<(), Box<dyn std::error::Error>> {
230 let summary_filename = format!("{}/{}_summary.txt", self.output_dir, self.function_name);
231 let mut file = File::create(&summary_filename)?;
232
233 let total_evaluations = *self.eval_counter.lock().unwrap();
234 let total_blocks = *self.block_counter.lock().unwrap();
235 let best_value = *self.best_value.lock().unwrap();
236
237 writeln!(file, "Function: {}", self.function_name)?;
238 writeln!(file, "Total evaluations: {}", total_evaluations)?;
239 writeln!(file, "Total blocks: {}", total_blocks)?;
240 writeln!(file, "Best value found: {:?}", best_value)?;
241 writeln!(file, "Block files:")?;
242
243 for block_id in 1..=total_blocks {
245 writeln!(file, " {}_block_{:04}.csv", self.function_name, block_id)?;
246 }
247
248 Ok(())
249 }
250
251 pub fn get_stats(&self) -> (usize, Option<f64>, usize) {
253 let total_evals = *self.eval_counter.lock().unwrap();
254 let best_value = *self.best_value.lock().unwrap();
255 let total_blocks = *self.block_counter.lock().unwrap();
256 (total_evals, best_value, total_blocks)
257 }
258
259 pub fn save_to_csv(&self, output_dir: &str) -> Result<String, Box<dyn std::error::Error>> {
261 let saved_files = self.finalize()?;
263 if let Some(first_file) = saved_files.first() {
264 Ok(first_file.clone())
265 } else {
266 Ok(format!("{}/{}_no_data.csv", output_dir, self.function_name))
267 }
268 }
269
270 pub fn get_records(&self) -> Vec<OptimizationRecord> {
272 Vec::new()
274 }
275
276 #[cfg(test)]
278 pub fn get_test_records(&self) -> Vec<OptimizationRecord> {
279 let records_guard = self.records.lock().unwrap();
280 records_guard
281 .iter()
282 .map(|eval_record| {
283 OptimizationRecord {
284 iteration: eval_record.generation,
285 x: eval_record.x.clone(),
286 best_result: eval_record.best_so_far,
287 convergence: 0.0, is_improvement: eval_record.is_improvement,
289 }
290 })
291 .collect()
292 }
293
294 pub fn num_iterations(&self) -> usize {
296 *self.eval_counter.lock().unwrap()
297 }
298
299 pub fn clear(&self) {
301 self.records.lock().unwrap().clear();
302 *self.best_value.lock().unwrap() = None;
303 *self.eval_counter.lock().unwrap() = 0;
304 *self.current_generation.lock().unwrap() = 0;
305 *self.block_counter.lock().unwrap() = 0;
306 }
307
308 pub fn get_best_solution(&self) -> Option<(Vec<f64>, f64)> {
310 (*self.best_value.lock().unwrap()).map(|best_val| (Vec::new(), best_val))
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use crate::{
319 DEConfigBuilder, recorder::OptimizationRecorder, run_recorded_differential_evolution,
320 };
321 use math_audio_test_functions::quadratic;
322 use ndarray::Array1;
323
324 #[test]
325 fn test_optimization_recorder() {
326 let recorder = OptimizationRecorder::new("test_function".to_string());
327
328 let x1 = Array1::from(vec![1.0, 2.0]);
330 recorder.set_generation(0);
331 recorder.record_evaluation(&x1, 5.0);
332
333 let x2 = Array1::from(vec![0.5, 1.0]);
334 recorder.set_generation(1);
335 recorder.record_evaluation(&x2, 1.25);
336
337 let records = recorder.get_test_records();
339 assert_eq!(records.len(), 2);
340
341 assert_eq!(records[0].iteration, 0);
342 assert_eq!(records[0].x, vec![1.0, 2.0]);
343 assert_eq!(records[0].best_result, 5.0);
344 assert!(records[0].is_improvement);
345
346 assert_eq!(records[1].iteration, 1);
347 assert_eq!(records[1].x, vec![0.5, 1.0]);
348 assert_eq!(records[1].best_result, 1.25);
349 assert!(records[1].is_improvement);
350 }
351
352 #[test]
353 fn test_recorded_optimization() {
354 let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
356 let config = DEConfigBuilder::new()
357 .seed(42)
358 .maxiter(50) .popsize(10)
360 .build()
361 .expect("popsize must be >= 4");
362
363 let result = run_recorded_differential_evolution("quadratic", quadratic, &bounds, config);
364
365 match result {
366 Ok((_de_report, csv_path)) => {
367 assert!(std::path::Path::new(&csv_path).exists());
369 println!("CSV saved to: {}", csv_path);
370
371 let csv_content = std::fs::read_to_string(&csv_path).expect("Failed to read CSV");
373 let lines: Vec<&str> = csv_content.trim().split('\n').collect();
374
375 assert!(lines.len() > 1, "CSV should have header plus data rows");
377
378 let header = lines[0];
380 assert!(
381 header
382 .starts_with("eval_id,generation,x0,x1,f_value,best_so_far,is_improvement")
383 );
384
385 println!(
386 "Recording test passed - {} iterations recorded",
387 lines.len() - 1
388 );
389 }
390 Err(e) => {
391 panic!(
392 "Test requires AUTOEQ_DIR to be set. Error: {}\nPlease run: export AUTOEQ_DIR=/path/to/autoeq",
393 e
394 );
395 }
396 }
397 }
398}