1use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase};
25
26pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29 let mut stack = Vec::with_capacity(program.len());
30 push_nodes_reversed(program, &mut stack);
31 walk_stack(&mut stack, visitor);
32}
33
34pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36 let mut stack = vec![node];
37 walk_stack(&mut stack, visitor);
38}
39
40pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44 let mut stack = Vec::new();
45 push_children_reversed(node, &mut stack);
46 walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50 while let Some(node) = stack.pop() {
51 visitor(node);
52 push_children_reversed(node, stack);
53 }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57 let mut children = Vec::new();
58 collect_children(node, &mut children);
59 stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63 stack.extend(nodes.iter().rev());
64}
65
66fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
67 match &node.node {
68 Node::AttributedDecl { attributes, inner } => {
69 for attr in attributes {
70 for arg in &attr.args {
71 children.push(&arg.value);
72 }
73 }
74 children.push(inner);
75 }
76 Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
77 collect_nodes(body, children);
78 }
79 Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
80 collect_binding_pattern(pattern, children);
81 children.push(value);
82 }
83 Node::ConstBinding { value, .. } => {
84 children.push(value);
85 }
86 Node::EnumDecl { .. }
87 | Node::StructDecl { .. }
88 | Node::InterfaceDecl { .. }
89 | Node::ImportDecl { .. }
90 | Node::SelectiveImport { .. }
91 | Node::TypeDecl { .. }
92 | Node::BreakStmt
93 | Node::ContinueStmt => {}
94 Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
95 Node::IfElse {
96 condition,
97 then_body,
98 else_body,
99 } => {
100 children.push(condition);
101 collect_nodes(then_body, children);
102 if let Some(body) = else_body {
103 collect_nodes(body, children);
104 }
105 }
106 Node::ForIn {
107 pattern,
108 iterable,
109 body,
110 } => {
111 collect_binding_pattern(pattern, children);
112 children.push(iterable);
113 collect_nodes(body, children);
114 }
115 Node::MatchExpr { value, arms } => {
116 children.push(value);
117 for arm in arms {
118 collect_match_arm(arm, children);
119 }
120 }
121 Node::WhileLoop { condition, body } => {
122 children.push(condition);
123 collect_nodes(body, children);
124 }
125 Node::Retry { count, body } => {
126 children.push(count);
127 collect_nodes(body, children);
128 }
129 Node::CostRoute { options, body } => {
130 collect_option_values(options, children);
131 collect_nodes(body, children);
132 }
133 Node::ReturnStmt { value } | Node::YieldExpr { value } => {
134 if let Some(value) = value {
135 children.push(value);
136 }
137 }
138 Node::TryCatch {
139 has_catch: _,
140 body,
141 catch_body,
142 finally_body,
143 ..
144 } => {
145 collect_nodes(body, children);
146 collect_nodes(catch_body, children);
147 if let Some(body) = finally_body {
148 collect_nodes(body, children);
149 }
150 }
151 Node::TryExpr { body }
152 | Node::SpawnExpr { body }
153 | Node::ScopeBlock { body }
154 | Node::DeferStmt { body }
155 | Node::Block(body)
156 | Node::Closure { body, .. } => collect_nodes(body, children),
157 Node::MutexBlock { key, body } => {
158 if let Some(key) = key {
159 children.push(key);
160 }
161 collect_nodes(body, children);
162 }
163 Node::FnDecl { body, .. } | Node::ToolDecl { body, .. } => {
164 collect_nodes(body, children);
165 }
166 Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
167 Node::EvalPackDecl {
168 fields,
169 body,
170 summarize,
171 ..
172 } => {
173 collect_field_values(fields, children);
174 collect_nodes(body, children);
175 if let Some(body) = summarize {
176 collect_nodes(body, children);
177 }
178 }
179 Node::RangeExpr { start, end, .. } => {
180 children.push(start);
181 children.push(end);
182 }
183 Node::GuardStmt {
184 condition,
185 else_body,
186 } => {
187 children.push(condition);
188 collect_nodes(else_body, children);
189 }
190 Node::RequireStmt { condition, message } => {
191 children.push(condition);
192 if let Some(message) = message {
193 children.push(message);
194 }
195 }
196 Node::DeadlineBlock { duration, body } => {
197 children.push(duration);
198 collect_nodes(body, children);
199 }
200 Node::EmitExpr { value }
201 | Node::ThrowStmt { value }
202 | Node::Spread(value)
203 | Node::TryOperator { operand: value }
204 | Node::TryStar { operand: value }
205 | Node::UnaryOp { operand: value, .. } => children.push(value),
206 Node::HitlExpr { args, .. } => {
207 for arg in args {
208 children.push(&arg.value);
209 }
210 }
211 Node::Parallel {
212 expr,
213 body,
214 options,
215 ..
216 } => {
217 children.push(expr);
218 collect_option_values(options, children);
219 collect_nodes(body, children);
220 }
221 Node::SelectExpr {
222 cases,
223 timeout,
224 default_body,
225 } => {
226 for case in cases {
227 collect_select_case(case, children);
228 }
229 if let Some((duration, body)) = timeout {
230 children.push(duration);
231 collect_nodes(body, children);
232 }
233 if let Some(body) = default_body {
234 collect_nodes(body, children);
235 }
236 }
237 Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
238 collect_nodes(args, children);
239 }
240 Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
241 children.push(object);
242 collect_nodes(args, children);
243 }
244 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
245 children.push(object);
246 }
247 Node::SubscriptAccess { object, index }
248 | Node::OptionalSubscriptAccess { object, index } => {
249 children.push(object);
250 children.push(index);
251 }
252 Node::SliceAccess { object, start, end } => {
253 children.push(object);
254 if let Some(start) = start {
255 children.push(start);
256 }
257 if let Some(end) = end {
258 children.push(end);
259 }
260 }
261 Node::BinaryOp { left, right, .. } => {
262 children.push(left);
263 children.push(right);
264 }
265 Node::Ternary {
266 condition,
267 true_expr,
268 false_expr,
269 } => {
270 children.push(condition);
271 children.push(true_expr);
272 children.push(false_expr);
273 }
274 Node::Assignment { target, value, .. } => {
275 children.push(target);
276 children.push(value);
277 }
278 Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
279 collect_dict_entries(fields, children);
280 }
281 Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
282 Node::InterpolatedString(_)
283 | Node::StringLiteral(_)
284 | Node::RawStringLiteral(_)
285 | Node::IntLiteral(_)
286 | Node::FloatLiteral(_)
287 | Node::BoolLiteral(_)
288 | Node::NilLiteral
289 | Node::Identifier(_)
290 | Node::DurationLiteral(_) => {}
291 }
292}
293
294fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
295 children.extend(nodes.iter());
296}
297
298fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
299 for entry in entries {
300 children.push(&entry.key);
301 children.push(&entry.value);
302 }
303}
304
305fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
306 for (_, value) in fields {
307 children.push(value);
308 }
309}
310
311fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
312 for (_, value) in options {
313 children.push(value);
314 }
315}
316
317fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
318 children.push(&arm.pattern);
319 if let Some(guard) = &arm.guard {
320 children.push(guard);
321 }
322 collect_nodes(&arm.body, children);
323}
324
325fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
326 children.push(&case.channel);
327 collect_nodes(&case.body, children);
328}
329
330fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
331 match pattern {
332 BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
333 BindingPattern::Dict(fields) => {
334 for field in fields {
335 if let Some(default) = &field.default_value {
336 children.push(default);
337 }
338 }
339 }
340 BindingPattern::List(items) => {
341 for item in items {
342 if let Some(default) = &item.default_value {
343 children.push(default);
344 }
345 }
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::ast::{spanned, Node};
354 use harn_lexer::Span;
355
356 fn dummy(node: Node) -> SNode {
357 spanned(node, Span::dummy())
358 }
359
360 #[test]
361 fn walk_program_preserves_preorder() {
362 let program = vec![dummy(Node::LetBinding {
363 pattern: BindingPattern::Identifier("x".to_string()),
364 type_ann: None,
365 value: Box::new(dummy(Node::BinaryOp {
366 op: "+".to_string(),
367 left: Box::new(dummy(Node::IntLiteral(1))),
368 right: Box::new(dummy(Node::IntLiteral(2))),
369 })),
370 })];
371 let mut seen = Vec::new();
372
373 walk_program(&program, &mut |node| {
374 seen.push(match &node.node {
375 Node::LetBinding { .. } => "let",
376 Node::BinaryOp { .. } => "binary",
377 Node::IntLiteral(1) => "one",
378 Node::IntLiteral(2) => "two",
379 other => panic!("unexpected node {other:?}"),
380 });
381 });
382
383 assert_eq!(seen, vec!["let", "binary", "one", "two"]);
384 }
385
386 #[test]
387 fn walk_node_handles_deep_unary_chain_iteratively() {
388 let mut node = dummy(Node::IntLiteral(0));
389 for _ in 0..10_000 {
390 node = dummy(Node::UnaryOp {
391 op: "!".to_string(),
392 operand: Box::new(node),
393 });
394 }
395
396 let mut count = 0usize;
397 walk_node(&node, &mut |_| count += 1);
398
399 assert_eq!(count, 10_001);
400 }
401}