react_compiler_optimization/
merge_consecutive_blocks.rs1use rustc_hash::{FxHashMap, FxHashSet};
17
18use react_compiler_hir::visitors;
19use react_compiler_hir::{
20 AliasingEffect, BlockId, BlockKind, Effect, GENERATED_SOURCE, HirFunction, Instruction,
21 InstructionId, InstructionValue, Place, Terminal,
22};
23use react_compiler_lowering::mark_predecessors;
24use react_compiler_ssa::enter_ssa::placeholder_function;
25
26pub fn merge_consecutive_blocks(func: &mut HirFunction, functions: &mut [HirFunction]) {
28 let inner_func_ids: Vec<usize> = func
30 .body
31 .blocks
32 .values()
33 .flat_map(|block| block.instructions.iter())
34 .filter_map(|instr_id| {
35 let instr = &func.instructions[instr_id.0 as usize];
36 match &instr.value {
37 InstructionValue::FunctionExpression { lowered_func, .. }
38 | InstructionValue::ObjectMethod { lowered_func, .. } => {
39 Some(lowered_func.func.0 as usize)
40 }
41 _ => None,
42 }
43 })
44 .collect();
45
46 for func_id in inner_func_ids {
48 let mut inner_func = std::mem::replace(&mut functions[func_id], placeholder_function());
51 merge_consecutive_blocks(&mut inner_func, functions);
52 functions[func_id] = inner_func;
53 }
54
55 let mut fallthrough_blocks: FxHashSet<BlockId> = FxHashSet::default();
57 for block in func.body.blocks.values() {
58 if let Some(ft) = visitors::terminal_fallthrough(&block.terminal) {
59 fallthrough_blocks.insert(ft);
60 }
61 }
62
63 let mut merged = MergedBlocks::new();
64
65 let block_ids: Vec<BlockId> = func.body.blocks.keys().copied().collect();
67
68 for block_id in &block_ids {
69 let block = match func.body.blocks.get(block_id) {
70 Some(b) => b,
71 None => continue, };
73
74 if block.preds.len() != 1
75 || block.kind != BlockKind::Block
76 || fallthrough_blocks.contains(block_id)
77 {
78 continue;
79 }
80
81 let original_pred_id = *block.preds.iter().next().unwrap();
82 let pred_id = merged.get(original_pred_id);
83
84 let pred_is_mergeable = func
86 .body
87 .blocks
88 .get(&pred_id)
89 .map(|p| matches!(p.terminal, Terminal::Goto { .. }) && p.kind == BlockKind::Block)
90 .unwrap_or(false);
91
92 if !pred_is_mergeable {
93 continue;
94 }
95
96 let eval_order = func.body.blocks[&pred_id].terminal.evaluation_order();
98
99 let phis: Vec<_> = block
101 .phis
102 .iter()
103 .map(|phi| {
104 assert_eq!(
105 phi.operands.len(),
106 1,
107 "Found a block with a single predecessor but where a phi has multiple ({}) operands",
108 phi.operands.len()
109 );
110 let operand = phi.operands.values().next().unwrap().clone();
111 (phi.place.identifier, operand)
112 })
113 .collect();
114 let block_instr_ids = block.instructions.clone();
115 let block_terminal = block.terminal.clone();
116
117 let mut new_instr_ids = Vec::new();
119 for (identifier, operand) in phis {
120 let lvalue = Place {
121 identifier,
122 effect: Effect::ConditionallyMutate,
123 reactive: false,
124 loc: GENERATED_SOURCE,
125 };
126 let instr = Instruction {
127 id: eval_order,
128 lvalue: lvalue.clone(),
129 value: InstructionValue::LoadLocal {
130 place: operand.clone(),
131 loc: GENERATED_SOURCE,
132 },
133 loc: GENERATED_SOURCE,
134 effects: Some(vec![AliasingEffect::Alias {
135 from: operand,
136 into: lvalue,
137 }]),
138 };
139 let instr_id = InstructionId(func.instructions.len() as u32);
140 func.instructions.push(instr);
141 new_instr_ids.push(instr_id);
142 }
143
144 let pred = func.body.blocks.get_mut(&pred_id).unwrap();
146 pred.instructions.extend(new_instr_ids);
147 pred.instructions.extend(block_instr_ids);
148 pred.terminal = block_terminal;
149
150 merged.merge(*block_id, pred_id);
152 func.body.blocks.shift_remove(block_id);
153 }
154
155 for block in func.body.blocks.values_mut() {
157 for phi in &mut block.phis {
158 let updates: Vec<_> = phi
159 .operands
160 .iter()
161 .filter_map(|(pred_id, operand)| {
162 let mapped = merged.get(*pred_id);
163 if mapped != *pred_id {
164 Some((*pred_id, mapped, operand.clone()))
165 } else {
166 None
167 }
168 })
169 .collect();
170 for (old_id, new_id, operand) in updates {
171 phi.operands.shift_remove(&old_id);
172 phi.operands.insert(new_id, operand);
173 }
174 }
175 }
176
177 mark_predecessors(&mut func.body);
178
179 for block in func.body.blocks.values_mut() {
181 visitors::map_terminal_successors(&mut block.terminal, &mut |block_id| {
182 merged.get(block_id)
183 });
184 }
185}
186
187struct MergedBlocks {
189 map: FxHashMap<BlockId, BlockId>,
190}
191
192impl MergedBlocks {
193 fn new() -> Self {
194 Self {
195 map: FxHashMap::default(),
196 }
197 }
198
199 fn merge(&mut self, block: BlockId, into: BlockId) {
201 let target = self.get(into);
202 self.map.insert(block, target);
203 }
204
205 fn get(&self, block: BlockId) -> BlockId {
208 let mut current = block;
209 while let Some(&target) = self.map.get(¤t) {
210 current = target;
211 }
212 current
213 }
214}