1use crate::ast::{BindingPattern, DictEntry, MatchArm, Node, SNode, SelectCase, TypedParam};
25
26pub fn walk_program(program: &[SNode], visitor: &mut impl FnMut(&SNode)) {
29 let mut stack = Vec::with_capacity(program.len());
30 push_nodes_reversed(program, &mut stack);
31 walk_stack(&mut stack, visitor);
32}
33
34pub fn walk_node(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
36 let mut stack = vec![node];
37 walk_stack(&mut stack, visitor);
38}
39
40pub fn walk_children(node: &SNode, visitor: &mut impl FnMut(&SNode)) {
44 let mut stack = Vec::new();
45 push_children_reversed(node, &mut stack);
46 walk_stack(&mut stack, visitor);
47}
48
49fn walk_stack(stack: &mut Vec<&SNode>, visitor: &mut impl FnMut(&SNode)) {
50 while let Some(node) = stack.pop() {
51 visitor(node);
52 push_children_reversed(node, stack);
53 }
54}
55
56fn push_children_reversed<'a>(node: &'a SNode, stack: &mut Vec<&'a SNode>) {
57 let mut children = Vec::new();
58 collect_children(node, &mut children);
59 stack.extend(children.into_iter().rev());
60}
61
62fn push_nodes_reversed<'a>(nodes: &'a [SNode], stack: &mut Vec<&'a SNode>) {
63 stack.extend(nodes.iter().rev());
64}
65
66pub fn immediate_children(node: &SNode) -> Vec<&SNode> {
70 let mut children = Vec::new();
71 collect_children(node, &mut children);
72 children
73}
74
75fn collect_children<'a>(node: &'a SNode, children: &mut Vec<&'a SNode>) {
76 match &node.node {
77 Node::AttributedDecl { attributes, inner } => {
78 for attr in attributes {
79 for arg in &attr.args {
80 children.push(&arg.value);
81 }
82 }
83 children.push(inner);
84 }
85 Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
86 collect_nodes(body, children);
87 }
88 Node::LetBinding { pattern, value, .. } | Node::VarBinding { pattern, value, .. } => {
89 collect_binding_pattern(pattern, children);
90 children.push(value);
91 }
92 Node::ConstBinding { value, .. } => {
93 children.push(value);
94 }
95 Node::EnumDecl { variants, .. } => {
96 for variant in variants {
97 collect_typed_param_defaults(&variant.fields, children);
98 }
99 }
100 Node::StructDecl { .. }
101 | Node::ImportDecl { .. }
102 | Node::SelectiveImport { .. }
103 | Node::TypeDecl { .. }
104 | Node::BreakStmt
105 | Node::ContinueStmt => {}
106 Node::InterfaceDecl { methods, .. } => {
107 for method in methods {
108 collect_typed_param_defaults(&method.params, children);
109 }
110 }
111 Node::ImplBlock { methods, .. } => collect_nodes(methods, children),
112 Node::IfElse {
113 condition,
114 then_body,
115 else_body,
116 } => {
117 children.push(condition);
118 collect_nodes(then_body, children);
119 if let Some(body) = else_body {
120 collect_nodes(body, children);
121 }
122 }
123 Node::ForIn {
124 pattern,
125 iterable,
126 body,
127 } => {
128 collect_binding_pattern(pattern, children);
129 children.push(iterable);
130 collect_nodes(body, children);
131 }
132 Node::MatchExpr { value, arms } => {
133 children.push(value);
134 for arm in arms {
135 collect_match_arm(arm, children);
136 }
137 }
138 Node::WhileLoop { condition, body } => {
139 children.push(condition);
140 collect_nodes(body, children);
141 }
142 Node::Retry { count, body } => {
143 children.push(count);
144 collect_nodes(body, children);
145 }
146 Node::CostRoute { options, body } => {
147 collect_option_values(options, children);
148 collect_nodes(body, children);
149 }
150 Node::ReturnStmt { value } | Node::YieldExpr { value } => {
151 if let Some(value) = value {
152 children.push(value);
153 }
154 }
155 Node::TryCatch {
156 has_catch: _,
157 body,
158 catch_body,
159 finally_body,
160 ..
161 } => {
162 collect_nodes(body, children);
163 collect_nodes(catch_body, children);
164 if let Some(body) = finally_body {
165 collect_nodes(body, children);
166 }
167 }
168 Node::TryExpr { body }
169 | Node::SpawnExpr { body }
170 | Node::ScopeBlock { body }
171 | Node::DeferStmt { body }
172 | Node::Block(body) => collect_nodes(body, children),
173 Node::Closure { params, body, .. } => {
174 collect_typed_param_defaults(params, children);
175 collect_nodes(body, children);
176 }
177 Node::MutexBlock { key, body } => {
178 if let Some(key) = key {
179 children.push(key);
180 }
181 collect_nodes(body, children);
182 }
183 Node::FnDecl { params, body, .. } | Node::ToolDecl { params, body, .. } => {
184 collect_typed_param_defaults(params, children);
185 collect_nodes(body, children);
186 }
187 Node::SkillDecl { fields, .. } => collect_field_values(fields, children),
188 Node::EvalPackDecl {
189 fields,
190 body,
191 summarize,
192 ..
193 } => {
194 collect_field_values(fields, children);
195 collect_nodes(body, children);
196 if let Some(body) = summarize {
197 collect_nodes(body, children);
198 }
199 }
200 Node::RangeExpr { start, end, .. } => {
201 children.push(start);
202 children.push(end);
203 }
204 Node::GuardStmt {
205 condition,
206 else_body,
207 } => {
208 children.push(condition);
209 collect_nodes(else_body, children);
210 }
211 Node::RequireStmt { condition, message } => {
212 children.push(condition);
213 if let Some(message) = message {
214 children.push(message);
215 }
216 }
217 Node::DeadlineBlock { duration, body } => {
218 children.push(duration);
219 collect_nodes(body, children);
220 }
221 Node::EmitExpr { value }
222 | Node::ThrowStmt { value }
223 | Node::Spread(value)
224 | Node::TryOperator { operand: value }
225 | Node::TryStar { operand: value }
226 | Node::UnaryOp { operand: value, .. } => children.push(value),
227 Node::HitlExpr { args, .. } => {
228 for arg in args {
229 children.push(&arg.value);
230 }
231 }
232 Node::Parallel {
233 expr,
234 body,
235 options,
236 ..
237 } => {
238 children.push(expr);
239 collect_option_values(options, children);
240 collect_nodes(body, children);
241 }
242 Node::SelectExpr {
243 cases,
244 timeout,
245 default_body,
246 } => {
247 for case in cases {
248 collect_select_case(case, children);
249 }
250 if let Some((duration, body)) = timeout {
251 children.push(duration);
252 collect_nodes(body, children);
253 }
254 if let Some(body) = default_body {
255 collect_nodes(body, children);
256 }
257 }
258 Node::FunctionCall { args, .. } | Node::EnumConstruct { args, .. } => {
259 collect_nodes(args, children);
260 }
261 Node::MethodCall { object, args, .. } | Node::OptionalMethodCall { object, args, .. } => {
262 children.push(object);
263 collect_nodes(args, children);
264 }
265 Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
266 children.push(object);
267 }
268 Node::SubscriptAccess { object, index }
269 | Node::OptionalSubscriptAccess { object, index } => {
270 children.push(object);
271 children.push(index);
272 }
273 Node::SliceAccess { object, start, end } => {
274 children.push(object);
275 if let Some(start) = start {
276 children.push(start);
277 }
278 if let Some(end) = end {
279 children.push(end);
280 }
281 }
282 Node::BinaryOp { left, right, .. } => {
283 children.push(left);
284 children.push(right);
285 }
286 Node::Ternary {
287 condition,
288 true_expr,
289 false_expr,
290 } => {
291 children.push(condition);
292 children.push(true_expr);
293 children.push(false_expr);
294 }
295 Node::Assignment { target, value, .. } => {
296 children.push(target);
297 children.push(value);
298 }
299 Node::StructConstruct { fields, .. } | Node::DictLiteral(fields) => {
300 collect_dict_entries(fields, children);
301 }
302 Node::ListLiteral(items) | Node::OrPattern(items) => collect_nodes(items, children),
303 Node::InterpolatedString(_)
304 | Node::StringLiteral(_)
305 | Node::RawStringLiteral(_)
306 | Node::IntLiteral(_)
307 | Node::FloatLiteral(_)
308 | Node::BoolLiteral(_)
309 | Node::NilLiteral
310 | Node::Identifier(_)
311 | Node::DurationLiteral(_) => {}
312 }
313}
314
315fn collect_nodes<'a>(nodes: &'a [SNode], children: &mut Vec<&'a SNode>) {
316 children.extend(nodes.iter());
317}
318
319fn collect_dict_entries<'a>(entries: &'a [DictEntry], children: &mut Vec<&'a SNode>) {
320 for entry in entries {
321 children.push(&entry.key);
322 children.push(&entry.value);
323 }
324}
325
326fn collect_field_values<'a>(fields: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
327 for (_, value) in fields {
328 children.push(value);
329 }
330}
331
332fn collect_option_values<'a>(options: &'a [(String, SNode)], children: &mut Vec<&'a SNode>) {
333 for (_, value) in options {
334 children.push(value);
335 }
336}
337
338fn collect_typed_param_defaults<'a>(params: &'a [TypedParam], children: &mut Vec<&'a SNode>) {
339 for param in params {
340 if let Some(default) = ¶m.default_value {
341 children.push(default);
342 }
343 }
344}
345
346fn collect_match_arm<'a>(arm: &'a MatchArm, children: &mut Vec<&'a SNode>) {
347 children.push(&arm.pattern);
348 if let Some(guard) = &arm.guard {
349 children.push(guard);
350 }
351 collect_nodes(&arm.body, children);
352}
353
354fn collect_select_case<'a>(case: &'a SelectCase, children: &mut Vec<&'a SNode>) {
355 children.push(&case.channel);
356 collect_nodes(&case.body, children);
357}
358
359fn collect_binding_pattern<'a>(pattern: &'a BindingPattern, children: &mut Vec<&'a SNode>) {
360 match pattern {
361 BindingPattern::Identifier(_) | BindingPattern::Pair(_, _) => {}
362 BindingPattern::Dict(fields) => {
363 for field in fields {
364 if let Some(default) = &field.default_value {
365 children.push(default);
366 }
367 }
368 }
369 BindingPattern::List(items) => {
370 for item in items {
371 if let Some(default) = &item.default_value {
372 children.push(default);
373 }
374 }
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::ast::{spanned, Node, TypedParam};
383 use harn_lexer::Span;
384
385 fn dummy(node: Node) -> SNode {
386 spanned(node, Span::dummy())
387 }
388
389 #[test]
390 fn walk_program_preserves_preorder() {
391 let program = vec![dummy(Node::LetBinding {
392 pattern: BindingPattern::Identifier("x".to_string()),
393 type_ann: None,
394 value: Box::new(dummy(Node::BinaryOp {
395 op: "+".to_string(),
396 left: Box::new(dummy(Node::IntLiteral(1))),
397 right: Box::new(dummy(Node::IntLiteral(2))),
398 })),
399 })];
400 let mut seen = Vec::new();
401
402 walk_program(&program, &mut |node| {
403 seen.push(match &node.node {
404 Node::LetBinding { .. } => "let",
405 Node::BinaryOp { .. } => "binary",
406 Node::IntLiteral(1) => "one",
407 Node::IntLiteral(2) => "two",
408 other => panic!("unexpected node {other:?}"),
409 });
410 });
411
412 assert_eq!(seen, vec!["let", "binary", "one", "two"]);
413 }
414
415 #[test]
416 fn walk_node_handles_deep_unary_chain_iteratively() {
417 let mut node = dummy(Node::IntLiteral(0));
418 for _ in 0..10_000 {
419 node = dummy(Node::UnaryOp {
420 op: "!".to_string(),
421 operand: Box::new(node),
422 });
423 }
424
425 let mut count = 0usize;
426 walk_node(&node, &mut |_| count += 1);
427
428 assert_eq!(count, 10_001);
429 }
430
431 #[test]
432 fn walk_node_visits_typed_param_defaults() {
433 let default = dummy(Node::Identifier("fallback".to_string()));
434 let node = dummy(Node::FnDecl {
435 name: "load".to_string(),
436 type_params: Vec::new(),
437 params: vec![TypedParam {
438 name: "root".to_string(),
439 type_expr: None,
440 default_value: Some(Box::new(default)),
441 rest: false,
442 }],
443 return_type: None,
444 where_clauses: Vec::new(),
445 body: Vec::new(),
446 is_pub: false,
447 is_stream: false,
448 });
449 let mut seen = Vec::new();
450
451 walk_node(&node, &mut |node| {
452 if let Node::Identifier(name) = &node.node {
453 seen.push(name.clone());
454 }
455 });
456
457 assert_eq!(seen, vec!["fallback"]);
458 }
459}