use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use num_traits::Float;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use super::{
ArrowType, EdgeStyle, GraphEdge, GraphLayout, GraphNode, LineType, NodeShape, NodeStyle,
NodeType,
};
#[derive(Debug)]
pub struct GraphVisualizer {
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
pub layout: GraphLayout,
pub canvas_size: (f32, f32),
}
impl GraphVisualizer {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
layout: GraphLayout::Hierarchical,
canvas_size: (800.0, 600.0),
}
}
pub fn with_layout(layout: GraphLayout) -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
layout,
canvas_size: (800.0, 600.0),
}
}
pub fn build_graph<T>(&mut self, variable: &Variable<T>) -> RusTorchResult<()>
where
T: Float
+ Debug
+ std::fmt::Display
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
self.nodes.clear();
self.edges.clear();
let mut visited = HashSet::new();
let mut node_counter = 0;
self.traverse_variable(variable, &mut visited, &mut node_counter)?;
self.apply_layout()?;
Ok(())
}
pub fn build_graph_multi<T>(&mut self, variables: &[&Variable<T>]) -> RusTorchResult<()>
where
T: Float
+ Debug
+ std::fmt::Display
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
self.nodes.clear();
self.edges.clear();
let mut visited = HashSet::new();
let mut node_counter = 0;
for variable in variables {
self.traverse_variable(variable, &mut visited, &mut node_counter)?;
}
self.apply_layout()?;
Ok(())
}
pub fn to_svg(&self) -> String {
let mut svg = String::new();
svg.push_str(&format!(
r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
self.canvas_size.0, self.canvas_size.1
));
for edge in &self.edges {
svg.push_str(&self.render_edge(edge));
}
for node in &self.nodes {
svg.push_str(&self.render_node(node));
}
svg.push_str("</svg>");
svg
}
fn traverse_variable<T>(
&mut self,
variable: &Variable<T>,
visited: &mut HashSet<String>,
node_counter: &mut usize,
) -> RusTorchResult<()>
where
T: Float
+ Debug
+ std::fmt::Display
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
Ok(())
}
fn apply_layout(&mut self) -> RusTorchResult<()> {
match self.layout {
GraphLayout::Hierarchical => self.apply_hierarchical_layout(),
GraphLayout::Circular => self.apply_circular_layout(),
GraphLayout::ForceDirected => self.apply_force_directed_layout(),
GraphLayout::Grid => self.apply_grid_layout(),
GraphLayout::LeftToRight => self.apply_left_to_right_layout(),
}
}
fn apply_hierarchical_layout(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn apply_circular_layout(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn apply_force_directed_layout(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn apply_grid_layout(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn apply_left_to_right_layout(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn render_node(&self, node: &GraphNode) -> String {
format!(
r#"<circle cx="{}" cy="{}" r="20" fill="rgb({},{},{})" stroke="rgb({},{},{})" stroke-width="{}"/>"#,
node.position.0,
node.position.1,
node.style.background_color.0,
node.style.background_color.1,
node.style.background_color.2,
node.style.border_color.0,
node.style.border_color.1,
node.style.border_color.2,
node.style.border_width
)
}
fn render_edge(&self, edge: &GraphEdge) -> String {
if let (Some(from_node), Some(to_node)) = (
self.nodes.iter().find(|n| n.id == edge.from),
self.nodes.iter().find(|n| n.id == edge.to),
) {
format!(
r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="rgb({},{},{})" stroke-width="{}"/>"#,
from_node.position.0,
from_node.position.1,
to_node.position.0,
to_node.position.1,
edge.style.color.0,
edge.style.color.1,
edge.style.color.2,
edge.style.thickness
)
} else {
String::new()
}
}
}
impl Default for GraphVisualizer {
fn default() -> Self {
Self::new()
}
}