Skip to main content

harn_parser/
visit.rs

1//! Generic AST visitor used by the linter, formatter, and any other
2//! crate that needs to walk every `SNode` in a parsed program.
3//!
4//! Centralizing this here keeps a single source of truth for which
5//! children each `Node` variant has — adding a new variant requires
6//! one edit (in `collect_children`) and every consumer benefits.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use harn_parser::visit::walk_program;
12//! let mut count = 0;
13//! walk_program(&program, &mut |node| {
14//!     if matches!(&node.node, harn_parser::Node::FunctionCall { .. }) {
15//!         count += 1;
16//!     }
17//! });
18//! ```
19//!
20//! The visitor invokes the closure on each node *before* recursing
21//! into its children (pre-order). To stop recursion at a particular
22//! node, prefer using [`walk_children`] directly.
23
24use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase, TypedParam};
25
26/// Walk every node in `program` in pre-order, invoking `visitor` on
27/// each.
28pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29    let mut stack = Vec::with_capacity(program.len());
30    push_nodes_reversed(program, &mut stack);
31    walk_stack(&mut stack, visitor);
32}
33
34/// Visit `node`, then recurse into its children.
35pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36    let mut stack = vec![node];
37    walk_stack(&mut stack, visitor);
38}
39
40/// Recurse into `node`'s children without re-visiting `node` itself.
41/// Useful when a caller wants to handle the parent specially and then
42/// continue the default traversal.
43pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44    let mut stack = Vec::new();
45    push_children_reversed(node, &mut stack);
46    walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50    while let Some(node) = stack.pop() {
51        visitor(node);
52        push_children_reversed(node, stack);
53    }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57    let mut children = Vec::new();
58    collect_children(node, &mut children);
59    stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63    stack.extend(nodes.iter().rev());
64}
65
66fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
67    match &node.node {
68        Node::AttributedDecl { attributes, inner } => {
69            for attr in attributes {
70                for arg in &attr.args {
71                    children.push(&arg.value);
72                }
73            }
74            children.push(inner);
75        }
76        Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
77            collect_nodes(body, children);
78        }
79        Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
80            collect_binding_pattern(pattern, children);
81            children.push(value);
82        }
83        Node::ConstBinding { value, .. } => {
84            children.push(value);
85        }
86        Node::EnumDecl { variants, .. } => {
87            for variant in variants {
88                collect_typed_param_defaults(&variant.fields, children);
89            }
90        }
91        Node::StructDecl { .. }
92        | Node::ImportDecl { .. }
93        | Node::SelectiveImport { .. }
94        | Node::TypeDecl { .. }
95        | Node::BreakStmt
96        | Node::ContinueStmt => {}
97        Node::InterfaceDecl { methods, .. } => {
98            for method in methods {
99                collect_typed_param_defaults(&method.params, children);
100            }
101        }
102        Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
103        Node::IfElse {
104            condition,
105            then_body,
106            else_body,
107        } => {
108            children.push(condition);
109            collect_nodes(then_body, children);
110            if let Some(body) = else_body {
111                collect_nodes(body, children);
112            }
113        }
114        Node::ForIn {
115            pattern,
116            iterable,
117            body,
118        } => {
119            collect_binding_pattern(pattern, children);
120            children.push(iterable);
121            collect_nodes(body, children);
122        }
123        Node::MatchExpr { value, arms } => {
124            children.push(value);
125            for arm in arms {
126                collect_match_arm(arm, children);
127            }
128        }
129        Node::WhileLoop { condition, body } => {
130            children.push(condition);
131            collect_nodes(body, children);
132        }
133        Node::Retry { count, body } => {
134            children.push(count);
135            collect_nodes(body, children);
136        }
137        Node::CostRoute { options, body } => {
138            collect_option_values(options, children);
139            collect_nodes(body, children);
140        }
141        Node::ReturnStmt { value } | Node::YieldExpr { value } => {
142            if let Some(value) = value {
143                children.push(value);
144            }
145        }
146        Node::TryCatch {
147            has_catch: _,
148            body,
149            catch_body,
150            finally_body,
151            ..
152        } => {
153            collect_nodes(body, children);
154            collect_nodes(catch_body, children);
155            if let Some(body) = finally_body {
156                collect_nodes(body, children);
157            }
158        }
159        Node::TryExpr { body }
160        | Node::SpawnExpr { body }
161        | Node::ScopeBlock { body }
162        | Node::DeferStmt { body }
163        | Node::Block(body) => collect_nodes(body, children),
164        Node::Closure { params, body, .. } => {
165            collect_typed_param_defaults(params, children);
166            collect_nodes(body, children);
167        }
168        Node::MutexBlock { key, body } => {
169            if let Some(key) = key {
170                children.push(key);
171            }
172            collect_nodes(body, children);
173        }
174        Node::FnDecl { params, body, .. } | Node::ToolDecl { params, body, .. } => {
175            collect_typed_param_defaults(params, children);
176            collect_nodes(body, children);
177        }
178        Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
179        Node::EvalPackDecl {
180            fields,
181            body,
182            summarize,
183            ..
184        } => {
185            collect_field_values(fields, children);
186            collect_nodes(body, children);
187            if let Some(body) = summarize {
188                collect_nodes(body, children);
189            }
190        }
191        Node::RangeExpr { start, end, .. } => {
192            children.push(start);
193            children.push(end);
194        }
195        Node::GuardStmt {
196            condition,
197            else_body,
198        } => {
199            children.push(condition);
200            collect_nodes(else_body, children);
201        }
202        Node::RequireStmt { condition, message } => {
203            children.push(condition);
204            if let Some(message) = message {
205                children.push(message);
206            }
207        }
208        Node::DeadlineBlock { duration, body } => {
209            children.push(duration);
210            collect_nodes(body, children);
211        }
212        Node::EmitExpr { value }
213        | Node::ThrowStmt { value }
214        | Node::Spread(value)
215        | Node::TryOperator { operand: value }
216        | Node::TryStar { operand: value }
217        | Node::UnaryOp { operand: value, .. } => children.push(value),
218        Node::HitlExpr { args, .. } => {
219            for arg in args {
220                children.push(&arg.value);
221            }
222        }
223        Node::Parallel {
224            expr,
225            body,
226            options,
227            ..
228        } => {
229            children.push(expr);
230            collect_option_values(options, children);
231            collect_nodes(body, children);
232        }
233        Node::SelectExpr {
234            cases,
235            timeout,
236            default_body,
237        } => {
238            for case in cases {
239                collect_select_case(case, children);
240            }
241            if let Some((duration, body)) = timeout {
242                children.push(duration);
243                collect_nodes(body, children);
244            }
245            if let Some(body) = default_body {
246                collect_nodes(body, children);
247            }
248        }
249        Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
250            collect_nodes(args, children);
251        }
252        Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
253            children.push(object);
254            collect_nodes(args, children);
255        }
256        Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
257            children.push(object);
258        }
259        Node::SubscriptAccess { object, index }
260        | Node::OptionalSubscriptAccess { object, index } => {
261            children.push(object);
262            children.push(index);
263        }
264        Node::SliceAccess { object, start, end } => {
265            children.push(object);
266            if let Some(start) = start {
267                children.push(start);
268            }
269            if let Some(end) = end {
270                children.push(end);
271            }
272        }
273        Node::BinaryOp { left, right, .. } => {
274            children.push(left);
275            children.push(right);
276        }
277        Node::Ternary {
278            condition,
279            true_expr,
280            false_expr,
281        } => {
282            children.push(condition);
283            children.push(true_expr);
284            children.push(false_expr);
285        }
286        Node::Assignment { target, value, .. } => {
287            children.push(target);
288            children.push(value);
289        }
290        Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
291            collect_dict_entries(fields, children);
292        }
293        Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
294        Node::InterpolatedString(_)
295        | Node::StringLiteral(_)
296        | Node::RawStringLiteral(_)
297        | Node::IntLiteral(_)
298        | Node::FloatLiteral(_)
299        | Node::BoolLiteral(_)
300        | Node::NilLiteral
301        | Node::Identifier(_)
302        | Node::DurationLiteral(_) => {}
303    }
304}
305
306fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
307    children.extend(nodes.iter());
308}
309
310fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
311    for entry in entries {
312        children.push(&entry.key);
313        children.push(&entry.value);
314    }
315}
316
317fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
318    for (_, value) in fields {
319        children.push(value);
320    }
321}
322
323fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
324    for (_, value) in options {
325        children.push(value);
326    }
327}
328
329fn collect_typed_param_defaults<'a>(params: &'a [TypedParam], children: &mut Vec<&'a SNode>) {
330    for param in params {
331        if let Some(default) = &param.default_value {
332            children.push(default);
333        }
334    }
335}
336
337fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
338    children.push(&arm.pattern);
339    if let Some(guard) = &arm.guard {
340        children.push(guard);
341    }
342    collect_nodes(&arm.body, children);
343}
344
345fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
346    children.push(&case.channel);
347    collect_nodes(&case.body, children);
348}
349
350fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
351    match pattern {
352        BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
353        BindingPattern::Dict(fields) => {
354            for field in fields {
355                if let Some(default) = &field.default_value {
356                    children.push(default);
357                }
358            }
359        }
360        BindingPattern::List(items) => {
361            for item in items {
362                if let Some(default) = &item.default_value {
363                    children.push(default);
364                }
365            }
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::ast::{spanned, Node, TypedParam};
374    use harn_lexer::Span;
375
376    fn dummy(node: Node) -> SNode {
377        spanned(node, Span::dummy())
378    }
379
380    #[test]
381    fn walk_program_preserves_preorder() {
382        let program = vec![dummy(Node::LetBinding {
383            pattern: BindingPattern::Identifier("x".to_string()),
384            type_ann: None,
385            value: Box::new(dummy(Node::BinaryOp {
386                op: "+".to_string(),
387                left: Box::new(dummy(Node::IntLiteral(1))),
388                right: Box::new(dummy(Node::IntLiteral(2))),
389            })),
390        })];
391        let mut seen = Vec::new();
392
393        walk_program(&program, &mut |node| {
394            seen.push(match &node.node {
395                Node::LetBinding { .. } => "let",
396                Node::BinaryOp { .. } => "binary",
397                Node::IntLiteral(1) => "one",
398                Node::IntLiteral(2) => "two",
399                other => panic!("unexpected node {other:?}"),
400            });
401        });
402
403        assert_eq!(seen, vec!["let", "binary", "one", "two"]);
404    }
405
406    #[test]
407    fn walk_node_handles_deep_unary_chain_iteratively() {
408        let mut node = dummy(Node::IntLiteral(0));
409        for _ in 0..10_000 {
410            node = dummy(Node::UnaryOp {
411                op: "!".to_string(),
412                operand: Box::new(node),
413            });
414        }
415
416        let mut count = 0usize;
417        walk_node(&node, &mut |_| count += 1);
418
419        assert_eq!(count, 10_001);
420    }
421
422    #[test]
423    fn walk_node_visits_typed_param_defaults() {
424        let default = dummy(Node::Identifier("fallback".to_string()));
425        let node = dummy(Node::FnDecl {
426            name: "load".to_string(),
427            type_params: Vec::new(),
428            params: vec![TypedParam {
429                name: "root".to_string(),
430                type_expr: None,
431                default_value: Some(Box::new(default)),
432                rest: false,
433            }],
434            return_type: None,
435            where_clauses: Vec::new(),
436            body: Vec::new(),
437            is_pub: false,
438            is_stream: false,
439        });
440        let mut seen = Vec::new();
441
442        walk_node(&node, &mut |node| {
443            if let Node::Identifier(name) = &node.node {
444                seen.push(name.clone());
445            }
446        });
447
448        assert_eq!(seen, vec!["fallback"]);
449    }
450}