kn_graph/
dot.rs

1use std::fmt::Write as _;
2use std::fs::File;
3use std::io::{BufWriter, Write};
4use std::path::Path;
5use std::process::Command;
6
7use crate::graph::{Graph, Operation};
8
9/// Render the given graph as an svg file.
10///
11/// This assumes that graphviz is installed and available on the path as `dot`.
12pub fn graph_to_svg(path: impl AsRef<Path>, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
13    let path = path.as_ref();
14
15    let path_gv = path.with_extension("gv");
16    let path_svg = path.with_extension("svg");
17
18    let output = BufWriter::new(File::create(&path_gv)?);
19    graph_to_dot(output, graph, hide_const, show_ids)?;
20
21    let result = Command::new("dot")
22        .arg("-Tsvg")
23        .arg(path_gv)
24        .arg("-o")
25        .arg(path_svg)
26        .status()?;
27    assert!(result.success(), "Running 'dot' failed with status {:?}", result);
28
29    Ok(())
30}
31
32/// Render the given graph as a graphviz string.
33///
34/// This makes no assumptions about the environment.
35pub fn graph_to_dot(mut f: impl Write, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
36    writeln!(f, "digraph {{")?;
37    writeln!(f)?;
38
39    for value in graph.values() {
40        if hide_const && graph.is_const(value) {
41            continue;
42        }
43
44        let info = &graph[value];
45
46        let (color, op, attrs_operation) = match info.operation {
47            Operation::Input { index } => ("gray", "Input", vec![("index", format!("{}", index))]),
48            Operation::Constant { ref tensor } => {
49                let mut attrs = vec![];
50                if let Some(single) = tensor.single() {
51                    attrs.push(("value", format!("{:?}", single)));
52                }
53                ("gray", "Constant", attrs)
54            }
55            Operation::View { input: _ } => ("brown", "View", vec![]),
56            Operation::Broadcast { input: _ } => ("brown", "Broadcast", vec![]),
57            Operation::Permute {
58                input: _,
59                ref permutation,
60            } => {
61                let attrs = vec![("Permute", format!("{:?}", permutation))];
62                ("brown", "permute", attrs)
63            }
64            Operation::Slice { input: _, axis, range } => {
65                let attrs = vec![("axis", format!("{}", axis)), ("range", format!("{}", range))];
66                ("brown", "Slice", attrs)
67            }
68            Operation::Flip { input: _, axis } => ("brown", "Flip", vec![("axis", format!("{}", axis))]),
69            Operation::Gather {
70                input: _,
71                axis,
72                indices: _,
73            } => ("yellow", "Gather", vec![("axis", format!("{}", axis))]),
74            Operation::Concat { inputs: _, axis } => ("yellow", "Concat", vec![("axis", format!("{}", axis))]),
75            Operation::Conv {
76                input: _,
77                filter: _,
78                details,
79            } => {
80                let mut attrs = vec![("kernel", format!("{}x{}", details.kernel_h, details.kernel_w))];
81                if details.has_stride() {
82                    attrs.push(("stride", format!("{}x{}", details.stride_y, details.stride_x)));
83                }
84                if !details.keeps_spatial_shape() {
85                    attrs.push(("padding", format!("{}x{}", details.padding_y, details.padding_x)));
86                }
87                ("blue", "Conv", attrs)
88            }
89            Operation::MatMul { left: _, right: _ } => ("blue", "MatMul", vec![]),
90            Operation::Unary { input: _, op } => ("green", "Unary", vec![("op", format!("{:?}", op))]),
91            Operation::Binary { left: _, right: _, op } => ("green", "Binary", vec![("op", format!("{:?}", op))]),
92            Operation::Softmax { input: _, axis } => ("purple", "Softmax", vec![("axis", format!("{}", axis))]),
93            Operation::Layernorm { input: _, axis, eps: _ } => {
94                ("purple", "Layernorm", vec![("axis", format!("{}", axis))])
95            }
96            Operation::Reduce { input: _, ref axes, op } => (
97                "purple",
98                "Reduce",
99                vec![("op", format!("{:?}", op)), ("axes", format!("{:?}", axes))],
100            ),
101        };
102
103        let mut attrs_general = vec![];
104        attrs_general.push(("shape", format!("{}", info.shape)));
105        if let Some(output_index) = graph.outputs().iter().position(|&v| v == value) {
106            attrs_general.push(("output", format!("{}", output_index)));
107        }
108
109        if show_ids {
110            let debug_id = &graph[value].debug_id;
111            if !debug_id.is_empty() {
112                attrs_general.push(("debug_id", format!("{:?}", debug_id)));
113            }
114        }
115
116        let mut attrs = attrs_general;
117        attrs.extend(attrs_operation.into_iter());
118
119        let mut table = String::new();
120        writeln!(&mut table, "<TABLE BORDER=\"0\">").unwrap();
121        writeln!(&mut table, "<TR><TD>{:?}</TD><TD><B>{}</B></TD></TR>", value, op).unwrap();
122        for (key, value) in attrs {
123            writeln!(&mut table, "<TR><TD>{}</TD><TD>{}</TD></TR>", key, value).unwrap();
124        }
125        writeln!(&mut table, "</TABLE>").unwrap();
126
127        let label = table;
128        writeln!(
129            f,
130            "{} [label=<{}>, color={:?}, shape=box, width=2]",
131            value.index(),
132            label,
133            color,
134        )?;
135    }
136
137    writeln!(f)?;
138
139    for value in graph.values() {
140        for operand in graph[value].operation.inputs() {
141            if hide_const && graph.is_const(operand) {
142                continue;
143            }
144
145            writeln!(f, "{} -> {}", operand.index(), value.index())?;
146        }
147    }
148
149    writeln!(f)?;
150    writeln!(f, "}}")?;
151    Ok(())
152}