Skip to main content

php_ast/
visitor.rs

1use std::ops::ControlFlow;
2
3use crate::ast::*;
4
5/// Visitor trait for immutable AST traversal.
6///
7/// All methods return `ControlFlow<()>`:
8/// - `ControlFlow::Continue(())` — keep walking.
9/// - `ControlFlow::Break(())` — stop the entire traversal immediately.
10///
11/// Default implementations recursively walk child nodes, so implementors
12/// only need to override the node types they care about.
13///
14/// To **skip** a subtree, override the method and return `Continue(())`
15/// without calling the corresponding `walk_*` function.
16///
17/// # Example
18///
19/// ```
20/// use php_ast::visitor::{Visitor, walk_expr};
21/// use php_ast::ast::*;
22/// use std::ops::ControlFlow;
23///
24/// struct VarCounter { count: usize }
25///
26/// impl<'arena, 'src> Visitor<'arena, 'src> for VarCounter {
27///     fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) -> ControlFlow<()> {
28///         if matches!(&expr.kind, ExprKind::Variable(_)) {
29///             self.count += 1;
30///         }
31///         walk_expr(self, expr)
32///     }
33/// }
34/// ```
35pub trait Visitor<'arena, 'src> {
36    fn visit_program(&mut self, program: &Program<'arena, 'src>) -> ControlFlow<()> {
37        walk_program(self, program)
38    }
39
40    fn visit_stmt(&mut self, stmt: &Stmt<'arena, 'src>) -> ControlFlow<()> {
41        walk_stmt(self, stmt)
42    }
43
44    fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) -> ControlFlow<()> {
45        walk_expr(self, expr)
46    }
47
48    fn visit_param(&mut self, param: &Param<'arena, 'src>) -> ControlFlow<()> {
49        walk_param(self, param)
50    }
51
52    fn visit_arg(&mut self, arg: &Arg<'arena, 'src>) -> ControlFlow<()> {
53        walk_arg(self, arg)
54    }
55
56    fn visit_class_member(&mut self, member: &ClassMember<'arena, 'src>) -> ControlFlow<()> {
57        walk_class_member(self, member)
58    }
59
60    fn visit_enum_member(&mut self, member: &EnumMember<'arena, 'src>) -> ControlFlow<()> {
61        walk_enum_member(self, member)
62    }
63
64    fn visit_property_hook(&mut self, hook: &PropertyHook<'arena, 'src>) -> ControlFlow<()> {
65        walk_property_hook(self, hook)
66    }
67
68    fn visit_type_hint(&mut self, type_hint: &TypeHint<'arena, 'src>) -> ControlFlow<()> {
69        walk_type_hint(self, type_hint)
70    }
71
72    fn visit_attribute(&mut self, attribute: &Attribute<'arena, 'src>) -> ControlFlow<()> {
73        walk_attribute(self, attribute)
74    }
75
76    fn visit_catch_clause(&mut self, catch: &CatchClause<'arena, 'src>) -> ControlFlow<()> {
77        walk_catch_clause(self, catch)
78    }
79
80    fn visit_match_arm(&mut self, arm: &MatchArm<'arena, 'src>) -> ControlFlow<()> {
81        walk_match_arm(self, arm)
82    }
83
84    fn visit_closure_use_var(&mut self, _var: &ClosureUseVar<'src>) -> ControlFlow<()> {
85        ControlFlow::Continue(())
86    }
87}
88
89// =============================================================================
90// Walk functions
91// =============================================================================
92
93pub fn walk_program<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
94    visitor: &mut V,
95    program: &Program<'arena, 'src>,
96) -> ControlFlow<()> {
97    for stmt in program.stmts.iter() {
98        visitor.visit_stmt(stmt)?;
99    }
100    ControlFlow::Continue(())
101}
102
103pub fn walk_stmt<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
104    visitor: &mut V,
105    stmt: &Stmt<'arena, 'src>,
106) -> ControlFlow<()> {
107    match &stmt.kind {
108        StmtKind::Expression(expr) => {
109            visitor.visit_expr(expr)?;
110        }
111        StmtKind::Echo(exprs) => {
112            for expr in exprs.iter() {
113                visitor.visit_expr(expr)?;
114            }
115        }
116        StmtKind::Return(expr) => {
117            if let Some(expr) = expr {
118                visitor.visit_expr(expr)?;
119            }
120        }
121        StmtKind::Block(stmts) => {
122            for stmt in stmts.iter() {
123                visitor.visit_stmt(stmt)?;
124            }
125        }
126        StmtKind::If(if_stmt) => {
127            visitor.visit_expr(&if_stmt.condition)?;
128            visitor.visit_stmt(if_stmt.then_branch)?;
129            for elseif in if_stmt.elseif_branches.iter() {
130                visitor.visit_expr(&elseif.condition)?;
131                visitor.visit_stmt(&elseif.body)?;
132            }
133            if let Some(else_branch) = &if_stmt.else_branch {
134                visitor.visit_stmt(else_branch)?;
135            }
136        }
137        StmtKind::While(while_stmt) => {
138            visitor.visit_expr(&while_stmt.condition)?;
139            visitor.visit_stmt(while_stmt.body)?;
140        }
141        StmtKind::For(for_stmt) => {
142            for expr in for_stmt.init.iter() {
143                visitor.visit_expr(expr)?;
144            }
145            for expr in for_stmt.condition.iter() {
146                visitor.visit_expr(expr)?;
147            }
148            for expr in for_stmt.update.iter() {
149                visitor.visit_expr(expr)?;
150            }
151            visitor.visit_stmt(for_stmt.body)?;
152        }
153        StmtKind::Foreach(foreach_stmt) => {
154            visitor.visit_expr(&foreach_stmt.expr)?;
155            if let Some(key) = &foreach_stmt.key {
156                visitor.visit_expr(key)?;
157            }
158            visitor.visit_expr(&foreach_stmt.value)?;
159            visitor.visit_stmt(foreach_stmt.body)?;
160        }
161        StmtKind::DoWhile(do_while) => {
162            visitor.visit_stmt(do_while.body)?;
163            visitor.visit_expr(&do_while.condition)?;
164        }
165        StmtKind::Function(func) => {
166            walk_function_like(visitor, &func.attributes, &func.params, &func.return_type)?;
167            for stmt in func.body.iter() {
168                visitor.visit_stmt(stmt)?;
169            }
170        }
171        StmtKind::Break(expr) | StmtKind::Continue(expr) => {
172            if let Some(expr) = expr {
173                visitor.visit_expr(expr)?;
174            }
175        }
176        StmtKind::Switch(switch_stmt) => {
177            visitor.visit_expr(&switch_stmt.expr)?;
178            for case in switch_stmt.cases.iter() {
179                if let Some(value) = &case.value {
180                    visitor.visit_expr(value)?;
181                }
182                for stmt in case.body.iter() {
183                    visitor.visit_stmt(stmt)?;
184                }
185            }
186        }
187        StmtKind::Throw(expr) => {
188            visitor.visit_expr(expr)?;
189        }
190        StmtKind::TryCatch(tc) => {
191            for stmt in tc.body.iter() {
192                visitor.visit_stmt(stmt)?;
193            }
194            for catch in tc.catches.iter() {
195                visitor.visit_catch_clause(catch)?;
196            }
197            if let Some(finally) = &tc.finally {
198                for stmt in finally.iter() {
199                    visitor.visit_stmt(stmt)?;
200                }
201            }
202        }
203        StmtKind::Declare(decl) => {
204            for (_, expr) in decl.directives.iter() {
205                visitor.visit_expr(expr)?;
206            }
207            if let Some(body) = decl.body {
208                visitor.visit_stmt(body)?;
209            }
210        }
211        StmtKind::Unset(exprs) | StmtKind::Global(exprs) => {
212            for expr in exprs.iter() {
213                visitor.visit_expr(expr)?;
214            }
215        }
216        StmtKind::Class(class) => {
217            walk_attributes(visitor, &class.attributes)?;
218            for member in class.members.iter() {
219                visitor.visit_class_member(member)?;
220            }
221        }
222        StmtKind::Interface(iface) => {
223            walk_attributes(visitor, &iface.attributes)?;
224            for member in iface.members.iter() {
225                visitor.visit_class_member(member)?;
226            }
227        }
228        StmtKind::Trait(trait_decl) => {
229            walk_attributes(visitor, &trait_decl.attributes)?;
230            for member in trait_decl.members.iter() {
231                visitor.visit_class_member(member)?;
232            }
233        }
234        StmtKind::Enum(enum_decl) => {
235            walk_attributes(visitor, &enum_decl.attributes)?;
236            for member in enum_decl.members.iter() {
237                visitor.visit_enum_member(member)?;
238            }
239        }
240        StmtKind::Namespace(ns) => {
241            if let NamespaceBody::Braced(stmts) = &ns.body {
242                for stmt in stmts.iter() {
243                    visitor.visit_stmt(stmt)?;
244                }
245            }
246        }
247        StmtKind::Const(items) => {
248            for item in items.iter() {
249                walk_attributes(visitor, &item.attributes)?;
250                visitor.visit_expr(&item.value)?;
251            }
252        }
253        StmtKind::StaticVar(vars) => {
254            for var in vars.iter() {
255                if let Some(default) = &var.default {
256                    visitor.visit_expr(default)?;
257                }
258            }
259        }
260        StmtKind::Use(_)
261        | StmtKind::Goto(_)
262        | StmtKind::Label(_)
263        | StmtKind::Nop
264        | StmtKind::InlineHtml(_)
265        | StmtKind::HaltCompiler(_)
266        | StmtKind::Error => {}
267    }
268    ControlFlow::Continue(())
269}
270
271pub fn walk_expr<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
272    visitor: &mut V,
273    expr: &Expr<'arena, 'src>,
274) -> ControlFlow<()> {
275    match &expr.kind {
276        ExprKind::Assign(assign) => {
277            visitor.visit_expr(assign.target)?;
278            visitor.visit_expr(assign.value)?;
279        }
280        ExprKind::Binary(binary) => {
281            visitor.visit_expr(binary.left)?;
282            visitor.visit_expr(binary.right)?;
283        }
284        ExprKind::UnaryPrefix(unary) => {
285            visitor.visit_expr(unary.operand)?;
286        }
287        ExprKind::UnaryPostfix(unary) => {
288            visitor.visit_expr(unary.operand)?;
289        }
290        ExprKind::Ternary(ternary) => {
291            visitor.visit_expr(ternary.condition)?;
292            if let Some(then_expr) = &ternary.then_expr {
293                visitor.visit_expr(then_expr)?;
294            }
295            visitor.visit_expr(ternary.else_expr)?;
296        }
297        ExprKind::NullCoalesce(nc) => {
298            visitor.visit_expr(nc.left)?;
299            visitor.visit_expr(nc.right)?;
300        }
301        ExprKind::FunctionCall(call) => {
302            visitor.visit_expr(call.name)?;
303            for arg in call.args.iter() {
304                visitor.visit_arg(arg)?;
305            }
306        }
307        ExprKind::Array(elements) => {
308            for elem in elements.iter() {
309                if let Some(key) = &elem.key {
310                    visitor.visit_expr(key)?;
311                }
312                visitor.visit_expr(&elem.value)?;
313            }
314        }
315        ExprKind::ArrayAccess(access) => {
316            visitor.visit_expr(access.array)?;
317            if let Some(index) = &access.index {
318                visitor.visit_expr(index)?;
319            }
320        }
321        ExprKind::Print(expr) => {
322            visitor.visit_expr(expr)?;
323        }
324        ExprKind::Parenthesized(expr) => {
325            visitor.visit_expr(expr)?;
326        }
327        ExprKind::Cast(_, expr) => {
328            visitor.visit_expr(expr)?;
329        }
330        ExprKind::ErrorSuppress(expr) => {
331            visitor.visit_expr(expr)?;
332        }
333        ExprKind::Isset(exprs) => {
334            for expr in exprs.iter() {
335                visitor.visit_expr(expr)?;
336            }
337        }
338        ExprKind::Empty(expr) => {
339            visitor.visit_expr(expr)?;
340        }
341        ExprKind::Include(_, expr) => {
342            visitor.visit_expr(expr)?;
343        }
344        ExprKind::Eval(expr) => {
345            visitor.visit_expr(expr)?;
346        }
347        ExprKind::Exit(expr) => {
348            if let Some(expr) = expr {
349                visitor.visit_expr(expr)?;
350            }
351        }
352        ExprKind::Clone(expr) => {
353            visitor.visit_expr(expr)?;
354        }
355        ExprKind::CloneWith(object, overrides) => {
356            visitor.visit_expr(object)?;
357            visitor.visit_expr(overrides)?;
358        }
359        ExprKind::New(new_expr) => {
360            visitor.visit_expr(new_expr.class)?;
361            for arg in new_expr.args.iter() {
362                visitor.visit_arg(arg)?;
363            }
364        }
365        ExprKind::PropertyAccess(access) | ExprKind::NullsafePropertyAccess(access) => {
366            visitor.visit_expr(access.object)?;
367            visitor.visit_expr(access.property)?;
368        }
369        ExprKind::MethodCall(call) | ExprKind::NullsafeMethodCall(call) => {
370            visitor.visit_expr(call.object)?;
371            visitor.visit_expr(call.method)?;
372            for arg in call.args.iter() {
373                visitor.visit_arg(arg)?;
374            }
375        }
376        ExprKind::StaticPropertyAccess(access) | ExprKind::ClassConstAccess(access) => {
377            visitor.visit_expr(access.class)?;
378        }
379        ExprKind::ClassConstAccessDynamic { class, member }
380        | ExprKind::StaticPropertyAccessDynamic { class, member } => {
381            visitor.visit_expr(class)?;
382            visitor.visit_expr(member)?;
383        }
384        ExprKind::StaticMethodCall(call) => {
385            visitor.visit_expr(call.class)?;
386            for arg in call.args.iter() {
387                visitor.visit_arg(arg)?;
388            }
389        }
390        ExprKind::Closure(closure) => {
391            walk_function_like(
392                visitor,
393                &closure.attributes,
394                &closure.params,
395                &closure.return_type,
396            )?;
397            for use_var in closure.use_vars.iter() {
398                visitor.visit_closure_use_var(use_var)?;
399            }
400            for stmt in closure.body.iter() {
401                visitor.visit_stmt(stmt)?;
402            }
403        }
404        ExprKind::ArrowFunction(arrow) => {
405            walk_function_like(
406                visitor,
407                &arrow.attributes,
408                &arrow.params,
409                &arrow.return_type,
410            )?;
411            visitor.visit_expr(arrow.body)?;
412        }
413        ExprKind::Match(match_expr) => {
414            visitor.visit_expr(match_expr.subject)?;
415            for arm in match_expr.arms.iter() {
416                visitor.visit_match_arm(arm)?;
417            }
418        }
419        ExprKind::ThrowExpr(expr) => {
420            visitor.visit_expr(expr)?;
421        }
422        ExprKind::Yield(yield_expr) => {
423            if let Some(key) = &yield_expr.key {
424                visitor.visit_expr(key)?;
425            }
426            if let Some(value) = &yield_expr.value {
427                visitor.visit_expr(value)?;
428            }
429        }
430        ExprKind::AnonymousClass(class) => {
431            walk_attributes(visitor, &class.attributes)?;
432            for member in class.members.iter() {
433                visitor.visit_class_member(member)?;
434            }
435        }
436        ExprKind::InterpolatedString(parts)
437        | ExprKind::Heredoc { parts, .. }
438        | ExprKind::ShellExec(parts) => {
439            for part in parts.iter() {
440                if let StringPart::Expr(e) = part {
441                    visitor.visit_expr(e)?;
442                }
443            }
444        }
445        ExprKind::VariableVariable(inner) => {
446            visitor.visit_expr(inner)?;
447        }
448        ExprKind::CallableCreate(cc) => match &cc.kind {
449            CallableCreateKind::Function(name) => visitor.visit_expr(name)?,
450            CallableCreateKind::Method { object, method }
451            | CallableCreateKind::NullsafeMethod { object, method } => {
452                visitor.visit_expr(object)?;
453                visitor.visit_expr(method)?;
454            }
455            CallableCreateKind::StaticMethod { class, .. } => {
456                visitor.visit_expr(class)?;
457            }
458        },
459        ExprKind::Int(_)
460        | ExprKind::Float(_)
461        | ExprKind::String(_)
462        | ExprKind::Bool(_)
463        | ExprKind::Null
464        | ExprKind::Omit
465        | ExprKind::Variable(_)
466        | ExprKind::Identifier(_)
467        | ExprKind::MagicConst(_)
468        | ExprKind::Nowdoc { .. }
469        | ExprKind::Error => {}
470    }
471    ControlFlow::Continue(())
472}
473
474pub fn walk_param<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
475    visitor: &mut V,
476    param: &Param<'arena, 'src>,
477) -> ControlFlow<()> {
478    walk_attributes(visitor, &param.attributes)?;
479    if let Some(type_hint) = &param.type_hint {
480        visitor.visit_type_hint(type_hint)?;
481    }
482    if let Some(default) = &param.default {
483        visitor.visit_expr(default)?;
484    }
485    for hook in param.hooks.iter() {
486        visitor.visit_property_hook(hook)?;
487    }
488    ControlFlow::Continue(())
489}
490
491pub fn walk_arg<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
492    visitor: &mut V,
493    arg: &Arg<'arena, 'src>,
494) -> ControlFlow<()> {
495    visitor.visit_expr(&arg.value)
496}
497
498pub fn walk_class_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
499    visitor: &mut V,
500    member: &ClassMember<'arena, 'src>,
501) -> ControlFlow<()> {
502    match &member.kind {
503        ClassMemberKind::Property(prop) => {
504            walk_property_decl(visitor, prop)?;
505        }
506        ClassMemberKind::Method(method) => {
507            walk_method_decl(visitor, method)?;
508        }
509        ClassMemberKind::ClassConst(cc) => {
510            walk_class_const_decl(visitor, cc)?;
511        }
512        ClassMemberKind::TraitUse(_) => {}
513    }
514    ControlFlow::Continue(())
515}
516
517pub fn walk_property_hook<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
518    visitor: &mut V,
519    hook: &PropertyHook<'arena, 'src>,
520) -> ControlFlow<()> {
521    walk_attributes(visitor, &hook.attributes)?;
522    for param in hook.params.iter() {
523        visitor.visit_param(param)?;
524    }
525    match &hook.body {
526        PropertyHookBody::Block(stmts) => {
527            for stmt in stmts.iter() {
528                visitor.visit_stmt(stmt)?;
529            }
530        }
531        PropertyHookBody::Expression(expr) => {
532            visitor.visit_expr(expr)?;
533        }
534        PropertyHookBody::Abstract => {}
535    }
536    ControlFlow::Continue(())
537}
538
539pub fn walk_enum_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
540    visitor: &mut V,
541    member: &EnumMember<'arena, 'src>,
542) -> ControlFlow<()> {
543    match &member.kind {
544        EnumMemberKind::Case(case) => {
545            walk_attributes(visitor, &case.attributes)?;
546            if let Some(value) = &case.value {
547                visitor.visit_expr(value)?;
548            }
549        }
550        EnumMemberKind::Method(method) => {
551            walk_method_decl(visitor, method)?;
552        }
553        EnumMemberKind::ClassConst(cc) => {
554            walk_class_const_decl(visitor, cc)?;
555        }
556        EnumMemberKind::TraitUse(_) => {}
557    }
558    ControlFlow::Continue(())
559}
560
561pub fn walk_type_hint<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
562    visitor: &mut V,
563    type_hint: &TypeHint<'arena, 'src>,
564) -> ControlFlow<()> {
565    match &type_hint.kind {
566        TypeHintKind::Nullable(inner) => {
567            visitor.visit_type_hint(inner)?;
568        }
569        TypeHintKind::Union(types) | TypeHintKind::Intersection(types) => {
570            for ty in types.iter() {
571                visitor.visit_type_hint(ty)?;
572            }
573        }
574        TypeHintKind::Named(_) | TypeHintKind::Keyword(_, _) => {}
575    }
576    ControlFlow::Continue(())
577}
578
579pub fn walk_attribute<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
580    visitor: &mut V,
581    attribute: &Attribute<'arena, 'src>,
582) -> ControlFlow<()> {
583    for arg in attribute.args.iter() {
584        visitor.visit_arg(arg)?;
585    }
586    ControlFlow::Continue(())
587}
588
589pub fn walk_catch_clause<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
590    visitor: &mut V,
591    catch: &CatchClause<'arena, 'src>,
592) -> ControlFlow<()> {
593    for stmt in catch.body.iter() {
594        visitor.visit_stmt(stmt)?;
595    }
596    ControlFlow::Continue(())
597}
598
599pub fn walk_match_arm<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
600    visitor: &mut V,
601    arm: &MatchArm<'arena, 'src>,
602) -> ControlFlow<()> {
603    if let Some(conditions) = &arm.conditions {
604        for cond in conditions.iter() {
605            visitor.visit_expr(cond)?;
606        }
607    }
608    visitor.visit_expr(&arm.body)
609}
610
611// =============================================================================
612// Internal helpers — shared walking logic to avoid duplication
613// =============================================================================
614
615/// Walks the common parts of any function-like construct:
616/// attributes → params → optional return type.
617fn walk_function_like<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
618    visitor: &mut V,
619    attributes: &[Attribute<'arena, 'src>],
620    params: &[Param<'arena, 'src>],
621    return_type: &Option<TypeHint<'arena, 'src>>,
622) -> ControlFlow<()> {
623    walk_attributes(visitor, attributes)?;
624    for param in params.iter() {
625        visitor.visit_param(param)?;
626    }
627    if let Some(ret) = return_type {
628        visitor.visit_type_hint(ret)?;
629    }
630    ControlFlow::Continue(())
631}
632
633/// Walks a method declaration (shared by ClassMember and EnumMember).
634fn walk_method_decl<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
635    visitor: &mut V,
636    method: &MethodDecl<'arena, 'src>,
637) -> ControlFlow<()> {
638    walk_function_like(
639        visitor,
640        &method.attributes,
641        &method.params,
642        &method.return_type,
643    )?;
644    if let Some(body) = &method.body {
645        for stmt in body.iter() {
646            visitor.visit_stmt(stmt)?;
647        }
648    }
649    ControlFlow::Continue(())
650}
651
652/// Walks a class constant declaration (shared by ClassMember and EnumMember).
653fn walk_class_const_decl<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
654    visitor: &mut V,
655    cc: &ClassConstDecl<'arena, 'src>,
656) -> ControlFlow<()> {
657    walk_attributes(visitor, &cc.attributes)?;
658    if let Some(type_hint) = &cc.type_hint {
659        visitor.visit_type_hint(type_hint)?;
660    }
661    visitor.visit_expr(&cc.value)
662}
663
664/// Walks a property declaration.
665fn walk_property_decl<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
666    visitor: &mut V,
667    prop: &PropertyDecl<'arena, 'src>,
668) -> ControlFlow<()> {
669    walk_attributes(visitor, &prop.attributes)?;
670    if let Some(type_hint) = &prop.type_hint {
671        visitor.visit_type_hint(type_hint)?;
672    }
673    if let Some(default) = &prop.default {
674        visitor.visit_expr(default)?;
675    }
676    for hook in prop.hooks.iter() {
677        visitor.visit_property_hook(hook)?;
678    }
679    ControlFlow::Continue(())
680}
681
682fn walk_attributes<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
683    visitor: &mut V,
684    attributes: &[Attribute<'arena, 'src>],
685) -> ControlFlow<()> {
686    for attr in attributes.iter() {
687        visitor.visit_attribute(attr)?;
688    }
689    ControlFlow::Continue(())
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695    use crate::Span;
696    use std::borrow::Cow;
697
698    // =========================================================================
699    // Unit tests with hand-built ASTs
700    // =========================================================================
701
702    struct VarCounter {
703        count: usize,
704    }
705
706    impl<'arena, 'src> Visitor<'arena, 'src> for VarCounter {
707        fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) -> ControlFlow<()> {
708            if matches!(&expr.kind, ExprKind::Variable(_)) {
709                self.count += 1;
710            }
711            walk_expr(self, expr)
712        }
713    }
714
715    #[test]
716    fn counts_variables() {
717        let arena = bumpalo::Bump::new();
718        let var_x = arena.alloc(Expr {
719            kind: ExprKind::Variable(Cow::Borrowed("x")),
720            span: Span::DUMMY,
721        });
722        let var_y = arena.alloc(Expr {
723            kind: ExprKind::Variable(Cow::Borrowed("y")),
724            span: Span::DUMMY,
725        });
726        let var_z = arena.alloc(Expr {
727            kind: ExprKind::Variable(Cow::Borrowed("z")),
728            span: Span::DUMMY,
729        });
730        let binary = arena.alloc(Expr {
731            kind: ExprKind::Binary(BinaryExpr {
732                left: var_y,
733                op: BinaryOp::Add,
734                right: var_z,
735            }),
736            span: Span::DUMMY,
737        });
738        let assign = arena.alloc(Expr {
739            kind: ExprKind::Assign(AssignExpr {
740                target: var_x,
741                op: AssignOp::Assign,
742                value: binary,
743                by_ref: false,
744            }),
745            span: Span::DUMMY,
746        });
747        let mut stmts = ArenaVec::new_in(&arena);
748        stmts.push(Stmt {
749            kind: StmtKind::Expression(assign),
750            span: Span::DUMMY,
751        });
752        let program = Program {
753            stmts,
754            span: Span::DUMMY,
755        };
756
757        let mut v = VarCounter { count: 0 };
758        let _ = v.visit_program(&program);
759        assert_eq!(v.count, 3);
760    }
761
762    #[test]
763    fn early_termination() {
764        let arena = bumpalo::Bump::new();
765        let var_a = arena.alloc(Expr {
766            kind: ExprKind::Variable(Cow::Borrowed("a")),
767            span: Span::DUMMY,
768        });
769        let var_b = arena.alloc(Expr {
770            kind: ExprKind::Variable(Cow::Borrowed("b")),
771            span: Span::DUMMY,
772        });
773        let binary = arena.alloc(Expr {
774            kind: ExprKind::Binary(BinaryExpr {
775                left: var_a,
776                op: BinaryOp::Add,
777                right: var_b,
778            }),
779            span: Span::DUMMY,
780        });
781        let mut stmts = ArenaVec::new_in(&arena);
782        stmts.push(Stmt {
783            kind: StmtKind::Expression(binary),
784            span: Span::DUMMY,
785        });
786        let program = Program {
787            stmts,
788            span: Span::DUMMY,
789        };
790
791        struct FindFirst {
792            found: Option<String>,
793        }
794        impl<'arena, 'src> Visitor<'arena, 'src> for FindFirst {
795            fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) -> ControlFlow<()> {
796                if let ExprKind::Variable(name) = &expr.kind {
797                    self.found = Some(name.to_string());
798                    return ControlFlow::Break(());
799                }
800                walk_expr(self, expr)
801            }
802        }
803
804        let mut finder = FindFirst { found: None };
805        let result = finder.visit_program(&program);
806        assert!(result.is_break());
807        assert_eq!(finder.found.as_deref(), Some("a"));
808    }
809
810    #[test]
811    fn skip_subtree() {
812        let arena = bumpalo::Bump::new();
813        // 1 + 2; function foo() { 3 + 4; }
814        let one = arena.alloc(Expr {
815            kind: ExprKind::Int(1),
816            span: Span::DUMMY,
817        });
818        let two = arena.alloc(Expr {
819            kind: ExprKind::Int(2),
820            span: Span::DUMMY,
821        });
822        let top = arena.alloc(Expr {
823            kind: ExprKind::Binary(BinaryExpr {
824                left: one,
825                op: BinaryOp::Add,
826                right: two,
827            }),
828            span: Span::DUMMY,
829        });
830        let three = arena.alloc(Expr {
831            kind: ExprKind::Int(3),
832            span: Span::DUMMY,
833        });
834        let four = arena.alloc(Expr {
835            kind: ExprKind::Int(4),
836            span: Span::DUMMY,
837        });
838        let inner = arena.alloc(Expr {
839            kind: ExprKind::Binary(BinaryExpr {
840                left: three,
841                op: BinaryOp::Add,
842                right: four,
843            }),
844            span: Span::DUMMY,
845        });
846        let mut func_body = ArenaVec::new_in(&arena);
847        func_body.push(Stmt {
848            kind: StmtKind::Expression(inner),
849            span: Span::DUMMY,
850        });
851        let func = arena.alloc(FunctionDecl {
852            name: "foo",
853            params: ArenaVec::new_in(&arena),
854            body: func_body,
855            return_type: None,
856            by_ref: false,
857            attributes: ArenaVec::new_in(&arena),
858            doc_comment: None,
859        });
860        let mut stmts = ArenaVec::new_in(&arena);
861        stmts.push(Stmt {
862            kind: StmtKind::Expression(top),
863            span: Span::DUMMY,
864        });
865        stmts.push(Stmt {
866            kind: StmtKind::Function(func),
867            span: Span::DUMMY,
868        });
869        let program = Program {
870            stmts,
871            span: Span::DUMMY,
872        };
873
874        struct SkipFunctions {
875            expr_count: usize,
876        }
877        impl<'arena, 'src> Visitor<'arena, 'src> for SkipFunctions {
878            fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) -> ControlFlow<()> {
879                self.expr_count += 1;
880                walk_expr(self, expr)
881            }
882            fn visit_stmt(&mut self, stmt: &Stmt<'arena, 'src>) -> ControlFlow<()> {
883                if matches!(&stmt.kind, StmtKind::Function(_)) {
884                    return ControlFlow::Continue(());
885                }
886                walk_stmt(self, stmt)
887            }
888        }
889
890        let mut v = SkipFunctions { expr_count: 0 };
891        let _ = v.visit_program(&program);
892        // Only top-level: binary(1, 2) = 3 exprs
893        assert_eq!(v.expr_count, 3);
894    }
895}