Skip to main content

react_compiler_optimization/
merge_consecutive_blocks.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Merges sequences of blocks that will always execute consecutively —
7//! i.e., where the predecessor always transfers control to the successor
8//! (ends in a goto) and where the predecessor is the only predecessor
9//! for that successor (no other way to reach the successor).
10//!
11//! Value/loop blocks are left alone because they cannot be merged without
12//! breaking the structure of the high-level terminals that reference them.
13//!
14//! Analogous to TS `HIR/MergeConsecutiveBlocks.ts`.
15
16use std::collections::{HashMap, HashSet};
17
18use react_compiler_hir::visitors;
19use react_compiler_hir::{
20    AliasingEffect, BlockId, BlockKind, Effect, GENERATED_SOURCE, HirFunction, Instruction,
21    InstructionId, InstructionValue, Place, Terminal,
22};
23use react_compiler_lowering::mark_predecessors;
24use react_compiler_ssa::enter_ssa::placeholder_function;
25
26/// Merge consecutive blocks in the function's CFG, including inner functions.
27pub fn merge_consecutive_blocks(func: &mut HirFunction, functions: &mut [HirFunction]) {
28    // Collect inner function IDs for recursive processing
29    let inner_func_ids: Vec<usize> = func
30        .body
31        .blocks
32        .values()
33        .flat_map(|block| block.instructions.iter())
34        .filter_map(|instr_id| {
35            let instr = &func.instructions[instr_id.0 as usize];
36            match &instr.value {
37                InstructionValue::FunctionExpression { lowered_func, .. }
38                | InstructionValue::ObjectMethod { lowered_func, .. } => {
39                    Some(lowered_func.func.0 as usize)
40                }
41                _ => None,
42            }
43        })
44        .collect();
45
46    // Recursively merge consecutive blocks in inner functions
47    for func_id in inner_func_ids {
48        // Use std::mem::replace to temporarily take the inner function out,
49        // process it, then put it back (standard borrow checker workaround)
50        let mut inner_func = std::mem::replace(
51            &mut functions[func_id],
52            placeholder_function(),
53        );
54        merge_consecutive_blocks(&mut inner_func, functions);
55        functions[func_id] = inner_func;
56    }
57
58    // Build fallthrough set
59    let mut fallthrough_blocks: HashSet<BlockId> = HashSet::new();
60    for block in func.body.blocks.values() {
61        if let Some(ft) = visitors::terminal_fallthrough(&block.terminal) {
62            fallthrough_blocks.insert(ft);
63        }
64    }
65
66    let mut merged = MergedBlocks::new();
67
68    // Collect block IDs for iteration (since we modify during iteration)
69    let block_ids: Vec<BlockId> = func.body.blocks.keys().copied().collect();
70
71    for block_id in &block_ids {
72        let block = match func.body.blocks.get(block_id) {
73            Some(b) => b,
74            None => continue, // already removed
75        };
76
77        if block.preds.len() != 1
78            || block.kind != BlockKind::Block
79            || fallthrough_blocks.contains(block_id)
80        {
81            continue;
82        }
83
84        let original_pred_id = *block.preds.iter().next().unwrap();
85        let pred_id = merged.get(original_pred_id);
86
87        // Check predecessor exists and ends in goto with block kind
88        let pred_is_mergeable = func
89            .body
90            .blocks
91            .get(&pred_id)
92            .map(|p| matches!(p.terminal, Terminal::Goto { .. }) && p.kind == BlockKind::Block)
93            .unwrap_or(false);
94
95        if !pred_is_mergeable {
96            continue;
97        }
98
99        // Get evaluation order from predecessor's terminal (for phi instructions)
100        let eval_order = func.body.blocks[&pred_id].terminal.evaluation_order();
101
102        // Collect phi data from the block being merged
103        let phis: Vec<_> = block
104            .phis
105            .iter()
106            .map(|phi| {
107                assert_eq!(
108                    phi.operands.len(),
109                    1,
110                    "Found a block with a single predecessor but where a phi has multiple ({}) operands",
111                    phi.operands.len()
112                );
113                let operand = phi.operands.values().next().unwrap().clone();
114                (phi.place.identifier, operand)
115            })
116            .collect();
117        let block_instr_ids = block.instructions.clone();
118        let block_terminal = block.terminal.clone();
119
120        // Create phi instructions and add to instruction table
121        let mut new_instr_ids = Vec::new();
122        for (identifier, operand) in phis {
123            let lvalue = Place {
124                identifier,
125                effect: Effect::ConditionallyMutate,
126                reactive: false,
127                loc: GENERATED_SOURCE,
128            };
129            let instr = Instruction {
130                id: eval_order,
131                lvalue: lvalue.clone(),
132                value: InstructionValue::LoadLocal {
133                    place: operand.clone(),
134                    loc: GENERATED_SOURCE,
135                },
136                loc: GENERATED_SOURCE,
137                effects: Some(vec![AliasingEffect::Alias {
138                    from: operand,
139                    into: lvalue,
140                }]),
141            };
142            let instr_id = InstructionId(func.instructions.len() as u32);
143            func.instructions.push(instr);
144            new_instr_ids.push(instr_id);
145        }
146
147        // Apply merge to predecessor
148        let pred = func.body.blocks.get_mut(&pred_id).unwrap();
149        pred.instructions.extend(new_instr_ids);
150        pred.instructions.extend(block_instr_ids);
151        pred.terminal = block_terminal;
152
153        // Record merge and remove block
154        merged.merge(*block_id, pred_id);
155        func.body.blocks.shift_remove(block_id);
156    }
157
158    // Update phi operands for merged blocks
159    for block in func.body.blocks.values_mut() {
160        for phi in &mut block.phis {
161            let updates: Vec<_> = phi
162                .operands
163                .iter()
164                .filter_map(|(pred_id, operand)| {
165                    let mapped = merged.get(*pred_id);
166                    if mapped != *pred_id {
167                        Some((*pred_id, mapped, operand.clone()))
168                    } else {
169                        None
170                    }
171                })
172                .collect();
173            for (old_id, new_id, operand) in updates {
174                phi.operands.shift_remove(&old_id);
175                phi.operands.insert(new_id, operand);
176            }
177        }
178    }
179
180    mark_predecessors(&mut func.body);
181
182    // Update terminal successors (including fallthroughs) for merged blocks
183    for block in func.body.blocks.values_mut() {
184        visitors::map_terminal_successors(&mut block.terminal, &mut |block_id| {
185            merged.get(block_id)
186        });
187    }
188}
189
190/// Tracks which blocks have been merged and into which target.
191struct MergedBlocks {
192    map: HashMap<BlockId, BlockId>,
193}
194
195impl MergedBlocks {
196    fn new() -> Self {
197        Self {
198            map: HashMap::new(),
199        }
200    }
201
202    /// Record that `block` was merged into `into`.
203    fn merge(&mut self, block: BlockId, into: BlockId) {
204        let target = self.get(into);
205        self.map.insert(block, target);
206    }
207
208    /// Get the id of the block that `block` has been merged into.
209    /// Transitive: if A merged into B which merged into C, get(A) returns C.
210    fn get(&self, block: BlockId) -> BlockId {
211        let mut current = block;
212        while let Some(&target) = self.map.get(&current) {
213            current = target;
214        }
215        current
216    }
217}