1use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase, TypedParam};
25
26pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29 let mut stack = Vec::with_capacity(program.len());
30 push_nodes_reversed(program, &mut stack);
31 walk_stack(&mut stack, visitor);
32}
33
34pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36 let mut stack = vec![node];
37 walk_stack(&mut stack, visitor);
38}
39
40pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44 let mut stack = Vec::new();
45 push_children_reversed(node, &mut stack);
46 walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50 while let Some(node) = stack.pop() {
51 visitor(node);
52 push_children_reversed(node, stack);
53 }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57 let mut children = Vec::new();
58 collect_children(node, &mut children);
59 stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63 stack.extend(nodes.iter().rev());
64}
65
66fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
67 match &node.node {
68 Node::AttributedDecl { attributes, inner } => {
69 for attr in attributes {
70 for arg in &attr.args {
71 children.push(&arg.value);
72 }
73 }
74 children.push(inner);
75 }
76 Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
77 collect_nodes(body, children);
78 }
79 Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
80 collect_binding_pattern(pattern, children);
81 children.push(value);
82 }
83 Node::ConstBinding { value, .. } => {
84 children.push(value);
85 }
86 Node::EnumDecl { variants, .. } => {
87 for variant in variants {
88 collect_typed_param_defaults(&variant.fields, children);
89 }
90 }
91 Node::StructDecl { .. }
92 | Node::ImportDecl { .. }
93 | Node::SelectiveImport { .. }
94 | Node::TypeDecl { .. }
95 | Node::BreakStmt
96 | Node::ContinueStmt => {}
97 Node::InterfaceDecl { methods, .. } => {
98 for method in methods {
99 collect_typed_param_defaults(&method.params, children);
100 }
101 }
102 Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
103 Node::IfElse {
104 condition,
105 then_body,
106 else_body,
107 } => {
108 children.push(condition);
109 collect_nodes(then_body, children);
110 if let Some(body) = else_body {
111 collect_nodes(body, children);
112 }
113 }
114 Node::ForIn {
115 pattern,
116 iterable,
117 body,
118 } => {
119 collect_binding_pattern(pattern, children);
120 children.push(iterable);
121 collect_nodes(body, children);
122 }
123 Node::MatchExpr { value, arms } => {
124 children.push(value);
125 for arm in arms {
126 collect_match_arm(arm, children);
127 }
128 }
129 Node::WhileLoop { condition, body } => {
130 children.push(condition);
131 collect_nodes(body, children);
132 }
133 Node::Retry { count, body } => {
134 children.push(count);
135 collect_nodes(body, children);
136 }
137 Node::CostRoute { options, body } => {
138 collect_option_values(options, children);
139 collect_nodes(body, children);
140 }
141 Node::ReturnStmt { value } | Node::YieldExpr { value } => {
142 if let Some(value) = value {
143 children.push(value);
144 }
145 }
146 Node::TryCatch {
147 has_catch: _,
148 body,
149 catch_body,
150 finally_body,
151 ..
152 } => {
153 collect_nodes(body, children);
154 collect_nodes(catch_body, children);
155 if let Some(body) = finally_body {
156 collect_nodes(body, children);
157 }
158 }
159 Node::TryExpr { body }
160 | Node::SpawnExpr { body }
161 | Node::ScopeBlock { body }
162 | Node::DeferStmt { body }
163 | Node::Block(body) => collect_nodes(body, children),
164 Node::Closure { params, body, .. } => {
165 collect_typed_param_defaults(params, children);
166 collect_nodes(body, children);
167 }
168 Node::MutexBlock { key, body } => {
169 if let Some(key) = key {
170 children.push(key);
171 }
172 collect_nodes(body, children);
173 }
174 Node::FnDecl { params, body, .. } | Node::ToolDecl { params, body, .. } => {
175 collect_typed_param_defaults(params, children);
176 collect_nodes(body, children);
177 }
178 Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
179 Node::EvalPackDecl {
180 fields,
181 body,
182 summarize,
183 ..
184 } => {
185 collect_field_values(fields, children);
186 collect_nodes(body, children);
187 if let Some(body) = summarize {
188 collect_nodes(body, children);
189 }
190 }
191 Node::RangeExpr { start, end, .. } => {
192 children.push(start);
193 children.push(end);
194 }
195 Node::GuardStmt {
196 condition,
197 else_body,
198 } => {
199 children.push(condition);
200 collect_nodes(else_body, children);
201 }
202 Node::RequireStmt { condition, message } => {
203 children.push(condition);
204 if let Some(message) = message {
205 children.push(message);
206 }
207 }
208 Node::DeadlineBlock { duration, body } => {
209 children.push(duration);
210 collect_nodes(body, children);
211 }
212 Node::EmitExpr { value }
213 | Node::ThrowStmt { value }
214 | Node::Spread(value)
215 | Node::TryOperator { operand: value }
216 | Node::TryStar { operand: value }
217 | Node::UnaryOp { operand: value, .. } => children.push(value),
218 Node::HitlExpr { args, .. } => {
219 for arg in args {
220 children.push(&arg.value);
221 }
222 }
223 Node::Parallel {
224 expr,
225 body,
226 options,
227 ..
228 } => {
229 children.push(expr);
230 collect_option_values(options, children);
231 collect_nodes(body, children);
232 }
233 Node::SelectExpr {
234 cases,
235 timeout,
236 default_body,
237 } => {
238 for case in cases {
239 collect_select_case(case, children);
240 }
241 if let Some((duration, body)) = timeout {
242 children.push(duration);
243 collect_nodes(body, children);
244 }
245 if let Some(body) = default_body {
246 collect_nodes(body, children);
247 }
248 }
249 Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
250 collect_nodes(args, children);
251 }
252 Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
253 children.push(object);
254 collect_nodes(args, children);
255 }
256 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
257 children.push(object);
258 }
259 Node::SubscriptAccess { object, index }
260 | Node::OptionalSubscriptAccess { object, index } => {
261 children.push(object);
262 children.push(index);
263 }
264 Node::SliceAccess { object, start, end } => {
265 children.push(object);
266 if let Some(start) = start {
267 children.push(start);
268 }
269 if let Some(end) = end {
270 children.push(end);
271 }
272 }
273 Node::BinaryOp { left, right, .. } => {
274 children.push(left);
275 children.push(right);
276 }
277 Node::Ternary {
278 condition,
279 true_expr,
280 false_expr,
281 } => {
282 children.push(condition);
283 children.push(true_expr);
284 children.push(false_expr);
285 }
286 Node::Assignment { target, value, .. } => {
287 children.push(target);
288 children.push(value);
289 }
290 Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
291 collect_dict_entries(fields, children);
292 }
293 Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
294 Node::InterpolatedString(_)
295 | Node::StringLiteral(_)
296 | Node::RawStringLiteral(_)
297 | Node::IntLiteral(_)
298 | Node::FloatLiteral(_)
299 | Node::BoolLiteral(_)
300 | Node::NilLiteral
301 | Node::Identifier(_)
302 | Node::DurationLiteral(_) => {}
303 }
304}
305
306fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
307 children.extend(nodes.iter());
308}
309
310fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
311 for entry in entries {
312 children.push(&entry.key);
313 children.push(&entry.value);
314 }
315}
316
317fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
318 for (_, value) in fields {
319 children.push(value);
320 }
321}
322
323fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
324 for (_, value) in options {
325 children.push(value);
326 }
327}
328
329fn collect_typed_param_defaults<'a>(params: &'a [TypedParam], children: &mut Vec<&'a SNode>) {
330 for param in params {
331 if let Some(default) = ¶m.default_value {
332 children.push(default);
333 }
334 }
335}
336
337fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
338 children.push(&arm.pattern);
339 if let Some(guard) = &arm.guard {
340 children.push(guard);
341 }
342 collect_nodes(&arm.body, children);
343}
344
345fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
346 children.push(&case.channel);
347 collect_nodes(&case.body, children);
348}
349
350fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
351 match pattern {
352 BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
353 BindingPattern::Dict(fields) => {
354 for field in fields {
355 if let Some(default) = &field.default_value {
356 children.push(default);
357 }
358 }
359 }
360 BindingPattern::List(items) => {
361 for item in items {
362 if let Some(default) = &item.default_value {
363 children.push(default);
364 }
365 }
366 }
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::ast::{spanned, Node, TypedParam};
374 use harn_lexer::Span;
375
376 fn dummy(node: Node) -> SNode {
377 spanned(node, Span::dummy())
378 }
379
380 #[test]
381 fn walk_program_preserves_preorder() {
382 let program = vec![dummy(Node::LetBinding {
383 pattern: BindingPattern::Identifier("x".to_string()),
384 type_ann: None,
385 value: Box::new(dummy(Node::BinaryOp {
386 op: "+".to_string(),
387 left: Box::new(dummy(Node::IntLiteral(1))),
388 right: Box::new(dummy(Node::IntLiteral(2))),
389 })),
390 })];
391 let mut seen = Vec::new();
392
393 walk_program(&program, &mut |node| {
394 seen.push(match &node.node {
395 Node::LetBinding { .. } => "let",
396 Node::BinaryOp { .. } => "binary",
397 Node::IntLiteral(1) => "one",
398 Node::IntLiteral(2) => "two",
399 other => panic!("unexpected node {other:?}"),
400 });
401 });
402
403 assert_eq!(seen, vec!["let", "binary", "one", "two"]);
404 }
405
406 #[test]
407 fn walk_node_handles_deep_unary_chain_iteratively() {
408 let mut node = dummy(Node::IntLiteral(0));
409 for _ in 0..10_000 {
410 node = dummy(Node::UnaryOp {
411 op: "!".to_string(),
412 operand: Box::new(node),
413 });
414 }
415
416 let mut count = 0usize;
417 walk_node(&node, &mut |_| count += 1);
418
419 assert_eq!(count, 10_001);
420 }
421
422 #[test]
423 fn walk_node_visits_typed_param_defaults() {
424 let default = dummy(Node::Identifier("fallback".to_string()));
425 let node = dummy(Node::FnDecl {
426 name: "load".to_string(),
427 type_params: Vec::new(),
428 params: vec![TypedParam {
429 name: "root".to_string(),
430 type_expr: None,
431 default_value: Some(Box::new(default)),
432 rest: false,
433 }],
434 return_type: None,
435 where_clauses: Vec::new(),
436 body: Vec::new(),
437 is_pub: false,
438 is_stream: false,
439 });
440 let mut seen = Vec::new();
441
442 walk_node(&node, &mut |node| {
443 if let Node::Identifier(name) = &node.node {
444 seen.push(name.clone());
445 }
446 });
447
448 assert_eq!(seen, vec!["fallback"]);
449 }
450}