use std::{borrow::Cow, mem};
use oxc_allocator::{Box as ArenaBox, StringBuilder as ArenaStringBuilder, TakeIn};
use oxc_ast::{NONE, ast::*};
use oxc_ast_visit::Visit;
use oxc_semantic::{ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags};
use oxc_span::{GetSpan, SPAN};
use oxc_str::Ident;
use oxc_syntax::{
identifier::{is_identifier_name, is_identifier_part, is_identifier_start},
keyword::is_reserved_keyword,
};
use oxc_traverse::{Ancestor, BoundIdentifier, Traverse};
use crate::{
common::helper_loader::{Helper, helper_call_expr},
context::TraverseCtx,
state::TransformState,
utils::sync_function_symbol_flags,
};
pub struct AsyncToGenerator<'a> {
executor: AsyncGeneratorExecutor<'a>,
}
impl AsyncToGenerator<'_> {
pub fn new() -> Self {
Self { executor: AsyncGeneratorExecutor::new(Helper::AsyncToGenerator) }
}
}
impl<'a> Traverse<'a, TransformState<'a>> for AsyncToGenerator<'a> {
fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) {
let new_expr = match expr {
Expression::AwaitExpression(await_expr) => {
Self::transform_await_expression(await_expr, ctx)
}
Expression::FunctionExpression(func) => {
if func.r#async && !func.generator && !func.is_typescript_syntax() {
Some(self.executor.transform_function_expression(func, ctx))
} else {
None
}
}
Expression::ArrowFunctionExpression(arrow) => {
if arrow.r#async {
Some(self.executor.transform_arrow_function(arrow, ctx))
} else {
None
}
}
_ => None,
};
if let Some(new_expr) = new_expr {
*expr = new_expr;
}
}
fn exit_statement(&mut self, stmt: &mut Statement<'a>, ctx: &mut TraverseCtx<'a>) {
let function = match stmt {
Statement::FunctionDeclaration(func) => Some(func),
Statement::ExportDefaultDeclaration(decl) => {
if let ExportDefaultDeclarationKind::FunctionDeclaration(func) =
&mut decl.declaration
{
Some(func)
} else {
None
}
}
Statement::ExportNamedDeclaration(decl) => {
if let Some(Declaration::FunctionDeclaration(func)) = &mut decl.declaration {
Some(func)
} else {
None
}
}
_ => None,
};
if let Some(function) = function
&& function.r#async
&& !function.generator
&& !function.is_typescript_syntax()
{
let new_statement = self.executor.transform_function_declaration(function, ctx);
ctx.state.statement_injector.insert_after(stmt, new_statement);
}
}
fn exit_function(&mut self, func: &mut Function<'a>, ctx: &mut TraverseCtx<'a>) {
if func.r#async
&& !func.is_typescript_syntax()
&& AsyncGeneratorExecutor::is_class_method_like_ancestor(ctx.parent())
{
self.executor.transform_function_for_method_definition(func, ctx);
}
}
}
impl<'a> AsyncToGenerator<'a> {
fn is_inside_async_function(ctx: &TraverseCtx<'a>) -> bool {
if ctx.current_scope_flags().is_top() {
return false;
}
for ancestor in ctx.ancestors() {
match ancestor {
Ancestor::FunctionBody(func) => return *func.r#async(),
Ancestor::ArrowFunctionExpressionBody(func) => {
return *func.r#async();
}
_ => {}
}
}
false
}
fn transform_await_expression(
expr: &mut AwaitExpression<'a>,
ctx: &TraverseCtx<'a>,
) -> Option<Expression<'a>> {
if Self::is_inside_async_function(ctx) {
Some(ctx.ast.expression_yield(expr.span, false, Some(expr.argument.take_in(ctx.ast))))
} else {
None
}
}
}
pub struct AsyncGeneratorExecutor<'a> {
helper: Helper,
_marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> AsyncGeneratorExecutor<'a> {
pub fn new(helper: Helper) -> Self {
Self { helper, _marker: std::marker::PhantomData }
}
pub fn transform_function_for_method_definition(
&self,
func: &mut Function<'a>,
ctx: &mut TraverseCtx<'a>,
) {
let Some(body) = func.body.take() else {
return;
};
let needs_move_parameters_to_inner_function =
Self::could_throw_errors_parameters(&func.params);
let (generator_scope_id, wrapper_scope_id) = {
let new_scope_id = ctx.create_child_scope(ctx.current_scope_id(), ScopeFlags::Function);
let scope_id = func.scope_id.replace(Some(new_scope_id)).unwrap();
ctx.scoping_mut().change_scope_parent_id(scope_id, Some(new_scope_id));
if !needs_move_parameters_to_inner_function {
Self::move_formal_parameters_to_target_scope(new_scope_id, &func.params, ctx);
}
(scope_id, new_scope_id)
};
let params = if needs_move_parameters_to_inner_function {
let new_params = Self::create_placeholder_params(&func.params, wrapper_scope_id, ctx);
mem::replace(&mut func.params, new_params)
} else {
Self::create_empty_params(ctx)
};
let callee = self.create_async_to_generator_call(params, body, generator_scope_id, ctx);
let (callee, arguments) = if needs_move_parameters_to_inner_function {
let property = ctx.ast.identifier_name(SPAN, "apply");
let callee =
Expression::from(ctx.ast.member_expression_static(SPAN, callee, property, false));
let this_argument = Argument::from(ctx.ast.expression_this(SPAN));
let arguments_argument = Argument::from(ctx.create_unbound_ident_expr(
SPAN,
ctx.ast.ident("arguments"),
ReferenceFlags::Read,
));
(callee, ctx.ast.vec_from_array([this_argument, arguments_argument]))
} else {
(callee, ctx.ast.vec())
};
let expression = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false);
let statement = ctx.ast.statement_return(SPAN, Some(expression));
func.r#async = false;
func.generator = false;
func.body = Some(ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), ctx.ast.vec1(statement)));
func.scope_id.set(Some(wrapper_scope_id));
sync_function_symbol_flags(func, ctx);
}
pub fn transform_function_expression(
&self,
wrapper_function: &mut Function<'a>,
ctx: &mut TraverseCtx<'a>,
) -> Expression<'a> {
let span = wrapper_function.span;
let body = wrapper_function.body.take().unwrap();
let params = wrapper_function.params.take_in_box(ctx.ast);
let id = wrapper_function.id.take();
let has_function_id = id.is_some();
if !has_function_id && !Self::is_function_length_affected(¶ms) {
return self.create_async_to_generator_call(
params,
body,
wrapper_function.scope_id.take().unwrap(),
ctx,
);
}
let (generator_scope_id, wrapper_scope_id) = {
let wrapper_scope_id =
ctx.create_child_scope(ctx.current_scope_id(), ScopeFlags::Function);
let scope_id = wrapper_function.scope_id.replace(Some(wrapper_scope_id)).unwrap();
ctx.scoping_mut().change_scope_parent_id(scope_id, Some(wrapper_scope_id));
if let Some(id) = id.as_ref() {
Self::move_binding_identifier_to_target_scope(wrapper_scope_id, id, ctx);
let symbol_id = id.symbol_id();
*ctx.scoping_mut().symbol_flags_mut(symbol_id) = SymbolFlags::Function;
}
(scope_id, wrapper_scope_id)
};
let bound_ident = Self::create_bound_identifier(
id.as_ref(),
wrapper_scope_id,
SymbolFlags::FunctionScopedVariable,
ctx,
);
let caller_function = {
let scope_id = ctx.create_child_scope(wrapper_scope_id, ScopeFlags::Function);
let params = Self::create_placeholder_params(¶ms, scope_id, ctx);
let statements = ctx.ast.vec1(Self::create_apply_call_statement(&bound_ident, ctx));
let body = ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), statements);
let (r#type, id) = if id.is_some() {
(FunctionType::FunctionDeclaration, id)
} else {
(
FunctionType::FunctionExpression,
Self::infer_function_id_from_parent_node(scope_id, ctx),
)
};
Self::create_function(r#type, id, params, body, scope_id, ctx)
};
{
let async_to_gen_decl = self.create_async_to_generator_declaration(
&bound_ident,
params,
body,
generator_scope_id,
ctx,
);
let statements = if has_function_id {
let id = caller_function.id.as_ref().unwrap();
let reference = ctx.create_bound_ident_expr(
SPAN,
id.name,
id.symbol_id(),
ReferenceFlags::Read,
);
let func_decl = Statement::FunctionDeclaration(caller_function);
let statement_return = ctx.ast.statement_return(SPAN, Some(reference));
ctx.ast.vec_from_array([async_to_gen_decl, func_decl, statement_return])
} else {
let statement_return = ctx
.ast
.statement_return(SPAN, Some(Expression::FunctionExpression(caller_function)));
ctx.ast.vec_from_array([async_to_gen_decl, statement_return])
};
debug_assert!(wrapper_function.body.is_none());
wrapper_function.r#async = false;
wrapper_function.generator = false;
wrapper_function.body.replace(ctx.ast.alloc_function_body(
SPAN,
ctx.ast.vec(),
statements,
));
}
let callee = Expression::FunctionExpression(wrapper_function.take_in_box(ctx.ast));
ctx.ast.expression_call_with_pure(span, callee, NONE, ctx.ast.vec(), false, true)
}
pub fn transform_function_declaration(
&self,
wrapper_function: &mut Function<'a>,
ctx: &mut TraverseCtx<'a>,
) -> Statement<'a> {
let (generator_scope_id, wrapper_scope_id) = {
let wrapper_scope_id =
ctx.create_child_scope(ctx.current_scope_id(), ScopeFlags::Function);
let scope_id = wrapper_function.scope_id.replace(Some(wrapper_scope_id)).unwrap();
ctx.scoping_mut().change_scope_parent_id(scope_id, Some(wrapper_scope_id));
(scope_id, wrapper_scope_id)
};
let body = wrapper_function.body.take().unwrap();
let params =
Self::create_placeholder_params(&wrapper_function.params, wrapper_scope_id, ctx);
let params = mem::replace(&mut wrapper_function.params, params);
let bound_ident = Self::create_bound_identifier(
wrapper_function.id.as_ref(),
ctx.current_scope_id(),
SymbolFlags::Function,
ctx,
);
{
wrapper_function.r#async = false;
wrapper_function.generator = false;
sync_function_symbol_flags(wrapper_function, ctx);
let statements = ctx.ast.vec1(Self::create_apply_call_statement(&bound_ident, ctx));
debug_assert!(wrapper_function.body.is_none());
wrapper_function.body.replace(ctx.ast.alloc_function_body(
SPAN,
ctx.ast.vec(),
statements,
));
}
{
let statements = ctx.ast.vec_from_array([
self.create_async_to_generator_assignment(
&bound_ident,
params,
body,
generator_scope_id,
ctx,
),
Self::create_apply_call_statement(&bound_ident, ctx),
]);
let body = ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), statements);
let scope_id = ctx.create_child_scope(ctx.current_scope_id(), ScopeFlags::Function);
ctx.scoping_mut().change_scope_parent_id(generator_scope_id, Some(scope_id));
let params = Self::create_empty_params(ctx);
let id = Some(bound_ident.create_binding_identifier(ctx));
let caller_function = Self::create_function(
FunctionType::FunctionDeclaration,
id,
params,
body,
scope_id,
ctx,
);
Statement::FunctionDeclaration(caller_function)
}
}
pub(self) fn transform_arrow_function(
&self,
arrow: &mut ArrowFunctionExpression<'a>,
ctx: &mut TraverseCtx<'a>,
) -> Expression<'a> {
let arrow_span = arrow.span;
let mut body = arrow.body.take_in_box(ctx.ast);
if arrow.expression {
let statement = body.statements.first_mut().unwrap();
let expression = match statement {
Statement::ExpressionStatement(es) => es.expression.take_in(ctx.ast),
_ => unreachable!(),
};
*statement = ctx.ast.statement_return(expression.span(), Some(expression));
}
let params = arrow.params.take_in_box(ctx.ast);
let generator_function_id = arrow.scope_id();
ctx.scoping_mut().scope_flags_mut(generator_function_id).remove(ScopeFlags::Arrow);
let function_name = Self::infer_function_name_from_parent_node(ctx);
if function_name.is_none() && !Self::is_function_length_affected(¶ms) {
return self.create_async_to_generator_call(params, body, generator_function_id, ctx);
}
let wrapper_scope_id = ctx.create_child_scope(ctx.current_scope_id(), ScopeFlags::Function);
ctx.scoping_mut().change_scope_parent_id(generator_function_id, Some(wrapper_scope_id));
let bound_ident = Self::create_bound_identifier(
None,
wrapper_scope_id,
SymbolFlags::FunctionScopedVariable,
ctx,
);
let caller_function = {
let scope_id = ctx.create_child_scope(wrapper_scope_id, ScopeFlags::Function);
let params = Self::create_placeholder_params(¶ms, scope_id, ctx);
let statements = ctx.ast.vec1(Self::create_apply_call_statement(&bound_ident, ctx));
let body = ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), statements);
let id = function_name.map(|name| {
ctx.generate_binding(name, scope_id, SymbolFlags::Function)
.create_binding_identifier(ctx)
});
let function = Self::create_function(
FunctionType::FunctionExpression,
id,
params,
body,
scope_id,
ctx,
);
let argument = Some(Expression::FunctionExpression(function));
ctx.ast.statement_return(SPAN, argument)
};
{
let statement = self.create_async_to_generator_declaration(
&bound_ident,
params,
body,
generator_function_id,
ctx,
);
let statements = ctx.ast.vec_from_array([statement, caller_function]);
let body = ctx.ast.alloc_function_body(SPAN, ctx.ast.vec(), statements);
let params = Self::create_empty_params(ctx);
let wrapper_function = Self::create_function(
FunctionType::FunctionExpression,
None,
params,
body,
wrapper_scope_id,
ctx,
);
let callee = Expression::FunctionExpression(wrapper_function);
ctx.ast.expression_call(arrow_span, callee, NONE, ctx.ast.vec(), false)
}
}
fn infer_function_id_from_parent_node(
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> Option<BindingIdentifier<'a>> {
let name = Self::infer_function_name_from_parent_node(ctx)?;
Some(
ctx.generate_binding(name, scope_id, SymbolFlags::Function)
.create_binding_identifier(ctx),
)
}
fn infer_function_name_from_parent_node(ctx: &TraverseCtx<'a>) -> Option<Ident<'a>> {
match ctx.parent() {
Ancestor::VariableDeclaratorInit(declarator) => {
declarator.id().get_binding_identifier().map(|id| id.name)
}
Ancestor::ObjectPropertyValue(property) if !*property.method() => {
property.key().static_name().map(|key| Self::normalize_function_name(&key, ctx))
}
_ => None,
}
}
fn normalize_function_name(input: &Cow<'a, str>, ctx: &TraverseCtx<'a>) -> Ident<'a> {
let input_str = input.as_ref();
if !is_reserved_keyword(input_str) && is_identifier_name(input_str) {
return ctx.ast.ident_from_cow(input);
}
let mut name = ArenaStringBuilder::with_capacity_in(input_str.len() + 1, ctx.ast.allocator);
let mut capitalize_next = false;
let mut chars = input_str.chars();
if let Some(first) = chars.next()
&& is_identifier_start(first)
{
name.push(first);
}
for c in chars {
if c == ' ' {
name.push('_');
} else if !is_identifier_part(c) {
capitalize_next = true;
} else if capitalize_next {
name.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
name.push(c);
}
}
if name.is_empty() {
return ctx.ast.ident("_");
}
if is_reserved_keyword(name.as_str()) {
name.push_ascii_byte_start(b'_');
}
Ident::from(name)
}
#[inline]
fn create_function(
r#type: FunctionType,
id: Option<BindingIdentifier<'a>>,
params: ArenaBox<'a, FormalParameters<'a>>,
body: ArenaBox<'a, FunctionBody<'a>>,
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> ArenaBox<'a, Function<'a>> {
let function = ctx.ast.alloc_function_with_scope_id(
SPAN,
r#type,
id,
false,
false,
false,
NONE,
NONE,
params,
NONE,
Some(body),
scope_id,
);
sync_function_symbol_flags(&function, ctx);
function
}
fn create_apply_call_statement(
bound_ident: &BoundIdentifier<'a>,
ctx: &mut TraverseCtx<'a>,
) -> Statement<'a> {
let arguments = ctx.ast.ident("arguments");
let symbol_id = ctx.scoping().find_binding(ctx.current_scope_id(), arguments);
let arguments_ident =
Argument::from(ctx.create_ident_expr(SPAN, arguments, symbol_id, ReferenceFlags::Read));
let this = Argument::from(ctx.ast.expression_this(SPAN));
let arguments = ctx.ast.vec_from_array([this, arguments_ident]);
let callee = Expression::from(ctx.ast.member_expression_static(
SPAN,
bound_ident.create_read_expression(ctx),
ctx.ast.identifier_name(SPAN, "apply"),
false,
));
let argument = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false);
ctx.ast.statement_return(SPAN, Some(argument))
}
fn create_async_to_generator_call(
&self,
params: ArenaBox<'a, FormalParameters<'a>>,
body: ArenaBox<'a, FunctionBody<'a>>,
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> Expression<'a> {
let mut function = Self::create_function(
FunctionType::FunctionExpression,
None,
params,
body,
scope_id,
ctx,
);
function.generator = true;
let arguments = ctx.ast.vec1(Argument::FunctionExpression(function));
helper_call_expr(self.helper, arguments, ctx)
}
fn create_async_to_generator_declaration(
&self,
bound_ident: &BoundIdentifier<'a>,
params: ArenaBox<'a, FormalParameters<'a>>,
body: ArenaBox<'a, FunctionBody<'a>>,
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> Statement<'a> {
let init = self.create_async_to_generator_call(params, body, scope_id, ctx);
let declarations = ctx.ast.vec1(ctx.ast.variable_declarator(
SPAN,
VariableDeclarationKind::Var,
bound_ident.create_binding_pattern(ctx),
NONE,
Some(init),
false,
));
Statement::from(ctx.ast.declaration_variable(
SPAN,
VariableDeclarationKind::Var,
declarations,
false,
))
}
fn create_async_to_generator_assignment(
&self,
bound: &BoundIdentifier<'a>,
params: ArenaBox<'a, FormalParameters<'a>>,
body: ArenaBox<'a, FunctionBody<'a>>,
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> Statement<'a> {
let right = self.create_async_to_generator_call(params, body, scope_id, ctx);
let expression = ctx.ast.expression_assignment(
SPAN,
AssignmentOperator::Assign,
bound.create_write_target(ctx),
right,
);
ctx.ast.statement_expression(SPAN, expression)
}
fn create_placeholder_params(
params: &FormalParameters<'a>,
scope_id: ScopeId,
ctx: &mut TraverseCtx<'a>,
) -> ArenaBox<'a, FormalParameters<'a>> {
let mut parameters = ctx.ast.vec_with_capacity(params.items.len());
for param in ¶ms.items {
if param.initializer.is_some() {
break;
}
let binding = ctx.generate_uid("x", scope_id, SymbolFlags::FunctionScopedVariable);
parameters.push(
ctx.ast.plain_formal_parameter(param.span(), binding.create_binding_pattern(ctx)),
);
}
ctx.ast.alloc_formal_parameters(
SPAN,
FormalParameterKind::FormalParameter,
parameters,
NONE,
)
}
#[inline]
fn create_empty_params(ctx: &TraverseCtx<'a>) -> ArenaBox<'a, FormalParameters<'a>> {
ctx.ast.alloc_formal_parameters(
SPAN,
FormalParameterKind::FormalParameter,
ctx.ast.vec(),
NONE,
)
}
#[inline]
fn create_bound_identifier(
id: Option<&BindingIdentifier<'a>>,
scope_id: ScopeId,
flags: SymbolFlags,
ctx: &mut TraverseCtx<'a>,
) -> BoundIdentifier<'a> {
ctx.generate_uid(id.as_ref().map_or_else(|| "ref", |id| id.name.as_str()), scope_id, flags)
}
pub(crate) fn is_class_method_like_ancestor(ancestor: Ancestor) -> bool {
match ancestor {
Ancestor::MethodDefinitionValue(_) => true,
Ancestor::ObjectPropertyValue(property) => *property.method(),
_ => false,
}
}
#[inline]
fn is_function_length_affected(params: &FormalParameters<'_>) -> bool {
params.items.first().is_some_and(|param| param.initializer.is_none())
}
#[inline]
fn could_throw_errors_parameters(params: &FormalParameters<'a>) -> bool {
params.items.iter().any(|param| {
param
.initializer
.as_ref()
.is_some_and(|init| Self::could_potentially_throw_error_expression(init))
})
}
#[inline]
fn could_potentially_throw_error_expression(expr: &Expression<'a>) -> bool {
!(matches!(
expr,
Expression::NullLiteral(_)
| Expression::BooleanLiteral(_)
| Expression::NumericLiteral(_)
| Expression::StringLiteral(_)
| Expression::BigIntLiteral(_)
| Expression::ArrowFunctionExpression(_)
| Expression::FunctionExpression(_)
) || expr.is_undefined())
}
#[inline]
fn move_formal_parameters_to_target_scope(
target_scope_id: ScopeId,
params: &FormalParameters<'a>,
ctx: &mut TraverseCtx<'a>,
) {
BindingMover::new(target_scope_id, ctx).visit_formal_parameters(params);
}
#[inline]
fn move_binding_identifier_to_target_scope(
target_scope_id: ScopeId,
ident: &BindingIdentifier<'a>,
ctx: &mut TraverseCtx<'a>,
) {
BindingMover::new(target_scope_id, ctx).visit_binding_identifier(ident);
}
}
struct BindingMover<'a, 'ctx> {
ctx: &'ctx mut TraverseCtx<'a>,
target_scope_id: ScopeId,
}
impl<'a, 'ctx> BindingMover<'a, 'ctx> {
fn new(target_scope_id: ScopeId, ctx: &'ctx mut TraverseCtx<'a>) -> Self {
Self { ctx, target_scope_id }
}
fn move_scope_to_target(&mut self, scope_id: ScopeId) {
self.ctx.scoping_mut().change_scope_parent_id(scope_id, Some(self.target_scope_id));
}
}
impl<'a> Visit<'a> for BindingMover<'a, '_> {
fn visit_formal_parameter(&mut self, param: &FormalParameter<'a>) {
self.visit_binding_pattern(¶m.pattern);
if let Some(initializer) = ¶m.initializer {
self.visit_expression(initializer);
}
}
fn visit_formal_parameter_rest(&mut self, param: &FormalParameterRest<'a>) {
self.visit_binding_rest_element(¶m.rest);
}
fn visit_assignment_pattern(&mut self, pattern: &AssignmentPattern<'a>) {
self.visit_binding_pattern(&pattern.left);
self.visit_expression(&pattern.right);
}
fn visit_binding_property(&mut self, property: &BindingProperty<'a>) {
if property.computed {
self.visit_property_key(&property.key);
}
self.visit_binding_pattern(&property.value);
}
#[inline]
fn visit_function(&mut self, func: &Function<'a>, _flags: ScopeFlags) {
self.move_scope_to_target(func.scope_id());
}
#[inline]
fn visit_arrow_function_expression(&mut self, func: &ArrowFunctionExpression<'a>) {
self.move_scope_to_target(func.scope_id());
}
#[inline]
fn visit_class(&mut self, class: &Class<'a>) {
self.visit_decorators(&class.decorators);
self.move_scope_to_target(class.scope_id());
}
fn visit_binding_identifier(&mut self, ident: &BindingIdentifier<'a>) {
let symbol_id = ident.symbol_id();
let current_scope_id = self.ctx.scoping().symbol_scope_id(symbol_id);
self.ctx.scoping_mut().move_binding_by_symbol_id(
current_scope_id,
self.target_scope_id,
symbol_id,
);
}
}