Skip to main content

react_compiler_optimization/
prune_unused_labels_hir.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Removes unused labels from the HIR.
7//!
8//! A label terminal whose body block immediately breaks to the label's
9//! fallthrough (with no other predecessors) is effectively a no-op label.
10//! This pass merges such label/body/fallthrough triples into a single block.
11//!
12//! Analogous to TS `PruneUnusedLabelsHIR.ts`.
13
14use react_compiler_hir::{BlockId, BlockKind, GotoVariant, HirFunction, Terminal};
15use std::collections::HashMap;
16
17pub fn prune_unused_labels_hir(func: &mut HirFunction) {
18    // Phase 1: Identify label terminals whose body block immediately breaks
19    // to the fallthrough, and both body and fallthrough are normal blocks.
20    let mut merged: Vec<(BlockId, BlockId, BlockId)> = Vec::new(); // (label, next, fallthrough)
21
22    for (&block_id, block) in &func.body.blocks {
23        if let Terminal::Label {
24            block: next_id,
25            fallthrough: fallthrough_id,
26            ..
27        } = &block.terminal
28        {
29            let next = &func.body.blocks[next_id];
30            let fallthrough = &func.body.blocks[fallthrough_id];
31            if let Terminal::Goto {
32                block: goto_target,
33                variant: GotoVariant::Break,
34                ..
35            } = &next.terminal
36            {
37                if goto_target == fallthrough_id
38                    && next.kind == BlockKind::Block
39                    && fallthrough.kind == BlockKind::Block
40                {
41                    merged.push((block_id, *next_id, *fallthrough_id));
42                }
43            }
44        }
45    }
46
47    // Phase 2: Apply merges
48    let mut rewrites: HashMap<BlockId, BlockId> = HashMap::new();
49
50    for (original_label_id, next_id, fallthrough_id) in &merged {
51        let label_id = rewrites.get(original_label_id).copied().unwrap_or(*original_label_id);
52
53        // Validate: no phis in next or fallthrough
54        let next_phis_empty = func.body.blocks[next_id].phis.is_empty();
55        let fallthrough_phis_empty = func.body.blocks[fallthrough_id].phis.is_empty();
56        assert!(
57            next_phis_empty && fallthrough_phis_empty,
58            "Unexpected phis when merging label blocks"
59        );
60
61        // Validate: single predecessors
62        let next_preds_ok = func.body.blocks[next_id].preds.len() == 1
63            && func.body.blocks[next_id].preds.contains(original_label_id);
64        let fallthrough_preds_ok = func.body.blocks[fallthrough_id].preds.len() == 1
65            && func.body.blocks[fallthrough_id].preds.contains(next_id);
66        assert!(
67            next_preds_ok && fallthrough_preds_ok,
68            "Unexpected block predecessors when merging label blocks"
69        );
70
71        // Collect instructions from next and fallthrough
72        let next_instructions = func.body.blocks[next_id].instructions.clone();
73        let fallthrough_instructions = func.body.blocks[fallthrough_id].instructions.clone();
74        let fallthrough_terminal = func.body.blocks[fallthrough_id].terminal.clone();
75
76        // Merge into the label block
77        let label_block = func.body.blocks.get_mut(&label_id).unwrap();
78        label_block.instructions.extend(next_instructions);
79        label_block.instructions.extend(fallthrough_instructions);
80        label_block.terminal = fallthrough_terminal;
81
82        // Remove merged blocks
83        func.body.blocks.shift_remove(next_id);
84        func.body.blocks.shift_remove(fallthrough_id);
85
86        rewrites.insert(*fallthrough_id, label_id);
87    }
88
89    // Phase 3: Rewrite predecessor sets
90    for block in func.body.blocks.values_mut() {
91        let preds_to_rewrite: Vec<(BlockId, BlockId)> = block
92            .preds
93            .iter()
94            .filter_map(|pred| rewrites.get(pred).map(|rewritten| (*pred, *rewritten)))
95            .collect();
96        for (old, new) in preds_to_rewrite {
97            block.preds.shift_remove(&old);
98            block.preds.insert(new);
99        }
100    }
101}