1use std::fmt::Write as _;
2use std::fs::File;
3use std::io::{BufWriter, Write};
4use std::path::Path;
5use std::process::Command;
6
7use crate::graph::{Graph, Operation};
8
9pub fn graph_to_svg(path: impl AsRef<Path>, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
13 let path = path.as_ref();
14
15 let path_gv = path.with_extension("gv");
16 let path_svg = path.with_extension("svg");
17
18 let output = BufWriter::new(File::create(&path_gv)?);
19 graph_to_dot(output, graph, hide_const, show_ids)?;
20
21 let result = Command::new("dot")
22 .arg("-Tsvg")
23 .arg(path_gv)
24 .arg("-o")
25 .arg(path_svg)
26 .status()?;
27 assert!(result.success(), "Running 'dot' failed with status {:?}", result);
28
29 Ok(())
30}
31
32pub fn graph_to_dot(mut f: impl Write, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
36 writeln!(f, "digraph {{")?;
37 writeln!(f)?;
38
39 for value in graph.values() {
40 if hide_const && graph.is_const(value) {
41 continue;
42 }
43
44 let info = &graph[value];
45
46 let (color, op, attrs_operation) = match info.operation {
47 Operation::Input { index } => ("gray", "Input", vec![("index", format!("{}", index))]),
48 Operation::Constant { ref tensor } => {
49 let mut attrs = vec![];
50 if let Some(single) = tensor.single() {
51 attrs.push(("value", format!("{:?}", single)));
52 }
53 ("gray", "Constant", attrs)
54 }
55 Operation::View { input: _ } => ("brown", "View", vec![]),
56 Operation::Broadcast { input: _ } => ("brown", "Broadcast", vec![]),
57 Operation::Permute {
58 input: _,
59 ref permutation,
60 } => {
61 let attrs = vec![("Permute", format!("{:?}", permutation))];
62 ("brown", "permute", attrs)
63 }
64 Operation::Slice { input: _, axis, range } => {
65 let attrs = vec![("axis", format!("{}", axis)), ("range", format!("{}", range))];
66 ("brown", "Slice", attrs)
67 }
68 Operation::Flip { input: _, axis } => ("brown", "Flip", vec![("axis", format!("{}", axis))]),
69 Operation::Gather {
70 input: _,
71 axis,
72 indices: _,
73 } => ("yellow", "Gather", vec![("axis", format!("{}", axis))]),
74 Operation::Concat { inputs: _, axis } => ("yellow", "Concat", vec![("axis", format!("{}", axis))]),
75 Operation::Conv {
76 input: _,
77 filter: _,
78 details,
79 } => {
80 let mut attrs = vec![("kernel", format!("{}x{}", details.kernel_h, details.kernel_w))];
81 if details.has_stride() {
82 attrs.push(("stride", format!("{}x{}", details.stride_y, details.stride_x)));
83 }
84 if !details.keeps_spatial_shape() {
85 attrs.push(("padding", format!("{}x{}", details.padding_y, details.padding_x)));
86 }
87 ("blue", "Conv", attrs)
88 }
89 Operation::MatMul { left: _, right: _ } => ("blue", "MatMul", vec![]),
90 Operation::Unary { input: _, op } => ("green", "Unary", vec![("op", format!("{:?}", op))]),
91 Operation::Binary { left: _, right: _, op } => ("green", "Binary", vec![("op", format!("{:?}", op))]),
92 Operation::Softmax { input: _, axis } => ("purple", "Softmax", vec![("axis", format!("{}", axis))]),
93 Operation::Layernorm { input: _, axis, eps: _ } => {
94 ("purple", "Layernorm", vec![("axis", format!("{}", axis))])
95 }
96 Operation::Reduce { input: _, ref axes, op } => (
97 "purple",
98 "Reduce",
99 vec![("op", format!("{:?}", op)), ("axes", format!("{:?}", axes))],
100 ),
101 };
102
103 let mut attrs_general = vec![];
104 attrs_general.push(("shape", format!("{}", info.shape)));
105 if let Some(output_index) = graph.outputs().iter().position(|&v| v == value) {
106 attrs_general.push(("output", format!("{}", output_index)));
107 }
108
109 if show_ids {
110 let debug_id = &graph[value].debug_id;
111 if !debug_id.is_empty() {
112 attrs_general.push(("debug_id", format!("{:?}", debug_id)));
113 }
114 }
115
116 let mut attrs = attrs_general;
117 attrs.extend(attrs_operation.into_iter());
118
119 let mut table = String::new();
120 writeln!(&mut table, "<TABLE BORDER=\"0\">").unwrap();
121 writeln!(&mut table, "<TR><TD>{:?}</TD><TD><B>{}</B></TD></TR>", value, op).unwrap();
122 for (key, value) in attrs {
123 writeln!(&mut table, "<TR><TD>{}</TD><TD>{}</TD></TR>", key, value).unwrap();
124 }
125 writeln!(&mut table, "</TABLE>").unwrap();
126
127 let label = table;
128 writeln!(
129 f,
130 "{} [label=<{}>, color={:?}, shape=box, width=2]",
131 value.index(),
132 label,
133 color,
134 )?;
135 }
136
137 writeln!(f)?;
138
139 for value in graph.values() {
140 for operand in graph[value].operation.inputs() {
141 if hide_const && graph.is_const(operand) {
142 continue;
143 }
144
145 writeln!(f, "{} -> {}", operand.index(), value.index())?;
146 }
147 }
148
149 writeln!(f)?;
150 writeln!(f, "}}")?;
151 Ok(())
152}