Skip to main content

dissolve_python/
remover.rs

1// Copyright (C) 2024 Jelmer Vernooij <jelmer@samba.org>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Remove deprecated functions decorated with @replace_me from source files.
16
17use anyhow::Result;
18use rustpython_ast::{self as ast};
19use rustpython_parser::{parse, Mode};
20use std::fs;
21
22/// Remove @replace_me decorated functions from source code
23///
24/// This removes entire functions decorated with @replace_me,
25/// not just the decorators. This should only be used after migration is complete
26/// and library authors want to clean up their codebase.
27///
28/// Returns a tuple of (removed_count, resulting_source)
29pub fn remove_decorators(
30    source: &str,
31    before_version: Option<&str>,
32    remove_all: bool,
33    current_version: Option<&str>,
34) -> Result<(usize, String)> {
35    if !remove_all && before_version.is_none() && current_version.is_none() {
36        // No removal criteria specified, return source unchanged
37        return Ok((0, source.to_string()));
38    }
39
40    // Parse the source
41    let parsed = parse(source, Mode::Module, "<module>")?;
42
43    let mut lines_to_remove = Vec::new();
44    let mut removed_count = 0;
45
46    if let ast::Mod::Module(module) = parsed {
47        // Check top-level statements
48        for (i, stmt) in module.body.iter().enumerate() {
49            if should_remove_statement(stmt, before_version, remove_all, current_version) {
50                removed_count += 1;
51                // Find the line range for this statement
52                // In a real implementation, we'd use source positions
53                // For now, we'll use a simple approach
54                if let Some(line_range) = find_statement_lines(source, i, &module.body) {
55                    lines_to_remove.push(line_range);
56                }
57            }
58        }
59
60        // Check methods inside classes (including nested classes)
61        for stmt in &module.body {
62            let count = collect_removable_statements(
63                stmt,
64                source,
65                before_version,
66                remove_all,
67                current_version,
68                &mut lines_to_remove,
69            );
70            removed_count += count;
71        }
72    }
73
74    // Remove the lines
75    let mut result_lines = Vec::new();
76    let source_lines: Vec<&str> = source.lines().collect();
77    let mut skip_until = 0;
78
79    for (i, line) in source_lines.iter().enumerate() {
80        if i < skip_until {
81            continue;
82        }
83
84        let mut should_skip = false;
85        for (start, end) in &lines_to_remove {
86            if i >= *start && i < *end {
87                should_skip = true;
88                skip_until = *end;
89                break;
90            }
91        }
92
93        if !should_skip {
94            result_lines.push(*line);
95        }
96    }
97
98    Ok((removed_count, result_lines.join("\n")))
99}
100
101/// Recursively collect statements that should be removed, including nested classes
102/// Returns the count of removed functions/methods
103fn collect_removable_statements(
104    stmt: &ast::Stmt,
105    source: &str,
106    before_version: Option<&str>,
107    remove_all: bool,
108    current_version: Option<&str>,
109    lines_to_remove: &mut Vec<(usize, usize)>,
110) -> usize {
111    let mut count = 0;
112    match stmt {
113        ast::Stmt::ClassDef(class) => {
114            // Check methods inside this class
115            for (i, method) in class.body.iter().enumerate() {
116                if should_remove_statement(method, before_version, remove_all, current_version) {
117                    count += 1;
118                    if let Some(line_range) = find_method_lines(source, class, i) {
119                        lines_to_remove.push(line_range);
120                    }
121                }
122
123                // Recursively check nested classes
124                count += collect_removable_statements(
125                    method,
126                    source,
127                    before_version,
128                    remove_all,
129                    current_version,
130                    lines_to_remove,
131                );
132            }
133        }
134        _ => {
135            // For non-class statements, no need to recurse
136        }
137    }
138    count
139}
140
141fn should_remove_statement(
142    stmt: &ast::Stmt,
143    before_version: Option<&str>,
144    remove_all: bool,
145    current_version: Option<&str>,
146) -> bool {
147    match stmt {
148        ast::Stmt::FunctionDef(func) => has_replace_me_decorator(
149            &func.decorator_list,
150            before_version,
151            remove_all,
152            current_version,
153        ),
154        ast::Stmt::AsyncFunctionDef(func) => has_replace_me_decorator(
155            &func.decorator_list,
156            before_version,
157            remove_all,
158            current_version,
159        ),
160        ast::Stmt::ClassDef(class) => has_replace_me_decorator(
161            &class.decorator_list,
162            before_version,
163            remove_all,
164            current_version,
165        ),
166        _ => false,
167    }
168}
169
170fn has_replace_me_decorator(
171    decorators: &[ast::Expr],
172    before_version: Option<&str>,
173    remove_all: bool,
174    current_version: Option<&str>,
175) -> bool {
176    for dec in decorators.iter() {
177        match dec {
178            ast::Expr::Name(name) if name.id.as_str() == "replace_me" => {
179                if remove_all {
180                    return true;
181                }
182            }
183            ast::Expr::Call(call) => {
184                if let ast::Expr::Name(name) = &*call.func {
185                    let func_name = name.id.as_str();
186                    if func_name == "replace_me" {
187                        if remove_all {
188                            return true;
189                        }
190
191                        // Check version if specified
192                        if let Some(before_ver) = before_version {
193                            if let Some(since_ver) = extract_since_version(&call.keywords) {
194                                if compare_versions(&since_ver, before_ver) < 0 {
195                                    return true;
196                                }
197                            }
198                        }
199
200                        // Check before_version parameter in decorator
201                        if let Some(current_ver) = current_version {
202                            let decorator_before_ver = extract_before_version(&call.keywords);
203                            if let Some(decorator_before_ver) = decorator_before_ver {
204                                if compare_versions(current_ver, &decorator_before_ver) >= 0 {
205                                    return true;
206                                }
207                            }
208                        }
209
210                        // Check remove_in version
211                        if let Some(current_ver) = current_version {
212                            if let Some(remove_in_ver) = extract_remove_in_version(&call.keywords) {
213                                if compare_versions(current_ver, &remove_in_ver) >= 0 {
214                                    return true;
215                                }
216                            }
217                        }
218                    }
219                }
220            }
221            _ => {}
222        }
223    }
224    false
225}
226
227fn extract_since_version(keywords: &[ast::Keyword]) -> Option<String> {
228    for keyword in keywords {
229        if let Some(arg) = &keyword.arg {
230            if arg.as_str() == "since" {
231                if let ast::Expr::Constant(c) = &keyword.value {
232                    if let ast::Constant::Str(s) = &c.value {
233                        return Some(s.to_string());
234                    }
235                }
236            }
237        }
238    }
239    None
240}
241
242fn extract_before_version(keywords: &[ast::Keyword]) -> Option<String> {
243    for keyword in keywords {
244        if let Some(arg) = &keyword.arg {
245            if arg.as_str() == "before_version" {
246                if let ast::Expr::Constant(c) = &keyword.value {
247                    if let ast::Constant::Str(s) = &c.value {
248                        return Some(s.to_string());
249                    }
250                }
251            }
252        }
253    }
254    None
255}
256
257fn extract_remove_in_version(keywords: &[ast::Keyword]) -> Option<String> {
258    for keyword in keywords {
259        if let Some(arg) = &keyword.arg {
260            if arg.as_str() == "remove_in" {
261                if let ast::Expr::Constant(c) = &keyword.value {
262                    if let ast::Constant::Str(s) = &c.value {
263                        return Some(s.to_string());
264                    }
265                }
266            }
267        }
268    }
269    None
270}
271
272fn compare_versions(v1: &str, v2: &str) -> i32 {
273    use crate::core::types::Version;
274    match (v1.parse::<Version>(), v2.parse::<Version>()) {
275        (Ok(ver1), Ok(ver2)) => match ver1.cmp(&ver2) {
276            std::cmp::Ordering::Less => -1,
277            std::cmp::Ordering::Equal => 0,
278            std::cmp::Ordering::Greater => 1,
279        },
280        _ => {
281            // Fallback to string comparison if parsing fails
282            v1.cmp(v2) as i32
283        }
284    }
285}
286
287fn find_statement_lines(
288    source: &str,
289    stmt_index: usize,
290    stmts: &[ast::Stmt],
291) -> Option<(usize, usize)> {
292    // Simple heuristic: find function/class definitions by name
293    let lines: Vec<&str> = source.lines().collect();
294
295    match &stmts[stmt_index] {
296        ast::Stmt::FunctionDef(func) => {
297            let func_name = &func.name;
298            for (i, line) in lines.iter().enumerate() {
299                if line.contains(&format!("def {}", func_name)) {
300                    // Find the end of the function (next def, class, or dedent)
301                    let indent = line.chars().take_while(|c| c.is_whitespace()).count();
302                    for (j, end_line) in lines[i + 1..].iter().enumerate() {
303                        let end_i = i + j + 1;
304                        if !end_line.trim().is_empty() {
305                            let end_indent =
306                                end_line.chars().take_while(|c| c.is_whitespace()).count();
307                            if end_indent <= indent && !end_line.trim_start().starts_with('#') {
308                                // Also remove decorators before the function
309                                let start = find_decorator_start(&lines, i);
310                                return Some((start, end_i));
311                            }
312                        }
313                    }
314                    // If we reach the end, include everything
315                    let start = find_decorator_start(&lines, i);
316                    return Some((start, lines.len()));
317                }
318            }
319        }
320        ast::Stmt::ClassDef(class) => {
321            let class_name = &class.name;
322            for (i, line) in lines.iter().enumerate() {
323                if line.contains(&format!("class {}", class_name)) {
324                    let indent = line.chars().take_while(|c| c.is_whitespace()).count();
325                    for (j, end_line) in lines[i + 1..].iter().enumerate() {
326                        let end_i = i + j + 1;
327                        if !end_line.trim().is_empty() {
328                            let end_indent =
329                                end_line.chars().take_while(|c| c.is_whitespace()).count();
330                            if end_indent <= indent && !end_line.trim_start().starts_with('#') {
331                                let start = find_decorator_start(&lines, i);
332                                return Some((start, end_i));
333                            }
334                        }
335                    }
336                    let start = find_decorator_start(&lines, i);
337                    return Some((start, lines.len()));
338                }
339            }
340        }
341        _ => {}
342    }
343
344    None
345}
346
347fn find_decorator_start(lines: &[&str], def_line: usize) -> usize {
348    // Look backwards for decorators
349    let mut start = def_line;
350    for i in (0..def_line).rev() {
351        let line = lines[i].trim();
352        if line.starts_with('@') || line.is_empty() || line.starts_with('#') {
353            start = i;
354        } else {
355            break;
356        }
357    }
358    start
359}
360
361fn find_method_lines(
362    source: &str,
363    class: &ast::StmtClassDef,
364    method_index: usize,
365) -> Option<(usize, usize)> {
366    let lines: Vec<&str> = source.lines().collect();
367
368    // Find the class first
369    let class_name = &class.name;
370    let mut class_line = None;
371    for (i, line) in lines.iter().enumerate() {
372        if line.contains(&format!("class {}:", class_name)) {
373            class_line = Some(i);
374            break;
375        }
376    }
377
378    let class_start = class_line?;
379
380    // Find the method within the class
381    match &class.body[method_index] {
382        ast::Stmt::FunctionDef(method) => {
383            let method_name = &method.name;
384
385            // Look for the method definition starting from after the class line
386            for (i, line) in lines[class_start + 1..].iter().enumerate() {
387                let actual_i = class_start + 1 + i;
388                if line.contains(&format!("def {}", method_name)) {
389                    // Find the end of the method (next def, class, or dedent to class level)
390                    let class_indent = lines[class_start]
391                        .chars()
392                        .take_while(|c| c.is_whitespace())
393                        .count();
394                    let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
395
396                    for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
397                        let end_i = actual_i + j + 1;
398                        if !end_line.trim().is_empty() {
399                            let end_indent =
400                                end_line.chars().take_while(|c| c.is_whitespace()).count();
401                            if end_indent <= method_indent
402                                && !end_line.trim_start().starts_with('#')
403                            {
404                                // Also remove decorators before the method
405                                let start = find_decorator_start(&lines, actual_i);
406                                return Some((start, end_i));
407                            }
408                        }
409                    }
410
411                    // If we reach the end, include everything to end of class
412                    let start = find_decorator_start(&lines, actual_i);
413
414                    // Find end of class
415                    for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
416                        let end_i = actual_i + j + 1;
417                        if !end_line.trim().is_empty() {
418                            let end_indent =
419                                end_line.chars().take_while(|c| c.is_whitespace()).count();
420                            if end_indent <= class_indent {
421                                return Some((start, end_i));
422                            }
423                        }
424                    }
425
426                    return Some((start, lines.len()));
427                }
428            }
429        }
430        ast::Stmt::AsyncFunctionDef(method) => {
431            let method_name = &method.name;
432
433            // Look for the async method definition
434            for (i, line) in lines[class_start + 1..].iter().enumerate() {
435                let actual_i = class_start + 1 + i;
436                if line.contains(&format!("async def {}", method_name)) {
437                    // Similar logic as above for async methods
438                    let _class_indent = lines[class_start]
439                        .chars()
440                        .take_while(|c| c.is_whitespace())
441                        .count();
442                    let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
443
444                    for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
445                        let end_i = actual_i + j + 1;
446                        if !end_line.trim().is_empty() {
447                            let end_indent =
448                                end_line.chars().take_while(|c| c.is_whitespace()).count();
449                            if end_indent <= method_indent
450                                && !end_line.trim_start().starts_with('#')
451                            {
452                                let start = find_decorator_start(&lines, actual_i);
453                                return Some((start, end_i));
454                            }
455                        }
456                    }
457
458                    let start = find_decorator_start(&lines, actual_i);
459                    return Some((start, lines.len()));
460                }
461            }
462        }
463        _ => {}
464    }
465
466    None
467}
468
469/// Remove functions decorated with @replace_me from a file
470pub fn remove_decorators_from_file(
471    file_path: &str,
472    before_version: Option<&str>,
473    remove_all: bool,
474    write: bool,
475    current_version: Option<&str>,
476) -> Result<(usize, String)> {
477    let source = fs::read_to_string(file_path)?;
478
479    let (removed_count, result) =
480        remove_decorators(&source, before_version, remove_all, current_version)?;
481
482    if write && result != source {
483        fs::write(file_path, &result)?;
484    }
485
486    Ok((removed_count, result))
487}
488
489/// Remove functions from a file - entry point from main
490pub fn remove_from_file(
491    file_path: &str,
492    before_version: Option<&str>,
493    remove_all: bool,
494    write: bool,
495    current_version: Option<&str>,
496) -> Result<(usize, String)> {
497    remove_decorators_from_file(
498        file_path,
499        before_version,
500        remove_all,
501        write,
502        current_version,
503    )
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_remove_all() {
512        let source = r#"
513from dissolve import replace_me
514
515@replace_me()
516def old_function():
517    return new_function()
518
519def regular_function():
520    return 42
521
522@replace_me(since="1.0.0")
523def another_old():
524    return new_api()
525"#;
526
527        let (count, result) = remove_decorators(source, None, true, None).unwrap();
528        assert_eq!(count, 2, "Should remove 2 functions");
529        assert!(!result.contains("def old_function"));
530        assert!(!result.contains("def another_old"));
531        assert!(result.contains("def regular_function"));
532    }
533
534    #[test]
535    fn test_no_removal_criteria() {
536        let source = r#"
537@replace_me()
538def old_function():
539    return new_function()
540"#;
541
542        let (count, result) = remove_decorators(source, None, false, None).unwrap();
543        assert_eq!(count, 0, "Should remove 0 functions");
544        assert_eq!(result, source);
545    }
546
547    #[test]
548    fn test_remove_before_version() {
549        let source = r#"
550from dissolve import replace_me
551
552@replace_me(since="1.0.0")
553def old_v1():
554    return new_v1()
555
556@replace_me(since="2.0.0")
557def old_v2():
558    return new_v2()
559
560def regular_function():
561    return 42
562"#;
563
564        let (count, result) = remove_decorators(source, Some("1.5.0"), false, None).unwrap();
565        assert_eq!(count, 1, "Should remove 1 function");
566        // Functions with version < 1.5.0 should be removed
567        assert!(!result.contains("def old_v1"));
568        // Functions with version >= 1.5.0 should remain
569        assert!(result.contains("def old_v2"));
570        assert!(result.contains("def regular_function"));
571    }
572}