dissolve_python/
ast_visitor.rs1use crate::domain_types::{FunctionName, ModuleName, QualifiedName};
18use crate::error::{DissolveError, Result};
19use rustpython_ast as ast;
20pub trait AstVisitor<T> {
24 fn visit_module(&mut self, module: &ast::Mod) -> Result<T>;
26
27 fn visit_function_def(&mut self, func: &ast::StmtFunctionDef) -> Result<()>;
29
30 fn visit_class_def(&mut self, class: &ast::StmtClassDef) -> Result<()>;
32
33 fn visit_call(&mut self, call: &ast::ExprCall) -> Result<()>;
35
36 fn module_name(&self) -> &ModuleName;
38}
39
40pub trait AstTransformer {
42 fn transform_call(&mut self, call: &ast::ExprCall) -> Result<Option<String>>;
44
45 fn should_transform(&self, qualified_name: &QualifiedName) -> bool;
47}
48
49pub struct VisitorContext {
51 pub module_name: ModuleName,
52 pub file_path: String,
53 pub current_class: Option<String>,
54 pub nested_level: usize,
55}
56
57impl VisitorContext {
58 pub fn new(module_name: ModuleName, file_path: String) -> Self {
59 Self {
60 module_name,
61 file_path,
62 current_class: None,
63 nested_level: 0,
64 }
65 }
66
67 pub fn enter_class(&mut self, class_name: &str) {
69 self.current_class = Some(class_name.to_string());
70 self.nested_level += 1;
71 }
72
73 pub fn exit_class(&mut self) {
75 self.current_class = None;
76 if self.nested_level > 0 {
77 self.nested_level -= 1;
78 }
79 }
80
81 pub fn current_context(&self) -> String {
83 match &self.current_class {
84 Some(class) => format!("{}.{}", self.module_name, class),
85 None => self.module_name.to_string(),
86 }
87 }
88
89 pub fn qualify_function(&self, function_name: &FunctionName) -> QualifiedName {
91 let context = match &self.current_class {
92 Some(class) => format!("{}.{}", self.module_name, class),
93 None => self.module_name.to_string(),
94 };
95
96 QualifiedName::from_string(&format!("{}.{}", context, function_name.as_str()))
97 .unwrap_or_else(|_| QualifiedName::new(self.module_name.clone(), function_name.clone()))
98 }
99}
100
101pub mod ast_helpers {
103 use super::*;
104 use rustpython_ast as ast;
105
106 pub fn extract_function_name(expr: &ast::Expr) -> Option<String> {
108 match expr {
109 ast::Expr::Name(name) => Some(name.id.to_string()),
110 ast::Expr::Attribute(attr) => {
111 let base = extract_function_name(&attr.value)?;
112 Some(format!("{}.{}", base, attr.attr))
113 }
114 _ => None,
115 }
116 }
117
118 pub fn is_simple_name(expr: &ast::Expr) -> bool {
120 matches!(expr, ast::Expr::Name(_))
121 }
122
123 pub fn extract_decorator_names(decorators: &[ast::Expr]) -> Vec<String> {
125 decorators
126 .iter()
127 .filter_map(|dec| match dec {
128 ast::Expr::Name(name) => Some(name.id.to_string()),
129 ast::Expr::Call(call) => match &*call.func {
130 ast::Expr::Name(name) => Some(name.id.to_string()),
131 _ => None,
132 },
133 _ => None,
134 })
135 .collect()
136 }
137
138 pub fn has_decorator(decorators: &[ast::Expr], decorator_name: &str) -> bool {
140 extract_decorator_names(decorators).contains(&decorator_name.to_string())
141 }
142
143 pub fn extract_string_literal(expr: &ast::Expr) -> Option<String> {
145 match expr {
146 ast::Expr::Constant(constant) => match &constant.value {
147 ast::Constant::Str(s) => Some(s.to_string()),
148 _ => None,
149 },
150 _ => None,
151 }
152 }
153
154 pub fn walk_statements<F>(statements: &[ast::Stmt], mut callback: F) -> Result<()>
156 where
157 F: FnMut(&ast::Stmt) -> Result<()>,
158 {
159 for stmt in statements {
160 callback(stmt)?;
161
162 match stmt {
164 ast::Stmt::FunctionDef(func) => {
165 walk_statements(&func.body, &mut callback)?;
166 }
167 ast::Stmt::AsyncFunctionDef(func) => {
168 walk_statements(&func.body, &mut callback)?;
169 }
170 ast::Stmt::ClassDef(class) => {
171 walk_statements(&class.body, &mut callback)?;
172 }
173 ast::Stmt::If(if_stmt) => {
174 walk_statements(&if_stmt.body, &mut callback)?;
175 walk_statements(&if_stmt.orelse, &mut callback)?;
176 }
177 ast::Stmt::While(while_stmt) => {
178 walk_statements(&while_stmt.body, &mut callback)?;
179 walk_statements(&while_stmt.orelse, &mut callback)?;
180 }
181 ast::Stmt::For(for_stmt) => {
182 walk_statements(&for_stmt.body, &mut callback)?;
183 walk_statements(&for_stmt.orelse, &mut callback)?;
184 }
185 ast::Stmt::With(with_stmt) => {
186 walk_statements(&with_stmt.body, &mut callback)?;
187 }
188 ast::Stmt::AsyncWith(with_stmt) => {
189 walk_statements(&with_stmt.body, &mut callback)?;
190 }
191 ast::Stmt::Try(try_stmt) => {
192 walk_statements(&try_stmt.body, &mut callback)?;
193 walk_statements(&try_stmt.orelse, &mut callback)?;
194 walk_statements(&try_stmt.finalbody, &mut callback)?;
195 for handler in &try_stmt.handlers {
196 match handler {
197 ast::ExceptHandler::ExceptHandler(exc) => {
198 walk_statements(&exc.body, &mut callback)?;
199 }
200 }
201 }
202 }
203 _ => {}
204 }
205 }
206 Ok(())
207 }
208}
209
210pub struct BaseVisitor {
212 pub context: VisitorContext,
213}
214
215impl BaseVisitor {
216 pub fn new(module_name: ModuleName, file_path: String) -> Self {
217 Self {
218 context: VisitorContext::new(module_name, file_path),
219 }
220 }
221
222 pub fn traverse_module<T, F>(&mut self, module: &ast::Mod, mut visitor_fn: F) -> Result<T>
224 where
225 F: FnMut(&mut Self, &ast::Stmt) -> Result<Option<T>>,
226 T: Default,
227 {
228 match module {
229 ast::Mod::Module(module) => {
230 for stmt in &module.body {
231 if let Some(result) = visitor_fn(self, stmt)? {
232 return Ok(result);
233 }
234 }
235 Ok(T::default())
236 }
237 _ => Err(DissolveError::invalid_input(
238 "Only module AST nodes are supported",
239 )),
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use rustpython_parser::{parse, Mode};
248
249 #[test]
250 fn test_visitor_context() {
251 let module_name = ModuleName::new("test_module");
252 let mut context = VisitorContext::new(module_name.clone(), "test.py".to_string());
253
254 assert_eq!(context.current_context(), "test_module");
255
256 context.enter_class("TestClass");
257 assert_eq!(context.current_context(), "test_module.TestClass");
258
259 context.exit_class();
260 assert_eq!(context.current_context(), "test_module");
261 }
262
263 #[test]
264 fn test_ast_helpers() {
265 let source = r#"
266@decorator
267def test_func():
268 pass
269"#;
270
271 let parsed = parse(source, Mode::Module, "<test>").unwrap();
272 if let ast::Mod::Module(module) = parsed {
273 if let Some(ast::Stmt::FunctionDef(func)) = module.body.first() {
274 let decorators = ast_helpers::extract_decorator_names(&func.decorator_list);
275 assert_eq!(decorators, vec!["decorator"]);
276 assert!(ast_helpers::has_decorator(
277 &func.decorator_list,
278 "decorator"
279 ));
280 }
281 }
282 }
283}