use rustorch::models::high_level::TrainingHistory;
use rustorch::prelude::*;
use rustorch::visualization::utils::{create_dashboard, ColorPalette};
use rustorch::visualization::*;
use std::collections::HashMap;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🎨 RusTorch 可視化機能デモ");
println!("🎨 RusTorch Visualization Demo\n");
println!("📈 1. 学習曲線の可視化 / Training Curves Visualization");
let mut history = TrainingHistory::<f32>::new();
for epoch in 1..=10 {
let train_loss = 1.0 - (epoch as f32 * 0.08);
let val_loss = train_loss + 0.05;
let mut metrics = HashMap::new();
metrics.insert(
"accuracy".to_string(),
vec![(epoch as f64 * 0.07 + 0.3).min(0.95)],
);
metrics.insert(
"precision".to_string(),
vec![(epoch as f64 * 0.06 + 0.4).min(0.92)],
);
history.train_loss.push(train_loss);
history.val_loss.push(val_loss);
for (key, value) in metrics {
history.metrics.entry(key).or_default().extend(value);
}
}
let plotter = TrainingPlotter::with_config(PlotConfig {
width: 800,
height: 600,
dpi: 150,
chart_type: ChartType::Single,
background_color: "#ffffff".to_string(),
font_size: 14,
line_width: 2.5,
marker_size: 5.0,
});
match plotter.plot_training_curves(&history) {
Ok(svg) => {
println!("✓ 学習曲線SVG生成成功 ({} bytes)", svg.len());
if let Err(e) = plotter.save_plot(&svg, "training_curves.svg") {
println!("⚠ ファイル保存に失敗: {}", e);
} else {
println!("✓ ファイルに保存: training_curves.svg");
}
}
Err(e) => println!("✗ 学習曲線生成エラー: {}", e),
}
match plotter.plot_metrics_timeline(&history, "accuracy") {
Ok(svg) => {
println!("✓ 精度時系列SVG生成成功 ({} bytes)", svg.len());
}
Err(e) => println!("⚠ 精度時系列生成エラー: {}", e),
}
println!();
println!("🔢 2. テンソルの可視化 / Tensor Visualization");
let tensor_viz = TensorVisualizer::with_config(TensorPlotConfig {
colormap: ColorMap::Viridis,
normalize: true,
aspect: "equal".to_string(),
title: Some("Sample Heatmap".to_string()),
show_colorbar: true,
show_axes: false,
figsize: (8.0, 6.0),
dpi: 100,
});
let heat_data: Vec<f32> = (0..16).map(|i| (i as f32 / 4.0).sin()).collect();
let heat_tensor = Tensor::from_vec(heat_data, vec![4, 4]);
match tensor_viz.plot_heatmap(&heat_tensor) {
Ok(svg) => {
println!("✓ ヒートマップSVG生成成功 ({} bytes)", svg.len());
if let Err(e) = save_plot(&svg, "heatmap.svg", PlotFormat::Svg) {
println!("⚠ ファイル保存に失敗: {}", e);
} else {
println!("✓ ファイルに保存: heatmap.svg");
}
}
Err(e) => println!("✗ ヒートマップ生成エラー: {}", e),
}
let bar_data = vec![1.0, 3.0, 2.0, 4.0, 1.5, 3.5, 2.5];
let bar_tensor = Tensor::from_vec(bar_data, vec![7]);
match tensor_viz.plot_bar_chart(&bar_tensor) {
Ok(svg) => {
println!("✓ 棒グラフSVG生成成功 ({} bytes)", svg.len());
}
Err(e) => println!("✗ 棒グラフ生成エラー: {}", e),
}
let slice_data: Vec<f32> = (0..24).map(|i| (i as f32).cos()).collect();
let slice_tensor = Tensor::from_vec(slice_data, vec![2, 3, 4]);
match tensor_viz.plot_3d_slices(&slice_tensor) {
Ok(svg) => {
println!("✓ 3Dスライス可視化SVG生成成功 ({} bytes)", svg.len());
}
Err(e) => println!("✗ 3Dスライス生成エラー: {}", e),
}
println!();
println!("🕸️ 3. 計算グラフの可視化 / Computation Graph Visualization");
let mut graph_viz = GraphVisualizer::with_layout(GraphLayout::Hierarchical);
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let variable = Variable::new(tensor, true);
match graph_viz.build_graph(&variable) {
Ok(()) => {
println!("✓ 計算グラフ構築成功");
let svg = graph_viz.to_svg();
println!("✓ 計算グラフSVG生成成功 ({} bytes)", svg.len());
if let Err(e) = save_plot(&svg, "computation_graph.svg", PlotFormat::Svg) {
println!("⚠ ファイル保存に失敗: {}", e);
} else {
println!("✓ ファイルに保存: computation_graph.svg");
}
println!("📊 グラフ統計:");
println!(" ノード数: {}", graph_viz.nodes.len());
println!(" エッジ数: {}", graph_viz.edges.len());
println!(" レイアウト: {:?}", graph_viz.layout);
println!(
" キャンバスサイズ: {}x{}",
graph_viz.canvas_size.0, graph_viz.canvas_size.1
);
}
Err(e) => println!("✗ 計算グラフ構築エラー: {}", e),
}
println!();
println!("🎨 4. カラーパレット / Color Palette Demo");
let categorical_colors = ColorPalette::categorical();
println!("✓ カテゴリカル色数: {}", categorical_colors.len());
for i in 0..3 {
let color = ColorPalette::get_categorical_color(i);
println!(" 色 {}: {}", i, color);
}
let sequential_colors = ColorPalette::sequential_blues();
println!("✓ シーケンシャル色数: {}", sequential_colors.len());
for &value in &[0.0, 0.5, 1.0] {
let color = ColorPalette::get_sequential_color(value);
println!(" 値 {}: {}", value, color);
}
println!();
println!("📊 5. ダッシュボード作成 / Dashboard Creation");
let sample_plots = vec![
("Training Loss", "<svg width=\"300\" height=\"200\"><rect x=\"10\" y=\"10\" width=\"280\" height=\"180\" fill=\"#e3f2fd\" stroke=\"#1976d2\"/><text x=\"150\" y=\"100\" text-anchor=\"middle\" font-family=\"Arial\">Training Loss Chart</text></svg>"),
("Validation Accuracy", "<svg width=\"300\" height=\"200\"><rect x=\"10\" y=\"10\" width=\"280\" height=\"180\" fill=\"#e8f5e8\" stroke=\"#4caf50\"/><text x=\"150\" y=\"100\" text-anchor=\"middle\" font-family=\"Arial\">Accuracy Chart</text></svg>"),
("Model Architecture", "<svg width=\"300\" height=\"200\"><rect x=\"10\" y=\"10\" width=\"280\" height=\"180\" fill=\"#fff3e0\" stroke=\"#ff9800\"/><text x=\"150\" y=\"100\" text-anchor=\"middle\" font-family=\"Arial\">Model Diagram</text></svg>"),
];
match create_dashboard(sample_plots) {
Ok(dashboard_html) => {
println!(
"✓ ダッシュボードHTML生成成功 ({} bytes)",
dashboard_html.len()
);
if let Err(e) = save_plot(&dashboard_html, "dashboard.html", PlotFormat::Html) {
println!("⚠ ダッシュボード保存に失敗: {}", e);
} else {
println!("✓ ダッシュボード保存: dashboard.html");
println!(" ブラウザで開いて表示可能です");
}
}
Err(e) => println!("✗ ダッシュボード生成エラー: {}", e),
}
println!();
println!("🎉 可視化機能デモ完了!");
println!("🎉 Visualization Demo Complete!\n");
println!("生成されたファイル:");
println!("Generated files:");
println!(" - training_curves.svg : 学習曲線");
println!(" - heatmap.svg : テンソルヒートマップ");
println!(" - computation_graph.svg : 計算グラフ (SVG)");
println!(" - dashboard.html : 可視化ダッシュボード");
println!();
println!("💡 使用方法:");
println!("💡 Usage:");
println!(" - SVGファイルはブラウザやベクター画像エディタで表示");
println!(" - HTMLダッシュボードはブラウザで直接表示");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visualization_demo_components() {
let plotter = TrainingPlotter::new();
assert_eq!(plotter.config.width, 800);
let visualizer = TensorVisualizer::new();
assert_eq!(visualizer.config().colormap, ColorMap::Viridis);
let graph_viz = GraphVisualizer::new();
assert_eq!(graph_viz.layout, GraphLayout::Hierarchical);
let colors = ColorPalette::categorical();
assert!(!colors.is_empty());
println!("✓ All visualization components initialized successfully");
}
#[test]
fn test_tensor_creation_and_visualization() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_vec(data.clone(), vec![2, 2]);
assert_eq!(tensor.shape(), &vec![2, 2]);
if let Some(slice) = tensor.as_slice() {
assert_eq!(slice, &data[..]);
}
let visualizer = TensorVisualizer::new();
let result = visualizer.plot_heatmap(&tensor);
assert!(result.is_ok(), "Tensor visualization should succeed");
println!("✓ Tensor creation and visualization test passed");
}
}