use crate::graph::{Graph, TensorID};
use crate::tensor::Tensor;
use crate::{Context, Float};
use std::collections::{HashMap, HashSet};
use std::fmt::Write as FmtWrite;
fn collect_reachable<F: Float>(root: TensorID, graph: &Graph<F>) -> Vec<TensorID> {
let mut visited = HashSet::new();
let mut order = Vec::new();
dfs_collect(root, graph, &mut visited, &mut order);
order
}
fn dfs_collect<F: Float>(
id: TensorID,
graph: &Graph<F>,
visited: &mut HashSet<TensorID>,
order: &mut Vec<TensorID>,
) {
if !visited.insert(id) {
return;
}
let node = graph.access_inner(id);
let incoming: Vec<TensorID> = node.incoming_nodes.iter().map(|inc| inc.id).collect();
drop(node);
for dep in incoming {
dfs_collect(dep, graph, visited, order);
}
order.push(id);
}
#[derive(Debug, Clone)]
pub struct NodeDesc {
id: TensorID,
op_name: String,
topo_rank: usize,
is_differentiable: bool,
is_placeholder: bool,
placeholder_name: Option<String>,
is_variable: bool,
known_shape: Option<Vec<isize>>,
input_ids: Vec<TensorID>,
}
fn describe_node<F: Float>(id: TensorID, graph: &Graph<F>) -> NodeDesc {
let node = graph.access_inner(id);
let op_name = node
.op
.as_ref()
.map(|o| {
let full = o.name();
full.rsplit("::").next().unwrap_or(full).to_string()
})
.unwrap_or_else(|| "Source".to_string());
let input_ids: Vec<TensorID> = node.incoming_nodes.iter().map(|inc| inc.id).collect();
let known_shape = node.knownshape.as_ref().map(|ks| ks.get().to_vec());
let placeholder_name = node.placeholder_name.map(|s| s.to_string());
let is_placeholder = node.placeholder_name.is_some();
let is_variable = node.variable_id.is_some();
let is_differentiable = node.is_differentiable;
let topo_rank = node.topo_rank;
drop(node);
NodeDesc {
id,
op_name,
topo_rank,
is_differentiable,
is_placeholder,
placeholder_name,
is_variable,
known_shape,
input_ids,
}
}
#[derive(Debug, Clone)]
pub struct ComputationGraphViz {
nodes: Vec<NodeDesc>,
}
impl ComputationGraphViz {
pub fn from_tensor<'a, 'g, F: Float>(root: &Tensor<'g, F>, ctx: &'a Context<'g, F>) -> Self {
let graph: &Graph<F> = std::ops::Deref::deref(ctx);
let ids = collect_reachable(root.id(), graph);
let nodes = ids.iter().map(|&id| describe_node(id, graph)).collect();
Self { nodes }
}
pub fn to_dot(&self) -> String {
let node_set: HashSet<TensorID> = self.nodes.iter().map(|n| n.id).collect();
let mut out = String::new();
let _ = writeln!(out, "digraph computation_graph {{");
let _ = writeln!(out, " rankdir=BT;");
let _ = writeln!(
out,
" node [shape=record, style=\"rounded,filled\", fontname=\"Helvetica\", fontsize=10];"
);
let _ = writeln!(out, " edge [color=gray50, fontsize=8];");
let _ = writeln!(out);
for nd in &self.nodes {
let label = self.node_to_dot(nd);
let color = dot_color(nd);
let _ = writeln!(out, " n{} [label=\"{}\", {}];", nd.id, label, color);
}
let _ = writeln!(out);
for nd in &self.nodes {
for &src in &nd.input_ids {
if node_set.contains(&src) {
let src_desc = self.nodes.iter().find(|n| n.id == src);
let edge_label = if let Some(s) = src_desc {
s.known_shape
.as_ref()
.map(|sh| format!("{sh:?}"))
.unwrap_or_default()
} else {
String::new()
};
if edge_label.is_empty() {
let _ = writeln!(out, " n{src} -> n{};", nd.id);
} else {
let _ = writeln!(
out,
" n{src} -> n{} [label=\"{}\"];",
nd.id, edge_label
);
}
}
}
}
let _ = writeln!(out, "}}");
out
}
pub fn node_to_dot(&self, nd: &NodeDesc) -> String {
let mut parts = Vec::new();
let mut header = nd.op_name.clone();
if let Some(ref pname) = nd.placeholder_name {
header = format!("{pname} | {header}");
} else if nd.is_variable {
header = format!("var | {header}");
}
parts.push(header);
if let Some(ref sh) = nd.known_shape {
parts.push(format!("shape: {sh:?}"));
}
if nd.is_differentiable {
parts.push("grad".to_string());
}
parts.push(format!("id={}", nd.id));
parts.join(" | ")
}
pub fn to_mermaid(&self) -> String {
let node_set: HashSet<TensorID> = self.nodes.iter().map(|n| n.id).collect();
let mut out = String::new();
let _ = writeln!(out, "graph BT");
for nd in &self.nodes {
let label = if let Some(ref pname) = nd.placeholder_name {
format!("{pname}: {}", nd.op_name)
} else if nd.is_variable {
format!("var: {}", nd.op_name)
} else {
nd.op_name.clone()
};
let shape_str = nd
.known_shape
.as_ref()
.map(|sh| format!("<br/>{sh:?}"))
.unwrap_or_default();
let _ = writeln!(out, " N{}[\"{}{}\"]", nd.id, label, shape_str);
}
for nd in &self.nodes {
for &src in &nd.input_ids {
if node_set.contains(&src) {
let _ = writeln!(out, " N{src} --> N{}", nd.id);
}
}
}
out
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn nodes(&self) -> &[NodeDesc] {
&self.nodes
}
}
pub fn export_dot<'g, F: Float>(
root: &Tensor<'g, F>,
ctx: &'g Context<'g, F>,
filename: &str,
) -> std::io::Result<()> {
let viz = ComputationGraphViz::from_tensor(root, ctx);
let content = viz.to_dot();
std::fs::write(filename, content)
}
pub fn export_mermaid<'g, F: Float>(
root: &Tensor<'g, F>,
ctx: &'g Context<'g, F>,
filename: &str,
) -> std::io::Result<()> {
let viz = ComputationGraphViz::from_tensor(root, ctx);
let mermaid = viz.to_mermaid();
let content = format!("```mermaid\n{mermaid}```\n");
std::fs::write(filename, content)
}
pub fn print_graph_summary<'a, 'g, F: Float>(
root: &Tensor<'g, F>,
ctx: &'a Context<'g, F>,
) -> String {
let graph: &Graph<F> = std::ops::Deref::deref(ctx);
let ids = collect_reachable(root.id(), graph);
let mut out = String::new();
let sep = "─".repeat(70);
let _ = writeln!(out, "{sep}");
let _ = writeln!(
out,
" {:<30} {:<20} {:<10} {}",
"Layer (op)", "Output Shape", "Params", "Grad"
);
let _ = writeln!(out, "{sep}");
let mut total_params: usize = 0;
let mut trainable_params: usize = 0;
for &id in &ids {
let nd = describe_node(id, graph);
let op_col = if let Some(ref pname) = nd.placeholder_name {
format!("{pname} ({})", nd.op_name)
} else {
nd.op_name.clone()
};
let shape_col = nd
.known_shape
.as_ref()
.map(|sh| format!("{sh:?}"))
.unwrap_or_else(|| "?".to_string());
let params: usize = if nd.is_variable {
nd.known_shape
.as_ref()
.map(|sh| sh.iter().map(|&d| d.max(1) as usize).product())
.unwrap_or(0)
} else {
0
};
total_params += params;
if nd.is_differentiable {
trainable_params += params;
}
let grad_col = if nd.is_differentiable { "yes" } else { "no" };
let _ = writeln!(
out,
" {:<30} {:<20} {:<10} {}",
truncate(&op_col, 29),
truncate(&shape_col, 19),
params,
grad_col
);
}
let _ = writeln!(out, "{sep}");
let _ = writeln!(out, " Total parameters: {total_params}");
let _ = writeln!(out, " Trainable parameters: {trainable_params}");
let _ = writeln!(
out,
" Non-trainable params: {}",
total_params.saturating_sub(trainable_params)
);
let _ = writeln!(out, " Total nodes: {}", ids.len());
let _ = writeln!(out, "{sep}");
out
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}…", &s[..max_len.saturating_sub(1)])
}
}
pub fn count_parameters<'a, 'g, F: Float>(root: &Tensor<'g, F>, ctx: &'a Context<'g, F>) -> usize {
let graph: &Graph<F> = std::ops::Deref::deref(ctx);
let ids = collect_reachable(root.id(), graph);
ids.iter()
.map(|&id| {
let nd = describe_node(id, graph);
if nd.is_variable {
nd.known_shape
.as_ref()
.map(|sh| sh.iter().map(|&d| d.max(1) as usize).product())
.unwrap_or(0)
} else {
0
}
})
.sum()
}
pub fn count_flops<'a, 'g, F: Float>(
root: &Tensor<'g, F>,
ctx: &'a Context<'g, F>,
_input_shape: Option<&[usize]>,
) -> u64 {
let graph: &Graph<F> = std::ops::Deref::deref(ctx);
let ids = collect_reachable(root.id(), graph);
let mut total: u64 = 0;
for &id in &ids {
let nd = describe_node(id, graph);
let elem_count: u64 = nd
.known_shape
.as_ref()
.map(|sh| sh.iter().map(|&d| d.max(1) as u64).product())
.unwrap_or(1);
let op_lower = nd.op_name.to_lowercase();
let flops = if op_lower.contains("matmul")
|| op_lower.contains("dot")
|| op_lower.contains("gemm")
{
2 * elem_count
} else if op_lower.contains("conv") {
9 * elem_count
} else if op_lower.contains("sum")
|| op_lower.contains("mean")
|| op_lower.contains("max")
|| op_lower.contains("min")
|| op_lower.contains("reduce")
{
elem_count
} else if op_lower.contains("relu")
|| op_lower.contains("sigmoid")
|| op_lower.contains("tanh")
|| op_lower.contains("gelu")
|| op_lower.contains("softmax")
{
4 * elem_count
} else if op_lower.contains("norm") || op_lower.contains("batch") {
8 * elem_count
} else if op_lower.contains("add")
|| op_lower.contains("sub")
|| op_lower.contains("mul")
|| op_lower.contains("div")
|| op_lower.contains("pow")
{
elem_count
} else if nd.is_placeholder || nd.is_variable {
0
} else {
elem_count
};
total = total.saturating_add(flops);
}
total
}
#[derive(Debug, Clone)]
pub struct NodeRecord {
pub id: TensorID,
pub op_name: String,
pub shape: Option<Vec<isize>>,
pub differentiable: bool,
pub is_variable: bool,
pub is_placeholder: bool,
pub topo_rank: usize,
pub input_ids: Vec<TensorID>,
}
#[derive(Debug, Clone)]
pub struct GraphNodeTable {
pub records: Vec<NodeRecord>,
}
impl GraphNodeTable {
pub fn from_tensor<'a, 'g, F: Float>(root: &Tensor<'g, F>, ctx: &'a Context<'g, F>) -> Self {
let graph: &Graph<F> = std::ops::Deref::deref(ctx);
let ids = collect_reachable(root.id(), graph);
let records = ids
.iter()
.map(|&id| {
let nd = describe_node(id, graph);
NodeRecord {
id: nd.id,
op_name: nd.op_name,
shape: nd.known_shape,
differentiable: nd.is_differentiable,
is_variable: nd.is_variable,
is_placeholder: nd.is_placeholder,
topo_rank: nd.topo_rank,
input_ids: nd.input_ids,
}
})
.collect();
Self { records }
}
pub fn len(&self) -> usize {
self.records.len()
}
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
pub fn variable_count(&self) -> usize {
self.records.iter().filter(|r| r.is_variable).count()
}
pub fn placeholder_count(&self) -> usize {
self.records.iter().filter(|r| r.is_placeholder).count()
}
pub fn total_parameters(&self) -> usize {
self.records
.iter()
.filter(|r| r.is_variable)
.map(|r| {
r.shape
.as_ref()
.map(|sh| sh.iter().map(|&d| d.max(1) as usize).product())
.unwrap_or(0)
})
.sum()
}
pub fn op_frequencies(&self) -> HashMap<String, usize> {
let mut map: HashMap<String, usize> = HashMap::new();
for r in &self.records {
*map.entry(r.op_name.clone()).or_insert(0) += 1;
}
map
}
}
fn dot_color(nd: &NodeDesc) -> String {
if nd.is_placeholder {
"fillcolor=\"#d5f5d5\", color=\"#6aaf6a\"".to_string()
} else if nd.is_variable {
"fillcolor=\"#fff8d5\", color=\"#c8a830\"".to_string()
} else if nd.is_differentiable {
"fillcolor=\"#d5e8f5\", color=\"#4a8fc0\"".to_string()
} else {
"fillcolor=\"#e8e8e8\", color=\"#999999\"".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops;
fn build_simple_graph_dot() -> String {
let mut out = String::new();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0 + 1.0;
let loss = tensor_ops::reduction::sum_all(y);
let viz = ComputationGraphViz::from_tensor(&loss, ctx);
out = viz.to_dot();
});
out
}
#[test]
fn test_to_dot_contains_digraph() {
let dot = build_simple_graph_dot();
assert!(dot.contains("digraph computation_graph"));
assert!(dot.contains("->"));
assert!(dot.contains('}'));
}
#[test]
fn test_to_mermaid_format() {
let mut mermaid = String::new();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let y = x + x;
let viz = ComputationGraphViz::from_tensor(&y, ctx);
mermaid = viz.to_mermaid();
});
assert!(mermaid.contains("graph BT"));
assert!(mermaid.contains("-->"));
}
#[test]
fn test_print_graph_summary_contains_totals() {
let mut summary = String::new();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0 + 1.0;
let loss = tensor_ops::reduction::sum_all(y);
summary = print_graph_summary(&loss, ctx);
});
assert!(summary.contains("Total parameters"));
assert!(summary.contains("Total nodes"));
assert!(summary.contains("Trainable parameters"));
}
#[test]
fn test_count_parameters_no_variables() {
let mut params = 99usize;
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0;
params = count_parameters(&y, ctx);
});
assert_eq!(params, 0);
}
#[test]
fn test_count_flops_positive() {
let mut flops = 0u64;
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[4, 8]);
let y = x * 2.0;
let loss = tensor_ops::reduction::sum_all(y);
flops = count_flops(&loss, ctx, None);
});
assert!(flops > 0);
}
#[test]
fn test_count_flops_placeholder_zero() {
let mut flops = 999u64;
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[4]);
flops = count_flops(&x, ctx, None);
});
assert_eq!(flops, 0);
}
#[test]
fn test_graph_node_table_basics() {
let mut node_count = 0usize;
let mut placeholder_count = 0usize;
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0 + 1.0;
let loss = tensor_ops::reduction::sum_all(y);
let table = GraphNodeTable::from_tensor(&loss, ctx);
node_count = table.len();
placeholder_count = table.placeholder_count();
});
assert!(node_count > 0);
assert_eq!(placeholder_count, 1);
}
#[test]
fn test_export_dot_writes_file() {
let path = std::env::temp_dir().join("test_export.dot");
let path_str = path.to_string_lossy().to_string();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let y = x * 3.0;
export_dot(&y, ctx, &path_str).expect("export_dot failed");
});
let content = std::fs::read_to_string(&path).expect("read failed");
assert!(content.contains("digraph"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_export_mermaid_writes_file() {
let path = std::env::temp_dir().join("test_export_mermaid.md");
let path_str = path.to_string_lossy().to_string();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let y = x * 3.0;
export_mermaid(&y, ctx, &path_str).expect("export_mermaid failed");
});
let content = std::fs::read_to_string(&path).expect("read failed");
assert!(content.contains("mermaid"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_node_to_dot_placeholder() {
let mut label = String::new();
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let viz = ComputationGraphViz::from_tensor(&x, ctx);
let nd = &viz.nodes()[0];
label = viz.node_to_dot(nd);
});
assert!(label.contains("x"));
}
#[test]
fn test_op_frequencies() {
let mut has_ops = false;
crate::run(|ctx: &mut crate::Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0 + 1.0;
let loss = tensor_ops::reduction::sum_all(y);
let table = GraphNodeTable::from_tensor(&loss, ctx);
let freqs = table.op_frequencies();
has_ops = !freqs.is_empty();
});
assert!(has_ops);
}
}