use std::iter;
use oxc_allocator::TakeIn;
use rustc_hash::FxHashMap;
use oxc_ast::{NONE, ast::*};
use oxc_ast_visit::{VisitMut, walk_mut};
use oxc_span::SPAN;
use oxc_str::Ident;
use oxc_syntax::{
node::NodeId,
scope::{ScopeFlags, ScopeId},
symbol::{SymbolFlags, SymbolId},
};
use oxc_traverse::BoundIdentifier;
use crate::{
context::TraverseCtx,
utils::ast_builder::{
create_assignment, create_class_constructor_with_params, create_super_call,
},
};
use super::{ClassProperties, utils::exprs_into_stmts};
pub(super) enum InstanceInitsInsertLocation<'a> {
NewConstructor,
ExistingConstructor(usize),
SuperFnInsideConstructor(BoundIdentifier<'a>),
SuperFnOutsideClass(BoundIdentifier<'a>),
}
pub(super) struct InstanceInitScopes {
pub insert_in_scope_id: ScopeId,
pub constructor_scope_id: Option<ScopeId>,
}
impl<'a> ClassProperties<'a> {
pub(super) fn replace_super_in_constructor(
constructor: &mut Function<'a>,
ctx: &mut TraverseCtx<'a>,
) -> (InstanceInitScopes, InstanceInitsInsertLocation<'a>) {
let replacer = ConstructorParamsSuperReplacer::new(ctx);
if let Some((super_func_scope_id, insert_location)) = replacer.replace(constructor) {
let insert_scopes = InstanceInitScopes {
insert_in_scope_id: super_func_scope_id,
constructor_scope_id: None,
};
return (insert_scopes, insert_location);
}
let constructor_scope_id = constructor.scope_id();
let replacer = ConstructorBodySuperReplacer::new(constructor_scope_id, ctx);
let (super_func_scope_id, insert_location) = replacer.replace(constructor);
let constructor_scope_id = if ctx.scoping().get_bindings(constructor_scope_id).is_empty() {
None
} else {
Some(constructor_scope_id)
};
let insert_scopes =
InstanceInitScopes { insert_in_scope_id: super_func_scope_id, constructor_scope_id };
(insert_scopes, insert_location)
}
pub(super) fn insert_constructor(
body: &mut ClassBody<'a>,
inits: Vec<Expression<'a>>,
has_super_class: bool,
constructor_scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) {
let mut stmts = ctx.ast.vec_with_capacity(inits.len() + usize::from(has_super_class));
let mut params_rest = None;
if has_super_class {
let args_binding =
ctx.generate_uid("args", constructor_scope_id, SymbolFlags::FunctionScopedVariable);
let rest_element =
ctx.ast.binding_rest_element(SPAN, args_binding.create_binding_pattern(ctx));
params_rest =
Some(ctx.ast.alloc_formal_parameter_rest(SPAN, ctx.ast.vec(), rest_element, NONE));
stmts.push(ctx.ast.statement_expression(SPAN, create_super_call(&args_binding, ctx)));
}
stmts.extend(exprs_into_stmts(inits, ctx.ast));
let params = ctx.ast.alloc_formal_parameters(
SPAN,
FormalParameterKind::FormalParameter,
ctx.ast.vec(),
params_rest,
);
let ctor = create_class_constructor_with_params(stmts, params, constructor_scope_id, ctx);
body.body.insert(0, ctor);
}
pub(super) fn insert_inits_into_constructor_as_statements(
&mut self,
constructor: &mut Function<'a>,
inits: Vec<Expression<'a>>,
insertion_index: usize,
ctx: &mut TraverseCtx<'a>,
) {
self.rename_clashing_symbols(constructor, ctx);
let body_stmts = &mut constructor.body.as_mut().unwrap().statements;
body_stmts.splice(insertion_index..insertion_index, exprs_into_stmts(inits, ctx.ast));
}
pub(super) fn create_super_function_inside_constructor(
&mut self,
constructor: &mut Function<'a>,
inits: Vec<Expression<'a>>,
super_binding: &BoundIdentifier<'a>,
super_func_scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) {
self.rename_clashing_symbols(constructor, ctx);
let args_binding =
ctx.generate_uid("args", super_func_scope_id, SymbolFlags::FunctionScopedVariable);
let super_call = create_super_call(&args_binding, ctx);
let this_expr = ctx.ast.expression_this(SPAN);
let body_exprs = ctx.ast.expression_sequence(
SPAN,
ctx.ast.vec_from_iter(iter::once(super_call).chain(inits).chain(iter::once(this_expr))),
);
let body = ctx.ast.vec1(ctx.ast.statement_expression(SPAN, body_exprs));
let super_func = ctx.ast.expression_arrow_function_with_scope_id_and_pure_and_pife(
SPAN,
true,
false,
NONE,
{
let rest_element =
ctx.ast.binding_rest_element(SPAN, args_binding.create_binding_pattern(ctx));
let rest =
ctx.ast.alloc_formal_parameter_rest(SPAN, ctx.ast.vec(), rest_element, NONE);
ctx.ast.alloc_formal_parameters(
SPAN,
FormalParameterKind::ArrowFormalParameters,
ctx.ast.vec(),
Some(rest),
)
},
NONE,
ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), body),
super_func_scope_id,
false,
false,
);
let super_func_decl = Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
ctx.ast.vec1(ctx.ast.variable_declarator(
SPAN,
VariableDeclarationKind::Var,
super_binding.create_binding_pattern(ctx),
NONE,
Some(super_func),
false,
)),
false,
));
let body_stmts = &mut constructor.body.as_mut().unwrap().statements;
body_stmts.insert(0, super_func_decl);
}
pub(super) fn create_super_function_outside_constructor(
&mut self,
inits: Vec<Expression<'a>>,
super_binding: &BoundIdentifier<'a>,
super_func_scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) {
let outer_scope_id = ctx.current_block_scope_id();
let directives = if ctx.scoping().scope_flags(outer_scope_id).is_strict_mode() {
ctx.ast.vec()
} else {
ctx.ast.vec1(ctx.ast.use_strict_directive())
};
let return_stmt = ctx.ast.statement_return(SPAN, Some(ctx.ast.expression_this(SPAN)));
let body_stmts =
ctx.ast.vec_from_iter(exprs_into_stmts(inits, ctx.ast).chain([return_stmt]));
let super_func = ctx.ast.expression_function_with_scope_id_and_pure_and_pife(
SPAN,
FunctionType::FunctionExpression,
None,
false,
false,
false,
NONE,
NONE,
ctx.ast.alloc_formal_parameters(
SPAN,
FormalParameterKind::FormalParameter,
ctx.ast.vec(),
NONE,
),
NONE,
Some(ctx.ast.alloc_function_body(SPAN, directives, body_stmts)),
super_func_scope_id,
false,
false,
);
let init = if self.current_class().is_declaration {
Some(super_func)
} else {
let assignment = create_assignment(super_binding, super_func, SPAN, ctx);
self.insert_after_exprs.push(assignment);
None
};
ctx.state.var_declarations.insert_let(super_binding, init, ctx.ast);
}
fn rename_clashing_symbols(
&mut self,
constructor: &mut Function<'a>,
ctx: &mut TraverseCtx<'a>,
) {
let clashing_symbols = &mut self.clashing_constructor_symbols;
if clashing_symbols.is_empty() {
return;
}
let constructor_scope_id = constructor.scope_id();
for (&symbol_id, name) in clashing_symbols.iter_mut() {
let new_name = ctx.generate_uid_name(name);
*name = new_name;
ctx.scoping_mut().rename_symbol(symbol_id, constructor_scope_id, new_name);
}
let mut renamer = ConstructorSymbolRenamer::new(clashing_symbols, ctx);
renamer.visit_function(constructor, ScopeFlags::empty());
clashing_symbols.clear();
}
}
struct ConstructorParamsSuperReplacer<'a, 'ctx> {
super_binding: Option<BoundIdentifier<'a>>,
ctx: &'ctx mut TraverseCtx<'a>,
}
impl<'a, 'ctx> ConstructorParamsSuperReplacer<'a, 'ctx> {
fn new(ctx: &'ctx mut TraverseCtx<'a>) -> Self {
Self { super_binding: None, ctx }
}
fn replace(
mut self,
constructor: &mut Function<'a>,
) -> Option<(ScopeId, InstanceInitsInsertLocation<'a>)> {
self.visit_formal_parameters(&mut constructor.params);
#[expect(clippy::question_mark)]
if self.super_binding.is_none() {
return None;
}
let body_stmts = &mut constructor.body.as_mut().unwrap().statements;
self.visit_statements(body_stmts);
let super_binding = self.super_binding.unwrap();
let insert_location = InstanceInitsInsertLocation::SuperFnOutsideClass(super_binding);
let outer_scope_id = self.ctx.current_block_scope_id();
let super_func_scope_id = self.ctx.scoping_mut().add_scope(
Some(outer_scope_id),
NodeId::DUMMY,
ScopeFlags::Function | ScopeFlags::StrictMode,
);
Some((super_func_scope_id, insert_location))
}
}
impl<'a> VisitMut<'a> for ConstructorParamsSuperReplacer<'a, '_> {
#[inline]
fn visit_expression(&mut self, expr: &mut Expression<'a>) {
if let Expression::CallExpression(call_expr) = expr
&& call_expr.callee.is_super()
{
self.visit_arguments(&mut call_expr.arguments);
let span = call_expr.span;
self.wrap_super(expr, span);
return;
}
walk_mut::walk_expression(self, expr);
}
#[inline]
fn visit_function(&mut self, _func: &mut Function<'a>, _flags: ScopeFlags) {}
#[inline]
fn visit_static_block(&mut self, _block: &mut StaticBlock) {}
#[inline]
fn visit_ts_module_block(&mut self, _block: &mut TSModuleBlock<'a>) {}
#[inline]
fn visit_property_definition(&mut self, prop: &mut PropertyDefinition<'a>) {
self.visit_decorators(&mut prop.decorators);
if prop.computed {
self.visit_property_key(&mut prop.key);
}
}
#[inline]
fn visit_accessor_property(&mut self, prop: &mut AccessorProperty<'a>) {
self.visit_decorators(&mut prop.decorators);
if prop.computed {
self.visit_property_key(&mut prop.key);
}
}
}
impl<'a> ConstructorParamsSuperReplacer<'a, '_> {
fn wrap_super(&mut self, expr: &mut Expression<'a>, span: Span) {
let super_binding = self.super_binding.get_or_insert_with(|| {
self.ctx.generate_uid(
"super",
self.ctx.current_block_scope_id(),
SymbolFlags::BlockScopedVariable,
)
});
let ctx = &mut *self.ctx;
let super_call = expr.take_in(ctx.ast);
*expr = ctx.ast.expression_call(
span,
Expression::from(ctx.ast.member_expression_static(
SPAN,
super_binding.create_read_expression(ctx),
ctx.ast.identifier_name(SPAN, Str::from("call")),
false,
)),
NONE,
ctx.ast.vec1(Argument::from(super_call)),
false,
);
}
}
struct ConstructorBodySuperReplacer<'a, 'ctx> {
constructor_scope_id: ScopeId,
super_binding: Option<BoundIdentifier<'a>>,
ctx: &'ctx mut TraverseCtx<'a>,
}
impl<'a, 'ctx> ConstructorBodySuperReplacer<'a, 'ctx> {
fn new(constructor_scope_id: ScopeId, ctx: &'ctx mut TraverseCtx<'a>) -> Self {
Self { constructor_scope_id, super_binding: None, ctx }
}
fn replace(
mut self,
constructor: &mut Function<'a>,
) -> (ScopeId, InstanceInitsInsertLocation<'a>) {
#[expect(clippy::never_loop)]
'outer: loop {
let body_stmts = &mut constructor.body.as_mut().unwrap().statements;
for (index, stmt) in body_stmts.iter_mut().enumerate() {
if let Statement::ExpressionStatement(expr_stmt) = stmt
&& let Expression::CallExpression(call_expr) = &mut expr_stmt.expression
&& let Expression::Super(super_) = &call_expr.callee
{
let span = super_.span;
self.visit_arguments(&mut call_expr.arguments);
if self.super_binding.is_none() {
let insert_location =
InstanceInitsInsertLocation::ExistingConstructor(index + 1);
return (self.constructor_scope_id, insert_location);
}
self.replace_super(call_expr, span);
break 'outer;
}
self.visit_statement(stmt);
}
if self.super_binding.is_none() {
self.super_binding = Some(self.create_super_binding());
}
break;
}
let super_func_scope_id = self.ctx.scoping_mut().add_scope(
Some(self.constructor_scope_id),
NodeId::DUMMY,
ScopeFlags::Function | ScopeFlags::Arrow | ScopeFlags::StrictMode,
);
let super_binding = self.super_binding.unwrap();
let insert_location = InstanceInitsInsertLocation::SuperFnInsideConstructor(super_binding);
(super_func_scope_id, insert_location)
}
}
impl<'a> VisitMut<'a> for ConstructorBodySuperReplacer<'a, '_> {
#[inline]
fn visit_call_expression(&mut self, call_expr: &mut CallExpression<'a>) {
if let Expression::Super(super_) = &call_expr.callee {
let span = super_.span;
self.replace_super(call_expr, span);
}
walk_mut::walk_call_expression(self, call_expr);
}
#[inline]
fn visit_function(&mut self, _func: &mut Function<'a>, _flags: ScopeFlags) {}
#[inline]
fn visit_static_block(&mut self, _block: &mut StaticBlock) {}
#[inline]
fn visit_ts_module_block(&mut self, _block: &mut TSModuleBlock<'a>) {}
#[inline]
fn visit_property_definition(&mut self, prop: &mut PropertyDefinition<'a>) {
self.visit_decorators(&mut prop.decorators);
if prop.computed {
self.visit_property_key(&mut prop.key);
}
}
#[inline]
fn visit_accessor_property(&mut self, prop: &mut AccessorProperty<'a>) {
self.visit_decorators(&mut prop.decorators);
if prop.computed {
self.visit_property_key(&mut prop.key);
}
}
}
impl<'a> ConstructorBodySuperReplacer<'a, '_> {
fn replace_super(&mut self, call_expr: &mut CallExpression<'a>, span: Span) {
if self.super_binding.is_none() {
self.super_binding = Some(self.create_super_binding());
}
let super_binding = self.super_binding.as_ref().unwrap();
call_expr.callee = super_binding.create_spanned_read_expression(span, self.ctx);
}
fn create_super_binding(&mut self) -> BoundIdentifier<'a> {
self.ctx.generate_uid(
"super",
self.constructor_scope_id,
SymbolFlags::FunctionScopedVariable,
)
}
}
struct ConstructorSymbolRenamer<'a, 'v> {
clashing_symbols: &'v mut FxHashMap<SymbolId, Ident<'a>>,
ctx: &'v TraverseCtx<'a>,
}
impl<'a, 'v> ConstructorSymbolRenamer<'a, 'v> {
fn new(
clashing_symbols: &'v mut FxHashMap<SymbolId, Ident<'a>>,
ctx: &'v TraverseCtx<'a>,
) -> Self {
Self { clashing_symbols, ctx }
}
}
impl<'a> VisitMut<'a> for ConstructorSymbolRenamer<'a, '_> {
fn visit_binding_identifier(&mut self, ident: &mut BindingIdentifier<'a>) {
let symbol_id = ident.symbol_id();
if let Some(new_name) = self.clashing_symbols.get(&symbol_id) {
ident.name = *new_name;
}
}
fn visit_identifier_reference(&mut self, ident: &mut IdentifierReference<'a>) {
let reference_id = ident.reference_id();
if let Some(symbol_id) = self.ctx.scoping().get_reference(reference_id).symbol_id()
&& let Some(new_name) = self.clashing_symbols.get(&symbol_id)
{
ident.name = *new_name;
}
}
}