kn-graph 0.7.3

A neural network inference graph intermediate representation, with surrounding utilities.
Documentation
use std::fmt::Write as _;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use std::process::Command;

use crate::graph::{Graph, Operation};

/// Render the given graph as an svg file.
///
/// This assumes that graphviz is installed and available on the path as `dot`.
pub fn graph_to_svg(path: impl AsRef<Path>, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
    let path = path.as_ref();

    let path_gv = path.with_extension("gv");
    let path_svg = path.with_extension("svg");

    let output = BufWriter::new(File::create(&path_gv)?);
    graph_to_dot(output, graph, hide_const, show_ids)?;

    let result = Command::new("dot")
        .arg("-Tsvg")
        .arg(path_gv)
        .arg("-o")
        .arg(path_svg)
        .status()?;
    assert!(result.success(), "Running 'dot' failed with status {:?}", result);

    Ok(())
}

/// Render the given graph as a graphviz string.
///
/// This makes no assumptions about the environment.
pub fn graph_to_dot(mut f: impl Write, graph: &Graph, hide_const: bool, show_ids: bool) -> std::io::Result<()> {
    writeln!(f, "digraph {{")?;
    writeln!(f)?;

    for value in graph.values() {
        if hide_const && graph.is_const(value) {
            continue;
        }

        let info = &graph[value];

        let (color, op, attrs_operation) = match info.operation {
            Operation::Input { index } => ("gray", "Input", vec![("index", format!("{}", index))]),
            Operation::Constant { ref tensor } => {
                let mut attrs = vec![];
                if let Some(single) = tensor.single() {
                    attrs.push(("value", format!("{:?}", single)));
                }
                ("gray", "Constant", attrs)
            }
            Operation::View { input: _ } => ("brown", "View", vec![]),
            Operation::Broadcast { input: _ } => ("brown", "Broadcast", vec![]),
            Operation::Permute {
                input: _,
                ref permutation,
            } => {
                let attrs = vec![("Permute", format!("{:?}", permutation))];
                ("brown", "permute", attrs)
            }
            Operation::Slice { input: _, axis, range } => {
                let attrs = vec![("axis", format!("{}", axis)), ("range", format!("{}", range))];
                ("brown", "Slice", attrs)
            }
            Operation::Flip { input: _, axis } => ("brown", "Flip", vec![("axis", format!("{}", axis))]),
            Operation::Gather {
                input: _,
                axis,
                indices: _,
            } => ("yellow", "Gather", vec![("axis", format!("{}", axis))]),
            Operation::Concat { inputs: _, axis } => ("yellow", "Concat", vec![("axis", format!("{}", axis))]),
            Operation::Conv {
                input: _,
                filter: _,
                details,
            } => {
                let mut attrs = vec![("kernel", format!("{}x{}", details.kernel_h, details.kernel_w))];
                if details.has_stride() {
                    attrs.push(("stride", format!("{}x{}", details.stride_y, details.stride_x)));
                }
                if !details.keeps_spatial_shape() {
                    attrs.push(("padding", format!("{}x{}", details.padding_y, details.padding_x)));
                }
                ("blue", "Conv", attrs)
            }
            Operation::MatMul { left: _, right: _ } => ("blue", "MatMul", vec![]),
            Operation::Unary { input: _, op } => ("green", "Unary", vec![("op", format!("{:?}", op))]),
            Operation::Binary { left: _, right: _, op } => ("green", "Binary", vec![("op", format!("{:?}", op))]),
            Operation::Softmax { input: _, axis } => ("purple", "Softmax", vec![("axis", format!("{}", axis))]),
            Operation::Layernorm { input: _, axis, eps: _ } => {
                ("purple", "Layernorm", vec![("axis", format!("{}", axis))])
            }
            Operation::Reduce { input: _, ref axes, op } => (
                "purple",
                "Reduce",
                vec![("op", format!("{:?}", op)), ("axes", format!("{:?}", axes))],
            ),
        };

        let mut attrs_general = vec![];
        attrs_general.push(("shape", format!("{}", info.shape)));
        if let Some(output_index) = graph.outputs().iter().position(|&v| v == value) {
            attrs_general.push(("output", format!("{}", output_index)));
        }

        if show_ids {
            let debug_id = &graph[value].debug_id;
            if !debug_id.is_empty() {
                attrs_general.push(("debug_id", format!("{:?}", debug_id)));
            }
        }

        let mut attrs = attrs_general;
        attrs.extend(attrs_operation.into_iter());

        let mut table = String::new();
        writeln!(&mut table, "<TABLE BORDER=\"0\">").unwrap();
        writeln!(&mut table, "<TR><TD>{:?}</TD><TD><B>{}</B></TD></TR>", value, op).unwrap();
        for (key, value) in attrs {
            writeln!(&mut table, "<TR><TD>{}</TD><TD>{}</TD></TR>", key, value).unwrap();
        }
        writeln!(&mut table, "</TABLE>").unwrap();

        let label = table;
        writeln!(
            f,
            "{} [label=<{}>, color={:?}, shape=box, width=2]",
            value.index(),
            label,
            color,
        )?;
    }

    writeln!(f)?;

    for value in graph.values() {
        for operand in graph[value].operation.inputs() {
            if hide_const && graph.is_const(operand) {
                continue;
            }

            writeln!(f, "{} -> {}", operand.index(), value.index())?;
        }
    }

    writeln!(f)?;
    writeln!(f, "}}")?;
    Ok(())
}