use crate::domain_types::{FunctionName, ModuleName, QualifiedName};
use crate::error::{DissolveError, Result};
use rustpython_ast as ast;
pub trait AstVisitor<T> {
fn visit_module(&mut self, module: &ast::Mod) -> Result<T>;
fn visit_function_def(&mut self, func: &ast::StmtFunctionDef) -> Result<()>;
fn visit_class_def(&mut self, class: &ast::StmtClassDef) -> Result<()>;
fn visit_call(&mut self, call: &ast::ExprCall) -> Result<()>;
fn module_name(&self) -> &ModuleName;
}
pub trait AstTransformer {
fn transform_call(&mut self, call: &ast::ExprCall) -> Result<Option<String>>;
fn should_transform(&self, qualified_name: &QualifiedName) -> bool;
}
pub struct VisitorContext {
pub module_name: ModuleName,
pub file_path: String,
pub current_class: Option<String>,
pub nested_level: usize,
}
impl VisitorContext {
pub fn new(module_name: ModuleName, file_path: String) -> Self {
Self {
module_name,
file_path,
current_class: None,
nested_level: 0,
}
}
pub fn enter_class(&mut self, class_name: &str) {
self.current_class = Some(class_name.to_string());
self.nested_level += 1;
}
pub fn exit_class(&mut self) {
self.current_class = None;
if self.nested_level > 0 {
self.nested_level -= 1;
}
}
pub fn current_context(&self) -> String {
match &self.current_class {
Some(class) => format!("{}.{}", self.module_name, class),
None => self.module_name.to_string(),
}
}
pub fn qualify_function(&self, function_name: &FunctionName) -> QualifiedName {
let context = match &self.current_class {
Some(class) => format!("{}.{}", self.module_name, class),
None => self.module_name.to_string(),
};
QualifiedName::from_string(&format!("{}.{}", context, function_name.as_str()))
.unwrap_or_else(|_| QualifiedName::new(self.module_name.clone(), function_name.clone()))
}
}
pub mod ast_helpers {
use super::*;
use rustpython_ast as ast;
pub fn extract_function_name(expr: &ast::Expr) -> Option<String> {
match expr {
ast::Expr::Name(name) => Some(name.id.to_string()),
ast::Expr::Attribute(attr) => {
let base = extract_function_name(&attr.value)?;
Some(format!("{}.{}", base, attr.attr))
}
_ => None,
}
}
pub fn is_simple_name(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Name(_))
}
pub fn extract_decorator_names(decorators: &[ast::Expr]) -> Vec<String> {
decorators
.iter()
.filter_map(|dec| match dec {
ast::Expr::Name(name) => Some(name.id.to_string()),
ast::Expr::Call(call) => match &*call.func {
ast::Expr::Name(name) => Some(name.id.to_string()),
_ => None,
},
_ => None,
})
.collect()
}
pub fn has_decorator(decorators: &[ast::Expr], decorator_name: &str) -> bool {
extract_decorator_names(decorators).contains(&decorator_name.to_string())
}
pub fn extract_string_literal(expr: &ast::Expr) -> Option<String> {
match expr {
ast::Expr::Constant(constant) => match &constant.value {
ast::Constant::Str(s) => Some(s.to_string()),
_ => None,
},
_ => None,
}
}
pub fn walk_statements<F>(statements: &[ast::Stmt], mut callback: F) -> Result<()>
where
F: FnMut(&ast::Stmt) -> Result<()>,
{
for stmt in statements {
callback(stmt)?;
match stmt {
ast::Stmt::FunctionDef(func) => {
walk_statements(&func.body, &mut callback)?;
}
ast::Stmt::AsyncFunctionDef(func) => {
walk_statements(&func.body, &mut callback)?;
}
ast::Stmt::ClassDef(class) => {
walk_statements(&class.body, &mut callback)?;
}
ast::Stmt::If(if_stmt) => {
walk_statements(&if_stmt.body, &mut callback)?;
walk_statements(&if_stmt.orelse, &mut callback)?;
}
ast::Stmt::While(while_stmt) => {
walk_statements(&while_stmt.body, &mut callback)?;
walk_statements(&while_stmt.orelse, &mut callback)?;
}
ast::Stmt::For(for_stmt) => {
walk_statements(&for_stmt.body, &mut callback)?;
walk_statements(&for_stmt.orelse, &mut callback)?;
}
ast::Stmt::With(with_stmt) => {
walk_statements(&with_stmt.body, &mut callback)?;
}
ast::Stmt::AsyncWith(with_stmt) => {
walk_statements(&with_stmt.body, &mut callback)?;
}
ast::Stmt::Try(try_stmt) => {
walk_statements(&try_stmt.body, &mut callback)?;
walk_statements(&try_stmt.orelse, &mut callback)?;
walk_statements(&try_stmt.finalbody, &mut callback)?;
for handler in &try_stmt.handlers {
match handler {
ast::ExceptHandler::ExceptHandler(exc) => {
walk_statements(&exc.body, &mut callback)?;
}
}
}
}
_ => {}
}
}
Ok(())
}
}
pub struct BaseVisitor {
pub context: VisitorContext,
}
impl BaseVisitor {
pub fn new(module_name: ModuleName, file_path: String) -> Self {
Self {
context: VisitorContext::new(module_name, file_path),
}
}
pub fn traverse_module<T, F>(&mut self, module: &ast::Mod, mut visitor_fn: F) -> Result<T>
where
F: FnMut(&mut Self, &ast::Stmt) -> Result<Option<T>>,
T: Default,
{
match module {
ast::Mod::Module(module) => {
for stmt in &module.body {
if let Some(result) = visitor_fn(self, stmt)? {
return Ok(result);
}
}
Ok(T::default())
}
_ => Err(DissolveError::invalid_input(
"Only module AST nodes are supported",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustpython_parser::{parse, Mode};
#[test]
fn test_visitor_context() {
let module_name = ModuleName::new("test_module");
let mut context = VisitorContext::new(module_name.clone(), "test.py".to_string());
assert_eq!(context.current_context(), "test_module");
context.enter_class("TestClass");
assert_eq!(context.current_context(), "test_module.TestClass");
context.exit_class();
assert_eq!(context.current_context(), "test_module");
}
#[test]
fn test_ast_helpers() {
let source = r#"
@decorator
def test_func():
pass
"#;
let parsed = parse(source, Mode::Module, "<test>").unwrap();
if let ast::Mod::Module(module) = parsed {
if let Some(ast::Stmt::FunctionDef(func)) = module.body.first() {
let decorators = ast_helpers::extract_decorator_names(&func.decorator_list);
assert_eq!(decorators, vec!["decorator"]);
assert!(ast_helpers::has_decorator(
&func.decorator_list,
"decorator"
));
}
}
}
}