1use anyhow::Result;
18use rustpython_ast::{self as ast};
19use rustpython_parser::{parse, Mode};
20use std::fs;
21
22pub fn remove_decorators(
30 source: &str,
31 before_version: Option<&str>,
32 remove_all: bool,
33 current_version: Option<&str>,
34) -> Result<(usize, String)> {
35 if !remove_all && before_version.is_none() && current_version.is_none() {
36 return Ok((0, source.to_string()));
38 }
39
40 let parsed = parse(source, Mode::Module, "<module>")?;
42
43 let mut lines_to_remove = Vec::new();
44 let mut removed_count = 0;
45
46 if let ast::Mod::Module(module) = parsed {
47 for (i, stmt) in module.body.iter().enumerate() {
49 if should_remove_statement(stmt, before_version, remove_all, current_version) {
50 removed_count += 1;
51 if let Some(line_range) = find_statement_lines(source, i, &module.body) {
55 lines_to_remove.push(line_range);
56 }
57 }
58 }
59
60 for stmt in &module.body {
62 let count = collect_removable_statements(
63 stmt,
64 source,
65 before_version,
66 remove_all,
67 current_version,
68 &mut lines_to_remove,
69 );
70 removed_count += count;
71 }
72 }
73
74 let mut result_lines = Vec::new();
76 let source_lines: Vec<&str> = source.lines().collect();
77 let mut skip_until = 0;
78
79 for (i, line) in source_lines.iter().enumerate() {
80 if i < skip_until {
81 continue;
82 }
83
84 let mut should_skip = false;
85 for (start, end) in &lines_to_remove {
86 if i >= *start && i < *end {
87 should_skip = true;
88 skip_until = *end;
89 break;
90 }
91 }
92
93 if !should_skip {
94 result_lines.push(*line);
95 }
96 }
97
98 Ok((removed_count, result_lines.join("\n")))
99}
100
101fn collect_removable_statements(
104 stmt: &ast::Stmt,
105 source: &str,
106 before_version: Option<&str>,
107 remove_all: bool,
108 current_version: Option<&str>,
109 lines_to_remove: &mut Vec<(usize, usize)>,
110) -> usize {
111 let mut count = 0;
112 match stmt {
113 ast::Stmt::ClassDef(class) => {
114 for (i, method) in class.body.iter().enumerate() {
116 if should_remove_statement(method, before_version, remove_all, current_version) {
117 count += 1;
118 if let Some(line_range) = find_method_lines(source, class, i) {
119 lines_to_remove.push(line_range);
120 }
121 }
122
123 count += collect_removable_statements(
125 method,
126 source,
127 before_version,
128 remove_all,
129 current_version,
130 lines_to_remove,
131 );
132 }
133 }
134 _ => {
135 }
137 }
138 count
139}
140
141fn should_remove_statement(
142 stmt: &ast::Stmt,
143 before_version: Option<&str>,
144 remove_all: bool,
145 current_version: Option<&str>,
146) -> bool {
147 match stmt {
148 ast::Stmt::FunctionDef(func) => has_replace_me_decorator(
149 &func.decorator_list,
150 before_version,
151 remove_all,
152 current_version,
153 ),
154 ast::Stmt::AsyncFunctionDef(func) => has_replace_me_decorator(
155 &func.decorator_list,
156 before_version,
157 remove_all,
158 current_version,
159 ),
160 ast::Stmt::ClassDef(class) => has_replace_me_decorator(
161 &class.decorator_list,
162 before_version,
163 remove_all,
164 current_version,
165 ),
166 _ => false,
167 }
168}
169
170fn has_replace_me_decorator(
171 decorators: &[ast::Expr],
172 before_version: Option<&str>,
173 remove_all: bool,
174 current_version: Option<&str>,
175) -> bool {
176 for dec in decorators.iter() {
177 match dec {
178 ast::Expr::Name(name) if name.id.as_str() == "replace_me" => {
179 if remove_all {
180 return true;
181 }
182 }
183 ast::Expr::Call(call) => {
184 if let ast::Expr::Name(name) = &*call.func {
185 let func_name = name.id.as_str();
186 if func_name == "replace_me" {
187 if remove_all {
188 return true;
189 }
190
191 if let Some(before_ver) = before_version {
193 if let Some(since_ver) = extract_since_version(&call.keywords) {
194 if compare_versions(&since_ver, before_ver) < 0 {
195 return true;
196 }
197 }
198 }
199
200 if let Some(current_ver) = current_version {
202 let decorator_before_ver = extract_before_version(&call.keywords);
203 if let Some(decorator_before_ver) = decorator_before_ver {
204 if compare_versions(current_ver, &decorator_before_ver) >= 0 {
205 return true;
206 }
207 }
208 }
209
210 if let Some(current_ver) = current_version {
212 if let Some(remove_in_ver) = extract_remove_in_version(&call.keywords) {
213 if compare_versions(current_ver, &remove_in_ver) >= 0 {
214 return true;
215 }
216 }
217 }
218 }
219 }
220 }
221 _ => {}
222 }
223 }
224 false
225}
226
227fn extract_since_version(keywords: &[ast::Keyword]) -> Option<String> {
228 for keyword in keywords {
229 if let Some(arg) = &keyword.arg {
230 if arg.as_str() == "since" {
231 if let ast::Expr::Constant(c) = &keyword.value {
232 if let ast::Constant::Str(s) = &c.value {
233 return Some(s.to_string());
234 }
235 }
236 }
237 }
238 }
239 None
240}
241
242fn extract_before_version(keywords: &[ast::Keyword]) -> Option<String> {
243 for keyword in keywords {
244 if let Some(arg) = &keyword.arg {
245 if arg.as_str() == "before_version" {
246 if let ast::Expr::Constant(c) = &keyword.value {
247 if let ast::Constant::Str(s) = &c.value {
248 return Some(s.to_string());
249 }
250 }
251 }
252 }
253 }
254 None
255}
256
257fn extract_remove_in_version(keywords: &[ast::Keyword]) -> Option<String> {
258 for keyword in keywords {
259 if let Some(arg) = &keyword.arg {
260 if arg.as_str() == "remove_in" {
261 if let ast::Expr::Constant(c) = &keyword.value {
262 if let ast::Constant::Str(s) = &c.value {
263 return Some(s.to_string());
264 }
265 }
266 }
267 }
268 }
269 None
270}
271
272fn compare_versions(v1: &str, v2: &str) -> i32 {
273 use crate::core::types::Version;
274 match (v1.parse::<Version>(), v2.parse::<Version>()) {
275 (Ok(ver1), Ok(ver2)) => match ver1.cmp(&ver2) {
276 std::cmp::Ordering::Less => -1,
277 std::cmp::Ordering::Equal => 0,
278 std::cmp::Ordering::Greater => 1,
279 },
280 _ => {
281 v1.cmp(v2) as i32
283 }
284 }
285}
286
287fn find_statement_lines(
288 source: &str,
289 stmt_index: usize,
290 stmts: &[ast::Stmt],
291) -> Option<(usize, usize)> {
292 let lines: Vec<&str> = source.lines().collect();
294
295 match &stmts[stmt_index] {
296 ast::Stmt::FunctionDef(func) => {
297 let func_name = &func.name;
298 for (i, line) in lines.iter().enumerate() {
299 if line.contains(&format!("def {}", func_name)) {
300 let indent = line.chars().take_while(|c| c.is_whitespace()).count();
302 for (j, end_line) in lines[i + 1..].iter().enumerate() {
303 let end_i = i + j + 1;
304 if !end_line.trim().is_empty() {
305 let end_indent =
306 end_line.chars().take_while(|c| c.is_whitespace()).count();
307 if end_indent <= indent && !end_line.trim_start().starts_with('#') {
308 let start = find_decorator_start(&lines, i);
310 return Some((start, end_i));
311 }
312 }
313 }
314 let start = find_decorator_start(&lines, i);
316 return Some((start, lines.len()));
317 }
318 }
319 }
320 ast::Stmt::ClassDef(class) => {
321 let class_name = &class.name;
322 for (i, line) in lines.iter().enumerate() {
323 if line.contains(&format!("class {}", class_name)) {
324 let indent = line.chars().take_while(|c| c.is_whitespace()).count();
325 for (j, end_line) in lines[i + 1..].iter().enumerate() {
326 let end_i = i + j + 1;
327 if !end_line.trim().is_empty() {
328 let end_indent =
329 end_line.chars().take_while(|c| c.is_whitespace()).count();
330 if end_indent <= indent && !end_line.trim_start().starts_with('#') {
331 let start = find_decorator_start(&lines, i);
332 return Some((start, end_i));
333 }
334 }
335 }
336 let start = find_decorator_start(&lines, i);
337 return Some((start, lines.len()));
338 }
339 }
340 }
341 _ => {}
342 }
343
344 None
345}
346
347fn find_decorator_start(lines: &[&str], def_line: usize) -> usize {
348 let mut start = def_line;
350 for i in (0..def_line).rev() {
351 let line = lines[i].trim();
352 if line.starts_with('@') || line.is_empty() || line.starts_with('#') {
353 start = i;
354 } else {
355 break;
356 }
357 }
358 start
359}
360
361fn find_method_lines(
362 source: &str,
363 class: &ast::StmtClassDef,
364 method_index: usize,
365) -> Option<(usize, usize)> {
366 let lines: Vec<&str> = source.lines().collect();
367
368 let class_name = &class.name;
370 let mut class_line = None;
371 for (i, line) in lines.iter().enumerate() {
372 if line.contains(&format!("class {}:", class_name)) {
373 class_line = Some(i);
374 break;
375 }
376 }
377
378 let class_start = class_line?;
379
380 match &class.body[method_index] {
382 ast::Stmt::FunctionDef(method) => {
383 let method_name = &method.name;
384
385 for (i, line) in lines[class_start + 1..].iter().enumerate() {
387 let actual_i = class_start + 1 + i;
388 if line.contains(&format!("def {}", method_name)) {
389 let class_indent = lines[class_start]
391 .chars()
392 .take_while(|c| c.is_whitespace())
393 .count();
394 let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
395
396 for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
397 let end_i = actual_i + j + 1;
398 if !end_line.trim().is_empty() {
399 let end_indent =
400 end_line.chars().take_while(|c| c.is_whitespace()).count();
401 if end_indent <= method_indent
402 && !end_line.trim_start().starts_with('#')
403 {
404 let start = find_decorator_start(&lines, actual_i);
406 return Some((start, end_i));
407 }
408 }
409 }
410
411 let start = find_decorator_start(&lines, actual_i);
413
414 for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
416 let end_i = actual_i + j + 1;
417 if !end_line.trim().is_empty() {
418 let end_indent =
419 end_line.chars().take_while(|c| c.is_whitespace()).count();
420 if end_indent <= class_indent {
421 return Some((start, end_i));
422 }
423 }
424 }
425
426 return Some((start, lines.len()));
427 }
428 }
429 }
430 ast::Stmt::AsyncFunctionDef(method) => {
431 let method_name = &method.name;
432
433 for (i, line) in lines[class_start + 1..].iter().enumerate() {
435 let actual_i = class_start + 1 + i;
436 if line.contains(&format!("async def {}", method_name)) {
437 let _class_indent = lines[class_start]
439 .chars()
440 .take_while(|c| c.is_whitespace())
441 .count();
442 let method_indent = line.chars().take_while(|c| c.is_whitespace()).count();
443
444 for (j, end_line) in lines[actual_i + 1..].iter().enumerate() {
445 let end_i = actual_i + j + 1;
446 if !end_line.trim().is_empty() {
447 let end_indent =
448 end_line.chars().take_while(|c| c.is_whitespace()).count();
449 if end_indent <= method_indent
450 && !end_line.trim_start().starts_with('#')
451 {
452 let start = find_decorator_start(&lines, actual_i);
453 return Some((start, end_i));
454 }
455 }
456 }
457
458 let start = find_decorator_start(&lines, actual_i);
459 return Some((start, lines.len()));
460 }
461 }
462 }
463 _ => {}
464 }
465
466 None
467}
468
469pub fn remove_decorators_from_file(
471 file_path: &str,
472 before_version: Option<&str>,
473 remove_all: bool,
474 write: bool,
475 current_version: Option<&str>,
476) -> Result<(usize, String)> {
477 let source = fs::read_to_string(file_path)?;
478
479 let (removed_count, result) =
480 remove_decorators(&source, before_version, remove_all, current_version)?;
481
482 if write && result != source {
483 fs::write(file_path, &result)?;
484 }
485
486 Ok((removed_count, result))
487}
488
489pub fn remove_from_file(
491 file_path: &str,
492 before_version: Option<&str>,
493 remove_all: bool,
494 write: bool,
495 current_version: Option<&str>,
496) -> Result<(usize, String)> {
497 remove_decorators_from_file(
498 file_path,
499 before_version,
500 remove_all,
501 write,
502 current_version,
503 )
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_remove_all() {
512 let source = r#"
513from dissolve import replace_me
514
515@replace_me()
516def old_function():
517 return new_function()
518
519def regular_function():
520 return 42
521
522@replace_me(since="1.0.0")
523def another_old():
524 return new_api()
525"#;
526
527 let (count, result) = remove_decorators(source, None, true, None).unwrap();
528 assert_eq!(count, 2, "Should remove 2 functions");
529 assert!(!result.contains("def old_function"));
530 assert!(!result.contains("def another_old"));
531 assert!(result.contains("def regular_function"));
532 }
533
534 #[test]
535 fn test_no_removal_criteria() {
536 let source = r#"
537@replace_me()
538def old_function():
539 return new_function()
540"#;
541
542 let (count, result) = remove_decorators(source, None, false, None).unwrap();
543 assert_eq!(count, 0, "Should remove 0 functions");
544 assert_eq!(result, source);
545 }
546
547 #[test]
548 fn test_remove_before_version() {
549 let source = r#"
550from dissolve import replace_me
551
552@replace_me(since="1.0.0")
553def old_v1():
554 return new_v1()
555
556@replace_me(since="2.0.0")
557def old_v2():
558 return new_v2()
559
560def regular_function():
561 return 42
562"#;
563
564 let (count, result) = remove_decorators(source, Some("1.5.0"), false, None).unwrap();
565 assert_eq!(count, 1, "Should remove 1 function");
566 assert!(!result.contains("def old_v1"));
568 assert!(result.contains("def old_v2"));
570 assert!(result.contains("def regular_function"));
571 }
572}