use crate::{
binder::{SymbolKey, SymbolTable},
document::Document,
idx::InternIdent,
};
use petgraph::graph::{Graph, NodeIndex};
use wat_syntax::{
GreenToken, SyntaxKind, SyntaxNode, SyntaxNodePtr,
ast::{AstNode, BlockInstr, Cat, Instr, support},
};
#[salsa::tracked(returns(ref))]
pub fn analyze(db: &dyn salsa::Database, document: Document, ptr: SyntaxNodePtr) -> ControlFlowGraph {
let root = document.root_tree(db);
Builder::new(db, document).build(ptr.to_node(&root).expect("invalid ptr in control flow analysis"))
}
struct Builder<'db> {
db: &'db dyn salsa::Database,
symbol_table: &'db SymbolTable<'db>,
graph: Graph<FlowNode, FlowEdge>,
block_stack: Vec<(NodeIndex, Option<InternIdent<'db>>)>,
current: Option<NodeIndex>,
bb_instrs: Vec<BasicBlockInstr>,
unreachable: bool,
}
impl<'db> Builder<'db> {
fn new(db: &'db dyn salsa::Database, document: Document) -> Self {
Self {
db,
symbol_table: SymbolTable::of(db, document),
graph: Graph::new(),
block_stack: Vec::new(),
current: None,
bb_instrs: Vec::with_capacity(2),
unreachable: false,
}
}
fn build(mut self, node: SyntaxNode) -> ControlFlowGraph {
let entry = self.graph.add_node(FlowNode {
kind: FlowNodeKind::Entry,
unreachable: false,
});
self.current = Some(entry);
let exit = self.graph.add_node(FlowNode {
kind: FlowNodeKind::Exit,
unreachable: false,
});
self.block_stack.push((exit, None));
self.visit_block_like(&node, exit);
self.finish_exit(exit);
ControlFlowGraph {
graph: self.graph,
entry,
}
}
fn visit_instrs(&mut self, node: &SyntaxNode) {
support::children::<Instr>(node).for_each(|instr| match instr {
Instr::Plain(plain) => {
self.visit_instrs(plain.syntax());
let Some(instr_name) = plain.instr_name() else {
return;
};
let instr_name = instr_name.green();
self.bb_instrs.push(BasicBlockInstr {
ptr: SyntaxNodePtr::new(plain.syntax()),
name: instr_name.to_owned(),
immediates: plain
.immediates()
.map(|immediate| SyntaxNodePtr::new(immediate.syntax()))
.collect(),
});
let unreachable = match instr_name.text() {
"unreachable" | "throw" | "throw_ref" => {
self.add_basic_block();
true
}
"return" | "return_call" | "return_call_indirect" | "return_call_ref" => {
if let Some((bb, (exit, _))) = self.add_basic_block().zip(self.block_stack.first()) {
self.graph.add_edge(bb, *exit, FlowEdge { loop_back: false });
}
true
}
"br" | "br_table" => {
if let Some(bb) = self.add_basic_block() {
for immediate in plain.immediates() {
if let Some(target) = self.find_jump_target(immediate) {
self.graph.add_edge(
bb,
target,
FlowEdge {
loop_back: self.is_loop_entry(target),
},
);
}
}
}
true
}
"br_if" | "br_on_null" | "br_on_non_null" | "br_on_cast" | "br_on_cast_fail" => {
let target = plain
.immediates()
.next()
.and_then(|immediate| self.find_jump_target(immediate));
if let Some((bb, target)) = self.add_basic_block().zip(target) {
self.graph.add_edge(
bb,
target,
FlowEdge {
loop_back: self.is_loop_entry(target),
},
);
}
false
}
"resume" | "resume_throw" | "resume_throw_ref" => {
if let Some(bb) = self.add_basic_block() {
for index in plain
.immediates()
.filter_map(|immediate| immediate.on_clause())
.filter_map(|on_clause| on_clause.label_index())
{
if let Some(target) = self.find_jump_target(index) {
self.graph.add_edge(
bb,
target,
FlowEdge {
loop_back: self.is_loop_entry(target),
},
);
}
}
}
false
}
_ => false,
};
if unreachable {
self.current = None;
}
self.unreachable |= unreachable;
}
Instr::Block(block) => {
self.add_basic_block();
let block_entry = self.graph.add_node(FlowNode {
kind: FlowNodeKind::BlockEntry(SyntaxNodePtr::new(block.syntax())),
unreachable: self.unreachable,
});
self.connect_current_to(block_entry);
self.current = Some(block_entry);
let block_exit = self.graph.add_node(FlowNode {
kind: FlowNodeKind::BlockExit,
unreachable: false,
});
let ident = support::token(block.syntax(), SyntaxKind::IDENT)
.map(|token| InternIdent::new(self.db, token.text()));
match block {
BlockInstr::Block(block_block) => {
self.block_stack.push((block_exit, ident));
self.visit_block_like(block_block.syntax(), block_exit);
}
BlockInstr::Loop(block_loop) => {
self.block_stack.push((block_entry, ident));
self.visit_block_like(block_loop.syntax(), block_exit);
}
BlockInstr::If(block_if) => {
self.visit_instrs(block_if.syntax());
self.add_basic_block();
self.block_stack.push((block_exit, ident));
let condition = self.current;
let saved_unreachable = self.unreachable;
if let Some(then_block) = block_if.then_block() {
self.visit_block_like(then_block.syntax(), block_exit);
self.current = condition;
self.unreachable = saved_unreachable;
}
if let Some(else_block) = block_if.else_block() {
self.visit_block_like(else_block.syntax(), block_exit);
} else if let Some(condition) = condition {
self.graph
.add_edge(condition, block_exit, FlowEdge { loop_back: false });
}
}
BlockInstr::TryTable(block_try_table) => {
for index in block_try_table.catches().filter_map(|cat| match cat {
Cat::Catch(catch) => catch.label_index(),
Cat::CatchAll(catch_all) => catch_all.label_index(),
}) {
if let Some(target) = self.find_jump_target(index) {
self.graph.add_edge(block_entry, target, FlowEdge { loop_back: false });
}
}
self.block_stack.push((block_exit, ident));
self.visit_block_like(block_try_table.syntax(), block_exit);
}
}
self.finish_exit(block_exit);
self.current = Some(block_exit);
self.block_stack.pop();
}
});
}
fn visit_block_like(&mut self, node: &SyntaxNode, exit: NodeIndex) {
self.visit_instrs(node);
self.add_basic_block();
self.connect_current_to(exit);
}
fn add_basic_block(&mut self) -> Option<NodeIndex> {
if self.bb_instrs.is_empty() {
None
} else {
let bb = self.graph.add_node(FlowNode {
kind: FlowNodeKind::BasicBlock(BasicBlock(self.bb_instrs.drain(..).collect())),
unreachable: self.unreachable,
});
self.connect_current_to(bb);
self.current = Some(bb);
Some(bb)
}
}
fn connect_current_to(&mut self, node_index: NodeIndex) {
if let Some(current) = self.current {
self.graph.add_edge(current, node_index, FlowEdge { loop_back: false });
}
}
fn find_jump_target<N: AstNode>(&self, node: N) -> Option<NodeIndex> {
self.symbol_table
.symbols
.get(&SymbolKey::new(node.syntax()))
.and_then(|symbol| {
if let Some(num) = symbol.idx.num {
let num = num as usize;
let total = self.block_stack.len();
if num < total {
self.block_stack.get(total - 1 - num)
} else {
None
}
} else {
self.block_stack.iter().rev().find(|(_, it)| *it == symbol.idx.name)
}
})
.map(|(target, _)| *target)
}
fn finish_exit(&mut self, exit: NodeIndex) {
self.unreachable = self
.graph
.neighbors_directed(exit, petgraph::Direction::Incoming)
.all(|node_index| {
matches!(
self.graph.node_weight(node_index),
Some(FlowNode { unreachable: true, .. }) | None
)
});
if let Some(flow_node) = self.graph.node_weight_mut(exit) {
flow_node.unreachable = self.unreachable;
}
}
fn is_loop_entry(&self, node_index: NodeIndex) -> bool {
if let Some(FlowNode {
kind: FlowNodeKind::BlockEntry(ptr),
..
}) = self.graph.node_weight(node_index)
{
ptr.kind() == SyntaxKind::BLOCK_LOOP
} else {
false
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FlowNode {
pub kind: FlowNodeKind,
pub unreachable: bool,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FlowNodeKind {
Entry,
Exit,
BasicBlock(BasicBlock),
BlockEntry(SyntaxNodePtr),
BlockExit,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BasicBlock(pub Vec<BasicBlockInstr>);
impl BasicBlock {
pub fn contains_instr(&self, node: &SyntaxNode) -> bool {
let end = node.text_range().end();
self.0
.first()
.zip(self.0.last())
.is_some_and(|(first, last)| first.ptr.text_range().start() < end && end <= last.ptr.text_range().end())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BasicBlockInstr {
pub ptr: SyntaxNodePtr,
pub name: GreenToken,
pub immediates: Vec<SyntaxNodePtr>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FlowEdge {
pub loop_back: bool,
}
#[derive(Clone)]
pub struct ControlFlowGraph {
pub graph: Graph<FlowNode, FlowEdge>,
pub entry: NodeIndex,
}
impl PartialEq for ControlFlowGraph {
fn eq(&self, other: &Self) -> bool {
self.entry == other.entry
&& self.graph.raw_nodes().iter().map(|node| &node.weight).eq(other
.graph
.raw_nodes()
.iter()
.map(|node| &node.weight))
&& self
.graph
.raw_edges()
.iter()
.map(|edge| (edge.source(), edge.target(), &edge.weight))
.eq(other
.graph
.raw_edges()
.iter()
.map(|edge| (edge.source(), edge.target(), &edge.weight)))
}
}