Skip to main content

react_compiler_optimization/
merge_consecutive_blocks.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Merges sequences of blocks that will always execute consecutively —
7//! i.e., where the predecessor always transfers control to the successor
8//! (ends in a goto) and where the predecessor is the only predecessor
9//! for that successor (no other way to reach the successor).
10//!
11//! Value/loop blocks are left alone because they cannot be merged without
12//! breaking the structure of the high-level terminals that reference them.
13//!
14//! Analogous to TS `HIR/MergeConsecutiveBlocks.ts`.
15
16use rustc_hash::{FxHashMap, FxHashSet};
17
18use react_compiler_hir::visitors;
19use react_compiler_hir::{
20    AliasingEffect, BlockId, BlockKind, Effect, GENERATED_SOURCE, HirFunction, Instruction,
21    InstructionId, InstructionValue, Place, Terminal,
22};
23use react_compiler_lowering::mark_predecessors;
24use react_compiler_ssa::enter_ssa::placeholder_function;
25
26/// Merge consecutive blocks in the function's CFG, including inner functions.
27pub fn merge_consecutive_blocks(func: &mut HirFunction, functions: &mut [HirFunction]) {
28    // Collect inner function IDs for recursive processing
29    let inner_func_ids: Vec<usize> = func
30        .body
31        .blocks
32        .values()
33        .flat_map(|block| block.instructions.iter())
34        .filter_map(|instr_id| {
35            let instr = &func.instructions[instr_id.0 as usize];
36            match &instr.value {
37                InstructionValue::FunctionExpression { lowered_func, .. }
38                | InstructionValue::ObjectMethod { lowered_func, .. } => {
39                    Some(lowered_func.func.0 as usize)
40                }
41                _ => None,
42            }
43        })
44        .collect();
45
46    // Recursively merge consecutive blocks in inner functions
47    for func_id in inner_func_ids {
48        // Use std::mem::replace to temporarily take the inner function out,
49        // process it, then put it back (standard borrow checker workaround)
50        let mut inner_func = std::mem::replace(&mut functions[func_id], placeholder_function());
51        merge_consecutive_blocks(&mut inner_func, functions);
52        functions[func_id] = inner_func;
53    }
54
55    // Build fallthrough set
56    let mut fallthrough_blocks: FxHashSet<BlockId> = FxHashSet::default();
57    for block in func.body.blocks.values() {
58        if let Some(ft) = visitors::terminal_fallthrough(&block.terminal) {
59            fallthrough_blocks.insert(ft);
60        }
61    }
62
63    let mut merged = MergedBlocks::new();
64
65    // Collect block IDs for iteration (since we modify during iteration)
66    let block_ids: Vec<BlockId> = func.body.blocks.keys().copied().collect();
67
68    for block_id in &block_ids {
69        let block = match func.body.blocks.get(block_id) {
70            Some(b) => b,
71            None => continue, // already removed
72        };
73
74        if block.preds.len() != 1
75            || block.kind != BlockKind::Block
76            || fallthrough_blocks.contains(block_id)
77        {
78            continue;
79        }
80
81        let original_pred_id = *block.preds.iter().next().unwrap();
82        let pred_id = merged.get(original_pred_id);
83
84        // Check predecessor exists and ends in goto with block kind
85        let pred_is_mergeable = func
86            .body
87            .blocks
88            .get(&pred_id)
89            .map(|p| matches!(p.terminal, Terminal::Goto { .. }) && p.kind == BlockKind::Block)
90            .unwrap_or(false);
91
92        if !pred_is_mergeable {
93            continue;
94        }
95
96        // Get evaluation order from predecessor's terminal (for phi instructions)
97        let eval_order = func.body.blocks[&pred_id].terminal.evaluation_order();
98
99        // Collect phi data from the block being merged
100        let phis: Vec<_> = block
101            .phis
102            .iter()
103            .map(|phi| {
104                assert_eq!(
105                    phi.operands.len(),
106                    1,
107                    "Found a block with a single predecessor but where a phi has multiple ({}) operands",
108                    phi.operands.len()
109                );
110                let operand = phi.operands.values().next().unwrap().clone();
111                (phi.place.identifier, operand)
112            })
113            .collect();
114        let block_instr_ids = block.instructions.clone();
115        let block_terminal = block.terminal.clone();
116
117        // Create phi instructions and add to instruction table
118        let mut new_instr_ids = Vec::new();
119        for (identifier, operand) in phis {
120            let lvalue = Place {
121                identifier,
122                effect: Effect::ConditionallyMutate,
123                reactive: false,
124                loc: GENERATED_SOURCE,
125            };
126            let instr = Instruction {
127                id: eval_order,
128                lvalue: lvalue.clone(),
129                value: InstructionValue::LoadLocal {
130                    place: operand.clone(),
131                    loc: GENERATED_SOURCE,
132                },
133                loc: GENERATED_SOURCE,
134                effects: Some(vec![AliasingEffect::Alias {
135                    from: operand,
136                    into: lvalue,
137                }]),
138            };
139            let instr_id = InstructionId(func.instructions.len() as u32);
140            func.instructions.push(instr);
141            new_instr_ids.push(instr_id);
142        }
143
144        // Apply merge to predecessor
145        let pred = func.body.blocks.get_mut(&pred_id).unwrap();
146        pred.instructions.extend(new_instr_ids);
147        pred.instructions.extend(block_instr_ids);
148        pred.terminal = block_terminal;
149
150        // Record merge and remove block
151        merged.merge(*block_id, pred_id);
152        func.body.blocks.shift_remove(block_id);
153    }
154
155    // Update phi operands for merged blocks
156    for block in func.body.blocks.values_mut() {
157        for phi in &mut block.phis {
158            let updates: Vec<_> = phi
159                .operands
160                .iter()
161                .filter_map(|(pred_id, operand)| {
162                    let mapped = merged.get(*pred_id);
163                    if mapped != *pred_id {
164                        Some((*pred_id, mapped, operand.clone()))
165                    } else {
166                        None
167                    }
168                })
169                .collect();
170            for (old_id, new_id, operand) in updates {
171                phi.operands.shift_remove(&old_id);
172                phi.operands.insert(new_id, operand);
173            }
174        }
175    }
176
177    mark_predecessors(&mut func.body);
178
179    // Update terminal successors (including fallthroughs) for merged blocks
180    for block in func.body.blocks.values_mut() {
181        visitors::map_terminal_successors(&mut block.terminal, &mut |block_id| {
182            merged.get(block_id)
183        });
184    }
185}
186
187/// Tracks which blocks have been merged and into which target.
188struct MergedBlocks {
189    map: FxHashMap<BlockId, BlockId>,
190}
191
192impl MergedBlocks {
193    fn new() -> Self {
194        Self {
195            map: FxHashMap::default(),
196        }
197    }
198
199    /// Record that `block` was merged into `into`.
200    fn merge(&mut self, block: BlockId, into: BlockId) {
201        let target = self.get(into);
202        self.map.insert(block, target);
203    }
204
205    /// Get the id of the block that `block` has been merged into.
206    /// Transitive: if A merged into B which merged into C, get(A) returns C.
207    fn get(&self, block: BlockId) -> BlockId {
208        let mut current = block;
209        while let Some(&target) = self.map.get(&current) {
210            current = target;
211        }
212        current
213    }
214}