Skip to main content

shape_runtime/
visitor.rs

1//! AST Visitor trait and walk functions for Shape.
2//!
3//! This module provides a visitor pattern for traversing the AST.
4//! All variants are explicitly handled - no wildcards.
5//!
6//! ## Per-Variant Expression Methods
7//!
8//! The `Visitor` trait provides fine-grained per-variant methods for expressions.
9//! Each method has a default implementation that returns `true` (continue into
10//! children). Override only the variants you care about.
11//!
12//! The visit order for each expression is:
13//! 1. `visit_expr(expr)` — coarse pre-visit hook; return `false` to skip entirely
14//! 2. `visit_<variant>(expr, span)` — per-variant hook; return `false` to skip children
15//! 3. Walk children recursively
16//! 4. `leave_expr(expr)` — post-visit hook
17
18use shape_ast::ast::*;
19
20/// A visitor trait for traversing Shape AST nodes.
21///
22/// All `visit_*` methods return `bool`:
23/// - `true`: continue visiting children
24/// - `false`: skip children
25///
26/// The `leave_*` methods are called after visiting all children.
27///
28/// ## Per-Variant Expression Methods
29///
30/// For finer granularity, override the per-variant expression methods
31/// (e.g., `visit_identifier`, `visit_binary_op`, `visit_method_call`).
32/// These are called from `walk_expr` after the coarse `visit_expr` hook.
33/// Each receives the full `&Expr` node and its `Span`.
34pub trait Visitor: Sized {
35    // ===== Coarse-grained visitors (called on every node) =====
36
37    /// Called before visiting any expression. Return `false` to skip entirely
38    /// (neither per-variant method nor children will be visited).
39    fn visit_expr(&mut self, _expr: &Expr) -> bool {
40        true
41    }
42    /// Called after visiting an expression and all its children.
43    fn leave_expr(&mut self, _expr: &Expr) {}
44
45    // Statement visitors
46    fn visit_stmt(&mut self, _stmt: &Statement) -> bool {
47        true
48    }
49    fn leave_stmt(&mut self, _stmt: &Statement) {}
50
51    // Item visitors
52    fn visit_item(&mut self, _item: &Item) -> bool {
53        true
54    }
55    fn leave_item(&mut self, _item: &Item) {}
56
57    // Function definition visitors
58    fn visit_function(&mut self, _func: &FunctionDef) -> bool {
59        true
60    }
61    fn leave_function(&mut self, _func: &FunctionDef) {}
62
63    // Literal visitors (kept for backward compat — also called from walk_expr)
64    fn visit_literal(&mut self, _lit: &Literal) -> bool {
65        true
66    }
67    fn leave_literal(&mut self, _lit: &Literal) {}
68
69    // Block visitors (kept for backward compat — also called from walk_expr)
70    fn visit_block(&mut self, _block: &BlockExpr) -> bool {
71        true
72    }
73    fn leave_block(&mut self, _block: &BlockExpr) {}
74
75    // ===== Per-variant expression visitors =====
76    //
77    // Each method receives the full &Expr and its Span. Return `true` to
78    // continue walking children, `false` to skip children.
79    //
80    // Default implementations return `true` (walk children).
81
82    fn visit_expr_literal(&mut self, _expr: &Expr, _span: Span) -> bool {
83        true
84    }
85    fn visit_expr_identifier(&mut self, _expr: &Expr, _span: Span) -> bool {
86        true
87    }
88    fn visit_expr_data_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
89        true
90    }
91    fn visit_expr_data_datetime_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
92        true
93    }
94    fn visit_expr_data_relative_access(&mut self, _expr: &Expr, _span: Span) -> bool {
95        true
96    }
97    fn visit_expr_property_access(&mut self, _expr: &Expr, _span: Span) -> bool {
98        true
99    }
100    fn visit_expr_index_access(&mut self, _expr: &Expr, _span: Span) -> bool {
101        true
102    }
103    fn visit_expr_binary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
104        true
105    }
106    fn visit_expr_fuzzy_comparison(&mut self, _expr: &Expr, _span: Span) -> bool {
107        true
108    }
109    fn visit_expr_unary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
110        true
111    }
112    fn visit_expr_function_call(&mut self, _expr: &Expr, _span: Span) -> bool {
113        true
114    }
115    fn visit_expr_enum_constructor(&mut self, _expr: &Expr, _span: Span) -> bool {
116        true
117    }
118    fn visit_expr_time_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
119        true
120    }
121    fn visit_expr_datetime(&mut self, _expr: &Expr, _span: Span) -> bool {
122        true
123    }
124    fn visit_expr_pattern_ref(&mut self, _expr: &Expr, _span: Span) -> bool {
125        true
126    }
127    fn visit_expr_conditional(&mut self, _expr: &Expr, _span: Span) -> bool {
128        true
129    }
130    fn visit_expr_object(&mut self, _expr: &Expr, _span: Span) -> bool {
131        true
132    }
133    fn visit_expr_array(&mut self, _expr: &Expr, _span: Span) -> bool {
134        true
135    }
136    fn visit_expr_list_comprehension(&mut self, _expr: &Expr, _span: Span) -> bool {
137        true
138    }
139    fn visit_expr_block(&mut self, _expr: &Expr, _span: Span) -> bool {
140        true
141    }
142    fn visit_expr_type_assertion(&mut self, _expr: &Expr, _span: Span) -> bool {
143        true
144    }
145    fn visit_expr_instance_of(&mut self, _expr: &Expr, _span: Span) -> bool {
146        true
147    }
148    fn visit_expr_function_expr(&mut self, _expr: &Expr, _span: Span) -> bool {
149        true
150    }
151    fn visit_expr_duration(&mut self, _expr: &Expr, _span: Span) -> bool {
152        true
153    }
154    fn visit_expr_spread(&mut self, _expr: &Expr, _span: Span) -> bool {
155        true
156    }
157    fn visit_expr_if(&mut self, _expr: &Expr, _span: Span) -> bool {
158        true
159    }
160    fn visit_expr_while(&mut self, _expr: &Expr, _span: Span) -> bool {
161        true
162    }
163    fn visit_expr_for(&mut self, _expr: &Expr, _span: Span) -> bool {
164        true
165    }
166    fn visit_expr_loop(&mut self, _expr: &Expr, _span: Span) -> bool {
167        true
168    }
169    fn visit_expr_let(&mut self, _expr: &Expr, _span: Span) -> bool {
170        true
171    }
172    fn visit_expr_assign(&mut self, _expr: &Expr, _span: Span) -> bool {
173        true
174    }
175    fn visit_expr_break(&mut self, _expr: &Expr, _span: Span) -> bool {
176        true
177    }
178    fn visit_expr_continue(&mut self, _expr: &Expr, _span: Span) -> bool {
179        true
180    }
181    fn visit_expr_return(&mut self, _expr: &Expr, _span: Span) -> bool {
182        true
183    }
184    fn visit_expr_method_call(&mut self, _expr: &Expr, _span: Span) -> bool {
185        true
186    }
187    fn visit_expr_match(&mut self, _expr: &Expr, _span: Span) -> bool {
188        true
189    }
190    fn visit_expr_unit(&mut self, _expr: &Expr, _span: Span) -> bool {
191        true
192    }
193    fn visit_expr_range(&mut self, _expr: &Expr, _span: Span) -> bool {
194        true
195    }
196    fn visit_expr_timeframe_context(&mut self, _expr: &Expr, _span: Span) -> bool {
197        true
198    }
199    fn visit_expr_try_operator(&mut self, _expr: &Expr, _span: Span) -> bool {
200        true
201    }
202    fn visit_expr_using_impl(&mut self, _expr: &Expr, _span: Span) -> bool {
203        true
204    }
205    fn visit_expr_simulation_call(&mut self, _expr: &Expr, _span: Span) -> bool {
206        true
207    }
208    fn visit_expr_window_expr(&mut self, _expr: &Expr, _span: Span) -> bool {
209        true
210    }
211    fn visit_expr_from_query(&mut self, _expr: &Expr, _span: Span) -> bool {
212        true
213    }
214    fn visit_expr_struct_literal(&mut self, _expr: &Expr, _span: Span) -> bool {
215        true
216    }
217    fn visit_expr_await(&mut self, _expr: &Expr, _span: Span) -> bool {
218        true
219    }
220    fn visit_expr_join(&mut self, _expr: &Expr, _span: Span) -> bool {
221        true
222    }
223    fn visit_expr_annotated(&mut self, _expr: &Expr, _span: Span) -> bool {
224        true
225    }
226    fn visit_expr_async_let(&mut self, _expr: &Expr, _span: Span) -> bool {
227        true
228    }
229    fn visit_expr_async_scope(&mut self, _expr: &Expr, _span: Span) -> bool {
230        true
231    }
232    fn visit_expr_comptime(&mut self, _expr: &Expr, _span: Span) -> bool {
233        true
234    }
235    fn visit_expr_comptime_for(&mut self, _expr: &Expr, _span: Span) -> bool {
236        true
237    }
238    fn visit_expr_reference(&mut self, _expr: &Expr, _span: Span) -> bool {
239        true
240    }
241}
242
243// ===== Walk Functions =====
244
245/// Walk a program, visiting all items
246pub fn walk_program<V: Visitor>(visitor: &mut V, program: &Program) {
247    for item in &program.items {
248        walk_item(visitor, item);
249    }
250}
251
252/// Walk an item
253pub fn walk_item<V: Visitor>(visitor: &mut V, item: &Item) {
254    if !visitor.visit_item(item) {
255        return;
256    }
257
258    match item {
259        Item::Import(_, _) => {}
260        Item::Module(module_def, _) => {
261            for inner in &module_def.items {
262                walk_item(visitor, inner);
263            }
264        }
265        Item::Export(export, _) => match &export.item {
266            ExportItem::Function(func) => walk_function(visitor, func),
267            ExportItem::BuiltinFunction(_) => {}
268            ExportItem::BuiltinType(_) => {}
269            ExportItem::TypeAlias(_) => {}
270            ExportItem::Named(_) => {}
271            ExportItem::Enum(_) => {}
272            ExportItem::Struct(_) => {}
273            ExportItem::Trait(_) => {}
274            ExportItem::Annotation(annotation_def) => {
275                for handler in &annotation_def.handlers {
276                    walk_expr(visitor, &handler.body);
277                }
278            }
279            ExportItem::ForeignFunction(_) => {} // foreign bodies are opaque
280        },
281        Item::TypeAlias(_, _) => {}
282        Item::Trait(_, _) => {}
283        Item::Enum(_, _) => {}
284        Item::Extend(extend, _) => {
285            for method in &extend.methods {
286                for stmt in &method.body {
287                    walk_stmt(visitor, stmt);
288                }
289            }
290        }
291        Item::Impl(impl_block, _) => {
292            for method in &impl_block.methods {
293                for stmt in &method.body {
294                    walk_stmt(visitor, stmt);
295                }
296            }
297        }
298        Item::Function(func, _) => walk_function(visitor, func),
299        Item::Query(query, _) => walk_query(visitor, query),
300        Item::VariableDecl(decl, _) => {
301            if let Some(value) = &decl.value {
302                walk_expr(visitor, value);
303            }
304        }
305        Item::Assignment(assign, _) => {
306            walk_expr(visitor, &assign.value);
307        }
308        Item::Expression(expr, _) => walk_expr(visitor, expr),
309        Item::Stream(stream, _) => {
310            for decl in &stream.state {
311                if let Some(value) = &decl.value {
312                    walk_expr(visitor, value);
313                }
314            }
315            if let Some(stmts) = &stream.on_connect {
316                for stmt in stmts {
317                    walk_stmt(visitor, stmt);
318                }
319            }
320            if let Some(stmts) = &stream.on_disconnect {
321                for stmt in stmts {
322                    walk_stmt(visitor, stmt);
323                }
324            }
325            if let Some(on_event) = &stream.on_event {
326                for stmt in &on_event.body {
327                    walk_stmt(visitor, stmt);
328                }
329            }
330            if let Some(on_window) = &stream.on_window {
331                for stmt in &on_window.body {
332                    walk_stmt(visitor, stmt);
333                }
334            }
335            if let Some(on_error) = &stream.on_error {
336                for stmt in &on_error.body {
337                    walk_stmt(visitor, stmt);
338                }
339            }
340        }
341        Item::Test(test, _) => {
342            if let Some(setup) = &test.setup {
343                for stmt in setup {
344                    walk_stmt(visitor, stmt);
345                }
346            }
347            if let Some(teardown) = &test.teardown {
348                for stmt in teardown {
349                    walk_stmt(visitor, stmt);
350                }
351            }
352            for case in &test.test_cases {
353                for test_stmt in &case.body {
354                    walk_test_statement(visitor, test_stmt);
355                }
356            }
357        }
358        Item::Optimize(opt, _) => {
359            walk_expr(visitor, &opt.range.0);
360            walk_expr(visitor, &opt.range.1);
361            if let OptimizationMetric::Custom(expr) = &opt.metric {
362                walk_expr(visitor, expr);
363            }
364        }
365        Item::Statement(stmt, _) => walk_stmt(visitor, stmt),
366        Item::AnnotationDef(ann_def, _) => {
367            // Walk the lifecycle handlers of the annotation definition
368            for handler in &ann_def.handlers {
369                walk_expr(visitor, &handler.body);
370            }
371        }
372        Item::StructType(_, _) => {
373            // No expressions to walk in struct type definitions
374        }
375        Item::DataSource(ds, _) => {
376            walk_expr(visitor, &ds.provider_expr);
377        }
378        Item::QueryDecl(_, _) => {
379            // Query declarations have no walkable expressions (SQL is a string literal)
380        }
381        Item::Comptime(stmts, _) => {
382            for stmt in stmts {
383                walk_stmt(visitor, stmt);
384            }
385        }
386        Item::BuiltinTypeDecl(_, _) => {
387            // Declaration-only intrinsic
388        }
389        Item::BuiltinFunctionDecl(_, _) => {
390            // Declaration-only intrinsic
391        }
392        Item::ForeignFunction(_, _) => {
393            // Foreign function bodies are opaque to the Shape visitor
394        }
395    }
396
397    visitor.leave_item(item);
398}
399
400/// Walk a function definition
401pub fn walk_function<V: Visitor>(visitor: &mut V, func: &FunctionDef) {
402    if !visitor.visit_function(func) {
403        return;
404    }
405
406    // Visit parameter default values
407    for param in &func.params {
408        if let Some(default) = &param.default_value {
409            walk_expr(visitor, default);
410        }
411    }
412
413    // Visit body statements
414    for stmt in &func.body {
415        walk_stmt(visitor, stmt);
416    }
417
418    visitor.leave_function(func);
419}
420
421/// Walk a query
422pub fn walk_query<V: Visitor>(visitor: &mut V, query: &Query) {
423    match query {
424        Query::Backtest(backtest) => {
425            for (_, expr) in &backtest.parameters {
426                walk_expr(visitor, expr);
427            }
428        }
429        Query::Alert(alert) => {
430            walk_expr(visitor, &alert.condition);
431        }
432        Query::With(with_query) => {
433            // Walk CTEs
434            for cte in &with_query.ctes {
435                walk_query(visitor, &cte.query);
436            }
437            // Walk main query
438            walk_query(visitor, &with_query.query);
439        }
440    }
441}
442
443/// Walk a statement
444pub fn walk_stmt<V: Visitor>(visitor: &mut V, stmt: &Statement) {
445    if !visitor.visit_stmt(stmt) {
446        return;
447    }
448
449    match stmt {
450        Statement::Return(expr, _) => {
451            if let Some(e) = expr {
452                walk_expr(visitor, e);
453            }
454        }
455        Statement::Break(_) => {}
456        Statement::Continue(_) => {}
457        Statement::VariableDecl(decl, _) => {
458            if let Some(value) = &decl.value {
459                walk_expr(visitor, value);
460            }
461        }
462        Statement::Assignment(assign, _) => {
463            walk_expr(visitor, &assign.value);
464        }
465        Statement::Expression(expr, _) => walk_expr(visitor, expr),
466        Statement::For(for_loop, _) => {
467            match &for_loop.init {
468                ForInit::ForIn { iter, .. } => walk_expr(visitor, iter),
469                ForInit::ForC {
470                    init,
471                    condition,
472                    update,
473                } => {
474                    walk_stmt(visitor, init);
475                    walk_expr(visitor, condition);
476                    walk_expr(visitor, update);
477                }
478            }
479            for stmt in &for_loop.body {
480                walk_stmt(visitor, stmt);
481            }
482        }
483        Statement::While(while_loop, _) => {
484            walk_expr(visitor, &while_loop.condition);
485            for stmt in &while_loop.body {
486                walk_stmt(visitor, stmt);
487            }
488        }
489        Statement::If(if_stmt, _) => {
490            walk_expr(visitor, &if_stmt.condition);
491            for stmt in &if_stmt.then_body {
492                walk_stmt(visitor, stmt);
493            }
494            if let Some(else_body) = &if_stmt.else_body {
495                for stmt in else_body {
496                    walk_stmt(visitor, stmt);
497                }
498            }
499        }
500        Statement::Extend(ext, _) => {
501            for method in &ext.methods {
502                for stmt in &method.body {
503                    walk_stmt(visitor, stmt);
504                }
505            }
506        }
507        Statement::RemoveTarget(_) => {}
508        Statement::SetParamType { .. }
509        | Statement::SetReturnType { .. }
510        | Statement::SetReturnExpr { .. } => {}
511        Statement::SetParamValue { expression, .. } => {
512            walk_expr(visitor, expression);
513        }
514        Statement::ReplaceModuleExpr { expression, .. } => {
515            walk_expr(visitor, expression);
516        }
517        Statement::ReplaceBodyExpr { expression, .. } => {
518            walk_expr(visitor, expression);
519        }
520        Statement::ReplaceBody { body, .. } => {
521            for stmt in body {
522                walk_stmt(visitor, stmt);
523            }
524        }
525    }
526
527    visitor.leave_stmt(stmt);
528}
529
530/// Walk an expression - ALL VARIANTS HANDLED EXPLICITLY
531///
532/// Visit order:
533/// 1. `visit_expr(expr)` — return `false` to skip entirely
534/// 2. `visit_expr_<variant>(expr, span)` — return `false` to skip children
535/// 3. Walk children recursively
536/// 4. `leave_expr(expr)`
537pub fn walk_expr<V: Visitor>(visitor: &mut V, expr: &Expr) {
538    if !visitor.visit_expr(expr) {
539        return;
540    }
541
542    match expr {
543        // Leaf nodes (no children)
544        Expr::Literal(lit, span) => {
545            if visitor.visit_expr_literal(expr, *span) {
546                visitor.visit_literal(lit);
547                visitor.leave_literal(lit);
548            }
549        }
550        Expr::Identifier(_, span) => {
551            visitor.visit_expr_identifier(expr, *span);
552        }
553        Expr::DataRef(data_ref, span) => {
554            if visitor.visit_expr_data_ref(expr, *span) {
555                match &data_ref.index {
556                    DataIndex::Expression(e) => walk_expr(visitor, e),
557                    DataIndex::ExpressionRange(start, end) => {
558                        walk_expr(visitor, start);
559                        walk_expr(visitor, end);
560                    }
561                    DataIndex::Single(_) | DataIndex::Range(_, _) => {}
562                }
563            }
564        }
565        Expr::DataDateTimeRef(_, span) => {
566            visitor.visit_expr_data_datetime_ref(expr, *span);
567        }
568        Expr::DataRelativeAccess {
569            reference,
570            index,
571            span,
572        } => {
573            if visitor.visit_expr_data_relative_access(expr, *span) {
574                walk_expr(visitor, reference);
575                match index {
576                    DataIndex::Expression(e) => walk_expr(visitor, e),
577                    DataIndex::ExpressionRange(start, end) => {
578                        walk_expr(visitor, start);
579                        walk_expr(visitor, end);
580                    }
581                    DataIndex::Single(_) | DataIndex::Range(_, _) => {}
582                }
583            }
584        }
585        Expr::PropertyAccess { object, span, .. } => {
586            if visitor.visit_expr_property_access(expr, *span) {
587                walk_expr(visitor, object);
588            }
589        }
590        Expr::IndexAccess {
591            object,
592            index,
593            end_index,
594            span,
595        } => {
596            if visitor.visit_expr_index_access(expr, *span) {
597                walk_expr(visitor, object);
598                walk_expr(visitor, index);
599                if let Some(end) = end_index {
600                    walk_expr(visitor, end);
601                }
602            }
603        }
604        Expr::BinaryOp {
605            left, right, span, ..
606        } => {
607            if visitor.visit_expr_binary_op(expr, *span) {
608                walk_expr(visitor, left);
609                walk_expr(visitor, right);
610            }
611        }
612        Expr::FuzzyComparison {
613            left, right, span, ..
614        } => {
615            if visitor.visit_expr_fuzzy_comparison(expr, *span) {
616                walk_expr(visitor, left);
617                walk_expr(visitor, right);
618            }
619        }
620        Expr::UnaryOp { operand, span, .. } => {
621            if visitor.visit_expr_unary_op(expr, *span) {
622                walk_expr(visitor, operand);
623            }
624        }
625        Expr::FunctionCall {
626            args,
627            named_args,
628            span,
629            ..
630        } => {
631            if visitor.visit_expr_function_call(expr, *span) {
632                for arg in args {
633                    walk_expr(visitor, arg);
634                }
635                for (_, value) in named_args {
636                    walk_expr(visitor, value);
637                }
638            }
639        }
640        Expr::QualifiedFunctionCall {
641            args,
642            named_args,
643            span,
644            ..
645        } => {
646            if visitor.visit_expr_function_call(expr, *span) {
647                for arg in args {
648                    walk_expr(visitor, arg);
649                }
650                for (_, value) in named_args {
651                    walk_expr(visitor, value);
652                }
653            }
654        }
655        Expr::EnumConstructor { payload, span, .. } => {
656            if visitor.visit_expr_enum_constructor(expr, *span) {
657                match payload {
658                    EnumConstructorPayload::Unit => {}
659                    EnumConstructorPayload::Tuple(values) => {
660                        for value in values {
661                            walk_expr(visitor, value);
662                        }
663                    }
664                    EnumConstructorPayload::Struct(fields) => {
665                        for (_, value) in fields {
666                            walk_expr(visitor, value);
667                        }
668                    }
669                }
670            }
671        }
672        Expr::TimeRef(_, span) => {
673            visitor.visit_expr_time_ref(expr, *span);
674        }
675        Expr::DateTime(_, span) => {
676            visitor.visit_expr_datetime(expr, *span);
677        }
678        Expr::PatternRef(_, span) => {
679            visitor.visit_expr_pattern_ref(expr, *span);
680        }
681        Expr::Conditional {
682            condition,
683            then_expr,
684            else_expr,
685            span,
686        } => {
687            if visitor.visit_expr_conditional(expr, *span) {
688                walk_expr(visitor, condition);
689                walk_expr(visitor, then_expr);
690                if let Some(else_e) = else_expr {
691                    walk_expr(visitor, else_e);
692                }
693            }
694        }
695        Expr::Object(entries, span) => {
696            if visitor.visit_expr_object(expr, *span) {
697                for entry in entries {
698                    match entry {
699                        ObjectEntry::Field { value, .. } => walk_expr(visitor, value),
700                        ObjectEntry::Spread(spread_expr) => walk_expr(visitor, spread_expr),
701                    }
702                }
703            }
704        }
705        Expr::Array(elements, span) => {
706            if visitor.visit_expr_array(expr, *span) {
707                for elem in elements {
708                    walk_expr(visitor, elem);
709                }
710            }
711        }
712        Expr::TableRows(rows, _span) => {
713            for row in rows {
714                for elem in row {
715                    walk_expr(visitor, elem);
716                }
717            }
718        }
719        Expr::ListComprehension(comp, span) => {
720            if visitor.visit_expr_list_comprehension(expr, *span) {
721                walk_expr(visitor, &comp.element);
722                for clause in &comp.clauses {
723                    walk_expr(visitor, &clause.iterable);
724                    if let Some(filter) = &clause.filter {
725                        walk_expr(visitor, filter);
726                    }
727                }
728            }
729        }
730        Expr::Block(block, span) => {
731            if visitor.visit_expr_block(expr, *span) {
732                if visitor.visit_block(block) {
733                    for item in &block.items {
734                        match item {
735                            BlockItem::VariableDecl(decl) => {
736                                if let Some(value) = &decl.value {
737                                    walk_expr(visitor, value);
738                                }
739                            }
740                            BlockItem::Assignment(assign) => {
741                                walk_expr(visitor, &assign.value);
742                            }
743                            BlockItem::Statement(stmt) => {
744                                walk_stmt(visitor, stmt);
745                            }
746                            BlockItem::Expression(e) => walk_expr(visitor, e),
747                        }
748                    }
749                    visitor.leave_block(block);
750                }
751            }
752        }
753        Expr::TypeAssertion {
754            expr: inner, span, ..
755        } => {
756            if visitor.visit_expr_type_assertion(expr, *span) {
757                walk_expr(visitor, inner);
758            }
759        }
760        Expr::InstanceOf {
761            expr: inner, span, ..
762        } => {
763            if visitor.visit_expr_instance_of(expr, *span) {
764                walk_expr(visitor, inner);
765            }
766        }
767        Expr::FunctionExpr {
768            params, body, span, ..
769        } => {
770            if visitor.visit_expr_function_expr(expr, *span) {
771                for param in params {
772                    if let Some(default) = &param.default_value {
773                        walk_expr(visitor, default);
774                    }
775                }
776                for stmt in body {
777                    walk_stmt(visitor, stmt);
778                }
779            }
780        }
781        Expr::Duration(_, span) => {
782            visitor.visit_expr_duration(expr, *span);
783        }
784        Expr::Spread(inner, span) => {
785            if visitor.visit_expr_spread(expr, *span) {
786                walk_expr(visitor, inner);
787            }
788        }
789        Expr::If(if_expr, span) => {
790            if visitor.visit_expr_if(expr, *span) {
791                walk_expr(visitor, &if_expr.condition);
792                walk_expr(visitor, &if_expr.then_branch);
793                if let Some(else_branch) = &if_expr.else_branch {
794                    walk_expr(visitor, else_branch);
795                }
796            }
797        }
798        Expr::While(while_expr, span) => {
799            if visitor.visit_expr_while(expr, *span) {
800                walk_expr(visitor, &while_expr.condition);
801                walk_expr(visitor, &while_expr.body);
802            }
803        }
804        Expr::For(for_expr, span) => {
805            if visitor.visit_expr_for(expr, *span) {
806                walk_expr(visitor, &for_expr.iterable);
807                walk_expr(visitor, &for_expr.body);
808            }
809        }
810        Expr::Loop(loop_expr, span) => {
811            if visitor.visit_expr_loop(expr, *span) {
812                walk_expr(visitor, &loop_expr.body);
813            }
814        }
815        Expr::Let(let_expr, span) => {
816            if visitor.visit_expr_let(expr, *span) {
817                if let Some(value) = &let_expr.value {
818                    walk_expr(visitor, value);
819                }
820                walk_expr(visitor, &let_expr.body);
821            }
822        }
823        Expr::Assign(assign, span) => {
824            if visitor.visit_expr_assign(expr, *span) {
825                walk_expr(visitor, &assign.target);
826                walk_expr(visitor, &assign.value);
827            }
828        }
829        Expr::Break(inner, span) => {
830            if visitor.visit_expr_break(expr, *span) {
831                if let Some(e) = inner {
832                    walk_expr(visitor, e);
833                }
834            }
835        }
836        Expr::Continue(span) => {
837            visitor.visit_expr_continue(expr, *span);
838        }
839        Expr::Return(inner, span) => {
840            if visitor.visit_expr_return(expr, *span) {
841                if let Some(e) = inner {
842                    walk_expr(visitor, e);
843                }
844            }
845        }
846        Expr::MethodCall {
847            receiver,
848            args,
849            named_args,
850            span,
851            ..
852        } => {
853            if visitor.visit_expr_method_call(expr, *span) {
854                walk_expr(visitor, receiver);
855                for arg in args {
856                    walk_expr(visitor, arg);
857                }
858                for (_, value) in named_args {
859                    walk_expr(visitor, value);
860                }
861            }
862        }
863        Expr::Match(match_expr, span) => {
864            if visitor.visit_expr_match(expr, *span) {
865                walk_expr(visitor, &match_expr.scrutinee);
866                for arm in &match_expr.arms {
867                    if let Some(guard) = &arm.guard {
868                        walk_expr(visitor, guard);
869                    }
870                    walk_expr(visitor, &arm.body);
871                }
872            }
873        }
874        Expr::Unit(span) => {
875            visitor.visit_expr_unit(expr, *span);
876        }
877        Expr::Range {
878            start, end, span, ..
879        } => {
880            if visitor.visit_expr_range(expr, *span) {
881                if let Some(s) = start {
882                    walk_expr(visitor, s);
883                }
884                if let Some(e) = end {
885                    walk_expr(visitor, e);
886                }
887            }
888        }
889        Expr::TimeframeContext {
890            expr: inner, span, ..
891        } => {
892            if visitor.visit_expr_timeframe_context(expr, *span) {
893                walk_expr(visitor, inner);
894            }
895        }
896        Expr::TryOperator(inner, span) => {
897            if visitor.visit_expr_try_operator(expr, *span) {
898                walk_expr(visitor, inner);
899            }
900        }
901        Expr::UsingImpl {
902            expr: inner, span, ..
903        } => {
904            if visitor.visit_expr_using_impl(expr, *span) {
905                walk_expr(visitor, inner);
906            }
907        }
908        Expr::SimulationCall { params, span, .. } => {
909            if visitor.visit_expr_simulation_call(expr, *span) {
910                for (_, value) in params {
911                    walk_expr(visitor, value);
912                }
913            }
914        }
915        Expr::WindowExpr(window_expr, span) => {
916            if visitor.visit_expr_window_expr(expr, *span) {
917                // Walk function argument expressions
918                match &window_expr.function {
919                    WindowFunction::Lead { expr, default, .. }
920                    | WindowFunction::Lag { expr, default, .. } => {
921                        walk_expr(visitor, expr);
922                        if let Some(d) = default {
923                            walk_expr(visitor, d);
924                        }
925                    }
926                    WindowFunction::FirstValue(e)
927                    | WindowFunction::LastValue(e)
928                    | WindowFunction::Sum(e)
929                    | WindowFunction::Avg(e)
930                    | WindowFunction::Min(e)
931                    | WindowFunction::Max(e) => {
932                        walk_expr(visitor, e);
933                    }
934                    WindowFunction::NthValue(e, _) => {
935                        walk_expr(visitor, e);
936                    }
937                    WindowFunction::Count(opt_e) => {
938                        if let Some(e) = opt_e {
939                            walk_expr(visitor, e);
940                        }
941                    }
942                    WindowFunction::RowNumber
943                    | WindowFunction::Rank
944                    | WindowFunction::DenseRank
945                    | WindowFunction::Ntile(_) => {}
946                }
947                // Walk partition_by expressions
948                for e in &window_expr.over.partition_by {
949                    walk_expr(visitor, e);
950                }
951                // Walk order_by expressions
952                if let Some(order_by) = &window_expr.over.order_by {
953                    for (e, _) in &order_by.columns {
954                        walk_expr(visitor, e);
955                    }
956                }
957            }
958        }
959        Expr::FromQuery(from_query, span) => {
960            if visitor.visit_expr_from_query(expr, *span) {
961                // Walk source expression
962                walk_expr(visitor, &from_query.source);
963                // Walk each clause
964                for clause in &from_query.clauses {
965                    match clause {
966                        QueryClause::Where(pred) => {
967                            walk_expr(visitor, pred);
968                        }
969                        QueryClause::OrderBy(specs) => {
970                            for spec in specs {
971                                walk_expr(visitor, &spec.key);
972                            }
973                        }
974                        QueryClause::GroupBy { element, key, .. } => {
975                            walk_expr(visitor, element);
976                            walk_expr(visitor, key);
977                        }
978                        QueryClause::Join {
979                            source,
980                            left_key,
981                            right_key,
982                            ..
983                        } => {
984                            walk_expr(visitor, source);
985                            walk_expr(visitor, left_key);
986                            walk_expr(visitor, right_key);
987                        }
988                        QueryClause::Let { value, .. } => {
989                            walk_expr(visitor, value);
990                        }
991                    }
992                }
993                // Walk select expression
994                walk_expr(visitor, &from_query.select);
995            }
996        }
997        Expr::StructLiteral { fields, span, .. } => {
998            if visitor.visit_expr_struct_literal(expr, *span) {
999                for (_, value_expr) in fields {
1000                    walk_expr(visitor, value_expr);
1001                }
1002            }
1003        }
1004        Expr::Await(inner, span) => {
1005            if visitor.visit_expr_await(expr, *span) {
1006                walk_expr(visitor, inner);
1007            }
1008        }
1009        Expr::Join(join_expr, span) => {
1010            if visitor.visit_expr_join(expr, *span) {
1011                for branch in &join_expr.branches {
1012                    walk_expr(visitor, &branch.expr);
1013                }
1014            }
1015        }
1016        Expr::Annotated { target, span, .. } => {
1017            if visitor.visit_expr_annotated(expr, *span) {
1018                walk_expr(visitor, target);
1019            }
1020        }
1021        Expr::AsyncLet(async_let, span) => {
1022            if visitor.visit_expr_async_let(expr, *span) {
1023                walk_expr(visitor, &async_let.expr);
1024            }
1025        }
1026        Expr::AsyncScope(inner, span) => {
1027            if visitor.visit_expr_async_scope(expr, *span) {
1028                walk_expr(visitor, inner);
1029            }
1030        }
1031        Expr::Comptime(stmts, span) => {
1032            if visitor.visit_expr_comptime(expr, *span) {
1033                for stmt in stmts {
1034                    walk_stmt(visitor, stmt);
1035                }
1036            }
1037        }
1038        Expr::ComptimeFor(cf, span) => {
1039            if visitor.visit_expr_comptime_for(expr, *span) {
1040                walk_expr(visitor, &cf.iterable);
1041                for stmt in &cf.body {
1042                    walk_stmt(visitor, stmt);
1043                }
1044            }
1045        }
1046        Expr::Reference {
1047            expr: inner, span, ..
1048        } => {
1049            if visitor.visit_expr_reference(expr, *span) {
1050                walk_expr(visitor, inner);
1051            }
1052        }
1053    }
1054
1055    visitor.leave_expr(expr);
1056}
1057
1058/// Walk a test statement
1059fn walk_test_statement<V: Visitor>(visitor: &mut V, test_stmt: &TestStatement) {
1060    match test_stmt {
1061        TestStatement::Statement(stmt) => walk_stmt(visitor, stmt),
1062        TestStatement::Assert(assert) => {
1063            walk_expr(visitor, &assert.condition);
1064        }
1065        TestStatement::Expect(expect) => {
1066            walk_expr(visitor, &expect.actual);
1067            match &expect.matcher {
1068                ExpectationMatcher::ToBe(e) => walk_expr(visitor, e),
1069                ExpectationMatcher::ToEqual(e) => walk_expr(visitor, e),
1070                ExpectationMatcher::ToBeCloseTo { expected, .. } => walk_expr(visitor, expected),
1071                ExpectationMatcher::ToBeGreaterThan(e) => walk_expr(visitor, e),
1072                ExpectationMatcher::ToBeLessThan(e) => walk_expr(visitor, e),
1073                ExpectationMatcher::ToContain(e) => walk_expr(visitor, e),
1074                ExpectationMatcher::ToBeTruthy => {}
1075                ExpectationMatcher::ToBeFalsy => {}
1076                ExpectationMatcher::ToThrow(_) => {}
1077                ExpectationMatcher::ToMatchPattern { .. } => {}
1078            }
1079        }
1080        TestStatement::Should(should) => {
1081            walk_expr(visitor, &should.subject);
1082            match &should.matcher {
1083                ShouldMatcher::Be(e) => walk_expr(visitor, e),
1084                ShouldMatcher::Equal(e) => walk_expr(visitor, e),
1085                ShouldMatcher::Contain(e) => walk_expr(visitor, e),
1086                ShouldMatcher::Match(_) => {}
1087                ShouldMatcher::BeCloseTo { expected, .. } => walk_expr(visitor, expected),
1088            }
1089        }
1090        TestStatement::Fixture(fixture) => match fixture {
1091            TestFixture::WithData { data, body } => {
1092                walk_expr(visitor, data);
1093                for stmt in body {
1094                    walk_stmt(visitor, stmt);
1095                }
1096            }
1097            TestFixture::WithMock {
1098                mock_value, body, ..
1099            } => {
1100                if let Some(value) = mock_value {
1101                    walk_expr(visitor, value);
1102                }
1103                for stmt in body {
1104                    walk_stmt(visitor, stmt);
1105                }
1106            }
1107        },
1108    }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use super::*;
1114
1115    /// Simple visitor that counts expressions
1116    struct ExprCounter {
1117        count: usize,
1118    }
1119
1120    impl Visitor for ExprCounter {
1121        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1122            self.count += 1;
1123            true
1124        }
1125    }
1126
1127    #[test]
1128    fn test_visitor_counts_expressions() {
1129        let program = Program {
1130            items: vec![Item::Expression(
1131                Expr::BinaryOp {
1132                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1133                    op: BinaryOp::Add,
1134                    right: Box::new(Expr::Literal(Literal::Number(1.0), Span::DUMMY)),
1135                    span: Span::DUMMY,
1136                },
1137                Span::DUMMY,
1138            )],
1139            docs: shape_ast::ast::ProgramDocs::default(),
1140        };
1141
1142        let mut counter = ExprCounter { count: 0 };
1143        walk_program(&mut counter, &program);
1144
1145        // Should count: BinaryOp, Identifier, Literal = 3
1146        assert_eq!(counter.count, 3);
1147    }
1148
1149    #[test]
1150    fn test_visitor_handles_try_operator() {
1151        let program = Program {
1152            items: vec![Item::Expression(
1153                Expr::TryOperator(
1154                    Box::new(Expr::FunctionCall {
1155                        name: "some_function".to_string(),
1156                        args: vec![Expr::Literal(
1157                            Literal::String("arg".to_string()),
1158                            Span::DUMMY,
1159                        )],
1160                        named_args: vec![],
1161                        span: Span::DUMMY,
1162                    }),
1163                    Span::DUMMY,
1164                ),
1165                Span::DUMMY,
1166            )],
1167            docs: shape_ast::ast::ProgramDocs::default(),
1168        };
1169
1170        let mut counter = ExprCounter { count: 0 };
1171        walk_program(&mut counter, &program);
1172
1173        // Should count: TryOperator, FunctionCall, Literal = 3
1174        assert_eq!(counter.count, 3);
1175    }
1176
1177    /// Test that per-variant visitor methods work
1178    struct IdentifierCollector {
1179        names: Vec<String>,
1180    }
1181
1182    impl Visitor for IdentifierCollector {
1183        fn visit_expr_identifier(&mut self, expr: &Expr, _span: Span) -> bool {
1184            if let Expr::Identifier(name, _) = expr {
1185                self.names.push(name.clone());
1186            }
1187            true
1188        }
1189    }
1190
1191    #[test]
1192    fn test_per_variant_visitor_identifier() {
1193        let program = Program {
1194            items: vec![Item::Expression(
1195                Expr::BinaryOp {
1196                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1197                    op: BinaryOp::Add,
1198                    right: Box::new(Expr::Identifier("y".to_string(), Span::DUMMY)),
1199                    span: Span::DUMMY,
1200                },
1201                Span::DUMMY,
1202            )],
1203            docs: shape_ast::ast::ProgramDocs::default(),
1204        };
1205
1206        let mut collector = IdentifierCollector { names: vec![] };
1207        walk_program(&mut collector, &program);
1208
1209        assert_eq!(collector.names, vec!["x", "y"]);
1210    }
1211
1212    /// Test that per-variant method can skip children
1213    struct SkippingVisitor {
1214        count: usize,
1215    }
1216
1217    impl Visitor for SkippingVisitor {
1218        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1219            self.count += 1;
1220            true
1221        }
1222        // Skip children of BinaryOp
1223        fn visit_expr_binary_op(&mut self, _expr: &Expr, _span: Span) -> bool {
1224            false
1225        }
1226    }
1227
1228    #[test]
1229    fn test_per_variant_skip_children() {
1230        let program = Program {
1231            items: vec![Item::Expression(
1232                Expr::BinaryOp {
1233                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1234                    op: BinaryOp::Add,
1235                    right: Box::new(Expr::Literal(Literal::Number(1.0), Span::DUMMY)),
1236                    span: Span::DUMMY,
1237                },
1238                Span::DUMMY,
1239            )],
1240            docs: shape_ast::ast::ProgramDocs::default(),
1241        };
1242
1243        let mut v = SkippingVisitor { count: 0 };
1244        walk_program(&mut v, &program);
1245
1246        // Only BinaryOp counted, children skipped
1247        assert_eq!(v.count, 1);
1248    }
1249
1250    /// Test combined coarse + per-variant
1251    struct MatchCollector {
1252        match_count: usize,
1253        total_expr_count: usize,
1254    }
1255
1256    impl Visitor for MatchCollector {
1257        fn visit_expr(&mut self, _expr: &Expr) -> bool {
1258            self.total_expr_count += 1;
1259            true
1260        }
1261        fn visit_expr_match(&mut self, _expr: &Expr, _span: Span) -> bool {
1262            self.match_count += 1;
1263            true
1264        }
1265    }
1266
1267    #[test]
1268    fn test_coarse_and_per_variant_combined() {
1269        let program = Program {
1270            items: vec![Item::Expression(
1271                Expr::BinaryOp {
1272                    left: Box::new(Expr::Identifier("x".to_string(), Span::DUMMY)),
1273                    op: BinaryOp::Add,
1274                    right: Box::new(Expr::Identifier("y".to_string(), Span::DUMMY)),
1275                    span: Span::DUMMY,
1276                },
1277                Span::DUMMY,
1278            )],
1279            docs: shape_ast::ast::ProgramDocs::default(),
1280        };
1281
1282        let mut mc = MatchCollector {
1283            match_count: 0,
1284            total_expr_count: 0,
1285        };
1286        walk_program(&mut mc, &program);
1287
1288        assert_eq!(mc.total_expr_count, 3); // BinaryOp + x + y
1289        assert_eq!(mc.match_count, 0); // No Match expressions
1290    }
1291}