Skip to main content

harn_parser/
visit.rs

1//! Generic AST visitor used by the linter, formatter, and any other
2//! crate that needs to walk every `SNode` in a parsed program.
3//!
4//! Centralizing this here keeps a single source of truth for which
5//! children each `Node` variant has — adding a new variant requires
6//! one edit (in `collect_children`) and every consumer benefits.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use harn_parser::visit::walk_program;
12//! let mut count = 0;
13//! walk_program(&program, &mut |node| {
14//!     if matches!(&node.node, harn_parser::Node::FunctionCall { .. }) {
15//!         count += 1;
16//!     }
17//! });
18//! ```
19//!
20//! The visitor invokes the closure on each node *before* recursing
21//! into its children (pre-order). To stop recursion at a particular
22//! node, prefer using [`walk_children`] directly.
23
24use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase, TypedParam};
25
26/// Walk every node in `program` in pre-order, invoking `visitor` on
27/// each.
28pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29    let mut stack = Vec::with_capacity(program.len());
30    push_nodes_reversed(program, &mut stack);
31    walk_stack(&mut stack, visitor);
32}
33
34/// Visit `node`, then recurse into its children.
35pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36    let mut stack = vec![node];
37    walk_stack(&mut stack, visitor);
38}
39
40/// Recurse into `node`'s children without re-visiting `node` itself.
41/// Useful when a caller wants to handle the parent specially and then
42/// continue the default traversal.
43pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44    let mut stack = Vec::new();
45    push_children_reversed(node, &mut stack);
46    walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50    while let Some(node) = stack.pop() {
51        visitor(node);
52        push_children_reversed(node, stack);
53    }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57    let mut children = Vec::new();
58    collect_children(node, &mut children);
59    stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63    stack.extend(nodes.iter().rev());
64}
65
66/// Collect `node`'s immediate children without recursing. Lets callers walk
67/// selectively (e.g. stop descending at nested loops) while still relying on
68/// this module's single source of truth for each variant's children.
69pub fn immediate_children(node: &SNode) -> Vec<&SNode> {
70    let mut children = Vec::new();
71    collect_children(node, &mut children);
72    children
73}
74
75fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
76    match &node.node {
77        Node::AttributedDecl { attributes, inner } => {
78            for attr in attributes {
79                for arg in &attr.args {
80                    children.push(&arg.value);
81                }
82            }
83            children.push(inner);
84        }
85        Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
86            collect_nodes(body, children);
87        }
88        Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
89            collect_binding_pattern(pattern, children);
90            children.push(value);
91        }
92        Node::ConstBinding { value, .. } => {
93            children.push(value);
94        }
95        Node::EnumDecl { variants, .. } => {
96            for variant in variants {
97                collect_typed_param_defaults(&variant.fields, children);
98            }
99        }
100        Node::StructDecl { .. }
101        | Node::ImportDecl { .. }
102        | Node::SelectiveImport { .. }
103        | Node::TypeDecl { .. }
104        | Node::BreakStmt
105        | Node::ContinueStmt => {}
106        Node::InterfaceDecl { methods, .. } => {
107            for method in methods {
108                collect_typed_param_defaults(&method.params, children);
109            }
110        }
111        Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
112        Node::IfElse {
113            condition,
114            then_body,
115            else_body,
116        } => {
117            children.push(condition);
118            collect_nodes(then_body, children);
119            if let Some(body) = else_body {
120                collect_nodes(body, children);
121            }
122        }
123        Node::ForIn {
124            pattern,
125            iterable,
126            body,
127        } => {
128            collect_binding_pattern(pattern, children);
129            children.push(iterable);
130            collect_nodes(body, children);
131        }
132        Node::MatchExpr { value, arms } => {
133            children.push(value);
134            for arm in arms {
135                collect_match_arm(arm, children);
136            }
137        }
138        Node::WhileLoop { condition, body } => {
139            children.push(condition);
140            collect_nodes(body, children);
141        }
142        Node::Retry { count, body } => {
143            children.push(count);
144            collect_nodes(body, children);
145        }
146        Node::CostRoute { options, body } => {
147            collect_option_values(options, children);
148            collect_nodes(body, children);
149        }
150        Node::ReturnStmt { value } | Node::YieldExpr { value } => {
151            if let Some(value) = value {
152                children.push(value);
153            }
154        }
155        Node::TryCatch {
156            has_catch: _,
157            body,
158            catch_body,
159            finally_body,
160            ..
161        } => {
162            collect_nodes(body, children);
163            collect_nodes(catch_body, children);
164            if let Some(body) = finally_body {
165                collect_nodes(body, children);
166            }
167        }
168        Node::TryExpr { body }
169        | Node::SpawnExpr { body }
170        | Node::ScopeBlock { body }
171        | Node::DeferStmt { body }
172        | Node::Block(body) => collect_nodes(body, children),
173        Node::Closure { params, body, .. } => {
174            collect_typed_param_defaults(params, children);
175            collect_nodes(body, children);
176        }
177        Node::MutexBlock { key, body } => {
178            if let Some(key) = key {
179                children.push(key);
180            }
181            collect_nodes(body, children);
182        }
183        Node::FnDecl { params, body, .. } | Node::ToolDecl { params, body, .. } => {
184            collect_typed_param_defaults(params, children);
185            collect_nodes(body, children);
186        }
187        Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
188        Node::EvalPackDecl {
189            fields,
190            body,
191            summarize,
192            ..
193        } => {
194            collect_field_values(fields, children);
195            collect_nodes(body, children);
196            if let Some(body) = summarize {
197                collect_nodes(body, children);
198            }
199        }
200        Node::RangeExpr { start, end, .. } => {
201            children.push(start);
202            children.push(end);
203        }
204        Node::GuardStmt {
205            condition,
206            else_body,
207        } => {
208            children.push(condition);
209            collect_nodes(else_body, children);
210        }
211        Node::RequireStmt { condition, message } => {
212            children.push(condition);
213            if let Some(message) = message {
214                children.push(message);
215            }
216        }
217        Node::DeadlineBlock { duration, body } => {
218            children.push(duration);
219            collect_nodes(body, children);
220        }
221        Node::EmitExpr { value }
222        | Node::ThrowStmt { value }
223        | Node::Spread(value)
224        | Node::TryOperator { operand: value }
225        | Node::TryStar { operand: value }
226        | Node::UnaryOp { operand: value, .. } => children.push(value),
227        Node::HitlExpr { args, .. } => {
228            for arg in args {
229                children.push(&arg.value);
230            }
231        }
232        Node::Parallel {
233            expr,
234            body,
235            options,
236            ..
237        } => {
238            children.push(expr);
239            collect_option_values(options, children);
240            collect_nodes(body, children);
241        }
242        Node::SelectExpr {
243            cases,
244            timeout,
245            default_body,
246        } => {
247            for case in cases {
248                collect_select_case(case, children);
249            }
250            if let Some((duration, body)) = timeout {
251                children.push(duration);
252                collect_nodes(body, children);
253            }
254            if let Some(body) = default_body {
255                collect_nodes(body, children);
256            }
257        }
258        Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
259            collect_nodes(args, children);
260        }
261        Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
262            children.push(object);
263            collect_nodes(args, children);
264        }
265        Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
266            children.push(object);
267        }
268        Node::SubscriptAccess { object, index }
269        | Node::OptionalSubscriptAccess { object, index } => {
270            children.push(object);
271            children.push(index);
272        }
273        Node::SliceAccess { object, start, end } => {
274            children.push(object);
275            if let Some(start) = start {
276                children.push(start);
277            }
278            if let Some(end) = end {
279                children.push(end);
280            }
281        }
282        Node::BinaryOp { left, right, .. } => {
283            children.push(left);
284            children.push(right);
285        }
286        Node::Ternary {
287            condition,
288            true_expr,
289            false_expr,
290        } => {
291            children.push(condition);
292            children.push(true_expr);
293            children.push(false_expr);
294        }
295        Node::Assignment { target, value, .. } => {
296            children.push(target);
297            children.push(value);
298        }
299        Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
300            collect_dict_entries(fields, children);
301        }
302        Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
303        Node::InterpolatedString(_)
304        | Node::StringLiteral(_)
305        | Node::RawStringLiteral(_)
306        | Node::IntLiteral(_)
307        | Node::FloatLiteral(_)
308        | Node::BoolLiteral(_)
309        | Node::NilLiteral
310        | Node::Identifier(_)
311        | Node::DurationLiteral(_) => {}
312    }
313}
314
315fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
316    children.extend(nodes.iter());
317}
318
319fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
320    for entry in entries {
321        children.push(&entry.key);
322        children.push(&entry.value);
323    }
324}
325
326fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
327    for (_, value) in fields {
328        children.push(value);
329    }
330}
331
332fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
333    for (_, value) in options {
334        children.push(value);
335    }
336}
337
338fn collect_typed_param_defaults<'a>(params: &'a [TypedParam], children: &mut Vec<&'a SNode>) {
339    for param in params {
340        if let Some(default) = &param.default_value {
341            children.push(default);
342        }
343    }
344}
345
346fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
347    children.push(&arm.pattern);
348    if let Some(guard) = &arm.guard {
349        children.push(guard);
350    }
351    collect_nodes(&arm.body, children);
352}
353
354fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
355    children.push(&case.channel);
356    collect_nodes(&case.body, children);
357}
358
359fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
360    match pattern {
361        BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
362        BindingPattern::Dict(fields) => {
363            for field in fields {
364                if let Some(default) = &field.default_value {
365                    children.push(default);
366                }
367            }
368        }
369        BindingPattern::List(items) => {
370            for item in items {
371                if let Some(default) = &item.default_value {
372                    children.push(default);
373                }
374            }
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::ast::{spanned, Node, TypedParam};
383    use harn_lexer::Span;
384
385    fn dummy(node: Node) -> SNode {
386        spanned(node, Span::dummy())
387    }
388
389    #[test]
390    fn walk_program_preserves_preorder() {
391        let program = vec![dummy(Node::LetBinding {
392            pattern: BindingPattern::Identifier("x".to_string()),
393            type_ann: None,
394            value: Box::new(dummy(Node::BinaryOp {
395                op: "+".to_string(),
396                left: Box::new(dummy(Node::IntLiteral(1))),
397                right: Box::new(dummy(Node::IntLiteral(2))),
398            })),
399        })];
400        let mut seen = Vec::new();
401
402        walk_program(&program, &mut |node| {
403            seen.push(match &node.node {
404                Node::LetBinding { .. } => "let",
405                Node::BinaryOp { .. } => "binary",
406                Node::IntLiteral(1) => "one",
407                Node::IntLiteral(2) => "two",
408                other => panic!("unexpected node {other:?}"),
409            });
410        });
411
412        assert_eq!(seen, vec!["let", "binary", "one", "two"]);
413    }
414
415    #[test]
416    fn walk_node_handles_deep_unary_chain_iteratively() {
417        let mut node = dummy(Node::IntLiteral(0));
418        for _ in 0..10_000 {
419            node = dummy(Node::UnaryOp {
420                op: "!".to_string(),
421                operand: Box::new(node),
422            });
423        }
424
425        let mut count = 0usize;
426        walk_node(&node, &mut |_| count += 1);
427
428        assert_eq!(count, 10_001);
429    }
430
431    #[test]
432    fn walk_node_visits_typed_param_defaults() {
433        let default = dummy(Node::Identifier("fallback".to_string()));
434        let node = dummy(Node::FnDecl {
435            name: "load".to_string(),
436            type_params: Vec::new(),
437            params: vec![TypedParam {
438                name: "root".to_string(),
439                type_expr: None,
440                default_value: Some(Box::new(default)),
441                rest: false,
442            }],
443            return_type: None,
444            where_clauses: Vec::new(),
445            body: Vec::new(),
446            is_pub: false,
447            is_stream: false,
448        });
449        let mut seen = Vec::new();
450
451        walk_node(&node, &mut |node| {
452            if let Node::Identifier(name) = &node.node {
453                seen.push(name.clone());
454            }
455        });
456
457        assert_eq!(seen, vec!["fallback"]);
458    }
459}