use crate::visualization::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
use std::collections::HashMap;
use std::error::Error;
pub struct TrainingHistoryVisualizer {
title: String,
metrics: Vec<String>,
history: Vec<HashMap<String, f64>>,
val_history: Option<Vec<HashMap<String, f64>>>,
x_label: String,
y_label: Option<String>,
}
impl TrainingHistoryVisualizer {
pub fn new(
title: impl Into<String>,
metrics: Vec<String>,
history: Vec<HashMap<String, f64>>,
) -> Self {
Self {
title: title.into(),
metrics,
history,
val_history: None,
x_label: "Epoch".to_string(),
y_label: None,
}
}
pub fn with_validation(mut self, valhistory: Vec<HashMap<String, f64>>) -> Self {
self.val_history = Some(valhistory);
self
}
pub fn with_x_label(mut self, xlabel: impl Into<String>) -> Self {
self.x_label = xlabel.into();
self
}
pub fn with_y_label(mut self, ylabel: impl Into<String>) -> Self {
self.y_label = Some(ylabel.into());
self
}
}
impl MetricVisualizer for TrainingHistoryVisualizer {
fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
let mut data = VisualizationData::new();
let epochs: Vec<f64> = (0..self.history.len()).map(|i| i as f64).collect();
data.add_series("epochs", epochs.clone());
for metric_name in &self.metrics {
let metric_values: Vec<f64> = self
.history
.iter()
.map(|epoch_data| *epoch_data.get(metric_name).unwrap_or(&f64::NAN))
.collect();
data.add_series(metric_name.clone(), metric_values);
if let Some(val_history) = &self.val_history {
let val_metric_values: Vec<f64> = val_history
.iter()
.map(|epoch_data| *epoch_data.get(metric_name).unwrap_or(&f64::NAN))
.collect();
data.add_series(format!("val_{}", metric_name), val_metric_values);
}
}
Ok(data)
}
fn get_metadata(&self) -> VisualizationMetadata {
let mut metadata = VisualizationMetadata::new(self.title.clone());
metadata.set_plot_type(PlotType::Line);
metadata.set_x_label(self.x_label.clone());
if let Some(y_label) = &self.y_label {
metadata.set_y_label(y_label.clone());
} else if self.metrics.len() == 1 {
metadata.set_y_label(self.metrics[0].clone());
}
metadata
}
}
#[allow(dead_code)]
pub fn training_history_visualization(
metric_names: Vec<String>,
history: Vec<HashMap<String, f64>>,
val_history: Option<Vec<HashMap<String, f64>>>,
) -> Box<dyn MetricVisualizer> {
let mut visualizer = TrainingHistoryVisualizer::new(
format!("Training History ({})", metric_names.join(", ")),
metric_names,
history,
);
if let Some(val_history) = val_history {
visualizer = visualizer.with_validation(val_history);
}
Box::new(visualizer)
}