react_compiler_optimization/
prune_unused_labels_hir.rs1use react_compiler_hir::{BlockId, BlockKind, GotoVariant, HirFunction, Terminal};
15use std::collections::HashMap;
16
17pub fn prune_unused_labels_hir(func: &mut HirFunction) {
18 let mut merged: Vec<(BlockId, BlockId, BlockId)> = Vec::new(); for (&block_id, block) in &func.body.blocks {
23 if let Terminal::Label {
24 block: next_id,
25 fallthrough: fallthrough_id,
26 ..
27 } = &block.terminal
28 {
29 let next = &func.body.blocks[next_id];
30 let fallthrough = &func.body.blocks[fallthrough_id];
31 if let Terminal::Goto {
32 block: goto_target,
33 variant: GotoVariant::Break,
34 ..
35 } = &next.terminal
36 {
37 if goto_target == fallthrough_id
38 && next.kind == BlockKind::Block
39 && fallthrough.kind == BlockKind::Block
40 {
41 merged.push((block_id, *next_id, *fallthrough_id));
42 }
43 }
44 }
45 }
46
47 let mut rewrites: HashMap<BlockId, BlockId> = HashMap::new();
49
50 for (original_label_id, next_id, fallthrough_id) in &merged {
51 let label_id = rewrites.get(original_label_id).copied().unwrap_or(*original_label_id);
52
53 let next_phis_empty = func.body.blocks[next_id].phis.is_empty();
55 let fallthrough_phis_empty = func.body.blocks[fallthrough_id].phis.is_empty();
56 assert!(
57 next_phis_empty && fallthrough_phis_empty,
58 "Unexpected phis when merging label blocks"
59 );
60
61 let next_preds_ok = func.body.blocks[next_id].preds.len() == 1
63 && func.body.blocks[next_id].preds.contains(original_label_id);
64 let fallthrough_preds_ok = func.body.blocks[fallthrough_id].preds.len() == 1
65 && func.body.blocks[fallthrough_id].preds.contains(next_id);
66 assert!(
67 next_preds_ok && fallthrough_preds_ok,
68 "Unexpected block predecessors when merging label blocks"
69 );
70
71 let next_instructions = func.body.blocks[next_id].instructions.clone();
73 let fallthrough_instructions = func.body.blocks[fallthrough_id].instructions.clone();
74 let fallthrough_terminal = func.body.blocks[fallthrough_id].terminal.clone();
75
76 let label_block = func.body.blocks.get_mut(&label_id).unwrap();
78 label_block.instructions.extend(next_instructions);
79 label_block.instructions.extend(fallthrough_instructions);
80 label_block.terminal = fallthrough_terminal;
81
82 func.body.blocks.shift_remove(next_id);
84 func.body.blocks.shift_remove(fallthrough_id);
85
86 rewrites.insert(*fallthrough_id, label_id);
87 }
88
89 for block in func.body.blocks.values_mut() {
91 let preds_to_rewrite: Vec<(BlockId, BlockId)> = block
92 .preds
93 .iter()
94 .filter_map(|pred| rewrites.get(pred).map(|rewritten| (*pred, *rewritten)))
95 .collect();
96 for (old, new) in preds_to_rewrite {
97 block.preds.shift_remove(&old);
98 block.preds.insert(new);
99 }
100 }
101}