Skip to main content

scirs2_neural/visualization/
training.rs

1//! Training metrics and curve visualization for neural networks
2//!
3//! This module provides comprehensive tools for visualizing training progress
4//! including loss curves, accuracy metrics, learning rate schedules, and system performance.
5
6use super::config::{DownsamplingStrategy, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use scirs2_core::numeric::Float;
9use scirs2_core::NumAssign;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::fs;
14use std::path::PathBuf;
15/// Training metrics visualizer
16#[allow(dead_code)]
17pub struct TrainingVisualizer<F: Float + Debug + NumAssign> {
18    /// Training history
19    metrics_history: Vec<TrainingMetrics<F>>,
20    /// Visualization configuration
21    config: VisualizationConfig,
22    /// Active plots
23    active_plots: HashMap<String, PlotConfig>,
24}
25/// Training metrics for a single epoch/step
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TrainingMetrics<F: Float + Debug + NumAssign> {
28    /// Epoch number
29    pub epoch: usize,
30    /// Step number within epoch
31    pub step: usize,
32    /// Timestamp
33    pub timestamp: String,
34    /// Loss values
35    pub losses: HashMap<String, F>,
36    /// Accuracy metrics
37    pub accuracies: HashMap<String, F>,
38    /// Learning rate
39    pub learning_rate: F,
40    /// Other custom metrics
41    pub custom_metrics: HashMap<String, F>,
42    /// System metrics
43    pub system_metrics: SystemMetrics,
44}
45
46/// System performance metrics during training
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct SystemMetrics {
49    /// Memory usage in MB
50    pub memory_usage_mb: f64,
51    /// GPU memory usage in MB (if available)
52    pub gpu_memory_mb: Option<f64>,
53    /// CPU utilization percentage
54    pub cpu_utilization: f64,
55    /// GPU utilization percentage (if available)
56    pub gpu_utilization: Option<f64>,
57    /// Training step duration in milliseconds
58    pub step_duration_ms: f64,
59    /// Samples processed per second
60    pub samples_per_second: f64,
61}
62
63/// Plot configuration
64#[derive(Debug, Clone, Serialize)]
65pub struct PlotConfig {
66    /// Plot title
67    pub title: String,
68    /// X-axis configuration
69    pub x_axis: AxisConfig,
70    /// Y-axis configuration
71    pub y_axis: AxisConfig,
72    /// Series to plot
73    pub series: Vec<SeriesConfig>,
74    /// Plot type
75    pub plot_type: PlotType,
76    /// Update mode
77    pub update_mode: UpdateMode,
78}
79
80/// Axis configuration
81#[derive(Debug, Clone, Serialize)]
82pub struct AxisConfig {
83    /// Axis label
84    pub label: String,
85    /// Axis scale
86    pub scale: AxisScale,
87    /// Range (None for auto)
88    pub range: Option<(f64, f64)>,
89    /// Show grid lines
90    pub show_grid: bool,
91    /// Tick configuration
92    pub ticks: TickConfig,
93}
94
95/// Axis scale type
96#[derive(Debug, Clone, PartialEq, Serialize)]
97pub enum AxisScale {
98    /// Linear scale
99    Linear,
100    /// Logarithmic scale
101    Log,
102    /// Square root scale
103    Sqrt,
104    /// Custom scale
105    Custom(String),
106}
107
108/// Tick configuration
109#[derive(Debug, Clone, Serialize)]
110pub struct TickConfig {
111    /// Tick interval (None for auto)
112    pub interval: Option<f64>,
113    /// Tick format
114    pub format: TickFormat,
115    /// Show tick labels
116    pub show_labels: bool,
117    /// Tick rotation angle
118    pub rotation: f32,
119}
120
121/// Tick format options
122#[derive(Debug, Clone, Serialize)]
123pub enum TickFormat {
124    /// Automatic formatting
125    Auto,
126    /// Fixed decimal places
127    Fixed(u32),
128    /// Scientific notation
129    Scientific,
130    /// Percentage
131    Percentage,
132    /// Custom format string
133    Custom(String),
134}
135
136/// Data series configuration
137#[derive(Debug, Clone, Serialize)]
138pub struct SeriesConfig {
139    /// Series name
140    pub name: String,
141    /// Data source (metric name)
142    pub data_source: String,
143    /// Line style
144    pub style: LineStyleConfig,
145    /// Marker style
146    pub markers: MarkerConfig,
147    /// Series color
148    pub color: String,
149    /// Series opacity
150    pub opacity: f32,
151}
152
153/// Line style configuration for series
154#[derive(Debug, Clone, Serialize)]
155pub struct LineStyleConfig {
156    pub style: LineStyle,
157    /// Line width
158    pub width: f32,
159    /// Smoothing enabled
160    pub smoothing: bool,
161    /// Smoothing window size
162    pub smoothing_window: usize,
163}
164
165/// Line style options (re-exported from network module)
166#[derive(Debug, Clone, PartialEq, Serialize)]
167pub enum LineStyle {
168    /// Solid line
169    Solid,
170    /// Dashed line
171    Dashed,
172    /// Dotted line
173    Dotted,
174    /// Dash-dot line
175    DashDot,
176}
177
178/// Marker configuration for data points
179#[derive(Debug, Clone, Serialize)]
180pub struct MarkerConfig {
181    /// Show markers
182    pub show: bool,
183    /// Marker shape
184    pub shape: MarkerShape,
185    /// Marker size
186    pub size: f32,
187    /// Marker fill color
188    pub fill_color: String,
189    /// Marker border color
190    pub border_color: String,
191}
192
193/// Marker shape options
194#[derive(Debug, Clone, PartialEq, Serialize)]
195pub enum MarkerShape {
196    /// Circle marker
197    Circle,
198    /// Square marker
199    Square,
200    /// Triangle marker
201    Triangle,
202    /// Diamond marker
203    Diamond,
204    /// Cross marker
205    Cross,
206    /// Plus marker
207    Plus,
208}
209
210/// Plot type options
211#[derive(Debug, Clone, PartialEq, Serialize)]
212pub enum PlotType {
213    /// Line plot
214    Line,
215    /// Scatter plot
216    Scatter,
217    /// Bar plot
218    Bar,
219    /// Area plot
220    Area,
221    /// Histogram
222    Histogram,
223    /// Box plot
224    Box,
225    /// Heatmap
226    Heatmap,
227}
228
229/// Update mode for plots
230#[derive(Debug, Clone, PartialEq, Serialize)]
231pub enum UpdateMode {
232    /// Append new data
233    Append,
234    /// Replace all data
235    Replace,
236    /// Rolling window
237    Rolling(usize),
238}
239
240// Implementation for TrainingVisualizer
241impl<
242        F: Float + Debug + NumAssign + 'static + scirs2_core::numeric::FromPrimitive + Send + Sync,
243    > TrainingVisualizer<F>
244{
245    /// Create a new training visualizer
246    pub fn new(config: VisualizationConfig) -> Self {
247        Self {
248            metrics_history: Vec::new(),
249            config,
250            active_plots: HashMap::new(),
251        }
252    }
253    /// Add training metrics for visualization
254    pub fn add_metrics(&mut self, metrics: TrainingMetrics<F>) {
255        self.metrics_history.push(metrics);
256        // Apply downsampling if needed
257        if self.metrics_history.len() > self.config.performance.max_points_per_plot
258            && self.config.performance.enable_downsampling
259        {
260            self.downsample_metrics();
261        }
262    }
263
264    /// Generate training curves visualization
265    pub fn visualize_training_curves(&self) -> Result<Vec<PathBuf>> {
266        let mut output_files = Vec::new();
267        // Generate loss curves
268        if let Some(loss_plot) = self.create_loss_plot()? {
269            let loss_path = self.config.output_dir.join("training_loss.html");
270            fs::write(&loss_path, loss_plot)
271                .map_err(|e| NeuralError::IOError(format!("Failed to write loss plot: {}", e)))?;
272            output_files.push(loss_path);
273        }
274
275        // Generate accuracy curves
276        if let Some(accuracy_plot) = self.create_accuracy_plot()? {
277            let accuracy_path = self.config.output_dir.join("training_accuracy.html");
278            fs::write(&accuracy_path, accuracy_plot).map_err(|e| {
279                NeuralError::IOError(format!("Failed to write accuracy plot: {}", e))
280            })?;
281            output_files.push(accuracy_path);
282        }
283
284        // Generate learning rate plot
285        if let Some(lr_plot) = self.create_learning_rate_plot()? {
286            let lr_path = self.config.output_dir.join("learning_rate.html");
287            fs::write(&lr_path, lr_plot).map_err(|e| {
288                NeuralError::IOError(format!("Failed to write learning rate plot: {}", e))
289            })?;
290            output_files.push(lr_path);
291        }
292
293        // Generate system metrics plot
294        if let Some(system_plot) = self.create_system_metrics_plot()? {
295            let system_path = self.config.output_dir.join("system_metrics.html");
296            fs::write(&system_path, system_plot).map_err(|e| {
297                NeuralError::IOError(format!("Failed to write system metrics plot: {}", e))
298            })?;
299            output_files.push(system_path);
300        }
301
302        Ok(output_files)
303    }
304
305    /// Get the current metrics history
306    pub fn get_metrics_history(&self) -> &[TrainingMetrics<F>] {
307        &self.metrics_history
308    }
309
310    /// Clear the metrics history
311    pub fn clear_history(&mut self) {
312        self.metrics_history.clear();
313    }
314    /// Add a custom plot configuration
315    pub fn add_plot(&mut self, name: String, config: PlotConfig) {
316        self.active_plots.insert(name, config);
317    }
318
319    /// Remove a plot configuration
320    pub fn remove_plot(&mut self, name: &str) -> Option<PlotConfig> {
321        self.active_plots.remove(name)
322    }
323
324    /// Update the visualization configuration
325    pub fn update_config(&mut self, config: VisualizationConfig) {
326        self.config = config;
327    }
328
329    fn downsample_metrics(&mut self) {
330        // Implement downsampling based on strategy
331        if self.metrics_history.len() <= self.config.performance.max_points_per_plot {
332            return; // No downsampling needed
333        }
334
335        match self.config.performance.downsampling_strategy {
336            DownsamplingStrategy::Uniform => {
337                // Keep every nth point
338                let step = self.metrics_history.len() / self.config.performance.max_points_per_plot;
339                if step > 1 {
340                    let mut downsampled = Vec::new();
341                    for (i, metric) in self.metrics_history.iter().enumerate() {
342                        if i % step == 0 {
343                            downsampled.push(metric.clone());
344                        }
345                    }
346                    self.metrics_history = downsampled;
347                }
348            }
349            DownsamplingStrategy::LTTB => {
350                // Largest Triangle Three Bucket algorithm - simplified implementation
351                self.downsample_lttb();
352            }
353            DownsamplingStrategy::MinMax => {
354                // Min-max decimation - keep local minima and maxima
355                self.downsample_minmax();
356            }
357            DownsamplingStrategy::Statistical => {
358                // Statistical sampling - sample based on variance/importance
359                self.downsample_statistical();
360            }
361        }
362    }
363
364    /// Largest Triangle Three Bucket (LTTB) downsampling algorithm
365    fn downsample_lttb(&mut self) {
366        let target_points = self.config.performance.max_points_per_plot;
367        if self.metrics_history.len() <= target_points {
368            return;
369        }
370        let bucket_size = self.metrics_history.len() as f64 / target_points as f64;
371        let mut downsampled = Vec::new();
372        // Always keep first point
373        downsampled.push(self.metrics_history[0].clone());
374        // For each bucket, select the point that forms the largest triangle
375        for bucket in 1..(target_points - 1) {
376            let bucket_start = (bucket as f64 * bucket_size) as usize;
377            let bucket_end =
378                ((bucket + 1) as f64 * bucket_size).min(self.metrics_history.len() as f64) as usize;
379            // Calculate average point of next bucket
380            let next_bucket_start = bucket_end;
381            let next_bucket_end =
382                ((bucket + 2) as f64 * bucket_size).min(self.metrics_history.len() as f64) as usize;
383            let avg_epoch = if next_bucket_end > next_bucket_start {
384                let sum: usize = (next_bucket_start..next_bucket_end)
385                    .map(|i| self.metrics_history[i].epoch)
386                    .sum();
387                sum as f64 / (next_bucket_end - next_bucket_start) as f64
388            } else {
389                self.metrics_history[self.metrics_history.len() - 1].epoch as f64
390            };
391            // Find point in current bucket that maximizes triangle area
392            let mut max_area = 0.0f64;
393            let mut selected_idx = bucket_start;
394            let prev_epoch = downsampled.last().expect("Operation failed").epoch as f64;
395            for i in bucket_start..bucket_end {
396                let curr_epoch = self.metrics_history[i].epoch as f64;
397                // Calculate triangle area (simplified - using epoch as primary metric)
398                let area = ((prev_epoch - avg_epoch) * (curr_epoch - prev_epoch)).abs();
399                if area > max_area {
400                    max_area = area;
401                    selected_idx = i;
402                }
403            }
404
405            downsampled.push(self.metrics_history[selected_idx].clone());
406        }
407
408        // Always keep last point
409        downsampled.push(self.metrics_history[self.metrics_history.len() - 1].clone());
410        self.metrics_history = downsampled;
411    }
412
413    /// Min-max decimation downsampling
414    fn downsample_minmax(&mut self) {
415        let target_points = self.config.performance.max_points_per_plot;
416        if self.metrics_history.len() <= target_points {
417            return;
418        }
419
420        let mut downsampled = Vec::new();
421        let bucket_size = self.metrics_history.len() / (target_points / 2); // Divide by 2 because we keep min and max
422        if bucket_size == 0 {
423            return;
424        }
425
426        for chunk in self.metrics_history.chunks(bucket_size) {
427            if chunk.is_empty() {
428                continue;
429            }
430
431            // Find min and max based on a primary loss metric
432            let mut min_metric = &chunk[0];
433            let mut max_metric = &chunk[0];
434
435            for metric in chunk {
436                // Use first loss value as comparison metric, or epoch if no losses
437                let current_value = metric
438                    .losses
439                    .values()
440                    .next()
441                    .map(|v| v.to_f64().unwrap_or(0.0))
442                    .unwrap_or(metric.epoch as f64);
443                let min_value = min_metric
444                    .losses
445                    .values()
446                    .next()
447                    .map(|v| v.to_f64().unwrap_or(0.0))
448                    .unwrap_or(min_metric.epoch as f64);
449                let max_value = max_metric
450                    .losses
451                    .values()
452                    .next()
453                    .map(|v| v.to_f64().unwrap_or(0.0))
454                    .unwrap_or(max_metric.epoch as f64);
455
456                if current_value < min_value {
457                    min_metric = metric;
458                }
459                if current_value > max_value {
460                    max_metric = metric;
461                }
462            }
463
464            // Add min and max (avoid duplicates)
465            if min_metric.epoch <= max_metric.epoch {
466                downsampled.push(min_metric.clone());
467                if min_metric.epoch != max_metric.epoch {
468                    downsampled.push(max_metric.clone());
469                }
470            } else {
471                downsampled.push(max_metric.clone());
472            }
473        }
474        // Sort by epoch to maintain temporal order
475        downsampled.sort_by_key(|m| m.epoch);
476
477        // If still too many points, apply uniform sampling
478        if downsampled.len() > target_points {
479            let step = downsampled.len() / target_points;
480            let mut final_downsampled = Vec::new();
481            for (i, metric) in downsampled.iter().enumerate() {
482                if i % step == 0 {
483                    final_downsampled.push(metric.clone());
484                }
485            }
486            self.metrics_history = final_downsampled;
487        } else {
488            self.metrics_history = downsampled;
489        }
490    }
491
492    /// Statistical downsampling based on variance and importance
493    fn downsample_statistical(&mut self) {
494        let target_points = self.config.performance.max_points_per_plot;
495        if self.metrics_history.len() <= target_points {
496            return;
497        }
498
499        let mut downsampled = Vec::new();
500        // Calculate importance scores for each point
501        let mut importance_scores: Vec<(usize, f64)> = Vec::new();
502        for (i, metric) in self.metrics_history.iter().enumerate() {
503            let mut score = 0.0f64;
504            // Base importance: changes in loss values
505            if i > 0 && i < self.metrics_history.len() - 1 {
506                let prev_metric = &self.metrics_history[i - 1];
507                let next_metric = &self.metrics_history[i + 1];
508                // Calculate variance in loss values
509                for (loss_name, &loss_value) in &metric.losses {
510                    if let (Some(&prev_loss), Some(&next_loss)) = (
511                        prev_metric.losses.get(loss_name),
512                        next_metric.losses.get(loss_name),
513                    ) {
514                        let prev_val = prev_loss.to_f64().unwrap_or(0.0);
515                        let curr_val = loss_value.to_f64().unwrap_or(0.0);
516                        let next_val = next_loss.to_f64().unwrap_or(0.0);
517                        // Second derivative (curvature) as importance measure
518                        let curvature = ((next_val - curr_val) - (curr_val - prev_val)).abs();
519                        score += curvature;
520                    }
521                }
522            }
523
524            // Always keep first and last points
525            if i == 0 || i == self.metrics_history.len() - 1 {
526                score += 1000.0; // High importance
527            }
528
529            importance_scores.push((i, score));
530        }
531        // Sort by importance score (descending)
532        importance_scores
533            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
534
535        // Select top points and sort by original index to maintain temporal order
536        let mut selected_indices: Vec<usize> = importance_scores
537            .iter()
538            .take(target_points)
539            .map(|(idx, _)| *idx)
540            .collect();
541        selected_indices.sort();
542
543        for &idx in &selected_indices {
544            downsampled.push(self.metrics_history[idx].clone());
545        }
546
547        self.metrics_history = downsampled;
548    }
549
550    fn create_loss_plot(&self) -> Result<Option<String>> {
551        if self.metrics_history.is_empty() {
552            return Ok(None);
553        }
554
555        // Extract loss data from metrics history
556        let mut loss_data = std::collections::HashMap::new();
557        let mut epochs = Vec::new();
558
559        for metric in &self.metrics_history {
560            epochs.push(metric.epoch);
561            for (loss_name, loss_value) in &metric.losses {
562                loss_data
563                    .entry(loss_name.clone())
564                    .or_insert_with(Vec::new)
565                    .push(loss_value.to_f64().unwrap_or(0.0));
566            }
567        }
568
569        if loss_data.is_empty() {
570            return Ok(None);
571        }
572
573        // Generate HTML with Plotly.js
574        let mut traces = Vec::new();
575        let colors = [
576            "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b",
577        ];
578
579        for (i, (loss_name, values)) in loss_data.iter().enumerate() {
580            let color = colors[i % colors.len()];
581            let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
582            let values_json = serde_json::to_string(values).unwrap_or_default();
583
584            traces.push(format!(
585                r#"{{
586                    x: {},
587                    y: {},
588                    type: 'scatter',
589                    mode: 'lines+markers',
590                    name: '{}',
591                    line: {{ color: '{}', width: 2 }},
592                    marker: {{ size: 6, color: '{}' }}
593                }}"#,
594                epochs_json, values_json, loss_name, color, color
595            ));
596        }
597
598        let traces_str = traces.join(",\n            ");
599        let plot_html = format!(
600            r#"
601<!DOCTYPE html>
602<html>
603<head>
604    <title>Training Loss</title>
605    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
606    <style>
607        body {{ font-family: Arial, sans-serif; margin: 20px; }}
608        .plot-container {{ width: 100%; height: 600px; }}
609    </style>
610</head>
611<body>
612    <h2>Training Loss Curves</h2>
613    <div id="lossPlot" class="plot-container"></div>
614    <script>
615        var traces = [
616            {}
617        
618        var layout = {{
619            title: {{
620                text: 'Training Loss Over Time',
621                font: {{ size: 18 }}
622            }},
623            xaxis: {{ 
624                title: 'Epoch',
625                showgrid: true,
626                gridcolor: '#e0e0e0'
627            yaxis: {{ 
628                title: 'Loss',
629            hovermode: 'x unified',
630            legend: {{
631                x: 1,
632                y: 1,
633                bgcolor: 'rgba(255,255,255,0.8)',
634                bordercolor: '#000',
635                borderwidth: 1
636            plot_bgcolor: '#ffffff',
637            paper_bgcolor: '#ffffff'
638        }};
639        var config = {{
640            responsive: true,
641            displayModeBar: true,
642            modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
643        }};
644
645        Plotly.newPlot('lossPlot', traces, layout, config);
646    </script>
647</body>
648</html>"#,
649            traces_str
650        );
651
652        Ok(Some(plot_html))
653    }
654
655    fn create_accuracy_plot(&self) -> Result<Option<String>> {
656        if self.metrics_history.is_empty() {
657            return Ok(None);
658        }
659
660        // Extract accuracy data from metrics history
661        let mut accuracy_data = std::collections::HashMap::new();
662        let mut epochs = Vec::new();
663
664        for metric in &self.metrics_history {
665            epochs.push(metric.epoch);
666            for (acc_name, acc_value) in &metric.accuracies {
667                accuracy_data
668                    .entry(acc_name.clone())
669                    .or_insert_with(Vec::new)
670                    .push(acc_value.to_f64().unwrap_or(0.0));
671            }
672        }
673
674        if accuracy_data.is_empty() {
675            return Ok(None);
676        }
677
678        // Generate HTML with Plotly.js
679        let mut traces = Vec::new();
680        let colors = [
681            "#2ca02c", "#ff7f0e", "#1f77b4", "#d62728", "#9467bd", "#8c564b",
682        ];
683
684        for (i, (acc_name, values)) in accuracy_data.iter().enumerate() {
685            let color = colors[i % colors.len()];
686            let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
687            let values_json = serde_json::to_string(values).unwrap_or_default();
688
689            traces.push(format!(
690                r#"{{
691                    x: {},
692                    y: {},
693                    type: 'scatter',
694                    mode: 'lines+markers',
695                    name: '{}',
696                    line: {{ color: '{}', width: 2 }},
697                    marker: {{ size: 6, color: '{}' }}
698                }}"#,
699                epochs_json, values_json, acc_name, color, color
700            ));
701        }
702
703        let traces_str = traces.join(",\n            ");
704
705        let plot_html = format!(
706            r#"
707<!DOCTYPE html>
708<html>
709<head>
710    <title>Training Accuracy</title>
711    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
712    <style>
713        body {{ font-family: Arial, sans-serif; margin: 20px; }}
714        .plot-container {{ width: 100%; height: 600px; }}
715    </style>
716</head>
717<body>
718    <h2>Training Accuracy Curves</h2>
719    <div id="accuracyPlot" class="plot-container"></div>
720    <script>
721        var traces = [
722            {}
723        ];
724
725        var layout = {{
726            title: {{
727                text: "Training Accuracy Over Time",
728                font: {{ size: 18 }}
729            }},
730            xaxis: {{
731                title: 'Epoch',
732                showgrid: true,
733                gridcolor: '#e0e0e0'
734            }},
735            yaxis: {{
736                title: 'Accuracy',
737                showgrid: true,
738                gridcolor: '#e0e0e0',
739                range: [0, 1]
740            }},
741            hovermode: 'x unified',
742            legend: {{
743                x: 1,
744                y: 0,
745                bgcolor: 'rgba(255,255,255,0.8)',
746                bordercolor: '#000',
747                borderwidth: 1
748            }},
749            plot_bgcolor: '#ffffff',
750            paper_bgcolor: '#ffffff'
751        }};
752
753        var config = {{
754            responsive: true,
755            displayModeBar: true,
756            modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
757        }};
758
759        Plotly.newPlot('accuracyPlot', traces, layout, config);
760    </script>
761</body>
762</html>"#,
763            traces_str
764        );
765
766        Ok(Some(plot_html))
767    }
768    fn create_learning_rate_plot(&self) -> Result<Option<String>> {
769        if self.metrics_history.is_empty() {
770            return Ok(None);
771        }
772
773        // Extract learning rate data from metrics history
774        let mut learning_rates = Vec::new();
775        let mut epochs = Vec::new();
776
777        for metric in &self.metrics_history {
778            epochs.push(metric.epoch);
779            learning_rates.push(metric.learning_rate.to_f64().unwrap_or(0.0));
780        }
781
782        if learning_rates.is_empty() {
783            return Ok(None);
784        }
785
786        let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
787        let lr_json = serde_json::to_string(&learning_rates).unwrap_or_default();
788
789        let plot_html = format!(
790            r#"
791<!DOCTYPE html>
792<html>
793<head>
794    <title>Learning Rate Schedule</title>
795    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
796    <style>
797        body {{ font-family: Arial, sans-serif; margin: 20px; }}
798        .plot-container {{ width: 100%; height: 600px; }}
799    </style>
800</head>
801<body>
802    <h2>Learning Rate Schedule</h2>
803    <div id="lrPlot" class="plot-container"></div>
804    <script>
805        var trace = {{
806            x: {},
807            y: {},
808            type: 'scatter',
809            mode: 'lines+markers',
810            name: 'Learning Rate',
811            line: {{ color: '#d62728', width: 3 }},
812            marker: {{ size: 8, color: '#d62728' }}
813        }};
814
815        var layout = {{
816            title: {{
817                text: "Learning Rate Over Time",
818                font: {{ size: 18 }}
819            }},
820            xaxis: {{
821                title: 'Epoch',
822                showgrid: true,
823                gridcolor: '#e0e0e0'
824            }},
825            yaxis: {{
826                title: 'Learning Rate',
827                showgrid: true,
828                gridcolor: '#e0e0e0',
829                type: 'log'
830            }},
831            hovermode: 'x unified',
832            legend: {{
833                x: 1,
834                y: 1,
835                bgcolor: 'rgba(255,255,255,0.8)',
836                bordercolor: '#000',
837                borderwidth: 1
838            }},
839            plot_bgcolor: '#ffffff',
840            paper_bgcolor: '#ffffff'
841        }};
842
843        var config = {{
844            responsive: true,
845            displayModeBar: true,
846            modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
847        }};
848
849        Plotly.newPlot('lrPlot', [trace], layout, config);
850    </script>
851</body>
852</html>"#,
853            epochs_json, lr_json
854        );
855
856        Ok(Some(plot_html))
857    }
858    fn create_system_metrics_plot(&self) -> Result<Option<String>> {
859        if self.metrics_history.is_empty() {
860            return Ok(None);
861        }
862
863        // Extract system metrics from history
864        let mut memory_usage = Vec::new();
865        let mut cpu_utilization = Vec::new();
866        let mut gpu_utilization = Vec::new();
867        let mut samples_per_second = Vec::new();
868        let mut epochs = Vec::new();
869
870        for metric in &self.metrics_history {
871            epochs.push(metric.epoch);
872            memory_usage.push(metric.system_metrics.memory_usage_mb);
873            cpu_utilization.push(metric.system_metrics.cpu_utilization);
874            if let Some(gpu_util) = metric.system_metrics.gpu_utilization {
875                gpu_utilization.push(gpu_util);
876            }
877            samples_per_second.push(metric.system_metrics.samples_per_second);
878        }
879
880        let epochs_json = serde_json::to_string(&epochs).unwrap_or_default();
881        let memory_json = serde_json::to_string(&memory_usage).unwrap_or_default();
882        let cpu_json = serde_json::to_string(&cpu_utilization).unwrap_or_default();
883        let gpu_json = if !gpu_utilization.is_empty() {
884            serde_json::to_string(&gpu_utilization).unwrap_or_default()
885        } else {
886            "[]".to_string()
887        };
888        let sps_json = serde_json::to_string(&samples_per_second).unwrap_or_default();
889
890        let plot_html = format!(
891            r#"
892<!DOCTYPE html>
893<html>
894<head>
895    <title>System Metrics</title>
896    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
897    <style>
898        body {{ font-family: Arial, sans-serif; margin: 20px; }}
899        .plot-container {{ width: 100%; height: 400px; margin-bottom: 20px; }}
900    </style>
901</head>
902<body>
903    <h2>System Performance Metrics</h2>
904
905    <h3>Memory Usage</h3>
906    <div id="memoryPlot" class="plot-container"></div>
907
908    <h3>CPU & GPU Utilization</h3>
909    <div id="utilizationPlot" class="plot-container"></div>
910
911    <h3>Training Throughput</h3>
912    <div id="throughputPlot" class="plot-container"></div>
913
914    <script>
915        var epochs = {};
916
917        // Memory usage plot
918        var memoryTrace = {{
919            x: epochs,
920            y: {},
921            type: 'scatter',
922            mode: 'lines+markers',
923            name: 'Memory Usage (MB)',
924            line: {{ color: '#ff7f0e', width: 2 }},
925            marker: {{ size: 6, color: '#ff7f0e' }}
926        }};
927
928        var memoryLayout = {{
929            title: "Memory Usage Over Time",
930            xaxis: {{ title: 'Epoch' }},
931            yaxis: {{ title: 'Memory (MB)' }},
932            showlegend: false
933        }};
934
935        var config = {{
936            responsive: true,
937            displayModeBar: true,
938            modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d']
939        }};
940
941        Plotly.newPlot('memoryPlot', [memoryTrace], memoryLayout, config);
942
943        // CPU and GPU utilization plot
944        var traces = [{{
945            x: epochs,
946            y: {},
947            type: 'scatter',
948            mode: 'lines+markers',
949            name: 'CPU Utilization (%)',
950            line: {{ color: '#1f77b4', width: 2 }},
951            marker: {{ size: 6, color: '#1f77b4' }}
952        }}];
953
954        if ({}.length > 0) {{
955            traces.push({{
956                x: epochs,
957                y: {},
958                type: 'scatter',
959                mode: 'lines+markers',
960                name: 'GPU Utilization (%)',
961                line: {{ color: '#2ca02c', width: 2 }},
962                marker: {{ size: 6, color: '#2ca02c' }}
963            }});
964        }}
965
966        var utilizationLayout = {{
967            title: "CPU & GPU Utilization",
968            xaxis: {{ title: 'Epoch' }},
969            yaxis: {{ title: 'Utilization (%)', range: [0, 100] }}
970        }};
971
972        Plotly.newPlot('utilizationPlot', traces, utilizationLayout, config);
973
974        // Throughput plot
975        var throughputTrace = {{
976            x: epochs,
977            y: {},
978            type: 'scatter',
979            mode: 'lines+markers',
980            name: 'Samples/Second',
981            line: {{ color: '#9467bd', width: 2 }},
982            marker: {{ size: 6, color: '#9467bd' }}
983        }};
984
985        var throughputLayout = {{
986            title: "Training Throughput",
987            xaxis: {{ title: 'Epoch' }},
988            yaxis: {{ title: 'Samples per Second' }},
989            showlegend: false
990        }};
991
992        Plotly.newPlot('throughputPlot', [throughputTrace], throughputLayout, config);
993    </script>
994</body>
995</html>"#,
996            epochs_json, memory_json, cpu_json, gpu_json, gpu_json, sps_json
997        );
998
999        Ok(Some(plot_html))
1000    }
1001}
1002
1003// Default implementations for configuration types
1004impl Default for PlotConfig {
1005    fn default() -> Self {
1006        Self {
1007            title: "Training Metrics".to_string(),
1008            x_axis: AxisConfig::default(),
1009            y_axis: AxisConfig::default(),
1010            series: Vec::new(),
1011            plot_type: PlotType::Line,
1012            update_mode: UpdateMode::Append,
1013        }
1014    }
1015}
1016
1017impl Default for AxisConfig {
1018    fn default() -> Self {
1019        Self {
1020            label: "".to_string(),
1021            scale: AxisScale::Linear,
1022            range: None,
1023            show_grid: true,
1024            ticks: TickConfig::default(),
1025        }
1026    }
1027}
1028
1029impl Default for TickConfig {
1030    fn default() -> Self {
1031        Self {
1032            interval: None,
1033            format: TickFormat::Auto,
1034            show_labels: true,
1035            rotation: 0.0,
1036        }
1037    }
1038}
1039
1040impl Default for SeriesConfig {
1041    fn default() -> Self {
1042        Self {
1043            name: "Series".to_string(),
1044            data_source: "".to_string(),
1045            style: LineStyleConfig::default(),
1046            markers: MarkerConfig::default(),
1047            color: "#1f77b4".to_string(), // Default blue
1048            opacity: 1.0,
1049        }
1050    }
1051}
1052
1053impl Default for LineStyleConfig {
1054    fn default() -> Self {
1055        Self {
1056            style: LineStyle::Solid,
1057            width: 2.0,
1058            smoothing: false,
1059            smoothing_window: 5,
1060        }
1061    }
1062}
1063
1064impl Default for MarkerConfig {
1065    fn default() -> Self {
1066        Self {
1067            show: false,
1068            shape: MarkerShape::Circle,
1069            size: 6.0,
1070            fill_color: "#1f77b4".to_string(),
1071            border_color: "#1f77b4".to_string(),
1072        }
1073    }
1074}
1075
1076impl Default for SystemMetrics {
1077    fn default() -> Self {
1078        Self {
1079            memory_usage_mb: 0.0,
1080            gpu_memory_mb: None,
1081            cpu_utilization: 0.0,
1082            gpu_utilization: None,
1083            step_duration_ms: 0.0,
1084            samples_per_second: 0.0,
1085        }
1086    }
1087}
1088#[cfg(test)]
1089mod tests {
1090    use super::*;
1091    #[test]
1092    fn test_training_visualizer_creation() {
1093        let config = VisualizationConfig::default();
1094        let visualizer = TrainingVisualizer::<f32>::new(config);
1095        assert!(visualizer.metrics_history.is_empty());
1096        assert!(visualizer.active_plots.is_empty());
1097    }
1098
1099    #[test]
1100    fn test_add_metrics() {
1101        let config = VisualizationConfig::default();
1102        let mut visualizer = TrainingVisualizer::<f32>::new(config);
1103        let metrics = TrainingMetrics {
1104            epoch: 1,
1105            step: 100,
1106            timestamp: "2024-01-01T00:00:00Z".to_string(),
1107            losses: HashMap::from([("train_loss".to_string(), 0.5)]),
1108            accuracies: HashMap::from([("train_acc".to_string(), 0.8)]),
1109            learning_rate: 0.001,
1110            custom_metrics: HashMap::new(),
1111            system_metrics: SystemMetrics::default(),
1112        };
1113        visualizer.add_metrics(metrics);
1114        assert_eq!(visualizer.metrics_history.len(), 1);
1115    }
1116
1117    #[test]
1118    fn test_plot_config_defaults() {
1119        let config = PlotConfig::default();
1120        assert_eq!(config.title, "Training Metrics");
1121        assert_eq!(config.plot_type, PlotType::Line);
1122        assert_eq!(config.update_mode, UpdateMode::Append);
1123    }
1124
1125    #[test]
1126    fn test_axis_scale_variants() {
1127        assert_eq!(AxisScale::Linear, AxisScale::Linear);
1128        assert_eq!(AxisScale::Log, AxisScale::Log);
1129        assert_eq!(AxisScale::Sqrt, AxisScale::Sqrt);
1130        let custom = AxisScale::Custom("symlog".to_string());
1131        match custom {
1132            AxisScale::Custom(name) => assert_eq!(name, "symlog"),
1133            _ => panic!("Expected custom scale"),
1134        }
1135    }
1136
1137    #[test]
1138    fn test_markershapes() {
1139        let shapes = [
1140            MarkerShape::Circle,
1141            MarkerShape::Square,
1142            MarkerShape::Triangle,
1143            MarkerShape::Diamond,
1144            MarkerShape::Cross,
1145            MarkerShape::Plus,
1146        ];
1147        assert_eq!(shapes.len(), 6);
1148        assert_eq!(shapes[0], MarkerShape::Circle);
1149    }
1150
1151    #[test]
1152    fn test_plot_types() {
1153        let types = [
1154            PlotType::Line,
1155            PlotType::Scatter,
1156            PlotType::Bar,
1157            PlotType::Area,
1158            PlotType::Histogram,
1159            PlotType::Box,
1160            PlotType::Heatmap,
1161        ];
1162        assert_eq!(types.len(), 7);
1163        assert_eq!(types[0], PlotType::Line);
1164    }
1165
1166    #[test]
1167    fn test_update_modes() {
1168        let append = UpdateMode::Append;
1169        let replace = UpdateMode::Replace;
1170        let rolling = UpdateMode::Rolling(100);
1171        assert_eq!(append, UpdateMode::Append);
1172        assert_eq!(replace, UpdateMode::Replace);
1173        match rolling {
1174            UpdateMode::Rolling(size) => assert_eq!(size, 100),
1175            _ => panic!("Expected rolling update mode"),
1176        }
1177    }
1178
1179    #[test]
1180    fn test_clear_history() {
1181        let config = VisualizationConfig::default();
1182        let mut visualizer = TrainingVisualizer::<f32>::new(config);
1183        visualizer.clear_history();
1184        assert!(visualizer.metrics_history.is_empty());
1185    }
1186
1187    #[test]
1188    fn test_plot_management() {
1189        let config = VisualizationConfig::default();
1190        let mut visualizer = TrainingVisualizer::<f32>::new(config);
1191        let plot_config = PlotConfig::default();
1192        visualizer.add_plot("test_plot".to_string(), plot_config);
1193        assert!(visualizer.active_plots.contains_key("test_plot"));
1194        let removed = visualizer.remove_plot("test_plot");
1195        assert!(removed.is_some());
1196        assert!(!visualizer.active_plots.contains_key("test_plot"));
1197    }
1198}