Skip to main content

react_compiler_hir/
dominator.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Dominator and post-dominator tree computation.
7//!
8//! Port of Dominator.ts and ComputeUnconditionalBlocks.ts.
9//! Uses the Cooper/Harvey/Kennedy algorithm from
10//! https://www.cs.rice.edu/~keith/Embed/dom.pdf
11
12use std::collections::{HashMap, HashSet};
13
14use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
15
16use crate::visitors::each_terminal_successor;
17use crate::{BlockId, HirFunction, Terminal};
18
19// =============================================================================
20// Public types
21// =============================================================================
22
23/// Stores the immediate post-dominator for each block.
24pub struct PostDominator {
25    /// The exit node (synthetic node representing function exit).
26    pub exit: BlockId,
27    nodes: HashMap<BlockId, BlockId>,
28}
29
30impl PostDominator {
31    /// Returns the immediate post-dominator of the given block, or None if
32    /// the block post-dominates itself (i.e., it is the exit node).
33    pub fn get(&self, id: BlockId) -> Option<BlockId> {
34        let dominator = self.nodes.get(&id).expect("Unknown node in post-dominator tree");
35        if *dominator == id {
36            None
37        } else {
38            Some(*dominator)
39        }
40    }
41}
42
43// =============================================================================
44// Graph representation
45// =============================================================================
46
47struct Node {
48    id: BlockId,
49    index: usize,
50    preds: HashSet<BlockId>,
51    succs: HashSet<BlockId>,
52}
53
54struct Graph {
55    entry: BlockId,
56    /// Nodes stored in iteration order (RPO for reverse graph).
57    nodes: Vec<Node>,
58    /// Map from BlockId to index in the nodes vec.
59    node_index: HashMap<BlockId, usize>,
60}
61
62impl Graph {
63    fn get_node(&self, id: BlockId) -> &Node {
64        let idx = self.node_index[&id];
65        &self.nodes[idx]
66    }
67}
68
69// =============================================================================
70// Post-dominator tree computation
71// =============================================================================
72
73/// Compute the post-dominator tree for a function.
74///
75/// If `include_throws_as_exit_node` is true, throw terminals are treated as
76/// exit nodes (like return). Otherwise, only return terminals feed into exit.
77pub fn compute_post_dominator_tree(
78    func: &HirFunction,
79    next_block_id_counter: u32,
80    include_throws_as_exit_node: bool,
81) -> Result<PostDominator, CompilerDiagnostic> {
82    let graph = build_reverse_graph(func, next_block_id_counter, include_throws_as_exit_node);
83    let mut nodes = compute_immediate_dominators(&graph)?;
84
85    // When include_throws_as_exit_node is false, nodes that flow into a throw
86    // terminal and don't reach the exit won't be in the node map. Add them
87    // with themselves as dominator.
88    if !include_throws_as_exit_node {
89        for (id, _) in &func.body.blocks {
90            nodes.entry(*id).or_insert(*id);
91        }
92    }
93
94    Ok(PostDominator {
95        exit: graph.entry,
96        nodes,
97    })
98}
99
100/// Build the reverse graph from the HIR function.
101///
102/// Reverses all edges and adds a synthetic exit node that receives edges from
103/// return (and optionally throw) terminals. The result is put into RPO order.
104fn build_reverse_graph(
105    func: &HirFunction,
106    next_block_id_counter: u32,
107    include_throws_as_exit_node: bool,
108) -> Graph {
109    let exit_id = BlockId(next_block_id_counter);
110
111    // Build initial nodes with reversed edges
112    let mut raw_nodes: HashMap<BlockId, Node> = HashMap::new();
113
114    // Create exit node
115    raw_nodes.insert(exit_id, Node {
116        id: exit_id,
117        index: 0,
118        preds: HashSet::new(),
119        succs: HashSet::new(),
120    });
121
122    for (id, block) in &func.body.blocks {
123        let successors = each_terminal_successor(&block.terminal);
124        let mut preds_set: HashSet<BlockId> = successors.into_iter().collect();
125        let succs_set: HashSet<BlockId> = block.preds.iter().copied().collect();
126
127        let is_return = matches!(&block.terminal, Terminal::Return { .. });
128        let is_throw = matches!(&block.terminal, Terminal::Throw { .. });
129
130        if is_return || (is_throw && include_throws_as_exit_node) {
131            preds_set.insert(exit_id);
132            raw_nodes.get_mut(&exit_id).unwrap().succs.insert(*id);
133        }
134
135        raw_nodes.insert(*id, Node {
136            id: *id,
137            index: 0,
138            preds: preds_set,
139            succs: succs_set,
140        });
141    }
142
143    // DFS from exit to compute RPO
144    let mut visited = HashSet::new();
145    let mut postorder = Vec::new();
146    dfs_postorder(exit_id, &raw_nodes, &mut visited, &mut postorder);
147
148    // Reverse postorder
149    postorder.reverse();
150
151    let mut nodes = Vec::with_capacity(postorder.len());
152    let mut node_index = HashMap::new();
153    for (idx, id) in postorder.into_iter().enumerate() {
154        let mut node = raw_nodes.remove(&id).unwrap();
155        node.index = idx;
156        node_index.insert(id, idx);
157        nodes.push(node);
158    }
159
160    Graph {
161        entry: exit_id,
162        nodes,
163        node_index,
164    }
165}
166
167fn dfs_postorder(
168    id: BlockId,
169    nodes: &HashMap<BlockId, Node>,
170    visited: &mut HashSet<BlockId>,
171    postorder: &mut Vec<BlockId>,
172) {
173    if !visited.insert(id) {
174        return;
175    }
176    if let Some(node) = nodes.get(&id) {
177        for &succ in &node.succs {
178            dfs_postorder(succ, nodes, visited, postorder);
179        }
180    }
181    postorder.push(id);
182}
183
184// =============================================================================
185// Dominator fixpoint (Cooper/Harvey/Kennedy)
186// =============================================================================
187
188fn compute_immediate_dominators(graph: &Graph) -> Result<HashMap<BlockId, BlockId>, CompilerDiagnostic> {
189    let mut doms: HashMap<BlockId, BlockId> = HashMap::new();
190    doms.insert(graph.entry, graph.entry);
191
192    let mut changed = true;
193    while changed {
194        changed = false;
195        for node in &graph.nodes {
196            if node.id == graph.entry {
197                continue;
198            }
199
200            // Find first processed predecessor
201            let mut new_idom: Option<BlockId> = None;
202            for &pred in &node.preds {
203                if doms.contains_key(&pred) {
204                    new_idom = Some(pred);
205                    break;
206                }
207            }
208            let mut new_idom = match new_idom {
209                Some(idom) => idom,
210                None => {
211                    return Err(CompilerDiagnostic::new(
212                        ErrorCategory::Invariant,
213                        format!(
214                            "At least one predecessor must have been visited for block {:?}",
215                            node.id
216                        ),
217                        None,
218                    ));
219                }
220            };
221
222            // Intersect with other processed predecessors
223            for &pred in &node.preds {
224                if pred == new_idom {
225                    continue;
226                }
227                if doms.contains_key(&pred) {
228                    new_idom = intersect(pred, new_idom, graph, &doms);
229                }
230            }
231
232            if doms.get(&node.id) != Some(&new_idom) {
233                doms.insert(node.id, new_idom);
234                changed = true;
235            }
236        }
237    }
238    Ok(doms)
239}
240
241fn intersect(
242    a: BlockId,
243    b: BlockId,
244    graph: &Graph,
245    doms: &HashMap<BlockId, BlockId>,
246) -> BlockId {
247    let mut block1 = graph.get_node(a);
248    let mut block2 = graph.get_node(b);
249    while block1.id != block2.id {
250        while block1.index > block2.index {
251            let dom = doms[&block1.id];
252            block1 = graph.get_node(dom);
253        }
254        while block2.index > block1.index {
255            let dom = doms[&block2.id];
256            block2 = graph.get_node(dom);
257        }
258    }
259    block1.id
260}
261
262// =============================================================================
263// Post-dominator frontier
264// =============================================================================
265
266/// Computes the post-dominator frontier of `target_id`. These are immediate
267/// predecessors of nodes that post-dominate `target_id` from which execution may
268/// not reach `target_id`. Intuitively, these are the earliest blocks from which
269/// execution branches such that it may or may not reach the target block.
270pub fn post_dominator_frontier(
271    func: &HirFunction,
272    post_dominators: &PostDominator,
273    target_id: BlockId,
274) -> HashSet<BlockId> {
275    let target_post_dominators = post_dominators_of(func, post_dominators, target_id);
276    let mut visited = HashSet::new();
277    let mut frontier = HashSet::new();
278
279    let mut to_visit: Vec<BlockId> = target_post_dominators.iter().copied().collect();
280    to_visit.push(target_id);
281
282    for block_id in to_visit {
283        if !visited.insert(block_id) {
284            continue;
285        }
286        if let Some(block) = func.body.blocks.get(&block_id) {
287            for &pred in &block.preds {
288                if !target_post_dominators.contains(&pred) {
289                    frontier.insert(pred);
290                }
291            }
292        }
293    }
294    frontier
295}
296
297/// Walks up the post-dominator tree to collect all blocks that post-dominate `target_id`.
298pub fn post_dominators_of(
299    func: &HirFunction,
300    post_dominators: &PostDominator,
301    target_id: BlockId,
302) -> HashSet<BlockId> {
303    let mut result = HashSet::new();
304    let mut visited = HashSet::new();
305    let mut queue = vec![target_id];
306
307    while let Some(current_id) = queue.pop() {
308        if !visited.insert(current_id) {
309            continue;
310        }
311        if let Some(block) = func.body.blocks.get(&current_id) {
312            for &pred in &block.preds {
313                let pred_post_dom = post_dominators.get(pred).unwrap_or(pred);
314                if pred_post_dom == target_id || result.contains(&pred_post_dom) {
315                    result.insert(pred);
316                }
317                queue.push(pred);
318            }
319        }
320    }
321    result
322}
323
324// =============================================================================
325// Unconditional blocks
326// =============================================================================
327
328/// Compute the set of blocks that are unconditionally executed from the entry.
329///
330/// Port of ComputeUnconditionalBlocks.ts. Walks the immediate post-dominator
331/// chain starting from the function entry. A block is unconditional if it lies
332/// on this chain (meaning every path through the function must pass through it).
333pub fn compute_unconditional_blocks(
334    func: &HirFunction,
335    next_block_id_counter: u32,
336) -> Result<HashSet<BlockId>, CompilerDiagnostic> {
337    let mut unconditional = HashSet::new();
338    let dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?;
339    let exit = dominators.exit;
340    let mut current: Option<BlockId> = Some(func.body.entry);
341
342    while let Some(block_id) = current {
343        if block_id == exit {
344            break;
345        }
346        assert!(
347            !unconditional.contains(&block_id),
348            "Internal error: non-terminating loop in ComputeUnconditionalBlocks"
349        );
350        unconditional.insert(block_id);
351        current = dominators.get(block_id);
352    }
353
354    Ok(unconditional)
355}