Skip to main content

php_ast/
visitor.rs

1use crate::ast::*;
2
3/// Visitor trait for AST traversal. All methods have default implementations
4/// that recursively walk child nodes, so implementors only need to override
5/// the node types they care about.
6pub trait Visitor<'arena, 'src> {
7    fn visit_program(&mut self, program: &Program<'arena, 'src>) {
8        walk_program(self, program);
9    }
10
11    fn visit_stmt(&mut self, stmt: &Stmt<'arena, 'src>) {
12        walk_stmt(self, stmt);
13    }
14
15    fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) {
16        walk_expr(self, expr);
17    }
18
19    fn visit_param(&mut self, param: &Param<'arena, 'src>) {
20        walk_param(self, param);
21    }
22
23    fn visit_arg(&mut self, arg: &Arg<'arena, 'src>) {
24        walk_arg(self, arg);
25    }
26
27    fn visit_class_member(&mut self, member: &ClassMember<'arena, 'src>) {
28        walk_class_member(self, member);
29    }
30
31    fn visit_enum_member(&mut self, member: &EnumMember<'arena, 'src>) {
32        walk_enum_member(self, member);
33    }
34
35    fn visit_property_hook(&mut self, hook: &PropertyHook<'arena, 'src>) {
36        walk_property_hook(self, hook);
37    }
38}
39
40pub fn walk_program<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
41    visitor: &mut V,
42    program: &Program<'arena, 'src>,
43) {
44    for stmt in program.stmts.iter() {
45        visitor.visit_stmt(stmt);
46    }
47}
48
49pub fn walk_stmt<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
50    visitor: &mut V,
51    stmt: &Stmt<'arena, 'src>,
52) {
53    match &stmt.kind {
54        StmtKind::Expression(expr) => {
55            visitor.visit_expr(expr);
56        }
57        StmtKind::Echo(exprs) => {
58            for expr in exprs.iter() {
59                visitor.visit_expr(expr);
60            }
61        }
62        StmtKind::Return(expr) => {
63            if let Some(expr) = expr {
64                visitor.visit_expr(expr);
65            }
66        }
67        StmtKind::Block(stmts) => {
68            for stmt in stmts.iter() {
69                visitor.visit_stmt(stmt);
70            }
71        }
72        StmtKind::If(if_stmt) => {
73            visitor.visit_expr(&if_stmt.condition);
74            visitor.visit_stmt(if_stmt.then_branch);
75            for elseif in if_stmt.elseif_branches.iter() {
76                visitor.visit_expr(&elseif.condition);
77                visitor.visit_stmt(&elseif.body);
78            }
79            if let Some(else_branch) = &if_stmt.else_branch {
80                visitor.visit_stmt(else_branch);
81            }
82        }
83        StmtKind::While(while_stmt) => {
84            visitor.visit_expr(&while_stmt.condition);
85            visitor.visit_stmt(while_stmt.body);
86        }
87        StmtKind::For(for_stmt) => {
88            for expr in for_stmt.init.iter() {
89                visitor.visit_expr(expr);
90            }
91            for expr in for_stmt.condition.iter() {
92                visitor.visit_expr(expr);
93            }
94            for expr in for_stmt.update.iter() {
95                visitor.visit_expr(expr);
96            }
97            visitor.visit_stmt(for_stmt.body);
98        }
99        StmtKind::Foreach(foreach_stmt) => {
100            visitor.visit_expr(&foreach_stmt.expr);
101            if let Some(key) = &foreach_stmt.key {
102                visitor.visit_expr(key);
103            }
104            visitor.visit_expr(&foreach_stmt.value);
105            visitor.visit_stmt(foreach_stmt.body);
106        }
107        StmtKind::DoWhile(do_while) => {
108            visitor.visit_stmt(do_while.body);
109            visitor.visit_expr(&do_while.condition);
110        }
111        StmtKind::Function(func) => {
112            for param in func.params.iter() {
113                visitor.visit_param(param);
114            }
115            for stmt in func.body.iter() {
116                visitor.visit_stmt(stmt);
117            }
118        }
119        StmtKind::Break(expr) | StmtKind::Continue(expr) => {
120            if let Some(expr) = expr {
121                visitor.visit_expr(expr);
122            }
123        }
124        StmtKind::Switch(switch_stmt) => {
125            visitor.visit_expr(&switch_stmt.expr);
126            for case in switch_stmt.cases.iter() {
127                if let Some(value) = &case.value {
128                    visitor.visit_expr(value);
129                }
130                for stmt in case.body.iter() {
131                    visitor.visit_stmt(stmt);
132                }
133            }
134        }
135        StmtKind::Throw(expr) => {
136            visitor.visit_expr(expr);
137        }
138        StmtKind::TryCatch(tc) => {
139            for stmt in tc.body.iter() {
140                visitor.visit_stmt(stmt);
141            }
142            for catch in tc.catches.iter() {
143                for stmt in catch.body.iter() {
144                    visitor.visit_stmt(stmt);
145                }
146            }
147            if let Some(finally) = &tc.finally {
148                for stmt in finally.iter() {
149                    visitor.visit_stmt(stmt);
150                }
151            }
152        }
153        StmtKind::Declare(decl) => {
154            if let Some(body) = decl.body {
155                visitor.visit_stmt(body);
156            }
157        }
158        StmtKind::Unset(exprs) | StmtKind::Global(exprs) => {
159            for expr in exprs.iter() {
160                visitor.visit_expr(expr);
161            }
162        }
163        StmtKind::Class(class) => {
164            for member in class.members.iter() {
165                visitor.visit_class_member(member);
166            }
167        }
168        StmtKind::Interface(iface) => {
169            for member in iface.members.iter() {
170                visitor.visit_class_member(member);
171            }
172        }
173        StmtKind::Trait(trait_decl) => {
174            for member in trait_decl.members.iter() {
175                visitor.visit_class_member(member);
176            }
177        }
178        StmtKind::Enum(enum_decl) => {
179            for member in enum_decl.members.iter() {
180                visitor.visit_enum_member(member);
181            }
182        }
183        StmtKind::Namespace(ns) => {
184            if let NamespaceBody::Braced(stmts) = &ns.body {
185                for stmt in stmts.iter() {
186                    visitor.visit_stmt(stmt);
187                }
188            }
189        }
190        StmtKind::Const(items) => {
191            for item in items.iter() {
192                visitor.visit_expr(&item.value);
193            }
194        }
195        StmtKind::StaticVar(vars) => {
196            for var in vars.iter() {
197                if let Some(default) = &var.default {
198                    visitor.visit_expr(default);
199                }
200            }
201        }
202        StmtKind::Use(_)
203        | StmtKind::Goto(_)
204        | StmtKind::Label(_)
205        | StmtKind::Nop
206        | StmtKind::InlineHtml(_)
207        | StmtKind::HaltCompiler(_)
208        | StmtKind::Error => {}
209    }
210}
211
212pub fn walk_expr<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
213    visitor: &mut V,
214    expr: &Expr<'arena, 'src>,
215) {
216    match &expr.kind {
217        ExprKind::Assign(assign) => {
218            visitor.visit_expr(assign.target);
219            visitor.visit_expr(assign.value);
220        }
221        ExprKind::Binary(binary) => {
222            visitor.visit_expr(binary.left);
223            visitor.visit_expr(binary.right);
224        }
225        ExprKind::UnaryPrefix(unary) => {
226            visitor.visit_expr(unary.operand);
227        }
228        ExprKind::UnaryPostfix(unary) => {
229            visitor.visit_expr(unary.operand);
230        }
231        ExprKind::Ternary(ternary) => {
232            visitor.visit_expr(ternary.condition);
233            if let Some(then_expr) = &ternary.then_expr {
234                visitor.visit_expr(then_expr);
235            }
236            visitor.visit_expr(ternary.else_expr);
237        }
238        ExprKind::NullCoalesce(nc) => {
239            visitor.visit_expr(nc.left);
240            visitor.visit_expr(nc.right);
241        }
242        ExprKind::FunctionCall(call) => {
243            visitor.visit_expr(call.name);
244            for arg in call.args.iter() {
245                visitor.visit_arg(arg);
246            }
247        }
248        ExprKind::Array(elements) => {
249            for elem in elements.iter() {
250                if let Some(key) = &elem.key {
251                    visitor.visit_expr(key);
252                }
253                visitor.visit_expr(&elem.value);
254            }
255        }
256        ExprKind::ArrayAccess(access) => {
257            visitor.visit_expr(access.array);
258            if let Some(index) = &access.index {
259                visitor.visit_expr(index);
260            }
261        }
262        ExprKind::Print(expr) => {
263            visitor.visit_expr(expr);
264        }
265        ExprKind::Parenthesized(expr) => {
266            visitor.visit_expr(expr);
267        }
268        ExprKind::Cast(_, expr) => {
269            visitor.visit_expr(expr);
270        }
271        ExprKind::ErrorSuppress(expr) => {
272            visitor.visit_expr(expr);
273        }
274        ExprKind::Isset(exprs) => {
275            for expr in exprs.iter() {
276                visitor.visit_expr(expr);
277            }
278        }
279        ExprKind::Empty(expr) => {
280            visitor.visit_expr(expr);
281        }
282        ExprKind::Include(_, expr) => {
283            visitor.visit_expr(expr);
284        }
285        ExprKind::Eval(expr) => {
286            visitor.visit_expr(expr);
287        }
288        ExprKind::Exit(expr) => {
289            if let Some(expr) = expr {
290                visitor.visit_expr(expr);
291            }
292        }
293        ExprKind::Clone(expr) => {
294            visitor.visit_expr(expr);
295        }
296        ExprKind::CloneWith(object, overrides) => {
297            visitor.visit_expr(object);
298            visitor.visit_expr(overrides);
299        }
300        ExprKind::New(new_expr) => {
301            visitor.visit_expr(new_expr.class);
302            for arg in new_expr.args.iter() {
303                visitor.visit_arg(arg);
304            }
305        }
306        ExprKind::PropertyAccess(access) | ExprKind::NullsafePropertyAccess(access) => {
307            visitor.visit_expr(access.object);
308            visitor.visit_expr(access.property);
309        }
310        ExprKind::MethodCall(call) | ExprKind::NullsafeMethodCall(call) => {
311            visitor.visit_expr(call.object);
312            visitor.visit_expr(call.method);
313            for arg in call.args.iter() {
314                visitor.visit_arg(arg);
315            }
316        }
317        ExprKind::StaticPropertyAccess(access) | ExprKind::ClassConstAccess(access) => {
318            visitor.visit_expr(access.class);
319        }
320        ExprKind::ClassConstAccessDynamic { class, member }
321        | ExprKind::StaticPropertyAccessDynamic { class, member } => {
322            visitor.visit_expr(class);
323            visitor.visit_expr(member);
324        }
325        ExprKind::StaticMethodCall(call) => {
326            visitor.visit_expr(call.class);
327            for arg in call.args.iter() {
328                visitor.visit_arg(arg);
329            }
330        }
331        ExprKind::Closure(closure) => {
332            for param in closure.params.iter() {
333                visitor.visit_param(param);
334            }
335            for stmt in closure.body.iter() {
336                visitor.visit_stmt(stmt);
337            }
338        }
339        ExprKind::ArrowFunction(arrow) => {
340            for param in arrow.params.iter() {
341                visitor.visit_param(param);
342            }
343            visitor.visit_expr(arrow.body);
344        }
345        ExprKind::Match(match_expr) => {
346            visitor.visit_expr(match_expr.subject);
347            for arm in match_expr.arms.iter() {
348                if let Some(conditions) = &arm.conditions {
349                    for cond in conditions.iter() {
350                        visitor.visit_expr(cond);
351                    }
352                }
353                visitor.visit_expr(&arm.body);
354            }
355        }
356        ExprKind::ThrowExpr(expr) => {
357            visitor.visit_expr(expr);
358        }
359        ExprKind::Yield(yield_expr) => {
360            if let Some(key) = &yield_expr.key {
361                visitor.visit_expr(key);
362            }
363            if let Some(value) = &yield_expr.value {
364                visitor.visit_expr(value);
365            }
366        }
367        ExprKind::AnonymousClass(class) => {
368            for member in class.members.iter() {
369                visitor.visit_class_member(member);
370            }
371        }
372        ExprKind::InterpolatedString(parts)
373        | ExprKind::Heredoc { parts, .. }
374        | ExprKind::ShellExec(parts) => {
375            for part in parts.iter() {
376                if let StringPart::Expr(e) = part {
377                    visitor.visit_expr(e);
378                }
379            }
380        }
381        ExprKind::VariableVariable(inner) => {
382            visitor.visit_expr(inner);
383        }
384        ExprKind::CallableCreate(cc) => match &cc.kind {
385            CallableCreateKind::Function(name) => visitor.visit_expr(name),
386            CallableCreateKind::Method { object, method }
387            | CallableCreateKind::NullsafeMethod { object, method } => {
388                visitor.visit_expr(object);
389                visitor.visit_expr(method);
390            }
391            CallableCreateKind::StaticMethod { class, .. } => {
392                visitor.visit_expr(class);
393            }
394        },
395        ExprKind::Int(_)
396        | ExprKind::Float(_)
397        | ExprKind::String(_)
398        | ExprKind::Bool(_)
399        | ExprKind::Null
400        | ExprKind::Variable(_)
401        | ExprKind::Identifier(_)
402        | ExprKind::MagicConst(_)
403        | ExprKind::Nowdoc { .. }
404        | ExprKind::Error => {}
405    }
406}
407
408pub fn walk_param<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
409    visitor: &mut V,
410    param: &Param<'arena, 'src>,
411) {
412    if let Some(default) = &param.default {
413        visitor.visit_expr(default);
414    }
415    for hook in param.hooks.iter() {
416        visitor.visit_property_hook(hook);
417    }
418}
419
420pub fn walk_arg<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
421    visitor: &mut V,
422    arg: &Arg<'arena, 'src>,
423) {
424    visitor.visit_expr(&arg.value);
425}
426
427pub fn walk_class_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
428    visitor: &mut V,
429    member: &ClassMember<'arena, 'src>,
430) {
431    match &member.kind {
432        ClassMemberKind::Property(prop) => {
433            if let Some(default) = &prop.default {
434                visitor.visit_expr(default);
435            }
436            for hook in prop.hooks.iter() {
437                visitor.visit_property_hook(hook);
438            }
439        }
440        ClassMemberKind::Method(method) => {
441            for param in method.params.iter() {
442                visitor.visit_param(param);
443            }
444            if let Some(body) = &method.body {
445                for stmt in body.iter() {
446                    visitor.visit_stmt(stmt);
447                }
448            }
449        }
450        ClassMemberKind::ClassConst(cc) => {
451            visitor.visit_expr(&cc.value);
452        }
453        ClassMemberKind::TraitUse(_) => {}
454    }
455}
456
457pub fn walk_property_hook<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
458    visitor: &mut V,
459    hook: &PropertyHook<'arena, 'src>,
460) {
461    for param in hook.params.iter() {
462        visitor.visit_param(param);
463    }
464    match &hook.body {
465        PropertyHookBody::Block(stmts) => {
466            for stmt in stmts.iter() {
467                visitor.visit_stmt(stmt);
468            }
469        }
470        PropertyHookBody::Expression(expr) => {
471            visitor.visit_expr(expr);
472        }
473        PropertyHookBody::Abstract => {}
474    }
475}
476
477pub fn walk_enum_member<'arena, 'src, V: Visitor<'arena, 'src> + ?Sized>(
478    visitor: &mut V,
479    member: &EnumMember<'arena, 'src>,
480) {
481    match &member.kind {
482        EnumMemberKind::Case(case) => {
483            if let Some(value) = &case.value {
484                visitor.visit_expr(value);
485            }
486        }
487        EnumMemberKind::Method(method) => {
488            for param in method.params.iter() {
489                visitor.visit_param(param);
490            }
491            if let Some(body) = &method.body {
492                for stmt in body.iter() {
493                    visitor.visit_stmt(stmt);
494                }
495            }
496        }
497        EnumMemberKind::ClassConst(cc) => {
498            visitor.visit_expr(&cc.value);
499        }
500        EnumMemberKind::TraitUse(_) => {}
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::Span;
508
509    /// A simple visitor that counts variables
510    struct VarCounter {
511        count: usize,
512    }
513
514    impl<'arena, 'src> Visitor<'arena, 'src> for VarCounter {
515        fn visit_expr(&mut self, expr: &Expr<'arena, 'src>) {
516            if matches!(&expr.kind, ExprKind::Variable(_)) {
517                self.count += 1;
518            }
519            walk_expr(self, expr);
520        }
521    }
522
523    #[test]
524    fn test_visitor_counts_variables() {
525        let arena = bumpalo::Bump::new();
526
527        let var_x = arena.alloc(Expr {
528            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("x")),
529            span: Span::DUMMY,
530        });
531        let var_y = arena.alloc(Expr {
532            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("y")),
533            span: Span::DUMMY,
534        });
535        let var_z = arena.alloc(Expr {
536            kind: ExprKind::Variable(std::borrow::Cow::Borrowed("z")),
537            span: Span::DUMMY,
538        });
539
540        let binary = arena.alloc(Expr {
541            kind: ExprKind::Binary(BinaryExpr {
542                left: var_y,
543                op: BinaryOp::Add,
544                right: var_z,
545            }),
546            span: Span::DUMMY,
547        });
548
549        let assign_expr = arena.alloc(Expr {
550            kind: ExprKind::Assign(AssignExpr {
551                target: var_x,
552                op: AssignOp::Assign,
553                value: binary,
554            }),
555            span: Span::DUMMY,
556        });
557
558        let mut stmts = ArenaVec::new_in(&arena);
559        stmts.push(Stmt {
560            kind: StmtKind::Expression(assign_expr),
561            span: Span::DUMMY,
562        });
563
564        let program = Program {
565            stmts,
566            span: Span::DUMMY,
567        };
568
569        let mut counter = VarCounter { count: 0 };
570        counter.visit_program(&program);
571        assert_eq!(counter.count, 3);
572    }
573}