use rustc_hash::FxHashMap;
use react_compiler_utils::FxIndexSet;
use react_compiler_hir::{
BlockId, ReactiveFunction, ReactiveScopeBlock, ReactiveTerminal, ReactiveTerminalStatement,
environment::Environment,
};
use crate::visitors::{
ReactiveFunctionTransform, ReactiveFunctionVisitor, transform_reactive_function,
visit_reactive_function,
};
pub fn stabilize_block_ids(func: &mut ReactiveFunction, env: &mut Environment) {
let mut referenced: FxIndexSet<BlockId> = FxIndexSet::default();
let collector = CollectReferencedLabels { env: &*env };
visit_reactive_function(func, &collector, &mut referenced);
let mut mappings: FxHashMap<BlockId, BlockId> = FxHashMap::default();
for block_id in &referenced {
let len = mappings.len() as u32;
mappings.entry(*block_id).or_insert(BlockId(len));
}
let mut rewriter = RewriteBlockIds { env };
let _ = transform_reactive_function(func, &mut rewriter, &mut mappings);
}
struct CollectReferencedLabels<'a> {
env: &'a Environment,
}
impl<'a> ReactiveFunctionVisitor for CollectReferencedLabels<'a> {
type State = FxIndexSet<BlockId>;
fn env(&self) -> &Environment {
self.env
}
fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut Self::State) {
let scope_data = &self.env.scopes[scope.scope.0 as usize];
if let Some(ref early_return) = scope_data.early_return_value {
state.insert(early_return.label);
}
self.traverse_scope(scope, state);
}
fn visit_terminal(&self, stmt: &ReactiveTerminalStatement, state: &mut Self::State) {
if let Some(ref label) = stmt.label {
if !label.implicit {
state.insert(label.id);
}
}
self.traverse_terminal(stmt, state);
}
}
fn get_or_insert_mapping(mappings: &mut FxHashMap<BlockId, BlockId>, id: BlockId) -> BlockId {
let len = mappings.len() as u32;
*mappings.entry(id).or_insert(BlockId(len))
}
struct RewriteBlockIds<'a> {
env: &'a mut Environment,
}
impl<'a> ReactiveFunctionTransform for RewriteBlockIds<'a> {
type State = FxHashMap<BlockId, BlockId>;
fn env(&self) -> &Environment {
self.env
}
fn visit_scope(
&mut self,
scope: &mut ReactiveScopeBlock,
state: &mut Self::State,
) -> Result<(), react_compiler_diagnostics::CompilerError> {
let scope_data = &mut self.env.scopes[scope.scope.0 as usize];
if let Some(ref mut early_return) = scope_data.early_return_value {
early_return.label = get_or_insert_mapping(state, early_return.label);
}
self.traverse_scope(scope, state)
}
fn visit_terminal(
&mut self,
stmt: &mut ReactiveTerminalStatement,
state: &mut Self::State,
) -> Result<(), react_compiler_diagnostics::CompilerError> {
if let Some(ref mut label) = stmt.label {
label.id = get_or_insert_mapping(state, label.id);
}
match &mut stmt.terminal {
ReactiveTerminal::Break { target, .. } | ReactiveTerminal::Continue { target, .. } => {
*target = get_or_insert_mapping(state, *target);
}
_ => {}
}
self.traverse_terminal(stmt, state)
}
}