use std::collections::HashMap;
use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory};
use react_compiler_hir::environment::Environment;
use react_compiler_hir::visitors;
use react_compiler_hir::{
DeclarationId, EvaluationOrder, HirFunction, IdentifierId,
InstructionValue, Pattern, Position, SourceLocation,
};
use react_compiler_utils::DisjointSet;
pub fn infer_reactive_scope_variables(func: &mut HirFunction, env: &mut Environment) -> Result<(), CompilerDiagnostic> {
let mut scope_identifiers = find_disjoint_mutable_values(func, env);
let mut scopes: HashMap<IdentifierId, ScopeState> = HashMap::new();
scope_identifiers.for_each(|identifier_id, group_id| {
let ident_range = env.identifiers[identifier_id.0 as usize].mutable_range.clone();
let ident_loc = env.identifiers[identifier_id.0 as usize].loc;
let state = scopes.entry(group_id).or_insert_with(|| {
let scope_id = env.next_scope_id();
let scope = &mut env.scopes[scope_id.0 as usize];
scope.range = ident_range.clone();
ScopeState {
scope_id,
loc: ident_loc,
}
});
let scope = &mut env.scopes[state.scope_id.0 as usize];
if scope.range.start != ident_range.start || scope.range.end != ident_range.end {
if scope.range.start == EvaluationOrder(0) {
scope.range.start = ident_range.start;
} else if ident_range.start != EvaluationOrder(0) {
scope.range.start =
EvaluationOrder(scope.range.start.0.min(ident_range.start.0));
}
scope.range.end = EvaluationOrder(scope.range.end.0.max(ident_range.end.0));
}
state.loc = merge_location(state.loc, ident_loc);
let scope_id = state.scope_id;
env.identifiers[identifier_id.0 as usize].scope = Some(scope_id);
});
for (_group_id, state) in &scopes {
env.scopes[state.scope_id.0 as usize].loc = state.loc;
}
for (&_identifier_id, state) in &scopes {
let scope_range = env.scopes[state.scope_id.0 as usize].range.clone();
for ident in &mut env.identifiers {
if ident.scope == Some(state.scope_id) {
ident.mutable_range = scope_range.clone();
}
}
}
let mut max_instruction = EvaluationOrder(0);
for (_block_id, block) in &func.body.blocks {
for instr_id in &block.instructions {
let instr = &func.instructions[instr_id.0 as usize];
max_instruction = EvaluationOrder(max_instruction.0.max(instr.id.0));
}
max_instruction = EvaluationOrder(max_instruction.0.max(block.terminal.evaluation_order().0));
}
for (_group_id, state) in &scopes {
let scope = &env.scopes[state.scope_id.0 as usize];
if scope.range.start == EvaluationOrder(0)
|| scope.range.end == EvaluationOrder(0)
|| max_instruction == EvaluationOrder(0)
|| scope.range.end.0 > max_instruction.0 + 1
{
return Err(CompilerDiagnostic::new(
ErrorCategory::Invariant,
&format!(
"Invalid mutable range for scope: Scope @{} has range [{}:{}] but the valid range is [1:{}]",
scope.id.0,
scope.range.start.0,
scope.range.end.0,
max_instruction.0 + 1,
),
None,
));
}
}
Ok(())
}
struct ScopeState {
scope_id: react_compiler_hir::ScopeId,
loc: Option<SourceLocation>,
}
fn merge_location(
l: Option<SourceLocation>,
r: Option<SourceLocation>,
) -> Option<SourceLocation> {
match (l, r) {
(None, r) => r,
(l, None) => l,
(Some(l), Some(r)) => Some(SourceLocation {
start: Position {
line: l.start.line.min(r.start.line),
column: l.start.column.min(r.start.column),
index: match (l.start.index, r.start.index) {
(Some(a), Some(b)) => Some(a.min(b)),
(a, b) => a.or(b),
},
},
end: Position {
line: l.end.line.max(r.end.line),
column: l.end.column.max(r.end.column),
index: match (l.end.index, r.end.index) {
(Some(a), Some(b)) => Some(a.max(b)),
(a, b) => a.or(b),
},
},
}),
}
}
fn may_allocate(value: &InstructionValue, lvalue_type_is_primitive: bool) -> bool {
match value {
InstructionValue::Destructure { lvalue, .. } => {
visitors::does_pattern_contain_spread_element(&lvalue.pattern)
}
InstructionValue::PostfixUpdate { .. }
| InstructionValue::PrefixUpdate { .. }
| InstructionValue::Await { .. }
| InstructionValue::DeclareLocal { .. }
| InstructionValue::DeclareContext { .. }
| InstructionValue::StoreLocal { .. }
| InstructionValue::LoadGlobal { .. }
| InstructionValue::MetaProperty { .. }
| InstructionValue::TypeCastExpression { .. }
| InstructionValue::LoadLocal { .. }
| InstructionValue::LoadContext { .. }
| InstructionValue::StoreContext { .. }
| InstructionValue::PropertyDelete { .. }
| InstructionValue::ComputedLoad { .. }
| InstructionValue::ComputedDelete { .. }
| InstructionValue::JSXText { .. }
| InstructionValue::TemplateLiteral { .. }
| InstructionValue::Primitive { .. }
| InstructionValue::GetIterator { .. }
| InstructionValue::IteratorNext { .. }
| InstructionValue::NextPropertyOf { .. }
| InstructionValue::Debugger { .. }
| InstructionValue::StartMemoize { .. }
| InstructionValue::FinishMemoize { .. }
| InstructionValue::UnaryExpression { .. }
| InstructionValue::BinaryExpression { .. }
| InstructionValue::PropertyLoad { .. }
| InstructionValue::StoreGlobal { .. } => false,
InstructionValue::TaggedTemplateExpression { .. }
| InstructionValue::CallExpression { .. }
| InstructionValue::MethodCall { .. } => !lvalue_type_is_primitive,
InstructionValue::RegExpLiteral { .. }
| InstructionValue::PropertyStore { .. }
| InstructionValue::ComputedStore { .. }
| InstructionValue::ArrayExpression { .. }
| InstructionValue::JsxExpression { .. }
| InstructionValue::JsxFragment { .. }
| InstructionValue::NewExpression { .. }
| InstructionValue::ObjectExpression { .. }
| InstructionValue::UnsupportedNode { .. }
| InstructionValue::ObjectMethod { .. }
| InstructionValue::FunctionExpression { .. } => true,
}
}
fn each_pattern_operand(pattern: &Pattern) -> Vec<IdentifierId> {
visitors::each_pattern_operand(pattern)
.into_iter()
.map(|p| p.identifier)
.collect()
}
fn each_instruction_value_operand(
value: &InstructionValue,
env: &Environment,
) -> Vec<IdentifierId> {
visitors::each_instruction_value_operand(value, env)
.into_iter()
.map(|p| p.identifier)
.collect()
}
pub(crate) fn find_disjoint_mutable_values(func: &HirFunction, env: &Environment) -> DisjointSet<IdentifierId> {
let mut scope_identifiers = DisjointSet::<IdentifierId>::new();
let mut declarations: HashMap<DeclarationId, IdentifierId> = HashMap::new();
let enable_forest = env.config.enable_forest;
for (_block_id, block) in &func.body.blocks {
for phi in &block.phis {
let phi_id = phi.place.identifier;
let phi_range = &env.identifiers[phi_id.0 as usize].mutable_range;
let phi_decl_id = env.identifiers[phi_id.0 as usize].declaration_id;
let first_instr_id = block
.instructions
.first()
.map(|iid| func.instructions[iid.0 as usize].id)
.unwrap_or(block.terminal.evaluation_order());
if phi_range.start.0 + 1 != phi_range.end.0
&& phi_range.end > first_instr_id
{
let mut operands = vec![phi_id];
if let Some(&decl_id) = declarations.get(&phi_decl_id) {
operands.push(decl_id);
}
for (_pred_id, phi_operand) in &phi.operands {
operands.push(phi_operand.identifier);
}
scope_identifiers.union(&operands);
} else if enable_forest {
for (_pred_id, phi_operand) in &phi.operands {
scope_identifiers.union(&[phi_id, phi_operand.identifier]);
}
}
}
for instr_id in &block.instructions {
let instr = &func.instructions[instr_id.0 as usize];
let mut operands: Vec<IdentifierId> = Vec::new();
let lvalue_id = instr.lvalue.identifier;
let lvalue_range = &env.identifiers[lvalue_id.0 as usize].mutable_range;
let lvalue_type = &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize];
let lvalue_type_is_primitive = react_compiler_hir::is_primitive_type(lvalue_type);
if lvalue_range.end.0 > lvalue_range.start.0 + 1
|| may_allocate(&instr.value, lvalue_type_is_primitive)
{
operands.push(lvalue_id);
}
match &instr.value {
InstructionValue::DeclareLocal { lvalue, .. }
| InstructionValue::DeclareContext { lvalue, .. } => {
let place_id = lvalue.place.identifier;
let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
declarations.entry(decl_id).or_insert(place_id);
}
InstructionValue::StoreLocal { lvalue, value, .. }
| InstructionValue::StoreContext { lvalue, value, .. } => {
let place_id = lvalue.place.identifier;
let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
declarations.entry(decl_id).or_insert(place_id);
let place_range =
&env.identifiers[place_id.0 as usize].mutable_range;
if place_range.end.0 > place_range.start.0 + 1 {
operands.push(place_id);
}
let value_range =
&env.identifiers[value.identifier.0 as usize].mutable_range;
if value_range.contains(instr.id)
&& value_range.start.0 > 0
{
operands.push(value.identifier);
}
}
InstructionValue::Destructure { lvalue, value, .. } => {
let pattern_places = each_pattern_operand(&lvalue.pattern);
for place_id in &pattern_places {
let decl_id = env.identifiers[place_id.0 as usize].declaration_id;
declarations.entry(decl_id).or_insert(*place_id);
let place_range =
&env.identifiers[place_id.0 as usize].mutable_range;
if place_range.end.0 > place_range.start.0 + 1 {
operands.push(*place_id);
}
}
let value_range =
&env.identifiers[value.identifier.0 as usize].mutable_range;
if value_range.contains(instr.id)
&& value_range.start.0 > 0
{
operands.push(value.identifier);
}
}
InstructionValue::MethodCall { property, .. } => {
let all_operands =
each_instruction_value_operand(&instr.value, env);
for op_id in &all_operands {
let op_range =
&env.identifiers[op_id.0 as usize].mutable_range;
if op_range.contains(instr.id) && op_range.start.0 > 0 {
operands.push(*op_id);
}
}
operands.push(property.identifier);
}
_ => {
let all_operands =
each_instruction_value_operand(&instr.value, env);
for op_id in &all_operands {
let op_range =
&env.identifiers[op_id.0 as usize].mutable_range;
if op_range.contains(instr.id) && op_range.start.0 > 0 {
operands.push(*op_id);
}
}
}
}
if !operands.is_empty() {
scope_identifiers.union(&operands);
}
}
}
scope_identifiers
}