react_compiler_optimization/
merge_consecutive_blocks.rs1use std::collections::{HashMap, HashSet};
17
18use react_compiler_hir::visitors;
19use react_compiler_hir::{
20 AliasingEffect, BlockId, BlockKind, Effect, GENERATED_SOURCE, HirFunction, Instruction,
21 InstructionId, InstructionValue, Place, Terminal,
22};
23use react_compiler_lowering::mark_predecessors;
24use react_compiler_ssa::enter_ssa::placeholder_function;
25
26pub fn merge_consecutive_blocks(func: &mut HirFunction, functions: &mut [HirFunction]) {
28 let inner_func_ids: Vec<usize> = func
30 .body
31 .blocks
32 .values()
33 .flat_map(|block| block.instructions.iter())
34 .filter_map(|instr_id| {
35 let instr = &func.instructions[instr_id.0 as usize];
36 match &instr.value {
37 InstructionValue::FunctionExpression { lowered_func, .. }
38 | InstructionValue::ObjectMethod { lowered_func, .. } => {
39 Some(lowered_func.func.0 as usize)
40 }
41 _ => None,
42 }
43 })
44 .collect();
45
46 for func_id in inner_func_ids {
48 let mut inner_func = std::mem::replace(
51 &mut functions[func_id],
52 placeholder_function(),
53 );
54 merge_consecutive_blocks(&mut inner_func, functions);
55 functions[func_id] = inner_func;
56 }
57
58 let mut fallthrough_blocks: HashSet<BlockId> = HashSet::new();
60 for block in func.body.blocks.values() {
61 if let Some(ft) = visitors::terminal_fallthrough(&block.terminal) {
62 fallthrough_blocks.insert(ft);
63 }
64 }
65
66 let mut merged = MergedBlocks::new();
67
68 let block_ids: Vec<BlockId> = func.body.blocks.keys().copied().collect();
70
71 for block_id in &block_ids {
72 let block = match func.body.blocks.get(block_id) {
73 Some(b) => b,
74 None => continue, };
76
77 if block.preds.len() != 1
78 || block.kind != BlockKind::Block
79 || fallthrough_blocks.contains(block_id)
80 {
81 continue;
82 }
83
84 let original_pred_id = *block.preds.iter().next().unwrap();
85 let pred_id = merged.get(original_pred_id);
86
87 let pred_is_mergeable = func
89 .body
90 .blocks
91 .get(&pred_id)
92 .map(|p| matches!(p.terminal, Terminal::Goto { .. }) && p.kind == BlockKind::Block)
93 .unwrap_or(false);
94
95 if !pred_is_mergeable {
96 continue;
97 }
98
99 let eval_order = func.body.blocks[&pred_id].terminal.evaluation_order();
101
102 let phis: Vec<_> = block
104 .phis
105 .iter()
106 .map(|phi| {
107 assert_eq!(
108 phi.operands.len(),
109 1,
110 "Found a block with a single predecessor but where a phi has multiple ({}) operands",
111 phi.operands.len()
112 );
113 let operand = phi.operands.values().next().unwrap().clone();
114 (phi.place.identifier, operand)
115 })
116 .collect();
117 let block_instr_ids = block.instructions.clone();
118 let block_terminal = block.terminal.clone();
119
120 let mut new_instr_ids = Vec::new();
122 for (identifier, operand) in phis {
123 let lvalue = Place {
124 identifier,
125 effect: Effect::ConditionallyMutate,
126 reactive: false,
127 loc: GENERATED_SOURCE,
128 };
129 let instr = Instruction {
130 id: eval_order,
131 lvalue: lvalue.clone(),
132 value: InstructionValue::LoadLocal {
133 place: operand.clone(),
134 loc: GENERATED_SOURCE,
135 },
136 loc: GENERATED_SOURCE,
137 effects: Some(vec![AliasingEffect::Alias {
138 from: operand,
139 into: lvalue,
140 }]),
141 };
142 let instr_id = InstructionId(func.instructions.len() as u32);
143 func.instructions.push(instr);
144 new_instr_ids.push(instr_id);
145 }
146
147 let pred = func.body.blocks.get_mut(&pred_id).unwrap();
149 pred.instructions.extend(new_instr_ids);
150 pred.instructions.extend(block_instr_ids);
151 pred.terminal = block_terminal;
152
153 merged.merge(*block_id, pred_id);
155 func.body.blocks.shift_remove(block_id);
156 }
157
158 for block in func.body.blocks.values_mut() {
160 for phi in &mut block.phis {
161 let updates: Vec<_> = phi
162 .operands
163 .iter()
164 .filter_map(|(pred_id, operand)| {
165 let mapped = merged.get(*pred_id);
166 if mapped != *pred_id {
167 Some((*pred_id, mapped, operand.clone()))
168 } else {
169 None
170 }
171 })
172 .collect();
173 for (old_id, new_id, operand) in updates {
174 phi.operands.shift_remove(&old_id);
175 phi.operands.insert(new_id, operand);
176 }
177 }
178 }
179
180 mark_predecessors(&mut func.body);
181
182 for block in func.body.blocks.values_mut() {
184 visitors::map_terminal_successors(&mut block.terminal, &mut |block_id| {
185 merged.get(block_id)
186 });
187 }
188}
189
190struct MergedBlocks {
192 map: HashMap<BlockId, BlockId>,
193}
194
195impl MergedBlocks {
196 fn new() -> Self {
197 Self {
198 map: HashMap::new(),
199 }
200 }
201
202 fn merge(&mut self, block: BlockId, into: BlockId) {
204 let target = self.get(into);
205 self.map.insert(block, target);
206 }
207
208 fn get(&self, block: BlockId) -> BlockId {
211 let mut current = block;
212 while let Some(&target) = self.map.get(¤t) {
213 current = target;
214 }
215 current
216 }
217}