//! Visualization tools for debugging
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
/// Plot type for visualization
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PlotType {
Line,
Scatter,
Bar,
Histogram,
Heatmap,
ThreeDimensional,
}
/// Configuration for visualizations
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualizationConfig {
pub output_directory: String,
pub image_format: ImageFormat,
pub plot_width: u32,
pub plot_height: u32,
pub font_size: u32,
pub color_scheme: ColorScheme,
}
impl Default for VisualizationConfig {
fn default() -> Self {
Self {
output_directory: "./debug_plots".to_string(),
image_format: ImageFormat::PNG,
plot_width: 800,
plot_height: 600,
font_size: 12,
color_scheme: ColorScheme::Default,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ImageFormat {
PNG,
SVG,
PDF,
HTML,
LaTeX,
JSON,
/// MP4 video format for animated visualizations
MP4,
/// GIF format for animated visualizations
GIF,
/// WebM video format for web-compatible animations
WebM,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColorScheme {
Default,
Dark,
Colorblind,
Viridis,
Plasma,
}
/// Visualization types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VisualizationType {
/// Line plot for time series data
LinePlot,
/// Histogram for distribution analysis
Histogram,
/// Heatmap for 2D tensor visualization
Heatmap,
/// Scatter plot for correlation analysis
ScatterPlot,
/// Box plot for statistical summaries
BoxPlot,
/// 3D surface plot for advanced visualization
SurfacePlot,
/// 3D loss landscape visualization
LossLandscape,
/// 3D optimization trajectory
OptimizationTrajectory,
/// 3D weight space exploration
WeightSpaceExploration,
/// 3D embedding projections
EmbeddingProjection,
/// Architecture diagram
ArchitectureDiagram,
}
/// Plot data structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlotData {
pub x_values: Vec<f64>,
pub y_values: Vec<f64>,
pub labels: Vec<String>,
pub title: String,
pub x_label: String,
pub y_label: String,
}
/// 3D Plot data structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plot3DData {
pub x_values: Vec<f64>,
pub y_values: Vec<f64>,
pub z_values: Vec<f64>,
pub title: String,
pub x_label: String,
pub y_label: String,
pub z_label: String,
pub point_labels: Vec<String>,
pub color_values: Option<Vec<f64>>,
pub size_values: Option<Vec<f64>>,
}
/// Architecture node for diagram visualization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureNode {
pub id: String,
pub name: String,
pub node_type: String,
pub position: (f64, f64, f64),
pub size: (f64, f64, f64),
pub color: String,
pub metadata: HashMap<String, String>,
}
/// Architecture connection for diagram visualization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureConnection {
pub from_node: String,
pub to_node: String,
pub connection_type: String,
pub weight: f64,
pub color: String,
pub metadata: HashMap<String, String>,
}
/// Visualization system for debugging data
#[derive(Debug)]
pub struct DebugVisualizer {
config: VisualizationConfig,
plots: HashMap<String, PlotData>,
plots_3d: HashMap<String, Plot3DData>,
architecture_nodes: HashMap<String, ArchitectureNode>,
architecture_connections: HashMap<String, ArchitectureConnection>,
}
impl DebugVisualizer {
/// Create a new debug visualizer
pub fn new(config: VisualizationConfig) -> Self {
Self {
config,
plots: HashMap::new(),
plots_3d: HashMap::new(),
architecture_nodes: HashMap::new(),
architecture_connections: HashMap::new(),
}
}
/// Create the output directory if it doesn't exist
pub fn ensure_output_directory(&self) -> Result<()> {
std::fs::create_dir_all(&self.config.output_directory)?;
Ok(())
}
/// Save a plot to file with the specified format
pub fn save_plot(&self, plot_id: &str, filename: &str) -> Result<()> {
let plot = self.plots.get(plot_id).ok_or_else(|| {
anyhow::anyhow!("Plot with id '{}' not found", plot_id)
})?;
self.ensure_output_directory()?;
let file_path = Path::new(&self.config.output_directory).join(filename);
match self.config.image_format {
ImageFormat::PNG => self.save_plot_png(plot, &file_path),
ImageFormat::SVG => self.save_plot_svg(plot, &file_path),
ImageFormat::PDF => self.save_plot_pdf(plot, &file_path),
ImageFormat::HTML => self.save_plot_html(plot, &file_path),
ImageFormat::LaTeX => self.save_plot_latex(plot, &file_path),
ImageFormat::JSON => self.save_plot_json(plot, &file_path),
ImageFormat::MP4 => {
// For single plot, create a simple animation sequence
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_mp4(&sequence, file_path.to_str().unwrap())
},
ImageFormat::GIF => {
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_gif(&sequence, file_path.to_str().unwrap())
},
ImageFormat::WebM => {
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_webm(&sequence, file_path.to_str().unwrap())
},
}
}
/// Export plot in all supported formats
pub fn export_plot_all_formats(&self, plot_id: &str, base_filename: &str) -> Result<()> {
let plot = self.plots.get(plot_id).ok_or_else(|| {
anyhow::anyhow!("Plot with id '{}' not found", plot_id)
})?;
self.ensure_output_directory()?;
let base_path = Path::new(&self.config.output_directory).join(base_filename);
// Export to all formats
let formats = [
(ImageFormat::PNG, "png"),
(ImageFormat::SVG, "svg"),
(ImageFormat::PDF, "pdf"),
(ImageFormat::HTML, "html"),
(ImageFormat::LaTeX, "tex"),
(ImageFormat::JSON, "json"),
(ImageFormat::MP4, "mp4"),
(ImageFormat::GIF, "gif"),
(ImageFormat::WebM, "webm"),
];
for (format, extension) in &formats {
let file_path = base_path.with_extension(extension);
match format {
ImageFormat::PNG => self.save_plot_png(plot, &file_path)?,
ImageFormat::SVG => self.save_plot_svg(plot, &file_path)?,
ImageFormat::PDF => self.save_plot_pdf(plot, &file_path)?,
ImageFormat::HTML => self.save_plot_html(plot, &file_path)?,
ImageFormat::LaTeX => self.save_plot_latex(plot, &file_path)?,
ImageFormat::JSON => self.save_plot_json(plot, &file_path)?,
ImageFormat::MP4 => {
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_mp4(&sequence, file_path.to_str().unwrap())?;
},
ImageFormat::GIF => {
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_gif(&sequence, file_path.to_str().unwrap())?;
},
ImageFormat::WebM => {
let video_generator = VideoGenerator::default();
let sequence = self.create_simple_animation_sequence(plot)?;
video_generator.generate_webm(&sequence, file_path.to_str().unwrap())?;
},
}
}
Ok(())
}
/// Plot tensor distribution as histogram
pub fn plot_tensor_distribution(
&mut self,
tensor_name: &str,
values: &[f64],
num_bins: usize,
) -> Result<String> {
self.ensure_output_directory()?;
// Create histogram data
let (min_val, max_val) =
values.iter().fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
(min.min(x), max.max(x))
});
let bin_width = (max_val - min_val) / num_bins as f64;
let mut bins = vec![0; num_bins];
let mut bin_centers = Vec::new();
for i in 0..num_bins {
bin_centers.push(min_val + (i as f64 + 0.5) * bin_width);
}
for &value in values {
if value.is_finite() {
let bin_idx = ((value - min_val) / bin_width).floor() as usize;
let bin_idx = bin_idx.min(num_bins - 1);
bins[bin_idx] += 1;
}
}
let plot_data = PlotData {
x_values: bin_centers,
y_values: bins.into_iter().map(|x| x as f64).collect(),
labels: vec!["Frequency".to_string()],
title: format!("Distribution of {}", tensor_name),
x_label: "Value".to_string(),
y_label: "Frequency".to_string(),
};
let plot_name = format!("{}_distribution", tensor_name);
self.plots.insert(plot_name.clone(), plot_data);
self.render_histogram(&plot_name)
}
/// Plot gradient flow over time
pub fn plot_gradient_flow(
&mut self,
layer_name: &str,
steps: &[usize],
gradient_norms: &[f64],
) -> Result<String> {
self.ensure_output_directory()?;
let plot_data = PlotData {
x_values: steps.iter().map(|&x| x as f64).collect(),
y_values: gradient_norms.to_vec(),
labels: vec![format!("{} Gradient Norm", layer_name)],
title: format!("Gradient Flow - {}", layer_name),
x_label: "Training Step".to_string(),
y_label: "Gradient Norm".to_string(),
};
let plot_name = format!("{}_gradient_flow", layer_name);
self.plots.insert(plot_name.clone(), plot_data);
self.render_line_plot(&plot_name)
}
/// Plot training metrics over time
pub fn plot_training_metrics(
&mut self,
steps: &[usize],
losses: &[f64],
accuracies: Option<&[f64]>,
) -> Result<String> {
self.ensure_output_directory()?;
let mut y_values = vec![losses.to_vec()];
let mut labels = vec!["Loss".to_string()];
if let Some(acc) = accuracies {
y_values.push(acc.to_vec());
labels.push("Accuracy".to_string());
}
let plot_data = PlotData {
x_values: steps.iter().map(|&x| x as f64).collect(),
y_values: losses.to_vec(), // Simplified for single series
labels,
title: "Training Metrics".to_string(),
x_label: "Training Step".to_string(),
y_label: "Value".to_string(),
};
let plot_name = "training_metrics".to_string();
self.plots.insert(plot_name.clone(), plot_data);
self.render_line_plot(&plot_name)
}
/// Plot tensor as heatmap (for 2D tensors)
pub fn plot_tensor_heatmap(
&mut self,
tensor_name: &str,
tensor_2d: &[Vec<f64>],
) -> Result<String> {
self.ensure_output_directory()?;
let filename = format!(
"{}/{}_heatmap.{}",
self.config.output_directory,
tensor_name,
self.format_extension()
);
self.render_heatmap(tensor_2d, &filename, &format!("Heatmap - {}", tensor_name))?;
Ok(filename)
}
/// Plot correlation matrix
pub fn plot_correlation_matrix(
&mut self,
layer_names: &[String],
correlation_matrix: &[Vec<f64>],
) -> Result<String> {
self.ensure_output_directory()?;
let filename = format!(
"{}/correlation_matrix.{}",
self.config.output_directory,
self.format_extension()
);
self.render_correlation_heatmap(correlation_matrix, layer_names, &filename)?;
Ok(filename)
}
/// Plot layer activation patterns
pub fn plot_activation_patterns(
&mut self,
layer_name: &str,
steps: &[usize],
mean_activations: &[f64],
std_activations: &[f64],
) -> Result<String> {
self.ensure_output_directory()?;
let plot_data = PlotData {
x_values: steps.iter().map(|&x| x as f64).collect(),
y_values: mean_activations.to_vec(),
labels: vec![format!("{} Mean Activation", layer_name)],
title: format!("Activation Patterns - {}", layer_name),
x_label: "Training Step".to_string(),
y_label: "Mean Activation".to_string(),
};
let plot_name = format!("{}_activations", layer_name);
self.plots.insert(plot_name.clone(), plot_data);
self.render_line_plot_with_error_bars(&plot_name, std_activations)
}
/// Create a dashboard with multiple plots
pub fn create_dashboard(&self, plot_names: &[String]) -> Result<String> {
self.ensure_output_directory()?;
let dashboard_path = format!("{}/debug_dashboard.html", self.config.output_directory);
let mut html_content = String::new();
// HTML header
html_content.push_str(&format!(
r#"
<!DOCTYPE html>
<html>
<head>
<title>TrustformeRS Debug Dashboard</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.plot-container {{ margin: 20px 0; text-align: center; }}
.plot-container img {{ max-width: 100%; height: auto; border: 1px solid #ddd; }}
h1 {{ color: #333; }}
h2 {{ color: #666; }}
</style>
</head>
<body>
<h1>TrustformeRS Debug Dashboard</h1>
<p>Generated on: {}</p>
"#,
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
));
// Add plots
for plot_name in plot_names {
if self.plots.contains_key(plot_name) {
let plot_path = format!("{}.{}", plot_name, self.format_extension());
html_content.push_str(&format!(
r#"
<div class="plot-container">
<h2>{}</h2>
<img src="{}" alt="{}">
</div>
"#,
plot_name, plot_path, plot_name
));
}
}
// HTML footer
html_content.push_str("</body>\n</html>");
std::fs::write(&dashboard_path, html_content)?;
Ok(dashboard_path)
}
/// Export plot data as JSON
pub fn export_plot_data(&self, plot_name: &str, path: &str) -> Result<()> {
if let Some(plot_data) = self.plots.get(plot_name) {
let json = serde_json::to_string_pretty(plot_data)?;
std::fs::write(path, json)?;
Ok(())
} else {
Err(anyhow::anyhow!("Plot '{}' not found", plot_name))
}
}
/// Get list of available plots
pub fn get_plot_names(&self) -> Vec<String> {
self.plots.keys().cloned().collect()
}
/// Get list of available 3D plots
pub fn get_3d_plot_names(&self) -> Vec<String> {
self.plots_3d.keys().cloned().collect()
}
/// Plot loss landscape in 3D
pub fn plot_loss_landscape(
&mut self,
x_params: &[f64],
y_params: &[f64],
loss_values: &[f64],
param_names: (&str, &str),
) -> Result<String> {
self.ensure_output_directory()?;
let plot_data = Plot3DData {
x_values: x_params.to_vec(),
y_values: y_params.to_vec(),
z_values: loss_values.to_vec(),
title: "Loss Landscape".to_string(),
x_label: param_names.0.to_string(),
y_label: param_names.1.to_string(),
z_label: "Loss".to_string(),
point_labels: vec![],
color_values: Some(loss_values.to_vec()),
size_values: None,
};
let plot_name = "loss_landscape".to_string();
self.plots_3d.insert(plot_name.clone(), plot_data);
self.render_3d_surface(&plot_name)
}
/// Plot optimization trajectory in 3D
pub fn plot_optimization_trajectory(
&mut self,
trajectory_points: &[(f64, f64, f64)],
step_labels: &[String],
trajectory_name: &str,
) -> Result<String> {
self.ensure_output_directory()?;
let mut x_vals = Vec::new();
let mut y_vals = Vec::new();
let mut z_vals = Vec::new();
for (x, y, z) in trajectory_points {
x_vals.push(*x);
y_vals.push(*y);
z_vals.push(*z);
}
let plot_data = Plot3DData {
x_values: x_vals,
y_values: y_vals,
z_values: z_vals,
title: format!("Optimization Trajectory - {}", trajectory_name),
x_label: "Parameter 1".to_string(),
y_label: "Parameter 2".to_string(),
z_label: "Loss".to_string(),
point_labels: step_labels.to_vec(),
color_values: Some((0..trajectory_points.len()).map(|i| i as f64).collect()),
size_values: None,
};
let plot_name = format!("{}_trajectory", trajectory_name);
self.plots_3d.insert(plot_name.clone(), plot_data);
self.render_3d_trajectory(&plot_name)
}
/// Plot weight space exploration in 3D
pub fn plot_weight_space_exploration(
&mut self,
weight_vectors: &[Vec<f64>],
layer_names: &[String],
reduction_method: &str,
) -> Result<String> {
self.ensure_output_directory()?;
if weight_vectors.is_empty() {
return Err(anyhow::anyhow!("No weight vectors provided"));
}
let reduced_weights = self.reduce_dimensions(weight_vectors, reduction_method)?;
let plot_data = Plot3DData {
x_values: reduced_weights.iter().map(|v| v[0]).collect(),
y_values: reduced_weights.iter().map(|v| v[1]).collect(),
z_values: reduced_weights.iter().map(|v| v[2]).collect(),
title: format!("Weight Space Exploration ({})", reduction_method),
x_label: "Component 1".to_string(),
y_label: "Component 2".to_string(),
z_label: "Component 3".to_string(),
point_labels: layer_names.to_vec(),
color_values: Some((0..layer_names.len()).map(|i| i as f64).collect()),
size_values: None,
};
let plot_name = format!("weight_space_{}", reduction_method);
self.plots_3d.insert(plot_name.clone(), plot_data);
self.render_3d_scatter(&plot_name)
}
/// Plot embedding projections in 3D
pub fn plot_embedding_projection(
&mut self,
embeddings: &[Vec<f64>],
labels: &[String],
embedding_name: &str,
) -> Result<String> {
self.ensure_output_directory()?;
if embeddings.is_empty() {
return Err(anyhow::anyhow!("No embeddings provided"));
}
let reduced_embeddings = self.reduce_dimensions(embeddings, "tsne")?;
let plot_data = Plot3DData {
x_values: reduced_embeddings.iter().map(|v| v[0]).collect(),
y_values: reduced_embeddings.iter().map(|v| v[1]).collect(),
z_values: reduced_embeddings.iter().map(|v| v[2]).collect(),
title: format!("Embedding Projection - {}", embedding_name),
x_label: "t-SNE 1".to_string(),
y_label: "t-SNE 2".to_string(),
z_label: "t-SNE 3".to_string(),
point_labels: labels.to_vec(),
color_values: Some(self.assign_cluster_colors(labels)),
size_values: None,
};
let plot_name = format!("{}_embedding", embedding_name);
self.plots_3d.insert(plot_name.clone(), plot_data);
self.render_3d_scatter(&plot_name)
}
/// Create architecture diagram
pub fn create_architecture_diagram(
&mut self,
layers: &[(&str, &str, Vec<usize>)], // (name, type, shape)
connections: &[(String, String, f64)], // (from, to, weight)
) -> Result<String> {
self.ensure_output_directory()?;
self.architecture_nodes.clear();
self.architecture_connections.clear();
let mut y_pos = 0.0;
for (i, (name, layer_type, shape)) in layers.iter().enumerate() {
let node_size = self.calculate_node_size(shape);
let node = ArchitectureNode {
id: name.to_string(),
name: name.to_string(),
node_type: layer_type.to_string(),
position: (0.0, y_pos, 0.0),
size: node_size,
color: self.get_layer_color(layer_type),
metadata: {
let mut meta = HashMap::new();
meta.insert("shape".to_string(), format!("{:?}", shape));
meta.insert(
"parameters".to_string(),
shape.iter().product::<usize>().to_string(),
);
meta
},
};
self.architecture_nodes.insert(name.to_string(), node);
y_pos += 2.0;
}
for (i, (from, to, weight)) in connections.iter().enumerate() {
let connection = ArchitectureConnection {
from_node: from.clone(),
to_node: to.clone(),
connection_type: "forward".to_string(),
weight: *weight,
color: self.get_connection_color(*weight),
metadata: HashMap::new(),
};
self.architecture_connections.insert(format!("conn_{}", i), connection);
}
self.render_architecture_diagram()
}
// Private rendering methods (simplified implementations)
fn render_line_plot(&self, plot_name: &str) -> Result<String> {
let filename = format!(
"{}/{}.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots.get(plot_name) {
self.create_line_plot_file(plot_data, &filename)?;
}
Ok(filename)
}
fn render_histogram(&self, plot_name: &str) -> Result<String> {
let filename = format!(
"{}/{}.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots.get(plot_name) {
self.create_histogram_file(plot_data, &filename)?;
}
Ok(filename)
}
fn render_line_plot_with_error_bars(
&self,
plot_name: &str,
error_bars: &[f64],
) -> Result<String> {
let filename = format!(
"{}/{}.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots.get(plot_name) {
self.create_line_plot_with_errors(plot_data, error_bars, &filename)?;
}
Ok(filename)
}
fn render_heatmap(&self, data: &[Vec<f64>], filename: &str, title: &str) -> Result<()> {
// Simplified heatmap rendering
self.create_heatmap_file(data, filename, title)
}
fn render_correlation_heatmap(
&self,
data: &[Vec<f64>],
labels: &[String],
filename: &str,
) -> Result<()> {
// Simplified correlation heatmap
self.create_correlation_heatmap_file(data, labels, filename)
}
// Actual plotting implementations using plotters crate
fn create_line_plot_file(&self, plot_data: &PlotData, filename: &str) -> Result<()> {
#[cfg(feature = "visual")]
{
use plotters::prelude::*;
match self.config.image_format {
ImageFormat::PNG => {
let root = BitMapBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
self.create_line_plot_with_backend(&root, plot_data)?;
},
ImageFormat::SVG => {
let root = SVGBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
self.create_line_plot_with_backend(&root, plot_data)?;
},
_ => return self.create_fallback_line_plot(plot_data, filename),
};
}
#[cfg(not(feature = "visual"))]
{
self.create_fallback_line_plot(plot_data, filename)?;
}
// Also create CSV backup
let csv_data = self.create_csv_data(plot_data)?;
std::fs::write(filename.replace(&self.format_extension(), "csv"), csv_data)?;
Ok(())
}
fn create_histogram_file(&self, plot_data: &PlotData, filename: &str) -> Result<()> {
#[cfg(feature = "visual")]
{
use plotters::prelude::*;
match self.config.image_format {
ImageFormat::PNG => {
let root = BitMapBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
self.create_histogram_with_backend(&root, plot_data)?;
},
ImageFormat::SVG => {
let root = SVGBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
self.create_histogram_with_backend(&root, plot_data)?;
},
_ => return self.create_fallback_histogram(plot_data, filename),
};
}
#[cfg(not(feature = "visual"))]
{
self.create_fallback_histogram(plot_data, filename)?;
}
// Also create CSV backup
let csv_data = self.create_csv_data(plot_data)?;
std::fs::write(filename.replace(&self.format_extension(), "csv"), csv_data)?;
Ok(())
}
fn create_line_plot_with_errors(
&self,
plot_data: &PlotData,
_error_bars: &[f64],
filename: &str,
) -> Result<()> {
// Simplified implementation
self.create_line_plot_file(plot_data, filename)
}
fn create_heatmap_file(&self, data: &[Vec<f64>], filename: &str, title: &str) -> Result<()> {
let mut csv_content = String::new();
for row in data {
let row_str: Vec<String> = row.iter().map(|x| x.to_string()).collect();
csv_content.push_str(&row_str.join(","));
csv_content.push('\n');
}
std::fs::write(
filename.replace(&self.format_extension(), "csv"),
csv_content,
)?;
let text_info = format!(
"Heatmap: {}\nDimensions: {}x{}\n",
title,
data.len(),
data.first().map_or(0, |row| row.len())
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), text_info)?;
Ok(())
}
fn create_correlation_heatmap_file(
&self,
data: &[Vec<f64>],
labels: &[String],
filename: &str,
) -> Result<()> {
let mut csv_content = String::new();
// Header row
csv_content.push_str(&format!(",{}\n", labels.join(",")));
// Data rows with labels
for (i, row) in data.iter().enumerate() {
let row_str: Vec<String> = row.iter().map(|x| format!("{:.4}", x)).collect();
csv_content.push_str(&format!(
"{},{}\n",
labels.get(i).unwrap_or(&i.to_string()),
row_str.join(",")
));
}
std::fs::write(
filename.replace(&self.format_extension(), "csv"),
csv_content,
)?;
Ok(())
}
fn create_csv_data(&self, plot_data: &PlotData) -> Result<String> {
let mut csv_content = format!("{},{}\n", plot_data.x_label, plot_data.y_label);
for (x, y) in plot_data.x_values.iter().zip(plot_data.y_values.iter()) {
csv_content.push_str(&format!("{},{}\n", x, y));
}
Ok(csv_content)
}
fn format_extension(&self) -> &str {
match self.config.image_format {
ImageFormat::PNG => "png",
ImageFormat::SVG => "svg",
ImageFormat::PDF => "pdf",
ImageFormat::HTML => "html",
ImageFormat::LaTeX => "tex",
ImageFormat::JSON => "json",
ImageFormat::MP4 => "mp4",
ImageFormat::GIF => "gif",
ImageFormat::WebM => "webm",
}
}
// 3D rendering methods
fn render_3d_surface(&self, plot_name: &str) -> Result<String> {
let filename = format!(
"{}/{}_surface.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots_3d.get(plot_name) {
self.create_3d_surface_file(plot_data, &filename)?;
}
Ok(filename)
}
fn render_3d_trajectory(&self, plot_name: &str) -> Result<String> {
let filename = format!(
"{}/{}_trajectory.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots_3d.get(plot_name) {
self.create_3d_trajectory_file(plot_data, &filename)?;
}
Ok(filename)
}
fn render_3d_scatter(&self, plot_name: &str) -> Result<String> {
let filename = format!(
"{}/{}_scatter.{}",
self.config.output_directory,
plot_name,
self.format_extension()
);
if let Some(plot_data) = self.plots_3d.get(plot_name) {
self.create_3d_scatter_file(plot_data, &filename)?;
}
Ok(filename)
}
fn render_architecture_diagram(&self) -> Result<String> {
let filename = format!(
"{}/architecture_diagram.{}",
self.config.output_directory,
self.format_extension()
);
self.create_architecture_diagram_file(&filename)?;
Ok(filename)
}
// 3D file creation methods
fn create_3d_surface_file(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
#[cfg(feature = "visual")]
{
match self.config.image_format {
ImageFormat::PNG => {
use plotters::prelude::*;
let root = BitMapBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
root.fill(&WHITE)?;
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let z_min = plot_data.z_values.iter().cloned().fold(f64::INFINITY, f64::min);
let z_max = plot_data.z_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(&root)
.caption(&plot_data.title, ("sans-serif", self.config.font_size))
.margin(5)
.build_cartesian_3d(x_min..x_max, z_min..z_max, y_min..y_max)?;
chart.configure_axes().draw()?;
let surface_data: Vec<_> = plot_data.x_values.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
.map(|((&x, &y), &z)| (x, z, y))
.collect();
chart.draw_series(
surface_data.iter().map(|&(x, z, y)| {
Circle::new((x, z, y), 2, BLUE.filled())
})
)?
.label("Surface Points")
.legend(|(x, y)| Circle::new((x, y), 2, BLUE.filled()));
chart.configure_series_labels().draw()?;
root.present()?;
},
_ => return self.create_fallback_3d_surface(plot_data, filename),
}
}
#[cfg(not(feature = "visual"))]
{
self.create_fallback_3d_surface(plot_data, filename)?;
}
// Create CSV backup
self.create_fallback_3d_surface(plot_data, filename)
}
fn create_3d_trajectory_file(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
#[cfg(feature = "visual")]
{
match self.config.image_format {
ImageFormat::PNG => {
use plotters::prelude::*;
let root = BitMapBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
root.fill(&WHITE)?;
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let z_min = plot_data.z_values.iter().cloned().fold(f64::INFINITY, f64::min);
let z_max = plot_data.z_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(&root)
.caption(&plot_data.title, ("sans-serif", self.config.font_size))
.margin(5)
.build_cartesian_3d(x_min..x_max, z_min..z_max, y_min..y_max)?;
chart.configure_axes().draw()?;
// Draw trajectory as connected line segments
let trajectory_data: Vec<_> = plot_data.x_values.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
.map(|((&x, &y), &z)| (x, z, y))
.collect();
// Draw points
chart.draw_series(
trajectory_data.iter().enumerate().map(|(i, &(x, z, y))| {
let color = if i == 0 { &GREEN } else if i == trajectory_data.len() - 1 { &RED } else { &BLUE };
Circle::new((x, z, y), 3, color.filled())
})
)?;
// Draw connecting lines
if trajectory_data.len() > 1 {
for i in 0..trajectory_data.len() - 1 {
let start = trajectory_data[i];
let end = trajectory_data[i + 1];
chart.draw_series(
std::iter::once(PathElement::new(vec![start, end], &BLUE))
)?;
}
}
root.present()?;
},
_ => return self.create_fallback_3d_trajectory(plot_data, filename),
}
}
#[cfg(not(feature = "visual"))]
{
self.create_fallback_3d_trajectory(plot_data, filename)?;
}
// Create CSV backup
self.create_fallback_3d_trajectory(plot_data, filename)
}
fn create_3d_scatter_file(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
#[cfg(feature = "visual")]
{
match self.config.image_format {
ImageFormat::PNG => {
use plotters::prelude::*;
let root = BitMapBackend::new(filename, (self.config.plot_width, self.config.plot_height)).into_drawing_area();
root.fill(&WHITE)?;
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let z_min = plot_data.z_values.iter().cloned().fold(f64::INFINITY, f64::min);
let z_max = plot_data.z_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(&root)
.caption(&plot_data.title, ("sans-serif", self.config.font_size))
.margin(5)
.build_cartesian_3d(x_min..x_max, z_min..z_max, y_min..y_max)?;
chart.configure_axes().draw()?;
// Color-coded scatter points based on color values
let colors = [&RED, &BLUE, &GREEN, &MAGENTA, &CYAN];
chart.draw_series(
plot_data.x_values.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
.enumerate()
.map(|(i, ((&x, &y), &z))| {
let color_idx = if let Some(color_values) = &plot_data.color_values {
(*color_values.get(i).unwrap_or(&0.0) as usize) % colors.len()
} else {
i % colors.len()
};
let size = if let Some(size_values) = &plot_data.size_values {
(*size_values.get(i).unwrap_or(&2.0) * 2.0) as i32
} else {
4
};
Circle::new((x, z, y), size, colors[color_idx].filled())
})
)?
.label("Data Points")
.legend(|(x, y)| Circle::new((x, y), 3, BLUE.filled()));
chart.configure_series_labels().draw()?;
root.present()?;
},
_ => return self.create_fallback_3d_scatter(plot_data, filename),
}
}
#[cfg(not(feature = "visual"))]
{
self.create_fallback_3d_scatter(plot_data, filename)?;
}
// Create CSV backup
self.create_fallback_3d_scatter(plot_data, filename)
}
fn create_architecture_diagram_file(&self, filename: &str) -> Result<()> {
// Create JSON representation of architecture
let architecture = serde_json::json!({
"nodes": self.architecture_nodes,
"connections": self.architecture_connections
});
std::fs::write(
filename.replace(&self.format_extension(), "json"),
serde_json::to_string_pretty(&architecture)?,
)?;
// Create summary file
let summary = format!(
"Architecture Diagram\nNodes: {}\nConnections: {}\n",
self.architecture_nodes.len(),
self.architecture_connections.len()
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), summary)?;
Ok(())
}
// Helper methods
fn reduce_dimensions(&self, vectors: &[Vec<f64>], method: &str) -> Result<Vec<Vec<f64>>> {
if vectors.is_empty() {
return Ok(vec![]);
}
match method {
"pca" => self.pca_reduction(vectors),
"tsne" => self.tsne_reduction(vectors),
_ => self.random_projection(vectors),
}
}
fn pca_reduction(&self, vectors: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
// Simplified PCA - in real implementation would use proper PCA algorithm
let mut result = Vec::new();
for (i, vector) in vectors.iter().enumerate() {
let reduced = if vector.len() >= 3 {
vec![vector[0], vector[1], vector[2]]
} else {
let mut padded = vector.clone();
while padded.len() < 3 {
padded.push(0.0);
}
padded
};
result.push(reduced);
}
Ok(result)
}
fn tsne_reduction(&self, vectors: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
// Simplified t-SNE - in real implementation would use proper t-SNE algorithm
let mut result = Vec::new();
for (i, vector) in vectors.iter().enumerate() {
let reduced = if vector.len() >= 3 {
vec![
vector[0] + (i as f64 * 0.1),
vector[1] + (i as f64 * 0.05),
vector[2] + (i as f64 * 0.02),
]
} else {
vec![i as f64, i as f64 * 0.5, i as f64 * 0.25]
};
result.push(reduced);
}
Ok(result)
}
fn random_projection(&self, vectors: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
// Simple random projection to 3D
let mut result = Vec::new();
for (i, vector) in vectors.iter().enumerate() {
let sum: f64 = vector.iter().sum();
let reduced = vec![
sum * 0.5 + i as f64 * 0.1,
sum * 0.3 + i as f64 * 0.05,
sum * 0.2 + i as f64 * 0.02,
];
result.push(reduced);
}
Ok(result)
}
fn assign_cluster_colors(&self, labels: &[String]) -> Vec<f64> {
let mut unique_labels: Vec<&String> = labels.iter().collect();
unique_labels.sort();
unique_labels.dedup();
labels
.iter()
.map(|label| unique_labels.iter().position(|&x| x == label).unwrap_or(0) as f64)
.collect()
}
fn calculate_node_size(&self, shape: &[usize]) -> (f64, f64, f64) {
let param_count = shape.iter().product::<usize>();
let base_size = (param_count as f64).log10().max(1.0);
(base_size, base_size * 0.5, base_size * 0.3)
}
fn get_layer_color(&self, layer_type: &str) -> String {
match layer_type.to_lowercase().as_str() {
"linear" | "dense" => "#4CAF50".to_string(),
"conv" | "conv2d" => "#2196F3".to_string(),
"attention" => "#FF9800".to_string(),
"embedding" => "#9C27B0".to_string(),
"norm" | "normalization" => "#FFC107".to_string(),
"activation" => "#F44336".to_string(),
_ => "#607D8B".to_string(),
}
}
fn get_connection_color(&self, weight: f64) -> String {
let intensity = weight.abs().min(1.0);
if weight > 0.0 {
format!("rgba(0, 255, 0, {:.2})", intensity)
} else {
format!("rgba(255, 0, 0, {:.2})", intensity)
}
}
fn calculate_trajectory_distance(&self, plot_data: &Plot3DData) -> f64 {
let mut total_distance = 0.0;
for i in 1..plot_data.x_values.len() {
let dx = plot_data.x_values[i] - plot_data.x_values[i - 1];
let dy = plot_data.y_values[i] - plot_data.y_values[i - 1];
let dz = plot_data.z_values[i] - plot_data.z_values[i - 1];
total_distance += (dx * dx + dy * dy + dz * dz).sqrt();
}
total_distance
}
// Generic plotting methods for different backends
#[cfg(feature = "visual")]
fn create_line_plot_with_backend<DB: plotters::prelude::DrawingBackend>(&self, root: &plotters::prelude::DrawingArea<DB, plotters::coord::Shift>, plot_data: &PlotData) -> Result<()>
where <DB as plotters::prelude::DrawingBackend>::ErrorType: 'static {
use plotters::prelude::*;
root.fill(&WHITE)?;
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(root)
.caption(&plot_data.title, ("sans-serif", self.config.font_size))
.margin(5)
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(x_min..x_max, y_min..y_max)?;
chart.configure_mesh()
.x_desc(&plot_data.x_label)
.y_desc(&plot_data.y_label)
.draw()?;
let line_data: Vec<(f64, f64)> = plot_data.x_values.iter()
.zip(plot_data.y_values.iter())
.map(|(&x, &y)| (x, y))
.collect();
chart.draw_series(LineSeries::new(line_data, &BLUE))?
.label(plot_data.labels.first().unwrap_or(&"Data".to_string()))
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 10, y)], &BLUE));
chart.configure_series_labels().draw()?;
root.present()?;
Ok(())
}
#[cfg(feature = "visual")]
fn create_histogram_with_backend<DB: plotters::prelude::DrawingBackend>(&self, root: &plotters::prelude::DrawingArea<DB, plotters::coord::Shift>, plot_data: &PlotData) -> Result<()>
where <DB as plotters::prelude::DrawingBackend>::ErrorType: 'static {
use plotters::prelude::*;
root.fill(&WHITE)?;
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(root)
.caption(&plot_data.title, ("sans-serif", self.config.font_size))
.margin(5)
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(x_min..x_max, 0f64..y_max)?;
chart.configure_mesh()
.x_desc(&plot_data.x_label)
.y_desc(&plot_data.y_label)
.draw()?;
let bin_width = if plot_data.x_values.len() > 1 {
plot_data.x_values[1] - plot_data.x_values[0]
} else {
1.0
};
let histogram_data: Vec<Rectangle<(f64, f64)>> = plot_data.x_values.iter()
.zip(plot_data.y_values.iter())
.map(|(&x, &y)| {
Rectangle::new([(x - bin_width/2.0, 0.0), (x + bin_width/2.0, y)], BLUE.filled())
})
.collect();
chart.draw_series(histogram_data)?
.label("Frequency")
.legend(|(x, y)| Rectangle::new([(x, y), (x + 10, y + 10)], BLUE.filled()));
chart.configure_series_labels().draw()?;
root.present()?;
Ok(())
}
// Fallback methods for when visual features are not available
fn create_fallback_line_plot(&self, plot_data: &PlotData, filename: &str) -> Result<()> {
let csv_data = self.create_csv_data(plot_data)?;
std::fs::write(filename.replace(&self.format_extension(), "csv"), csv_data)?;
let text_plot = format!(
"Line Plot: {}\nX: {}\nY: {}\nData points: {}\n",
plot_data.title,
plot_data.x_label,
plot_data.y_label,
plot_data.x_values.len()
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), text_plot)?;
Ok(())
}
fn create_fallback_histogram(&self, plot_data: &PlotData, filename: &str) -> Result<()> {
let csv_data = self.create_csv_data(plot_data)?;
std::fs::write(filename.replace(&self.format_extension(), "csv"), csv_data)?;
let text_plot = format!(
"Histogram: {}\nBins: {}\nTotal frequency: {}\n",
plot_data.title,
plot_data.x_values.len(),
plot_data.y_values.iter().sum::<f64>()
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), text_plot)?;
Ok(())
}
// 3D fallback methods
fn create_fallback_3d_surface(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
// Create CSV data for 3D surface
let mut csv_content = format!(
"{},{},{}\n",
plot_data.x_label, plot_data.y_label, plot_data.z_label
);
for ((x, y), z) in plot_data
.x_values
.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
{
csv_content.push_str(&format!("{},{},{}\n", x, y, z));
}
std::fs::write(
filename.replace(&self.format_extension(), "csv"),
csv_content,
)?;
// Create metadata file
let metadata = format!(
"3D Surface Plot: {}\nPoints: {}\nX Range: [{:.4}, {:.4}]\nY Range: [{:.4}, {:.4}]\nZ Range: [{:.4}, {:.4}]\n",
plot_data.title,
plot_data.x_values.len(),
plot_data.x_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
plot_data.x_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
plot_data.y_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
plot_data.y_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
plot_data.z_values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
plot_data.z_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), metadata)?;
Ok(())
}
fn create_fallback_3d_trajectory(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
// Create CSV data for 3D trajectory
let mut csv_content = format!(
"step,{},{},{}\n",
plot_data.x_label, plot_data.y_label, plot_data.z_label
);
for (i, ((x, y), z)) in plot_data
.x_values
.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
.enumerate()
{
csv_content.push_str(&format!("{},{},{},{}\n", i, x, y, z));
}
std::fs::write(
filename.replace(&self.format_extension(), "csv"),
csv_content,
)?;
// Create metadata file
let metadata = format!(
"3D Trajectory Plot: {}\nSteps: {}\nTotal Distance: {:.4}\nStart: ({:.4}, {:.4}, {:.4})\nEnd: ({:.4}, {:.4}, {:.4})\n",
plot_data.title,
plot_data.x_values.len(),
self.calculate_trajectory_distance(plot_data),
plot_data.x_values.first().unwrap_or(&0.0),
plot_data.y_values.first().unwrap_or(&0.0),
plot_data.z_values.first().unwrap_or(&0.0),
plot_data.x_values.last().unwrap_or(&0.0),
plot_data.y_values.last().unwrap_or(&0.0),
plot_data.z_values.last().unwrap_or(&0.0)
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), metadata)?;
Ok(())
}
fn create_fallback_3d_scatter(&self, plot_data: &Plot3DData, filename: &str) -> Result<()> {
// Create CSV data for 3D scatter
let mut csv_content = format!(
"label,{},{},{}\n",
plot_data.x_label, plot_data.y_label, plot_data.z_label
);
for (i, ((x, y), z)) in plot_data
.x_values
.iter()
.zip(plot_data.y_values.iter())
.zip(plot_data.z_values.iter())
.enumerate()
{
let default_label = format!("Point_{}", i);
let label = plot_data.point_labels.get(i).unwrap_or(&default_label);
csv_content.push_str(&format!("{},{},{},{}\n", label, x, y, z));
}
std::fs::write(
filename.replace(&self.format_extension(), "csv"),
csv_content,
)?;
// Create metadata file
let metadata = format!(
"3D Scatter Plot: {}\nPoints: {}\nLabels: {}\n",
plot_data.title,
plot_data.x_values.len(),
plot_data.point_labels.len()
);
std::fs::write(filename.replace(&self.format_extension(), "txt"), metadata)?;
Ok(())
}
// Export format methods
/// Save plot as PNG format
pub fn save_plot_png(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
self.create_line_plot_file(plot_data, file_path.to_str().unwrap_or("plot.png"))
}
/// Save plot as SVG format
pub fn save_plot_svg(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
let svg_content = self.generate_svg_plot(plot_data)?;
std::fs::write(file_path, svg_content)?;
Ok(())
}
/// Save plot as PDF format
pub fn save_plot_pdf(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
// For PDF, we'll create a simple fallback
self.create_fallback_line_plot(plot_data, file_path.to_str().unwrap_or("plot.pdf"))
}
/// Save plot as HTML format
pub fn save_plot_html(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
let html_content = self.generate_html_plot(plot_data)?;
std::fs::write(file_path, html_content)?;
Ok(())
}
/// Save plot as LaTeX format
pub fn save_plot_latex(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
let latex_content = self.generate_latex_plot(plot_data)?;
std::fs::write(file_path, latex_content)?;
Ok(())
}
/// Save plot as JSON format
pub fn save_plot_json(&self, plot_data: &PlotData, file_path: &std::path::Path) -> Result<()> {
let json_content = serde_json::to_string_pretty(plot_data)?;
std::fs::write(file_path, json_content)?;
Ok(())
}
/// Generate SVG plot content
fn generate_svg_plot(&self, plot_data: &PlotData) -> Result<String> {
let width = self.config.plot_width;
let height = self.config.plot_height;
let margin = 60;
let plot_width = width - 2 * margin;
let plot_height = height - 2 * margin;
// Calculate data ranges
let x_min = plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let x_range = x_max - x_min;
let y_range = y_max - y_min;
let mut svg = format!(
r##"<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">
<defs>
<style>
.title {{ font-family: Arial, sans-serif; font-size: {}px; font-weight: bold; text-anchor: middle; }}
.axis-label {{ font-family: Arial, sans-serif; font-size: {}px; text-anchor: middle; }}
.tick-label {{ font-family: Arial, sans-serif; font-size: {}px; text-anchor: middle; }}
.grid-line {{ stroke: #e0e0e0; stroke-width: 1; }}
.axis-line {{ stroke: #000; stroke-width: 2; }}
.plot-line {{ stroke: #2196F3; stroke-width: 2; fill: none; }}
</style>
</defs>
<rect width="{}" height="{}" fill="white"/>
<text x="{}" y="30" class="title">{}</text>
<rect x="{}" y="{}" width="{}" height="{}" fill="#fafafa" stroke="#ddd"/>
"##,
width, height, width, height,
self.config.font_size + 4, // title font size
self.config.font_size, // axis label font size
self.config.font_size - 2, // tick label font size
width, height,
width / 2, plot_data.title,
margin, margin, plot_width, plot_height
);
// Add vertical grid lines
for i in 0..=5 {
let x = margin + (i * plot_width / 5) as u32;
svg.push_str(&format!(
r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid-line"/>
"#,
x, margin, x, margin + plot_height
));
}
// Add horizontal grid lines
for i in 0..=5 {
let y = margin + (i * plot_height / 5) as u32;
svg.push_str(&format!(
r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="grid-line"/>
"#,
margin, y, margin + plot_width, y
));
}
// Add axes
svg.push_str(&format!(
r#"
<line x1="{}" y1="{}" x2="{}" y2="{}" class="axis-line"/>
<line x1="{}" y1="{}" x2="{}" y2="{}" class="axis-line"/>
"#,
margin, margin + plot_height, margin + plot_width, margin + plot_height, // x-axis
margin, margin, margin, margin + plot_height // y-axis
));
// Add data points and line
if !plot_data.x_values.is_empty() && !plot_data.y_values.is_empty() {
let mut path_data = String::new();
for (i, (&x, &y)) in plot_data.x_values.iter().zip(plot_data.y_values.iter()).enumerate() {
let plot_x = margin as f64 + (x - x_min) / x_range * plot_width as f64;
let plot_y = margin as f64 + plot_height as f64 - (y - y_min) / y_range * plot_height as f64;
if i == 0 {
path_data.push_str(&format!("M {} {}", plot_x, plot_y));
} else {
path_data.push_str(&format!(" L {} {}", plot_x, plot_y));
}
}
svg.push_str(&format!(
r#"
<path d="{}" class="plot-line"/>
"#,
path_data
));
}
// Add axis labels
svg.push_str(&format!(
r#"
<text x="{}" y="{}" class="axis-label">{}</text>
<text x="20" y="{}" class="axis-label" transform="rotate(-90, 20, {})">{}</text>
"#,
margin + plot_width / 2, height - 10, plot_data.x_label,
margin + plot_height / 2, margin + plot_height / 2, plot_data.y_label
));
svg.push_str("</svg>");
Ok(svg)
}
/// Generate HTML plot content
fn generate_html_plot(&self, plot_data: &PlotData) -> Result<String> {
let json_data = serde_json::to_string(&plot_data)?;
let html = format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{}</title>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
background-color: #f5f5f5;
}}
.container {{
background-color: white;
padding: 20px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
h1 {{
color: #333;
text-align: center;
margin-bottom: 30px;
}}
#plotDiv {{
width: 100%;
height: 600px;
}}
.metadata {{
margin-top: 20px;
padding: 15px;
background-color: #f8f9fa;
border-radius: 4px;
border-left: 4px solid #2196F3;
}}
.stats-table {{
width: 100%;
border-collapse: collapse;
margin-top: 10px;
}}
.stats-table th, .stats-table td {{
padding: 8px 12px;
text-align: left;
border-bottom: 1px solid #ddd;
}}
.stats-table th {{
background-color: #f5f5f5;
font-weight: bold;
}}
</style>
</head>
<body>
<div class="container">
<h1>{}</h1>
<div id="plotDiv"></div>
<div class="metadata">
<h3>Plot Metadata</h3>
<table class="stats-table">
<tr><th>Property</th><th>Value</th></tr>
<tr><td>Data Points</td><td>{}</td></tr>
<tr><td>X Label</td><td>{}</td></tr>
<tr><td>Y Label</td><td>{}</td></tr>
<tr><td>Generated</td><td>{}</td></tr>
</table>
</div>
</div>
<script>
const plotData = {};
const trace = {{
x: plotData.x_values,
y: plotData.y_values,
type: 'scatter',
mode: 'lines+markers',
name: plotData.labels[0] || 'Data',
line: {{ color: '#2196F3', width: 2 }},
marker: {{ color: '#2196F3', size: 4 }}
}};
const layout = {{
title: {{
text: plotData.title,
font: {{ size: 18, color: '#333' }}
}},
xaxis: {{
title: {{ text: plotData.x_label, font: {{ size: 14 }} }},
gridcolor: '#e0e0e0',
showgrid: true
}},
yaxis: {{
title: {{ text: plotData.y_label, font: {{ size: 14 }} }},
gridcolor: '#e0e0e0',
showgrid: true
}},
plot_bgcolor: '#fafafa',
paper_bgcolor: 'white',
showlegend: true,
legend: {{
x: 0.7,
y: 0.95,
bgcolor: 'rgba(255,255,255,0.8)',
bordercolor: '#ddd',
borderwidth: 1
}}
}};
const config = {{
responsive: true,
displayModeBar: true,
modeBarButtonsToRemove: ['pan2d', 'lasso2d', 'select2d'],
displaylogo: false
}};
Plotly.newPlot('plotDiv', [trace], layout, config);
</script>
</body>
</html>"#,
plot_data.title,
plot_data.title,
plot_data.x_values.len(),
plot_data.x_label,
plot_data.y_label,
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC"),
json_data
);
Ok(html)
}
/// Generate LaTeX plot content
fn generate_latex_plot(&self, plot_data: &PlotData) -> Result<String> {
let mut latex = format!(
r#"\documentclass{{article}}
\usepackage{{pgfplots}}
\usepackage{{amsmath}}
\usepackage{{geometry}}
\geometry{{a4paper, margin=1in}}
\pgfplotsset{{compat=1.17}}
\begin{{document}}
\title{{{} - TrustformeRS Debug Visualization}}
\author{{Generated by TrustformeRS}}
\date{{\today}}
\maketitle
\section{{Plot: {}}}
\begin{{figure}}[h]
\centering
\begin{{tikzpicture}}
\begin{{axis}}[
title={{{}}},
xlabel={{{}}},
ylabel={{{}}},
width=0.8\textwidth,
height=0.6\textwidth,
grid=major,
grid style={{dashed,gray!30}},
legend pos=north west,
line width=1.5pt
]
\addplot[
color=blue,
mark=*,
mark size=1.5pt
] coordinates {{
"#,
plot_data.title.replace("_", r"\_"),
plot_data.title.replace("_", r"\_"),
plot_data.title.replace("_", r"\_"),
plot_data.x_label.replace("_", r"\_"),
plot_data.y_label.replace("_", r"\_")
);
// Add coordinate data
for (&x, &y) in plot_data.x_values.iter().zip(plot_data.y_values.iter()) {
latex.push_str(&format!(" ({:.6},{:.6})\n", x, y));
}
latex.push_str(&format!(
r#"}};
\legend{{{}}}
\end{{axis}}
\end{{tikzpicture}}
\caption{{Plot showing the relationship between {} and {}.}}
\label{{fig:{}}}
\end{{figure}}
\section{{Data Summary}}
\begin{{itemize}}
\item Total data points: {}
\item X-axis range: [{:.6}, {:.6}]
\item Y-axis range: [{:.6}, {:.6}]
\item Mean X: {:.6}
\item Mean Y: {:.6}
\end{{itemize}}
\section{{Raw Data}}
\begin{{table}}[h]
\centering
\begin{{tabular}}{{|c|c|}}
\hline
{} & {} \\
\hline
"#,
plot_data.labels.first().unwrap_or(&"Data".to_string()).replace("_", r"\_"),
plot_data.x_label.replace("_", r"\_"),
plot_data.y_label.replace("_", r"\_"),
plot_data.title.replace("_", r"\_").replace(" ", "_").to_lowercase(),
plot_data.x_values.len(),
plot_data.x_values.iter().cloned().fold(f64::INFINITY, f64::min),
plot_data.x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
plot_data.y_values.iter().cloned().fold(f64::INFINITY, f64::min),
plot_data.y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
plot_data.x_values.iter().sum::<f64>() / plot_data.x_values.len() as f64,
plot_data.y_values.iter().sum::<f64>() / plot_data.y_values.len() as f64,
plot_data.x_label.replace("_", r"\_"),
plot_data.y_label.replace("_", r"\_")
));
// Add first 20 data points to the table (to avoid overly long tables)
let display_count = plot_data.x_values.len().min(20);
for i in 0..display_count {
latex.push_str(&format!(
"{:.6} & {:.6} \\\\\n",
plot_data.x_values[i], plot_data.y_values[i]
));
}
if plot_data.x_values.len() > 20 {
latex.push_str(&format!(
"\\multicolumn{{2}}{{|c|}}{{... and {} more data points}} \\\\\n",
plot_data.x_values.len() - 20
));
}
latex.push_str(&format!(
r#"\hline
\end{{tabular}}
\caption{{Data table for {}.}}
\label{{tab:{}}}
\end{{table}}
\end{{document}}
"#,
plot_data.title.replace("_", r"\_"),
plot_data.title.replace("_", r"\_").replace(" ", "_").to_lowercase()
));
Ok(latex)
}
/// Create a simple animation sequence from a single plot for video export
fn create_simple_animation_sequence(&self, plot: &PlotData) -> Result<AnimationSequence> {
// For a single plot, create a simple animation that builds up the data points
let mut frames = Vec::new();
let total_points = plot.x_values.len();
if total_points == 0 {
return Err(anyhow::anyhow!("Cannot create animation from empty plot data"));
}
// Create frames that progressively show more data points
let frame_count = (total_points / 5).max(1).min(50); // Reasonable number of frames
for i in 1..=frame_count {
let end_index = (i * total_points / frame_count).min(total_points);
let frame_plot_data = PlotData {
x_values: plot.x_values[0..end_index].to_vec(),
y_values: plot.y_values[0..end_index].to_vec(),
labels: if plot.labels.len() >= end_index {
plot.labels[0..end_index].to_vec()
} else {
vec!["Data".to_string(); end_index]
},
title: format!("{} (Frame {} of {})", plot.title, i, frame_count),
x_label: plot.x_label.clone(),
y_label: plot.y_label.clone(),
};
let mut metadata = HashMap::new();
metadata.insert("frame".to_string(), i.to_string());
metadata.insert("total_frames".to_string(), frame_count.to_string());
metadata.insert("data_points".to_string(), end_index.to_string());
metadata.insert("progress".to_string(), format!("{:.1}%", (i as f64 / frame_count as f64) * 100.0));
frames.push(VideoFrame {
plot_data: frame_plot_data,
plot_3d_data: None,
timestamp: i as f64,
metadata,
});
}
Ok(AnimationSequence {
title: format!("Animation: {}", plot.title),
frames,
config: VideoConfig::default(),
animation_type: AnimationType::Custom("Progressive Plot".to_string()),
})
}
/// Create video animation from a sequence of plots
pub fn create_video_animation_from_plots(&self, plot_ids: &[String], title: &str) -> Result<AnimationSequence> {
let mut frames = Vec::new();
for (i, plot_id) in plot_ids.iter().enumerate() {
let plot = self.plots.get(plot_id).ok_or_else(|| {
anyhow::anyhow!("Plot with id '{}' not found", plot_id)
})?;
let mut metadata = HashMap::new();
metadata.insert("frame".to_string(), (i + 1).to_string());
metadata.insert("total_frames".to_string(), plot_ids.len().to_string());
metadata.insert("plot_id".to_string(), plot_id.clone());
frames.push(VideoFrame {
plot_data: plot.clone(),
plot_3d_data: self.plots_3d.get(plot_id).cloned(),
timestamp: i as f64,
metadata,
});
}
Ok(AnimationSequence {
title: title.to_string(),
frames,
config: VideoConfig::default(),
animation_type: AnimationType::Custom("Plot Sequence".to_string()),
})
}
/// Generate video export for training progress
pub fn export_training_progress_video(
&self,
output_path: &str,
format: ImageFormat,
loss_history: &[f64],
accuracy_history: &[f64],
timestamps: &[f64],
) -> Result<()> {
let video_generator = VideoGenerator::new(
VideoConfig::default(),
self.config.output_directory.clone(),
);
let animation = video_generator.create_training_animation(
loss_history,
accuracy_history,
timestamps,
)?;
match format {
ImageFormat::MP4 => video_generator.generate_mp4(&animation, output_path),
ImageFormat::GIF => video_generator.generate_gif(&animation, output_path),
ImageFormat::WebM => video_generator.generate_webm(&animation, output_path),
_ => Err(anyhow::anyhow!("Unsupported video format: {:?}", format)),
}
}
/// Generate video export for gradient flow
pub fn export_gradient_flow_video(
&self,
output_path: &str,
format: ImageFormat,
layer_names: &[String],
gradient_norms_history: &[Vec<f64>],
timestamps: &[f64],
) -> Result<()> {
let video_generator = VideoGenerator::new(
VideoConfig::default(),
self.config.output_directory.clone(),
);
let animation = video_generator.create_gradient_flow_animation(
layer_names,
gradient_norms_history,
timestamps,
)?;
match format {
ImageFormat::MP4 => video_generator.generate_mp4(&animation, output_path),
ImageFormat::GIF => video_generator.generate_gif(&animation, output_path),
ImageFormat::WebM => video_generator.generate_webm(&animation, output_path),
_ => Err(anyhow::anyhow!("Unsupported video format: {:?}", format)),
}
}
/// Generate video export for optimization trajectory
pub fn export_optimization_trajectory_video(
&self,
output_path: &str,
format: ImageFormat,
trajectory_points: &[(f64, f64, f64)],
timestamps: &[f64],
) -> Result<()> {
let video_generator = VideoGenerator::new(
VideoConfig::default(),
self.config.output_directory.clone(),
);
let animation = video_generator.create_optimization_trajectory_animation(
trajectory_points,
timestamps,
)?;
match format {
ImageFormat::MP4 => video_generator.generate_mp4(&animation, output_path),
ImageFormat::GIF => video_generator.generate_gif(&animation, output_path),
ImageFormat::WebM => video_generator.generate_webm(&animation, output_path),
_ => Err(anyhow::anyhow!("Unsupported video format: {:?}", format)),
}
}
/// Create a custom video generator with specific configuration
pub fn create_video_generator(&self, config: VideoConfig) -> VideoGenerator {
VideoGenerator::new(config, self.config.output_directory.clone())
}
}
/// Terminal-based visualization for headless environments
#[derive(Debug)]
pub struct TerminalVisualizer {
width: usize,
height: usize,
}
impl TerminalVisualizer {
pub fn new(width: usize, height: usize) -> Self {
Self { width, height }
}
/// Create a simple ASCII histogram
pub fn ascii_histogram(&self, values: &[f64], num_bins: usize) -> String {
if values.is_empty() {
return "No data to display".to_string();
}
let (min_val, max_val) =
values.iter().fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
(min.min(x), max.max(x))
});
if min_val == max_val {
return format!("All values equal to {:.4}", min_val);
}
let bin_width = (max_val - min_val) / num_bins as f64;
let mut bins = vec![0; num_bins];
for &value in values {
if value.is_finite() {
let bin_idx = ((value - min_val) / bin_width).floor() as usize;
let bin_idx = bin_idx.min(num_bins - 1);
bins[bin_idx] += 1;
}
}
let max_count = *bins.iter().max().unwrap_or(&0);
let scale = (self.width - 20) as f64 / max_count as f64;
let mut result = String::new();
result.push_str(&format!(
"Histogram ({} bins, {} values)\n",
num_bins,
values.len()
));
result.push_str(&format!("Range: [{:.4}, {:.4}]\n\n", min_val, max_val));
for (i, &count) in bins.iter().enumerate() {
let bin_start = min_val + i as f64 * bin_width;
let bin_end = bin_start + bin_width;
let bar_length = (count as f64 * scale) as usize;
let bar = "█".repeat(bar_length);
result.push_str(&format!(
"[{:8.4}, {:8.4}): {:6} |{}\n",
bin_start, bin_end, count, bar
));
}
result
}
/// Create a simple ASCII line plot
pub fn ascii_line_plot(&self, x_values: &[f64], y_values: &[f64], title: &str) -> String {
if x_values.len() != y_values.len() || x_values.is_empty() {
return "Invalid data for line plot".to_string();
}
let mut result = String::new();
result.push_str(&format!("{}\n", title));
result.push_str(&"=".repeat(title.len()));
result.push('\n');
// Find ranges
let (y_min, y_max) =
y_values.iter().fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &y| {
(min.min(y), max.max(y))
});
if y_min == y_max {
result.push_str(&format!("All Y values equal to {:.4}\n", y_min));
return result;
}
// Create simple plot
let plot_height = self.height.min(20);
let y_scale = (plot_height - 1) as f64 / (y_max - y_min);
for i in 0..plot_height {
let y_level = y_max - (i as f64 / y_scale);
result.push_str(&format!("{:8.2} |", y_level));
for &y in y_values {
let normalized_y = (y - y_min) * y_scale;
let plot_y = plot_height - 1 - normalized_y as usize;
if plot_y == i {
result.push('*');
} else {
result.push(' ');
}
}
result.push('\n');
}
// X-axis
result.push_str(" +");
result.push_str(&"-".repeat(y_values.len()));
result.push('\n');
result
}
/// Show tensor statistics in a formatted table
pub fn format_tensor_stats(
&self,
name: &str,
stats: &crate::tensor_inspector::TensorStats,
) -> String {
format!(
"┌─ Tensor: {} ─┐\n\
│ Shape: {:?}\n\
│ Elements: {}\n\
│ Mean: {:.6}\n\
│ Std: {:.6}\n\
│ Min: {:.6}\n\
│ Max: {:.6}\n\
│ L2 Norm: {:.6}\n\
│ NaN count: {}\n\
│ Inf count: {}\n\
│ Memory: {:.2} MB\n\
└─────────────────┘\n",
name,
stats.shape,
stats.total_elements,
stats.mean,
stats.std,
stats.min,
stats.max,
stats.l2_norm,
stats.nan_count,
stats.inf_count,
stats.memory_usage_bytes as f64 / (1024.0 * 1024.0)
)
}
}
impl Default for TerminalVisualizer {
fn default() -> Self {
Self::new(80, 24)
}
}
/// Video generation configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoConfig {
/// Frame rate for the video
pub fps: u32,
/// Duration of each frame in seconds
pub frame_duration: f64,
/// Quality setting (1-100)
pub quality: u32,
/// Loop the animation
pub loop_animation: bool,
/// Transition effects between frames
pub transitions: bool,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
fps: 30,
frame_duration: 0.1,
quality: 85,
loop_animation: true,
transitions: false,
}
}
}
/// Video frame data for animation sequences
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoFrame {
/// Plot data for this frame
pub plot_data: PlotData,
/// 3D plot data if applicable
pub plot_3d_data: Option<Plot3DData>,
/// Frame timestamp
pub timestamp: f64,
/// Frame metadata
pub metadata: HashMap<String, String>,
}
/// Animation sequence data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnimationSequence {
/// Title of the animation
pub title: String,
/// Sequence of frames
pub frames: Vec<VideoFrame>,
/// Video configuration
pub config: VideoConfig,
/// Animation type
pub animation_type: AnimationType,
}
/// Types of animations supported
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AnimationType {
/// Training progress animation
TrainingProgress,
/// Loss landscape evolution
LossLandscapeEvolution,
/// Gradient flow animation
GradientFlow,
/// Weight evolution over time
WeightEvolution,
/// Optimization trajectory
OptimizationTrajectory,
/// Model convergence visualization
ConvergenceVisualization,
/// Custom animation
Custom(String),
}
/// Video generator for creating animated visualizations
pub struct VideoGenerator {
config: VideoConfig,
output_directory: String,
}
impl VideoGenerator {
/// Create a new video generator
pub fn new(config: VideoConfig, output_directory: String) -> Self {
Self {
config,
output_directory,
}
}
/// Generate MP4 video from animation sequence
pub fn generate_mp4(&self, sequence: &AnimationSequence, output_path: &str) -> Result<()> {
self.ensure_output_directory()?;
// Generate individual frames first
let frame_dir = format!("{}/frames_{}", self.output_directory, uuid::Uuid::new_v4());
std::fs::create_dir_all(&frame_dir)?;
for (i, frame) in sequence.frames.iter().enumerate() {
let frame_path = format!("{}/frame_{:06}.png", frame_dir, i);
self.generate_frame_image(frame, &frame_path)?;
}
// Use ffmpeg to create video (if available)
#[cfg(feature = "video")]
{
self.create_mp4_with_ffmpeg(&frame_dir, output_path, &sequence.config)?;
}
#[cfg(not(feature = "video"))]
{
// Fallback: create HTML animation
self.create_html_animation(sequence, output_path)?;
}
// Cleanup frame directory
std::fs::remove_dir_all(&frame_dir)?;
Ok(())
}
/// Generate GIF animation from sequence
pub fn generate_gif(&self, sequence: &AnimationSequence, output_path: &str) -> Result<()> {
self.ensure_output_directory()?;
#[cfg(feature = "gif")]
{
self.create_gif_animation(sequence, output_path)?;
}
#[cfg(not(feature = "gif"))]
{
// Fallback: create HTML animation
self.create_html_animation(sequence, output_path)?;
}
Ok(())
}
/// Generate WebM video from sequence
pub fn generate_webm(&self, sequence: &AnimationSequence, output_path: &str) -> Result<()> {
self.ensure_output_directory()?;
#[cfg(feature = "video")]
{
// Generate individual frames first
let frame_dir = format!("{}/frames_{}", self.output_directory, uuid::Uuid::new_v4());
std::fs::create_dir_all(&frame_dir)?;
for (i, frame) in sequence.frames.iter().enumerate() {
let frame_path = format!("{}/frame_{:06}.png", frame_dir, i);
self.generate_frame_image(frame, &frame_path)?;
}
self.create_webm_with_ffmpeg(&frame_dir, output_path, &sequence.config)?;
// Cleanup frame directory
std::fs::remove_dir_all(&frame_dir)?;
}
#[cfg(not(feature = "video"))]
{
// Fallback: create HTML animation
self.create_html_animation(sequence, output_path)?;
}
Ok(())
}
/// Create training progress animation
pub fn create_training_animation(
&self,
loss_history: &[f64],
accuracy_history: &[f64],
timestamps: &[f64],
) -> Result<AnimationSequence> {
let mut frames = Vec::new();
for i in 1..=loss_history.len() {
let frame_loss = &loss_history[0..i];
let frame_accuracy = &accuracy_history[0..i];
let frame_timestamps = ×tamps[0..i];
let plot_data = PlotData {
x_values: frame_timestamps.to_vec(),
y_values: frame_loss.to_vec(),
labels: vec!["Loss".to_string()],
title: format!("Training Progress - Step {}", i),
x_label: "Training Step".to_string(),
y_label: "Loss".to_string(),
};
let mut metadata = HashMap::new();
metadata.insert("step".to_string(), i.to_string());
metadata.insert("current_loss".to_string(), format!("{:.6}", frame_loss[frame_loss.len() - 1]));
if i > 0 {
metadata.insert("current_accuracy".to_string(), format!("{:.4}", frame_accuracy[frame_accuracy.len() - 1]));
}
frames.push(VideoFrame {
plot_data,
plot_3d_data: None,
timestamp: frame_timestamps[frame_timestamps.len() - 1],
metadata,
});
}
Ok(AnimationSequence {
title: "Training Progress Animation".to_string(),
frames,
config: self.config.clone(),
animation_type: AnimationType::TrainingProgress,
})
}
/// Create gradient flow animation
pub fn create_gradient_flow_animation(
&self,
layer_names: &[String],
gradient_norms_history: &[Vec<f64>],
timestamps: &[f64],
) -> Result<AnimationSequence> {
let mut frames = Vec::new();
for (i, (gradient_norms, ×tamp)) in gradient_norms_history.iter().zip(timestamps.iter()).enumerate() {
let plot_data = PlotData {
x_values: (0..layer_names.len()).map(|x| x as f64).collect(),
y_values: gradient_norms.clone(),
labels: layer_names.to_vec(),
title: format!("Gradient Flow - Step {}", i + 1),
x_label: "Layer Index".to_string(),
y_label: "Gradient Norm".to_string(),
};
let mut metadata = HashMap::new();
metadata.insert("step".to_string(), (i + 1).to_string());
metadata.insert("total_gradient_norm".to_string(), format!("{:.6}", gradient_norms.iter().sum::<f64>()));
metadata.insert("max_gradient_norm".to_string(), format!("{:.6}", gradient_norms.iter().fold(0.0_f64, |a, &b| a.max(b))));
frames.push(VideoFrame {
plot_data,
plot_3d_data: None,
timestamp,
metadata,
});
}
Ok(AnimationSequence {
title: "Gradient Flow Animation".to_string(),
frames,
config: self.config.clone(),
animation_type: AnimationType::GradientFlow,
})
}
/// Create optimization trajectory animation
pub fn create_optimization_trajectory_animation(
&self,
trajectory_points: &[(f64, f64, f64)], // (x, y, loss)
timestamps: &[f64],
) -> Result<AnimationSequence> {
let mut frames = Vec::new();
for i in 1..=trajectory_points.len() {
let current_trajectory = &trajectory_points[0..i];
let plot_3d_data = Plot3DData {
x_values: current_trajectory.iter().map(|p| p.0).collect(),
y_values: current_trajectory.iter().map(|p| p.1).collect(),
z_values: current_trajectory.iter().map(|p| p.2).collect(),
title: format!("Optimization Trajectory - Step {}", i),
x_label: "Parameter 1".to_string(),
y_label: "Parameter 2".to_string(),
z_label: "Loss".to_string(),
point_labels: (1..=i).map(|j| format!("Step {}", j)).collect(),
color_values: Some((1..=i).map(|j| j as f64).collect()),
size_values: None, // Use default sizes
};
// Create a 2D projection as well
let plot_data = PlotData {
x_values: current_trajectory.iter().map(|p| p.0).collect(),
y_values: current_trajectory.iter().map(|p| p.1).collect(),
labels: (1..=i).map(|j| format!("Step {}", j)).collect(),
title: format!("Optimization Trajectory (2D) - Step {}", i),
x_label: "Parameter 1".to_string(),
y_label: "Parameter 2".to_string(),
};
let mut metadata = HashMap::new();
metadata.insert("step".to_string(), i.to_string());
metadata.insert("current_loss".to_string(), format!("{:.6}", trajectory_points[i-1].2));
metadata.insert("distance_from_start".to_string(), {
let start = trajectory_points[0];
let current = trajectory_points[i-1];
let distance = ((current.0 - start.0).powi(2) + (current.1 - start.1).powi(2)).sqrt();
format!("{:.6}", distance)
});
frames.push(VideoFrame {
plot_data,
plot_3d_data: Some(plot_3d_data),
timestamp: timestamps[i-1],
metadata,
});
}
Ok(AnimationSequence {
title: "Optimization Trajectory Animation".to_string(),
frames,
config: self.config.clone(),
animation_type: AnimationType::OptimizationTrajectory,
})
}
/// Generate a single frame image
fn generate_frame_image(&self, frame: &VideoFrame, output_path: &str) -> Result<()> {
// For now, generate a simple PNG using the existing plotting functionality
// This would need to be integrated with the actual plotting system
#[cfg(feature = "visual")]
{
// Generate plot using plotters if available
self.generate_frame_with_plotters(frame, output_path)?;
}
#[cfg(not(feature = "visual"))]
{
// Fallback: create a simple text-based representation
let content = format!(
"Frame: {}\nTimestamp: {:.3}\nData points: {}\nTitle: {}",
frame.metadata.get("step").unwrap_or(&"?".to_string()),
frame.timestamp,
frame.plot_data.x_values.len(),
frame.plot_data.title
);
std::fs::write(format!("{}.txt", output_path), content)?;
}
Ok(())
}
#[cfg(feature = "visual")]
fn generate_frame_with_plotters(&self, frame: &VideoFrame, output_path: &str) -> Result<()> {
use plotters::prelude::*;
let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption(&frame.plot_data.title, ("sans-serif", 30))
.margin(10)
.x_label_area_size(40)
.y_label_area_size(50)
.build_cartesian_2d(
frame.plot_data.x_values.iter().fold(f64::INFINITY, |a, &b| a.min(b))
..frame.plot_data.x_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
frame.plot_data.y_values.iter().fold(f64::INFINITY, |a, &b| a.min(b))
..frame.plot_data.y_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
)?;
chart.configure_mesh()
.x_desc(&frame.plot_data.x_label)
.y_desc(&frame.plot_data.y_label)
.draw()?;
chart.draw_series(LineSeries::new(
frame.plot_data.x_values.iter().zip(frame.plot_data.y_values.iter()).map(|(&x, &y)| (x, y)),
&BLUE,
))?;
root.present()?;
Ok(())
}
#[cfg(feature = "video")]
fn create_mp4_with_ffmpeg(&self, frame_dir: &str, output_path: &str, config: &VideoConfig) -> Result<()> {
use std::process::Command;
let output = Command::new("ffmpeg")
.args(&[
"-y", // Overwrite output file
"-framerate", &config.fps.to_string(),
"-i", &format!("{}/frame_%06d.png", frame_dir),
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-crf", &((100 - config.quality) / 4).to_string(), // Convert quality to CRF
output_path,
])
.output()?;
if !output.status.success() {
return Err(anyhow::anyhow!("ffmpeg failed: {}", String::from_utf8_lossy(&output.stderr)));
}
Ok(())
}
#[cfg(feature = "video")]
fn create_webm_with_ffmpeg(&self, frame_dir: &str, output_path: &str, config: &VideoConfig) -> Result<()> {
use std::process::Command;
let output = Command::new("ffmpeg")
.args(&[
"-y", // Overwrite output file
"-framerate", &config.fps.to_string(),
"-i", &format!("{}/frame_%06d.png", frame_dir),
"-c:v", "libvpx-vp9",
"-pix_fmt", "yuv420p",
"-crf", &((100 - config.quality) / 3).to_string(), // Convert quality to CRF
"-b:v", "0", // Use constant quality mode
output_path,
])
.output()?;
if !output.status.success() {
return Err(anyhow::anyhow!("ffmpeg failed: {}", String::from_utf8_lossy(&output.stderr)));
}
Ok(())
}
#[cfg(feature = "gif")]
fn create_gif_animation(&self, sequence: &AnimationSequence, output_path: &str) -> Result<()> {
// This would use a GIF encoding library like 'gif' crate
// For now, create an HTML animation as fallback
self.create_html_animation(sequence, output_path)
}
/// Create HTML animation as fallback
fn create_html_animation(&self, sequence: &AnimationSequence, output_path: &str) -> Result<()> {
let frames_json = serde_json::to_string(&sequence.frames)?;
let html_content = format!(r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{}</title>
<script src="https://cdn.plot.ly/plotly-2.26.0.min.js"></script>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
background-color: #f5f5f5;
}}
.animation-container {{
background: white;
border-radius: 8px;
padding: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}}
.controls {{
margin: 20px 0;
text-align: center;
}}
button {{
margin: 0 5px;
padding: 10px 20px;
background: #007bff;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}}
button:hover {{
background: #0056b3;
}}
.frame-info {{
margin: 10px 0;
font-size: 14px;
color: #666;
}}
</style>
</head>
<body>
<div class="animation-container">
<h1>{}</h1>
<div id="plot"></div>
<div class="controls">
<button onclick="playAnimation()">Play</button>
<button onclick="pauseAnimation()">Pause</button>
<button onclick="resetAnimation()">Reset</button>
<input type="range" id="frameSlider" min="0" max="{}" value="0" oninput="showFrame(this.value)">
</div>
<div class="frame-info">
<span id="frameInfo">Frame 1 of {}</span>
</div>
</div>
<script>
const frames = {};
let currentFrame = 0;
let isPlaying = false;
let animationInterval;
function plotFrame(frameIndex) {{
const frame = frames[frameIndex];
const trace = {{
x: frame.plot_data.x_values,
y: frame.plot_data.y_values,
type: 'scatter',
mode: 'lines+markers',
name: 'Data',
line: {{ color: '#007bff', width: 2 }},
marker: {{ size: 6 }}
}};
const layout = {{
title: frame.plot_data.title,
xaxis: {{ title: frame.plot_data.x_label }},
yaxis: {{ title: frame.plot_data.y_label }},
showlegend: false,
margin: {{ t: 50, l: 60, r: 30, b: 60 }}
}};
Plotly.newPlot('plot', [trace], layout, {{ responsive: true }});
document.getElementById('frameInfo').textContent =
`Frame ${{frameIndex + 1}} of ${{frames.length}} - Step: ${{frame.metadata.step || 'N/A'}}`;
document.getElementById('frameSlider').value = frameIndex;
}}
function playAnimation() {{
if (isPlaying) return;
isPlaying = true;
animationInterval = setInterval(() => {{
currentFrame = (currentFrame + 1) % frames.length;
plotFrame(currentFrame);
if (currentFrame === 0 && !{}) {{
pauseAnimation();
}}
}}, {});
}}
function pauseAnimation() {{
isPlaying = false;
if (animationInterval) {{
clearInterval(animationInterval);
}}
}}
function resetAnimation() {{
pauseAnimation();
currentFrame = 0;
plotFrame(currentFrame);
}}
function showFrame(frameIndex) {{
currentFrame = parseInt(frameIndex);
plotFrame(currentFrame);
}}
if (frames.length > 0) {{
plotFrame(0);
}}
</script>
</body>
</html>"#,
sequence.title,
sequence.title,
sequence.frames.len() - 1,
sequence.frames.len(),
frames_json,
sequence.config.loop_animation.to_string().to_lowercase(),
(sequence.config.frame_duration * 1000.0) as u32
);
std::fs::write(output_path, html_content)?;
Ok(())
}
fn ensure_output_directory(&self) -> Result<()> {
std::fs::create_dir_all(&self.output_directory)?;
Ok(())
}
}
impl Default for VideoGenerator {
fn default() -> Self {
Self::new(VideoConfig::default(), "./debug_videos".to_string())
}
}