use std::collections::HashMap;
use std::fmt::Write;
use std::process::Command;
use crate::tensor::TensorError;
use super::profile::Profile;
use super::Graph;
impl Graph {
pub fn dot(&self) -> String {
self.build_dot(None)
}
pub fn dot_with_profile(&self) -> String {
let profile = self.last_profile.borrow().clone();
self.build_dot(profile.as_ref())
}
pub fn svg(&self, path: Option<&str>) -> crate::tensor::Result<Vec<u8>> {
render_svg(&self.dot(), path)
}
pub fn svg_with_profile(&self, path: Option<&str>) -> crate::tensor::Result<Vec<u8>> {
render_svg(&self.dot_with_profile(), path)
}
fn build_dot(&self, profile: Option<&Profile>) -> String {
let mut b = String::with_capacity(4096);
b.push_str("digraph G {\n");
b.push_str(" rankdir=TB;\n");
b.push_str(" fontname=\"Helvetica\";\n");
b.push_str(" node [fontname=\"Helvetica\" fontsize=11 style=filled];\n");
b.push_str(" edge [fontname=\"Helvetica\" fontsize=9];\n");
b.push_str(" compound=true;\n\n");
let mut node_timings: HashMap<&str, f64> = HashMap::new();
let mut max_duration: f64 = 0.0;
if let Some(p) = profile {
for n in &p.nodes {
let secs = n.duration.as_secs_f64();
node_timings.insert(&n.id, secs);
if secs > max_duration {
max_duration = secs;
}
}
}
let mut node_tags: HashMap<usize, Vec<String>> = HashMap::new();
for (name, &(ni, _)) in &self.tag_names {
node_tags.entry(ni).or_default().push(name.clone());
}
let mut input_nodes: HashMap<usize, bool> = HashMap::new();
let mut output_nodes: HashMap<usize, bool> = HashMap::new();
for ep in &self.inputs {
if let Some(&ni) = self.node_index.get(&ep.node_id) {
input_nodes.insert(ni, true);
}
}
for ep in &self.outputs {
if let Some(&ni) = self.node_index.get(&ep.node_id) {
output_nodes.insert(ni, true);
}
}
for (i, level) in self.levels.iter().enumerate() {
let _ = writeln!(b, " subgraph cluster_level_{} {{", i);
if let Some(p) = profile
&& let Some(lt) = p.levels.get(i)
{
let mut label = format!("level {} {:.0?}", i, lt.wall_clock);
if lt.num_nodes > 1 {
let _ = write!(label, " x{:.1}", lt.parallelism());
}
let _ = writeln!(b, " label=\"{}\";", label);
} else {
let _ = writeln!(b, " label=\"level {}\";", i);
}
b.push_str(" style=dashed; color=\"#999999\"; fontcolor=\"#999999\";\n");
b.push_str(" rank=same;\n");
for &ni in level {
let node = &self.nodes[ni];
let tags = node_tags.get(&ni).cloned().unwrap_or_default();
let mut label = node_label(&node.id, &tags);
let is_input = input_nodes.contains_key(&ni);
let is_output = output_nodes.contains_key(&ni);
let (shape, mut fill) = node_style(&node.id, is_input, is_output);
if let Some(&secs) = node_timings.get(node.id.as_str()) {
let _ = write!(label, "\\n{}", format_duration_us(secs));
if max_duration > 0.0 {
fill = heat_color(secs / max_duration);
}
}
if let Some(ref module) = node.module {
let params = module.parameters();
let count: i64 = params
.iter()
.map(|p| p.variable.shape().iter().product::<i64>())
.sum();
if count > 0 {
let _ = write!(label, "\\n[{}]", format_count(count));
}
}
let _ = writeln!(
b,
" \"{}\" [label=\"{}\" shape={} fillcolor=\"{}\"];",
node.id, label, shape, fill
);
}
b.push_str(" }\n\n");
}
for edge in &self.edges {
let (style, color, elabel) = edge_style(edge);
let mut attrs = format!("style={} color=\"{}\"", style, color);
if !elabel.is_empty() {
let _ = write!(attrs, " label=\"{}\" fontcolor=\"{}\"", elabel, color);
}
let _ = writeln!(
b,
" \"{}\" -> \"{}\" [{}];",
edge.from_node, edge.to_node, attrs
);
}
for entry in &self.state {
let writer_id = &self.nodes[entry.writer_ni].id;
let _ = writeln!(
b,
" \"{}\" -> \"state_read_{}\" [style=dotted color=\"#e67e22\" label=\"state\" fontcolor=\"#e67e22\" constraint=false];",
writer_id,
writer_id
);
}
if let Some(p) = profile {
let _ = writeln!(b);
let _ = writeln!(b, " label=\"Forward: {:.0?}\";", p.total);
let _ = writeln!(b, " labelloc=t;");
let _ = writeln!(b, " fontsize=14;");
}
b.push_str("}\n");
b
}
}
fn clean_id(id: &str) -> &str {
if let Some(pos) = id.rfind('_') {
let suffix = &id[pos + 1..];
if !suffix.is_empty() && suffix.chars().all(|c| c.is_ascii_digit()) {
return &id[..pos];
}
}
id
}
fn node_label(id: &str, tags: &[String]) -> String {
let mut label = clean_id(id).to_string();
if !tags.is_empty() {
label += "\\n";
label += &tags.iter().map(|t| format!("#{}", t)).collect::<Vec<_>>().join(" ");
}
label
}
fn node_style(id: &str, is_input: bool, is_output: bool) -> (&'static str, String) {
match (is_input, is_output) {
(true, true) => return ("doubleoctagon", "#aed6f1".into()),
(true, false) => return ("invhouse", "#aed6f1".into()),
(false, true) => return ("house", "#a9dfbf".into()),
_ => {}
}
let prefix = clean_id(id);
if id.starts_with("state_read_") {
("diamond", "#f9e79f".into())
} else if prefix == "add" || prefix == "merge" {
("circle", "#d5dbdb".into())
} else if prefix == "map" {
("parallelogram", "#a9cce3".into())
} else if prefix == "loop" {
("box3d", "#d7bde2".into())
} else if prefix == "switch" || prefix == "gate" {
("diamond", "#f5cba7".into())
} else if is_activation(prefix) {
("ellipse", "#fdebd0".into())
} else if is_norm(prefix) {
("box", "#d5f5e3".into())
} else {
("box", "#eaecee".into())
}
}
fn edge_style(edge: &super::node::Edge) -> (&'static str, &'static str, String) {
if edge.to_port.starts_with("ref_") {
let name = edge.to_port.strip_prefix("ref_").unwrap_or("");
return ("dashed", "#2980b9", name.to_string());
}
("solid", "#2c3e50", String::new())
}
fn is_activation(prefix: &str) -> bool {
matches!(prefix, "relu" | "sigmoid" | "tanh" | "gelu" | "silu")
}
fn is_norm(prefix: &str) -> bool {
matches!(prefix, "layernorm" | "batchnorm" | "rmsnorm")
}
fn heat_color(ratio: f64) -> String {
let ratio = ratio.clamp(0.0, 1.0);
let (r, g, b) = if ratio < 0.5 {
let t = ratio * 2.0;
(
0x27 as f64 + t * (0xf3 - 0x27) as f64,
0xae as f64 + t * (0x9c_u8 as f64 - 0xae as f64),
0x60 as f64 + t * (0x12 as f64 - 0x60 as f64),
)
} else {
let t = (ratio - 0.5) * 2.0;
(
0xf3 as f64 + t * (0xe7 as f64 - 0xf3 as f64),
0x9c as f64 + t * (0x4c as f64 - 0x9c as f64),
0x12 as f64 + t * (0x3c as f64 - 0x12 as f64),
)
};
format!("#{:02x}{:02x}{:02x}", r as u8, g as u8, b as u8)
}
fn format_count(n: i64) -> String {
if n >= 1_000_000 {
format!("{:.1}M params", n as f64 / 1_000_000.0)
} else if n >= 1_000 {
format!("{:.1}K params", n as f64 / 1_000.0)
} else {
format!("{} params", n)
}
}
fn format_duration_us(secs: f64) -> String {
let us = secs * 1_000_000.0;
if us < 1000.0 {
format!("{:.0}us", us)
} else {
format!("{:.2}ms", us / 1000.0)
}
}
fn render_svg(dot: &str, path: Option<&str>) -> crate::tensor::Result<Vec<u8>> {
let dot_bin = which_dot()?;
let mut child = Command::new(&dot_bin)
.arg("-Tsvg")
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| TensorError::new(&format!("failed to spawn {}: {}", dot_bin, e)))?;
{
use std::io::Write as IoWrite;
let stdin = child.stdin.as_mut().unwrap();
stdin.write_all(dot.as_bytes())
.map_err(|e| TensorError::new(&format!("failed to write DOT to stdin: {}", e)))?;
}
let output = child.wait_with_output()
.map_err(|e| TensorError::new(&format!("dot process failed: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(TensorError::new(&format!("dot exited with {}: {}", output.status, stderr)));
}
if let Some(p) = path {
std::fs::write(p, &output.stdout)
.map_err(|e| TensorError::new(&format!("failed to write {}: {}", p, e)))?;
}
Ok(output.stdout)
}
fn which_dot() -> crate::tensor::Result<String> {
if let Ok(output) = Command::new("which").arg("dot").output()
&& output.status.success()
{
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Ok(path);
}
}
Err(TensorError::new(
"Graphviz `dot` binary not found. Install with:\n \
apt install graphviz (Debian/Ubuntu)\n \
brew install graphviz (macOS)\n\n\
Or use .dot() to get the DOT string and render online at:\n \
https://dreampuf.github.io/GraphvizOnline"
))
}