1use std::collections::{HashMap, HashSet};
13
14use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
15
16use crate::visitors::each_terminal_successor;
17use crate::{BlockId, HirFunction, Terminal};
18
19pub struct PostDominator {
25 pub exit: BlockId,
27 nodes: HashMap<BlockId, BlockId>,
28}
29
30impl PostDominator {
31 pub fn get(&self, id: BlockId) -> Option<BlockId> {
34 let dominator = self.nodes.get(&id).expect("Unknown node in post-dominator tree");
35 if *dominator == id {
36 None
37 } else {
38 Some(*dominator)
39 }
40 }
41}
42
43struct Node {
48 id: BlockId,
49 index: usize,
50 preds: HashSet<BlockId>,
51 succs: HashSet<BlockId>,
52}
53
54struct Graph {
55 entry: BlockId,
56 nodes: Vec<Node>,
58 node_index: HashMap<BlockId, usize>,
60}
61
62impl Graph {
63 fn get_node(&self, id: BlockId) -> &Node {
64 let idx = self.node_index[&id];
65 &self.nodes[idx]
66 }
67}
68
69pub fn compute_post_dominator_tree(
78 func: &HirFunction,
79 next_block_id_counter: u32,
80 include_throws_as_exit_node: bool,
81) -> Result<PostDominator, CompilerDiagnostic> {
82 let graph = build_reverse_graph(func, next_block_id_counter, include_throws_as_exit_node);
83 let mut nodes = compute_immediate_dominators(&graph)?;
84
85 if !include_throws_as_exit_node {
89 for (id, _) in &func.body.blocks {
90 nodes.entry(*id).or_insert(*id);
91 }
92 }
93
94 Ok(PostDominator {
95 exit: graph.entry,
96 nodes,
97 })
98}
99
100fn build_reverse_graph(
105 func: &HirFunction,
106 next_block_id_counter: u32,
107 include_throws_as_exit_node: bool,
108) -> Graph {
109 let exit_id = BlockId(next_block_id_counter);
110
111 let mut raw_nodes: HashMap<BlockId, Node> = HashMap::new();
113
114 raw_nodes.insert(exit_id, Node {
116 id: exit_id,
117 index: 0,
118 preds: HashSet::new(),
119 succs: HashSet::new(),
120 });
121
122 for (id, block) in &func.body.blocks {
123 let successors = each_terminal_successor(&block.terminal);
124 let mut preds_set: HashSet<BlockId> = successors.into_iter().collect();
125 let succs_set: HashSet<BlockId> = block.preds.iter().copied().collect();
126
127 let is_return = matches!(&block.terminal, Terminal::Return { .. });
128 let is_throw = matches!(&block.terminal, Terminal::Throw { .. });
129
130 if is_return || (is_throw && include_throws_as_exit_node) {
131 preds_set.insert(exit_id);
132 raw_nodes.get_mut(&exit_id).unwrap().succs.insert(*id);
133 }
134
135 raw_nodes.insert(*id, Node {
136 id: *id,
137 index: 0,
138 preds: preds_set,
139 succs: succs_set,
140 });
141 }
142
143 let mut visited = HashSet::new();
145 let mut postorder = Vec::new();
146 dfs_postorder(exit_id, &raw_nodes, &mut visited, &mut postorder);
147
148 postorder.reverse();
150
151 let mut nodes = Vec::with_capacity(postorder.len());
152 let mut node_index = HashMap::new();
153 for (idx, id) in postorder.into_iter().enumerate() {
154 let mut node = raw_nodes.remove(&id).unwrap();
155 node.index = idx;
156 node_index.insert(id, idx);
157 nodes.push(node);
158 }
159
160 Graph {
161 entry: exit_id,
162 nodes,
163 node_index,
164 }
165}
166
167fn dfs_postorder(
168 id: BlockId,
169 nodes: &HashMap<BlockId, Node>,
170 visited: &mut HashSet<BlockId>,
171 postorder: &mut Vec<BlockId>,
172) {
173 if !visited.insert(id) {
174 return;
175 }
176 if let Some(node) = nodes.get(&id) {
177 for &succ in &node.succs {
178 dfs_postorder(succ, nodes, visited, postorder);
179 }
180 }
181 postorder.push(id);
182}
183
184fn compute_immediate_dominators(graph: &Graph) -> Result<HashMap<BlockId, BlockId>, CompilerDiagnostic> {
189 let mut doms: HashMap<BlockId, BlockId> = HashMap::new();
190 doms.insert(graph.entry, graph.entry);
191
192 let mut changed = true;
193 while changed {
194 changed = false;
195 for node in &graph.nodes {
196 if node.id == graph.entry {
197 continue;
198 }
199
200 let mut new_idom: Option<BlockId> = None;
202 for &pred in &node.preds {
203 if doms.contains_key(&pred) {
204 new_idom = Some(pred);
205 break;
206 }
207 }
208 let mut new_idom = match new_idom {
209 Some(idom) => idom,
210 None => {
211 return Err(CompilerDiagnostic::new(
212 ErrorCategory::Invariant,
213 format!(
214 "At least one predecessor must have been visited for block {:?}",
215 node.id
216 ),
217 None,
218 ));
219 }
220 };
221
222 for &pred in &node.preds {
224 if pred == new_idom {
225 continue;
226 }
227 if doms.contains_key(&pred) {
228 new_idom = intersect(pred, new_idom, graph, &doms);
229 }
230 }
231
232 if doms.get(&node.id) != Some(&new_idom) {
233 doms.insert(node.id, new_idom);
234 changed = true;
235 }
236 }
237 }
238 Ok(doms)
239}
240
241fn intersect(
242 a: BlockId,
243 b: BlockId,
244 graph: &Graph,
245 doms: &HashMap<BlockId, BlockId>,
246) -> BlockId {
247 let mut block1 = graph.get_node(a);
248 let mut block2 = graph.get_node(b);
249 while block1.id != block2.id {
250 while block1.index > block2.index {
251 let dom = doms[&block1.id];
252 block1 = graph.get_node(dom);
253 }
254 while block2.index > block1.index {
255 let dom = doms[&block2.id];
256 block2 = graph.get_node(dom);
257 }
258 }
259 block1.id
260}
261
262pub fn post_dominator_frontier(
271 func: &HirFunction,
272 post_dominators: &PostDominator,
273 target_id: BlockId,
274) -> HashSet<BlockId> {
275 let target_post_dominators = post_dominators_of(func, post_dominators, target_id);
276 let mut visited = HashSet::new();
277 let mut frontier = HashSet::new();
278
279 let mut to_visit: Vec<BlockId> = target_post_dominators.iter().copied().collect();
280 to_visit.push(target_id);
281
282 for block_id in to_visit {
283 if !visited.insert(block_id) {
284 continue;
285 }
286 if let Some(block) = func.body.blocks.get(&block_id) {
287 for &pred in &block.preds {
288 if !target_post_dominators.contains(&pred) {
289 frontier.insert(pred);
290 }
291 }
292 }
293 }
294 frontier
295}
296
297pub fn post_dominators_of(
299 func: &HirFunction,
300 post_dominators: &PostDominator,
301 target_id: BlockId,
302) -> HashSet<BlockId> {
303 let mut result = HashSet::new();
304 let mut visited = HashSet::new();
305 let mut queue = vec![target_id];
306
307 while let Some(current_id) = queue.pop() {
308 if !visited.insert(current_id) {
309 continue;
310 }
311 if let Some(block) = func.body.blocks.get(¤t_id) {
312 for &pred in &block.preds {
313 let pred_post_dom = post_dominators.get(pred).unwrap_or(pred);
314 if pred_post_dom == target_id || result.contains(&pred_post_dom) {
315 result.insert(pred);
316 }
317 queue.push(pred);
318 }
319 }
320 }
321 result
322}
323
324pub fn compute_unconditional_blocks(
334 func: &HirFunction,
335 next_block_id_counter: u32,
336) -> Result<HashSet<BlockId>, CompilerDiagnostic> {
337 let mut unconditional = HashSet::new();
338 let dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?;
339 let exit = dominators.exit;
340 let mut current: Option<BlockId> = Some(func.body.entry);
341
342 while let Some(block_id) = current {
343 if block_id == exit {
344 break;
345 }
346 assert!(
347 !unconditional.contains(&block_id),
348 "Internal error: non-terminating loop in ComputeUnconditionalBlocks"
349 );
350 unconditional.insert(block_id);
351 current = dominators.get(block_id);
352 }
353
354 Ok(unconditional)
355}