use std::collections::{HashMap, HashSet};
use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
use crate::visitors::each_terminal_successor;
use crate::{BlockId, HirFunction, Terminal};
pub struct PostDominator {
pub exit: BlockId,
nodes: HashMap<BlockId, BlockId>,
}
impl PostDominator {
pub fn get(&self, id: BlockId) -> Option<BlockId> {
let dominator = self.nodes.get(&id).expect("Unknown node in post-dominator tree");
if *dominator == id {
None
} else {
Some(*dominator)
}
}
}
struct Node {
id: BlockId,
index: usize,
preds: HashSet<BlockId>,
succs: HashSet<BlockId>,
}
struct Graph {
entry: BlockId,
nodes: Vec<Node>,
node_index: HashMap<BlockId, usize>,
}
impl Graph {
fn get_node(&self, id: BlockId) -> &Node {
let idx = self.node_index[&id];
&self.nodes[idx]
}
}
pub fn compute_post_dominator_tree(
func: &HirFunction,
next_block_id_counter: u32,
include_throws_as_exit_node: bool,
) -> Result<PostDominator, CompilerDiagnostic> {
let graph = build_reverse_graph(func, next_block_id_counter, include_throws_as_exit_node);
let mut nodes = compute_immediate_dominators(&graph)?;
if !include_throws_as_exit_node {
for (id, _) in &func.body.blocks {
nodes.entry(*id).or_insert(*id);
}
}
Ok(PostDominator {
exit: graph.entry,
nodes,
})
}
fn build_reverse_graph(
func: &HirFunction,
next_block_id_counter: u32,
include_throws_as_exit_node: bool,
) -> Graph {
let exit_id = BlockId(next_block_id_counter);
let mut raw_nodes: HashMap<BlockId, Node> = HashMap::new();
raw_nodes.insert(exit_id, Node {
id: exit_id,
index: 0,
preds: HashSet::new(),
succs: HashSet::new(),
});
for (id, block) in &func.body.blocks {
let successors = each_terminal_successor(&block.terminal);
let mut preds_set: HashSet<BlockId> = successors.into_iter().collect();
let succs_set: HashSet<BlockId> = block.preds.iter().copied().collect();
let is_return = matches!(&block.terminal, Terminal::Return { .. });
let is_throw = matches!(&block.terminal, Terminal::Throw { .. });
if is_return || (is_throw && include_throws_as_exit_node) {
preds_set.insert(exit_id);
raw_nodes.get_mut(&exit_id).unwrap().succs.insert(*id);
}
raw_nodes.insert(*id, Node {
id: *id,
index: 0,
preds: preds_set,
succs: succs_set,
});
}
let mut visited = HashSet::new();
let mut postorder = Vec::new();
dfs_postorder(exit_id, &raw_nodes, &mut visited, &mut postorder);
postorder.reverse();
let mut nodes = Vec::with_capacity(postorder.len());
let mut node_index = HashMap::new();
for (idx, id) in postorder.into_iter().enumerate() {
let mut node = raw_nodes.remove(&id).unwrap();
node.index = idx;
node_index.insert(id, idx);
nodes.push(node);
}
Graph {
entry: exit_id,
nodes,
node_index,
}
}
fn dfs_postorder(
id: BlockId,
nodes: &HashMap<BlockId, Node>,
visited: &mut HashSet<BlockId>,
postorder: &mut Vec<BlockId>,
) {
if !visited.insert(id) {
return;
}
if let Some(node) = nodes.get(&id) {
for &succ in &node.succs {
dfs_postorder(succ, nodes, visited, postorder);
}
}
postorder.push(id);
}
fn compute_immediate_dominators(graph: &Graph) -> Result<HashMap<BlockId, BlockId>, CompilerDiagnostic> {
let mut doms: HashMap<BlockId, BlockId> = HashMap::new();
doms.insert(graph.entry, graph.entry);
let mut changed = true;
while changed {
changed = false;
for node in &graph.nodes {
if node.id == graph.entry {
continue;
}
let mut new_idom: Option<BlockId> = None;
for &pred in &node.preds {
if doms.contains_key(&pred) {
new_idom = Some(pred);
break;
}
}
let mut new_idom = match new_idom {
Some(idom) => idom,
None => {
return Err(CompilerDiagnostic::new(
ErrorCategory::Invariant,
format!(
"At least one predecessor must have been visited for block {:?}",
node.id
),
None,
));
}
};
for &pred in &node.preds {
if pred == new_idom {
continue;
}
if doms.contains_key(&pred) {
new_idom = intersect(pred, new_idom, graph, &doms);
}
}
if doms.get(&node.id) != Some(&new_idom) {
doms.insert(node.id, new_idom);
changed = true;
}
}
}
Ok(doms)
}
fn intersect(
a: BlockId,
b: BlockId,
graph: &Graph,
doms: &HashMap<BlockId, BlockId>,
) -> BlockId {
let mut block1 = graph.get_node(a);
let mut block2 = graph.get_node(b);
while block1.id != block2.id {
while block1.index > block2.index {
let dom = doms[&block1.id];
block1 = graph.get_node(dom);
}
while block2.index > block1.index {
let dom = doms[&block2.id];
block2 = graph.get_node(dom);
}
}
block1.id
}
pub fn post_dominator_frontier(
func: &HirFunction,
post_dominators: &PostDominator,
target_id: BlockId,
) -> HashSet<BlockId> {
let target_post_dominators = post_dominators_of(func, post_dominators, target_id);
let mut visited = HashSet::new();
let mut frontier = HashSet::new();
let mut to_visit: Vec<BlockId> = target_post_dominators.iter().copied().collect();
to_visit.push(target_id);
for block_id in to_visit {
if !visited.insert(block_id) {
continue;
}
if let Some(block) = func.body.blocks.get(&block_id) {
for &pred in &block.preds {
if !target_post_dominators.contains(&pred) {
frontier.insert(pred);
}
}
}
}
frontier
}
pub fn post_dominators_of(
func: &HirFunction,
post_dominators: &PostDominator,
target_id: BlockId,
) -> HashSet<BlockId> {
let mut result = HashSet::new();
let mut visited = HashSet::new();
let mut queue = vec![target_id];
while let Some(current_id) = queue.pop() {
if !visited.insert(current_id) {
continue;
}
if let Some(block) = func.body.blocks.get(¤t_id) {
for &pred in &block.preds {
let pred_post_dom = post_dominators.get(pred).unwrap_or(pred);
if pred_post_dom == target_id || result.contains(&pred_post_dom) {
result.insert(pred);
}
queue.push(pred);
}
}
}
result
}
pub fn compute_unconditional_blocks(
func: &HirFunction,
next_block_id_counter: u32,
) -> Result<HashSet<BlockId>, CompilerDiagnostic> {
let mut unconditional = HashSet::new();
let dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?;
let exit = dominators.exit;
let mut current: Option<BlockId> = Some(func.body.entry);
while let Some(block_id) = current {
if block_id == exit {
break;
}
assert!(
!unconditional.contains(&block_id),
"Internal error: non-terminating loop in ComputeUnconditionalBlocks"
);
unconditional.insert(block_id);
current = dominators.get(block_id);
}
Ok(unconditional)
}