edb_engine/analysis/
visitor.rs

1// EDB - Ethereum Debugger
2// Copyright (C) 2024 Zhuo Zhang and Wuqi Zhang
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program. If not, see <https://www.gnu.org/licenses/>.
16
17use std::any::Any;
18
19use eyre::Result;
20use foundry_compilers::artifacts::*;
21use paste::paste;
22
23/// Controls the behavior of AST traversal during visitor pattern execution.
24///
25/// This enum is returned by pre-node visitor methods to indicate how the traversal
26/// should proceed. It allows visitors to control the flow of AST traversal without
27/// having to implement complex traversal logic themselves.
28///
29/// # Examples
30///
31/// ```rust
32/// use crate::analysis::visitor::{Visitor, VisitorAction};
33///
34/// struct MyVisitor;
35///
36/// impl Visitor for MyVisitor {
37///     fn visit_function_definition(&mut self, _def: &FunctionDefinition) -> Result<VisitorAction> {
38///         // Skip traversing function bodies to improve performance
39///         Ok(VisitorAction::SkipSubtree)
40///     }
41///
42///     fn visit_variable_declaration(&mut self, _decl: &VariableDeclaration) -> Result<VisitorAction> {
43///         // Continue normal traversal
44///         Ok(VisitorAction::Continue)
45///     }
46/// }
47/// ```
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum VisitorAction {
50    /// Continue normal AST traversal, visiting all child nodes.
51    ///
52    /// This is the default behavior that traverses the entire subtree
53    /// starting from the current node.
54    #[default]
55    Continue,
56
57    /// Skip traversing the subtree of the current node.
58    ///
59    /// When this variant is returned, the visitor will not traverse any
60    /// child nodes of the current node, but will still call the corresponding
61    /// post-visit method. This is useful for:
62    ///
63    /// - Performance optimization when child nodes are not needed
64    /// - Avoiding traversal of large subtrees (e.g., function bodies)
65    /// - Implementing conditional traversal logic
66    ///
67    /// # Note
68    ///
69    /// The post-visit method for the current node will still be called,
70    /// even when skipping the subtree.
71    SkipSubtree,
72}
73
74impl VisitorAction {
75    /// Execute the given function if the visitor action is `Continue`.
76    pub fn and_then<F, E>(self, f: F) -> Result<Self, E>
77    where
78        F: FnOnce() -> Result<Self, E>,
79    {
80        match self {
81            Self::Continue => f(),
82            _ => Ok(self),
83        }
84    }
85}
86
87/// Trait for implementing AST visitors that can traverse Solidity source code.
88///
89/// This trait provides methods for visiting each type of AST node in a Solidity
90/// source unit. Implementors can override specific methods to perform custom
91/// analysis or transformations on the AST.
92///
93/// All methods have default implementations that do nothing, allowing implementors
94/// to only override the methods they need.
95#[allow(missing_docs)]
96pub trait Visitor {
97    /// Visits a source unit (the root of a Solidity file).
98    ///
99    /// # Arguments
100    ///
101    /// * `_source_unit` - The source unit to visit
102    ///
103    /// # Returns
104    ///
105    /// A Result containing a VisitorAction indicating how to proceed with traversal.
106    fn visit_source_unit(&mut self, _source_unit: &SourceUnit) -> Result<VisitorAction> {
107        Ok(VisitorAction::Continue)
108    }
109    fn visit_import_directive(&mut self, _directive: &ImportDirective) -> Result<VisitorAction> {
110        Ok(VisitorAction::Continue)
111    }
112    fn visit_pragma_directive(&mut self, _directive: &PragmaDirective) -> Result<VisitorAction> {
113        Ok(VisitorAction::Continue)
114    }
115    fn visit_block(&mut self, _block: &Block) -> Result<VisitorAction> {
116        Ok(VisitorAction::Continue)
117    }
118    fn visit_statement(&mut self, _statement: &Statement) -> Result<VisitorAction> {
119        Ok(VisitorAction::Continue)
120    }
121    fn visit_expression(&mut self, _expression: &Expression) -> Result<VisitorAction> {
122        Ok(VisitorAction::Continue)
123    }
124    fn visit_function_call(&mut self, _function_call: &FunctionCall) -> Result<VisitorAction> {
125        Ok(VisitorAction::Continue)
126    }
127    fn visit_user_defined_type_name(
128        &mut self,
129        _type_name: &UserDefinedTypeName,
130    ) -> Result<VisitorAction> {
131        Ok(VisitorAction::Continue)
132    }
133    fn visit_identifier_path(
134        &mut self,
135        _identifier_path: &IdentifierPath,
136    ) -> Result<VisitorAction> {
137        Ok(VisitorAction::Continue)
138    }
139    fn visit_type_name(&mut self, _type_name: &TypeName) -> Result<VisitorAction> {
140        Ok(VisitorAction::Continue)
141    }
142    fn visit_parameter_list(&mut self, _parameter_list: &ParameterList) -> Result<VisitorAction> {
143        Ok(VisitorAction::Continue)
144    }
145    fn visit_function_definition(
146        &mut self,
147        _definition: &FunctionDefinition,
148    ) -> Result<VisitorAction> {
149        Ok(VisitorAction::Continue)
150    }
151    fn visit_enum_definition(&mut self, _definition: &EnumDefinition) -> Result<VisitorAction> {
152        Ok(VisitorAction::Continue)
153    }
154    fn visit_error_definition(&mut self, _definition: &ErrorDefinition) -> Result<VisitorAction> {
155        Ok(VisitorAction::Continue)
156    }
157    fn visit_event_definition(&mut self, _definition: &EventDefinition) -> Result<VisitorAction> {
158        Ok(VisitorAction::Continue)
159    }
160    fn visit_struct_definition(&mut self, _definition: &StructDefinition) -> Result<VisitorAction> {
161        Ok(VisitorAction::Continue)
162    }
163    fn visit_modifier_definition(
164        &mut self,
165        _definition: &ModifierDefinition,
166    ) -> Result<VisitorAction> {
167        Ok(VisitorAction::Continue)
168    }
169    fn visit_variable_declaration(
170        &mut self,
171        _declaration: &VariableDeclaration,
172    ) -> Result<VisitorAction> {
173        Ok(VisitorAction::Continue)
174    }
175    fn visit_overrides(&mut self, _specifier: &OverrideSpecifier) -> Result<VisitorAction> {
176        Ok(VisitorAction::Continue)
177    }
178    fn visit_user_defined_value_type(
179        &mut self,
180        _value_type: &UserDefinedValueTypeDefinition,
181    ) -> Result<VisitorAction> {
182        Ok(VisitorAction::Continue)
183    }
184    fn visit_contract_definition(
185        &mut self,
186        _definition: &ContractDefinition,
187    ) -> Result<VisitorAction> {
188        Ok(VisitorAction::Continue)
189    }
190    fn visit_using_for(&mut self, _directive: &UsingForDirective) -> Result<VisitorAction> {
191        Ok(VisitorAction::Continue)
192    }
193    fn visit_unary_operation(&mut self, _unary_op: &UnaryOperation) -> Result<VisitorAction> {
194        Ok(VisitorAction::Continue)
195    }
196    fn visit_binary_operation(&mut self, _binary_op: &BinaryOperation) -> Result<VisitorAction> {
197        Ok(VisitorAction::Continue)
198    }
199    fn visit_conditional(&mut self, _conditional: &Conditional) -> Result<VisitorAction> {
200        Ok(VisitorAction::Continue)
201    }
202    fn visit_tuple_expression(
203        &mut self,
204        _tuple_expression: &TupleExpression,
205    ) -> Result<VisitorAction> {
206        Ok(VisitorAction::Continue)
207    }
208    fn visit_new_expression(&mut self, _new_expression: &NewExpression) -> Result<VisitorAction> {
209        Ok(VisitorAction::Continue)
210    }
211    fn visit_assignment(&mut self, _assignment: &Assignment) -> Result<VisitorAction> {
212        Ok(VisitorAction::Continue)
213    }
214    fn visit_identifier(&mut self, _identifier: &Identifier) -> Result<VisitorAction> {
215        Ok(VisitorAction::Continue)
216    }
217    fn visit_index_access(&mut self, _index_access: &IndexAccess) -> Result<VisitorAction> {
218        Ok(VisitorAction::Continue)
219    }
220    fn visit_index_range_access(
221        &mut self,
222        _index_range_access: &IndexRangeAccess,
223    ) -> Result<VisitorAction> {
224        Ok(VisitorAction::Continue)
225    }
226    fn visit_while_statement(
227        &mut self,
228        _while_statement: &WhileStatement,
229    ) -> Result<VisitorAction> {
230        Ok(VisitorAction::Continue)
231    }
232    fn visit_for_statement(&mut self, _for_statement: &ForStatement) -> Result<VisitorAction> {
233        Ok(VisitorAction::Continue)
234    }
235    fn visit_if_statement(&mut self, _if_statement: &IfStatement) -> Result<VisitorAction> {
236        Ok(VisitorAction::Continue)
237    }
238    fn visit_do_while_statement(
239        &mut self,
240        _do_while_statement: &DoWhileStatement,
241    ) -> Result<VisitorAction> {
242        Ok(VisitorAction::Continue)
243    }
244    fn visit_emit_statement(&mut self, _emit_statement: &EmitStatement) -> Result<VisitorAction> {
245        Ok(VisitorAction::Continue)
246    }
247    fn visit_unchecked_block(
248        &mut self,
249        _unchecked_block: &UncheckedBlock,
250    ) -> Result<VisitorAction> {
251        Ok(VisitorAction::Continue)
252    }
253    fn visit_try_statement(&mut self, _try_statement: &TryStatement) -> Result<VisitorAction> {
254        Ok(VisitorAction::Continue)
255    }
256    fn visit_revert_statement(
257        &mut self,
258        _revert_statement: &RevertStatement,
259    ) -> Result<VisitorAction> {
260        Ok(VisitorAction::Continue)
261    }
262    fn visit_member_access(&mut self, _member_access: &MemberAccess) -> Result<VisitorAction> {
263        Ok(VisitorAction::Continue)
264    }
265    fn visit_mapping(&mut self, _mapping: &Mapping) -> Result<VisitorAction> {
266        Ok(VisitorAction::Continue)
267    }
268    fn visit_elementary_type_name(
269        &mut self,
270        _elementary_type_name: &ElementaryTypeName,
271    ) -> Result<VisitorAction> {
272        Ok(VisitorAction::Continue)
273    }
274    fn visit_literal(&mut self, _literal: &Literal) -> Result<VisitorAction> {
275        Ok(VisitorAction::Continue)
276    }
277    fn visit_function_type_name(
278        &mut self,
279        _function_type_name: &FunctionTypeName,
280    ) -> Result<VisitorAction> {
281        Ok(VisitorAction::Continue)
282    }
283    fn visit_array_type_name(&mut self, _array_type_name: &ArrayTypeName) -> Result<VisitorAction> {
284        Ok(VisitorAction::Continue)
285    }
286    fn visit_function_call_options(
287        &mut self,
288        _function_call: &FunctionCallOptions,
289    ) -> Result<VisitorAction> {
290        Ok(VisitorAction::Continue)
291    }
292    fn visit_return(&mut self, _return: &Return) -> Result<VisitorAction> {
293        Ok(VisitorAction::Continue)
294    }
295    fn visit_inheritance_specifier(
296        &mut self,
297        _specifier: &InheritanceSpecifier,
298    ) -> Result<VisitorAction> {
299        Ok(VisitorAction::Continue)
300    }
301    fn visit_modifier_invocation(
302        &mut self,
303        _invocation: &ModifierInvocation,
304    ) -> Result<VisitorAction> {
305        Ok(VisitorAction::Continue)
306    }
307    fn visit_inline_assembly(&mut self, _assembly: &InlineAssembly) -> Result<VisitorAction> {
308        Ok(VisitorAction::Continue)
309    }
310    fn visit_external_assembly_reference(
311        &mut self,
312        _ref: &ExternalInlineAssemblyReference,
313    ) -> Result<VisitorAction> {
314        Ok(VisitorAction::Continue)
315    }
316
317    fn post_visit_source_unit(&mut self, _source_unit: &SourceUnit) -> Result<()> {
318        Ok(())
319    }
320    fn post_visit_import_directive(&mut self, _directive: &ImportDirective) -> Result<()> {
321        Ok(())
322    }
323    fn post_visit_pragma_directive(&mut self, _directive: &PragmaDirective) -> Result<()> {
324        Ok(())
325    }
326    fn post_visit_block(&mut self, _block: &Block) -> Result<()> {
327        Ok(())
328    }
329    fn post_visit_statement(&mut self, _statement: &Statement) -> Result<()> {
330        Ok(())
331    }
332    fn post_visit_expression(&mut self, _expression: &Expression) -> Result<()> {
333        Ok(())
334    }
335    fn post_visit_function_call(&mut self, _function_call: &FunctionCall) -> Result<()> {
336        Ok(())
337    }
338    fn post_visit_user_defined_type_name(
339        &mut self,
340        _type_name: &UserDefinedTypeName,
341    ) -> Result<()> {
342        Ok(())
343    }
344    fn post_visit_identifier_path(&mut self, _identifier_path: &IdentifierPath) -> Result<()> {
345        Ok(())
346    }
347    fn post_visit_type_name(&mut self, _type_name: &TypeName) -> Result<()> {
348        Ok(())
349    }
350    fn post_visit_parameter_list(&mut self, _parameter_list: &ParameterList) -> Result<()> {
351        Ok(())
352    }
353    fn post_visit_function_definition(&mut self, _definition: &FunctionDefinition) -> Result<()> {
354        Ok(())
355    }
356    fn post_visit_enum_definition(&mut self, _definition: &EnumDefinition) -> Result<()> {
357        Ok(())
358    }
359    fn post_visit_error_definition(&mut self, _definition: &ErrorDefinition) -> Result<()> {
360        Ok(())
361    }
362    fn post_visit_event_definition(&mut self, _definition: &EventDefinition) -> Result<()> {
363        Ok(())
364    }
365    fn post_visit_struct_definition(&mut self, _definition: &StructDefinition) -> Result<()> {
366        Ok(())
367    }
368    fn post_visit_modifier_definition(&mut self, _definition: &ModifierDefinition) -> Result<()> {
369        Ok(())
370    }
371    fn post_visit_variable_declaration(
372        &mut self,
373        _declaration: &VariableDeclaration,
374    ) -> Result<()> {
375        Ok(())
376    }
377    fn post_visit_overrides(&mut self, _specifier: &OverrideSpecifier) -> Result<()> {
378        Ok(())
379    }
380    fn post_visit_user_defined_value_type(
381        &mut self,
382        _value_type: &UserDefinedValueTypeDefinition,
383    ) -> Result<()> {
384        Ok(())
385    }
386    fn post_visit_contract_definition(&mut self, _definition: &ContractDefinition) -> Result<()> {
387        Ok(())
388    }
389    fn post_visit_using_for(&mut self, _directive: &UsingForDirective) -> Result<()> {
390        Ok(())
391    }
392    fn post_visit_unary_operation(&mut self, _unary_op: &UnaryOperation) -> Result<()> {
393        Ok(())
394    }
395    fn post_visit_binary_operation(&mut self, _binary_op: &BinaryOperation) -> Result<()> {
396        Ok(())
397    }
398    fn post_visit_conditional(&mut self, _conditional: &Conditional) -> Result<()> {
399        Ok(())
400    }
401    fn post_visit_tuple_expression(&mut self, _tuple_expression: &TupleExpression) -> Result<()> {
402        Ok(())
403    }
404    fn post_visit_new_expression(&mut self, _new_expression: &NewExpression) -> Result<()> {
405        Ok(())
406    }
407    fn post_visit_assignment(&mut self, _assignment: &Assignment) -> Result<()> {
408        Ok(())
409    }
410    fn post_visit_identifier(&mut self, _identifier: &Identifier) -> Result<()> {
411        Ok(())
412    }
413    fn post_visit_index_access(&mut self, _index_access: &IndexAccess) -> Result<()> {
414        Ok(())
415    }
416    fn post_visit_index_range_access(
417        &mut self,
418        _index_range_access: &IndexRangeAccess,
419    ) -> Result<()> {
420        Ok(())
421    }
422    fn post_visit_while_statement(&mut self, _while_statement: &WhileStatement) -> Result<()> {
423        Ok(())
424    }
425    fn post_visit_for_statement(&mut self, _for_statement: &ForStatement) -> Result<()> {
426        Ok(())
427    }
428    fn post_visit_if_statement(&mut self, _if_statement: &IfStatement) -> Result<()> {
429        Ok(())
430    }
431    fn post_visit_do_while_statement(
432        &mut self,
433        _do_while_statement: &DoWhileStatement,
434    ) -> Result<()> {
435        Ok(())
436    }
437    fn post_visit_emit_statement(&mut self, _emit_statement: &EmitStatement) -> Result<()> {
438        Ok(())
439    }
440    fn post_visit_unchecked_block(&mut self, _unchecked_block: &UncheckedBlock) -> Result<()> {
441        Ok(())
442    }
443    fn post_visit_try_statement(&mut self, _try_statement: &TryStatement) -> Result<()> {
444        Ok(())
445    }
446    fn post_visit_revert_statement(&mut self, _revert_statement: &RevertStatement) -> Result<()> {
447        Ok(())
448    }
449    fn post_visit_member_access(&mut self, _member_access: &MemberAccess) -> Result<()> {
450        Ok(())
451    }
452    fn post_visit_mapping(&mut self, _mapping: &Mapping) -> Result<()> {
453        Ok(())
454    }
455    fn post_visit_elementary_type_name(
456        &mut self,
457        _elementary_type_name: &ElementaryTypeName,
458    ) -> Result<()> {
459        Ok(())
460    }
461    fn post_visit_literal(&mut self, _literal: &Literal) -> Result<()> {
462        Ok(())
463    }
464    fn post_visit_function_type_name(
465        &mut self,
466        _function_type_name: &FunctionTypeName,
467    ) -> Result<()> {
468        Ok(())
469    }
470    fn post_visit_array_type_name(&mut self, _array_type_name: &ArrayTypeName) -> Result<()> {
471        Ok(())
472    }
473    fn post_visit_function_call_options(
474        &mut self,
475        _function_call: &FunctionCallOptions,
476    ) -> Result<()> {
477        Ok(())
478    }
479    fn post_visit_return(&mut self, _return: &Return) -> Result<()> {
480        Ok(())
481    }
482    fn post_visit_inheritance_specifier(
483        &mut self,
484        _specifier: &InheritanceSpecifier,
485    ) -> Result<()> {
486        Ok(())
487    }
488    fn post_visit_modifier_invocation(&mut self, _invocation: &ModifierInvocation) -> Result<()> {
489        Ok(())
490    }
491    fn post_visit_inline_assembly(&mut self, _assembly: &InlineAssembly) -> Result<()> {
492        Ok(())
493    }
494    fn post_visit_external_assembly_reference(
495        &mut self,
496        _ref: &ExternalInlineAssemblyReference,
497    ) -> Result<()> {
498        Ok(())
499    }
500}
501
502/// Trait for AST nodes that can be walked by a visitor.
503///
504/// This trait is implemented by AST nodes that support visitor pattern traversal.
505/// It provides a way to walk through the AST structure and apply visitor operations
506/// to each node.
507pub trait Walk: Any {
508    /// Walks this AST node with the given visitor.
509    ///
510    /// # Arguments
511    ///
512    /// * `visitor` - The visitor to apply to this node and its children
513    ///
514    /// # Returns
515    ///
516    /// A Result indicating success or failure of the walk operation.
517    fn walk(&self, visitor: &mut dyn Visitor) -> Result<()>;
518}
519
520macro_rules! impl_walk {
521    // Implement `Walk` for a type, calling the given function.
522    ($ty:ty, | $val:ident, $visitor:ident | $e:expr) => {
523        impl Walk for $ty {
524            fn walk(&self, visitor: &mut dyn Visitor) -> Result<()> {
525                let $val = self;
526                let $visitor = visitor;
527                $e
528            }
529        }
530    };
531    ($ty:ty, $func:ident) => {
532        impl_walk!($ty, |obj, visitor| {
533            match visitor.$func(obj)? {
534                VisitorAction::Continue => {
535                    paste! { visitor.[<post_ $func>](obj)?; }
536                    Ok(())
537                }
538                VisitorAction::SkipSubtree => {
539                    paste! { visitor.[<post_ $func>](obj)?; }
540                    Ok(())
541                }
542            }
543        });
544    };
545    ($ty:ty, $func:ident, | $val:ident, $visitor:ident | $e:expr) => {
546        impl_walk!($ty, |$val, $visitor| {
547            match $visitor.$func($val)? {
548                VisitorAction::Continue => {
549                    let r = $e;
550
551                    #[allow(clippy::question_mark)]
552                    if r.is_err() {
553                        return r;
554                    }
555
556                    paste! { $visitor.[<post_ $func>]($val)?; }
557                    Ok(())
558                }
559                VisitorAction::SkipSubtree => {
560                    paste! { $visitor.[<post_ $func>]($val)?; }
561                    Ok(())
562                }
563            }
564        });
565    };
566}
567
568impl_walk!(SourceUnit, visit_source_unit, |source_unit, visitor| {
569    for node in &source_unit.nodes {
570        node.walk(visitor)?;
571    }
572    Ok(())
573});
574
575impl_walk!(SourceUnitPart, |part, visitor| {
576    match part {
577        SourceUnitPart::ContractDefinition(contract) => contract.walk(visitor),
578        SourceUnitPart::UsingForDirective(directive) => directive.walk(visitor),
579        SourceUnitPart::ErrorDefinition(error) => error.walk(visitor),
580        SourceUnitPart::StructDefinition(struct_) => struct_.walk(visitor),
581        SourceUnitPart::VariableDeclaration(declaration) => declaration.walk(visitor),
582        SourceUnitPart::FunctionDefinition(function) => function.walk(visitor),
583        SourceUnitPart::UserDefinedValueTypeDefinition(value_type) => value_type.walk(visitor),
584        SourceUnitPart::ImportDirective(directive) => directive.walk(visitor),
585        SourceUnitPart::EnumDefinition(enum_) => enum_.walk(visitor),
586        SourceUnitPart::PragmaDirective(directive) => directive.walk(visitor),
587        SourceUnitPart::EventDefinition(event) => event.walk(visitor),
588    }
589});
590
591impl_walk!(ContractDefinition, visit_contract_definition, |contract, visitor| {
592    for base_contract in &contract.base_contracts {
593        base_contract.walk(visitor)?;
594    }
595
596    for part in &contract.nodes {
597        match part {
598            ContractDefinitionPart::FunctionDefinition(function) => function.walk(visitor),
599            ContractDefinitionPart::ErrorDefinition(error) => error.walk(visitor),
600            ContractDefinitionPart::EventDefinition(event) => event.walk(visitor),
601            ContractDefinitionPart::StructDefinition(struct_) => struct_.walk(visitor),
602            ContractDefinitionPart::VariableDeclaration(declaration) => declaration.walk(visitor),
603            ContractDefinitionPart::ModifierDefinition(modifier) => modifier.walk(visitor),
604            ContractDefinitionPart::UserDefinedValueTypeDefinition(definition) => {
605                definition.walk(visitor)
606            }
607            ContractDefinitionPart::UsingForDirective(directive) => directive.walk(visitor),
608            ContractDefinitionPart::EnumDefinition(enum_) => enum_.walk(visitor),
609        }?;
610    }
611    Ok(())
612});
613
614impl_walk!(Expression, visit_expression, |expr, visitor| {
615    match expr {
616        Expression::FunctionCall(expression) => expression.walk(visitor),
617        Expression::MemberAccess(member_access) => member_access.walk(visitor),
618        Expression::IndexAccess(index_access) => index_access.walk(visitor),
619        Expression::UnaryOperation(unary_op) => unary_op.walk(visitor),
620        Expression::BinaryOperation(expression) => expression.walk(visitor),
621        Expression::Conditional(expression) => expression.walk(visitor),
622        Expression::TupleExpression(tuple) => tuple.walk(visitor),
623        Expression::NewExpression(expression) => expression.walk(visitor),
624        Expression::Assignment(expression) => expression.walk(visitor),
625        Expression::Identifier(identifier) => identifier.walk(visitor),
626        Expression::FunctionCallOptions(function_call) => function_call.walk(visitor),
627        Expression::IndexRangeAccess(range_access) => range_access.walk(visitor),
628        Expression::Literal(literal) => literal.walk(visitor),
629        Expression::ElementaryTypeNameExpression(type_name) => type_name.walk(visitor),
630    }
631});
632
633impl_walk!(Statement, visit_statement, |statement, visitor| {
634    match statement {
635        Statement::Block(block) => block.walk(visitor),
636        Statement::WhileStatement(statement) => statement.walk(visitor),
637        Statement::ForStatement(statement) => statement.walk(visitor),
638        Statement::IfStatement(statement) => statement.walk(visitor),
639        Statement::DoWhileStatement(statement) => statement.walk(visitor),
640        Statement::EmitStatement(statement) => statement.walk(visitor),
641        Statement::VariableDeclarationStatement(statement) => statement.walk(visitor),
642        Statement::ExpressionStatement(statement) => statement.walk(visitor),
643        Statement::UncheckedBlock(statement) => statement.walk(visitor),
644        Statement::TryStatement(statement) => statement.walk(visitor),
645        Statement::RevertStatement(statement) => statement.walk(visitor),
646        Statement::Return(statement) => statement.walk(visitor),
647        Statement::InlineAssembly(assembly) => assembly.walk(visitor),
648        Statement::Break(_) | Statement::Continue(_) | Statement::PlaceholderStatement(_) => Ok(()),
649    }
650});
651
652impl_walk!(FunctionDefinition, visit_function_definition, |function, visitor| {
653    function.parameters.walk(visitor)?;
654    function.return_parameters.walk(visitor)?;
655
656    if let Some(overrides) = &function.overrides {
657        overrides.walk(visitor)?;
658    }
659
660    if let Some(body) = &function.body {
661        body.walk(visitor)?;
662    }
663
664    for m in &function.modifiers {
665        m.walk(visitor)?;
666    }
667    Ok(())
668});
669
670impl_walk!(ErrorDefinition, visit_error_definition, |error, visitor| {
671    error.parameters.walk(visitor)
672});
673
674impl_walk!(EventDefinition, visit_event_definition, |event, visitor| {
675    event.parameters.walk(visitor)
676});
677
678impl_walk!(StructDefinition, visit_struct_definition, |struct_, visitor| {
679    for member in &struct_.members {
680        member.walk(visitor)?;
681    }
682    Ok(())
683});
684
685impl_walk!(ModifierDefinition, visit_modifier_definition, |modifier, visitor| {
686    if let Some(body) = modifier.body.as_ref() {
687        body.walk(visitor)?;
688    }
689    if let Some(override_) = &modifier.overrides {
690        override_.walk(visitor)?;
691    }
692    modifier.parameters.walk(visitor)?;
693    Ok(())
694});
695
696impl_walk!(VariableDeclaration, visit_variable_declaration, |declaration, visitor| {
697    if let Some(value) = &declaration.value {
698        value.walk(visitor)?;
699    }
700
701    if let Some(type_name) = &declaration.type_name {
702        type_name.walk(visitor)?;
703    }
704
705    Ok(())
706});
707
708impl_walk!(OverrideSpecifier, visit_overrides, |override_, visitor| {
709    for type_name in &override_.overrides {
710        type_name.walk(visitor)?;
711    }
712    Ok(())
713});
714
715impl_walk!(UserDefinedValueTypeDefinition, visit_user_defined_value_type, |value_type, visitor| {
716    value_type.underlying_type.walk(visitor)
717});
718
719impl_walk!(FunctionCallOptions, visit_function_call_options, |function_call, visitor| {
720    function_call.expression.walk(visitor)?;
721    for option in &function_call.options {
722        option.walk(visitor)?;
723    }
724    Ok(())
725});
726
727impl_walk!(Return, visit_return, |return_, visitor| {
728    if let Some(expr) = return_.expression.as_ref() {
729        expr.walk(visitor)?;
730    }
731    Ok(())
732});
733
734impl_walk!(UsingForDirective, visit_using_for, |directive, visitor| {
735    if let Some(type_name) = &directive.type_name {
736        type_name.walk(visitor)?;
737    }
738    if let Some(library_name) = &directive.library_name {
739        library_name.walk(visitor)?;
740    }
741    for function in &directive.function_list {
742        function.walk(visitor)?;
743    }
744
745    Ok(())
746});
747
748impl_walk!(UnaryOperation, visit_unary_operation, |unary_op, visitor| {
749    unary_op.sub_expression.walk(visitor)
750});
751
752impl_walk!(BinaryOperation, visit_binary_operation, |binary_op, visitor| {
753    binary_op.lhs.walk(visitor)?;
754    binary_op.rhs.walk(visitor)?;
755    Ok(())
756});
757
758impl_walk!(Conditional, visit_conditional, |conditional, visitor| {
759    conditional.condition.walk(visitor)?;
760    conditional.true_expression.walk(visitor)?;
761    conditional.false_expression.walk(visitor)?;
762    Ok(())
763});
764
765impl_walk!(TupleExpression, visit_tuple_expression, |tuple_expression, visitor| {
766    for component in tuple_expression.components.iter().filter_map(|component| component.as_ref()) {
767        component.walk(visitor)?;
768    }
769    Ok(())
770});
771
772impl_walk!(NewExpression, visit_new_expression, |new_expression, visitor| {
773    new_expression.type_name.walk(visitor)
774});
775
776impl_walk!(Assignment, visit_assignment, |assignment, visitor| {
777    assignment.lhs.walk(visitor)?;
778    assignment.rhs.walk(visitor)?;
779    Ok(())
780});
781
782impl_walk!(IfStatement, visit_if_statement, |if_statement, visitor| {
783    if_statement.condition.walk(visitor)?;
784    if_statement.true_body.walk(visitor)?;
785
786    if let Some(false_body) = &if_statement.false_body {
787        false_body.walk(visitor)?;
788    }
789
790    Ok(())
791});
792
793impl_walk!(IndexAccess, visit_index_access, |index_access, visitor| {
794    index_access.base_expression.walk(visitor)?;
795    if let Some(index_expression) = &index_access.index_expression {
796        index_expression.walk(visitor)?;
797    }
798    Ok(())
799});
800
801impl_walk!(IndexRangeAccess, visit_index_range_access, |index_range_access, visitor| {
802    index_range_access.base_expression.walk(visitor)?;
803    if let Some(start_expression) = &index_range_access.start_expression {
804        start_expression.walk(visitor)?;
805    }
806    if let Some(end_expression) = &index_range_access.end_expression {
807        end_expression.walk(visitor)?;
808    }
809    Ok(())
810});
811
812impl_walk!(WhileStatement, visit_while_statement, |while_statement, visitor| {
813    while_statement.condition.walk(visitor)?;
814    while_statement.body.walk(visitor)?;
815    Ok(())
816});
817
818impl_walk!(ForStatement, visit_for_statement, |for_statement, visitor| {
819    for_statement.body.walk(visitor)?;
820    if let Some(condition) = &for_statement.condition {
821        condition.walk(visitor)?;
822    }
823
824    if let Some(loop_expression) = &for_statement.loop_expression {
825        loop_expression.walk(visitor)?;
826    }
827
828    if let Some(initialization_expr) = &for_statement.initialization_expression {
829        initialization_expr.walk(visitor)?;
830    }
831
832    Ok(())
833});
834
835impl_walk!(DoWhileStatement, visit_do_while_statement, |do_while_statement, visitor| {
836    do_while_statement.body.walk(visitor)?;
837    do_while_statement.condition.walk(visitor)?;
838    Ok(())
839});
840
841impl_walk!(EmitStatement, visit_emit_statement, |emit_statement, visitor| {
842    emit_statement.event_call.walk(visitor)
843});
844
845impl_walk!(VariableDeclarationStatement, |stmt, visitor| {
846    for declaration in stmt.declarations.iter().filter_map(|d| d.as_ref()) {
847        declaration.walk(visitor)?;
848    }
849    if let Some(initial_value) = &stmt.initial_value {
850        initial_value.walk(visitor)?;
851    }
852    Ok(())
853});
854
855impl_walk!(UncheckedBlock, visit_unchecked_block, |unchecked_block, visitor| {
856    for statement in &unchecked_block.statements {
857        statement.walk(visitor)?;
858    }
859    Ok(())
860});
861
862impl_walk!(TryStatement, visit_try_statement, |try_statement, visitor| {
863    for clause in &try_statement.clauses {
864        clause.block.walk(visitor)?;
865
866        if let Some(parameter_list) = &clause.parameters {
867            parameter_list.walk(visitor)?;
868        }
869    }
870
871    try_statement.external_call.walk(visitor)
872});
873
874impl_walk!(RevertStatement, visit_revert_statement, |revert_statement, visitor| {
875    revert_statement.error_call.walk(visitor)
876});
877
878impl_walk!(MemberAccess, visit_member_access, |member_access, visitor| {
879    member_access.expression.walk(visitor)
880});
881
882impl_walk!(FunctionCall, visit_function_call, |function_call, visitor| {
883    function_call.expression.walk(visitor)?;
884    for argument in &function_call.arguments {
885        argument.walk(visitor)?;
886    }
887    Ok(())
888});
889
890impl_walk!(Block, visit_block, |block, visitor| {
891    for statement in &block.statements {
892        statement.walk(visitor)?;
893    }
894    Ok(())
895});
896
897impl_walk!(UserDefinedTypeName, visit_user_defined_type_name, |type_name, visitor| {
898    if let Some(path_node) = &type_name.path_node {
899        path_node.walk(visitor)?;
900    }
901    Ok(())
902});
903
904impl_walk!(TypeName, visit_type_name, |type_name, visitor| {
905    match type_name {
906        TypeName::ElementaryTypeName(type_name) => type_name.walk(visitor),
907        TypeName::UserDefinedTypeName(type_name) => type_name.walk(visitor),
908        TypeName::Mapping(mapping) => mapping.walk(visitor),
909        TypeName::ArrayTypeName(array) => array.walk(visitor),
910        TypeName::FunctionTypeName(function) => function.walk(visitor),
911    }
912});
913
914impl_walk!(FunctionTypeName, visit_function_type_name, |function, visitor| {
915    function.parameter_types.walk(visitor)?;
916    function.return_parameter_types.walk(visitor)?;
917    Ok(())
918});
919
920impl_walk!(ParameterList, visit_parameter_list, |parameter_list, visitor| {
921    for parameter in &parameter_list.parameters {
922        parameter.walk(visitor)?;
923    }
924    Ok(())
925});
926
927impl_walk!(Mapping, visit_mapping, |mapping, visitor| {
928    mapping.key_type.walk(visitor)?;
929    mapping.value_type.walk(visitor)?;
930    Ok(())
931});
932
933impl_walk!(ArrayTypeName, visit_array_type_name, |array, visitor| {
934    array.base_type.walk(visitor)?;
935    if let Some(length) = &array.length {
936        length.walk(visitor)?;
937    }
938    Ok(())
939});
940
941impl_walk!(InheritanceSpecifier, visit_inheritance_specifier, |specifier, visitor| {
942    specifier.base_name.walk(visitor)?;
943    for arg in &specifier.arguments {
944        arg.walk(visitor)?;
945    }
946    Ok(())
947});
948
949impl_walk!(ModifierInvocation, visit_modifier_invocation, |invocation, visitor| {
950    for arg in &invocation.arguments {
951        arg.walk(visitor)?;
952    }
953    invocation.modifier_name.walk(visitor)?;
954    Ok(())
955});
956
957impl_walk!(InlineAssembly, visit_inline_assembly, |assembly, visitor| {
958    for reference in &assembly.external_references {
959        reference.walk(visitor)?;
960    }
961    Ok(())
962});
963
964impl_walk!(ExternalInlineAssemblyReference, visit_external_assembly_reference);
965
966impl_walk!(ElementaryTypeName, visit_elementary_type_name);
967impl_walk!(Literal, visit_literal);
968impl_walk!(ImportDirective, visit_import_directive);
969impl_walk!(PragmaDirective, visit_pragma_directive);
970impl_walk!(IdentifierPath, visit_identifier_path);
971impl_walk!(EnumDefinition, visit_enum_definition);
972impl_walk!(Identifier, visit_identifier);
973
974impl_walk!(UserDefinedTypeNameOrIdentifierPath, |type_name, visitor| {
975    match type_name {
976        UserDefinedTypeNameOrIdentifierPath::UserDefinedTypeName(type_name) => {
977            type_name.walk(visitor)
978        }
979        UserDefinedTypeNameOrIdentifierPath::IdentifierPath(identifier_path) => {
980            identifier_path.walk(visitor)
981        }
982    }
983});
984
985impl_walk!(BlockOrStatement, |block_or_statement, visitor| {
986    match block_or_statement {
987        BlockOrStatement::Block(block) => block.walk(visitor),
988        BlockOrStatement::Statement(statement) => statement.walk(visitor),
989    }
990});
991
992impl_walk!(ExpressionOrVariableDeclarationStatement, |val, visitor| {
993    match val {
994        ExpressionOrVariableDeclarationStatement::ExpressionStatement(expression) => {
995            expression.walk(visitor)
996        }
997        ExpressionOrVariableDeclarationStatement::VariableDeclarationStatement(stmt) => {
998            stmt.walk(visitor)
999        }
1000    }
1001});
1002
1003impl_walk!(IdentifierOrIdentifierPath, |val, visitor| {
1004    match val {
1005        IdentifierOrIdentifierPath::Identifier(ident) => ident.walk(visitor),
1006        IdentifierOrIdentifierPath::IdentifierPath(path) => path.walk(visitor),
1007    }
1008});
1009
1010impl_walk!(ExpressionStatement, |expression_statement, visitor| {
1011    expression_statement.expression.walk(visitor)
1012});
1013
1014impl_walk!(ElementaryTypeNameExpression, |type_name, visitor| {
1015    type_name.type_name.walk(visitor)
1016});
1017
1018impl_walk!(ElementaryOrRawTypeName, |type_name, visitor| {
1019    match type_name {
1020        ElementaryOrRawTypeName::ElementaryTypeName(type_name) => type_name.walk(visitor),
1021        ElementaryOrRawTypeName::Raw(_) => Ok(()),
1022    }
1023});
1024
1025impl_walk!(UsingForFunctionItem, |item, visitor| {
1026    match item {
1027        UsingForFunctionItem::Function(func) => func.function.walk(visitor),
1028        UsingForFunctionItem::OverloadedOperator(operator) => operator.walk(visitor),
1029    }
1030});
1031
1032impl_walk!(OverloadedOperator, |operator, visitor| operator.definition.walk(visitor));
1033
1034#[cfg(test)]
1035mod tests {
1036    use super::*;
1037    use crate::utils::compile_contract_source_to_source_unit;
1038    use semver::Version;
1039
1040    /// A test visitor that counts nodes and skips function bodies
1041    struct SkipFunctionBodiesVisitor {
1042        node_count: usize,
1043        function_count: usize,
1044    }
1045
1046    impl SkipFunctionBodiesVisitor {
1047        fn new() -> Self {
1048            Self { node_count: 0, function_count: 0 }
1049        }
1050    }
1051
1052    impl Visitor for SkipFunctionBodiesVisitor {
1053        fn visit_function_definition(
1054            &mut self,
1055            _definition: &FunctionDefinition,
1056        ) -> Result<VisitorAction> {
1057            self.function_count += 1;
1058            // Skip the function body (subtree) to avoid counting nodes inside functions
1059            Ok(VisitorAction::SkipSubtree)
1060        }
1061
1062        fn visit_contract_definition(
1063            &mut self,
1064            _definition: &ContractDefinition,
1065        ) -> Result<VisitorAction> {
1066            self.node_count += 1;
1067            Ok(VisitorAction::Continue)
1068        }
1069
1070        fn visit_variable_declaration(
1071            &mut self,
1072            _declaration: &VariableDeclaration,
1073        ) -> Result<VisitorAction> {
1074            self.node_count += 1;
1075            Ok(VisitorAction::Continue)
1076        }
1077
1078        fn visit_statement(&mut self, _statement: &Statement) -> Result<VisitorAction> {
1079            self.node_count += 1;
1080            Ok(VisitorAction::Continue)
1081        }
1082    }
1083
1084    #[test]
1085    fn test_skip_subtree_functionality() {
1086        let source = r#"
1087        contract TestContract {
1088            uint256 public value1;
1089            uint256 public value2;
1090
1091            function setValue1(uint256 newValue) public {
1092                value1 = newValue;
1093                emit ValueSet1(newValue);
1094            }
1095
1096            function setValue2(uint256 newValue) public {
1097                value2 = newValue;
1098                emit ValueSet2(newValue);
1099            }
1100
1101            event ValueSet1(uint256 value);
1102            event ValueSet2(uint256 value);
1103        }
1104        "#;
1105
1106        let version = Version::parse("0.8.19").unwrap();
1107        let source_unit = compile_contract_source_to_source_unit(version, source, true)
1108            .expect("Failed to compile contract");
1109
1110        let mut visitor = SkipFunctionBodiesVisitor::new();
1111        source_unit.walk(&mut visitor).expect("Failed to walk AST");
1112
1113        // Should count the contract definition and variable declarations
1114        // but skip all nodes inside function bodies due to SkipSubtree
1115        assert_eq!(visitor.function_count, 2, "Should have visited 2 functions");
1116        assert!(visitor.node_count > 0, "Should have counted some nodes");
1117        assert!(visitor.node_count < 20, "Should have skipped many nodes due to SkipSubtree");
1118    }
1119}