use crate::{
arena::Handle,
valid::{FunctionInfo, ModuleInfo},
};
use std::{
borrow::Cow,
fmt::{Error as FmtError, Write as _},
};
#[derive(Default)]
pub struct Options {
pub cfg_only: bool,
}
type NodeId = usize;
#[derive(Default, Clone, Copy)]
struct Targets {
continue_target: Option<usize>,
break_target: Option<usize>,
}
#[derive(Default)]
struct StatementGraph {
nodes: Vec<&'static str>,
flow: Vec<(NodeId, NodeId, &'static str)>,
jumps: Vec<(NodeId, NodeId, &'static str, usize)>,
dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>,
emits: Vec<(NodeId, Handle<crate::Expression>)>,
calls: Vec<(NodeId, Handle<crate::Function>)>,
}
impl StatementGraph {
fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) {
use crate::Statement as S;
let root = self.nodes.len();
self.nodes.push(if root == 0 { "Root" } else { "Node" });
let mut last_node = root;
for statement in block {
let id = self.nodes.len();
self.flow.push((last_node, id, ""));
self.nodes.push("");
let mut merge_id = id;
self.nodes[id] = match *statement {
S::Emit(ref range) => {
for handle in range.clone() {
self.emits.push((id, handle));
}
"Emit"
}
S::Kill => "Kill", S::Break => {
if let Some(target) = targets.break_target {
self.jumps.push((id, target, "Break", 5))
} else {
self.jumps.push((id, root, "Broken", 7))
}
"Break"
}
S::Continue => {
if let Some(target) = targets.continue_target {
self.jumps.push((id, target, "Continue", 5))
} else {
self.jumps.push((id, root, "Broken", 7))
}
"Continue"
}
S::Barrier(_flags) => "Barrier",
S::Block(ref b) => {
let (other, last) = self.add(b, targets);
self.flow.push((id, other, ""));
merge_id = last;
"Block"
}
S::If {
condition,
ref accept,
ref reject,
} => {
self.dependencies.push((id, condition, "condition"));
let (accept_id, accept_last) = self.add(accept, targets);
self.flow.push((id, accept_id, "accept"));
let (reject_id, reject_last) = self.add(reject, targets);
self.flow.push((id, reject_id, "reject"));
merge_id = self.nodes.len();
self.nodes.push("Merge");
self.flow.push((accept_last, merge_id, ""));
self.flow.push((reject_last, merge_id, ""));
"If"
}
S::Switch {
selector,
ref cases,
} => {
self.dependencies.push((id, selector, "selector"));
merge_id = self.nodes.len();
self.nodes.push("Merge");
let mut targets = targets;
targets.break_target = Some(merge_id);
for case in cases {
let (case_id, case_last) = self.add(&case.body, targets);
let label = match case.value {
crate::SwitchValue::Integer(_) => "case",
crate::SwitchValue::Default => "default",
};
self.flow.push((id, case_id, label));
self.flow.push((case_last, merge_id, ""));
}
"Switch"
}
S::Loop {
ref body,
ref continuing,
break_if,
} => {
let mut targets = targets;
targets.break_target = Some(id);
let (continuing_id, continuing_last) = self.add(continuing, targets);
targets.continue_target = Some(continuing_id);
let (body_id, body_last) = self.add(body, targets);
self.flow.push((id, body_id, "body"));
self.flow.push((body_last, continuing_id, "continuing"));
self.flow.push((continuing_last, body_id, "continuing"));
if let Some(expr) = break_if {
self.dependencies.push((continuing_id, expr, "break if"));
}
"Loop"
}
S::Return { value } => {
if let Some(expr) = value {
self.dependencies.push((id, expr, "value"));
}
"Return"
}
S::Store { pointer, value } => {
self.dependencies.push((id, value, "value"));
self.emits.push((id, pointer));
"Store"
}
S::ImageStore {
image,
coordinate,
array_index,
value,
} => {
self.dependencies.push((id, image, "image"));
self.dependencies.push((id, coordinate, "coordinate"));
if let Some(expr) = array_index {
self.dependencies.push((id, expr, "array_index"));
}
self.dependencies.push((id, value, "value"));
"ImageStore"
}
S::Call {
function,
ref arguments,
result,
} => {
for &arg in arguments {
self.dependencies.push((id, arg, "arg"));
}
if let Some(expr) = result {
self.emits.push((id, expr));
}
self.calls.push((id, function));
"Call"
}
S::Atomic {
pointer,
ref fun,
value,
result,
} => {
self.emits.push((id, result));
self.dependencies.push((id, pointer, "pointer"));
self.dependencies.push((id, value, "value"));
if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
self.dependencies.push((id, cmp, "cmp"));
}
"Atomic"
}
};
last_node = merge_id;
}
(root, last_node)
}
}
#[allow(clippy::manual_unwrap_or)]
fn name(option: &Option<String>) -> &str {
match *option {
Some(ref name) => name,
None => "",
}
}
const COLORS: &[&str] = &[
"white", "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5",
"#d9d9d9",
];
fn write_fun(
output: &mut String,
prefix: String,
fun: &crate::Function,
info: Option<&FunctionInfo>,
options: &Options,
) -> Result<(), FmtError> {
writeln!(output, "\t\tnode [ style=filled ]")?;
if !options.cfg_only {
for (handle, var) in fun.local_variables.iter() {
writeln!(
output,
"\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]",
prefix,
handle.index(),
handle,
name(&var.name),
)?;
}
write_function_expressions(output, &prefix, fun, info)?;
}
let mut sg = StatementGraph::default();
sg.add(&fun.body, Targets::default());
for (index, label) in sg.nodes.into_iter().enumerate() {
writeln!(
output,
"\t\t{}_s{} [ shape=square label=\"{}\" ]",
prefix, index, label,
)?;
}
for (from, to, label) in sg.flow {
writeln!(
output,
"\t\t{}_s{} -> {}_s{} [ arrowhead=tee label=\"{}\" ]",
prefix, from, prefix, to, label,
)?;
}
for (from, to, label, color_id) in sg.jumps {
writeln!(
output,
"\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]",
prefix, from, prefix, to, COLORS[color_id], label,
)?;
}
if !options.cfg_only {
for (to, expr, label) in sg.dependencies {
writeln!(
output,
"\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]",
prefix,
expr.index(),
prefix,
to,
label,
)?;
}
for (from, to) in sg.emits {
writeln!(
output,
"\t\t{}_s{} -> {}_e{} [ style=dotted ]",
prefix,
from,
prefix,
to.index(),
)?;
}
}
for (from, function) in sg.calls {
writeln!(
output,
"\t\t{}_s{} -> f{}_s0",
prefix,
from,
function.index(),
)?;
}
Ok(())
}
fn write_function_expressions(
output: &mut String,
prefix: &str,
fun: &crate::Function,
info: Option<&FunctionInfo>,
) -> Result<(), FmtError> {
enum Payload<'a> {
Arguments(&'a [Handle<crate::Expression>]),
Local(Handle<crate::LocalVariable>),
Global(Handle<crate::GlobalVariable>),
}
let mut edges = crate::FastHashMap::<&str, _>::default();
let mut payload = None;
for (handle, expression) in fun.expressions.iter() {
use crate::Expression as E;
let (label, color_id) = match *expression {
E::Access { base, index } => {
edges.insert("base", base);
edges.insert("index", index);
("Access".into(), 1)
}
E::AccessIndex { base, index } => {
edges.insert("base", base);
(format!("AccessIndex[{}]", index).into(), 1)
}
E::Constant(_) => ("Constant".into(), 2),
E::Splat { size, value } => {
edges.insert("value", value);
(format!("Splat{:?}", size).into(), 3)
}
E::Swizzle {
size,
vector,
pattern,
} => {
edges.insert("vector", vector);
(format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3)
}
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));
("Compose".into(), 3)
}
E::FunctionArgument(index) => (format!("Argument[{}]", index).into(), 1),
E::GlobalVariable(h) => {
payload = Some(Payload::Global(h));
("Global".into(), 2)
}
E::LocalVariable(h) => {
payload = Some(Payload::Local(h));
("Local".into(), 1)
}
E::Load { pointer } => {
edges.insert("pointer", pointer);
("Load".into(), 4)
}
E::ImageSample {
image,
sampler,
gather,
coordinate,
array_index,
offset: _,
level,
depth_ref,
} => {
edges.insert("image", image);
edges.insert("sampler", sampler);
edges.insert("coordinate", coordinate);
if let Some(expr) = array_index {
edges.insert("array_index", expr);
}
match level {
crate::SampleLevel::Auto => {}
crate::SampleLevel::Zero => {}
crate::SampleLevel::Exact(expr) => {
edges.insert("level", expr);
}
crate::SampleLevel::Bias(expr) => {
edges.insert("bias", expr);
}
crate::SampleLevel::Gradient { x, y } => {
edges.insert("grad_x", x);
edges.insert("grad_y", y);
}
}
if let Some(expr) = depth_ref {
edges.insert("depth_ref", expr);
}
let string = match gather {
Some(component) => Cow::Owned(format!("ImageGather{:?}", component)),
_ => Cow::Borrowed("ImageSample"),
};
(string, 5)
}
E::ImageLoad {
image,
coordinate,
array_index,
sample,
level,
} => {
edges.insert("image", image);
edges.insert("coordinate", coordinate);
if let Some(expr) = array_index {
edges.insert("array_index", expr);
}
if let Some(sample) = sample {
edges.insert("sample", sample);
}
if let Some(level) = level {
edges.insert("level", level);
}
("ImageLoad".into(), 5)
}
E::ImageQuery { image, query } => {
edges.insert("image", image);
let args = match query {
crate::ImageQuery::Size { level } => {
if let Some(expr) = level {
edges.insert("level", expr);
}
Cow::from("ImageSize")
}
_ => Cow::Owned(format!("{:?}", query)),
};
(args, 7)
}
E::Unary { op, expr } => {
edges.insert("expr", expr);
(format!("{:?}", op).into(), 6)
}
E::Binary { op, left, right } => {
edges.insert("left", left);
edges.insert("right", right);
(format!("{:?}", op).into(), 6)
}
E::Select {
condition,
accept,
reject,
} => {
edges.insert("condition", condition);
edges.insert("accept", accept);
edges.insert("reject", reject);
("Select".into(), 3)
}
E::Derivative { axis, expr } => {
edges.insert("", expr);
(format!("d{:?}", axis).into(), 8)
}
E::Relational { fun, argument } => {
edges.insert("arg", argument);
(format!("{:?}", fun).into(), 6)
}
E::Math {
fun,
arg,
arg1,
arg2,
arg3,
} => {
edges.insert("arg", arg);
if let Some(expr) = arg1 {
edges.insert("arg1", expr);
}
if let Some(expr) = arg2 {
edges.insert("arg2", expr);
}
if let Some(expr) = arg3 {
edges.insert("arg3", expr);
}
(format!("{:?}", fun).into(), 7)
}
E::As {
kind,
expr,
convert,
} => {
edges.insert("", expr);
let string = match convert {
Some(width) => format!("Convert<{:?},{}>", kind, width),
None => format!("Bitcast<{:?}>", kind),
};
(string.into(), 3)
}
E::CallResult(_function) => ("CallResult".into(), 4),
E::AtomicResult { .. } => ("AtomicResult".into(), 4),
E::ArrayLength(expr) => {
edges.insert("", expr);
("ArrayLength".into(), 7)
}
};
let color_attr = match info {
Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor",
_ => "color",
};
writeln!(
output,
"\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]",
prefix,
handle.index(),
color_attr,
COLORS[color_id],
handle,
label,
)?;
for (key, edge) in edges.drain() {
writeln!(
output,
"\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]",
prefix,
edge.index(),
prefix,
handle.index(),
key,
)?;
}
match payload.take() {
Some(Payload::Arguments(list)) => {
write!(output, "\t\t{{")?;
for &comp in list {
write!(output, " {}_e{}", prefix, comp.index())?;
}
writeln!(output, " }} -> {}_e{}", prefix, handle.index())?;
}
Some(Payload::Local(h)) => {
writeln!(
output,
"\t\t{}_l{} -> {}_e{}",
prefix,
h.index(),
prefix,
handle.index(),
)?;
}
Some(Payload::Global(h)) => {
writeln!(
output,
"\t\tg{} -> {}_e{} [fillcolor=gray]",
h.index(),
prefix,
handle.index(),
)?;
}
None => {}
}
}
Ok(())
}
pub fn write(
module: &crate::Module,
mod_info: Option<&ModuleInfo>,
options: Options,
) -> Result<String, FmtError> {
use std::fmt::Write as _;
let mut output = String::new();
output += "digraph Module {\n";
if !options.cfg_only {
writeln!(output, "\tsubgraph cluster_globals {{")?;
writeln!(output, "\t\tlabel=\"Globals\"")?;
for (handle, var) in module.global_variables.iter() {
writeln!(
output,
"\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]",
handle.index(),
handle,
var.space,
name(&var.name),
)?;
}
writeln!(output, "\t}}")?;
}
for (handle, fun) in module.functions.iter() {
let prefix = format!("f{}", handle.index());
writeln!(output, "\tsubgraph cluster_{} {{", prefix)?;
writeln!(
output,
"\t\tlabel=\"Function{:?}/'{}'\"",
handle,
name(&fun.name)
)?;
let info = mod_info.map(|a| &a[handle]);
write_fun(&mut output, prefix, fun, info, &options)?;
writeln!(output, "\t}}")?;
}
for (ep_index, ep) in module.entry_points.iter().enumerate() {
let prefix = format!("ep{}", ep_index);
writeln!(output, "\tsubgraph cluster_{} {{", prefix)?;
writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?;
let info = mod_info.map(|a| a.get_entry_point(ep_index));
write_fun(&mut output, prefix, &ep.function, info, &options)?;
writeln!(output, "\t}}")?;
}
output += "}\n";
Ok(output)
}