optirs_core/
visualization.rs

1// Visualization tools for optimization metrics
2//
3// This module provides comprehensive visualization capabilities for tracking
4// optimization progress, comparing optimizers, and analyzing training dynamics.
5
6#[allow(dead_code)]
7use crate::error::{OptimError, Result};
8use std::collections::{HashMap, VecDeque};
9use std::fmt::Write as FmtWrite;
10use std::io::Write;
11use std::path::Path;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13
14/// Configuration for visualization output
15#[derive(Debug, Clone)]
16pub struct VisualizationConfig {
17    /// Output directory for saved plots
18    pub output_dir: String,
19
20    /// Maximum number of data points to display
21    pub max_points: usize,
22
23    /// Update frequency for real-time plots (in steps)
24    pub update_frequency: usize,
25
26    /// Enable interactive HTML output
27    pub interactive_html: bool,
28
29    /// Enable SVG output for publication
30    pub svg_output: bool,
31
32    /// Color scheme for plots
33    pub color_scheme: ColorScheme,
34
35    /// Default figure size (width, height)
36    pub figure_size: (u32, u32),
37
38    /// DPI for raster outputs
39    pub dpi: u32,
40
41    /// Enable grid lines
42    pub show_grid: bool,
43
44    /// Enable legends
45    pub show_legend: bool,
46}
47
48impl Default for VisualizationConfig {
49    fn default() -> Self {
50        Self {
51            output_dir: "optimization_plots".to_string(),
52            max_points: 10000,
53            update_frequency: 100,
54            interactive_html: true,
55            svg_output: false,
56            color_scheme: ColorScheme::Default,
57            figure_size: (800, 600),
58            dpi: 300,
59            show_grid: true,
60            show_legend: true,
61        }
62    }
63}
64
65/// Color schemes for plots
66#[derive(Debug, Clone, Copy)]
67pub enum ColorScheme {
68    Default,
69    Dark,
70    Colorblind,
71    Publication,
72    Vibrant,
73}
74
75/// Optimization metric for tracking
76#[derive(Debug, Clone)]
77pub struct OptimizationMetric {
78    /// Metric name
79    pub name: String,
80
81    /// Metric values over time
82    pub values: VecDeque<f64>,
83
84    /// Timestamps for each value
85    pub timestamps: VecDeque<u64>,
86
87    /// Step numbers
88    pub steps: VecDeque<usize>,
89
90    /// Target value (if any)
91    pub target: Option<f64>,
92
93    /// Whether higher values are better
94    pub higher_isbetter: bool,
95
96    /// Units for display
97    pub units: String,
98
99    /// Smoothing window size
100    pub smoothing_window: usize,
101}
102
103impl OptimizationMetric {
104    /// Create new metric tracker
105    pub fn new(name: String, higher_isbetter: bool, units: String) -> Self {
106        Self {
107            name,
108            values: VecDeque::new(),
109            timestamps: VecDeque::new(),
110            steps: VecDeque::new(),
111            target: None,
112            higher_isbetter,
113            units,
114            smoothing_window: 10,
115        }
116    }
117
118    /// Add a new value
119    pub fn add_value(&mut self, value: f64, step: usize) {
120        let timestamp = SystemTime::now()
121            .duration_since(UNIX_EPOCH)
122            .unwrap()
123            .as_secs();
124
125        self.values.push_back(value);
126        self.timestamps.push_back(timestamp);
127        self.steps.push_back(step);
128
129        // Keep only recent values to avoid memory issues
130        while self.values.len() > 50000 {
131            self.values.pop_front();
132            self.timestamps.pop_front();
133            self.steps.pop_front();
134        }
135    }
136
137    /// Get smoothed values
138    pub fn get_smoothed_values(&self) -> Vec<f64> {
139        if self.values.len() < self.smoothing_window {
140            return self.values.iter().copied().collect();
141        }
142
143        let mut smoothed = Vec::new();
144        let window = self.smoothing_window.min(self.values.len());
145
146        for i in 0..self.values.len() {
147            let start = i.saturating_sub(window / 2);
148            let end = (i + window / 2 + 1).min(self.values.len());
149
150            let sum: f64 = self.values.range(start..end).sum();
151            let avg = sum / (end - start) as f64;
152            smoothed.push(avg);
153        }
154
155        smoothed
156    }
157
158    /// Get recent improvement
159    pub fn get_recent_improvement(&self, windowsize: usize) -> Option<f64> {
160        if self.values.len() < windowsize * 2 {
161            return None;
162        }
163
164        let recent_avg: f64 =
165            self.values.iter().rev().take(windowsize).sum::<f64>() / windowsize as f64;
166        let older_avg: f64 = self
167            .values
168            .iter()
169            .rev()
170            .skip(windowsize)
171            .take(windowsize)
172            .sum::<f64>()
173            / windowsize as f64;
174
175        Some(if self.higher_isbetter {
176            recent_avg - older_avg
177        } else {
178            older_avg - recent_avg
179        })
180    }
181}
182
183/// Optimizer comparison data
184#[derive(Debug, Clone)]
185pub struct OptimizerComparison {
186    /// Optimizer name
187    pub name: String,
188
189    /// Performance metrics
190    pub metrics: HashMap<String, Vec<f64>>,
191
192    /// Hyperparameters used
193    pub hyperparameters: HashMap<String, f64>,
194
195    /// Total training time
196    pub training_time: Duration,
197
198    /// Memory usage statistics
199    pub memory_stats: MemoryStats,
200
201    /// Convergence information
202    pub convergence_info: ConvergenceInfo,
203}
204
205/// Memory usage statistics
206#[derive(Debug, Clone)]
207pub struct MemoryStats {
208    /// Peak memory usage (MB)
209    pub peak_memory_mb: f64,
210
211    /// Average memory usage (MB)
212    pub avg_memory_mb: f64,
213
214    /// Memory efficiency (ops per MB)
215    pub memory_efficiency: f64,
216}
217
218/// Convergence information
219#[derive(Debug, Clone)]
220pub struct ConvergenceInfo {
221    /// Whether convergence was achieved
222    pub converged: bool,
223
224    /// Step at which convergence was achieved
225    pub convergence_step: Option<usize>,
226
227    /// Final metric value
228    pub final_value: f64,
229
230    /// Best metric value achieved
231    pub best_value: f64,
232
233    /// Convergence rate (improvement per step)
234    pub convergence_rate: f64,
235}
236
237/// Main visualization engine
238pub struct OptimizationVisualizer {
239    /// Configuration
240    config: VisualizationConfig,
241
242    /// Tracked metrics
243    metrics: HashMap<String, OptimizationMetric>,
244
245    /// Optimizer comparisons
246    comparisons: Vec<OptimizerComparison>,
247
248    /// Real-time dashboard state
249    dashboard_state: DashboardState,
250
251    /// Step counter
252    current_step: usize,
253
254    /// Last update step
255    last_update_step: usize,
256}
257
258/// Dashboard state for real-time visualization
259#[derive(Debug)]
260struct DashboardState {
261    /// Active plots
262    active_plots: HashMap<String, PlotState>,
263
264    /// Layout configuration
265    layout: DashboardLayout,
266
267    /// Update timestamps
268    last_update: SystemTime,
269}
270
271/// Individual plot state
272#[derive(Debug)]
273struct PlotState {
274    /// Plot type
275    plot_type: PlotType,
276
277    /// Data series
278    series: Vec<DataSeries>,
279
280    /// Axis configuration
281    x_axis: AxisConfig,
282    y_axis: AxisConfig,
283
284    /// Plot title
285    title: String,
286}
287
288/// Types of plots available
289#[derive(Debug, Clone, Copy)]
290pub enum PlotType {
291    Line,
292    Scatter,
293    Histogram,
294    Heatmap,
295    Bar,
296    Box,
297    Violin,
298    Surface3D,
299}
300
301/// Data series for plotting
302#[derive(Debug, Clone)]
303pub struct DataSeries {
304    /// Series name
305    pub name: String,
306
307    /// X values
308    pub x_values: Vec<f64>,
309
310    /// Y values
311    pub y_values: Vec<f64>,
312
313    /// Z values (for 3D plots)
314    pub z_values: Option<Vec<f64>>,
315
316    /// Color
317    pub color: String,
318
319    /// Line style
320    pub line_style: LineStyle,
321
322    /// Marker style
323    pub marker_style: MarkerStyle,
324}
325
326/// Line styles
327#[derive(Debug, Clone, Copy)]
328pub enum LineStyle {
329    Solid,
330    Dashed,
331    Dotted,
332    DashDot,
333    None,
334}
335
336/// Marker styles
337#[derive(Debug, Clone, Copy)]
338pub enum MarkerStyle {
339    Circle,
340    Square,
341    Triangle,
342    Diamond,
343    Plus,
344    Cross,
345    None,
346}
347
348/// Axis configuration
349#[derive(Debug, Clone)]
350pub struct AxisConfig {
351    /// Axis label
352    pub label: String,
353
354    /// Scale type
355    pub scale: AxisScale,
356
357    /// Range (min, max)
358    pub range: Option<(f64, f64)>,
359
360    /// Tick configuration
361    pub ticks: TickConfig,
362}
363
364/// Axis scale types
365#[derive(Debug, Clone, Copy)]
366pub enum AxisScale {
367    Linear,
368    Log,
369    Symlog,
370}
371
372/// Tick configuration
373#[derive(Debug, Clone)]
374pub struct TickConfig {
375    /// Major tick spacing
376    pub major_spacing: Option<f64>,
377
378    /// Minor tick count
379    pub minor_count: usize,
380
381    /// Show tick labels
382    pub show_labels: bool,
383}
384
385/// Dashboard layout
386#[derive(Debug, Clone)]
387pub struct DashboardLayout {
388    /// Number of rows
389    pub rows: usize,
390
391    /// Number of columns
392    pub cols: usize,
393
394    /// Plot positions
395    pub plot_positions: HashMap<String, (usize, usize)>,
396}
397
398impl OptimizationVisualizer {
399    /// Create new visualization engine
400    pub fn new(config: VisualizationConfig) -> Result<Self> {
401        // Create output directory if it doesn't exist
402        std::fs::create_dir_all(&config.output_dir).map_err(|e| {
403            OptimError::InvalidConfig(format!("Failed to create output directory: {e}"))
404        })?;
405
406        let dashboard_state = DashboardState {
407            active_plots: HashMap::new(),
408            layout: DashboardLayout {
409                rows: 2,
410                cols: 2,
411                plot_positions: HashMap::new(),
412            },
413            last_update: SystemTime::now(),
414        };
415
416        Ok(Self {
417            config,
418            metrics: HashMap::new(),
419            comparisons: Vec::new(),
420            dashboard_state,
421            current_step: 0,
422            last_update_step: 0,
423        })
424    }
425
426    /// Add or update a metric
427    pub fn add_metric(&mut self, name: String, value: f64, higher_isbetter: bool, units: String) {
428        let metric = self
429            .metrics
430            .entry(name.clone())
431            .or_insert_with(|| OptimizationMetric::new(name, higher_isbetter, units));
432
433        metric.add_value(value, self.current_step);
434    }
435
436    /// Set target value for a metric
437    pub fn set_target(&mut self, metricname: &str, target: f64) {
438        if let Some(metric) = self.metrics.get_mut(metricname) {
439            metric.target = Some(target);
440        }
441    }
442
443    /// Update step counter
444    pub fn step(&mut self) {
445        self.current_step += 1;
446
447        if self.current_step - self.last_update_step >= self.config.update_frequency {
448            if let Err(e) = self.update_dashboard() {
449                eprintln!("Failed to update dashboard: {e}");
450            }
451            self.last_update_step = self.current_step;
452        }
453    }
454
455    /// Create loss curve plot
456    pub fn plot_loss_curve(&self, metricname: &str) -> Result<String> {
457        let metric = self
458            .metrics
459            .get(metricname)
460            .ok_or_else(|| OptimError::InvalidConfig(format!("Metric '{metricname}' not found")))?;
461
462        let steps: Vec<f64> = metric.steps.iter().map(|&s| s as f64).collect();
463        let values = metric.get_smoothed_values();
464
465        let plotdata = self.create_line_plot(
466            &steps,
467            &values,
468            &format!("{} over Training Steps", metric.name),
469            "Training Steps",
470            &format!("{} ({})", metric.name, metric.units),
471        )?;
472
473        self.save_plot(&plotdata, &format!("{metricname}_curve"))
474    }
475
476    /// Create learning rate schedule plot
477    pub fn plot_learning_rate_schedule(&self) -> Result<String> {
478        if let Some(lr_metric) = self.metrics.get("learning_rate") {
479            let steps: Vec<f64> = lr_metric.steps.iter().map(|&s| s as f64).collect();
480            let values: Vec<f64> = lr_metric.values.iter().copied().collect();
481
482            let plotdata = self.create_line_plot(
483                &steps,
484                &values,
485                "Learning Rate Schedule",
486                "Training Steps",
487                "Learning Rate",
488            )?;
489
490            self.save_plot(&plotdata, "learning_rate_schedule")
491        } else {
492            Err(OptimError::InvalidConfig(
493                "Learning rate metric not found".to_string(),
494            ))
495        }
496    }
497
498    /// Create optimizer comparison plot
499    pub fn plot_optimizer_comparison(&self, metricname: &str) -> Result<String> {
500        if self.comparisons.is_empty() {
501            return Err(OptimError::InvalidConfig(
502                "No optimizer comparisons available".to_string(),
503            ));
504        }
505
506        let mut plotdata = String::new();
507
508        // HTML header for interactive plot
509        if self.config.interactive_html {
510            plotdata.push_str(&self.create_html_header("Optimizer Comparison")?);
511            plotdata.push_str("<div id='comparison-plot'></div>\n");
512            plotdata.push_str("<script>\n");
513            plotdata.push_str("const traces = [];\n");
514
515            for comparison in &self.comparisons {
516                if let Some(values) = comparison.metrics.get(metricname) {
517                    let x_values: Vec<String> = (0..values.len()).map(|i| i.to_string()).collect();
518                    writeln!(&mut plotdata,
519                        "traces.push({{x: {:?}, y: {:?}, name: '{}', type: 'scatter', mode: 'lines'}});",
520                        x_values, values, comparison.name
521                    ).unwrap();
522                }
523            }
524
525            plotdata.push_str("Plotly.newPlot('comparison-plot', traces, {\n");
526            plotdata.push_str("  title: 'Optimizer Comparison',\n");
527            plotdata.push_str("  xaxis: {title: 'Training Steps'},\n");
528            writeln!(&mut plotdata, "  yaxis: {{title: '{metricname}'}}").unwrap();
529            plotdata.push_str("});\n");
530            plotdata.push_str("</script>\n");
531            plotdata.push_str("</body></html>\n");
532        }
533
534        self.save_plot(&plotdata, &format!("{metricname}_comparison"))
535    }
536
537    /// Create gradient norm visualization
538    pub fn plot_gradient_norm(&self) -> Result<String> {
539        if let Some(grad_metric) = self.metrics.get("gradient_norm") {
540            let steps: Vec<f64> = grad_metric.steps.iter().map(|&s| s as f64).collect();
541            let values: Vec<f64> = grad_metric.values.iter().copied().collect();
542
543            let mut plotdata = self.create_line_plot(
544                &steps,
545                &values,
546                "Gradient Norm",
547                "Training Steps",
548                "Gradient Norm",
549            )?;
550
551            // Add log scale for y-axis if values span multiple orders of magnitude
552            let max_val = values.iter().fold(0.0f64, |a, &b| a.max(b));
553            let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
554
555            if max_val / min_val > 100.0 {
556                plotdata = plotdata.replace("yaxis: {", "yaxis: {type: 'log', ");
557            }
558
559            self.save_plot(&plotdata, "gradient_norm")
560        } else {
561            Err(OptimError::InvalidConfig(
562                "Gradient norm metric not found".to_string(),
563            ))
564        }
565    }
566
567    /// Create training throughput plot
568    pub fn plot_throughput(&self) -> Result<String> {
569        if let Some(throughput_metric) = self.metrics.get("throughput") {
570            let steps: Vec<f64> = throughput_metric.steps.iter().map(|&s| s as f64).collect();
571            let values: Vec<f64> = throughput_metric.values.iter().copied().collect();
572
573            let plotdata = self.create_line_plot(
574                &steps,
575                &values,
576                "Training Throughput",
577                "Training Steps",
578                "Samples/Second",
579            )?;
580
581            self.save_plot(&plotdata, "throughput")
582        } else {
583            Err(OptimError::InvalidConfig(
584                "Throughput metric not found".to_string(),
585            ))
586        }
587    }
588
589    /// Create memory usage visualization
590    pub fn plot_memory_usage(&self) -> Result<String> {
591        if let Some(memory_metric) = self.metrics.get("memory_usage") {
592            let steps: Vec<f64> = memory_metric.steps.iter().map(|&s| s as f64).collect();
593            let values: Vec<f64> = memory_metric.values.iter().copied().collect();
594
595            let plotdata = self.create_line_plot(
596                &steps,
597                &values,
598                "Memory Usage",
599                "Training Steps",
600                "Memory (MB)",
601            )?;
602
603            self.save_plot(&plotdata, "memory_usage")
604        } else {
605            Err(OptimError::InvalidConfig(
606                "Memory usage metric not found".to_string(),
607            ))
608        }
609    }
610
611    /// Create hyperparameter sensitivity analysis
612    pub fn plot_hyperparameter_sensitivity(
613        &self,
614        param_name: &str,
615        metricname: &str,
616    ) -> Result<String> {
617        let mut param_values = Vec::new();
618        let mut metric_values = Vec::new();
619
620        for comparison in &self.comparisons {
621            if let (Some(&param_val), Some(metric_vals)) = (
622                comparison.hyperparameters.get(param_name),
623                comparison.metrics.get(metricname),
624            ) {
625                if let Some(&final_metric) = metric_vals.last() {
626                    param_values.push(param_val);
627                    metric_values.push(final_metric);
628                }
629            }
630        }
631
632        if param_values.is_empty() {
633            return Err(OptimError::InvalidConfig(format!(
634                "No data available for hyperparameter '{}' and metric '{}'",
635                param_name, metricname
636            )));
637        }
638
639        let plotdata = self.create_scatter_plot(
640            &param_values,
641            &metric_values,
642            &format!("Sensitivity of {} to {}", metricname, param_name),
643            param_name,
644            metricname,
645        )?;
646
647        self.save_plot(
648            &plotdata,
649            &format!("sensitivity_{}_{}", param_name, metricname),
650        )
651    }
652
653    /// Create comprehensive dashboard
654    pub fn create_dashboard(&self) -> Result<String> {
655        let mut dashboard = String::new();
656
657        if self.config.interactive_html {
658            dashboard.push_str(&self.create_html_header("Optimization Dashboard")?);
659
660            // Add CSS for layout
661            dashboard.push_str(
662                r#"
663<style>
664.dashboard-container {
665    display: grid;
666    grid-template-columns: 1fr 1fr;
667    grid-template-rows: 1fr 1fr;
668    gap: 20px;
669    height: 100vh;
670    padding: 20px;
671}
672.plot-container {
673    border: 1px solid #ddd;
674    border-radius: 8px;
675    padding: 10px;
676}
677.metrics-summary {
678    grid-column: span 2;
679    padding: 20px;
680    background-color: #f8f9fa;
681    border-radius: 8px;
682    margin-bottom: 20px;
683}
684</style>
685"#,
686            );
687
688            // Metrics summary
689            dashboard.push_str("<div class='metrics-summary'>\n");
690            dashboard.push_str("<h2>Current Metrics</h2>\n");
691            dashboard.push_str("<div style='display: flex; gap: 20px;'>\n");
692
693            for (name, metric) in &self.metrics {
694                if let Some(&latest_value) = metric.values.back() {
695                    writeln!(
696                        &mut dashboard,
697                        "<div><strong>{}:</strong> {:.4} {}</div>",
698                        name, latest_value, metric.units
699                    )
700                    .unwrap();
701                }
702            }
703
704            dashboard.push_str("</div></div>\n");
705
706            // Plot containers
707            dashboard.push_str("<div class='dashboard-container'>\n");
708
709            let mut plot_id = 0;
710            for _ in &self.metrics {
711                if plot_id >= 4 {
712                    break;
713                } // Limit to 4 plots in 2x2 grid
714
715                writeln!(
716                    &mut dashboard,
717                    "<div class='plot-container'><div id='plot-{}'></div></div>",
718                    plot_id
719                )
720                .unwrap();
721
722                plot_id += 1;
723            }
724
725            dashboard.push_str("</div>\n");
726
727            // JavaScript for plots
728            dashboard.push_str("<script>\n");
729
730            plot_id = 0;
731            for (name, metric) in &self.metrics {
732                if plot_id >= 4 {
733                    break;
734                }
735
736                let steps: Vec<String> = metric.steps.iter().map(|&s| s.to_string()).collect();
737                let values: Vec<f64> = metric.values.iter().copied().collect();
738
739                writeln!(&mut dashboard,
740                    "Plotly.newPlot('plot-{}', [{{x: {:?}, y: {:?}, type: 'scatter', mode: 'lines', name: '{}'}}], {{title: '{}', xaxis: {{title: 'Steps'}}, yaxis: {{title: '{}'}}}});",
741                    plot_id, steps, values, name, name, metric.units
742                ).unwrap();
743
744                plot_id += 1;
745            }
746
747            dashboard.push_str("</script>\n");
748            dashboard.push_str("</body></html>\n");
749        }
750
751        self.save_plot(&dashboard, "dashboard")
752    }
753
754    /// Update real-time dashboard
755    fn update_dashboard(&mut self) -> Result<()> {
756        self.dashboard_state.last_update = SystemTime::now();
757
758        // In a real implementation, this would update the live dashboard
759        // For now, we'll just regenerate static files
760        self.create_dashboard()?;
761
762        Ok(())
763    }
764
765    /// Add optimizer comparison data
766    pub fn add_optimizer_comparison(&mut self, comparison: OptimizerComparison) {
767        self.comparisons.push(comparison);
768    }
769
770    /// Export all visualizations
771    pub fn export_all(&self) -> Result<Vec<String>> {
772        let mut exported_files = Vec::new();
773
774        // Export individual metric plots
775        for metricname in self.metrics.keys() {
776            if let Ok(filename) = self.plot_loss_curve(metricname) {
777                exported_files.push(filename);
778            }
779        }
780
781        // Export comparisons
782        for metricname in ["loss", "accuracy", "throughput"] {
783            if let Ok(filename) = self.plot_optimizer_comparison(metricname) {
784                exported_files.push(filename);
785            }
786        }
787
788        // Export specialized plots
789        if let Ok(filename) = self.plot_gradient_norm() {
790            exported_files.push(filename);
791        }
792
793        if let Ok(filename) = self.plot_throughput() {
794            exported_files.push(filename);
795        }
796
797        if let Ok(filename) = self.plot_memory_usage() {
798            exported_files.push(filename);
799        }
800
801        // Export dashboard
802        if let Ok(filename) = self.create_dashboard() {
803            exported_files.push(filename);
804        }
805
806        Ok(exported_files)
807    }
808
809    /// Helper function to create line plot
810    fn create_line_plot(
811        &self,
812        x_values: &[f64],
813        y_values: &[f64],
814        title: &str,
815        x_label: &str,
816        y_label: &str,
817    ) -> Result<String> {
818        if !self.config.interactive_html {
819            return Ok(format!("# {}\nX: {:?}\nY: {:?}", title, x_values, y_values));
820        }
821
822        let mut plot = String::new();
823        plot.push_str(&self.create_html_header(title)?);
824        plot.push_str("<div id='plot'></div>\n");
825        plot.push_str("<script>\n");
826
827        writeln!(
828            &mut plot,
829            "const trace = {{x: {:?}, y: {:?}, type: 'scatter', mode: 'lines', name: '{}'}};",
830            x_values, y_values, title
831        )
832        .unwrap();
833
834        writeln!(&mut plot,
835            "Plotly.newPlot('plot', [trace], {{title: '{}', xaxis: {{title: '{}'}}, yaxis: {{title: '{}'}}}});",
836            title, x_label, y_label
837        ).unwrap();
838
839        plot.push_str("</script></body></html>");
840
841        Ok(plot)
842    }
843
844    /// Helper function to create scatter plot
845    fn create_scatter_plot(
846        &self,
847        x_values: &[f64],
848        y_values: &[f64],
849        title: &str,
850        x_label: &str,
851        y_label: &str,
852    ) -> Result<String> {
853        if !self.config.interactive_html {
854            return Ok(format!("# {}\nX: {:?}\nY: {:?}", title, x_values, y_values));
855        }
856
857        let mut plot = String::new();
858        plot.push_str(&self.create_html_header(title)?);
859        plot.push_str("<div id='plot'></div>\n");
860        plot.push_str("<script>\n");
861
862        writeln!(
863            &mut plot,
864            "const trace = {{x: {:?}, y: {:?}, type: 'scatter', mode: 'markers', name: '{}'}};",
865            x_values, y_values, title
866        )
867        .unwrap();
868
869        writeln!(&mut plot,
870            "Plotly.newPlot('plot', [trace], {{title: '{}', xaxis: {{title: '{}'}}, yaxis: {{title: '{}'}}}});",
871            title, x_label, y_label
872        ).unwrap();
873
874        plot.push_str("</script></body></html>");
875
876        Ok(plot)
877    }
878
879    /// Create HTML header for interactive plots
880    fn create_html_header(&self, title: &str) -> Result<String> {
881        Ok(format!(
882            r#"
883<!DOCTYPE html>
884<html>
885<head>
886    <title>{}</title>
887    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
888    <style>
889        body {{ font-family: Arial, sans-serif; margin: 20px; }}
890        #plot {{ width: 100%; height: 500px; }}
891    </style>
892</head>
893<body>
894    <h1>{}</h1>
895"#,
896            title, title
897        ))
898    }
899
900    /// Save plot to file
901    fn save_plot(&self, plotdata: &str, filename: &str) -> Result<String> {
902        let extension = if self.config.interactive_html {
903            "html"
904        } else {
905            "txt"
906        };
907        let full_filename = format!("{}.{}", filename, extension);
908        let filepath = Path::new(&self.config.output_dir).join(&full_filename);
909
910        let mut file = std::fs::File::create(&filepath).map_err(|e| {
911            OptimError::InvalidConfig(format!(
912                "Failed to create file {}: {}",
913                filepath.display(),
914                e
915            ))
916        })?;
917
918        file.write_all(plotdata.as_bytes()).map_err(|e| {
919            OptimError::InvalidConfig(format!(
920                "Failed to write to file {}: {}",
921                filepath.display(),
922                e
923            ))
924        })?;
925
926        Ok(full_filename)
927    }
928
929    /// Get color for plot series
930    fn get_color(&self, index: usize) -> String {
931        let colors = match self.config.color_scheme {
932            ColorScheme::Default => vec![
933                "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2",
934                "#7f7f7f", "#bcbd22", "#17becf",
935            ],
936            ColorScheme::Dark => vec![
937                "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69",
938                "#fccde5", "#d9d9d9", "#bc80bd",
939            ],
940            ColorScheme::Colorblind => vec![
941                "#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00",
942                "#CC79A7",
943            ],
944            ColorScheme::Publication => vec!["#000000", "#333333", "#666666", "#999999", "#CCCCCC"],
945            ColorScheme::Vibrant => vec![
946                "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD", "#98D8C8",
947                "#F7DC6F", "#BB8FCE", "#85C1E9",
948            ],
949        };
950
951        colors[index % colors.len()].to_string()
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use std::time::Duration;
959
960    #[test]
961    fn test_visualization_config_default() {
962        let config = VisualizationConfig::default();
963        assert_eq!(config.max_points, 10000);
964        assert!(config.interactive_html);
965        assert!(config.show_grid);
966    }
967
968    #[test]
969    fn test_optimization_metric() {
970        let mut metric = OptimizationMetric::new("loss".to_string(), false, "nats".to_string());
971
972        metric.add_value(1.0, 0);
973        metric.add_value(0.8, 1);
974        metric.add_value(0.6, 2);
975        metric.add_value(0.4, 3); // Add 4th value to meet windowsize * 2 requirement
976
977        assert_eq!(metric.values.len(), 4);
978        assert_eq!(metric.steps.len(), 4);
979
980        let improvement = metric.get_recent_improvement(2);
981        assert!(improvement.is_some());
982    }
983
984    #[test]
985    fn test_visualizer_creation() {
986        let config = VisualizationConfig {
987            output_dir: "/tmp/test_plots".to_string(),
988            ..Default::default()
989        };
990
991        let visualizer = OptimizationVisualizer::new(config);
992        assert!(visualizer.is_ok());
993    }
994
995    #[test]
996    fn test_add_metric() {
997        let config = VisualizationConfig {
998            output_dir: "/tmp/test_plots".to_string(),
999            ..Default::default()
1000        };
1001
1002        let mut visualizer = OptimizationVisualizer::new(config).unwrap();
1003
1004        visualizer.add_metric("loss".to_string(), 1.0, false, "nats".to_string());
1005        visualizer.step();
1006        visualizer.add_metric("loss".to_string(), 0.8, false, "nats".to_string());
1007
1008        assert!(visualizer.metrics.contains_key("loss"));
1009        assert_eq!(visualizer.metrics["loss"].values.len(), 2);
1010    }
1011
1012    #[test]
1013    fn test_optimizer_comparison() {
1014        let comparison = OptimizerComparison {
1015            name: "Adam".to_string(),
1016            metrics: {
1017                let mut map = HashMap::new();
1018                map.insert("loss".to_string(), vec![1.0, 0.8, 0.6]);
1019                map
1020            },
1021            hyperparameters: {
1022                let mut map = HashMap::new();
1023                map.insert("learning_rate".to_string(), 0.001);
1024                map
1025            },
1026            training_time: Duration::from_secs(120),
1027            memory_stats: MemoryStats {
1028                peak_memory_mb: 1024.0,
1029                avg_memory_mb: 512.0,
1030                memory_efficiency: 100.0,
1031            },
1032            convergence_info: ConvergenceInfo {
1033                converged: true,
1034                convergence_step: Some(100),
1035                final_value: 0.6,
1036                best_value: 0.6,
1037                convergence_rate: 0.004,
1038            },
1039        };
1040
1041        assert_eq!(comparison.name, "Adam");
1042        assert!(comparison.convergence_info.converged);
1043    }
1044
1045    #[test]
1046    fn test_color_schemes() {
1047        let config = VisualizationConfig {
1048            color_scheme: ColorScheme::Colorblind,
1049            output_dir: "/tmp/test_plots".to_string(),
1050            ..Default::default()
1051        };
1052
1053        let visualizer = OptimizationVisualizer::new(config).unwrap();
1054        let color = visualizer.get_color(0);
1055        assert_eq!(color, "#000000");
1056    }
1057}