Skip to main content

react_compiler_inference/
infer_reactive_scope_variables.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! Infers which variables belong to reactive scopes.
7//!
8//! Ported from TypeScript `src/ReactiveScopes/InferReactiveScopeVariables.ts`.
9//!
10//! This is the 1st of 4 passes that determine how to break a function into
11//! discrete reactive scopes (independently memoizable units of code):
12//! 1. InferReactiveScopeVariables (this pass, on HIR) determines operands that
13//!    mutate together and assigns them a unique reactive scope.
14//! 2. AlignReactiveScopesToBlockScopes aligns reactive scopes to block scopes.
15//! 3. MergeOverlappingReactiveScopes ensures scopes do not overlap.
16//! 4. BuildReactiveBlocks groups the statements for each scope.
17
18use std::collections::HashMap;
19
20use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
21use react_compiler_hir::environment::Environment;
22use react_compiler_hir::visitors;
23use react_compiler_hir::{
24    DeclarationId, EvaluationOrder, HirFunction, IdentifierId,
25    InstructionValue, Pattern, Position, SourceLocation,
26};
27use react_compiler_utils::DisjointSet;
28
29// =============================================================================
30// Public API
31// =============================================================================
32
33/// Infer reactive scope variables for a function.
34///
35/// For each mutable variable, infers a reactive scope which will construct that
36/// variable. Variables that co-mutate are assigned to the same reactive scope.
37///
38/// Corresponds to TS `inferReactiveScopeVariables(fn: HIRFunction): void`.
39pub fn infer_reactive_scope_variables(func: &mut HirFunction, env: &mut Environment) -> Result<(), CompilerDiagnostic> {
40    // Phase 1: find disjoint sets of co-mutating identifiers
41    let mut scope_identifiers = find_disjoint_mutable_values(func, env);
42
43    // Phase 2: assign scopes
44    // Maps each group root identifier to the ScopeId assigned to that group.
45    let mut scopes: HashMap<IdentifierId, ScopeState> = HashMap::new();
46
47    scope_identifiers.for_each(|identifier_id, group_id| {
48        let ident_range = env.identifiers[identifier_id.0 as usize].mutable_range.clone();
49        let ident_loc = env.identifiers[identifier_id.0 as usize].loc;
50
51        let state = scopes.entry(group_id).or_insert_with(|| {
52            let scope_id = env.next_scope_id();
53            // Initialize scope range from the first member
54            let scope = &mut env.scopes[scope_id.0 as usize];
55            scope.range = ident_range.clone();
56            ScopeState {
57                scope_id,
58                loc: ident_loc,
59            }
60        });
61
62        // Update scope range
63        let scope = &mut env.scopes[state.scope_id.0 as usize];
64
65        // If this is not the first identifier (scope was already created), merge ranges
66        if scope.range.start != ident_range.start || scope.range.end != ident_range.end {
67            if scope.range.start == EvaluationOrder(0) {
68                scope.range.start = ident_range.start;
69            } else if ident_range.start != EvaluationOrder(0) {
70                scope.range.start =
71                    EvaluationOrder(scope.range.start.0.min(ident_range.start.0));
72            }
73            scope.range.end = EvaluationOrder(scope.range.end.0.max(ident_range.end.0));
74        }
75
76        // Merge location
77        state.loc = merge_location(state.loc, ident_loc);
78
79        // Assign the scope to this identifier
80        let scope_id = state.scope_id;
81        env.identifiers[identifier_id.0 as usize].scope = Some(scope_id);
82    });
83
84    // Set loc on each scope
85    for (_group_id, state) in &scopes {
86        env.scopes[state.scope_id.0 as usize].loc = state.loc;
87    }
88
89    // Update each identifier's mutable_range to match its scope's range
90    for (&_identifier_id, state) in &scopes {
91        let scope_range = env.scopes[state.scope_id.0 as usize].range.clone();
92        // Find all identifiers with this scope and update their mutable_range
93        // We iterate through all identifiers and check their scope
94        for ident in &mut env.identifiers {
95            if ident.scope == Some(state.scope_id) {
96                ident.mutable_range = scope_range.clone();
97            }
98        }
99    }
100
101    // Validate scope ranges
102    let mut max_instruction = EvaluationOrder(0);
103    for (_block_id, block) in &func.body.blocks {
104        for instr_id in &block.instructions {
105            let instr = &func.instructions[instr_id.0 as usize];
106            max_instruction = EvaluationOrder(max_instruction.0.max(instr.id.0));
107        }
108        max_instruction = EvaluationOrder(max_instruction.0.max(block.terminal.evaluation_order().0));
109    }
110
111    for (_group_id, state) in &scopes {
112        let scope = &env.scopes[state.scope_id.0 as usize];
113        if scope.range.start == EvaluationOrder(0)
114            || scope.range.end == EvaluationOrder(0)
115            || max_instruction == EvaluationOrder(0)
116            || scope.range.end.0 > max_instruction.0 + 1
117        {
118            return Err(CompilerDiagnostic::new(
119                ErrorCategory::Invariant,
120                &format!(
121                    "Invalid mutable range for scope: Scope @{} has range [{}:{}] but the valid range is [1:{}]",
122                    scope.id.0,
123                    scope.range.start.0,
124                    scope.range.end.0,
125                    max_instruction.0 + 1,
126                ),
127                None,
128            ));
129        }
130    }
131
132    Ok(())
133}
134
135struct ScopeState {
136    scope_id: react_compiler_hir::ScopeId,
137    loc: Option<SourceLocation>,
138}
139
140/// Merge two source locations, preferring non-None values.
141/// Corresponds to TS `mergeLocation`.
142fn merge_location(
143    l: Option<SourceLocation>,
144    r: Option<SourceLocation>,
145) -> Option<SourceLocation> {
146    match (l, r) {
147        (None, r) => r,
148        (l, None) => l,
149        (Some(l), Some(r)) => Some(SourceLocation {
150            start: Position {
151                line: l.start.line.min(r.start.line),
152                column: l.start.column.min(r.start.column),
153                index: match (l.start.index, r.start.index) {
154                    (Some(a), Some(b)) => Some(a.min(b)),
155                    (a, b) => a.or(b),
156                },
157            },
158            end: Position {
159                line: l.end.line.max(r.end.line),
160                column: l.end.column.max(r.end.column),
161                index: match (l.end.index, r.end.index) {
162                    (Some(a), Some(b)) => Some(a.max(b)),
163                    (a, b) => a.or(b),
164                },
165            },
166        }),
167    }
168}
169
170// =============================================================================
171// is_mutable / in_range helpers
172// =============================================================================
173
174
175// =============================================================================
176// may_allocate
177// =============================================================================
178
179/// Check if an instruction may allocate. Corresponds to TS `mayAllocate`.
180fn may_allocate(value: &InstructionValue, lvalue_type_is_primitive: bool) -> bool {
181    match value {
182        InstructionValue::Destructure { lvalue, .. } => {
183            visitors::does_pattern_contain_spread_element(&lvalue.pattern)
184        }
185        InstructionValue::PostfixUpdate { .. }
186        | InstructionValue::PrefixUpdate { .. }
187        | InstructionValue::Await { .. }
188        | InstructionValue::DeclareLocal { .. }
189        | InstructionValue::DeclareContext { .. }
190        | InstructionValue::StoreLocal { .. }
191        | InstructionValue::LoadGlobal { .. }
192        | InstructionValue::MetaProperty { .. }
193        | InstructionValue::TypeCastExpression { .. }
194        | InstructionValue::LoadLocal { .. }
195        | InstructionValue::LoadContext { .. }
196        | InstructionValue::StoreContext { .. }
197        | InstructionValue::PropertyDelete { .. }
198        | InstructionValue::ComputedLoad { .. }
199        | InstructionValue::ComputedDelete { .. }
200        | InstructionValue::JSXText { .. }
201        | InstructionValue::TemplateLiteral { .. }
202        | InstructionValue::Primitive { .. }
203        | InstructionValue::GetIterator { .. }
204        | InstructionValue::IteratorNext { .. }
205        | InstructionValue::NextPropertyOf { .. }
206        | InstructionValue::Debugger { .. }
207        | InstructionValue::StartMemoize { .. }
208        | InstructionValue::FinishMemoize { .. }
209        | InstructionValue::UnaryExpression { .. }
210        | InstructionValue::BinaryExpression { .. }
211        | InstructionValue::PropertyLoad { .. }
212        | InstructionValue::StoreGlobal { .. } => false,
213
214        InstructionValue::TaggedTemplateExpression { .. }
215        | InstructionValue::CallExpression { .. }
216        | InstructionValue::MethodCall { .. } => !lvalue_type_is_primitive,
217
218        InstructionValue::RegExpLiteral { .. }
219        | InstructionValue::PropertyStore { .. }
220        | InstructionValue::ComputedStore { .. }
221        | InstructionValue::ArrayExpression { .. }
222        | InstructionValue::JsxExpression { .. }
223        | InstructionValue::JsxFragment { .. }
224        | InstructionValue::NewExpression { .. }
225        | InstructionValue::ObjectExpression { .. }
226        | InstructionValue::UnsupportedNode { .. }
227        | InstructionValue::ObjectMethod { .. }
228        | InstructionValue::FunctionExpression { .. } => true,
229    }
230}
231
232// =============================================================================
233// Pattern helpers
234// =============================================================================
235
236/// Collect all Place identifiers from a destructure pattern.
237/// Corresponds to TS `eachPatternOperand`.
238fn each_pattern_operand(pattern: &Pattern) -> Vec<IdentifierId> {
239    visitors::each_pattern_operand(pattern)
240        .into_iter()
241        .map(|p| p.identifier)
242        .collect()
243}
244
245/// Collect all operand identifiers from an instruction value.
246/// Corresponds to TS `eachInstructionValueOperand`.
247fn each_instruction_value_operand(
248    value: &InstructionValue,
249    env: &Environment,
250) -> Vec<IdentifierId> {
251    visitors::each_instruction_value_operand(value, env)
252        .into_iter()
253        .map(|p| p.identifier)
254        .collect()
255}
256
257// =============================================================================
258// findDisjointMutableValues
259// =============================================================================
260
261/// Find disjoint sets of co-mutating identifier IDs.
262///
263/// Corresponds to TS `findDisjointMutableValues(fn: HIRFunction): DisjointSet<Identifier>`.
264pub(crate) fn find_disjoint_mutable_values(func: &HirFunction, env: &Environment) -> DisjointSet<IdentifierId> {
265    let mut scope_identifiers = DisjointSet::<IdentifierId>::new();
266    let mut declarations: HashMap<DeclarationId, IdentifierId> = HashMap::new();
267
268    let enable_forest = env.config.enable_forest;
269
270    for (_block_id, block) in &func.body.blocks {
271        // Handle phi nodes
272        for phi in &block.phis {
273            let phi_id = phi.place.identifier;
274            let phi_range = &env.identifiers[phi_id.0 as usize].mutable_range;
275            let phi_decl_id = env.identifiers[phi_id.0 as usize].declaration_id;
276
277            let first_instr_id = block
278                .instructions
279                .first()
280                .map(|iid| func.instructions[iid.0 as usize].id)
281                .unwrap_or(block.terminal.evaluation_order());
282
283            if phi_range.start.0 + 1 != phi_range.end.0
284                && phi_range.end > first_instr_id
285            {
286                let mut operands = vec![phi_id];
287                if let Some(&decl_id) = declarations.get(&phi_decl_id) {
288                    operands.push(decl_id);
289                }
290                for (_pred_id, phi_operand) in &phi.operands {
291                    operands.push(phi_operand.identifier);
292                }
293                scope_identifiers.union(&operands);
294            } else if enable_forest {
295                for (_pred_id, phi_operand) in &phi.operands {
296                    scope_identifiers.union(&[phi_id, phi_operand.identifier]);
297                }
298            }
299        }
300
301        // Handle instructions
302        for instr_id in &block.instructions {
303            let instr = &func.instructions[instr_id.0 as usize];
304            let mut operands: Vec<IdentifierId> = Vec::new();
305
306            let lvalue_id = instr.lvalue.identifier;
307            let lvalue_range = &env.identifiers[lvalue_id.0 as usize].mutable_range;
308            let lvalue_type = &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize];
309            let lvalue_type_is_primitive = react_compiler_hir::is_primitive_type(lvalue_type);
310
311            if lvalue_range.end.0 > lvalue_range.start.0 + 1
312                || may_allocate(&instr.value, lvalue_type_is_primitive)
313            {
314                operands.push(lvalue_id);
315            }
316
317            match &instr.value {
318                InstructionValue::DeclareLocal { lvalue, .. }
319                | InstructionValue::DeclareContext { lvalue, .. } => {
320                    let place_id = lvalue.place.identifier;
321                    let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
322                    declarations.entry(decl_id).or_insert(place_id);
323                }
324                InstructionValue::StoreLocal { lvalue, value, .. }
325                | InstructionValue::StoreContext { lvalue, value, .. } => {
326                    let place_id = lvalue.place.identifier;
327                    let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
328                    declarations.entry(decl_id).or_insert(place_id);
329
330                    let place_range =
331                        &env.identifiers[place_id.0 as usize].mutable_range;
332                    if place_range.end.0 > place_range.start.0 + 1 {
333                        operands.push(place_id);
334                    }
335
336                    let value_range =
337                        &env.identifiers[value.identifier.0 as usize].mutable_range;
338                    if value_range.contains(instr.id)
339                        && value_range.start.0 > 0
340                    {
341                        operands.push(value.identifier);
342                    }
343                }
344                InstructionValue::Destructure { lvalue, value, .. } => {
345                    let pattern_places = each_pattern_operand(&lvalue.pattern);
346                    for place_id in &pattern_places {
347                        let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
348                        declarations.entry(decl_id).or_insert(*place_id);
349
350                        let place_range =
351                            &env.identifiers[place_id.0 as usize].mutable_range;
352                        if place_range.end.0 > place_range.start.0 + 1 {
353                            operands.push(*place_id);
354                        }
355                    }
356
357                    let value_range =
358                        &env.identifiers[value.identifier.0 as usize].mutable_range;
359                    if value_range.contains(instr.id)
360                        && value_range.start.0 > 0
361                    {
362                        operands.push(value.identifier);
363                    }
364                }
365                InstructionValue::MethodCall { property, .. } => {
366                    // For MethodCall: include all mutable operands plus the computed property
367                    let all_operands =
368                        each_instruction_value_operand(&instr.value, env);
369                    for op_id in &all_operands {
370                        let op_range =
371                            &env.identifiers[op_id.0 as usize].mutable_range;
372                        if op_range.contains(instr.id) && op_range.start.0 > 0 {
373                            operands.push(*op_id);
374                        }
375                    }
376                    // Ensure method property is in the same scope as the call
377                    operands.push(property.identifier);
378                }
379                _ => {
380                    // For all other instructions: include mutable operands
381                    let all_operands =
382                        each_instruction_value_operand(&instr.value, env);
383                    for op_id in &all_operands {
384                        let op_range =
385                            &env.identifiers[op_id.0 as usize].mutable_range;
386                        if op_range.contains(instr.id) && op_range.start.0 > 0 {
387                            operands.push(*op_id);
388                        }
389                    }
390                }
391            }
392
393            if !operands.is_empty() {
394                scope_identifiers.union(&operands);
395            }
396        }
397    }
398    scope_identifiers
399}