use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PlotType {
Line,
Scatter,
Bar,
Histogram,
Heatmap,
ThreeDimensional,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualizationConfig {
pub output_directory: String,
pub image_format: ImageFormat,
pub plot_width: u32,
pub plot_height: u32,
pub font_size: u32,
pub color_scheme: ColorScheme,
}
impl Default for VisualizationConfig {
fn default() -> Self {
Self {
output_directory: "./debug_plots".to_string(),
image_format: ImageFormat::PNG,
plot_width: 800,
plot_height: 600,
font_size: 12,
color_scheme: ColorScheme::Default,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ImageFormat {
PNG,
SVG,
PDF,
HTML,
LaTeX,
JSON,
MP4,
GIF,
WebM,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColorScheme {
Default,
Dark,
Colorblind,
Viridis,
Plasma,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VisualizationType {
LinePlot,
Histogram,
Heatmap,
ScatterPlot,
BoxPlot,
SurfacePlot,
LossLandscape,
OptimizationTrajectory,
WeightSpaceExploration,
EmbeddingProjection,
ArchitectureDiagram,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlotData {
pub x_values: Vec<f64>,
pub y_values: Vec<f64>,
pub labels: Vec<String>,
pub title: String,
pub x_label: String,
pub y_label: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plot3DData {
pub x_values: Vec<f64>,
pub y_values: Vec<f64>,
pub z_values: Vec<f64>,
pub title: String,
pub x_label: String,
pub y_label: String,
pub z_label: String,
pub point_labels: Vec<String>,
pub color_values: Option<Vec<f64>>,
pub size_values: Option<Vec<f64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureNode {
pub id: String,
pub name: String,
pub node_type: String,
pub position: (f64, f64, f64),
pub size: (f64, f64, f64),
pub color: String,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureConnection {
pub from_node: String,
pub to_node: String,
pub connection_type: String,
pub weight: f64,
pub color: String,
pub style: ConnectionStyle,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConnectionStyle {
Solid,
Dashed,
Dotted,
Thick,
Arrow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeatmapData {
pub values: Vec<Vec<f64>>,
pub x_labels: Vec<String>,
pub y_labels: Vec<String>,
pub title: String,
pub color_bar_label: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramData {
pub values: Vec<f64>,
pub bins: usize,
pub title: String,
pub x_label: String,
pub y_label: String,
pub density: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_visualization_config_default() {
let cfg = VisualizationConfig::default();
assert_eq!(cfg.plot_width, 800);
assert_eq!(cfg.plot_height, 600);
assert_eq!(cfg.font_size, 12);
}
#[test]
fn test_visualization_config_clone() {
let cfg = VisualizationConfig::default();
let cloned = cfg.clone();
assert_eq!(cloned.plot_width, cfg.plot_width);
}
#[test]
fn test_plot_type_equality() {
assert_eq!(PlotType::Line, PlotType::Line);
assert_ne!(PlotType::Line, PlotType::Bar);
}
#[test]
fn test_plot_type_all_variants() {
let _ = [
PlotType::Line,
PlotType::Scatter,
PlotType::Bar,
PlotType::Histogram,
PlotType::Heatmap,
PlotType::ThreeDimensional,
];
}
#[test]
fn test_plot_data_creation() {
let data = PlotData {
x_values: vec![1.0, 2.0, 3.0],
y_values: vec![10.0, 20.0, 30.0],
labels: vec!["s1".to_string()],
title: "Test".to_string(),
x_label: "X".to_string(),
y_label: "Y".to_string(),
};
assert_eq!(data.x_values.len(), 3);
}
#[test]
fn test_heatmap_data_creation() {
let hm = HeatmapData {
values: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
x_labels: vec!["a".to_string(), "b".to_string()],
y_labels: vec!["r1".to_string(), "r2".to_string()],
title: "HM".to_string(),
color_bar_label: "val".to_string(),
};
assert_eq!(hm.values.len(), 2);
}
#[test]
fn test_histogram_data_creation() {
let hd = HistogramData {
values: vec![1.0, 2.0, 3.0],
bins: 10,
title: "Dist".to_string(),
x_label: "Value".to_string(),
y_label: "Count".to_string(),
density: false,
};
assert_eq!(hd.bins, 10);
}
#[test]
fn test_architecture_node_creation() {
let node = ArchitectureNode {
id: "n1".to_string(),
name: "Linear".to_string(),
node_type: "layer".to_string(),
position: (0.0, 0.0, 0.0),
size: (1.0, 1.0, 1.0),
color: "#ff0000".to_string(),
metadata: HashMap::new(),
};
assert_eq!(node.id, "n1");
}
#[test]
fn test_architecture_connection_creation() {
let conn = ArchitectureConnection {
from_node: "n1".to_string(),
to_node: "n2".to_string(),
connection_type: "forward".to_string(),
weight: 1.0,
color: "#00ff00".to_string(),
style: ConnectionStyle::Solid,
};
assert_eq!(conn.from_node, "n1");
}
#[test]
fn test_connection_style_variants() {
let _ = ConnectionStyle::Solid;
let _ = ConnectionStyle::Dashed;
let _ = ConnectionStyle::Arrow;
}
#[test]
fn test_image_format_variants() {
let _ = [
ImageFormat::PNG,
ImageFormat::SVG,
ImageFormat::HTML,
ImageFormat::MP4,
ImageFormat::GIF,
ImageFormat::WebM,
];
}
#[test]
fn test_color_scheme_variants() {
let _ = [
ColorScheme::Default,
ColorScheme::Dark,
ColorScheme::Viridis,
ColorScheme::Plasma,
];
}
#[test]
fn test_plot_3d_data_creation() {
let data = Plot3DData {
x_values: vec![0.0, 1.0],
y_values: vec![0.0, 1.0],
z_values: vec![0.0, 2.0],
title: "3D".to_string(),
x_label: "X".to_string(),
y_label: "Y".to_string(),
z_label: "Z".to_string(),
point_labels: vec!["p1".to_string()],
color_values: None,
size_values: None,
};
assert_eq!(data.x_values.len(), 2);
assert!(data.color_values.is_none());
}
}