Skip to main content

react_compiler_inference/
build_reactive_scope_terminals_hir.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Builds reactive scope terminals in the HIR.
7//!
8//! Given a function whose reactive scope ranges have been correctly aligned and
9//! merged, this pass rewrites blocks to introduce ReactiveScopeTerminals and
10//! their fallthrough blocks.
11//!
12//! Ported from TypeScript `src/HIR/BuildReactiveScopeTerminalsHIR.ts`.
13
14use std::collections::HashMap;
15use std::collections::HashSet;
16
17use indexmap::IndexMap;
18use react_compiler_hir::BasicBlock;
19use react_compiler_hir::BlockId;
20use react_compiler_hir::EvaluationOrder;
21use react_compiler_hir::GotoVariant;
22use react_compiler_hir::HirFunction;
23use react_compiler_hir::IdentifierId;
24use react_compiler_hir::ScopeId;
25use react_compiler_hir::Terminal;
26use react_compiler_hir::environment::Environment;
27use react_compiler_hir::visitors::each_instruction_lvalue_ids;
28use react_compiler_hir::visitors::each_instruction_operand_ids;
29use react_compiler_hir::visitors::each_terminal_operand_ids;
30use react_compiler_lowering::get_reverse_postordered_blocks;
31use react_compiler_lowering::mark_instruction_ids;
32use react_compiler_lowering::mark_predecessors;
33
34// =============================================================================
35// getScopes
36// =============================================================================
37
38/// Collect all unique scopes from places in the function that have non-empty ranges.
39/// Corresponds to TS `getScopes(fn)`.
40fn get_scopes(func: &HirFunction, env: &Environment) -> Vec<ScopeId> {
41    let mut scope_ids: HashSet<ScopeId> = HashSet::new();
42
43    let mut visit_place = |identifier_id: IdentifierId| {
44        if let Some(scope_id) = env.identifiers[identifier_id.0 as usize].scope {
45            let range = &env.scopes[scope_id.0 as usize].range;
46            if range.start != range.end {
47                scope_ids.insert(scope_id);
48            }
49        }
50    };
51
52    for (_block_id, block) in &func.body.blocks {
53        for &instr_id in &block.instructions {
54            let instr = &func.instructions[instr_id.0 as usize];
55            // lvalues
56            for id in each_instruction_lvalue_ids(instr) {
57                visit_place(id);
58            }
59            // operands
60            for id in each_instruction_operand_ids(instr, env) {
61                visit_place(id);
62            }
63        }
64        // terminal operands
65        for id in each_terminal_operand_ids(&block.terminal) {
66            visit_place(id);
67        }
68    }
69
70    scope_ids.into_iter().collect()
71}
72
73// =============================================================================
74// TerminalRewriteInfo
75// =============================================================================
76
77enum TerminalRewriteInfo {
78    StartScope {
79        block_id: BlockId,
80        fallthrough_id: BlockId,
81        instr_id: EvaluationOrder,
82        scope_id: ScopeId,
83    },
84    EndScope {
85        instr_id: EvaluationOrder,
86        fallthrough_id: BlockId,
87    },
88}
89
90impl TerminalRewriteInfo {
91    fn instr_id(&self) -> EvaluationOrder {
92        match self {
93            TerminalRewriteInfo::StartScope { instr_id, .. } => *instr_id,
94            TerminalRewriteInfo::EndScope { instr_id, .. } => *instr_id,
95        }
96    }
97}
98
99// =============================================================================
100// collectScopeRewrites
101// =============================================================================
102
103/// Collect all scope rewrites by traversing scopes in pre-order.
104fn collect_scope_rewrites(func: &HirFunction, env: &mut Environment) -> Vec<TerminalRewriteInfo> {
105    let scope_ids = get_scopes(func, env);
106
107    // Sort: ascending by start, descending by end for ties
108    let mut items: Vec<ScopeId> = scope_ids;
109    items.sort_by(|a, b| {
110        let a_range = &env.scopes[a.0 as usize].range;
111        let b_range = &env.scopes[b.0 as usize].range;
112        let start_diff = a_range.start.0.cmp(&b_range.start.0);
113        if start_diff != std::cmp::Ordering::Equal {
114            return start_diff;
115        }
116        b_range.end.0.cmp(&a_range.end.0)
117    });
118
119    let mut rewrites: Vec<TerminalRewriteInfo> = Vec::new();
120    let mut fallthroughs: HashMap<ScopeId, BlockId> = HashMap::new();
121    let mut active_items: Vec<ScopeId> = Vec::new();
122
123    for i in 0..items.len() {
124        let curr = items[i];
125        let curr_start = env.scopes[curr.0 as usize].range.start;
126        let curr_end = env.scopes[curr.0 as usize].range.end;
127
128        // Pop active items that are disjoint with current
129        let mut j = active_items.len();
130        while j > 0 {
131            j -= 1;
132            let maybe_parent = active_items[j];
133            let parent_end = env.scopes[maybe_parent.0 as usize].range.end;
134            let disjoint = curr_start >= parent_end;
135            let nested = curr_end <= parent_end;
136            assert!(
137                disjoint || nested,
138                "Invalid nesting in program blocks or scopes"
139            );
140            if disjoint {
141                // Exit this scope
142                let fallthrough_id = *fallthroughs
143                    .get(&maybe_parent)
144                    .expect("Expected scope to exist");
145                let end_instr_id = env.scopes[maybe_parent.0 as usize].range.end;
146                rewrites.push(TerminalRewriteInfo::EndScope {
147                    instr_id: end_instr_id,
148                    fallthrough_id,
149                });
150                active_items.truncate(j);
151            } else {
152                break;
153            }
154        }
155
156        // Enter scope
157        let block_id = env.next_block_id();
158        let fallthrough_id = env.next_block_id();
159        let start_instr_id = env.scopes[curr.0 as usize].range.start;
160        rewrites.push(TerminalRewriteInfo::StartScope {
161            block_id,
162            fallthrough_id,
163            instr_id: start_instr_id,
164            scope_id: curr,
165        });
166        fallthroughs.insert(curr, fallthrough_id);
167        active_items.push(curr);
168    }
169
170    // Exit remaining active items
171    while let Some(curr) = active_items.pop() {
172        let fallthrough_id = *fallthroughs.get(&curr).expect("Expected scope to exist");
173        let end_instr_id = env.scopes[curr.0 as usize].range.end;
174        rewrites.push(TerminalRewriteInfo::EndScope {
175            instr_id: end_instr_id,
176            fallthrough_id,
177        });
178    }
179
180    rewrites
181}
182
183// =============================================================================
184// handleRewrite
185// =============================================================================
186
187struct RewriteContext {
188    next_block_id: BlockId,
189    next_preds: Vec<BlockId>,
190    instr_slice_idx: usize,
191    rewrites: Vec<BasicBlock>,
192}
193
194fn handle_rewrite(
195    terminal_info: &TerminalRewriteInfo,
196    idx: usize,
197    source_block: &BasicBlock,
198    context: &mut RewriteContext,
199) {
200    let terminal: Terminal = match terminal_info {
201        TerminalRewriteInfo::StartScope {
202            block_id,
203            fallthrough_id,
204            instr_id,
205            scope_id,
206        } => Terminal::Scope {
207            fallthrough: *fallthrough_id,
208            block: *block_id,
209            scope: *scope_id,
210            id: *instr_id,
211            loc: None,
212        },
213        TerminalRewriteInfo::EndScope {
214            instr_id,
215            fallthrough_id,
216        } => Terminal::Goto {
217            variant: GotoVariant::Break,
218            block: *fallthrough_id,
219            id: *instr_id,
220            loc: None,
221        },
222    };
223
224    let curr_block_id = context.next_block_id;
225    let mut preds = indexmap::IndexSet::new();
226    for &p in &context.next_preds {
227        preds.insert(p);
228    }
229
230    context.rewrites.push(BasicBlock {
231        kind: source_block.kind,
232        id: curr_block_id,
233        instructions: source_block.instructions[context.instr_slice_idx..idx].to_vec(),
234        preds,
235        // Only the first rewrite should reuse source block phis
236        phis: if context.rewrites.is_empty() {
237            source_block.phis.clone()
238        } else {
239            Vec::new()
240        },
241        terminal,
242    });
243
244    context.next_preds = vec![curr_block_id];
245    context.next_block_id = match terminal_info {
246        TerminalRewriteInfo::StartScope { block_id, .. } => *block_id,
247        TerminalRewriteInfo::EndScope { fallthrough_id, .. } => *fallthrough_id,
248    };
249    context.instr_slice_idx = idx;
250}
251
252// =============================================================================
253// Public API
254// =============================================================================
255
256/// Builds reactive scope terminals in the HIR.
257///
258/// This pass assumes that all program blocks are properly nested with respect
259/// to fallthroughs. Given a function whose reactive scope ranges have been
260/// correctly aligned and merged, this pass rewrites blocks to introduce
261/// ReactiveScopeTerminals and their fallthrough blocks.
262pub fn build_reactive_scope_terminals_hir(func: &mut HirFunction, env: &mut Environment) {
263    // Step 1: Collect rewrites
264    let mut queued_rewrites = collect_scope_rewrites(func, env);
265
266    // Step 2: Apply rewrites by splitting blocks
267    let mut rewritten_final_blocks: HashMap<BlockId, BlockId> = HashMap::new();
268    let mut next_blocks: IndexMap<BlockId, BasicBlock> = IndexMap::new();
269
270    // Reverse so we can pop from the end while traversing in ascending order
271    queued_rewrites.reverse();
272
273    for (_block_id, block) in &func.body.blocks {
274        let preds_vec: Vec<BlockId> = block.preds.iter().copied().collect();
275        let mut context = RewriteContext {
276            next_block_id: block.id,
277            rewrites: Vec::new(),
278            next_preds: preds_vec,
279            instr_slice_idx: 0,
280        };
281
282        // Handle queued terminal rewrites at their nearest instruction ID
283        for i in 0..block.instructions.len() + 1 {
284            let instr_id = if i < block.instructions.len() {
285                let instr_idx = block.instructions[i];
286                func.instructions[instr_idx.0 as usize].id
287            } else {
288                block.terminal.evaluation_order()
289            };
290
291            while let Some(rewrite) = queued_rewrites.last() {
292                if rewrite.instr_id() <= instr_id {
293                    // Need to pop before calling handle_rewrite
294                    let rewrite = queued_rewrites.pop().unwrap();
295                    handle_rewrite(&rewrite, i, block, &mut context);
296                } else {
297                    break;
298                }
299            }
300        }
301
302        if !context.rewrites.is_empty() {
303            let mut final_preds = indexmap::IndexSet::new();
304            for &p in &context.next_preds {
305                final_preds.insert(p);
306            }
307            let final_block = BasicBlock {
308                id: context.next_block_id,
309                kind: block.kind,
310                preds: final_preds,
311                terminal: block.terminal.clone(),
312                instructions: block.instructions[context.instr_slice_idx..].to_vec(),
313                phis: Vec::new(),
314            };
315            let final_block_id = final_block.id;
316            context.rewrites.push(final_block);
317            for b in context.rewrites {
318                next_blocks.insert(b.id, b);
319            }
320            rewritten_final_blocks.insert(block.id, final_block_id);
321        } else {
322            next_blocks.insert(block.id, block.clone());
323        }
324    }
325
326    func.body.blocks = next_blocks;
327
328    // Step 3: Repoint phis when they refer to a rewritten block
329    for block in func.body.blocks.values_mut() {
330        for phi in &mut block.phis {
331            let updates: Vec<(BlockId, BlockId)> = phi
332                .operands
333                .keys()
334                .filter_map(|original_id| {
335                    rewritten_final_blocks
336                        .get(original_id)
337                        .map(|new_id| (*original_id, *new_id))
338                })
339                .collect();
340            for (old_id, new_id) in updates {
341                if let Some(value) = phi.operands.shift_remove(&old_id) {
342                    phi.operands.insert(new_id, value);
343                }
344            }
345        }
346    }
347
348    // Step 4: Fixup HIR to restore RPO, correct predecessors, renumber instructions
349    func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions);
350    mark_predecessors(&mut func.body);
351    mark_instruction_ids(&mut func.body, &mut func.instructions);
352
353    // Step 5: Fix scope and identifier ranges to account for renumbered instructions
354    fix_scope_and_identifier_ranges(func, env);
355}
356
357/// Fix scope ranges after instruction renumbering.
358/// Scope ranges should always align to start at the 'scope' terminal
359/// and end at the first instruction of the fallthrough block.
360///
361/// In TS, `identifier.mutableRange` and `scope.range` are the same object
362/// reference (after InferReactiveScopeVariables). When scope.range is updated,
363/// all identifiers with that scope automatically see the new range.
364/// BUT: after MergeOverlappingReactiveScopesHIR, repointed identifiers have
365/// mutableRange pointing to the OLD scope's range, NOT the root scope's range.
366/// So only identifiers whose mutableRange matches their scope's pre-renumbering
367/// range should be updated.
368///
369/// Corresponds to TS `fixScopeAndIdentifierRanges`.
370fn fix_scope_and_identifier_ranges(func: &HirFunction, env: &mut Environment) {
371    // Save original scope ranges before updating them. In TS,
372    // identifier.mutableRange and scope.range may or may not be the same
373    // JS object. Only identifiers whose mutableRange shares the same object
374    // reference as scope.range see the update automatically. We simulate
375    // this by only syncing identifiers whose mutableRange matches the
376    // scope's pre-update range.
377    let original_scope_ranges: Vec<react_compiler_hir::MutableRange> = env
378        .scopes
379        .iter()
380        .map(|s| s.range.clone())
381        .collect();
382
383    for (_block_id, block) in &func.body.blocks {
384        match &block.terminal {
385            Terminal::Scope {
386                fallthrough,
387                scope,
388                id,
389                ..
390            }
391            | Terminal::PrunedScope {
392                fallthrough,
393                scope,
394                id,
395                ..
396            } => {
397                let fallthrough_block = func.body.blocks.get(fallthrough).unwrap();
398                let first_id = if !fallthrough_block.instructions.is_empty() {
399                    func.instructions[fallthrough_block.instructions[0].0 as usize].id
400                } else {
401                    fallthrough_block.terminal.evaluation_order()
402                };
403                env.scopes[scope.0 as usize].range.start = *id;
404                env.scopes[scope.0 as usize].range.end = first_id;
405            }
406            _ => {}
407        }
408    }
409
410    // Sync identifier mutable ranges with their scope ranges, but ONLY
411    // for identifiers whose mutableRange has the same identity as their
412    // scope's ORIGINAL range (before the updates above). In TS,
413    // identifier.mutableRange and scope.range are only the same JS object
414    // for identifiers that were the canonical representative when the scope
415    // was created. After MergeOverlappingReactiveScopesHIR, repointed
416    // identifiers have mutableRange pointing to the OLD scope's range,
417    // not the root scope's range — so they should NOT be synced here.
418    for ident in &mut env.identifiers {
419        if let Some(scope_id) = ident.scope {
420            let original = &original_scope_ranges[scope_id.0 as usize];
421            if ident.mutable_range.same_range(original) {
422                let scope_range = &env.scopes[scope_id.0 as usize].range;
423                ident.mutable_range.start = scope_range.start;
424                ident.mutable_range.end = scope_range.end;
425            }
426        }
427    }
428}