oxc_transformer/jsx/
refresh.rs

1use std::{collections::hash_map::Entry, iter, str};
2
3use base64::{
4    encoded_len as base64_encoded_len,
5    prelude::{BASE64_STANDARD, Engine},
6};
7use rustc_hash::{FxHashMap, FxHashSet};
8use sha1::{Digest, Sha1};
9
10use oxc_allocator::{
11    Address, CloneIn, GetAddress, StringBuilder as ArenaStringBuilder, TakeIn, Vec as ArenaVec,
12};
13use oxc_ast::{AstBuilder, NONE, ast::*, match_expression};
14use oxc_ast_visit::{
15    Visit,
16    walk::{walk_call_expression, walk_declaration},
17};
18use oxc_semantic::{ReferenceFlags, ScopeFlags, ScopeId, SymbolFlags, SymbolId};
19use oxc_span::{Atom, GetSpan, SPAN};
20use oxc_syntax::operator::AssignmentOperator;
21use oxc_traverse::{Ancestor, BoundIdentifier, Traverse};
22
23use crate::{
24    context::{TransformCtx, TraverseCtx},
25    state::TransformState,
26};
27
28use super::options::ReactRefreshOptions;
29
30/// Parse a string into a `RefreshIdentifierResolver` and convert it into an `Expression`
31#[derive(Debug)]
32enum RefreshIdentifierResolver<'a> {
33    /// Simple IdentifierReference (e.g. `$RefreshReg$`)
34    Identifier(IdentifierReference<'a>),
35    /// StaticMemberExpression (object, property) (e.g. `window.$RefreshReg$`)
36    Member((IdentifierReference<'a>, IdentifierName<'a>)),
37    /// Used for `import.meta` expression (e.g. `import.meta.$RefreshReg$`)
38    Expression(Expression<'a>),
39}
40
41impl<'a> RefreshIdentifierResolver<'a> {
42    /// Parses a string into a RefreshIdentifierResolver
43    pub fn parse(input: &str, ast: AstBuilder<'a>) -> Self {
44        let mut parts = input.split('.');
45
46        let first_part = parts.next().unwrap();
47        let Some(second_part) = parts.next() else {
48            // Handle simple identifier reference
49            return Self::Identifier(ast.identifier_reference(SPAN, ast.atom(input)));
50        };
51
52        if first_part == "import" {
53            // Handle `import.meta.$RefreshReg$` expression
54            let mut expr = ast.expression_meta_property(
55                SPAN,
56                ast.identifier_name(SPAN, "import"),
57                ast.identifier_name(SPAN, ast.atom(second_part)),
58            );
59            if let Some(property) = parts.next() {
60                expr = Expression::from(ast.member_expression_static(
61                    SPAN,
62                    expr,
63                    ast.identifier_name(SPAN, ast.atom(property)),
64                    false,
65                ));
66            }
67            return Self::Expression(expr);
68        }
69
70        // Handle `window.$RefreshReg$` member expression
71        let object = ast.identifier_reference(SPAN, ast.atom(first_part));
72        let property = ast.identifier_name(SPAN, ast.atom(second_part));
73        Self::Member((object, property))
74    }
75
76    /// Converts the RefreshIdentifierResolver into an Expression
77    pub fn to_expression(&self, ctx: &mut TraverseCtx<'a>) -> Expression<'a> {
78        match self {
79            Self::Identifier(ident) => {
80                let reference_id = ctx.create_unbound_reference(&ident.name, ReferenceFlags::Read);
81                ctx.ast.expression_identifier_with_reference_id(
82                    ident.span,
83                    ident.name,
84                    reference_id,
85                )
86            }
87            Self::Member((ident, property)) => {
88                let reference_id = ctx.create_unbound_reference(&ident.name, ReferenceFlags::Read);
89                let ident = ctx.ast.expression_identifier_with_reference_id(
90                    ident.span,
91                    ident.name,
92                    reference_id,
93                );
94                Expression::from(ctx.ast.member_expression_static(
95                    SPAN,
96                    ident,
97                    property.clone(),
98                    false,
99                ))
100            }
101            Self::Expression(expr) => expr.clone_in(ctx.ast.allocator),
102        }
103    }
104}
105
106/// React Fast Refresh
107///
108/// Transform React functional components to integrate Fast Refresh.
109///
110/// References:
111///
112/// * <https://github.com/facebook/react/issues/16604#issuecomment-528663101>
113/// * <https://github.com/facebook/react/blob/v18.3.1/packages/react-refresh/src/ReactFreshBabelPlugin.js>
114pub struct ReactRefresh<'a, 'ctx> {
115    refresh_reg: RefreshIdentifierResolver<'a>,
116    refresh_sig: RefreshIdentifierResolver<'a>,
117    emit_full_signatures: bool,
118    ctx: &'ctx TransformCtx<'a>,
119    // States
120    registrations: Vec<(BoundIdentifier<'a>, Atom<'a>)>,
121    /// Used to wrap call expression with signature.
122    /// (eg: hoc(() => {}) -> _s1(hoc(_s1(() => {}))))
123    last_signature: Option<(BindingIdentifier<'a>, ArenaVec<'a, Argument<'a>>)>,
124    // (function_scope_id, key)
125    function_signature_keys: FxHashMap<ScopeId, String>,
126    non_builtin_hooks_callee: FxHashMap<ScopeId, Vec<Option<Expression<'a>>>>,
127    /// Used to determine which bindings are used in JSX calls.
128    used_in_jsx_bindings: FxHashSet<SymbolId>,
129}
130
131impl<'a, 'ctx> ReactRefresh<'a, 'ctx> {
132    pub fn new(
133        options: &ReactRefreshOptions,
134        ast: AstBuilder<'a>,
135        ctx: &'ctx TransformCtx<'a>,
136    ) -> Self {
137        Self {
138            refresh_reg: RefreshIdentifierResolver::parse(&options.refresh_reg, ast),
139            refresh_sig: RefreshIdentifierResolver::parse(&options.refresh_sig, ast),
140            emit_full_signatures: options.emit_full_signatures,
141            registrations: Vec::default(),
142            ctx,
143            last_signature: None,
144            function_signature_keys: FxHashMap::default(),
145            non_builtin_hooks_callee: FxHashMap::default(),
146            used_in_jsx_bindings: FxHashSet::default(),
147        }
148    }
149}
150
151impl<'a> Traverse<'a, TransformState<'a>> for ReactRefresh<'a, '_> {
152    fn enter_program(&mut self, program: &mut Program<'a>, ctx: &mut TraverseCtx<'a>) {
153        self.used_in_jsx_bindings = UsedInJSXBindingsCollector::collect(program, ctx);
154
155        let mut new_statements = ctx.ast.vec_with_capacity(program.body.len() * 2);
156        for mut statement in program.body.take_in(ctx.ast) {
157            let next_statement = self.process_statement(&mut statement, ctx);
158            new_statements.push(statement);
159            if let Some(assignment_expression) = next_statement {
160                new_statements.push(assignment_expression);
161            }
162        }
163        program.body = new_statements;
164    }
165
166    fn exit_program(&mut self, program: &mut Program<'a>, ctx: &mut TraverseCtx<'a>) {
167        if self.registrations.is_empty() {
168            return;
169        }
170
171        let var_decl = Statement::from(ctx.ast.declaration_variable(
172            SPAN,
173            VariableDeclarationKind::Var,
174            ctx.ast.vec(), // This is replaced at the end
175            false,
176        ));
177
178        let mut variable_declarator_items = ctx.ast.vec_with_capacity(self.registrations.len());
179        let calls = self.registrations.iter().map(|(binding, persistent_id)| {
180            variable_declarator_items.push(ctx.ast.variable_declarator(
181                SPAN,
182                VariableDeclarationKind::Var,
183                binding.create_binding_pattern(ctx),
184                None,
185                false,
186            ));
187
188            let callee = self.refresh_reg.to_expression(ctx);
189            let arguments = ctx.ast.vec_from_array([
190                Argument::from(binding.create_read_expression(ctx)),
191                Argument::from(ctx.ast.expression_string_literal(SPAN, *persistent_id, None)),
192            ]);
193            ctx.ast.statement_expression(
194                SPAN,
195                ctx.ast.expression_call(SPAN, callee, NONE, arguments, false),
196            )
197        });
198
199        let var_decl_index = program.body.len();
200        program.body.extend(iter::once(var_decl).chain(calls));
201
202        let Statement::VariableDeclaration(var_decl) = &mut program.body[var_decl_index] else {
203            unreachable!()
204        };
205        var_decl.declarations = variable_declarator_items;
206    }
207
208    fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) {
209        let signature = match expr {
210            Expression::FunctionExpression(func) => self.create_signature_call_expression(
211                func.scope_id(),
212                func.body.as_mut().unwrap(),
213                ctx,
214            ),
215            Expression::ArrowFunctionExpression(arrow) => {
216                let call_fn =
217                    self.create_signature_call_expression(arrow.scope_id(), &mut arrow.body, ctx);
218
219                // If the signature is found, we will push a new statement to the arrow function body. So it's not an expression anymore.
220                if call_fn.is_some() {
221                    Self::transform_arrow_function_to_block(arrow, ctx);
222                }
223                call_fn
224            }
225            // hoc(_c = function() { })
226            Expression::AssignmentExpression(_) => return,
227            // hoc1(hoc2(...))
228            Expression::CallExpression(_) => self.last_signature.take(),
229            _ => None,
230        };
231
232        let Some((binding_identifier, mut arguments)) = signature else {
233            return;
234        };
235        let binding = BoundIdentifier::from_binding_ident(&binding_identifier);
236
237        if !matches!(expr, Expression::CallExpression(_)) {
238            // Try to get binding from parent VariableDeclarator
239            if let Ancestor::VariableDeclaratorInit(declarator) = ctx.parent()
240                && let Some(ident) = declarator.id().get_binding_identifier()
241            {
242                let id_binding = BoundIdentifier::from_binding_ident(ident);
243                self.handle_function_in_variable_declarator(&id_binding, &binding, arguments, ctx);
244                return;
245            }
246        }
247
248        let mut found_call_expression = false;
249        for ancestor in ctx.ancestors() {
250            if ancestor.is_assignment_expression() {
251                continue;
252            }
253            if ancestor.is_call_expression() {
254                found_call_expression = true;
255            }
256            break;
257        }
258
259        if found_call_expression {
260            self.last_signature =
261                Some((binding_identifier.clone(), arguments.clone_in(ctx.ast.allocator)));
262        }
263
264        arguments.insert(0, Argument::from(expr.take_in(ctx.ast)));
265        *expr = ctx.ast.expression_call(
266            SPAN,
267            binding.create_read_expression(ctx),
268            NONE,
269            arguments,
270            false,
271        );
272    }
273
274    fn exit_function(&mut self, func: &mut Function<'a>, ctx: &mut TraverseCtx<'a>) {
275        if !func.is_function_declaration() {
276            return;
277        }
278
279        let Some((binding_identifier, mut arguments)) = self.create_signature_call_expression(
280            func.scope_id(),
281            func.body.as_mut().unwrap(),
282            ctx,
283        ) else {
284            return;
285        };
286
287        let Some(id) = func.id.as_ref() else {
288            return;
289        };
290        let id_binding = BoundIdentifier::from_binding_ident(id);
291
292        arguments.insert(0, Argument::from(id_binding.create_read_expression(ctx)));
293
294        let binding = BoundIdentifier::from_binding_ident(&binding_identifier);
295        let callee = binding.create_read_expression(ctx);
296        let expr = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false);
297        let statement = ctx.ast.statement_expression(SPAN, expr);
298
299        // Get the address of the statement containing this `FunctionDeclaration`
300        let address = match ctx.parent() {
301            // For `export function Foo() {}`
302            // which is a `Statement::ExportNamedDeclaration`
303            Ancestor::ExportNamedDeclarationDeclaration(decl) => decl.address(),
304            // For `export default function() {}`
305            // which is a `Statement::ExportDefaultDeclaration`
306            Ancestor::ExportDefaultDeclarationDeclaration(decl) => decl.address(),
307            // Otherwise just a `function Foo() {}`
308            // which is a `Statement::FunctionDeclaration`.
309            // `Function` is always stored in a `Box`, so has a stable memory address.
310            _ => Address::from_ptr(func),
311        };
312        self.ctx.statement_injector.insert_after(&address, statement);
313    }
314
315    fn enter_call_expression(
316        &mut self,
317        call_expr: &mut CallExpression<'a>,
318        ctx: &mut TraverseCtx<'a>,
319    ) {
320        let current_scope_id = ctx.current_scope_id();
321        if !ctx.scoping().scope_flags(current_scope_id).is_function() {
322            return;
323        }
324
325        let hook_name = match &call_expr.callee {
326            Expression::Identifier(ident) => ident.name,
327            Expression::StaticMemberExpression(member) => member.property.name,
328            _ => return,
329        };
330
331        if !is_use_hook_name(&hook_name) {
332            return;
333        }
334
335        if !is_builtin_hook(&hook_name) {
336            // Check if a corresponding binding exists where we emit the signature.
337            let (binding_name, is_member_expression) = match &call_expr.callee {
338                Expression::Identifier(ident) => (Some(ident.name), false),
339                Expression::StaticMemberExpression(member) => {
340                    if let Expression::Identifier(object) = &member.object {
341                        (Some(object.name), true)
342                    } else {
343                        (None, false)
344                    }
345                }
346                _ => unreachable!(),
347            };
348
349            if let Some(binding_name) = binding_name {
350                self.non_builtin_hooks_callee.entry(current_scope_id).or_default().push(
351                    ctx.scoping()
352                        .find_binding(
353                            ctx.scoping().scope_parent_id(ctx.current_scope_id()).unwrap(),
354                            binding_name.as_str(),
355                        )
356                        .map(|symbol_id| {
357                            let mut expr = ctx.create_bound_ident_expr(
358                                SPAN,
359                                binding_name,
360                                symbol_id,
361                                ReferenceFlags::Read,
362                            );
363
364                            if is_member_expression {
365                                // binding_name.hook_name
366                                expr = Expression::from(ctx.ast.member_expression_static(
367                                    SPAN,
368                                    expr,
369                                    ctx.ast.identifier_name(SPAN, hook_name),
370                                    false,
371                                ));
372                            }
373                            expr
374                        }),
375                );
376            }
377        }
378
379        let declarator_id = if let Ancestor::VariableDeclaratorInit(declarator) = ctx.parent() {
380            // TODO: if there is no LHS, consider some other heuristic.
381            declarator.id().span().source_text(self.ctx.source_text)
382        } else {
383            ""
384        };
385
386        let args = &call_expr.arguments;
387        let (args_key, mut key_len) = if hook_name == "useState" && !args.is_empty() {
388            let args_key = args[0].span().source_text(self.ctx.source_text);
389            (args_key, args_key.len() + 4)
390        } else if hook_name == "useReducer" && args.len() > 1 {
391            let args_key = args[1].span().source_text(self.ctx.source_text);
392            (args_key, args_key.len() + 4)
393        } else {
394            ("", 2)
395        };
396
397        key_len += hook_name.len() + declarator_id.len();
398
399        let string = match self.function_signature_keys.entry(current_scope_id) {
400            Entry::Occupied(entry) => {
401                let string = entry.into_mut();
402                string.reserve(key_len + 2);
403                string.push_str("\\n");
404                string
405            }
406            Entry::Vacant(entry) => entry.insert(String::with_capacity(key_len)),
407        };
408
409        // `hook_name{{declarator_id(args_key)}}` or `hook_name{{declarator_id}}`
410        let old_len = string.len();
411
412        string.push_str(&hook_name);
413        string.push('{');
414        string.push_str(declarator_id);
415        if !args_key.is_empty() {
416            string.push('(');
417            string.push_str(args_key);
418            string.push(')');
419        }
420        string.push('}');
421
422        debug_assert_eq!(key_len, string.len() - old_len);
423    }
424}
425
426// Internal Methods
427impl<'a> ReactRefresh<'a, '_> {
428    fn create_registration(
429        &mut self,
430        persistent_id: Atom<'a>,
431        ctx: &mut TraverseCtx<'a>,
432    ) -> AssignmentTarget<'a> {
433        let binding = ctx.generate_uid_in_root_scope("c", SymbolFlags::FunctionScopedVariable);
434        let target = binding.create_target(ReferenceFlags::Write, ctx);
435        self.registrations.push((binding, persistent_id));
436        target
437    }
438
439    /// Similar to the `findInnerComponents` function in `react-refresh/babel`.
440    fn replace_inner_components(
441        &mut self,
442        inferred_name: &str,
443        expr: &mut Expression<'a>,
444        is_variable_declarator: bool,
445        ctx: &mut TraverseCtx<'a>,
446    ) -> bool {
447        match expr {
448            Expression::Identifier(ident) => {
449                // For case like:
450                // export const Something = hoc(Foo)
451                // we don't want to wrap Foo inside the call.
452                // Instead we assume it's registered at definition.
453                return is_componentish_name(&ident.name);
454            }
455            Expression::FunctionExpression(_) => {}
456            Expression::ArrowFunctionExpression(arrow) => {
457                // Don't transform `() => () => {}`
458                if arrow
459                    .get_expression()
460                    .is_some_and(|expr| matches!(expr, Expression::ArrowFunctionExpression(_)))
461                {
462                    return false;
463                }
464            }
465            Expression::CallExpression(call_expr) => {
466                let allowed_callee = matches!(
467                    call_expr.callee,
468                    Expression::Identifier(_)
469                        | Expression::ComputedMemberExpression(_)
470                        | Expression::StaticMemberExpression(_)
471                );
472
473                if allowed_callee {
474                    let callee_span = call_expr.callee.span();
475
476                    let Some(argument_expr) =
477                        call_expr.arguments.first_mut().and_then(|e| e.as_expression_mut())
478                    else {
479                        return false;
480                    };
481
482                    let found_inside = self.replace_inner_components(
483                        format!(
484                            "{}${}",
485                            inferred_name,
486                            callee_span.source_text(self.ctx.source_text)
487                        )
488                        .as_str(),
489                        argument_expr,
490                        /* is_variable_declarator */ false,
491                        ctx,
492                    );
493
494                    if !found_inside {
495                        return false;
496                    }
497
498                    // const Foo = hoc1(hoc2(() => {}))
499                    // export default memo(React.forwardRef(function() {}))
500                    if is_variable_declarator {
501                        return true;
502                    }
503                } else {
504                    return false;
505                }
506            }
507            _ => {
508                return false;
509            }
510        }
511
512        if !is_variable_declarator {
513            *expr = ctx.ast.expression_assignment(
514                SPAN,
515                AssignmentOperator::Assign,
516                self.create_registration(ctx.ast.atom(inferred_name), ctx),
517                expr.take_in(ctx.ast),
518            );
519        }
520
521        true
522    }
523
524    /// _c = id.name;
525    fn create_assignment_expression(
526        &mut self,
527        id: &BindingIdentifier<'a>,
528        ctx: &mut TraverseCtx<'a>,
529    ) -> Statement<'a> {
530        let left = self.create_registration(id.name, ctx);
531        let right =
532            ctx.create_bound_ident_expr(SPAN, id.name, id.symbol_id(), ReferenceFlags::Read);
533        let expr = ctx.ast.expression_assignment(SPAN, AssignmentOperator::Assign, left, right);
534        ctx.ast.statement_expression(SPAN, expr)
535    }
536
537    fn create_signature_call_expression(
538        &mut self,
539        scope_id: ScopeId,
540        body: &mut FunctionBody<'a>,
541        ctx: &mut TraverseCtx<'a>,
542    ) -> Option<(BindingIdentifier<'a>, ArenaVec<'a, Argument<'a>>)> {
543        let key = self.function_signature_keys.remove(&scope_id)?;
544
545        let key = if self.emit_full_signatures {
546            ctx.ast.atom(&key)
547        } else {
548            // Prefer to hash when we can (e.g. outside of ASTExplorer).
549            // This makes it deterministically compact, even if there's
550            // e.g. a useState initializer with some code inside.
551            // We also need it for www that has transforms like cx()
552            // that don't understand if something is part of a string.
553            const SHA1_HASH_LEN: usize = 20;
554            const ENCODED_LEN: usize = {
555                let len = base64_encoded_len(SHA1_HASH_LEN, true);
556                match len {
557                    Some(l) => l,
558                    None => panic!("Invalid base64 length"),
559                }
560            };
561
562            let mut hasher = Sha1::new();
563            hasher.update(&key);
564            let hash = hasher.finalize();
565            debug_assert_eq!(hash.len(), SHA1_HASH_LEN);
566
567            // Encode to base64 string directly in arena, without an intermediate string allocation
568            #[expect(clippy::items_after_statements)]
569            const ZEROS_STR: &str = {
570                const ZEROS_BYTES: [u8; ENCODED_LEN] = [0; ENCODED_LEN];
571                match str::from_utf8(&ZEROS_BYTES) {
572                    Ok(s) => s,
573                    Err(_) => unreachable!(),
574                }
575            };
576
577            let mut hashed_key = ArenaStringBuilder::from_str_in(ZEROS_STR, ctx.ast.allocator);
578            // SAFETY: Base64 encoding only produces ASCII bytes. Even if our assumptions are incorrect,
579            // and Base64 bytes do not fill `hashed_key` completely, the remaining bytes are 0, so also ASCII.
580            let hashed_key_bytes = unsafe { hashed_key.as_mut_str().as_bytes_mut() };
581            let encoded_bytes = BASE64_STANDARD.encode_slice(hash, hashed_key_bytes).unwrap();
582            debug_assert_eq!(encoded_bytes, ENCODED_LEN);
583            Atom::from(hashed_key)
584        };
585
586        let callee_list = self.non_builtin_hooks_callee.remove(&scope_id).unwrap_or_default();
587        let callee_len = callee_list.len();
588        let custom_hooks_in_scope = ctx.ast.vec_from_iter(
589            callee_list.into_iter().filter_map(|e| e.map(ArrayExpressionElement::from)),
590        );
591
592        let force_reset = custom_hooks_in_scope.len() != callee_len;
593
594        let mut arguments = ctx.ast.vec();
595        arguments.push(Argument::from(ctx.ast.expression_string_literal(SPAN, key, None)));
596
597        if force_reset || !custom_hooks_in_scope.is_empty() {
598            arguments.push(Argument::from(ctx.ast.expression_boolean_literal(SPAN, force_reset)));
599        }
600
601        if !custom_hooks_in_scope.is_empty() {
602            // function () { return custom_hooks_in_scope }
603            let formal_parameters = ctx.ast.formal_parameters(
604                SPAN,
605                FormalParameterKind::FormalParameter,
606                ctx.ast.vec(),
607                NONE,
608            );
609            let function_body = ctx.ast.function_body(
610                SPAN,
611                ctx.ast.vec(),
612                ctx.ast.vec1(ctx.ast.statement_return(
613                    SPAN,
614                    Some(ctx.ast.expression_array(SPAN, custom_hooks_in_scope)),
615                )),
616            );
617            let scope_id = ctx.create_child_scope_of_current(ScopeFlags::Function);
618            let function =
619                Argument::from(ctx.ast.expression_function_with_scope_id_and_pure_and_pife(
620                    SPAN,
621                    FunctionType::FunctionExpression,
622                    None,
623                    false,
624                    false,
625                    false,
626                    NONE,
627                    NONE,
628                    formal_parameters,
629                    NONE,
630                    Some(function_body),
631                    scope_id,
632                    false,
633                    false,
634                ));
635            arguments.push(function);
636        }
637
638        // _s = refresh_sig();
639        let init = ctx.ast.expression_call(
640            SPAN,
641            self.refresh_sig.to_expression(ctx),
642            NONE,
643            ctx.ast.vec(),
644            false,
645        );
646        let binding = self.ctx.var_declarations.create_uid_var_with_init("s", init, ctx);
647
648        // _s();
649        let call_expression = ctx.ast.statement_expression(
650            SPAN,
651            ctx.ast.expression_call(
652                SPAN,
653                binding.create_read_expression(ctx),
654                NONE,
655                ctx.ast.vec(),
656                false,
657            ),
658        );
659
660        body.statements.insert(0, call_expression);
661
662        // Following is the signature call expression, will be generated in call site.
663        // _s(App, signature_key, false, function() { return [] });
664        //                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ custom hooks only
665        let binding_identifier = binding.create_binding_identifier(ctx);
666        Some((binding_identifier, arguments))
667    }
668
669    fn process_statement(
670        &mut self,
671        statement: &mut Statement<'a>,
672        ctx: &mut TraverseCtx<'a>,
673    ) -> Option<Statement<'a>> {
674        match statement {
675            Statement::VariableDeclaration(variable) => {
676                self.handle_variable_declaration(variable, ctx)
677            }
678            Statement::FunctionDeclaration(func) => self.handle_function_declaration(func, ctx),
679            Statement::ExportNamedDeclaration(export_decl) => {
680                if let Some(declaration) = &mut export_decl.declaration {
681                    match declaration {
682                        Declaration::FunctionDeclaration(func) => {
683                            self.handle_function_declaration(func, ctx)
684                        }
685                        Declaration::VariableDeclaration(variable) => {
686                            self.handle_variable_declaration(variable, ctx)
687                        }
688                        _ => None,
689                    }
690                } else {
691                    None
692                }
693            }
694            Statement::ExportDefaultDeclaration(stmt_decl) => {
695                match &mut stmt_decl.declaration {
696                    declaration @ match_expression!(ExportDefaultDeclarationKind) => {
697                        let expression = declaration.to_expression_mut();
698                        if !matches!(expression, Expression::CallExpression(_)) {
699                            // For now, we only support possible HOC calls here.
700                            // Named function declarations are handled in FunctionDeclaration.
701                            // Anonymous direct exports like export default function() {}
702                            // are currently ignored.
703                            return None;
704                        }
705
706                        // This code path handles nested cases like:
707                        // export default memo(() => {})
708                        // In those cases it is more plausible people will omit names
709                        // so they're worth handling despite possible false positives.
710                        // More importantly, it handles the named case:
711                        // export default memo(function Named() {})
712                        self.replace_inner_components(
713                            "%default%",
714                            expression,
715                            /* is_variable_declarator */ false,
716                            ctx,
717                        );
718
719                        None
720                    }
721                    ExportDefaultDeclarationKind::FunctionDeclaration(func) => {
722                        if let Some(id) = &func.id {
723                            if func.is_typescript_syntax() || !is_componentish_name(&id.name) {
724                                return None;
725                            }
726
727                            return Some(self.create_assignment_expression(id, ctx));
728                        }
729                        None
730                    }
731                    _ => None,
732                }
733            }
734            _ => None,
735        }
736    }
737
738    fn handle_function_declaration(
739        &mut self,
740        func: &Function<'a>,
741        ctx: &mut TraverseCtx<'a>,
742    ) -> Option<Statement<'a>> {
743        let Some(id) = &func.id else {
744            return None;
745        };
746
747        if func.is_typescript_syntax() || !is_componentish_name(&id.name) {
748            return None;
749        }
750
751        Some(self.create_assignment_expression(id, ctx))
752    }
753
754    fn handle_variable_declaration(
755        &mut self,
756        decl: &mut VariableDeclaration<'a>,
757        ctx: &mut TraverseCtx<'a>,
758    ) -> Option<Statement<'a>> {
759        if decl.declarations.len() != 1 {
760            return None;
761        }
762
763        let declarator = decl.declarations.first_mut().unwrap_or_else(|| unreachable!());
764        let init = declarator.init.as_mut()?;
765        let id = declarator.id.get_binding_identifier()?;
766        let symbol_id = id.symbol_id();
767
768        if !is_componentish_name(&id.name) {
769            return None;
770        }
771
772        match init {
773            // Likely component definitions.
774            Expression::ArrowFunctionExpression(arrow) => {
775                // () => () => {}
776                if arrow.get_expression().is_some_and(|expr| matches!(expr, Expression::ArrowFunctionExpression(_))) {
777                    return None;
778                }
779            }
780            Expression::FunctionExpression(_)
781            // Maybe something like styled.div`...`
782            | Expression::TaggedTemplateExpression(_) => {
783                // Special case when a variable would get an inferred name:
784                // let Foo = () => {}
785                // let Foo = function() {}
786                // let Foo = styled.div``;
787                // We'll register it on next line so that
788                // we don't mess up the inferred 'Foo' function name.
789                // (eg: with @babel/plugin-transform-react-display-name or
790                // babel-plugin-styled-components)
791            }
792            Expression::CallExpression(call_expr) => {
793                let is_import_expression = match call_expr.callee.get_inner_expression() {
794                    Expression::ImportExpression(_) => {
795                        true
796                    }
797                    Expression::Identifier(ident) => {
798                        ident.name.starts_with("require")
799                    },
800                    _ => false
801                };
802
803                if is_import_expression {
804                    return None;
805                }
806            }
807            _ => {
808                return None;
809            }
810        }
811
812        // Maybe a HOC.
813        // Try to determine if this is some form of import.
814        let found_inside = self
815            .replace_inner_components(&id.name, init, /* is_variable_declarator */ true, ctx);
816
817        if !found_inside && !self.used_in_jsx_bindings.contains(&symbol_id) {
818            return None;
819        }
820
821        Some(self.create_assignment_expression(id, ctx))
822    }
823
824    /// Handle `export const Foo = () => {}` or `const Foo = function() {}`
825    fn handle_function_in_variable_declarator(
826        &self,
827        id_binding: &BoundIdentifier<'a>,
828        binding: &BoundIdentifier<'a>,
829        mut arguments: ArenaVec<'a, Argument<'a>>,
830        ctx: &mut TraverseCtx<'a>,
831    ) {
832        // Special case when a function would get an inferred name:
833        // let Foo = () => {}
834        // let Foo = function() {}
835        // We'll add signature it on next line so that
836        // we don't mess up the inferred 'Foo' function name.
837
838        // Result: let Foo = () => {}; __signature(Foo, ...);
839        arguments.insert(0, Argument::from(id_binding.create_read_expression(ctx)));
840        let statement = ctx.ast.statement_expression(
841            SPAN,
842            ctx.ast.expression_call(
843                SPAN,
844                binding.create_read_expression(ctx),
845                NONE,
846                arguments,
847                false,
848            ),
849        );
850
851        // Get the address of the statement containing this `VariableDeclarator`
852        let address =
853            if let Ancestor::ExportNamedDeclarationDeclaration(export_decl) = ctx.ancestor(2) {
854                // For `export const Foo = () => {}`
855                // which is a `VariableDeclaration` inside a `Statement::ExportNamedDeclaration`
856                export_decl.address()
857            } else {
858                // Otherwise just a `const Foo = () => {}` which is a `Statement::VariableDeclaration`
859                let var_decl = ctx.ancestor(1);
860                debug_assert!(matches!(var_decl, Ancestor::VariableDeclarationDeclarations(_)));
861                var_decl.address()
862            };
863        self.ctx.statement_injector.insert_after(&address, statement);
864    }
865
866    /// Convert arrow function expression to normal arrow function
867    ///
868    /// ```js
869    /// () => 1
870    /// ```
871    /// to
872    /// ```js
873    /// () => { return 1 }
874    /// ```
875    fn transform_arrow_function_to_block(
876        arrow: &mut ArrowFunctionExpression<'a>,
877        ctx: &TraverseCtx<'a>,
878    ) {
879        if !arrow.expression {
880            return;
881        }
882
883        arrow.expression = false;
884
885        let Some(Statement::ExpressionStatement(statement)) = arrow.body.statements.pop() else {
886            unreachable!("arrow function body is never empty")
887        };
888
889        arrow
890            .body
891            .statements
892            .push(ctx.ast.statement_return(SPAN, Some(statement.unbox().expression)));
893    }
894}
895
896fn is_componentish_name(name: &str) -> bool {
897    name.as_bytes().first().is_some_and(u8::is_ascii_uppercase)
898}
899
900fn is_use_hook_name(name: &str) -> bool {
901    name.starts_with("use") && name.as_bytes().get(3).is_none_or(u8::is_ascii_uppercase)
902}
903
904#[rustfmt::skip]
905fn is_builtin_hook(hook_name: &str) -> bool {
906    matches!(
907        hook_name,
908        "useState" | "useReducer" | "useEffect" |
909        "useLayoutEffect" | "useMemo" | "useCallback" |
910        "useRef" | "useContext" | "useImperativeHandle" |
911        "useDebugValue" | "useId" | "useDeferredValue" |
912        "useTransition" | "useInsertionEffect" | "useSyncExternalStore" |
913        "useFormStatus" | "useFormState" | "useActionState" |
914        "useOptimistic"
915    )
916}
917
918/// Collects all bindings that are used in JSX elements or JSX-like calls.
919///
920/// For <https://github.com/facebook/react/blob/ba6a9e94edf0db3ad96432804f9931ce9dc89fec/packages/react-refresh/src/ReactFreshBabelPlugin.js#L161-L199>
921struct UsedInJSXBindingsCollector<'a, 'b> {
922    ctx: &'b TraverseCtx<'a>,
923    bindings: FxHashSet<SymbolId>,
924}
925
926impl<'a, 'b> UsedInJSXBindingsCollector<'a, 'b> {
927    fn collect(program: &Program<'a>, ctx: &'b TraverseCtx<'a>) -> FxHashSet<SymbolId> {
928        let mut visitor = Self { ctx, bindings: FxHashSet::default() };
929        visitor.visit_program(program);
930        visitor.bindings
931    }
932
933    fn is_jsx_like_call(name: &str) -> bool {
934        matches!(name, "createElement" | "jsx" | "jsxDEV" | "jsxs")
935    }
936}
937
938impl<'a> Visit<'a> for UsedInJSXBindingsCollector<'a, '_> {
939    fn visit_call_expression(&mut self, it: &CallExpression<'a>) {
940        walk_call_expression(self, it);
941
942        let is_jsx_call = match &it.callee {
943            Expression::Identifier(ident) => Self::is_jsx_like_call(&ident.name),
944            Expression::StaticMemberExpression(member) => {
945                Self::is_jsx_like_call(&member.property.name)
946            }
947            _ => false,
948        };
949
950        if is_jsx_call
951            && let Some(Argument::Identifier(ident)) = it.arguments.first()
952            && let Some(symbol_id) =
953                self.ctx.scoping().get_reference(ident.reference_id()).symbol_id()
954        {
955            self.bindings.insert(symbol_id);
956        }
957    }
958
959    fn visit_jsx_opening_element(&mut self, it: &JSXOpeningElement<'_>) {
960        if let Some(ident) = it.name.get_identifier()
961            && let Some(symbol_id) =
962                self.ctx.scoping().get_reference(ident.reference_id()).symbol_id()
963        {
964            self.bindings.insert(symbol_id);
965        }
966    }
967
968    #[inline]
969    fn visit_ts_type_annotation(&mut self, _it: &TSTypeAnnotation<'a>) {
970        // Skip type annotations because it definitely doesn't have any JSX bindings
971    }
972
973    #[inline]
974    fn visit_declaration(&mut self, it: &Declaration<'a>) {
975        if matches!(
976            it,
977            Declaration::TSTypeAliasDeclaration(_) | Declaration::TSInterfaceDeclaration(_)
978        ) {
979            // Skip type-only declarations because it definitely doesn't have any JSX bindings
980            return;
981        }
982        walk_declaration(self, it);
983    }
984
985    #[inline]
986    fn visit_import_declaration(&mut self, _it: &ImportDeclaration<'a>) {
987        // Skip import declarations because it definitely doesn't have any JSX bindings
988    }
989
990    #[inline]
991    fn visit_export_all_declaration(&mut self, _it: &ExportAllDeclaration<'a>) {
992        // Skip export all declarations because it definitely doesn't have any JSX bindings
993    }
994}