Skip to main content

dissolve_python/
ast_visitor.rs

1// Copyright (C) 2024 Jelmer Vernooij <jelmer@samba.org>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unified AST visitor traits and implementations
16
17use crate::domain_types::{FunctionName, ModuleName, QualifiedName};
18use crate::error::{DissolveError, Result};
19use rustpython_ast as ast;
20// Removed unused import
21
22/// Trait for visiting AST nodes and collecting information
23pub trait AstVisitor<T> {
24    /// Visit a module and return collected information
25    fn visit_module(&mut self, module: &ast::Mod) -> Result<T>;
26
27    /// Visit a function definition
28    fn visit_function_def(&mut self, func: &ast::StmtFunctionDef) -> Result<()>;
29
30    /// Visit a class definition
31    fn visit_class_def(&mut self, class: &ast::StmtClassDef) -> Result<()>;
32
33    /// Visit a function call
34    fn visit_call(&mut self, call: &ast::ExprCall) -> Result<()>;
35
36    /// Get the current module name
37    fn module_name(&self) -> &ModuleName;
38}
39
40/// Trait for transforming AST nodes
41pub trait AstTransformer {
42    /// Transform a function call expression
43    fn transform_call(&mut self, call: &ast::ExprCall) -> Result<Option<String>>;
44
45    /// Check if a function should be transformed
46    fn should_transform(&self, qualified_name: &QualifiedName) -> bool;
47}
48
49/// Common functionality for AST visitors
50pub struct VisitorContext {
51    pub module_name: ModuleName,
52    pub file_path: String,
53    pub current_class: Option<String>,
54    pub nested_level: usize,
55}
56
57impl VisitorContext {
58    pub fn new(module_name: ModuleName, file_path: String) -> Self {
59        Self {
60            module_name,
61            file_path,
62            current_class: None,
63            nested_level: 0,
64        }
65    }
66
67    /// Enter a class scope
68    pub fn enter_class(&mut self, class_name: &str) {
69        self.current_class = Some(class_name.to_string());
70        self.nested_level += 1;
71    }
72
73    /// Exit a class scope
74    pub fn exit_class(&mut self) {
75        self.current_class = None;
76        if self.nested_level > 0 {
77            self.nested_level -= 1;
78        }
79    }
80
81    /// Get the current qualified context
82    pub fn current_context(&self) -> String {
83        match &self.current_class {
84            Some(class) => format!("{}.{}", self.module_name, class),
85            None => self.module_name.to_string(),
86        }
87    }
88
89    /// Create a qualified name for a function
90    pub fn qualify_function(&self, function_name: &FunctionName) -> QualifiedName {
91        let context = match &self.current_class {
92            Some(class) => format!("{}.{}", self.module_name, class),
93            None => self.module_name.to_string(),
94        };
95
96        QualifiedName::from_string(&format!("{}.{}", context, function_name.as_str()))
97            .unwrap_or_else(|_| QualifiedName::new(self.module_name.clone(), function_name.clone()))
98    }
99}
100
101/// Helper functions for AST traversal
102pub mod ast_helpers {
103    use super::*;
104    use rustpython_ast as ast;
105
106    /// Extract function name from various expression types
107    pub fn extract_function_name(expr: &ast::Expr) -> Option<String> {
108        match expr {
109            ast::Expr::Name(name) => Some(name.id.to_string()),
110            ast::Expr::Attribute(attr) => {
111                let base = extract_function_name(&attr.value)?;
112                Some(format!("{}.{}", base, attr.attr))
113            }
114            _ => None,
115        }
116    }
117
118    /// Check if an expression is a simple name (not a complex expression)
119    pub fn is_simple_name(expr: &ast::Expr) -> bool {
120        matches!(expr, ast::Expr::Name(_))
121    }
122
123    /// Extract decorator names from a function definition
124    pub fn extract_decorator_names(decorators: &[ast::Expr]) -> Vec<String> {
125        decorators
126            .iter()
127            .filter_map(|dec| match dec {
128                ast::Expr::Name(name) => Some(name.id.to_string()),
129                ast::Expr::Call(call) => match &*call.func {
130                    ast::Expr::Name(name) => Some(name.id.to_string()),
131                    _ => None,
132                },
133                _ => None,
134            })
135            .collect()
136    }
137
138    /// Check if a function has a specific decorator
139    pub fn has_decorator(decorators: &[ast::Expr], decorator_name: &str) -> bool {
140        extract_decorator_names(decorators).contains(&decorator_name.to_string())
141    }
142
143    /// Extract string literal value
144    pub fn extract_string_literal(expr: &ast::Expr) -> Option<String> {
145        match expr {
146            ast::Expr::Constant(constant) => match &constant.value {
147                ast::Constant::Str(s) => Some(s.to_string()),
148                _ => None,
149            },
150            _ => None,
151        }
152    }
153
154    /// Walk through all statements in a module recursively
155    pub fn walk_statements<F>(statements: &[ast::Stmt], mut callback: F) -> Result<()>
156    where
157        F: FnMut(&ast::Stmt) -> Result<()>,
158    {
159        for stmt in statements {
160            callback(stmt)?;
161
162            // Recursively visit nested statements
163            match stmt {
164                ast::Stmt::FunctionDef(func) => {
165                    walk_statements(&func.body, &mut callback)?;
166                }
167                ast::Stmt::AsyncFunctionDef(func) => {
168                    walk_statements(&func.body, &mut callback)?;
169                }
170                ast::Stmt::ClassDef(class) => {
171                    walk_statements(&class.body, &mut callback)?;
172                }
173                ast::Stmt::If(if_stmt) => {
174                    walk_statements(&if_stmt.body, &mut callback)?;
175                    walk_statements(&if_stmt.orelse, &mut callback)?;
176                }
177                ast::Stmt::While(while_stmt) => {
178                    walk_statements(&while_stmt.body, &mut callback)?;
179                    walk_statements(&while_stmt.orelse, &mut callback)?;
180                }
181                ast::Stmt::For(for_stmt) => {
182                    walk_statements(&for_stmt.body, &mut callback)?;
183                    walk_statements(&for_stmt.orelse, &mut callback)?;
184                }
185                ast::Stmt::With(with_stmt) => {
186                    walk_statements(&with_stmt.body, &mut callback)?;
187                }
188                ast::Stmt::AsyncWith(with_stmt) => {
189                    walk_statements(&with_stmt.body, &mut callback)?;
190                }
191                ast::Stmt::Try(try_stmt) => {
192                    walk_statements(&try_stmt.body, &mut callback)?;
193                    walk_statements(&try_stmt.orelse, &mut callback)?;
194                    walk_statements(&try_stmt.finalbody, &mut callback)?;
195                    for handler in &try_stmt.handlers {
196                        match handler {
197                            ast::ExceptHandler::ExceptHandler(exc) => {
198                                walk_statements(&exc.body, &mut callback)?;
199                            }
200                        }
201                    }
202                }
203                _ => {}
204            }
205        }
206        Ok(())
207    }
208}
209
210/// Base implementation of common visitor patterns
211pub struct BaseVisitor {
212    pub context: VisitorContext,
213}
214
215impl BaseVisitor {
216    pub fn new(module_name: ModuleName, file_path: String) -> Self {
217        Self {
218            context: VisitorContext::new(module_name, file_path),
219        }
220    }
221
222    /// Generic module traversal
223    pub fn traverse_module<T, F>(&mut self, module: &ast::Mod, mut visitor_fn: F) -> Result<T>
224    where
225        F: FnMut(&mut Self, &ast::Stmt) -> Result<Option<T>>,
226        T: Default,
227    {
228        match module {
229            ast::Mod::Module(module) => {
230                for stmt in &module.body {
231                    if let Some(result) = visitor_fn(self, stmt)? {
232                        return Ok(result);
233                    }
234                }
235                Ok(T::default())
236            }
237            _ => Err(DissolveError::invalid_input(
238                "Only module AST nodes are supported",
239            )),
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use rustpython_parser::{parse, Mode};
248
249    #[test]
250    fn test_visitor_context() {
251        let module_name = ModuleName::new("test_module");
252        let mut context = VisitorContext::new(module_name.clone(), "test.py".to_string());
253
254        assert_eq!(context.current_context(), "test_module");
255
256        context.enter_class("TestClass");
257        assert_eq!(context.current_context(), "test_module.TestClass");
258
259        context.exit_class();
260        assert_eq!(context.current_context(), "test_module");
261    }
262
263    #[test]
264    fn test_ast_helpers() {
265        let source = r#"
266@decorator
267def test_func():
268    pass
269"#;
270
271        let parsed = parse(source, Mode::Module, "<test>").unwrap();
272        if let ast::Mod::Module(module) = parsed {
273            if let Some(ast::Stmt::FunctionDef(func)) = module.body.first() {
274                let decorators = ast_helpers::extract_decorator_names(&func.decorator_list);
275                assert_eq!(decorators, vec!["decorator"]);
276                assert!(ast_helpers::has_decorator(
277                    &func.decorator_list,
278                    "decorator"
279                ));
280            }
281        }
282    }
283}