Skip to main content

harn_parser/
visit.rs

1//! Generic AST visitor used by the linter, formatter, and any other
2//! crate that needs to walk every `SNode` in a parsed program.
3//!
4//! Centralizing this here keeps a single source of truth for which
5//! children each `Node` variant has — adding a new variant requires
6//! one edit (in `collect_children`) and every consumer benefits.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use harn_parser::visit::walk_program;
12//! let mut count = 0;
13//! walk_program(&program, &mut |node| {
14//!     if matches!(&node.node, harn_parser::Node::FunctionCall { .. }) {
15//!         count += 1;
16//!     }
17//! });
18//! ```
19//!
20//! The visitor invokes the closure on each node *before* recursing
21//! into its children (pre-order). To stop recursion at a particular
22//! node, prefer using [`walk_children`] directly.
23
24use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase};
25
26/// Walk every node in `program` in pre-order, invoking `visitor` on
27/// each.
28pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29    let mut stack = Vec::with_capacity(program.len());
30    push_nodes_reversed(program, &mut stack);
31    walk_stack(&mut stack, visitor);
32}
33
34/// Visit `node`, then recurse into its children.
35pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36    let mut stack = vec![node];
37    walk_stack(&mut stack, visitor);
38}
39
40/// Recurse into `node`'s children without re-visiting `node` itself.
41/// Useful when a caller wants to handle the parent specially and then
42/// continue the default traversal.
43pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44    let mut stack = Vec::new();
45    push_children_reversed(node, &mut stack);
46    walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50    while let Some(node) = stack.pop() {
51        visitor(node);
52        push_children_reversed(node, stack);
53    }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57    let mut children = Vec::new();
58    collect_children(node, &mut children);
59    stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63    stack.extend(nodes.iter().rev());
64}
65
66fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
67    match &node.node {
68        Node::AttributedDecl { attributes, inner } => {
69            for attr in attributes {
70                for arg in &attr.args {
71                    children.push(&arg.value);
72                }
73            }
74            children.push(inner);
75        }
76        Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
77            collect_nodes(body, children);
78        }
79        Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
80            collect_binding_pattern(pattern, children);
81            children.push(value);
82        }
83        Node::ConstBinding { value, .. } => {
84            children.push(value);
85        }
86        Node::EnumDecl { .. }
87        | Node::StructDecl { .. }
88        | Node::InterfaceDecl { .. }
89        | Node::ImportDecl { .. }
90        | Node::SelectiveImport { .. }
91        | Node::TypeDecl { .. }
92        | Node::BreakStmt
93        | Node::ContinueStmt => {}
94        Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
95        Node::IfElse {
96            condition,
97            then_body,
98            else_body,
99        } => {
100            children.push(condition);
101            collect_nodes(then_body, children);
102            if let Some(body) = else_body {
103                collect_nodes(body, children);
104            }
105        }
106        Node::ForIn {
107            pattern,
108            iterable,
109            body,
110        } => {
111            collect_binding_pattern(pattern, children);
112            children.push(iterable);
113            collect_nodes(body, children);
114        }
115        Node::MatchExpr { value, arms } => {
116            children.push(value);
117            for arm in arms {
118                collect_match_arm(arm, children);
119            }
120        }
121        Node::WhileLoop { condition, body } => {
122            children.push(condition);
123            collect_nodes(body, children);
124        }
125        Node::Retry { count, body } => {
126            children.push(count);
127            collect_nodes(body, children);
128        }
129        Node::CostRoute { options, body } => {
130            collect_option_values(options, children);
131            collect_nodes(body, children);
132        }
133        Node::ReturnStmt { value } | Node::YieldExpr { value } => {
134            if let Some(value) = value {
135                children.push(value);
136            }
137        }
138        Node::TryCatch {
139            has_catch: _,
140            body,
141            catch_body,
142            finally_body,
143            ..
144        } => {
145            collect_nodes(body, children);
146            collect_nodes(catch_body, children);
147            if let Some(body) = finally_body {
148                collect_nodes(body, children);
149            }
150        }
151        Node::TryExpr { body }
152        | Node::SpawnExpr { body }
153        | Node::DeferStmt { body }
154        | Node::MutexBlock { body }
155        | Node::Block(body)
156        | Node::Closure { body, .. } => collect_nodes(body, children),
157        Node::FnDecl { body, .. } | Node::ToolDecl { body, .. } => {
158            collect_nodes(body, children);
159        }
160        Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
161        Node::EvalPackDecl {
162            fields,
163            body,
164            summarize,
165            ..
166        } => {
167            collect_field_values(fields, children);
168            collect_nodes(body, children);
169            if let Some(body) = summarize {
170                collect_nodes(body, children);
171            }
172        }
173        Node::RangeExpr { start, end, .. } => {
174            children.push(start);
175            children.push(end);
176        }
177        Node::GuardStmt {
178            condition,
179            else_body,
180        } => {
181            children.push(condition);
182            collect_nodes(else_body, children);
183        }
184        Node::RequireStmt { condition, message } => {
185            children.push(condition);
186            if let Some(message) = message {
187                children.push(message);
188            }
189        }
190        Node::DeadlineBlock { duration, body } => {
191            children.push(duration);
192            collect_nodes(body, children);
193        }
194        Node::EmitExpr { value }
195        | Node::ThrowStmt { value }
196        | Node::Spread(value)
197        | Node::TryOperator { operand: value }
198        | Node::TryStar { operand: value }
199        | Node::UnaryOp { operand: value, .. } => children.push(value),
200        Node::HitlExpr { args, .. } => {
201            for arg in args {
202                children.push(&arg.value);
203            }
204        }
205        Node::Parallel {
206            expr,
207            body,
208            options,
209            ..
210        } => {
211            children.push(expr);
212            collect_option_values(options, children);
213            collect_nodes(body, children);
214        }
215        Node::SelectExpr {
216            cases,
217            timeout,
218            default_body,
219        } => {
220            for case in cases {
221                collect_select_case(case, children);
222            }
223            if let Some((duration, body)) = timeout {
224                children.push(duration);
225                collect_nodes(body, children);
226            }
227            if let Some(body) = default_body {
228                collect_nodes(body, children);
229            }
230        }
231        Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
232            collect_nodes(args, children);
233        }
234        Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
235            children.push(object);
236            collect_nodes(args, children);
237        }
238        Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
239            children.push(object);
240        }
241        Node::SubscriptAccess { object, index }
242        | Node::OptionalSubscriptAccess { object, index } => {
243            children.push(object);
244            children.push(index);
245        }
246        Node::SliceAccess { object, start, end } => {
247            children.push(object);
248            if let Some(start) = start {
249                children.push(start);
250            }
251            if let Some(end) = end {
252                children.push(end);
253            }
254        }
255        Node::BinaryOp { left, right, .. } => {
256            children.push(left);
257            children.push(right);
258        }
259        Node::Ternary {
260            condition,
261            true_expr,
262            false_expr,
263        } => {
264            children.push(condition);
265            children.push(true_expr);
266            children.push(false_expr);
267        }
268        Node::Assignment { target, value, .. } => {
269            children.push(target);
270            children.push(value);
271        }
272        Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
273            collect_dict_entries(fields, children);
274        }
275        Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
276        Node::InterpolatedString(_)
277        | Node::StringLiteral(_)
278        | Node::RawStringLiteral(_)
279        | Node::IntLiteral(_)
280        | Node::FloatLiteral(_)
281        | Node::BoolLiteral(_)
282        | Node::NilLiteral
283        | Node::Identifier(_)
284        | Node::DurationLiteral(_) => {}
285    }
286}
287
288fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
289    children.extend(nodes.iter());
290}
291
292fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
293    for entry in entries {
294        children.push(&entry.key);
295        children.push(&entry.value);
296    }
297}
298
299fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
300    for (_, value) in fields {
301        children.push(value);
302    }
303}
304
305fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
306    for (_, value) in options {
307        children.push(value);
308    }
309}
310
311fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
312    children.push(&arm.pattern);
313    if let Some(guard) = &arm.guard {
314        children.push(guard);
315    }
316    collect_nodes(&arm.body, children);
317}
318
319fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
320    children.push(&case.channel);
321    collect_nodes(&case.body, children);
322}
323
324fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
325    match pattern {
326        BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
327        BindingPattern::Dict(fields) => {
328            for field in fields {
329                if let Some(default) = &field.default_value {
330                    children.push(default);
331                }
332            }
333        }
334        BindingPattern::List(items) => {
335            for item in items {
336                if let Some(default) = &item.default_value {
337                    children.push(default);
338                }
339            }
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::ast::{spanned, Node};
348    use harn_lexer::Span;
349
350    fn dummy(node: Node) -> SNode {
351        spanned(node, Span::dummy())
352    }
353
354    #[test]
355    fn walk_program_preserves_preorder() {
356        let program = vec![dummy(Node::LetBinding {
357            pattern: BindingPattern::Identifier("x".to_string()),
358            type_ann: None,
359            value: Box::new(dummy(Node::BinaryOp {
360                op: "+".to_string(),
361                left: Box::new(dummy(Node::IntLiteral(1))),
362                right: Box::new(dummy(Node::IntLiteral(2))),
363            })),
364        })];
365        let mut seen = Vec::new();
366
367        walk_program(&program, &mut |node| {
368            seen.push(match &node.node {
369                Node::LetBinding { .. } => "let",
370                Node::BinaryOp { .. } => "binary",
371                Node::IntLiteral(1) => "one",
372                Node::IntLiteral(2) => "two",
373                other => panic!("unexpected node {other:?}"),
374            });
375        });
376
377        assert_eq!(seen, vec!["let", "binary", "one", "two"]);
378    }
379
380    #[test]
381    fn walk_node_handles_deep_unary_chain_iteratively() {
382        let mut node = dummy(Node::IntLiteral(0));
383        for _ in 0..10_000 {
384            node = dummy(Node::UnaryOp {
385                op: "!".to_string(),
386                operand: Box::new(node),
387            });
388        }
389
390        let mut count = 0usize;
391        walk_node(&node, &mut |_| count += 1);
392
393        assert_eq!(count, 10_001);
394    }
395}