use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use crate::visualization::{PlotData, Visualizable};
use num_traits::Float;
use std::collections::HashMap;
use super::{ColorMap, TensorPlotConfig};
#[derive(Debug)]
pub struct TensorVisualizer {
config: TensorPlotConfig,
}
impl TensorVisualizer {
pub fn new() -> Self {
Self {
config: TensorPlotConfig::default(),
}
}
pub fn with_config(config: TensorPlotConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &TensorPlotConfig {
&self.config
}
pub fn plot_heatmap<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
let shape = tensor.shape();
match shape.len() {
1 => self.plot_1d_heatmap(tensor),
2 => self.plot_2d_heatmap(tensor),
_ => Err(RusTorchError::tensor_op(
"Heatmap visualization only supports 1D and 2D tensors",
)),
}
}
pub fn plot_bar_chart<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
if tensor.shape().len() != 1 {
return Err(RusTorchError::tensor_op(
"Bar chart only supports 1D tensors",
));
}
let shape = tensor.shape();
let size = shape[0];
let mut svg = String::new();
svg.push_str(&format!(
r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
self.config.figsize.0 * self.config.dpi as f32,
self.config.figsize.1 * self.config.dpi as f32
));
let bar_width = (self.config.figsize.0 * self.config.dpi as f32) / size as f32;
for i in 0..size.min(20) {
let height = 50.0 + (i as f32 * 10.0); let x = i as f32 * bar_width;
let y = self.config.figsize.1 * self.config.dpi as f32 - height;
svg.push_str(&format!(
r#"<rect x="{}" y="{}" width="{}" height="{}" fill="steelblue"/>"#,
x,
y,
bar_width * 0.8,
height
));
}
svg.push_str("</svg>");
Ok(svg)
}
pub fn plot_3d_slices<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
if tensor.shape().len() != 3 {
return Err(RusTorchError::tensor_op(
"3D slice visualization only supports 3D tensors",
));
}
Ok("<svg>3D slices visualization placeholder</svg>".to_string())
}
pub fn plot_statistics<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
Ok("<svg>Statistics visualization placeholder</svg>".to_string())
}
pub fn plot_variable<T>(&self, variable: &Variable<T>) -> RusTorchResult<String>
where
T: Float
+ std::fmt::Display
+ std::fmt::Debug
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive
+ Send
+ Sync,
{
Ok("<svg>Variable visualization placeholder</svg>".to_string())
}
fn plot_1d_heatmap<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
Ok("<svg>1D heatmap placeholder</svg>".to_string())
}
fn plot_2d_heatmap<T>(&self, tensor: &Tensor<T>) -> RusTorchResult<String>
where
T: Float + std::fmt::Display + std::fmt::Debug + 'static,
{
let shape = tensor.shape();
let (height, width) = (shape[0], shape[1]);
let mut svg = String::new();
svg.push_str(&format!(
r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
width * 10,
height * 10
));
svg.push_str(&format!(
r#"<rect x="0" y="0" width="{}" height="{}" fill="lightblue"/>"#,
width * 10,
height * 10
));
svg.push_str(&format!(
r#"<text x="{}" y="{}" text-anchor="middle" font-family="Arial" font-size="12">{}x{} Tensor</text>"#,
width * 5, height * 5, height, width
));
svg.push_str("</svg>");
Ok(svg)
}
}
impl Default for TensorVisualizer {
fn default() -> Self {
Self::new()
}
}