react_compiler_optimization/
prune_unused_labels_hir.rs1use react_compiler_hir::{BlockId, BlockKind, GotoVariant, HirFunction, Terminal};
15use rustc_hash::FxHashMap;
16
17pub fn prune_unused_labels_hir(func: &mut HirFunction) {
18 let mut merged: Vec<(BlockId, BlockId, BlockId)> = Vec::new(); for (&block_id, block) in &func.body.blocks {
23 if let Terminal::Label {
24 block: next_id,
25 fallthrough: fallthrough_id,
26 ..
27 } = &block.terminal
28 {
29 let next = &func.body.blocks[next_id];
30 let fallthrough = &func.body.blocks[fallthrough_id];
31 if let Terminal::Goto {
32 block: goto_target,
33 variant: GotoVariant::Break,
34 ..
35 } = &next.terminal
36 {
37 if goto_target == fallthrough_id
38 && next.kind == BlockKind::Block
39 && fallthrough.kind == BlockKind::Block
40 {
41 merged.push((block_id, *next_id, *fallthrough_id));
42 }
43 }
44 }
45 }
46
47 let mut rewrites: FxHashMap<BlockId, BlockId> = FxHashMap::default();
49
50 for (original_label_id, next_id, fallthrough_id) in &merged {
51 let label_id = rewrites
52 .get(original_label_id)
53 .copied()
54 .unwrap_or(*original_label_id);
55
56 let next_phis_empty = func.body.blocks[next_id].phis.is_empty();
58 let fallthrough_phis_empty = func.body.blocks[fallthrough_id].phis.is_empty();
59 assert!(
60 next_phis_empty && fallthrough_phis_empty,
61 "Unexpected phis when merging label blocks"
62 );
63
64 let next_preds_ok = func.body.blocks[next_id].preds.len() == 1
66 && func.body.blocks[next_id].preds.contains(original_label_id);
67 let fallthrough_preds_ok = func.body.blocks[fallthrough_id].preds.len() == 1
68 && func.body.blocks[fallthrough_id].preds.contains(next_id);
69 assert!(
70 next_preds_ok && fallthrough_preds_ok,
71 "Unexpected block predecessors when merging label blocks"
72 );
73
74 let next_instructions = func.body.blocks[next_id].instructions.clone();
76 let fallthrough_instructions = func.body.blocks[fallthrough_id].instructions.clone();
77 let fallthrough_terminal = func.body.blocks[fallthrough_id].terminal.clone();
78
79 let label_block = func.body.blocks.get_mut(&label_id).unwrap();
81 label_block.instructions.extend(next_instructions);
82 label_block.instructions.extend(fallthrough_instructions);
83 label_block.terminal = fallthrough_terminal;
84
85 func.body.blocks.shift_remove(next_id);
87 func.body.blocks.shift_remove(fallthrough_id);
88
89 rewrites.insert(*fallthrough_id, label_id);
90 }
91
92 for block in func.body.blocks.values_mut() {
94 let preds_to_rewrite: Vec<(BlockId, BlockId)> = block
95 .preds
96 .iter()
97 .filter_map(|pred| rewrites.get(pred).map(|rewritten| (*pred, *rewritten)))
98 .collect();
99 for (old, new) in preds_to_rewrite {
100 block.preds.shift_remove(&old);
101 block.preds.insert(new);
102 }
103 }
104}