react_compiler_inference/
build_reactive_scope_terminals_hir.rs1use std::collections::HashMap;
15use std::collections::HashSet;
16
17use indexmap::IndexMap;
18use react_compiler_hir::BasicBlock;
19use react_compiler_hir::BlockId;
20use react_compiler_hir::EvaluationOrder;
21use react_compiler_hir::GotoVariant;
22use react_compiler_hir::HirFunction;
23use react_compiler_hir::IdentifierId;
24use react_compiler_hir::ScopeId;
25use react_compiler_hir::Terminal;
26use react_compiler_hir::environment::Environment;
27use react_compiler_hir::visitors::each_instruction_lvalue_ids;
28use react_compiler_hir::visitors::each_instruction_operand_ids;
29use react_compiler_hir::visitors::each_terminal_operand_ids;
30use react_compiler_lowering::get_reverse_postordered_blocks;
31use react_compiler_lowering::mark_instruction_ids;
32use react_compiler_lowering::mark_predecessors;
33
34fn get_scopes(func: &HirFunction, env: &Environment) -> Vec<ScopeId> {
41 let mut scope_ids: HashSet<ScopeId> = HashSet::new();
42
43 let mut visit_place = |identifier_id: IdentifierId| {
44 if let Some(scope_id) = env.identifiers[identifier_id.0 as usize].scope {
45 let range = &env.scopes[scope_id.0 as usize].range;
46 if range.start != range.end {
47 scope_ids.insert(scope_id);
48 }
49 }
50 };
51
52 for (_block_id, block) in &func.body.blocks {
53 for &instr_id in &block.instructions {
54 let instr = &func.instructions[instr_id.0 as usize];
55 for id in each_instruction_lvalue_ids(instr) {
57 visit_place(id);
58 }
59 for id in each_instruction_operand_ids(instr, env) {
61 visit_place(id);
62 }
63 }
64 for id in each_terminal_operand_ids(&block.terminal) {
66 visit_place(id);
67 }
68 }
69
70 scope_ids.into_iter().collect()
71}
72
73enum TerminalRewriteInfo {
78 StartScope {
79 block_id: BlockId,
80 fallthrough_id: BlockId,
81 instr_id: EvaluationOrder,
82 scope_id: ScopeId,
83 },
84 EndScope {
85 instr_id: EvaluationOrder,
86 fallthrough_id: BlockId,
87 },
88}
89
90impl TerminalRewriteInfo {
91 fn instr_id(&self) -> EvaluationOrder {
92 match self {
93 TerminalRewriteInfo::StartScope { instr_id, .. } => *instr_id,
94 TerminalRewriteInfo::EndScope { instr_id, .. } => *instr_id,
95 }
96 }
97}
98
99fn collect_scope_rewrites(func: &HirFunction, env: &mut Environment) -> Vec<TerminalRewriteInfo> {
105 let scope_ids = get_scopes(func, env);
106
107 let mut items: Vec<ScopeId> = scope_ids;
109 items.sort_by(|a, b| {
110 let a_range = &env.scopes[a.0 as usize].range;
111 let b_range = &env.scopes[b.0 as usize].range;
112 let start_diff = a_range.start.0.cmp(&b_range.start.0);
113 if start_diff != std::cmp::Ordering::Equal {
114 return start_diff;
115 }
116 b_range.end.0.cmp(&a_range.end.0)
117 });
118
119 let mut rewrites: Vec<TerminalRewriteInfo> = Vec::new();
120 let mut fallthroughs: HashMap<ScopeId, BlockId> = HashMap::new();
121 let mut active_items: Vec<ScopeId> = Vec::new();
122
123 for i in 0..items.len() {
124 let curr = items[i];
125 let curr_start = env.scopes[curr.0 as usize].range.start;
126 let curr_end = env.scopes[curr.0 as usize].range.end;
127
128 let mut j = active_items.len();
130 while j > 0 {
131 j -= 1;
132 let maybe_parent = active_items[j];
133 let parent_end = env.scopes[maybe_parent.0 as usize].range.end;
134 let disjoint = curr_start >= parent_end;
135 let nested = curr_end <= parent_end;
136 assert!(
137 disjoint || nested,
138 "Invalid nesting in program blocks or scopes"
139 );
140 if disjoint {
141 let fallthrough_id = *fallthroughs
143 .get(&maybe_parent)
144 .expect("Expected scope to exist");
145 let end_instr_id = env.scopes[maybe_parent.0 as usize].range.end;
146 rewrites.push(TerminalRewriteInfo::EndScope {
147 instr_id: end_instr_id,
148 fallthrough_id,
149 });
150 active_items.truncate(j);
151 } else {
152 break;
153 }
154 }
155
156 let block_id = env.next_block_id();
158 let fallthrough_id = env.next_block_id();
159 let start_instr_id = env.scopes[curr.0 as usize].range.start;
160 rewrites.push(TerminalRewriteInfo::StartScope {
161 block_id,
162 fallthrough_id,
163 instr_id: start_instr_id,
164 scope_id: curr,
165 });
166 fallthroughs.insert(curr, fallthrough_id);
167 active_items.push(curr);
168 }
169
170 while let Some(curr) = active_items.pop() {
172 let fallthrough_id = *fallthroughs.get(&curr).expect("Expected scope to exist");
173 let end_instr_id = env.scopes[curr.0 as usize].range.end;
174 rewrites.push(TerminalRewriteInfo::EndScope {
175 instr_id: end_instr_id,
176 fallthrough_id,
177 });
178 }
179
180 rewrites
181}
182
183struct RewriteContext {
188 next_block_id: BlockId,
189 next_preds: Vec<BlockId>,
190 instr_slice_idx: usize,
191 rewrites: Vec<BasicBlock>,
192}
193
194fn handle_rewrite(
195 terminal_info: &TerminalRewriteInfo,
196 idx: usize,
197 source_block: &BasicBlock,
198 context: &mut RewriteContext,
199) {
200 let terminal: Terminal = match terminal_info {
201 TerminalRewriteInfo::StartScope {
202 block_id,
203 fallthrough_id,
204 instr_id,
205 scope_id,
206 } => Terminal::Scope {
207 fallthrough: *fallthrough_id,
208 block: *block_id,
209 scope: *scope_id,
210 id: *instr_id,
211 loc: None,
212 },
213 TerminalRewriteInfo::EndScope {
214 instr_id,
215 fallthrough_id,
216 } => Terminal::Goto {
217 variant: GotoVariant::Break,
218 block: *fallthrough_id,
219 id: *instr_id,
220 loc: None,
221 },
222 };
223
224 let curr_block_id = context.next_block_id;
225 let mut preds = indexmap::IndexSet::new();
226 for &p in &context.next_preds {
227 preds.insert(p);
228 }
229
230 context.rewrites.push(BasicBlock {
231 kind: source_block.kind,
232 id: curr_block_id,
233 instructions: source_block.instructions[context.instr_slice_idx..idx].to_vec(),
234 preds,
235 phis: if context.rewrites.is_empty() {
237 source_block.phis.clone()
238 } else {
239 Vec::new()
240 },
241 terminal,
242 });
243
244 context.next_preds = vec![curr_block_id];
245 context.next_block_id = match terminal_info {
246 TerminalRewriteInfo::StartScope { block_id, .. } => *block_id,
247 TerminalRewriteInfo::EndScope { fallthrough_id, .. } => *fallthrough_id,
248 };
249 context.instr_slice_idx = idx;
250}
251
252pub fn build_reactive_scope_terminals_hir(func: &mut HirFunction, env: &mut Environment) {
263 let mut queued_rewrites = collect_scope_rewrites(func, env);
265
266 let mut rewritten_final_blocks: HashMap<BlockId, BlockId> = HashMap::new();
268 let mut next_blocks: IndexMap<BlockId, BasicBlock> = IndexMap::new();
269
270 queued_rewrites.reverse();
272
273 for (_block_id, block) in &func.body.blocks {
274 let preds_vec: Vec<BlockId> = block.preds.iter().copied().collect();
275 let mut context = RewriteContext {
276 next_block_id: block.id,
277 rewrites: Vec::new(),
278 next_preds: preds_vec,
279 instr_slice_idx: 0,
280 };
281
282 for i in 0..block.instructions.len() + 1 {
284 let instr_id = if i < block.instructions.len() {
285 let instr_idx = block.instructions[i];
286 func.instructions[instr_idx.0 as usize].id
287 } else {
288 block.terminal.evaluation_order()
289 };
290
291 while let Some(rewrite) = queued_rewrites.last() {
292 if rewrite.instr_id() <= instr_id {
293 let rewrite = queued_rewrites.pop().unwrap();
295 handle_rewrite(&rewrite, i, block, &mut context);
296 } else {
297 break;
298 }
299 }
300 }
301
302 if !context.rewrites.is_empty() {
303 let mut final_preds = indexmap::IndexSet::new();
304 for &p in &context.next_preds {
305 final_preds.insert(p);
306 }
307 let final_block = BasicBlock {
308 id: context.next_block_id,
309 kind: block.kind,
310 preds: final_preds,
311 terminal: block.terminal.clone(),
312 instructions: block.instructions[context.instr_slice_idx..].to_vec(),
313 phis: Vec::new(),
314 };
315 let final_block_id = final_block.id;
316 context.rewrites.push(final_block);
317 for b in context.rewrites {
318 next_blocks.insert(b.id, b);
319 }
320 rewritten_final_blocks.insert(block.id, final_block_id);
321 } else {
322 next_blocks.insert(block.id, block.clone());
323 }
324 }
325
326 func.body.blocks = next_blocks;
327
328 for block in func.body.blocks.values_mut() {
330 for phi in &mut block.phis {
331 let updates: Vec<(BlockId, BlockId)> = phi
332 .operands
333 .keys()
334 .filter_map(|original_id| {
335 rewritten_final_blocks
336 .get(original_id)
337 .map(|new_id| (*original_id, *new_id))
338 })
339 .collect();
340 for (old_id, new_id) in updates {
341 if let Some(value) = phi.operands.shift_remove(&old_id) {
342 phi.operands.insert(new_id, value);
343 }
344 }
345 }
346 }
347
348 func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions);
350 mark_predecessors(&mut func.body);
351 mark_instruction_ids(&mut func.body, &mut func.instructions);
352
353 fix_scope_and_identifier_ranges(func, env);
355}
356
357fn fix_scope_and_identifier_ranges(func: &HirFunction, env: &mut Environment) {
371 let original_scope_ranges: Vec<react_compiler_hir::MutableRange> = env
378 .scopes
379 .iter()
380 .map(|s| s.range.clone())
381 .collect();
382
383 for (_block_id, block) in &func.body.blocks {
384 match &block.terminal {
385 Terminal::Scope {
386 fallthrough,
387 scope,
388 id,
389 ..
390 }
391 | Terminal::PrunedScope {
392 fallthrough,
393 scope,
394 id,
395 ..
396 } => {
397 let fallthrough_block = func.body.blocks.get(fallthrough).unwrap();
398 let first_id = if !fallthrough_block.instructions.is_empty() {
399 func.instructions[fallthrough_block.instructions[0].0 as usize].id
400 } else {
401 fallthrough_block.terminal.evaluation_order()
402 };
403 env.scopes[scope.0 as usize].range.start = *id;
404 env.scopes[scope.0 as usize].range.end = first_id;
405 }
406 _ => {}
407 }
408 }
409
410 for ident in &mut env.identifiers {
419 if let Some(scope_id) = ident.scope {
420 let original = &original_scope_ranges[scope_id.0 as usize];
421 if ident.mutable_range.same_range(original) {
422 let scope_range = &env.scopes[scope_id.0 as usize].range;
423 ident.mutable_range.start = scope_range.start;
424 ident.mutable_range.end = scope_range.end;
425 }
426 }
427 }
428}