1#[allow(dead_code)]
7use crate::error::{OptimError, Result};
8use std::collections::{HashMap, VecDeque};
9use std::fmt::Write as FmtWrite;
10use std::io::Write;
11use std::path::Path;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13
14#[derive(Debug, Clone)]
16pub struct VisualizationConfig {
17 pub output_dir: String,
19
20 pub max_points: usize,
22
23 pub update_frequency: usize,
25
26 pub interactive_html: bool,
28
29 pub svg_output: bool,
31
32 pub color_scheme: ColorScheme,
34
35 pub figure_size: (u32, u32),
37
38 pub dpi: u32,
40
41 pub show_grid: bool,
43
44 pub show_legend: bool,
46}
47
48impl Default for VisualizationConfig {
49 fn default() -> Self {
50 Self {
51 output_dir: "optimization_plots".to_string(),
52 max_points: 10000,
53 update_frequency: 100,
54 interactive_html: true,
55 svg_output: false,
56 color_scheme: ColorScheme::Default,
57 figure_size: (800, 600),
58 dpi: 300,
59 show_grid: true,
60 show_legend: true,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy)]
67pub enum ColorScheme {
68 Default,
69 Dark,
70 Colorblind,
71 Publication,
72 Vibrant,
73}
74
75#[derive(Debug, Clone)]
77pub struct OptimizationMetric {
78 pub name: String,
80
81 pub values: VecDeque<f64>,
83
84 pub timestamps: VecDeque<u64>,
86
87 pub steps: VecDeque<usize>,
89
90 pub target: Option<f64>,
92
93 pub higher_isbetter: bool,
95
96 pub units: String,
98
99 pub smoothing_window: usize,
101}
102
103impl OptimizationMetric {
104 pub fn new(name: String, higher_isbetter: bool, units: String) -> Self {
106 Self {
107 name,
108 values: VecDeque::new(),
109 timestamps: VecDeque::new(),
110 steps: VecDeque::new(),
111 target: None,
112 higher_isbetter,
113 units,
114 smoothing_window: 10,
115 }
116 }
117
118 pub fn add_value(&mut self, value: f64, step: usize) {
120 let timestamp = SystemTime::now()
121 .duration_since(UNIX_EPOCH)
122 .unwrap()
123 .as_secs();
124
125 self.values.push_back(value);
126 self.timestamps.push_back(timestamp);
127 self.steps.push_back(step);
128
129 while self.values.len() > 50000 {
131 self.values.pop_front();
132 self.timestamps.pop_front();
133 self.steps.pop_front();
134 }
135 }
136
137 pub fn get_smoothed_values(&self) -> Vec<f64> {
139 if self.values.len() < self.smoothing_window {
140 return self.values.iter().copied().collect();
141 }
142
143 let mut smoothed = Vec::new();
144 let window = self.smoothing_window.min(self.values.len());
145
146 for i in 0..self.values.len() {
147 let start = i.saturating_sub(window / 2);
148 let end = (i + window / 2 + 1).min(self.values.len());
149
150 let sum: f64 = self.values.range(start..end).sum();
151 let avg = sum / (end - start) as f64;
152 smoothed.push(avg);
153 }
154
155 smoothed
156 }
157
158 pub fn get_recent_improvement(&self, windowsize: usize) -> Option<f64> {
160 if self.values.len() < windowsize * 2 {
161 return None;
162 }
163
164 let recent_avg: f64 =
165 self.values.iter().rev().take(windowsize).sum::<f64>() / windowsize as f64;
166 let older_avg: f64 = self
167 .values
168 .iter()
169 .rev()
170 .skip(windowsize)
171 .take(windowsize)
172 .sum::<f64>()
173 / windowsize as f64;
174
175 Some(if self.higher_isbetter {
176 recent_avg - older_avg
177 } else {
178 older_avg - recent_avg
179 })
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct OptimizerComparison {
186 pub name: String,
188
189 pub metrics: HashMap<String, Vec<f64>>,
191
192 pub hyperparameters: HashMap<String, f64>,
194
195 pub training_time: Duration,
197
198 pub memory_stats: MemoryStats,
200
201 pub convergence_info: ConvergenceInfo,
203}
204
205#[derive(Debug, Clone)]
207pub struct MemoryStats {
208 pub peak_memory_mb: f64,
210
211 pub avg_memory_mb: f64,
213
214 pub memory_efficiency: f64,
216}
217
218#[derive(Debug, Clone)]
220pub struct ConvergenceInfo {
221 pub converged: bool,
223
224 pub convergence_step: Option<usize>,
226
227 pub final_value: f64,
229
230 pub best_value: f64,
232
233 pub convergence_rate: f64,
235}
236
237pub struct OptimizationVisualizer {
239 config: VisualizationConfig,
241
242 metrics: HashMap<String, OptimizationMetric>,
244
245 comparisons: Vec<OptimizerComparison>,
247
248 dashboard_state: DashboardState,
250
251 current_step: usize,
253
254 last_update_step: usize,
256}
257
258#[derive(Debug)]
260struct DashboardState {
261 active_plots: HashMap<String, PlotState>,
263
264 layout: DashboardLayout,
266
267 last_update: SystemTime,
269}
270
271#[derive(Debug)]
273struct PlotState {
274 plot_type: PlotType,
276
277 series: Vec<DataSeries>,
279
280 x_axis: AxisConfig,
282 y_axis: AxisConfig,
283
284 title: String,
286}
287
288#[derive(Debug, Clone, Copy)]
290pub enum PlotType {
291 Line,
292 Scatter,
293 Histogram,
294 Heatmap,
295 Bar,
296 Box,
297 Violin,
298 Surface3D,
299}
300
301#[derive(Debug, Clone)]
303pub struct DataSeries {
304 pub name: String,
306
307 pub x_values: Vec<f64>,
309
310 pub y_values: Vec<f64>,
312
313 pub z_values: Option<Vec<f64>>,
315
316 pub color: String,
318
319 pub line_style: LineStyle,
321
322 pub marker_style: MarkerStyle,
324}
325
326#[derive(Debug, Clone, Copy)]
328pub enum LineStyle {
329 Solid,
330 Dashed,
331 Dotted,
332 DashDot,
333 None,
334}
335
336#[derive(Debug, Clone, Copy)]
338pub enum MarkerStyle {
339 Circle,
340 Square,
341 Triangle,
342 Diamond,
343 Plus,
344 Cross,
345 None,
346}
347
348#[derive(Debug, Clone)]
350pub struct AxisConfig {
351 pub label: String,
353
354 pub scale: AxisScale,
356
357 pub range: Option<(f64, f64)>,
359
360 pub ticks: TickConfig,
362}
363
364#[derive(Debug, Clone, Copy)]
366pub enum AxisScale {
367 Linear,
368 Log,
369 Symlog,
370}
371
372#[derive(Debug, Clone)]
374pub struct TickConfig {
375 pub major_spacing: Option<f64>,
377
378 pub minor_count: usize,
380
381 pub show_labels: bool,
383}
384
385#[derive(Debug, Clone)]
387pub struct DashboardLayout {
388 pub rows: usize,
390
391 pub cols: usize,
393
394 pub plot_positions: HashMap<String, (usize, usize)>,
396}
397
398impl OptimizationVisualizer {
399 pub fn new(config: VisualizationConfig) -> Result<Self> {
401 std::fs::create_dir_all(&config.output_dir).map_err(|e| {
403 OptimError::InvalidConfig(format!("Failed to create output directory: {e}"))
404 })?;
405
406 let dashboard_state = DashboardState {
407 active_plots: HashMap::new(),
408 layout: DashboardLayout {
409 rows: 2,
410 cols: 2,
411 plot_positions: HashMap::new(),
412 },
413 last_update: SystemTime::now(),
414 };
415
416 Ok(Self {
417 config,
418 metrics: HashMap::new(),
419 comparisons: Vec::new(),
420 dashboard_state,
421 current_step: 0,
422 last_update_step: 0,
423 })
424 }
425
426 pub fn add_metric(&mut self, name: String, value: f64, higher_isbetter: bool, units: String) {
428 let metric = self
429 .metrics
430 .entry(name.clone())
431 .or_insert_with(|| OptimizationMetric::new(name, higher_isbetter, units));
432
433 metric.add_value(value, self.current_step);
434 }
435
436 pub fn set_target(&mut self, metricname: &str, target: f64) {
438 if let Some(metric) = self.metrics.get_mut(metricname) {
439 metric.target = Some(target);
440 }
441 }
442
443 pub fn step(&mut self) {
445 self.current_step += 1;
446
447 if self.current_step - self.last_update_step >= self.config.update_frequency {
448 if let Err(e) = self.update_dashboard() {
449 eprintln!("Failed to update dashboard: {e}");
450 }
451 self.last_update_step = self.current_step;
452 }
453 }
454
455 pub fn plot_loss_curve(&self, metricname: &str) -> Result<String> {
457 let metric = self
458 .metrics
459 .get(metricname)
460 .ok_or_else(|| OptimError::InvalidConfig(format!("Metric '{metricname}' not found")))?;
461
462 let steps: Vec<f64> = metric.steps.iter().map(|&s| s as f64).collect();
463 let values = metric.get_smoothed_values();
464
465 let plotdata = self.create_line_plot(
466 &steps,
467 &values,
468 &format!("{} over Training Steps", metric.name),
469 "Training Steps",
470 &format!("{} ({})", metric.name, metric.units),
471 )?;
472
473 self.save_plot(&plotdata, &format!("{metricname}_curve"))
474 }
475
476 pub fn plot_learning_rate_schedule(&self) -> Result<String> {
478 if let Some(lr_metric) = self.metrics.get("learning_rate") {
479 let steps: Vec<f64> = lr_metric.steps.iter().map(|&s| s as f64).collect();
480 let values: Vec<f64> = lr_metric.values.iter().copied().collect();
481
482 let plotdata = self.create_line_plot(
483 &steps,
484 &values,
485 "Learning Rate Schedule",
486 "Training Steps",
487 "Learning Rate",
488 )?;
489
490 self.save_plot(&plotdata, "learning_rate_schedule")
491 } else {
492 Err(OptimError::InvalidConfig(
493 "Learning rate metric not found".to_string(),
494 ))
495 }
496 }
497
498 pub fn plot_optimizer_comparison(&self, metricname: &str) -> Result<String> {
500 if self.comparisons.is_empty() {
501 return Err(OptimError::InvalidConfig(
502 "No optimizer comparisons available".to_string(),
503 ));
504 }
505
506 let mut plotdata = String::new();
507
508 if self.config.interactive_html {
510 plotdata.push_str(&self.create_html_header("Optimizer Comparison")?);
511 plotdata.push_str("<div id='comparison-plot'></div>\n");
512 plotdata.push_str("<script>\n");
513 plotdata.push_str("const traces = [];\n");
514
515 for comparison in &self.comparisons {
516 if let Some(values) = comparison.metrics.get(metricname) {
517 let x_values: Vec<String> = (0..values.len()).map(|i| i.to_string()).collect();
518 writeln!(&mut plotdata,
519 "traces.push({{x: {:?}, y: {:?}, name: '{}', type: 'scatter', mode: 'lines'}});",
520 x_values, values, comparison.name
521 ).unwrap();
522 }
523 }
524
525 plotdata.push_str("Plotly.newPlot('comparison-plot', traces, {\n");
526 plotdata.push_str(" title: 'Optimizer Comparison',\n");
527 plotdata.push_str(" xaxis: {title: 'Training Steps'},\n");
528 writeln!(&mut plotdata, " yaxis: {{title: '{metricname}'}}").unwrap();
529 plotdata.push_str("});\n");
530 plotdata.push_str("</script>\n");
531 plotdata.push_str("</body></html>\n");
532 }
533
534 self.save_plot(&plotdata, &format!("{metricname}_comparison"))
535 }
536
537 pub fn plot_gradient_norm(&self) -> Result<String> {
539 if let Some(grad_metric) = self.metrics.get("gradient_norm") {
540 let steps: Vec<f64> = grad_metric.steps.iter().map(|&s| s as f64).collect();
541 let values: Vec<f64> = grad_metric.values.iter().copied().collect();
542
543 let mut plotdata = self.create_line_plot(
544 &steps,
545 &values,
546 "Gradient Norm",
547 "Training Steps",
548 "Gradient Norm",
549 )?;
550
551 let max_val = values.iter().fold(0.0f64, |a, &b| a.max(b));
553 let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
554
555 if max_val / min_val > 100.0 {
556 plotdata = plotdata.replace("yaxis: {", "yaxis: {type: 'log', ");
557 }
558
559 self.save_plot(&plotdata, "gradient_norm")
560 } else {
561 Err(OptimError::InvalidConfig(
562 "Gradient norm metric not found".to_string(),
563 ))
564 }
565 }
566
567 pub fn plot_throughput(&self) -> Result<String> {
569 if let Some(throughput_metric) = self.metrics.get("throughput") {
570 let steps: Vec<f64> = throughput_metric.steps.iter().map(|&s| s as f64).collect();
571 let values: Vec<f64> = throughput_metric.values.iter().copied().collect();
572
573 let plotdata = self.create_line_plot(
574 &steps,
575 &values,
576 "Training Throughput",
577 "Training Steps",
578 "Samples/Second",
579 )?;
580
581 self.save_plot(&plotdata, "throughput")
582 } else {
583 Err(OptimError::InvalidConfig(
584 "Throughput metric not found".to_string(),
585 ))
586 }
587 }
588
589 pub fn plot_memory_usage(&self) -> Result<String> {
591 if let Some(memory_metric) = self.metrics.get("memory_usage") {
592 let steps: Vec<f64> = memory_metric.steps.iter().map(|&s| s as f64).collect();
593 let values: Vec<f64> = memory_metric.values.iter().copied().collect();
594
595 let plotdata = self.create_line_plot(
596 &steps,
597 &values,
598 "Memory Usage",
599 "Training Steps",
600 "Memory (MB)",
601 )?;
602
603 self.save_plot(&plotdata, "memory_usage")
604 } else {
605 Err(OptimError::InvalidConfig(
606 "Memory usage metric not found".to_string(),
607 ))
608 }
609 }
610
611 pub fn plot_hyperparameter_sensitivity(
613 &self,
614 param_name: &str,
615 metricname: &str,
616 ) -> Result<String> {
617 let mut param_values = Vec::new();
618 let mut metric_values = Vec::new();
619
620 for comparison in &self.comparisons {
621 if let (Some(¶m_val), Some(metric_vals)) = (
622 comparison.hyperparameters.get(param_name),
623 comparison.metrics.get(metricname),
624 ) {
625 if let Some(&final_metric) = metric_vals.last() {
626 param_values.push(param_val);
627 metric_values.push(final_metric);
628 }
629 }
630 }
631
632 if param_values.is_empty() {
633 return Err(OptimError::InvalidConfig(format!(
634 "No data available for hyperparameter '{}' and metric '{}'",
635 param_name, metricname
636 )));
637 }
638
639 let plotdata = self.create_scatter_plot(
640 ¶m_values,
641 &metric_values,
642 &format!("Sensitivity of {} to {}", metricname, param_name),
643 param_name,
644 metricname,
645 )?;
646
647 self.save_plot(
648 &plotdata,
649 &format!("sensitivity_{}_{}", param_name, metricname),
650 )
651 }
652
653 pub fn create_dashboard(&self) -> Result<String> {
655 let mut dashboard = String::new();
656
657 if self.config.interactive_html {
658 dashboard.push_str(&self.create_html_header("Optimization Dashboard")?);
659
660 dashboard.push_str(
662 r#"
663<style>
664.dashboard-container {
665 display: grid;
666 grid-template-columns: 1fr 1fr;
667 grid-template-rows: 1fr 1fr;
668 gap: 20px;
669 height: 100vh;
670 padding: 20px;
671}
672.plot-container {
673 border: 1px solid #ddd;
674 border-radius: 8px;
675 padding: 10px;
676}
677.metrics-summary {
678 grid-column: span 2;
679 padding: 20px;
680 background-color: #f8f9fa;
681 border-radius: 8px;
682 margin-bottom: 20px;
683}
684</style>
685"#,
686 );
687
688 dashboard.push_str("<div class='metrics-summary'>\n");
690 dashboard.push_str("<h2>Current Metrics</h2>\n");
691 dashboard.push_str("<div style='display: flex; gap: 20px;'>\n");
692
693 for (name, metric) in &self.metrics {
694 if let Some(&latest_value) = metric.values.back() {
695 writeln!(
696 &mut dashboard,
697 "<div><strong>{}:</strong> {:.4} {}</div>",
698 name, latest_value, metric.units
699 )
700 .unwrap();
701 }
702 }
703
704 dashboard.push_str("</div></div>\n");
705
706 dashboard.push_str("<div class='dashboard-container'>\n");
708
709 let mut plot_id = 0;
710 for _ in &self.metrics {
711 if plot_id >= 4 {
712 break;
713 } writeln!(
716 &mut dashboard,
717 "<div class='plot-container'><div id='plot-{}'></div></div>",
718 plot_id
719 )
720 .unwrap();
721
722 plot_id += 1;
723 }
724
725 dashboard.push_str("</div>\n");
726
727 dashboard.push_str("<script>\n");
729
730 plot_id = 0;
731 for (name, metric) in &self.metrics {
732 if plot_id >= 4 {
733 break;
734 }
735
736 let steps: Vec<String> = metric.steps.iter().map(|&s| s.to_string()).collect();
737 let values: Vec<f64> = metric.values.iter().copied().collect();
738
739 writeln!(&mut dashboard,
740 "Plotly.newPlot('plot-{}', [{{x: {:?}, y: {:?}, type: 'scatter', mode: 'lines', name: '{}'}}], {{title: '{}', xaxis: {{title: 'Steps'}}, yaxis: {{title: '{}'}}}});",
741 plot_id, steps, values, name, name, metric.units
742 ).unwrap();
743
744 plot_id += 1;
745 }
746
747 dashboard.push_str("</script>\n");
748 dashboard.push_str("</body></html>\n");
749 }
750
751 self.save_plot(&dashboard, "dashboard")
752 }
753
754 fn update_dashboard(&mut self) -> Result<()> {
756 self.dashboard_state.last_update = SystemTime::now();
757
758 self.create_dashboard()?;
761
762 Ok(())
763 }
764
765 pub fn add_optimizer_comparison(&mut self, comparison: OptimizerComparison) {
767 self.comparisons.push(comparison);
768 }
769
770 pub fn export_all(&self) -> Result<Vec<String>> {
772 let mut exported_files = Vec::new();
773
774 for metricname in self.metrics.keys() {
776 if let Ok(filename) = self.plot_loss_curve(metricname) {
777 exported_files.push(filename);
778 }
779 }
780
781 for metricname in ["loss", "accuracy", "throughput"] {
783 if let Ok(filename) = self.plot_optimizer_comparison(metricname) {
784 exported_files.push(filename);
785 }
786 }
787
788 if let Ok(filename) = self.plot_gradient_norm() {
790 exported_files.push(filename);
791 }
792
793 if let Ok(filename) = self.plot_throughput() {
794 exported_files.push(filename);
795 }
796
797 if let Ok(filename) = self.plot_memory_usage() {
798 exported_files.push(filename);
799 }
800
801 if let Ok(filename) = self.create_dashboard() {
803 exported_files.push(filename);
804 }
805
806 Ok(exported_files)
807 }
808
809 fn create_line_plot(
811 &self,
812 x_values: &[f64],
813 y_values: &[f64],
814 title: &str,
815 x_label: &str,
816 y_label: &str,
817 ) -> Result<String> {
818 if !self.config.interactive_html {
819 return Ok(format!("# {}\nX: {:?}\nY: {:?}", title, x_values, y_values));
820 }
821
822 let mut plot = String::new();
823 plot.push_str(&self.create_html_header(title)?);
824 plot.push_str("<div id='plot'></div>\n");
825 plot.push_str("<script>\n");
826
827 writeln!(
828 &mut plot,
829 "const trace = {{x: {:?}, y: {:?}, type: 'scatter', mode: 'lines', name: '{}'}};",
830 x_values, y_values, title
831 )
832 .unwrap();
833
834 writeln!(&mut plot,
835 "Plotly.newPlot('plot', [trace], {{title: '{}', xaxis: {{title: '{}'}}, yaxis: {{title: '{}'}}}});",
836 title, x_label, y_label
837 ).unwrap();
838
839 plot.push_str("</script></body></html>");
840
841 Ok(plot)
842 }
843
844 fn create_scatter_plot(
846 &self,
847 x_values: &[f64],
848 y_values: &[f64],
849 title: &str,
850 x_label: &str,
851 y_label: &str,
852 ) -> Result<String> {
853 if !self.config.interactive_html {
854 return Ok(format!("# {}\nX: {:?}\nY: {:?}", title, x_values, y_values));
855 }
856
857 let mut plot = String::new();
858 plot.push_str(&self.create_html_header(title)?);
859 plot.push_str("<div id='plot'></div>\n");
860 plot.push_str("<script>\n");
861
862 writeln!(
863 &mut plot,
864 "const trace = {{x: {:?}, y: {:?}, type: 'scatter', mode: 'markers', name: '{}'}};",
865 x_values, y_values, title
866 )
867 .unwrap();
868
869 writeln!(&mut plot,
870 "Plotly.newPlot('plot', [trace], {{title: '{}', xaxis: {{title: '{}'}}, yaxis: {{title: '{}'}}}});",
871 title, x_label, y_label
872 ).unwrap();
873
874 plot.push_str("</script></body></html>");
875
876 Ok(plot)
877 }
878
879 fn create_html_header(&self, title: &str) -> Result<String> {
881 Ok(format!(
882 r#"
883<!DOCTYPE html>
884<html>
885<head>
886 <title>{}</title>
887 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
888 <style>
889 body {{ font-family: Arial, sans-serif; margin: 20px; }}
890 #plot {{ width: 100%; height: 500px; }}
891 </style>
892</head>
893<body>
894 <h1>{}</h1>
895"#,
896 title, title
897 ))
898 }
899
900 fn save_plot(&self, plotdata: &str, filename: &str) -> Result<String> {
902 let extension = if self.config.interactive_html {
903 "html"
904 } else {
905 "txt"
906 };
907 let full_filename = format!("{}.{}", filename, extension);
908 let filepath = Path::new(&self.config.output_dir).join(&full_filename);
909
910 let mut file = std::fs::File::create(&filepath).map_err(|e| {
911 OptimError::InvalidConfig(format!(
912 "Failed to create file {}: {}",
913 filepath.display(),
914 e
915 ))
916 })?;
917
918 file.write_all(plotdata.as_bytes()).map_err(|e| {
919 OptimError::InvalidConfig(format!(
920 "Failed to write to file {}: {}",
921 filepath.display(),
922 e
923 ))
924 })?;
925
926 Ok(full_filename)
927 }
928
929 fn get_color(&self, index: usize) -> String {
931 let colors = match self.config.color_scheme {
932 ColorScheme::Default => vec![
933 "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2",
934 "#7f7f7f", "#bcbd22", "#17becf",
935 ],
936 ColorScheme::Dark => vec![
937 "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69",
938 "#fccde5", "#d9d9d9", "#bc80bd",
939 ],
940 ColorScheme::Colorblind => vec![
941 "#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00",
942 "#CC79A7",
943 ],
944 ColorScheme::Publication => vec!["#000000", "#333333", "#666666", "#999999", "#CCCCCC"],
945 ColorScheme::Vibrant => vec![
946 "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD", "#98D8C8",
947 "#F7DC6F", "#BB8FCE", "#85C1E9",
948 ],
949 };
950
951 colors[index % colors.len()].to_string()
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use std::time::Duration;
959
960 #[test]
961 fn test_visualization_config_default() {
962 let config = VisualizationConfig::default();
963 assert_eq!(config.max_points, 10000);
964 assert!(config.interactive_html);
965 assert!(config.show_grid);
966 }
967
968 #[test]
969 fn test_optimization_metric() {
970 let mut metric = OptimizationMetric::new("loss".to_string(), false, "nats".to_string());
971
972 metric.add_value(1.0, 0);
973 metric.add_value(0.8, 1);
974 metric.add_value(0.6, 2);
975 metric.add_value(0.4, 3); assert_eq!(metric.values.len(), 4);
978 assert_eq!(metric.steps.len(), 4);
979
980 let improvement = metric.get_recent_improvement(2);
981 assert!(improvement.is_some());
982 }
983
984 #[test]
985 fn test_visualizer_creation() {
986 let config = VisualizationConfig {
987 output_dir: "/tmp/test_plots".to_string(),
988 ..Default::default()
989 };
990
991 let visualizer = OptimizationVisualizer::new(config);
992 assert!(visualizer.is_ok());
993 }
994
995 #[test]
996 fn test_add_metric() {
997 let config = VisualizationConfig {
998 output_dir: "/tmp/test_plots".to_string(),
999 ..Default::default()
1000 };
1001
1002 let mut visualizer = OptimizationVisualizer::new(config).unwrap();
1003
1004 visualizer.add_metric("loss".to_string(), 1.0, false, "nats".to_string());
1005 visualizer.step();
1006 visualizer.add_metric("loss".to_string(), 0.8, false, "nats".to_string());
1007
1008 assert!(visualizer.metrics.contains_key("loss"));
1009 assert_eq!(visualizer.metrics["loss"].values.len(), 2);
1010 }
1011
1012 #[test]
1013 fn test_optimizer_comparison() {
1014 let comparison = OptimizerComparison {
1015 name: "Adam".to_string(),
1016 metrics: {
1017 let mut map = HashMap::new();
1018 map.insert("loss".to_string(), vec![1.0, 0.8, 0.6]);
1019 map
1020 },
1021 hyperparameters: {
1022 let mut map = HashMap::new();
1023 map.insert("learning_rate".to_string(), 0.001);
1024 map
1025 },
1026 training_time: Duration::from_secs(120),
1027 memory_stats: MemoryStats {
1028 peak_memory_mb: 1024.0,
1029 avg_memory_mb: 512.0,
1030 memory_efficiency: 100.0,
1031 },
1032 convergence_info: ConvergenceInfo {
1033 converged: true,
1034 convergence_step: Some(100),
1035 final_value: 0.6,
1036 best_value: 0.6,
1037 convergence_rate: 0.004,
1038 },
1039 };
1040
1041 assert_eq!(comparison.name, "Adam");
1042 assert!(comparison.convergence_info.converged);
1043 }
1044
1045 #[test]
1046 fn test_color_schemes() {
1047 let config = VisualizationConfig {
1048 color_scheme: ColorScheme::Colorblind,
1049 output_dir: "/tmp/test_plots".to_string(),
1050 ..Default::default()
1051 };
1052
1053 let visualizer = OptimizationVisualizer::new(config).unwrap();
1054 let color = visualizer.get_color(0);
1055 assert_eq!(color, "#000000");
1056 }
1057}