Skip to main content

math_audio_optimisation/
recorder.rs

1use crate::{CallbackAction, DEIntermediate};
2use ndarray::Array1;
3use std::fs::{File, create_dir_all};
4use std::io::{BufWriter, Write};
5use std::sync::{Arc, Mutex};
6
7/// Records optimization progress for every function evaluation
8#[derive(Debug)]
9pub struct OptimizationRecorder {
10    /// Function name (used for CSV filename)
11    function_name: String,
12    /// Output directory for CSV files
13    output_dir: String,
14    /// Shared evaluation records storage
15    records: Arc<Mutex<Vec<EvaluationRecord>>>,
16    /// Best function value seen so far
17    best_value: Arc<Mutex<Option<f64>>>,
18    /// Counter for function evaluations
19    eval_counter: Arc<Mutex<usize>>,
20    /// Current generation number
21    current_generation: Arc<Mutex<usize>>,
22    /// Block counter for periodic saves
23    block_counter: Arc<Mutex<usize>>,
24}
25
26/// A single function evaluation record
27#[derive(Debug, Clone)]
28pub struct EvaluationRecord {
29    /// Function evaluation number
30    pub eval_id: usize,
31    /// Generation number
32    pub generation: usize,
33    /// Input parameters x
34    pub x: Vec<f64>,
35    /// Function value f(x)
36    pub f_value: f64,
37    /// Current best function value so far
38    pub best_so_far: f64,
39    /// Whether this evaluation improved the global best
40    pub is_improvement: bool,
41}
42
43/// Legacy record type for compatibility
44#[derive(Debug, Clone)]
45pub struct OptimizationRecord {
46    /// Iteration number
47    pub iteration: usize,
48    /// Best x found so far
49    pub x: Vec<f64>,
50    /// Best function result so far
51    pub best_result: f64,
52    /// Convergence measure (standard deviation of population)
53    pub convergence: f64,
54    /// Whether this iteration improved the best known result
55    pub is_improvement: bool,
56}
57
58impl OptimizationRecorder {
59    /// Create a new optimization recorder for the given function
60    /// Uses the default records directory under AUTOEQ_DIR/data_generated/records
61    pub fn new(function_name: String) -> Self {
62        Self::with_output_dir(function_name, "./data_generated/records".to_string())
63    }
64
65    /// Create a new optimization recorder with custom output directory
66    pub fn with_output_dir(function_name: String, output_dir: String) -> Self {
67        Self {
68            function_name,
69            output_dir,
70            records: Arc::new(Mutex::new(Vec::new())),
71            best_value: Arc::new(Mutex::new(None)),
72            eval_counter: Arc::new(Mutex::new(0)),
73            current_generation: Arc::new(Mutex::new(0)),
74            block_counter: Arc::new(Mutex::new(0)),
75        }
76    }
77
78    /// Record a single function evaluation
79    pub fn record_evaluation(&self, x: &Array1<f64>, f_value: f64) {
80        let mut eval_counter_guard = self.eval_counter.lock().unwrap();
81        *eval_counter_guard += 1;
82        let eval_id = *eval_counter_guard;
83
84        drop(eval_counter_guard);
85
86        // Update best value
87        let mut best_guard = self.best_value.lock().unwrap();
88        let is_improvement = match *best_guard {
89            Some(best) => f_value < best,
90            None => true,
91        };
92
93        let best_so_far = if is_improvement {
94            *best_guard = Some(f_value);
95            f_value
96        } else {
97            best_guard.unwrap_or(f_value)
98        };
99        drop(best_guard);
100
101        // Record the evaluation
102        let mut records_guard = self.records.lock().unwrap();
103        let current_gen = *self.current_generation.lock().unwrap();
104        records_guard.push(EvaluationRecord {
105            eval_id,
106            generation: current_gen,
107            x: x.to_vec(),
108            f_value,
109            best_so_far,
110            is_improvement,
111        });
112
113        // Check if we need to save a block (every 10k evaluations)
114        if records_guard.len() >= 10_000 {
115            let records_to_save = records_guard.clone();
116            records_guard.clear();
117            drop(records_guard);
118
119            // Save block in background
120            let mut block_counter = self.block_counter.lock().unwrap();
121            *block_counter += 1;
122            let block_id = *block_counter;
123            drop(block_counter);
124
125            if let Err(e) = self.save_block_to_csv(&records_to_save, block_id) {
126                eprintln!(
127                    "Warning: Failed to save evaluation block {}: {}",
128                    block_id, e
129                );
130            }
131        }
132    }
133
134    /// Set the current generation number
135    pub fn set_generation(&self, generation: usize) {
136        *self.current_generation.lock().unwrap() = generation;
137    }
138
139    /// Create a callback function that updates generation number
140    pub fn create_callback(&self) -> Box<dyn FnMut(&DEIntermediate) -> CallbackAction + Send> {
141        let current_generation = self.current_generation.clone();
142
143        Box::new(move |intermediate: &DEIntermediate| -> CallbackAction {
144            *current_generation.lock().unwrap() = intermediate.iter;
145            CallbackAction::Continue
146        })
147    }
148
149    /// Save a block of evaluations to CSV file
150    fn save_block_to_csv(
151        &self,
152        records: &[EvaluationRecord],
153        block_id: usize,
154    ) -> Result<(), Box<dyn std::error::Error>> {
155        // Create output directory if it doesn't exist
156        create_dir_all(&self.output_dir)?;
157
158        let filename = format!(
159            "{}/{}_block_{:04}.csv",
160            self.output_dir, self.function_name, block_id
161        );
162        let mut file = BufWriter::new(File::create(&filename)?);
163
164        if records.is_empty() {
165            return Ok(());
166        }
167
168        // Write CSV header
169        let num_dimensions = records[0].x.len();
170        write!(file, "eval_id,generation,")?;
171        for i in 0..num_dimensions {
172            write!(file, "x{},", i)?;
173        }
174        writeln!(file, "f_value,best_so_far,is_improvement")?;
175
176        // Write data rows
177        for record in records.iter() {
178            write!(file, "{},{},", record.eval_id, record.generation)?;
179            for &xi in &record.x {
180                write!(file, "{:.16},", xi)?;
181            }
182            writeln!(
183                file,
184                "{:.16},{:.16},{}",
185                record.f_value, record.best_so_far, record.is_improvement
186            )?;
187        }
188
189        file.flush()?;
190        Ok(())
191    }
192
193    /// Save any remaining records and finalize
194    pub fn finalize(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
195        // Save any remaining records
196        let mut records_guard = self.records.lock().unwrap();
197        if !records_guard.is_empty() {
198            let records_to_save = records_guard.clone();
199            records_guard.clear();
200            drop(records_guard);
201
202            let mut block_counter = self.block_counter.lock().unwrap();
203            *block_counter += 1;
204            let block_id = *block_counter;
205            drop(block_counter);
206
207            self.save_block_to_csv(&records_to_save, block_id)?;
208        } else {
209            drop(records_guard);
210        }
211
212        // Create a summary file with metadata
213        self.save_summary(&[])?;
214
215        // Return all saved CSV files
216        let total_blocks = *self.block_counter.lock().unwrap();
217        let mut saved_files = Vec::new();
218        for block_id in 1..=total_blocks {
219            saved_files.push(format!(
220                "{}/{}_block_{:04}.csv",
221                self.output_dir, self.function_name, block_id
222            ));
223        }
224
225        Ok(saved_files)
226    }
227
228    /// Save summary file with metadata
229    fn save_summary(&self, _block_files: &[String]) -> Result<(), Box<dyn std::error::Error>> {
230        let summary_filename = format!("{}/{}_summary.txt", self.output_dir, self.function_name);
231        let mut file = File::create(&summary_filename)?;
232
233        let total_evaluations = *self.eval_counter.lock().unwrap();
234        let total_blocks = *self.block_counter.lock().unwrap();
235        let best_value = *self.best_value.lock().unwrap();
236
237        writeln!(file, "Function: {}", self.function_name)?;
238        writeln!(file, "Total evaluations: {}", total_evaluations)?;
239        writeln!(file, "Total blocks: {}", total_blocks)?;
240        writeln!(file, "Best value found: {:?}", best_value)?;
241        writeln!(file, "Block files:")?;
242
243        // List all block files that were saved (from 1 to total_blocks)
244        for block_id in 1..=total_blocks {
245            writeln!(file, "  {}_block_{:04}.csv", self.function_name, block_id)?;
246        }
247
248        Ok(())
249    }
250
251    /// Get evaluation statistics
252    pub fn get_stats(&self) -> (usize, Option<f64>, usize) {
253        let total_evals = *self.eval_counter.lock().unwrap();
254        let best_value = *self.best_value.lock().unwrap();
255        let total_blocks = *self.block_counter.lock().unwrap();
256        (total_evals, best_value, total_blocks)
257    }
258
259    /// Legacy method: Save all recorded iterations to a CSV file (for compatibility)
260    pub fn save_to_csv(&self, output_dir: &str) -> Result<String, Box<dyn std::error::Error>> {
261        // For backward compatibility, just finalize and return the first block filename
262        let saved_files = self.finalize()?;
263        if let Some(first_file) = saved_files.first() {
264            Ok(first_file.clone())
265        } else {
266            Ok(format!("{}/{}_no_data.csv", output_dir, self.function_name))
267        }
268    }
269
270    /// Get a copy of all recorded iterations (legacy compatibility - returns empty)
271    pub fn get_records(&self) -> Vec<OptimizationRecord> {
272        // Legacy compatibility: evaluation records are saved to disk, not kept in memory
273        Vec::new()
274    }
275
276    /// Test-only method: Get evaluation records converted to legacy format
277    #[cfg(test)]
278    pub fn get_test_records(&self) -> Vec<OptimizationRecord> {
279        let records_guard = self.records.lock().unwrap();
280        records_guard
281            .iter()
282            .map(|eval_record| {
283                OptimizationRecord {
284                    iteration: eval_record.generation,
285                    x: eval_record.x.clone(),
286                    best_result: eval_record.best_so_far,
287                    convergence: 0.0, // Not tracked in new system
288                    is_improvement: eval_record.is_improvement,
289                }
290            })
291            .collect()
292    }
293
294    /// Get the number of evaluations recorded
295    pub fn num_iterations(&self) -> usize {
296        *self.eval_counter.lock().unwrap()
297    }
298
299    /// Clear all recorded evaluations
300    pub fn clear(&self) {
301        self.records.lock().unwrap().clear();
302        *self.best_value.lock().unwrap() = None;
303        *self.eval_counter.lock().unwrap() = 0;
304        *self.current_generation.lock().unwrap() = 0;
305        *self.block_counter.lock().unwrap() = 0;
306    }
307
308    /// Get the final best solution if any evaluations were recorded
309    pub fn get_best_solution(&self) -> Option<(Vec<f64>, f64)> {
310        // Since we don't keep all records in memory, we can't return the exact solution
311        // This would need to be reconstructed from the CSV files if needed
312        (*self.best_value.lock().unwrap()).map(|best_val| (Vec::new(), best_val))
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use crate::{
319        DEConfigBuilder, recorder::OptimizationRecorder, run_recorded_differential_evolution,
320    };
321    use math_audio_test_functions::quadratic;
322    use ndarray::Array1;
323
324    #[test]
325    fn test_optimization_recorder() {
326        let recorder = OptimizationRecorder::new("test_function".to_string());
327
328        // Test recording evaluations directly
329        let x1 = Array1::from(vec![1.0, 2.0]);
330        recorder.set_generation(0);
331        recorder.record_evaluation(&x1, 5.0);
332
333        let x2 = Array1::from(vec![0.5, 1.0]);
334        recorder.set_generation(1);
335        recorder.record_evaluation(&x2, 1.25);
336
337        // Check records using test method
338        let records = recorder.get_test_records();
339        assert_eq!(records.len(), 2);
340
341        assert_eq!(records[0].iteration, 0);
342        assert_eq!(records[0].x, vec![1.0, 2.0]);
343        assert_eq!(records[0].best_result, 5.0);
344        assert!(records[0].is_improvement);
345
346        assert_eq!(records[1].iteration, 1);
347        assert_eq!(records[1].x, vec![0.5, 1.0]);
348        assert_eq!(records[1].best_result, 1.25);
349        assert!(records[1].is_improvement);
350    }
351
352    #[test]
353    fn test_recorded_optimization() {
354        // Test recording with simple quadratic function
355        let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
356        let config = DEConfigBuilder::new()
357            .seed(42)
358            .maxiter(50) // Keep it short for testing
359            .popsize(10)
360            .build()
361            .expect("popsize must be >= 4");
362
363        let result = run_recorded_differential_evolution("quadratic", quadratic, &bounds, config);
364
365        match result {
366            Ok((_de_report, csv_path)) => {
367                // Check that CSV file was created
368                assert!(std::path::Path::new(&csv_path).exists());
369                println!("CSV saved to: {}", csv_path);
370
371                // Read and verify CSV content
372                let csv_content = std::fs::read_to_string(&csv_path).expect("Failed to read CSV");
373                let lines: Vec<&str> = csv_content.trim().split('\n').collect();
374
375                // Should have header plus at least a few iterations
376                assert!(lines.len() > 1, "CSV should have header plus data rows");
377
378                // Check header format
379                let header = lines[0];
380                assert!(
381                    header
382                        .starts_with("eval_id,generation,x0,x1,f_value,best_so_far,is_improvement")
383                );
384
385                println!(
386                    "Recording test passed - {} iterations recorded",
387                    lines.len() - 1
388                );
389            }
390            Err(e) => {
391                panic!(
392                    "Test requires AUTOEQ_DIR to be set. Error: {}\nPlease run: export AUTOEQ_DIR=/path/to/autoeq",
393                    e
394                );
395            }
396        }
397    }
398}