Skip to main content

harn_parser/
visit.rs

1//! Generic AST visitor used by the linter, formatter, and any other
2//! crate that needs to walk every `SNode` in a parsed program.
3//!
4//! Centralizing this here keeps a single source of truth for which
5//! children each `Node` variant has — adding a new variant requires
6//! one edit (in [`walk_children`]) and every consumer benefits.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use harn_parser::visit::walk_program;
12//! let mut count = 0;
13//! walk_program(&program, &mut |node| {
14//!     if matches!(&node.node, harn_parser::Node::FunctionCall { .. }) {
15//!         count += 1;
16//!     }
17//! });
18//! ```
19//!
20//! The visitor invokes the closure on each node *before* recursing
21//! into its children (pre-order). To stop recursion at a particular
22//! node, prefer using [`walk_children`] directly.
23
24use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase};
25
26/// Walk every node in `program` in pre-order, invoking `visitor` on
27/// each.
28pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29    for node in program {
30        walk_node(node, visitor);
31    }
32}
33
34/// Visit `node`, then recurse into its children.
35pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36    visitor(node);
37    walk_children(node, visitor);
38}
39
40/// Recurse into `node`'s children without re-visiting `node` itself.
41/// Useful when a caller wants to handle the parent specially and then
42/// continue the default traversal.
43pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44    match &node.node {
45        Node::AttributedDecl { attributes, inner } => {
46            for attr in attributes {
47                for arg in &attr.args {
48                    walk_node(&arg.value, visitor);
49                }
50            }
51            walk_node(inner, visitor);
52        }
53        Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
54            walk_nodes(body, visitor);
55        }
56        Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
57            walk_binding_pattern(pattern, visitor);
58            walk_node(value, visitor);
59        }
60        Node::EnumDecl { .. }
61        | Node::StructDecl { .. }
62        | Node::InterfaceDecl { .. }
63        | Node::ImportDecl { .. }
64        | Node::SelectiveImport { .. }
65        | Node::TypeDecl { .. }
66        | Node::BreakStmt
67        | Node::ContinueStmt => {}
68        Node::ImplBlock { methods, .. } => walk_nodes(methods, visitor),
69        Node::IfElse {
70            condition,
71            then_body,
72            else_body,
73        } => {
74            walk_node(condition, visitor);
75            walk_nodes(then_body, visitor);
76            if let Some(body) = else_body {
77                walk_nodes(body, visitor);
78            }
79        }
80        Node::ForIn {
81            pattern,
82            iterable,
83            body,
84        } => {
85            walk_binding_pattern(pattern, visitor);
86            walk_node(iterable, visitor);
87            walk_nodes(body, visitor);
88        }
89        Node::MatchExpr { value, arms } => {
90            walk_node(value, visitor);
91            for arm in arms {
92                walk_match_arm(arm, visitor);
93            }
94        }
95        Node::WhileLoop { condition, body } => {
96            walk_node(condition, visitor);
97            walk_nodes(body, visitor);
98        }
99        Node::Retry { count, body } => {
100            walk_node(count, visitor);
101            walk_nodes(body, visitor);
102        }
103        Node::CostRoute { options, body } => {
104            walk_option_values(options, visitor);
105            walk_nodes(body, visitor);
106        }
107        Node::ReturnStmt { value } | Node::YieldExpr { value } => {
108            if let Some(value) = value {
109                walk_node(value, visitor);
110            }
111        }
112        Node::TryCatch {
113            has_catch: _,
114            body,
115            catch_body,
116            finally_body,
117            ..
118        } => {
119            walk_nodes(body, visitor);
120            walk_nodes(catch_body, visitor);
121            if let Some(body) = finally_body {
122                walk_nodes(body, visitor);
123            }
124        }
125        Node::TryExpr { body }
126        | Node::SpawnExpr { body }
127        | Node::DeferStmt { body }
128        | Node::MutexBlock { body }
129        | Node::Block(body)
130        | Node::Closure { body, .. } => walk_nodes(body, visitor),
131        Node::FnDecl { body, .. } | Node::ToolDecl { body, .. } => {
132            walk_nodes(body, visitor);
133        }
134        Node::SkillDecl { fields, .. } => walk_field_values(fields, visitor),
135        Node::EvalPackDecl {
136            fields,
137            body,
138            summarize,
139            ..
140        } => {
141            walk_field_values(fields, visitor);
142            walk_nodes(body, visitor);
143            if let Some(body) = summarize {
144                walk_nodes(body, visitor);
145            }
146        }
147        Node::RangeExpr { start, end, .. } => {
148            walk_node(start, visitor);
149            walk_node(end, visitor);
150        }
151        Node::GuardStmt {
152            condition,
153            else_body,
154        } => {
155            walk_node(condition, visitor);
156            walk_nodes(else_body, visitor);
157        }
158        Node::RequireStmt { condition, message } => {
159            walk_node(condition, visitor);
160            if let Some(message) = message {
161                walk_node(message, visitor);
162            }
163        }
164        Node::DeadlineBlock { duration, body } => {
165            walk_node(duration, visitor);
166            walk_nodes(body, visitor);
167        }
168        Node::EmitExpr { value }
169        | Node::ThrowStmt { value }
170        | Node::Spread(value)
171        | Node::TryOperator { operand: value }
172        | Node::TryStar { operand: value }
173        | Node::UnaryOp { operand: value, .. } => walk_node(value, visitor),
174        Node::HitlExpr { args, .. } => {
175            for arg in args {
176                walk_node(&arg.value, visitor);
177            }
178        }
179        Node::Parallel {
180            expr,
181            body,
182            options,
183            ..
184        } => {
185            walk_node(expr, visitor);
186            walk_option_values(options, visitor);
187            walk_nodes(body, visitor);
188        }
189        Node::SelectExpr {
190            cases,
191            timeout,
192            default_body,
193        } => {
194            for case in cases {
195                walk_select_case(case, visitor);
196            }
197            if let Some((duration, body)) = timeout {
198                walk_node(duration, visitor);
199                walk_nodes(body, visitor);
200            }
201            if let Some(body) = default_body {
202                walk_nodes(body, visitor);
203            }
204        }
205        Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
206            walk_nodes(args, visitor);
207        }
208        Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
209            walk_node(object, visitor);
210            walk_nodes(args, visitor);
211        }
212        Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
213            walk_node(object, visitor);
214        }
215        Node::SubscriptAccess { object, index }
216        | Node::OptionalSubscriptAccess { object, index } => {
217            walk_node(object, visitor);
218            walk_node(index, visitor);
219        }
220        Node::SliceAccess { object, start, end } => {
221            walk_node(object, visitor);
222            if let Some(start) = start {
223                walk_node(start, visitor);
224            }
225            if let Some(end) = end {
226                walk_node(end, visitor);
227            }
228        }
229        Node::BinaryOp { left, right, .. } => {
230            walk_node(left, visitor);
231            walk_node(right, visitor);
232        }
233        Node::Ternary {
234            condition,
235            true_expr,
236            false_expr,
237        } => {
238            walk_node(condition, visitor);
239            walk_node(true_expr, visitor);
240            walk_node(false_expr, visitor);
241        }
242        Node::Assignment { target, value, .. } => {
243            walk_node(target, visitor);
244            walk_node(value, visitor);
245        }
246        Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
247            walk_dict_entries(fields, visitor);
248        }
249        Node::ListLiteral(items) | Node::OrPattern(items) => walk_nodes(items, visitor),
250        Node::InterpolatedString(_)
251        | Node::StringLiteral(_)
252        | Node::RawStringLiteral(_)
253        | Node::IntLiteral(_)
254        | Node::FloatLiteral(_)
255        | Node::BoolLiteral(_)
256        | Node::NilLiteral
257        | Node::Identifier(_)
258        | Node::DurationLiteral(_) => {}
259    }
260}
261
262fn walk_nodes(nodes: &[SNode], visitor: &mut impl FnMut(&SNode)) {
263    for node in nodes {
264        walk_node(node, visitor);
265    }
266}
267
268fn walk_dict_entries(entries: &[DictEntry], visitor: &mut impl FnMut(&SNode)) {
269    for entry in entries {
270        walk_node(&entry.key, visitor);
271        walk_node(&entry.value, visitor);
272    }
273}
274
275fn walk_field_values(fields: &[(String, SNode)], visitor: &mut impl FnMut(&SNode)) {
276    for (_, value) in fields {
277        walk_node(value, visitor);
278    }
279}
280
281fn walk_option_values(options: &[(String, SNode)], visitor: &mut impl FnMut(&SNode)) {
282    for (_, value) in options {
283        walk_node(value, visitor);
284    }
285}
286
287fn walk_match_arm(arm: &MatchArm, visitor: &mut impl FnMut(&SNode)) {
288    walk_node(&arm.pattern, visitor);
289    if let Some(guard) = &arm.guard {
290        walk_node(guard, visitor);
291    }
292    walk_nodes(&arm.body, visitor);
293}
294
295fn walk_select_case(case: &SelectCase, visitor: &mut impl FnMut(&SNode)) {
296    walk_node(&case.channel, visitor);
297    walk_nodes(&case.body, visitor);
298}
299
300fn walk_binding_pattern(pattern: &BindingPattern, visitor: &mut impl FnMut(&SNode)) {
301    match pattern {
302        BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
303        BindingPattern::Dict(fields) => {
304            for field in fields {
305                if let Some(default) = &field.default_value {
306                    walk_node(default, visitor);
307                }
308            }
309        }
310        BindingPattern::List(items) => {
311            for item in items {
312                if let Some(default) = &item.default_value {
313                    walk_node(default, visitor);
314                }
315            }
316        }
317    }
318}