use crate::error::RusTorchResult; use std::fs;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
pub enum PlotFormat {
Svg,
Png,
Pdf,
Html,
Dot,
}
impl PlotFormat {
pub fn extension(&self) -> &'static str {
match self {
PlotFormat::Svg => "svg",
PlotFormat::Png => "png",
PlotFormat::Pdf => "pdf",
PlotFormat::Html => "html",
PlotFormat::Dot => "dot",
}
}
pub fn mime_type(&self) -> &'static str {
match self {
PlotFormat::Svg => "image/svg+xml",
PlotFormat::Png => "image/png",
PlotFormat::Pdf => "application/pdf",
PlotFormat::Html => "text/html",
PlotFormat::Dot => "text/vnd.graphviz",
}
}
}
pub fn save_plot<P: AsRef<Path>>(content: &str, path: P, format: PlotFormat) -> RusTorchResult<()> {
let file_path = path.as_ref();
if let Some(parent) = file_path.parent() {
fs::create_dir_all(parent)?;
}
match format {
PlotFormat::Html => {
let html_content = wrap_in_html(content, format)?;
fs::write(file_path, html_content)?;
}
_ => {
fs::write(file_path, content)?;
}
}
Ok(())
}
pub fn wrap_in_html(svg_content: &str, format: PlotFormat) -> RusTorchResult<String> {
match format {
PlotFormat::Html => Ok(format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>RusTorch Visualization</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
background-color: #f5f5f5;
}}
.container {{
background-color: white;
padding: 20px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
text-align: center;
}}
h1 {{
color: #333;
margin-bottom: 20px;
}}
.plot-container {{
display: inline-block;
margin: 10px;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
background-color: #fafafa;
}}
</style>
</head>
<body>
<div class="container">
<h1>RusTorch Visualization</h1>
<div class="plot-container">
{}
</div>
<p style="color: #666; font-size: 12px; margin-top: 20px;">
Generated by RusTorch v0.3.5
</p>
</div>
</body>
</html>"#,
svg_content
)),
_ => Err(crate::error::RusTorchError::visualization(
"Invalid format for HTML wrapping",
)),
}
}
pub fn create_dashboard(plots: Vec<(&str, &str)>) -> RusTorchResult<String> {
let mut plot_sections = String::new();
for (i, (title, svg_content)) in plots.iter().enumerate() {
plot_sections.push_str(&format!(
r#"
<div class="plot-section">
<h2>{}. {}</h2>
<div class="plot-container">
{}
</div>
</div>"#,
i + 1,
title,
svg_content
));
}
Ok(format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>RusTorch Visualization Dashboard</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}}
.header {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
border-radius: 8px;
margin-bottom: 30px;
text-align: center;
}}
.header h1 {{
margin: 0;
font-size: 2.5em;
font-weight: 300;
}}
.header p {{
margin: 10px 0 0 0;
opacity: 0.9;
}}
.dashboard {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(400px, 1fr));
gap: 20px;
}}
.plot-section {{
background-color: white;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
overflow: hidden;
}}
.plot-section h2 {{
background-color: #f8f9fa;
margin: 0;
padding: 15px 20px;
color: #333;
border-bottom: 1px solid #e9ecef;
font-size: 1.2em;
}}
.plot-container {{
padding: 20px;
text-align: center;
}}
.footer {{
text-align: center;
margin-top: 40px;
padding: 20px;
color: #666;
font-size: 14px;
background-color: white;
border-radius: 8px;
}}
</style>
</head>
<body>
<div class="header">
<h1>RusTorch Visualization Dashboard</h1>
<p>Comprehensive visualization of training progress and model analysis</p>
</div>
<div class="dashboard">
{}
</div>
<div class="footer">
Generated by RusTorch v0.3.5 • {} plots
</div>
</body>
</html>"#,
plot_sections,
plots.len()
))
}
pub struct ColorPalette;
impl ColorPalette {
pub fn categorical() -> Vec<String> {
vec![
"#1f77b4".to_string(), "#ff7f0e".to_string(), "#2ca02c".to_string(), "#d62728".to_string(), "#9467bd".to_string(), "#8c564b".to_string(), "#e377c2".to_string(), "#7f7f7f".to_string(), "#bcbd22".to_string(), "#17becf".to_string(), ]
}
pub fn sequential_blues() -> Vec<String> {
vec![
"#f7fbff".to_string(),
"#deebf7".to_string(),
"#c6dbef".to_string(),
"#9ecae1".to_string(),
"#6baed6".to_string(),
"#4292c6".to_string(),
"#2171b5".to_string(),
"#08519c".to_string(),
"#08306b".to_string(),
]
}
pub fn diverging_red_blue() -> Vec<String> {
vec![
"#67001f".to_string(),
"#b2182b".to_string(),
"#d6604d".to_string(),
"#f4a582".to_string(),
"#fddbc7".to_string(),
"#f7f7f7".to_string(),
"#d1e5f0".to_string(),
"#92c5de".to_string(),
"#4393c3".to_string(),
"#2166ac".to_string(),
"#053061".to_string(),
]
}
pub fn get_categorical_color(index: usize) -> String {
let colors = Self::categorical();
colors[index % colors.len()].clone()
}
pub fn get_sequential_color(value: f32) -> String {
let colors = Self::sequential_blues();
let clamped_value = value.clamp(0.0, 1.0);
let index = (clamped_value * (colors.len() - 1) as f32).round() as usize;
colors[index].clone()
}
}
pub fn export_format(
content: &str,
from_format: PlotFormat,
to_format: PlotFormat,
) -> RusTorchResult<String> {
match (&from_format, &to_format) {
(PlotFormat::Svg, PlotFormat::Html) => wrap_in_html(content, PlotFormat::Html),
(PlotFormat::Dot, PlotFormat::Svg) => {
Err(crate::error::RusTorchError::visualization(
"DOT to SVG conversion requires Graphviz",
))
}
(from, to) if from == to => Ok(content.to_string()),
_ => Err(crate::error::RusTorchError::visualization(format!(
"Conversion from {:?} to {:?} is not supported",
from_format, to_format
))),
}
}
pub fn resize_svg(svg_content: &str, new_width: u32, new_height: u32) -> RusTorchResult<String> {
if let Some(start) = svg_content.find("<svg") {
if let Some(end) = svg_content[start..].find(">") {
let svg_tag_end = start + end + 1;
let new_svg_tag = format!(
r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
new_width, new_height
);
let mut result = String::new();
result.push_str(&svg_content[..start]);
result.push_str(&new_svg_tag);
result.push_str(&svg_content[svg_tag_end..]);
Ok(result)
} else {
Err(crate::error::RusTorchError::plotting_error(
"Invalid SVG format",
))
}
} else {
Err(crate::error::RusTorchError::plotting_error(
"No SVG tag found",
))
}
}
#[derive(Debug, Clone)]
pub struct PlotStatistics {
pub total_elements: usize,
pub file_size_bytes: usize,
pub generation_time_ms: u64,
}
impl PlotStatistics {
pub fn new(content: &str, generation_time_ms: u64) -> Self {
Self {
total_elements: Self::count_svg_elements(content),
file_size_bytes: content.len(),
generation_time_ms,
}
}
fn count_svg_elements(content: &str) -> usize {
let elements = [
"<rect", "<circle", "<ellipse", "<line", "<path", "<text", "<polygon",
];
elements
.iter()
.map(|element| content.matches(element).count())
.sum()
}
pub fn format(&self) -> String {
format!(
"Elements: {}, Size: {} bytes, Generation: {}ms",
self.total_elements, self.file_size_bytes, self.generation_time_ms
)
}
}
pub fn generate_filename(base_name: &str, format: PlotFormat, timestamp: bool) -> String {
let mut filename = base_name.to_string();
if timestamp {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
filename.push_str(&format!("_{}", now));
}
filename.push('.');
filename.push_str(format.extension());
filename
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_plot_format_extension() {
assert_eq!(PlotFormat::Svg.extension(), "svg");
assert_eq!(PlotFormat::Png.extension(), "png");
assert_eq!(PlotFormat::Html.extension(), "html");
assert_eq!(PlotFormat::Dot.extension(), "dot");
}
#[test]
fn test_plot_format_mime_type() {
assert_eq!(PlotFormat::Svg.mime_type(), "image/svg+xml");
assert_eq!(PlotFormat::Html.mime_type(), "text/html");
}
#[test]
fn test_color_palette_categorical() {
let colors = ColorPalette::categorical();
assert!(!colors.is_empty());
assert!(colors[0].starts_with('#'));
assert_eq!(colors.len(), 10);
}
#[test]
fn test_get_categorical_color() {
let color1 = ColorPalette::get_categorical_color(0);
let color2 = ColorPalette::get_categorical_color(1);
assert_ne!(color1, color2);
let color_overflow = ColorPalette::get_categorical_color(100);
let colors = ColorPalette::categorical();
assert_eq!(color_overflow, colors[100 % colors.len()]);
}
#[test]
fn test_sequential_color() {
let color_min = ColorPalette::get_sequential_color(0.0);
let color_max = ColorPalette::get_sequential_color(1.0);
let color_mid = ColorPalette::get_sequential_color(0.5);
assert_ne!(color_min, color_max);
assert_ne!(color_min, color_mid);
assert_ne!(color_max, color_mid);
}
#[test]
fn test_wrap_in_html() {
let svg = r#"<svg><rect x="0" y="0" width="100" height="100"/></svg>"#;
let html = wrap_in_html(svg, PlotFormat::Html).unwrap();
assert!(html.contains("<!DOCTYPE html>"));
assert!(html.contains("RusTorch Visualization"));
assert!(html.contains(svg));
}
#[test]
fn test_create_dashboard() {
let plots = vec![
(
"Training Loss",
r#"<svg><rect x="0" y="0" width="100" height="100"/></svg>"#,
),
("Accuracy", r#"<svg><circle cx="50" cy="50" r="25"/></svg>"#),
];
let dashboard = create_dashboard(plots).unwrap();
assert!(dashboard.contains("Training Loss"));
assert!(dashboard.contains("Accuracy"));
assert!(dashboard.contains("dashboard"));
}
#[test]
fn test_plot_statistics() {
let svg_content = r#"<svg><rect x="0" y="0"/><circle cx="50" cy="50"/></svg>"#;
let stats = PlotStatistics::new(svg_content, 100);
assert_eq!(stats.total_elements, 2); assert_eq!(stats.file_size_bytes, svg_content.len());
assert_eq!(stats.generation_time_ms, 100);
}
#[test]
fn test_generate_filename() {
let filename = generate_filename("test_plot", PlotFormat::Svg, false);
assert_eq!(filename, "test_plot.svg");
let filename_with_timestamp = generate_filename("test_plot", PlotFormat::Html, true);
assert!(filename_with_timestamp.starts_with("test_plot_"));
assert!(filename_with_timestamp.ends_with(".html"));
}
#[test]
fn test_save_plot() -> std::io::Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("test.svg");
let svg_content = r#"<svg><rect x="0" y="0" width="100" height="100"/></svg>"#;
save_plot(svg_content, &file_path, PlotFormat::Svg).unwrap();
let saved_content = fs::read_to_string(&file_path)?;
assert_eq!(saved_content, svg_content);
Ok(())
}
#[test]
fn test_resize_svg() {
let original_svg = r#"<svg width="400" height="300" xmlns="http://www.w3.org/2000/svg"><rect x="0" y="0" width="100" height="100"/></svg>"#;
let resized_svg = resize_svg(original_svg, 800, 600).unwrap();
assert!(resized_svg.contains("width=\"800\""));
assert!(resized_svg.contains("height=\"600\""));
assert!(resized_svg.contains("rect"));
}
}