use super::{BinaryOp, Coercion, Comparison, Filter, Mir, MirEdge, MirGraph, MirNode};
use crate::viz::{Id, IdGen};
use ::petgraph::graph::NodeIndex;
use ::std::collections::HashMap;
#[derive(Debug, Hash, PartialEq, Eq)]
struct GlobalNodeKey {
local_idx: NodeIndex,
global_region_key: usize,
}
impl GlobalNodeKey {
fn new(region: &RegionView<'_>, local_idx: NodeIndex) -> Self {
let global_region_key = region.graph as *const MirGraph as usize;
Self {
local_idx,
global_region_key,
}
}
}
fn push_node(dot: &mut String, id: &Id, label: &str) {
id.fmt(dot);
dot.push_str(" [label = \"");
dot.push_str(label);
dot.push_str("\"]\n");
}
#[derive(Clone)]
enum RId {
Simple(Id),
Structural { id: Id, end_id: Box<RId> },
}
impl RId {
fn fmt(&self, out: &mut String) {
match self {
RId::Simple(id) => id.fmt(out),
RId::Structural { id, end_id: _ } => RId::region_fmt(id, out),
}
}
fn region_fmt(id: &Id, out: &mut String) {
out.push_str("cluster_");
id.fmt(out)
}
}
#[derive(Clone, Copy)]
struct RegionView<'a> {
graph: &'a MirGraph,
#[allow(dead_code)]
end: NodeIndex,
}
fn dot_inner(
mir: RegionView,
gen: &mut IdGen,
id_map: &mut HashMap<GlobalNodeKey, RId>,
dot: &mut String,
) {
for node in mir.graph.node_indices() {
macro_rules! push_node {
($val:expr) => {{
let id = gen.next();
push_node(&mut *dot, &id, &*$val);
id_map.insert(GlobalNodeKey::new(&mir, node), RId::Simple(id));
}};
}
macro_rules! push_region {
($name:expr, $region:expr) => {{
match $region {
region => {
let region_id = gen.next();
dot.push_str("subgraph ");
RId::region_fmt(®ion_id, &mut *dot);
dot.push_str(" {\n");
let view = RegionView {
graph: ®ion.graph,
end: region.end,
};
dot_inner(view, &mut *gen, &mut *id_map, &mut *dot);
let end_id = id_map[&GlobalNodeKey::new(&view, region.end)].clone();
dot.push_str("label = \"");
dot.push_str(&*$name);
dot.push_str("\"\n");
dot.push_str("}\n");
id_map.insert(
GlobalNodeKey::new(&mir, node),
RId::Structural {
id: region_id,
end_id: Box::new(end_id),
},
);
}
}
}};
}
match &mir.graph[node] {
MirNode::Integer(val) => push_node!(format!("{}", val)),
MirNode::Coerce(Coercion::FromOutputToInt) => push_node!("Coerce(FromOutputToInt)"),
MirNode::Roll => push_node!("Roll"),
MirNode::BinOp(BinaryOp::Add) => push_node!("+"),
MirNode::BinOp(BinaryOp::Subtract) => push_node!("-"),
MirNode::BinOp(BinaryOp::LogicalAnd) => push_node!("AND"),
MirNode::Filter(Filter::Simple(filter)) => push_node!(format!("Filter({:?})", filter)),
MirNode::Filter(Filter::SatisfiesPredicate) => push_node!("Filter(Satisfies)"),
MirNode::Apply => push_node!("Apply"),
MirNode::PartialApply => push_node!("PartialApply"),
MirNode::Compare(Comparison::Equal) => push_node!("Compare(Equal)"),
MirNode::Compare(Comparison::GreaterThan) => push_node!("Compare(GreaterThan)"),
MirNode::Count => push_node!("Count"),
MirNode::Loop(body, _ty) => push_region!("Loop", body),
MirNode::Decision(_) => todo!("visualizing decision points"),
MirNode::FunctionDefinition(body) => push_region!("Function", body),
MirNode::RecursiveEnvironment(body) => push_region!("Recursive Environment", body),
MirNode::RegionArgument(_) => push_node!("Region Argument"),
MirNode::End => push_node!("End"),
MirNode::Fmt(node) => push_node!(format!("Fmt({:?})", node)),
MirNode::UseFuel(_) => push_node!("UseFuel"),
}
}
for edge in mir.graph.raw_edges().iter() {
let source = &id_map[&GlobalNodeKey::new(&mir, edge.source())];
let target = &id_map[&GlobalNodeKey::new(&mir, edge.target())];
source.fmt(&mut *dot);
dot.push_str(" -> ");
match target {
RId::Simple(_) => {
target.fmt(&mut *dot);
}
RId::Structural { id: _, end_id } => {
end_id.fmt(&mut *dot);
}
};
match edge.weight {
MirEdge::IntermediateResultDependency { .. } => dot.push_str(" [color=cornflowerblue]"),
_ => (),
}
dot.push_str("\n");
}
}
pub fn dot(mir: &Mir) -> String {
let mut out = String::new();
let mut gen = IdGen::new();
let mut id_map: HashMap<GlobalNodeKey, RId> = HashMap::new();
out.push_str("strict digraph {\n");
out.push_str("compound=true\n");
dot_inner(
RegionView {
graph: &mir.graph,
end: mir.top,
},
&mut gen,
&mut id_map,
&mut out,
);
out.push_str("}");
out
}