use std::collections::BTreeMap;
use std::fmt::Write as _;
use crate::Graph;
pub fn pretty_print(g: &Graph) -> String {
let mut out = String::new();
writeln!(out, "{}", header_line(g)).unwrap();
writeln!(out, "{}", op_kinds_line(g)).unwrap();
writeln!(out).unwrap();
let mut tag_w = 0usize;
for n in g.nodes() {
let t = node_tag(n.id, n.name.as_deref(), &n.op);
if t.len() > tag_w {
tag_w = t.len();
}
}
for n in g.nodes() {
let tag = node_tag(n.id, n.name.as_deref(), &n.op);
write!(out, " {tag:<width$} = {}", n.op, width = tag_w).unwrap();
if !n.inputs.is_empty() {
write!(out, "(").unwrap();
for (i, inp) in n.inputs.iter().enumerate() {
if i > 0 {
write!(out, ", ").unwrap();
}
write!(out, "{inp}").unwrap();
}
write!(out, ")").unwrap();
}
write!(out, " : {}", n.shape).unwrap();
if let Some(ref o) = n.origin {
write!(out, " // {}", o).unwrap();
}
if g.outputs.contains(&n.id) {
write!(out, " ← output").unwrap();
}
writeln!(out).unwrap();
}
if !g.outputs.is_empty() {
write!(out, " return ").unwrap();
for (i, o) in g.outputs.iter().enumerate() {
if i > 0 {
write!(out, ", ").unwrap();
}
write!(out, "{o}").unwrap();
}
writeln!(out).unwrap();
}
out
}
pub fn pretty_stats(g: &Graph) -> String {
format!("{}\n{}", header_line(g), op_kinds_line(g))
}
pub(crate) fn header_line(g: &Graph) -> String {
let arena_bytes: usize = g.nodes().iter().filter_map(|n| n.shape.size_bytes()).sum();
format!(
"graph @{} ({} nodes, {} outputs, {} arena)",
g.name,
g.len(),
g.outputs.len(),
human_bytes(arena_bytes),
)
}
pub(crate) fn op_kinds_line(g: &Graph) -> String {
let mut hist: BTreeMap<String, usize> = BTreeMap::new();
for n in g.nodes() {
*hist.entry(format!("{:?}", n.op.kind())).or_insert(0) += 1;
}
let mut entries: Vec<(String, usize)> = hist.into_iter().collect();
entries.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
let parts: Vec<String> = entries.iter().map(|(k, c)| format!("{k}={c}")).collect();
format!(" op kinds: {}", parts.join(", "))
}
fn node_tag(id: crate::NodeId, name: Option<&str>, op: &crate::Op) -> String {
use crate::Op;
let label: Option<String> = match op {
Op::Input { name } => Some(format!("input \"{name}\"")),
Op::Param { name } => Some(format!("param \"{name}\"")),
_ => name.map(|s| format!("\"{s}\"")),
};
match label {
Some(s) => format!("{id} [{s}]"),
None => format!("{id}"),
}
}
fn human_bytes(b: usize) -> String {
const K: f64 = 1024.0;
let bf = b as f64;
if bf < K {
format!("{b} B")
} else if bf < K * K {
format!("{:.1} KB", bf / K)
} else if bf < K * K * K {
format!("{:.1} MB", bf / (K * K))
} else {
format!("{:.1} GB", bf / (K * K * K))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DType, Graph, Shape, op::BinaryOp};
#[test]
fn pretty_print_basic() {
let mut g = Graph::new("basic");
let x = g.input("x", Shape::new(&[4, 4], DType::F32));
let y = g.input("y", Shape::new(&[4, 4], DType::F32));
let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4, 4], DType::F32));
g.set_outputs(vec![z]);
let s = pretty_print(&g);
assert!(s.contains("graph @basic"));
assert!(s.contains("nodes"));
assert!(s.contains("Input=2"));
assert!(s.contains("Binary=1"));
assert!(s.contains("← output"));
assert!(s.contains("return %2"));
}
#[test]
fn pretty_stats_no_body() {
let mut g = Graph::new("s");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.input("y", Shape::new(&[4], DType::F32));
let _ = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
let s = pretty_stats(&g);
assert!(s.contains("3 nodes"));
assert!(!s.contains("%0 = input"));
}
#[test]
fn human_bytes_scales() {
assert_eq!(human_bytes(0), "0 B");
assert_eq!(human_bytes(1023), "1023 B");
assert_eq!(human_bytes(1024), "1.0 KB");
assert_eq!(human_bytes(1024 * 1024), "1.0 MB");
assert_eq!(human_bytes(2 * 1024 * 1024 * 1024), "2.0 GB");
}
#[test]
fn outputs_marker_present() {
let mut g = Graph::new("o");
let a = g.input("a", Shape::new(&[2], DType::F32));
let b = g.input("b", Shape::new(&[2], DType::F32));
let c = g.binary(BinaryOp::Add, a, b, Shape::new(&[2], DType::F32));
let d = g.binary(BinaryOp::Add, c, a, Shape::new(&[2], DType::F32));
g.set_outputs(vec![d]);
let s = pretty_print(&g);
let lines: Vec<&str> = s.lines().collect();
let count = lines.iter().filter(|l| l.contains("← output")).count();
assert_eq!(count, 1, "expected exactly one output marker, got {count}");
}
}