Skip to main content

php_ast/
visitor.rs

1use crate::ast::*;
2
3/// Visitor trait for AST traversal. All methods have default implementations
4/// that recursively walk child nodes, so implementors only need to override
5/// the node types they care about.
6pub trait Visitor<'arena, 'src> {
7    fn visit_program(&mut self, program: &Program<'arena, 'src>) {
8        walk_program(self, program);
9    }
10
11    fn visit_stmt(&mut self, stmt: &Stmt<'arena, 'src>) {
12        walk_stmt(self, stmt);
13    }
14
15    fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) {
16        walk_expr(self, expr);
17    }
18
19    fn visit_param(&mut self, param: &Param<'arena, 'src>) {
20        walk_param(self, param);
21    }
22
23    fn visit_arg(&mut self, arg: &Arg<'arena, 'src>) {
24        walk_arg(self, arg);
25    }
26
27    fn visit_class_member(&mut self, member: &ClassMember<'arena, 'src>) {
28        walk_class_member(self, member);
29    }
30
31    fn visit_enum_member(&mut self, member: &EnumMember<'arena, 'src>) {
32        walk_enum_member(self, member);
33    }
34
35    fn visit_property_hook(&mut self, hook: &PropertyHook<'arena, 'src>) {
36        walk_property_hook(self, hook);
37    }
38}
39
40pub fn walk_program<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
41    visitor: &mut V,
42    program: &Program<'arena, 'src>,
43) {
44    for stmt in program.stmts.iter() {
45        visitor.visit_stmt(stmt);
46    }
47}
48
49pub fn walk_stmt<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
50    visitor: &mut V,
51    stmt: &Stmt<'arena, 'src>,
52) {
53    match &stmt.kind {
54        StmtKind::Expression(expr) => {
55            visitor.visit_expr(expr);
56        }
57        StmtKind::Echo(exprs) => {
58            for expr in exprs.iter() {
59                visitor.visit_expr(expr);
60            }
61        }
62        StmtKind::Return(expr) => {
63            if let Some(expr) = expr {
64                visitor.visit_expr(expr);
65            }
66        }
67        StmtKind::Block(stmts) => {
68            for stmt in stmts.iter() {
69                visitor.visit_stmt(stmt);
70            }
71        }
72        StmtKind::If(if_stmt) => {
73            visitor.visit_expr(&if_stmt.condition);
74            visitor.visit_stmt(if_stmt.then_branch);
75            for elseif in if_stmt.elseif_branches.iter() {
76                visitor.visit_expr(&elseif.condition);
77                visitor.visit_stmt(&elseif.body);
78            }
79            if let Some(else_branch) = &if_stmt.else_branch {
80                visitor.visit_stmt(else_branch);
81            }
82        }
83        StmtKind::While(while_stmt) => {
84            visitor.visit_expr(&while_stmt.condition);
85            visitor.visit_stmt(while_stmt.body);
86        }
87        StmtKind::For(for_stmt) => {
88            for expr in for_stmt.init.iter() {
89                visitor.visit_expr(expr);
90            }
91            for expr in for_stmt.condition.iter() {
92                visitor.visit_expr(expr);
93            }
94            for expr in for_stmt.update.iter() {
95                visitor.visit_expr(expr);
96            }
97            visitor.visit_stmt(for_stmt.body);
98        }
99        StmtKind::Foreach(foreach_stmt) => {
100            visitor.visit_expr(&foreach_stmt.expr);
101            if let Some(key) = &foreach_stmt.key {
102                visitor.visit_expr(key);
103            }
104            visitor.visit_expr(&foreach_stmt.value);
105            visitor.visit_stmt(foreach_stmt.body);
106        }
107        StmtKind::DoWhile(do_while) => {
108            visitor.visit_stmt(do_while.body);
109            visitor.visit_expr(&do_while.condition);
110        }
111        StmtKind::Function(func) => {
112            for param in func.params.iter() {
113                visitor.visit_param(param);
114            }
115            for stmt in func.body.iter() {
116                visitor.visit_stmt(stmt);
117            }
118        }
119        StmtKind::Break(expr) | StmtKind::Continue(expr) => {
120            if let Some(expr) = expr {
121                visitor.visit_expr(expr);
122            }
123        }
124        StmtKind::Switch(switch_stmt) => {
125            visitor.visit_expr(&switch_stmt.expr);
126            for case in switch_stmt.cases.iter() {
127                if let Some(value) = &case.value {
128                    visitor.visit_expr(value);
129                }
130                for stmt in case.body.iter() {
131                    visitor.visit_stmt(stmt);
132                }
133            }
134        }
135        StmtKind::Throw(expr) => {
136            visitor.visit_expr(expr);
137        }
138        StmtKind::TryCatch(tc) => {
139            for stmt in tc.body.iter() {
140                visitor.visit_stmt(stmt);
141            }
142            for catch in tc.catches.iter() {
143                for stmt in catch.body.iter() {
144                    visitor.visit_stmt(stmt);
145                }
146            }
147            if let Some(finally) = &tc.finally {
148                for stmt in finally.iter() {
149                    visitor.visit_stmt(stmt);
150                }
151            }
152        }
153        StmtKind::Declare(decl) => {
154            if let Some(body) = decl.body {
155                visitor.visit_stmt(body);
156            }
157        }
158        StmtKind::Unset(exprs) | StmtKind::Global(exprs) => {
159            for expr in exprs.iter() {
160                visitor.visit_expr(expr);
161            }
162        }
163        StmtKind::Class(class) => {
164            for member in class.members.iter() {
165                visitor.visit_class_member(member);
166            }
167        }
168        StmtKind::Interface(iface) => {
169            for member in iface.members.iter() {
170                visitor.visit_class_member(member);
171            }
172        }
173        StmtKind::Trait(trait_decl) => {
174            for member in trait_decl.members.iter() {
175                visitor.visit_class_member(member);
176            }
177        }
178        StmtKind::Enum(enum_decl) => {
179            for member in enum_decl.members.iter() {
180                visitor.visit_enum_member(member);
181            }
182        }
183        StmtKind::Namespace(ns) => {
184            if let NamespaceBody::Braced(stmts) = &ns.body {
185                for stmt in stmts.iter() {
186                    visitor.visit_stmt(stmt);
187                }
188            }
189        }
190        StmtKind::Const(items) => {
191            for item in items.iter() {
192                visitor.visit_expr(&item.value);
193            }
194        }
195        StmtKind::StaticVar(vars) => {
196            for var in vars.iter() {
197                if let Some(default) = &var.default {
198                    visitor.visit_expr(default);
199                }
200            }
201        }
202        StmtKind::Use(_)
203        | StmtKind::Goto(_)
204        | StmtKind::Label(_)
205        | StmtKind::Nop
206        | StmtKind::InlineHtml(_)
207        | StmtKind::HaltCompiler(_)
208        | StmtKind::Error => {}
209    }
210}
211
212pub fn walk_expr<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
213    visitor: &mut V,
214    expr: &Expr<'arena, 'src>,
215) {
216    match &expr.kind {
217        ExprKind::Assign(assign) => {
218            visitor.visit_expr(assign.target);
219            visitor.visit_expr(assign.value);
220        }
221        ExprKind::Binary(binary) => {
222            visitor.visit_expr(binary.left);
223            visitor.visit_expr(binary.right);
224        }
225        ExprKind::UnaryPrefix(unary) => {
226            visitor.visit_expr(unary.operand);
227        }
228        ExprKind::UnaryPostfix(unary) => {
229            visitor.visit_expr(unary.operand);
230        }
231        ExprKind::Ternary(ternary) => {
232            visitor.visit_expr(ternary.condition);
233            if let Some(then_expr) = &ternary.then_expr {
234                visitor.visit_expr(then_expr);
235            }
236            visitor.visit_expr(ternary.else_expr);
237        }
238        ExprKind::NullCoalesce(nc) => {
239            visitor.visit_expr(nc.left);
240            visitor.visit_expr(nc.right);
241        }
242        ExprKind::FunctionCall(call) => {
243            visitor.visit_expr(call.name);
244            for arg in call.args.iter() {
245                visitor.visit_arg(arg);
246            }
247        }
248        ExprKind::Array(elements) => {
249            for elem in elements.iter() {
250                if let Some(key) = &elem.key {
251                    visitor.visit_expr(key);
252                }
253                visitor.visit_expr(&elem.value);
254            }
255        }
256        ExprKind::ArrayAccess(access) => {
257            visitor.visit_expr(access.array);
258            if let Some(index) = &access.index {
259                visitor.visit_expr(index);
260            }
261        }
262        ExprKind::Print(expr) => {
263            visitor.visit_expr(expr);
264        }
265        ExprKind::Parenthesized(expr) => {
266            visitor.visit_expr(expr);
267        }
268        ExprKind::Cast(_, expr) => {
269            visitor.visit_expr(expr);
270        }
271        ExprKind::ErrorSuppress(expr) => {
272            visitor.visit_expr(expr);
273        }
274        ExprKind::Isset(exprs) => {
275            for expr in exprs.iter() {
276                visitor.visit_expr(expr);
277            }
278        }
279        ExprKind::Empty(expr) => {
280            visitor.visit_expr(expr);
281        }
282        ExprKind::Include(_, expr) => {
283            visitor.visit_expr(expr);
284        }
285        ExprKind::Eval(expr) => {
286            visitor.visit_expr(expr);
287        }
288        ExprKind::Exit(expr) => {
289            if let Some(expr) = expr {
290                visitor.visit_expr(expr);
291            }
292        }
293        ExprKind::Clone(expr) => {
294            visitor.visit_expr(expr);
295        }
296        ExprKind::New(new_expr) => {
297            visitor.visit_expr(new_expr.class);
298            for arg in new_expr.args.iter() {
299                visitor.visit_arg(arg);
300            }
301        }
302        ExprKind::PropertyAccess(access) | ExprKind::NullsafePropertyAccess(access) => {
303            visitor.visit_expr(access.object);
304            visitor.visit_expr(access.property);
305        }
306        ExprKind::MethodCall(call) | ExprKind::NullsafeMethodCall(call) => {
307            visitor.visit_expr(call.object);
308            visitor.visit_expr(call.method);
309            for arg in call.args.iter() {
310                visitor.visit_arg(arg);
311            }
312        }
313        ExprKind::StaticPropertyAccess(access) | ExprKind::ClassConstAccess(access) => {
314            visitor.visit_expr(access.class);
315        }
316        ExprKind::ClassConstAccessDynamic { class, member }
317        | ExprKind::StaticPropertyAccessDynamic { class, member } => {
318            visitor.visit_expr(class);
319            visitor.visit_expr(member);
320        }
321        ExprKind::StaticMethodCall(call) => {
322            visitor.visit_expr(call.class);
323            for arg in call.args.iter() {
324                visitor.visit_arg(arg);
325            }
326        }
327        ExprKind::Closure(closure) => {
328            for param in closure.params.iter() {
329                visitor.visit_param(param);
330            }
331            for stmt in closure.body.iter() {
332                visitor.visit_stmt(stmt);
333            }
334        }
335        ExprKind::ArrowFunction(arrow) => {
336            for param in arrow.params.iter() {
337                visitor.visit_param(param);
338            }
339            visitor.visit_expr(arrow.body);
340        }
341        ExprKind::Match(match_expr) => {
342            visitor.visit_expr(match_expr.subject);
343            for arm in match_expr.arms.iter() {
344                if let Some(conditions) = &arm.conditions {
345                    for cond in conditions.iter() {
346                        visitor.visit_expr(cond);
347                    }
348                }
349                visitor.visit_expr(&arm.body);
350            }
351        }
352        ExprKind::ThrowExpr(expr) => {
353            visitor.visit_expr(expr);
354        }
355        ExprKind::Yield(yield_expr) => {
356            if let Some(key) = &yield_expr.key {
357                visitor.visit_expr(key);
358            }
359            if let Some(value) = &yield_expr.value {
360                visitor.visit_expr(value);
361            }
362        }
363        ExprKind::AnonymousClass(class) => {
364            for member in class.members.iter() {
365                visitor.visit_class_member(member);
366            }
367        }
368        ExprKind::InterpolatedString(parts)
369        | ExprKind::Heredoc { parts, .. }
370        | ExprKind::ShellExec(parts) => {
371            for part in parts.iter() {
372                if let StringPart::Expr(e) = part {
373                    visitor.visit_expr(e);
374                }
375            }
376        }
377        ExprKind::VariableVariable(inner) => {
378            visitor.visit_expr(inner);
379        }
380        ExprKind::CallableCreate(cc) => match &cc.kind {
381            CallableCreateKind::Function(name) => visitor.visit_expr(name),
382            CallableCreateKind::Method { object, method }
383            | CallableCreateKind::NullsafeMethod { object, method } => {
384                visitor.visit_expr(object);
385                visitor.visit_expr(method);
386            }
387            CallableCreateKind::StaticMethod { class, .. } => {
388                visitor.visit_expr(class);
389            }
390        },
391        ExprKind::Int(_)
392        | ExprKind::Float(_)
393        | ExprKind::String(_)
394        | ExprKind::Bool(_)
395        | ExprKind::Null
396        | ExprKind::Variable(_)
397        | ExprKind::Identifier(_)
398        | ExprKind::MagicConst(_)
399        | ExprKind::Nowdoc { .. }
400        | ExprKind::Error => {}
401    }
402}
403
404pub fn walk_param<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
405    visitor: &mut V,
406    param: &Param<'arena, 'src>,
407) {
408    if let Some(default) = &param.default {
409        visitor.visit_expr(default);
410    }
411    for hook in param.hooks.iter() {
412        visitor.visit_property_hook(hook);
413    }
414}
415
416pub fn walk_arg<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
417    visitor: &mut V,
418    arg: &Arg<'arena, 'src>,
419) {
420    visitor.visit_expr(&arg.value);
421}
422
423pub fn walk_class_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
424    visitor: &mut V,
425    member: &ClassMember<'arena, 'src>,
426) {
427    match &member.kind {
428        ClassMemberKind::Property(prop) => {
429            if let Some(default) = &prop.default {
430                visitor.visit_expr(default);
431            }
432            for hook in prop.hooks.iter() {
433                visitor.visit_property_hook(hook);
434            }
435        }
436        ClassMemberKind::Method(method) => {
437            for param in method.params.iter() {
438                visitor.visit_param(param);
439            }
440            if let Some(body) = &method.body {
441                for stmt in body.iter() {
442                    visitor.visit_stmt(stmt);
443                }
444            }
445        }
446        ClassMemberKind::ClassConst(cc) => {
447            visitor.visit_expr(&cc.value);
448        }
449        ClassMemberKind::TraitUse(_) => {}
450    }
451}
452
453pub fn walk_property_hook<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
454    visitor: &mut V,
455    hook: &PropertyHook<'arena, 'src>,
456) {
457    for param in hook.params.iter() {
458        visitor.visit_param(param);
459    }
460    match &hook.body {
461        PropertyHookBody::Block(stmts) => {
462            for stmt in stmts.iter() {
463                visitor.visit_stmt(stmt);
464            }
465        }
466        PropertyHookBody::Expression(expr) => {
467            visitor.visit_expr(expr);
468        }
469        PropertyHookBody::Abstract => {}
470    }
471}
472
473pub fn walk_enum_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
474    visitor: &mut V,
475    member: &EnumMember<'arena, 'src>,
476) {
477    match &member.kind {
478        EnumMemberKind::Case(case) => {
479            if let Some(value) = &case.value {
480                visitor.visit_expr(value);
481            }
482        }
483        EnumMemberKind::Method(method) => {
484            for param in method.params.iter() {
485                visitor.visit_param(param);
486            }
487            if let Some(body) = &method.body {
488                for stmt in body.iter() {
489                    visitor.visit_stmt(stmt);
490                }
491            }
492        }
493        EnumMemberKind::ClassConst(cc) => {
494            visitor.visit_expr(&cc.value);
495        }
496        EnumMemberKind::TraitUse(_) => {}
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::Span;
504
505    /// A simple visitor that counts variables
506    struct VarCounter {
507        count: usize,
508    }
509
510    impl<'arena, 'src> Visitor<'arena, 'src> for VarCounter {
511        fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) {
512            if matches!(&expr.kind, ExprKind::Variable(_)) {
513                self.count += 1;
514            }
515            walk_expr(self, expr);
516        }
517    }
518
519    #[test]
520    fn test_visitor_counts_variables() {
521        let arena = bumpalo::Bump::new();
522
523        let var_x = arena.alloc(Expr {
524            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("x")),
525            span: Span::DUMMY,
526        });
527        let var_y = arena.alloc(Expr {
528            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("y")),
529            span: Span::DUMMY,
530        });
531        let var_z = arena.alloc(Expr {
532            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("z")),
533            span: Span::DUMMY,
534        });
535
536        let binary = arena.alloc(Expr {
537            kind: ExprKind::Binary(BinaryExpr {
538                left: var_y,
539                op: BinaryOp::Add,
540                right: var_z,
541            }),
542            span: Span::DUMMY,
543        });
544
545        let assign_expr = arena.alloc(Expr {
546            kind: ExprKind::Assign(AssignExpr {
547                target: var_x,
548                op: AssignOp::Assign,
549                value: binary,
550            }),
551            span: Span::DUMMY,
552        });
553
554        let mut stmts = ArenaVec::new_in(&arena);
555        stmts.push(Stmt {
556            kind: StmtKind::Expression(assign_expr),
557            span: Span::DUMMY,
558        });
559
560        let program = Program {
561            stmts,
562            span: Span::DUMMY,
563        };
564
565        let mut counter = VarCounter { count: 0 };
566        counter.visit_program(&program);
567        assert_eq!(counter.count, 3);
568    }
569}