use crate::callbacks::{Callback, CallbackContext, CallbackTiming};
use crate::error::Result;
use crate::utils::visualization::{float_vec_to_f64, plot_metrics, PlotOptions};
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::PathBuf;
pub struct VisualizationCallback<F: Float + Debug + ScalarOperand> {
pub frequency: usize,
pub show_plots: bool,
pub save_path: Option<PathBuf>,
pub tracked_metrics: Vec<String>,
pub plot_options: PlotOptions,
epoch_history: HashMap<String, Vec<F>>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> VisualizationCallback<F> {
pub fn new(frequency: usize) -> Self {
Self {
frequency,
show_plots: true,
save_path: None,
tracked_metrics: vec!["train_loss".to_string(), "val_loss".to_string()],
plot_options: PlotOptions::default(),
epoch_history: HashMap::new(),
}
}
pub fn with_save_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
self.save_path = Some(path.into());
self
}
pub fn with_show_plots(mut self, show_plots: bool) -> Self {
self.show_plots = show_plots;
self
}
pub fn with_tracked_metrics(mut self, metrics: Vec<String>) -> Self {
self.tracked_metrics = metrics;
self
}
pub fn with_plot_options(mut self, options: PlotOptions) -> Self {
self.plot_options = options;
self
}
fn history_as_f64(&self) -> HashMap<String, Vec<f64>> {
self.epoch_history
.iter()
.filter_map(|(key, values)| float_vec_to_f64(values).ok().map(|v| (key.clone(), v)))
.collect()
}
}
impl<F: Float + Debug + ScalarOperand + std::fmt::Display + NumAssign> Callback<F>
for VisualizationCallback<F>
{
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
match timing {
CallbackTiming::BeforeTraining => {
for metric in &self.tracked_metrics {
self.epoch_history.insert(metric.clone(), Vec::new());
}
}
CallbackTiming::AfterEpoch => {
if let Some(train_loss) = context.epoch_loss {
if let Some(values) = self.epoch_history.get_mut("train_loss") {
values.push(train_loss);
}
}
if let Some(val_loss) = context.val_loss {
if let Some(values) = self.epoch_history.get_mut("val_loss") {
values.push(val_loss);
}
}
if !context.metrics.is_empty() {
let tracked_metrics_count = self.tracked_metrics.len();
let metric_offset = 2;
for (i, &metric_value) in context.metrics.iter().enumerate() {
let metric_name = if i + metric_offset < tracked_metrics_count {
self.tracked_metrics[i + metric_offset].clone()
} else {
format!("metric_{}", i)
};
if let Some(values) = self.epoch_history.get_mut(&metric_name) {
values.push(metric_value);
} else {
self.epoch_history.insert(metric_name, vec![metric_value]);
}
}
}
if context.epoch.is_multiple_of(self.frequency)
&& self.show_plots
&& !self.epoch_history.is_empty()
{
let history_f64 = self.history_as_f64();
if let Ok(plot) = plot_metrics(
&history_f64,
Some("Training Metrics"),
Some(self.plot_options.clone()),
) {
println!("\n{plot}");
}
}
}
CallbackTiming::AfterTraining => {
if self.show_plots && !self.epoch_history.is_empty() {
let history_f64 = self.history_as_f64();
if let Ok(plot) = plot_metrics(
&history_f64,
Some("Final Training Metrics"),
Some(self.plot_options.clone()),
) {
println!("\n{plot}");
}
}
if let Some(save_path) = &self.save_path {
let history_f64 = self.history_as_f64();
if let Ok(plot) = plot_metrics(
&history_f64,
Some("Final Training Metrics"),
Some(self.plot_options.clone()),
) {
if let Err(e) = std::fs::write(save_path, &plot) {
eprintln!("Failed to save plot to {}: {}", save_path.display(), e);
} else {
println!("Plot saved to {}", save_path.display());
}
}
}
}
_ => {}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visualization_callback_creation() {
let callback = VisualizationCallback::<f32>::new(1);
assert_eq!(callback.frequency, 1);
assert!(callback.show_plots);
assert!(callback.save_path.is_none());
assert_eq!(
callback.tracked_metrics,
vec!["train_loss".to_string(), "val_loss".to_string()]
);
assert!(callback.epoch_history.is_empty());
let callback = VisualizationCallback::<f32>::new(2)
.with_save_path("test.txt")
.with_show_plots(false)
.with_tracked_metrics(vec![
"train_loss".to_string(),
"val_loss".to_string(),
"accuracy".to_string(),
]);
assert_eq!(callback.frequency, 2);
assert!(!callback.show_plots);
assert!(callback.save_path.is_some());
assert_eq!(
callback.tracked_metrics,
vec![
"train_loss".to_string(),
"val_loss".to_string(),
"accuracy".to_string()
]
);
}
}