1use anyhow::{Context, Result};
16use once_cell::sync::Lazy;
17use std::collections::{HashMap, HashSet};
18use std::fs;
19use std::path::Path;
20use std::sync::Mutex;
21use tracing;
22
23use crate::core::{CollectorResult, ImportInfo, ReplaceInfo};
24use crate::unified_visitor::{UnifiedResult, UnifiedVisitor};
25use crate::RuffDeprecatedFunctionCollector;
26
27static MODULE_CACHE: Lazy<Mutex<HashMap<String, CollectorResult>>> =
29 Lazy::new(|| Mutex::new(HashMap::new()));
30
31#[derive(Debug, Clone)]
33pub struct DependencyCollectionResult {
34 pub replacements: HashMap<String, ReplaceInfo>,
35 pub inheritance_map: HashMap<String, Vec<String>>,
36 pub class_methods: HashMap<String, HashSet<String>>,
37}
38
39impl Default for DependencyCollectionResult {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl DependencyCollectionResult {
46 pub fn new() -> Self {
47 Self {
48 replacements: HashMap::new(),
49 inheritance_map: HashMap::new(),
50 class_methods: HashMap::new(),
51 }
52 }
53
54 pub fn update(&mut self, other: &DependencyCollectionResult) {
56 for (key, value) in &other.replacements {
58 self.replacements.insert(key.clone(), value.clone());
59 }
60
61 for (key, value) in &other.inheritance_map {
62 self.inheritance_map.insert(key.clone(), value.clone());
63 }
64
65 for (class_name, methods) in &other.class_methods {
67 self.class_methods
68 .entry(class_name.clone())
69 .or_default()
70 .extend(methods.iter().cloned());
71 }
72 }
73}
74
75impl From<CollectorResult> for DependencyCollectionResult {
76 fn from(result: CollectorResult) -> Self {
77 Self {
78 replacements: result.replacements,
79 inheritance_map: result.inheritance_map,
80 class_methods: result.class_methods,
81 }
82 }
83}
84
85pub fn clear_module_cache() {
87 if let Ok(mut cache) = MODULE_CACHE.lock() {
88 cache.clear();
89 }
90}
91
92fn get_inheritance_chain_for_class(
94 class_name: &str,
95 inheritance_map: &HashMap<String, Vec<String>>,
96) -> Vec<String> {
97 let mut chain = Vec::new();
98 let mut to_process = vec![class_name.to_string()];
99 let mut processed = HashSet::new();
100
101 while let Some(current) = to_process.pop() {
102 if !processed.insert(current.clone()) {
103 continue;
105 }
106
107 if let Some(bases) = inheritance_map.get(¤t) {
108 chain.extend(bases.iter().cloned());
110 to_process.extend(bases.iter().cloned());
111 }
112 }
113
114 chain
115}
116
117pub fn collect_imports_from_source(source: &str, module_name: &str) -> Result<Vec<ImportInfo>> {
119 let visitor = UnifiedVisitor::new_for_collection(module_name, None);
121 let unified_result = visitor.process_source(source.to_string())?;
122
123 let result = match unified_result {
124 UnifiedResult::Collection(result) => result,
125 _ => return Err(anyhow::anyhow!("Expected collection result")),
126 };
127
128 Ok(result.imports)
129}
130
131pub fn resolve_module_path(module_name: &str, relative_to: Option<&str>) -> Option<String> {
133 if module_name.starts_with('.') {
135 let relative_to = relative_to?;
136
137 let level = module_name.chars().take_while(|&c| c == '.').count();
139 let relative_parts: Vec<&str> = if module_name.len() > level {
140 module_name[level..].split('.').collect()
141 } else {
142 vec![]
143 };
144
145 let mut base_parts: Vec<&str> = relative_to.split('.').collect();
147 if level >= base_parts.len() {
148 return None;
149 }
150
151 base_parts.truncate(base_parts.len() - level);
152 base_parts.extend(relative_parts);
153
154 Some(base_parts.join("."))
155 } else {
156 Some(module_name.to_string())
157 }
158}
159
160pub fn might_contain_replace_me(source: &str) -> bool {
162 source.contains("@replace_me") || source.contains("replace_me")
164}
165
166#[allow(dead_code)]
168fn find_module_file(module_path: &str) -> Option<String> {
169 find_module_file_with_paths(module_path, &[])
170}
171
172fn find_module_file_with_paths(module_path: &str, additional_paths: &[String]) -> Option<String> {
174 use pyo3::prelude::*;
175
176 Python::with_gil(|py| {
177 if !additional_paths.is_empty() {
179 tracing::debug!(
180 "Checking additional paths for module {}: {:?}",
181 module_path,
182 additional_paths
183 );
184 for base_path in additional_paths {
186 let module_parts: Vec<&str> = module_path.split('.').collect();
188 let mut file_path = std::path::PathBuf::from(base_path);
189 for part in &module_parts {
190 file_path.push(part);
191 }
192
193 let init_path = file_path.join("__init__.py");
195 if init_path.exists() {
196 return Some(init_path.to_string_lossy().to_string());
197 }
198
199 file_path.set_extension("py");
201 tracing::debug!(
202 "Checking path: {:?}, exists: {}",
203 file_path,
204 file_path.exists()
205 );
206 if file_path.exists() {
207 tracing::debug!("Found module at: {:?}", file_path);
208 return Some(file_path.to_string_lossy().to_string());
209 }
210 }
211 }
212
213 let importlib_util = py.import("importlib.util").ok()?;
215 let find_spec = importlib_util.getattr("find_spec").ok()?;
216
217 if let Ok(spec) = find_spec.call1((module_path,)) {
219 if !spec.is_none() {
220 if let Ok(origin) = spec.getattr("origin") {
221 if !origin.is_none() {
222 if let Ok(path) = origin.extract::<String>() {
223 return Some(path);
224 }
225 }
226 }
227 }
228 }
229
230 None
231 })
232}
233
234pub fn collect_deprecated_from_module(module_path: &str) -> Result<DependencyCollectionResult> {
236 collect_deprecated_from_module_with_paths(module_path, &[])
237}
238
239pub fn collect_deprecated_from_module_with_paths(
241 module_path: &str,
242 additional_paths: &[String],
243) -> Result<DependencyCollectionResult> {
244 if let Ok(cache) = MODULE_CACHE.lock() {
246 if let Some(cached) = cache.get(module_path) {
247 return Ok(cached.clone().into());
248 }
249 }
250
251 let mut result = CollectorResult::new();
252
253 tracing::debug!(
255 "Looking for module {} with additional paths: {:?}",
256 module_path,
257 additional_paths
258 );
259 if let Some(file_path) = find_module_file_with_paths(module_path, additional_paths) {
260 tracing::debug!("Found module {} at {}", module_path, file_path);
261
262 let source = fs::read_to_string(&file_path)
264 .with_context(|| format!("Failed to read module file: {}", file_path))?;
265
266 if !might_contain_replace_me(&source) {
268 tracing::debug!("Module {} does not contain replace_me", module_path);
269 if let Ok(mut cache) = MODULE_CACHE.lock() {
271 cache.insert(module_path.to_string(), result.clone());
272 }
273 return Ok(result.into());
274 }
275
276 tracing::debug!("Module {} contains replace_me, collecting...", module_path);
277
278 let collector = RuffDeprecatedFunctionCollector::new(
280 module_path.to_string(),
281 Some(Path::new(&file_path)),
282 );
283 if let Ok(collector_result) = collector.collect_from_source(source) {
284 tracing::debug!(
285 "Found {} replacements in {}",
286 collector_result.replacements.len(),
287 module_path
288 );
289 for (key, replacement) in &collector_result.replacements {
290 tracing::debug!(
291 " Replacement key: {} -> {}",
292 key,
293 replacement.replacement_expr
294 );
295 }
296 result = collector_result;
297 }
298 } else {
299 tracing::debug!("Module {} not found", module_path);
300 }
301
302 if let Ok(mut cache) = MODULE_CACHE.lock() {
304 cache.insert(module_path.to_string(), result.clone());
305 }
306
307 Ok(result.into())
308}
309
310pub fn collect_deprecated_from_dependencies(
312 source: &str,
313 module_name: &str,
314 max_depth: i32,
315) -> Result<DependencyCollectionResult> {
316 collect_deprecated_from_dependencies_with_paths(source, module_name, max_depth, &[])
317}
318
319pub fn collect_deprecated_from_dependencies_with_paths(
321 source: &str,
322 module_name: &str,
323 max_depth: i32,
324 additional_paths: &[String],
325) -> Result<DependencyCollectionResult> {
326 tracing::info!(
327 "Starting recursive collection for module {} with max_depth {}",
328 module_name,
329 max_depth
330 );
331 collect_deprecated_from_dependencies_recursive(
332 source,
333 module_name,
334 max_depth,
335 &mut HashSet::new(),
336 additional_paths,
337 )
338}
339
340fn collect_deprecated_from_dependencies_recursive(
342 source: &str,
343 module_name: &str,
344 max_depth: i32,
345 visited_modules: &mut HashSet<String>,
346 additional_paths: &[String],
347) -> Result<DependencyCollectionResult> {
348 let mut result = DependencyCollectionResult::new();
349
350 if max_depth <= 0 {
352 return Ok(result);
353 }
354
355 let imports = collect_imports_from_source(source, module_name)?;
357 tracing::info!("Found {} imports in source", imports.len());
358 for imp in &imports {
359 tracing::info!(" Import: {:?}", imp);
360 }
361
362 let mut module_imports: HashMap<String, Vec<ImportInfo>> = HashMap::new();
364
365 for imp in imports {
366 if let Some(resolved) = resolve_module_path(&imp.module, Some(module_name)) {
367 module_imports.entry(resolved).or_default().push(imp);
368 }
369 }
370
371 for (resolved, imp_list) in module_imports {
373 if visited_modules.contains(&resolved) {
375 tracing::debug!("Skipping already visited module: {}", resolved);
376 continue;
377 }
378 tracing::debug!("Processing module: {} at depth {}", resolved, max_depth);
379 visited_modules.insert(resolved.clone());
380
381 tracing::debug!("Attempting to collect from module: {}", resolved);
383 if let Ok(module_result) =
384 collect_deprecated_from_module_with_paths(&resolved, additional_paths)
385 {
386 tracing::debug!(
387 "Module {} has {} replacements",
388 resolved,
389 module_result.replacements.len()
390 );
391 tracing::info!(
392 "Module {} has {} replacements and inheritance map: {:?}",
393 resolved,
394 module_result.replacements.len(),
395 module_result.inheritance_map
396 );
397 for (key, value) in &module_result.inheritance_map {
399 result.inheritance_map.insert(key.clone(), value.clone());
400 }
401
402 let mut all_imported_names = HashSet::new();
404 let mut has_star_import = false;
405
406 for imp in &imp_list {
407 for (name, _alias) in &imp.names {
408 if name == "*" {
409 has_star_import = true;
410 } else {
411 all_imported_names.insert(name.clone());
412 }
413 }
414 if imp.names.is_empty() {
415 has_star_import = true;
417 }
418 }
419
420 if has_star_import {
422 tracing::debug!(
424 "Star import from {}, including all {} replacements",
425 resolved,
426 module_result.replacements.len()
427 );
428 result
429 .replacements
430 .extend(module_result.replacements.clone());
431
432 for class_path in module_result.replacements.keys() {
434 if let Some(class_name) = class_path.split('.').nth(1) {
435 let full_class_path = format!("{}.{}", resolved, class_name);
436
437 let inheritance_chain = get_inheritance_chain_for_class(
439 &full_class_path,
440 &module_result.inheritance_map,
441 );
442
443 for base_class in inheritance_chain {
444 for (repl_path, repl_info) in &module_result.replacements {
446 if repl_path.starts_with(&format!("{}.", base_class)) {
447 result
448 .replacements
449 .insert(repl_path.clone(), repl_info.clone());
450 }
451 }
452 }
453 }
454 }
455 } else {
456 tracing::info!("Checking imported names: {:?}", all_imported_names);
458 for name in &all_imported_names {
459 let full_path = format!("{}.{}", resolved, name);
460 tracing::debug!(
461 "Checking imported name '{}', full_path: '{}' with replacements: {:?}",
462 name,
463 full_path,
464 module_result.replacements.keys().collect::<Vec<_>>()
465 );
466
467 for (repl_path, repl_info) in &module_result.replacements {
469 if repl_path == &full_path
470 || repl_path.starts_with(&format!("{}.", full_path))
471 {
472 result
473 .replacements
474 .insert(repl_path.clone(), repl_info.clone());
475 }
476 }
477
478 if !module_result.inheritance_map.is_empty() {
480 let inheritance_chain = get_inheritance_chain_for_class(
481 &full_path,
482 &module_result.inheritance_map,
483 );
484 tracing::debug!(
485 "Inheritance chain for {}: {:?}",
486 full_path,
487 inheritance_chain
488 );
489
490 for base_class in inheritance_chain {
491 let qualified_base = format!("{}.{}", resolved, base_class);
493
494 for (repl_path, repl_info) in &module_result.replacements {
495 if repl_path.starts_with(&format!("{}.", base_class))
496 || repl_path.starts_with(&format!("{}.", qualified_base))
497 {
498 tracing::debug!(
499 "Including inherited replacement: {}",
500 repl_path
501 );
502 result
503 .replacements
504 .insert(repl_path.clone(), repl_info.clone());
505 }
506 }
507 }
508 }
509
510 let submodule_path = format!("{}.{}", resolved, name);
512 if let Ok(submodule_result) =
513 collect_deprecated_from_module_with_paths(&submodule_path, additional_paths)
514 {
515 if !submodule_result.replacements.is_empty() {
516 result.update(&submodule_result);
517 }
518 }
519 }
520 }
521
522 result
524 .class_methods
525 .extend(module_result.class_methods.clone());
526
527 if max_depth > 1 {
529 if let Some(module_file) = find_module_file_with_paths(&resolved, additional_paths)
531 {
532 if let Ok(module_source) = fs::read_to_string(&module_file) {
533 tracing::debug!(
534 "Recursively processing imports from {} (depth {})",
535 resolved,
536 max_depth - 1
537 );
538 if let Ok(dep_result) = collect_deprecated_from_dependencies_recursive(
539 &module_source,
540 &resolved,
541 max_depth - 1,
542 visited_modules,
543 additional_paths,
544 ) {
545 result.update(&dep_result);
546 }
547 }
548 }
549 }
550 }
551 }
552
553 Ok(result)
554}
555
556pub fn scan_file_with_dependencies(
558 file_path: &str,
559 module_name: &str,
560) -> Result<HashMap<String, ReplaceInfo>> {
561 let mut all_replacements = HashMap::new();
562
563 let source = fs::read_to_string(file_path)
565 .with_context(|| format!("Failed to read file: {}", file_path))?;
566
567 let collector =
569 RuffDeprecatedFunctionCollector::new(module_name.to_string(), Some(Path::new(&file_path)));
570 if let Ok(result) = collector.collect_from_source(source.clone()) {
571 all_replacements.extend(result.replacements);
572 }
573
574 if let Ok(dep_result) = collect_deprecated_from_dependencies(&source, module_name, 5) {
576 all_replacements.extend(dep_result.replacements);
577 }
578
579 Ok(all_replacements)
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_resolve_module_path_absolute() {
588 assert_eq!(
589 resolve_module_path("os.path", None),
590 Some("os.path".to_string())
591 );
592 assert_eq!(
593 resolve_module_path("dulwich.repo", None),
594 Some("dulwich.repo".to_string())
595 );
596 }
597
598 #[test]
599 fn test_resolve_module_path_relative() {
600 assert_eq!(
602 resolve_module_path(".sibling", Some("package.module")),
603 Some("package.sibling".to_string())
604 );
605
606 assert_eq!(
608 resolve_module_path("..parent", Some("package.sub.module")),
609 Some("package.parent".to_string())
610 );
611
612 assert_eq!(
614 resolve_module_path("..", Some("package.sub.module")),
615 Some("package".to_string())
616 );
617
618 assert_eq!(
620 resolve_module_path("...toomuch", Some("package.module")),
621 None
622 );
623 }
624
625 #[test]
626 fn test_might_contain_replace_me() {
627 assert!(might_contain_replace_me("@replace_me\ndef foo(): pass"));
628 assert!(might_contain_replace_me("from dissolve import replace_me"));
629 assert!(!might_contain_replace_me("def regular_function(): pass"));
630 }
631
632 #[test]
633 fn test_get_inheritance_chain() {
634 let mut inheritance_map = HashMap::new();
635 inheritance_map.insert("Child".to_string(), vec!["Parent".to_string()]);
636 inheritance_map.insert("Parent".to_string(), vec!["GrandParent".to_string()]);
637 inheritance_map.insert(
638 "GrandParent".to_string(),
639 vec!["GreatGrandParent".to_string()],
640 );
641
642 let chain = get_inheritance_chain_for_class("Child", &inheritance_map);
643 assert_eq!(chain.len(), 3);
644 assert!(chain.contains(&"Parent".to_string()));
645 assert!(chain.contains(&"GrandParent".to_string()));
646 assert!(chain.contains(&"GreatGrandParent".to_string()));
647 }
648
649 #[test]
650 fn test_get_inheritance_chain_multiple_inheritance() {
651 let mut inheritance_map = HashMap::new();
652 inheritance_map.insert(
653 "Child".to_string(),
654 vec!["Parent1".to_string(), "Parent2".to_string()],
655 );
656 inheritance_map.insert("Parent1".to_string(), vec!["GrandParent".to_string()]);
657 inheritance_map.insert("Parent2".to_string(), vec!["GrandParent".to_string()]);
658
659 let chain = get_inheritance_chain_for_class("Child", &inheritance_map);
660 assert!(chain.contains(&"Parent1".to_string()));
661 assert!(chain.contains(&"Parent2".to_string()));
662 assert!(chain.contains(&"GrandParent".to_string()));
663 }
665
666 #[test]
667 fn test_collect_imports_from_source() {
668 let source = r#"
669import os
670from sys import path
671from ..relative import something
672from . import sibling
673import multiple, imports, together
674"#;
675
676 let imports = collect_imports_from_source(source, "test_module").unwrap();
677 assert_eq!(imports.len(), 7); assert_eq!(imports[0].module, "os");
681 assert_eq!(imports[0].names.len(), 1); assert_eq!(imports[0].names[0], ("os".to_string(), None));
683
684 assert_eq!(imports[1].module, "sys");
686 assert_eq!(imports[1].names, vec![("path".to_string(), None)]);
687
688 assert_eq!(imports[2].module, "..relative");
690 assert_eq!(imports[2].names, vec![("something".to_string(), None)]);
691
692 assert_eq!(imports[3].module, ".");
693 assert_eq!(imports[3].names, vec![("sibling".to_string(), None)]);
694
695 assert_eq!(imports[4].module, "multiple");
697 assert_eq!(imports[4].names.len(), 1);
698 assert_eq!(imports[4].names[0], ("multiple".to_string(), None));
699
700 assert_eq!(imports[5].module, "imports");
701 assert_eq!(imports[5].names.len(), 1);
702 assert_eq!(imports[5].names[0], ("imports".to_string(), None));
703
704 assert_eq!(imports[6].module, "together");
705 assert_eq!(imports[6].names.len(), 1);
706 assert_eq!(imports[6].names[0], ("together".to_string(), None));
707 }
708
709 #[test]
710 fn test_empty_module_cache() {
711 clear_module_cache();
712
713 let result = collect_deprecated_from_module("nonexistent.module").unwrap();
715 assert!(result.replacements.is_empty());
716 }
717
718 #[test]
719 fn test_max_depth_zero() {
720 let source = "import os";
722 let result = collect_deprecated_from_dependencies(source, "test_module", 0).unwrap();
723 assert!(result.replacements.is_empty());
724 }
725
726 #[test]
727 fn test_visited_modules_cycle_prevention() {
728 let mut visited = HashSet::new();
732 visited.insert("module_a".to_string());
733
734 assert!(visited.contains("module_a"));
737 }
738}