Skip to main content

react_compiler_hir/
dominator.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Dominator and post-dominator tree computation.
7//!
8//! Port of Dominator.ts and ComputeUnconditionalBlocks.ts.
9//! Uses the Cooper/Harvey/Kennedy algorithm from
10//! https://www.cs.rice.edu/~keith/Embed/dom.pdf
11
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
15
16use crate::visitors::each_terminal_successor;
17use crate::{BlockId, HirFunction, Terminal};
18
19// =============================================================================
20// Public types
21// =============================================================================
22
23/// Stores the immediate post-dominator for each block.
24pub struct PostDominator {
25    /// The exit node (synthetic node representing function exit).
26    pub exit: BlockId,
27    nodes: FxHashMap<BlockId, BlockId>,
28}
29
30impl PostDominator {
31    /// Returns the immediate post-dominator of the given block, or None if
32    /// the block post-dominates itself (i.e., it is the exit node).
33    pub fn get(&self, id: BlockId) -> Option<BlockId> {
34        let dominator = self
35            .nodes
36            .get(&id)
37            .expect("Unknown node in post-dominator tree");
38        if *dominator == id {
39            None
40        } else {
41            Some(*dominator)
42        }
43    }
44}
45
46// =============================================================================
47// Graph representation
48// =============================================================================
49
50struct Node {
51    id: BlockId,
52    index: usize,
53    preds: FxHashSet<BlockId>,
54    succs: FxHashSet<BlockId>,
55}
56
57struct Graph {
58    entry: BlockId,
59    /// Nodes stored in iteration order (RPO for reverse graph).
60    nodes: Vec<Node>,
61    /// Map from BlockId to index in the nodes vec.
62    node_index: FxHashMap<BlockId, usize>,
63}
64
65impl Graph {
66    fn get_node(&self, id: BlockId) -> &Node {
67        let idx = self.node_index[&id];
68        &self.nodes[idx]
69    }
70}
71
72// =============================================================================
73// Post-dominator tree computation
74// =============================================================================
75
76/// Compute the post-dominator tree for a function.
77///
78/// If `include_throws_as_exit_node` is true, throw terminals are treated as
79/// exit nodes (like return). Otherwise, only return terminals feed into exit.
80pub fn compute_post_dominator_tree(
81    func: &HirFunction,
82    next_block_id_counter: u32,
83    include_throws_as_exit_node: bool,
84) -> Result<PostDominator, CompilerDiagnostic> {
85    let graph = build_reverse_graph(func, next_block_id_counter, include_throws_as_exit_node);
86    let mut nodes = compute_immediate_dominators(&graph)?;
87
88    // When include_throws_as_exit_node is false, nodes that flow into a throw
89    // terminal and don't reach the exit won't be in the node map. Add them
90    // with themselves as dominator.
91    if !include_throws_as_exit_node {
92        for (id, _) in &func.body.blocks {
93            nodes.entry(*id).or_insert(*id);
94        }
95    }
96
97    Ok(PostDominator {
98        exit: graph.entry,
99        nodes,
100    })
101}
102
103/// Build the reverse graph from the HIR function.
104///
105/// Reverses all edges and adds a synthetic exit node that receives edges from
106/// return (and optionally throw) terminals. The result is put into RPO order.
107fn build_reverse_graph(
108    func: &HirFunction,
109    next_block_id_counter: u32,
110    include_throws_as_exit_node: bool,
111) -> Graph {
112    let exit_id = BlockId(next_block_id_counter);
113
114    // Build initial nodes with reversed edges
115    let mut raw_nodes: FxHashMap<BlockId, Node> = FxHashMap::default();
116
117    // Create exit node
118    raw_nodes.insert(
119        exit_id,
120        Node {
121            id: exit_id,
122            index: 0,
123            preds: FxHashSet::default(),
124            succs: FxHashSet::default(),
125        },
126    );
127
128    for (id, block) in &func.body.blocks {
129        let successors = each_terminal_successor(&block.terminal);
130        let mut preds_set: FxHashSet<BlockId> = successors.into_iter().collect();
131        let succs_set: FxHashSet<BlockId> = block.preds.iter().copied().collect();
132
133        let is_return = matches!(&block.terminal, Terminal::Return { .. });
134        let is_throw = matches!(&block.terminal, Terminal::Throw { .. });
135
136        if is_return || (is_throw && include_throws_as_exit_node) {
137            preds_set.insert(exit_id);
138            raw_nodes.get_mut(&exit_id).unwrap().succs.insert(*id);
139        }
140
141        raw_nodes.insert(
142            *id,
143            Node {
144                id: *id,
145                index: 0,
146                preds: preds_set,
147                succs: succs_set,
148            },
149        );
150    }
151
152    // DFS from exit to compute RPO
153    let mut visited = FxHashSet::default();
154    let mut postorder = Vec::new();
155    dfs_postorder(exit_id, &raw_nodes, &mut visited, &mut postorder);
156
157    // Reverse postorder
158    postorder.reverse();
159
160    let mut nodes = Vec::with_capacity(postorder.len());
161    let mut node_index = FxHashMap::default();
162    for (idx, id) in postorder.into_iter().enumerate() {
163        let mut node = raw_nodes.remove(&id).unwrap();
164        node.index = idx;
165        node_index.insert(id, idx);
166        nodes.push(node);
167    }
168
169    Graph {
170        entry: exit_id,
171        nodes,
172        node_index,
173    }
174}
175
176fn dfs_postorder(
177    id: BlockId,
178    nodes: &FxHashMap<BlockId, Node>,
179    visited: &mut FxHashSet<BlockId>,
180    postorder: &mut Vec<BlockId>,
181) {
182    if !visited.insert(id) {
183        return;
184    }
185    if let Some(node) = nodes.get(&id) {
186        for &succ in &node.succs {
187            dfs_postorder(succ, nodes, visited, postorder);
188        }
189    }
190    postorder.push(id);
191}
192
193// =============================================================================
194// Dominator fixpoint (Cooper/Harvey/Kennedy)
195// =============================================================================
196
197fn compute_immediate_dominators(
198    graph: &Graph,
199) -> Result<FxHashMap<BlockId, BlockId>, CompilerDiagnostic> {
200    let mut doms: FxHashMap<BlockId, BlockId> = FxHashMap::default();
201    doms.insert(graph.entry, graph.entry);
202
203    let mut changed = true;
204    while changed {
205        changed = false;
206        for node in &graph.nodes {
207            if node.id == graph.entry {
208                continue;
209            }
210
211            // Find first processed predecessor
212            let mut new_idom: Option<BlockId> = None;
213            for &pred in &node.preds {
214                if doms.contains_key(&pred) {
215                    new_idom = Some(pred);
216                    break;
217                }
218            }
219            let mut new_idom = match new_idom {
220                Some(idom) => idom,
221                None => {
222                    return Err(CompilerDiagnostic::new(
223                        ErrorCategory::Invariant,
224                        format!(
225                            "At least one predecessor must have been visited for block {:?}",
226                            node.id
227                        ),
228                        None,
229                    ));
230                }
231            };
232
233            // Intersect with other processed predecessors
234            for &pred in &node.preds {
235                if pred == new_idom {
236                    continue;
237                }
238                if doms.contains_key(&pred) {
239                    new_idom = intersect(pred, new_idom, graph, &doms);
240                }
241            }
242
243            if doms.get(&node.id) != Some(&new_idom) {
244                doms.insert(node.id, new_idom);
245                changed = true;
246            }
247        }
248    }
249    Ok(doms)
250}
251
252fn intersect(a: BlockId, b: BlockId, graph: &Graph, doms: &FxHashMap<BlockId, BlockId>) -> BlockId {
253    let mut block1 = graph.get_node(a);
254    let mut block2 = graph.get_node(b);
255    while block1.id != block2.id {
256        while block1.index > block2.index {
257            let dom = doms[&block1.id];
258            block1 = graph.get_node(dom);
259        }
260        while block2.index > block1.index {
261            let dom = doms[&block2.id];
262            block2 = graph.get_node(dom);
263        }
264    }
265    block1.id
266}
267
268// =============================================================================
269// Post-dominator frontier
270// =============================================================================
271
272/// Computes the post-dominator frontier of `target_id`. These are immediate
273/// predecessors of nodes that post-dominate `target_id` from which execution may
274/// not reach `target_id`. Intuitively, these are the earliest blocks from which
275/// execution branches such that it may or may not reach the target block.
276pub fn post_dominator_frontier(
277    func: &HirFunction,
278    post_dominators: &PostDominator,
279    target_id: BlockId,
280) -> FxHashSet<BlockId> {
281    let target_post_dominators = post_dominators_of(func, post_dominators, target_id);
282    let mut visited = FxHashSet::default();
283    let mut frontier = FxHashSet::default();
284
285    let mut to_visit: Vec<BlockId> = target_post_dominators.iter().copied().collect();
286    to_visit.push(target_id);
287
288    for block_id in to_visit {
289        if !visited.insert(block_id) {
290            continue;
291        }
292        if let Some(block) = func.body.blocks.get(&block_id) {
293            for &pred in &block.preds {
294                if !target_post_dominators.contains(&pred) {
295                    frontier.insert(pred);
296                }
297            }
298        }
299    }
300    frontier
301}
302
303/// Walks up the post-dominator tree to collect all blocks that post-dominate `target_id`.
304pub fn post_dominators_of(
305    func: &HirFunction,
306    post_dominators: &PostDominator,
307    target_id: BlockId,
308) -> FxHashSet<BlockId> {
309    let mut result = FxHashSet::default();
310    let mut visited = FxHashSet::default();
311    let mut queue = vec![target_id];
312
313    while let Some(current_id) = queue.pop() {
314        if !visited.insert(current_id) {
315            continue;
316        }
317        if let Some(block) = func.body.blocks.get(&current_id) {
318            for &pred in &block.preds {
319                let pred_post_dom = post_dominators.get(pred).unwrap_or(pred);
320                if pred_post_dom == target_id || result.contains(&pred_post_dom) {
321                    result.insert(pred);
322                }
323                queue.push(pred);
324            }
325        }
326    }
327    result
328}
329
330// =============================================================================
331// Unconditional blocks
332// =============================================================================
333
334/// Compute the set of blocks that are unconditionally executed from the entry.
335///
336/// Port of ComputeUnconditionalBlocks.ts. Walks the immediate post-dominator
337/// chain starting from the function entry. A block is unconditional if it lies
338/// on this chain (meaning every path through the function must pass through it).
339pub fn compute_unconditional_blocks(
340    func: &HirFunction,
341    next_block_id_counter: u32,
342) -> Result<FxHashSet<BlockId>, CompilerDiagnostic> {
343    let mut unconditional = FxHashSet::default();
344    let dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?;
345    let exit = dominators.exit;
346    let mut current: Option<BlockId> = Some(func.body.entry);
347
348    while let Some(block_id) = current {
349        if block_id == exit {
350            break;
351        }
352        assert!(
353            !unconditional.contains(&block_id),
354            "Internal error: non-terminating loop in ComputeUnconditionalBlocks"
355        );
356        unconditional.insert(block_id);
357        current = dominators.get(block_id);
358    }
359
360    Ok(unconditional)
361}