use anyhow::Result;
use rustpython_ast::{self as ast};
use rustpython_parser::{parse, Mode};
use std::fs;
pub fn remove_decorators(
source: &str,
before_version: Option<&str>,
remove_all: bool,
current_version: Option<&str>,
) -> Result<(usize, String)> {
if !remove_all && before_version.is_none() && current_version.is_none() {
return Ok((0, source.to_string()));
}
let parsed = parse(source, Mode::Module, "<module>")?;
let mut lines_to_remove = Vec::new();
let mut removed_count = 0;
if let ast::Mod::Module(module) = parsed {
for (i, stmt) in module.body.iter().enumerate() {
if should_remove_statement(stmt, before_version, remove_all, current_version) {
removed_count += 1;
if let Some(line_range) = find_statement_lines(source, i, &module.body) {
lines_to_remove.push(line_range);
}
}
}
for stmt in &module.body {
let count = collect_removable_statements(
stmt,
source,
before_version,
remove_all,
current_version,
&mut lines_to_remove,
);
removed_count += count;
}
}
let mut result_lines = Vec::new();
let source_lines: Vec<&str> = source.lines().collect();
let mut skip_until = 0;
for (i, line) in source_lines.iter().enumerate() {
if i < skip_until {
continue;
}
let mut should_skip = false;
for (start, end) in &lines_to_remove {
if i >= *start && i < *end {
should_skip = true;
skip_until = *end;
break;
}
}
if !should_skip {
result_lines.push(*line);
}
}
Ok((removed_count, result_lines.join("\n")))
}
fn collect_removable_statements(
stmt: &ast::Stmt,
source: &str,
before_version: Option<&str>,
remove_all: bool,
current_version: Option<&str>,
lines_to_remove: &mut Vec<(usize, usize)>,
) -> usize {
let mut count = 0;
match stmt {
ast::Stmt::ClassDef(class) => {
for (i, method) in class.body.iter().enumerate() {
if should_remove_statement(method, before_version, remove_all, current_version) {
count += 1;
if let Some(line_range) = find_method_lines(source, class, i) {
lines_to_remove.push(line_range);
}
}
count += collect_removable_statements(
method,
source,
before_version,
remove_all,
current_version,
lines_to_remove,
);
}
}
_ => {
}
}
count
}
fn should_remove_statement(
stmt: &ast::Stmt,
before_version: Option<&str>,
remove_all: bool,
current_version: Option<&str>,
) -> bool {
match stmt {
ast::Stmt::FunctionDef(func) => has_replace_me_decorator(
&func.decorator_list,
before_version,
remove_all,
current_version,
),
ast::Stmt::AsyncFunctionDef(func) => has_replace_me_decorator(
&func.decorator_list,
before_version,
remove_all,
current_version,
),
ast::Stmt::ClassDef(class) => has_replace_me_decorator(
&class.decorator_list,
before_version,
remove_all,
current_version,
),
_ => false,
}
}
fn has_replace_me_decorator(
decorators: &[ast::Expr],
before_version: Option<&str>,
remove_all: bool,
current_version: Option<&str>,
) -> bool {
for dec in decorators.iter() {
match dec {
ast::Expr::Name(name) if name.id.as_str() == "replace_me" => {
if remove_all {
return true;
}
}
ast::Expr::Call(call) => {
if let ast::Expr::Name(name) = &*call.func {
let func_name = name.id.as_str();
if func_name == "replace_me" {
if remove_all {
return true;
}
if let Some(before_ver) = before_version {
if let Some(since_ver) = extract_since_version(&call.keywords) {
if compare_versions(&since_ver, before_ver) < 0 {
return true;
}
}
}
if let Some(current_ver) = current_version {
let decorator_before_ver = extract_before_version(&call.keywords);
if let Some(decorator_before_ver) = decorator_before_ver {
if compare_versions(current_ver, &decorator_before_ver) >= 0 {
return true;
}
}
}
if let Some(current_ver) = current_version {
if let Some(remove_in_ver) = extract_remove_in_version(&call.keywords) {
if compare_versions(current_ver, &remove_in_ver) >= 0 {
return true;
}
}
}
}
}
}
_ => {}
}
}
false
}
fn extract_since_version(keywords: &[ast::Keyword]) -> Option<String> {
for keyword in keywords {
if let Some(arg) = &keyword.arg {
if arg.as_str() == "since" {
if let ast::Expr::Constant(c) = &keyword.value {
if let ast::Constant::Str(s) = &c.value {
return Some(s.to_string());
}
}
}
}
}
None
}
fn extract_before_version(keywords: &[ast::Keyword]) -> Option<String> {
for keyword in keywords {
if let Some(arg) = &keyword.arg {
if arg.as_str() == "before_version" {
if let ast::Expr::Constant(c) = &keyword.value {
if let ast::Constant::Str(s) = &c.value {
return Some(s.to_string());
}
}
}
}
}
None
}
fn extract_remove_in_version(keywords: &[ast::Keyword]) -> Option<String> {
for keyword in keywords {
if let Some(arg) = &keyword.arg {
if arg.as_str() == "remove_in" {
if let ast::Expr::Constant(c) = &keyword.value {
if let ast::Constant::Str(s) = &c.value {
return Some(s.to_string());
}
}
}
}
}
None
}
fn compare_versions(v1: &str, v2: &str) -> i32 {
use crate::core::types::Version;
match (v1.parse::<Version>(), v2.parse::<Version>()) {
(Ok(ver1), Ok(ver2)) => match ver1.cmp(&ver2) {
std::cmp::Ordering::Less => -1,
std::cmp::Ordering::Equal => 0,
std::cmp::Ordering::Greater => 1,
},
_ => {
v1.cmp(v2) as i32
}
}
}
fn find_statement_lines(
source: &str,
stmt_index: usize,
stmts: &[ast::Stmt],
) -> Option<(usize, usize)> {
let lines: Vec<&str> = source.lines().collect();
match &stmts[stmt_index] {
ast::Stmt::FunctionDef(func) => {
let func_name = &func.name;
for (i, line) in lines.iter().enumerate() {
if line.contains(&format!("def {}", func_name)) {
let indent = line.chars().take_while(|c| c.is_whitespace()).count();
for (j, end_line) in lines[i + 1..].iter().enumerate() {
let end_i = i + j + 1;
if !end_line.trim().is_empty() {
let end_indent =
end_line.chars().take_while(|c| c.is_whitespace()).count();
if end_indent <= indent && !end_line.trim_start().starts_with('#') {
let start = find_decorator_start(&lines, i);
return Some((start, end_i));
}
}
}
let start = find_decorator_start(&lines, i);
return Some((start, lines.len()));
}
}
}
ast::Stmt::ClassDef(class) => {
let class_name = &class.name;
for (i, line) in lines.iter().enumerate() {
if line.contains(&format!("class {}", class_name)) {
let indent = line.chars().take_while(|c| c.is_whitespace()).count();
for (j, end_line) in lines[i + 1..].iter().enumerate() {
let end_i = i + j + 1;
if !end_line.trim().is_empty() {
let end_indent =
end_line.chars().take_while(|c| c.is_whitespace()).count();
if end_indent <= indent && !end_line.trim_start().starts_with('#') {
let start = find_decorator_start(&lines, i);
return Some((start, end_i));
}
}
}
let start = find_decorator_start(&lines, i);
return Some((start, lines.len()));
}
}
}
_ => {}
}
None
}
fn find_decorator_start(lines: &[&str], def_line: usize) -> usize {
let mut start = def_line;
for i in (0..def_line).rev() {
let line = lines[i].trim();
if line.starts_with('@') || line.is_empty() || line.starts_with('#') {
start = i;
} else {
break;
}
}
start
}
fn find_method_lines(
source: &str,
class: &ast::StmtClassDef,
method_index: usize,
) -> Option<(usize, usize)> {
let lines: Vec<&str> = source.lines().collect();
let class_name = &class.name;
let mut class_line = None;
for (i, line) in lines.iter().enumerate() {
if line.contains(&format!("class {}:", class_name)) {
class_line = Some(i);
break;
}
}
let class_start = class_line?;
match &class.body[method_index] {
ast::Stmt::FunctionDef(method) => {
let method_name = &method.name;
for (i, line) in lines[class_start + 1..].iter().enumerate() {
let actual_i = class_start + 1 + i;
if line.contains(&format!("def {}", method_name)) {
let class_indent = lines[class_start]
.chars()
.take_while(|c| c.is_whitespace())
.count();
let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
let end_i = actual_i + j + 1;
if !end_line.trim().is_empty() {
let end_indent =
end_line.chars().take_while(|c| c.is_whitespace()).count();
if end_indent <= method_indent
&& !end_line.trim_start().starts_with('#')
{
let start = find_decorator_start(&lines, actual_i);
return Some((start, end_i));
}
}
}
let start = find_decorator_start(&lines, actual_i);
for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
let end_i = actual_i + j + 1;
if !end_line.trim().is_empty() {
let end_indent =
end_line.chars().take_while(|c| c.is_whitespace()).count();
if end_indent <= class_indent {
return Some((start, end_i));
}
}
}
return Some((start, lines.len()));
}
}
}
ast::Stmt::AsyncFunctionDef(method) => {
let method_name = &method.name;
for (i, line) in lines[class_start + 1..].iter().enumerate() {
let actual_i = class_start + 1 + i;
if line.contains(&format!("async def {}", method_name)) {
let _class_indent = lines[class_start]
.chars()
.take_while(|c| c.is_whitespace())
.count();
let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
let end_i = actual_i + j + 1;
if !end_line.trim().is_empty() {
let end_indent =
end_line.chars().take_while(|c| c.is_whitespace()).count();
if end_indent <= method_indent
&& !end_line.trim_start().starts_with('#')
{
let start = find_decorator_start(&lines, actual_i);
return Some((start, end_i));
}
}
}
let start = find_decorator_start(&lines, actual_i);
return Some((start, lines.len()));
}
}
}
_ => {}
}
None
}
pub fn remove_decorators_from_file(
file_path: &str,
before_version: Option<&str>,
remove_all: bool,
write: bool,
current_version: Option<&str>,
) -> Result<(usize, String)> {
let source = fs::read_to_string(file_path)?;
let (removed_count, result) =
remove_decorators(&source, before_version, remove_all, current_version)?;
if write && result != source {
fs::write(file_path, &result)?;
}
Ok((removed_count, result))
}
pub fn remove_from_file(
file_path: &str,
before_version: Option<&str>,
remove_all: bool,
write: bool,
current_version: Option<&str>,
) -> Result<(usize, String)> {
remove_decorators_from_file(
file_path,
before_version,
remove_all,
write,
current_version,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_all() {
let source = r#"
from dissolve import replace_me
@replace_me()
def old_function():
return new_function()
def regular_function():
return 42
@replace_me(since="1.0.0")
def another_old():
return new_api()
"#;
let (count, result) = remove_decorators(source, None, true, None).unwrap();
assert_eq!(count, 2, "Should remove 2 functions");
assert!(!result.contains("def old_function"));
assert!(!result.contains("def another_old"));
assert!(result.contains("def regular_function"));
}
#[test]
fn test_no_removal_criteria() {
let source = r#"
@replace_me()
def old_function():
return new_function()
"#;
let (count, result) = remove_decorators(source, None, false, None).unwrap();
assert_eq!(count, 0, "Should remove 0 functions");
assert_eq!(result, source);
}
#[test]
fn test_remove_before_version() {
let source = r#"
from dissolve import replace_me
@replace_me(since="1.0.0")
def old_v1():
return new_v1()
@replace_me(since="2.0.0")
def old_v2():
return new_v2()
def regular_function():
return 42
"#;
let (count, result) = remove_decorators(source, Some("1.5.0"), false, None).unwrap();
assert_eq!(count, 1, "Should remove 1 function");
assert!(!result.contains("def old_v1"));
assert!(result.contains("def old_v2"));
assert!(result.contains("def regular_function"));
}
}