use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use tensorlogic_ir::{DotExportOptions, EinsumGraph, OpType};
#[derive(Debug, Clone)]
pub struct VisualizationConfig {
pub show_details: bool,
pub show_shapes: bool,
pub max_depth: usize,
pub use_color: bool,
pub indent: String,
pub show_tensor_ids: bool,
pub show_node_ids: bool,
pub horizontal_layout: bool,
pub cluster_by_operation: bool,
}
impl Default for VisualizationConfig {
fn default() -> Self {
VisualizationConfig {
show_details: true,
show_shapes: true,
max_depth: 0,
use_color: true,
indent: " ".to_string(),
show_tensor_ids: false,
show_node_ids: true,
horizontal_layout: false,
cluster_by_operation: false,
}
}
}
impl VisualizationConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_details(mut self, v: bool) -> Self {
self.show_details = v;
self
}
pub fn with_shapes(mut self, v: bool) -> Self {
self.show_shapes = v;
self
}
pub fn with_max_depth(mut self, d: usize) -> Self {
self.max_depth = d;
self
}
pub fn with_color(mut self, v: bool) -> Self {
self.use_color = v;
self
}
pub fn with_tensor_ids(mut self, v: bool) -> Self {
self.show_tensor_ids = v;
self
}
pub fn with_node_ids(mut self, v: bool) -> Self {
self.show_node_ids = v;
self
}
pub fn with_horizontal_layout(mut self, v: bool) -> Self {
self.horizontal_layout = v;
self
}
pub fn with_clustering(mut self, v: bool) -> Self {
self.cluster_by_operation = v;
self
}
pub fn minimal() -> Self {
VisualizationConfig {
show_details: false,
show_shapes: false,
show_tensor_ids: false,
show_node_ids: false,
..Self::default()
}
}
fn to_dot_options(&self) -> DotExportOptions {
DotExportOptions {
show_tensor_ids: self.show_tensor_ids,
show_node_ids: self.show_node_ids,
show_metadata: self.show_details,
show_shapes: self.show_shapes,
cluster_by_operation: self.cluster_by_operation,
horizontal_layout: self.horizontal_layout,
highlight_tensors: Vec::new(),
highlight_nodes: Vec::new(),
}
}
}
pub struct DotExporter;
impl DotExporter {
pub fn export(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
let options = config.to_dot_options();
let dot = tensorlogic_ir::export_to_dot_with_options(graph, &options);
if config.use_color {
dot
} else {
Self::strip_fill_colors(&dot)
}
}
fn strip_fill_colors(dot: &str) -> String {
let mut result = String::with_capacity(dot.len());
for line in dot.lines() {
let cleaned = line
.replace(", style=filled", "")
.replace("style=filled, ", "")
.replace("style=filled", "");
let cleaned = strip_attr(&cleaned, "fillcolor=");
let cleaned = cleaned.replace(", ];", "];").replace(",];", "];");
let _ = writeln!(result, "{}", cleaned);
}
result
}
}
fn strip_attr(line: &str, prefix: &str) -> String {
if let Some(start) = line.find(prefix) {
let before = &line[..start];
let after_key = &line[start + prefix.len()..];
let end = after_key
.find([',', ';', ']', ' '])
.unwrap_or(after_key.len());
let rest = &after_key[end..];
let rest = rest.strip_prefix(", ").unwrap_or(rest);
let rest = rest.strip_prefix(',').unwrap_or(rest);
format!("{}{}", before.trim_end_matches(", "), rest)
} else {
line.to_string()
}
}
pub fn write_dot_file(
path: &std::path::Path,
graph: &EinsumGraph,
config: &VisualizationConfig,
) -> std::io::Result<()> {
let dot = DotExporter::export(graph, config);
std::fs::write(path, dot)
}
pub struct AsciiRenderer;
impl AsciiRenderer {
pub fn render(graph: &EinsumGraph, config: &VisualizationConfig) -> String {
let mut out = String::new();
let _ = writeln!(out, "=== EinsumGraph ===");
let _ = writeln!(out, "Nodes: {}", graph.nodes.len());
let _ = writeln!(
out,
"Tensors: {} ({} inputs, {} outputs)",
graph.tensors.len(),
graph.inputs.len(),
graph.outputs.len()
);
if !graph.outputs.is_empty() {
let names: Vec<&str> = graph
.outputs
.iter()
.filter_map(|&idx| graph.tensors.get(idx).map(|s| s.as_str()))
.collect();
let _ = writeln!(out, "Outputs: [{}]", names.join(", "));
}
let _ = writeln!(out);
let depth_limit = if config.max_depth == 0 {
usize::MAX
} else {
config.max_depth
};
for (i, node) in graph.nodes.iter().enumerate() {
if i >= depth_limit {
let _ = writeln!(
out,
"{}... ({} more nodes)",
config.indent,
graph.nodes.len() - i
);
break;
}
Self::render_node(&mut out, graph, node, i, config);
}
let _ = writeln!(out, "===================");
out
}
fn render_node(
out: &mut String,
graph: &EinsumGraph,
node: &tensorlogic_ir::EinsumNode,
idx: usize,
config: &VisualizationConfig,
) {
let indent = &config.indent;
let _ = write!(out, "{}[{}] ", indent, idx);
let _ = writeln!(out, "{}", node.operation_description());
if config.show_details {
let input_names: Vec<String> = node
.inputs
.iter()
.map(|&i| {
graph
.tensors
.get(i)
.cloned()
.unwrap_or_else(|| format!("?{}", i))
})
.collect();
let _ = writeln!(
out,
"{}{} inputs: [{}]",
indent,
indent,
input_names.join(", ")
);
let output_names: Vec<String> = node
.outputs
.iter()
.map(|&i| {
graph
.tensors
.get(i)
.cloned()
.unwrap_or_else(|| format!("?{}", i))
})
.collect();
let _ = writeln!(
out,
"{}{} outputs: [{}]",
indent,
indent,
output_names.join(", ")
);
}
}
}
#[derive(Debug, Clone)]
pub struct GraphSummary {
pub node_count: usize,
pub tensor_count: usize,
pub output_count: usize,
pub input_count: usize,
pub max_fan_in: usize,
pub max_fan_out: usize,
pub depth: usize,
pub op_counts: HashMap<String, usize>,
}
impl GraphSummary {
pub fn compute(graph: &EinsumGraph) -> Self {
let node_count = graph.nodes.len();
let tensor_count = graph.tensors.len();
let output_count = graph.outputs.len();
let input_count = graph.inputs.len();
let max_fan_in = graph
.nodes
.iter()
.map(|n| n.inputs.len())
.max()
.unwrap_or(0);
let max_fan_out = graph
.nodes
.iter()
.map(|n| n.outputs.len())
.max()
.unwrap_or(0);
let mut op_counts: HashMap<String, usize> = HashMap::new();
for node in &graph.nodes {
let key = match &node.op {
OpType::Einsum { .. } => "Einsum",
OpType::ElemUnary { .. } => "ElemUnary",
OpType::ElemBinary { .. } => "ElemBinary",
OpType::Reduce { .. } => "Reduce",
};
*op_counts.entry(key.to_string()).or_insert(0) += 1;
}
let depth = Self::compute_depth(graph);
GraphSummary {
node_count,
tensor_count,
output_count,
input_count,
max_fan_in,
max_fan_out,
depth,
op_counts,
}
}
fn compute_depth(graph: &EinsumGraph) -> usize {
if graph.nodes.is_empty() {
return 0;
}
let mut producer: HashMap<usize, usize> = HashMap::new();
for (node_idx, node) in graph.nodes.iter().enumerate() {
for &out_t in &node.outputs {
producer.insert(out_t, node_idx);
}
}
let num_nodes = graph.nodes.len();
let mut memo: Vec<Option<usize>> = vec![None; num_nodes];
fn depth_of(
node_idx: usize,
graph: &EinsumGraph,
producer: &HashMap<usize, usize>,
memo: &mut [Option<usize>],
visited: &mut HashSet<usize>,
) -> usize {
if let Some(d) = memo[node_idx] {
return d;
}
if !visited.insert(node_idx) {
return 0;
}
let node = &graph.nodes[node_idx];
let mut max_pred = 0usize;
for &inp_t in &node.inputs {
if let Some(&pred_node) = producer.get(&inp_t) {
let d = depth_of(pred_node, graph, producer, memo, visited);
if d + 1 > max_pred {
max_pred = d + 1;
}
}
}
memo[node_idx] = Some(max_pred);
max_pred
}
let mut max_depth = 0usize;
for i in 0..num_nodes {
let mut visited = HashSet::new();
let d = depth_of(i, graph, &producer, &mut memo, &mut visited);
if d > max_depth {
max_depth = d;
}
}
max_depth + 1
}
pub fn display(&self) -> String {
let mut out = String::new();
let _ = writeln!(out, "Graph Summary:");
let _ = writeln!(out, " Nodes: {}", self.node_count);
let _ = writeln!(out, " Tensors: {}", self.tensor_count);
let _ = writeln!(out, " Inputs: {}", self.input_count);
let _ = writeln!(out, " Outputs: {}", self.output_count);
let _ = writeln!(out, " Depth: {}", self.depth);
let _ = writeln!(out, " Max fan-in: {}", self.max_fan_in);
let _ = writeln!(out, " Max fan-out: {}", self.max_fan_out);
if !self.op_counts.is_empty() {
let _ = writeln!(out, " Operations:");
let mut sorted: Vec<_> = self.op_counts.iter().collect();
sorted.sort_by_key(|(k, _)| (*k).clone());
for (op, count) in sorted {
let _ = writeln!(out, " {}: {}", op, count);
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{EinsumGraph, EinsumNode};
fn empty_graph() -> EinsumGraph {
EinsumGraph::new()
}
fn small_graph() -> EinsumGraph {
let mut g = EinsumGraph::new();
let a = g.add_tensor("a".to_string());
let b = g.add_tensor("b".to_string());
let c = g.add_tensor("c".to_string());
let d = g.add_tensor("d".to_string());
g.inputs = vec![a, b];
g.outputs = vec![d];
g.add_node(EinsumNode::elem_binary("add", a, b, c))
.expect("node add");
g.add_node(EinsumNode::elem_unary("relu", c, d))
.expect("node relu");
g
}
#[test]
fn test_dot_export_empty_graph() {
let g = empty_graph();
let dot = DotExporter::export(&g, &VisualizationConfig::default());
assert!(dot.contains("digraph"));
}
#[test]
fn test_dot_export_contains_nodes() {
let g = small_graph();
let dot = DotExporter::export(&g, &VisualizationConfig::default());
assert!(dot.contains("op_0"));
assert!(dot.contains("op_1"));
}
#[test]
fn test_dot_export_contains_edges() {
let g = small_graph();
let dot = DotExporter::export(&g, &VisualizationConfig::default());
assert!(dot.contains("tensor_0 -> op_0"));
assert!(dot.contains("tensor_1 -> op_0"));
assert!(dot.contains("op_0 -> tensor_2"));
assert!(dot.contains("tensor_2 -> op_1"));
assert!(dot.contains("op_1 -> tensor_3"));
}
#[test]
fn test_dot_export_no_color() {
let g = small_graph();
let config = VisualizationConfig::new().with_color(false);
let dot = DotExporter::export(&g, &config);
assert!(!dot.contains("fillcolor"));
}
#[test]
fn test_dot_export_minimal_config() {
let g = small_graph();
let full = DotExporter::export(&g, &VisualizationConfig::default());
let minimal = DotExporter::export(&g, &VisualizationConfig::minimal());
assert!(minimal.contains("digraph"));
assert!(minimal.len() <= full.len());
}
#[test]
fn test_write_dot_file() {
let g = small_graph();
let dir = std::env::temp_dir();
let path = dir.join("tensorlogic_test_viz.dot");
write_dot_file(&path, &g, &VisualizationConfig::default()).expect("should write file");
let contents = std::fs::read_to_string(&path).expect("should read file");
assert!(contents.contains("digraph"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_ascii_render_header() {
let g = empty_graph();
let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
assert!(ascii.starts_with("=== EinsumGraph ==="));
}
#[test]
fn test_ascii_render_node_count() {
let g = small_graph();
let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
assert!(ascii.contains("Nodes: 2"));
}
#[test]
fn test_ascii_render_output_count() {
let g = small_graph();
let ascii = AsciiRenderer::render(&g, &VisualizationConfig::default());
assert!(ascii.contains("Outputs: [d]"));
}
#[test]
fn test_ascii_render_details() {
let g = small_graph();
let config = VisualizationConfig::new().with_details(true);
let ascii = AsciiRenderer::render(&g, &config);
assert!(ascii.contains("inputs:"));
assert!(ascii.contains("outputs:"));
}
#[test]
fn test_ascii_render_no_details() {
let g = small_graph();
let with_details =
AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(true));
let without = AsciiRenderer::render(&g, &VisualizationConfig::new().with_details(false));
assert!(without.len() < with_details.len());
assert!(!without.contains("inputs:"));
}
#[test]
fn test_config_default() {
let c = VisualizationConfig::default();
assert!(c.show_details);
assert!(c.show_shapes);
assert_eq!(c.max_depth, 0);
assert!(c.use_color);
assert_eq!(c.indent, " ");
}
#[test]
fn test_config_builder() {
let c = VisualizationConfig::new()
.with_details(false)
.with_shapes(false)
.with_max_depth(5)
.with_color(false);
assert!(!c.show_details);
assert!(!c.show_shapes);
assert_eq!(c.max_depth, 5);
assert!(!c.use_color);
}
#[test]
fn test_config_minimal() {
let c = VisualizationConfig::minimal();
assert!(!c.show_details);
assert!(!c.show_shapes);
assert!(!c.show_tensor_ids);
assert!(!c.show_node_ids);
}
#[test]
fn test_graph_summary_empty() {
let g = empty_graph();
let s = GraphSummary::compute(&g);
assert_eq!(s.node_count, 0);
assert_eq!(s.tensor_count, 0);
assert_eq!(s.output_count, 0);
assert_eq!(s.input_count, 0);
assert_eq!(s.max_fan_in, 0);
assert_eq!(s.max_fan_out, 0);
assert_eq!(s.depth, 0);
}
#[test]
fn test_graph_summary_basic() {
let g = small_graph();
let s = GraphSummary::compute(&g);
assert_eq!(s.node_count, 2);
assert_eq!(s.tensor_count, 4);
assert_eq!(s.output_count, 1);
assert_eq!(s.input_count, 2);
assert_eq!(s.max_fan_in, 2); assert_eq!(s.max_fan_out, 1);
assert_eq!(s.depth, 2); assert_eq!(s.op_counts.get("ElemBinary"), Some(&1));
assert_eq!(s.op_counts.get("ElemUnary"), Some(&1));
}
#[test]
fn test_dot_deterministic() {
let g = small_graph();
let config = VisualizationConfig::default();
let a = DotExporter::export(&g, &config);
let b = DotExporter::export(&g, &config);
assert_eq!(a, b);
}
#[test]
fn test_ascii_deterministic() {
let g = small_graph();
let config = VisualizationConfig::default();
let a = AsciiRenderer::render(&g, &config);
let b = AsciiRenderer::render(&g, &config);
assert_eq!(a, b);
}
}