dissolve-python 0.3.0

A tool to dissolve deprecated calls in Python codebases
Documentation
// Copyright (C) 2024 Jelmer Vernooij <jelmer@samba.org>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Unified AST visitor traits and implementations

use crate::domain_types::{FunctionName, ModuleName, QualifiedName};
use crate::error::{DissolveError, Result};
use rustpython_ast as ast;
// Removed unused import

/// Trait for visiting AST nodes and collecting information
pub trait AstVisitor<T> {
    /// Visit a module and return collected information
    fn visit_module(&mut self, module: &ast::Mod) -> Result<T>;

    /// Visit a function definition
    fn visit_function_def(&mut self, func: &ast::StmtFunctionDef) -> Result<()>;

    /// Visit a class definition
    fn visit_class_def(&mut self, class: &ast::StmtClassDef) -> Result<()>;

    /// Visit a function call
    fn visit_call(&mut self, call: &ast::ExprCall) -> Result<()>;

    /// Get the current module name
    fn module_name(&self) -> &ModuleName;
}

/// Trait for transforming AST nodes
pub trait AstTransformer {
    /// Transform a function call expression
    fn transform_call(&mut self, call: &ast::ExprCall) -> Result<Option<String>>;

    /// Check if a function should be transformed
    fn should_transform(&self, qualified_name: &QualifiedName) -> bool;
}

/// Common functionality for AST visitors
pub struct VisitorContext {
    pub module_name: ModuleName,
    pub file_path: String,
    pub current_class: Option<String>,
    pub nested_level: usize,
}

impl VisitorContext {
    pub fn new(module_name: ModuleName, file_path: String) -> Self {
        Self {
            module_name,
            file_path,
            current_class: None,
            nested_level: 0,
        }
    }

    /// Enter a class scope
    pub fn enter_class(&mut self, class_name: &str) {
        self.current_class = Some(class_name.to_string());
        self.nested_level += 1;
    }

    /// Exit a class scope
    pub fn exit_class(&mut self) {
        self.current_class = None;
        if self.nested_level > 0 {
            self.nested_level -= 1;
        }
    }

    /// Get the current qualified context
    pub fn current_context(&self) -> String {
        match &self.current_class {
            Some(class) => format!("{}.{}", self.module_name, class),
            None => self.module_name.to_string(),
        }
    }

    /// Create a qualified name for a function
    pub fn qualify_function(&self, function_name: &FunctionName) -> QualifiedName {
        let context = match &self.current_class {
            Some(class) => format!("{}.{}", self.module_name, class),
            None => self.module_name.to_string(),
        };

        QualifiedName::from_string(&format!("{}.{}", context, function_name.as_str()))
            .unwrap_or_else(|_| QualifiedName::new(self.module_name.clone(), function_name.clone()))
    }
}

/// Helper functions for AST traversal
pub mod ast_helpers {
    use super::*;
    use rustpython_ast as ast;

    /// Extract function name from various expression types
    pub fn extract_function_name(expr: &ast::Expr) -> Option<String> {
        match expr {
            ast::Expr::Name(name) => Some(name.id.to_string()),
            ast::Expr::Attribute(attr) => {
                let base = extract_function_name(&attr.value)?;
                Some(format!("{}.{}", base, attr.attr))
            }
            _ => None,
        }
    }

    /// Check if an expression is a simple name (not a complex expression)
    pub fn is_simple_name(expr: &ast::Expr) -> bool {
        matches!(expr, ast::Expr::Name(_))
    }

    /// Extract decorator names from a function definition
    pub fn extract_decorator_names(decorators: &[ast::Expr]) -> Vec<String> {
        decorators
            .iter()
            .filter_map(|dec| match dec {
                ast::Expr::Name(name) => Some(name.id.to_string()),
                ast::Expr::Call(call) => match &*call.func {
                    ast::Expr::Name(name) => Some(name.id.to_string()),
                    _ => None,
                },
                _ => None,
            })
            .collect()
    }

    /// Check if a function has a specific decorator
    pub fn has_decorator(decorators: &[ast::Expr], decorator_name: &str) -> bool {
        extract_decorator_names(decorators).contains(&decorator_name.to_string())
    }

    /// Extract string literal value
    pub fn extract_string_literal(expr: &ast::Expr) -> Option<String> {
        match expr {
            ast::Expr::Constant(constant) => match &constant.value {
                ast::Constant::Str(s) => Some(s.to_string()),
                _ => None,
            },
            _ => None,
        }
    }

    /// Walk through all statements in a module recursively
    pub fn walk_statements<F>(statements: &[ast::Stmt], mut callback: F) -> Result<()>
    where
        F: FnMut(&ast::Stmt) -> Result<()>,
    {
        for stmt in statements {
            callback(stmt)?;

            // Recursively visit nested statements
            match stmt {
                ast::Stmt::FunctionDef(func) => {
                    walk_statements(&func.body, &mut callback)?;
                }
                ast::Stmt::AsyncFunctionDef(func) => {
                    walk_statements(&func.body, &mut callback)?;
                }
                ast::Stmt::ClassDef(class) => {
                    walk_statements(&class.body, &mut callback)?;
                }
                ast::Stmt::If(if_stmt) => {
                    walk_statements(&if_stmt.body, &mut callback)?;
                    walk_statements(&if_stmt.orelse, &mut callback)?;
                }
                ast::Stmt::While(while_stmt) => {
                    walk_statements(&while_stmt.body, &mut callback)?;
                    walk_statements(&while_stmt.orelse, &mut callback)?;
                }
                ast::Stmt::For(for_stmt) => {
                    walk_statements(&for_stmt.body, &mut callback)?;
                    walk_statements(&for_stmt.orelse, &mut callback)?;
                }
                ast::Stmt::With(with_stmt) => {
                    walk_statements(&with_stmt.body, &mut callback)?;
                }
                ast::Stmt::AsyncWith(with_stmt) => {
                    walk_statements(&with_stmt.body, &mut callback)?;
                }
                ast::Stmt::Try(try_stmt) => {
                    walk_statements(&try_stmt.body, &mut callback)?;
                    walk_statements(&try_stmt.orelse, &mut callback)?;
                    walk_statements(&try_stmt.finalbody, &mut callback)?;
                    for handler in &try_stmt.handlers {
                        match handler {
                            ast::ExceptHandler::ExceptHandler(exc) => {
                                walk_statements(&exc.body, &mut callback)?;
                            }
                        }
                    }
                }
                _ => {}
            }
        }
        Ok(())
    }
}

/// Base implementation of common visitor patterns
pub struct BaseVisitor {
    pub context: VisitorContext,
}

impl BaseVisitor {
    pub fn new(module_name: ModuleName, file_path: String) -> Self {
        Self {
            context: VisitorContext::new(module_name, file_path),
        }
    }

    /// Generic module traversal
    pub fn traverse_module<T, F>(&mut self, module: &ast::Mod, mut visitor_fn: F) -> Result<T>
    where
        F: FnMut(&mut Self, &ast::Stmt) -> Result<Option<T>>,
        T: Default,
    {
        match module {
            ast::Mod::Module(module) => {
                for stmt in &module.body {
                    if let Some(result) = visitor_fn(self, stmt)? {
                        return Ok(result);
                    }
                }
                Ok(T::default())
            }
            _ => Err(DissolveError::invalid_input(
                "Only module AST nodes are supported",
            )),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rustpython_parser::{parse, Mode};

    #[test]
    fn test_visitor_context() {
        let module_name = ModuleName::new("test_module");
        let mut context = VisitorContext::new(module_name.clone(), "test.py".to_string());

        assert_eq!(context.current_context(), "test_module");

        context.enter_class("TestClass");
        assert_eq!(context.current_context(), "test_module.TestClass");

        context.exit_class();
        assert_eq!(context.current_context(), "test_module");
    }

    #[test]
    fn test_ast_helpers() {
        let source = r#"
@decorator
def test_func():
    pass
"#;

        let parsed = parse(source, Mode::Module, "<test>").unwrap();
        if let ast::Mod::Module(module) = parsed {
            if let Some(ast::Stmt::FunctionDef(func)) = module.body.first() {
                let decorators = ast_helpers::extract_decorator_names(&func.decorator_list);
                assert_eq!(decorators, vec!["decorator"]);
                assert!(ast_helpers::has_decorator(
                    &func.decorator_list,
                    "decorator"
                ));
            }
        }
    }
}