Skip to main content

react_compiler_optimization/
prune_unused_labels_hir.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Removes unused labels from the HIR.
7//!
8//! A label terminal whose body block immediately breaks to the label's
9//! fallthrough (with no other predecessors) is effectively a no-op label.
10//! This pass merges such label/body/fallthrough triples into a single block.
11//!
12//! Analogous to TS `PruneUnusedLabelsHIR.ts`.
13
14use react_compiler_hir::{BlockId, BlockKind, GotoVariant, HirFunction, Terminal};
15use rustc_hash::FxHashMap;
16
17pub fn prune_unused_labels_hir(func: &mut HirFunction) {
18    // Phase 1: Identify label terminals whose body block immediately breaks
19    // to the fallthrough, and both body and fallthrough are normal blocks.
20    let mut merged: Vec<(BlockId, BlockId, BlockId)> = Vec::new(); // (label, next, fallthrough)
21
22    for (&block_id, block) in &func.body.blocks {
23        if let Terminal::Label {
24            block: next_id,
25            fallthrough: fallthrough_id,
26            ..
27        } = &block.terminal
28        {
29            let next = &func.body.blocks[next_id];
30            let fallthrough = &func.body.blocks[fallthrough_id];
31            if let Terminal::Goto {
32                block: goto_target,
33                variant: GotoVariant::Break,
34                ..
35            } = &next.terminal
36            {
37                if goto_target == fallthrough_id
38                    && next.kind == BlockKind::Block
39                    && fallthrough.kind == BlockKind::Block
40                {
41                    merged.push((block_id, *next_id, *fallthrough_id));
42                }
43            }
44        }
45    }
46
47    // Phase 2: Apply merges
48    let mut rewrites: FxHashMap<BlockId, BlockId> = FxHashMap::default();
49
50    for (original_label_id, next_id, fallthrough_id) in &merged {
51        let label_id = rewrites
52            .get(original_label_id)
53            .copied()
54            .unwrap_or(*original_label_id);
55
56        // Validate: no phis in next or fallthrough
57        let next_phis_empty = func.body.blocks[next_id].phis.is_empty();
58        let fallthrough_phis_empty = func.body.blocks[fallthrough_id].phis.is_empty();
59        assert!(
60            next_phis_empty && fallthrough_phis_empty,
61            "Unexpected phis when merging label blocks"
62        );
63
64        // Validate: single predecessors
65        let next_preds_ok = func.body.blocks[next_id].preds.len() == 1
66            && func.body.blocks[next_id].preds.contains(original_label_id);
67        let fallthrough_preds_ok = func.body.blocks[fallthrough_id].preds.len() == 1
68            && func.body.blocks[fallthrough_id].preds.contains(next_id);
69        assert!(
70            next_preds_ok && fallthrough_preds_ok,
71            "Unexpected block predecessors when merging label blocks"
72        );
73
74        // Collect instructions from next and fallthrough
75        let next_instructions = func.body.blocks[next_id].instructions.clone();
76        let fallthrough_instructions = func.body.blocks[fallthrough_id].instructions.clone();
77        let fallthrough_terminal = func.body.blocks[fallthrough_id].terminal.clone();
78
79        // Merge into the label block
80        let label_block = func.body.blocks.get_mut(&label_id).unwrap();
81        label_block.instructions.extend(next_instructions);
82        label_block.instructions.extend(fallthrough_instructions);
83        label_block.terminal = fallthrough_terminal;
84
85        // Remove merged blocks
86        func.body.blocks.shift_remove(next_id);
87        func.body.blocks.shift_remove(fallthrough_id);
88
89        rewrites.insert(*fallthrough_id, label_id);
90    }
91
92    // Phase 3: Rewrite predecessor sets
93    for block in func.body.blocks.values_mut() {
94        let preds_to_rewrite: Vec<(BlockId, BlockId)> = block
95            .preds
96            .iter()
97            .filter_map(|pred| rewrites.get(pred).map(|rewritten| (*pred, *rewritten)))
98            .collect();
99        for (old, new) in preds_to_rewrite {
100            block.preds.shift_remove(&old);
101            block.preds.insert(new);
102        }
103    }
104}