Skip to main content

dissolve_python/
dependency_collector.rs

1// Copyright (C) 2024 Jelmer Vernooij <jelmer@samba.org>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Context, Result};
16use once_cell::sync::Lazy;
17use std::collections::{HashMap, HashSet};
18use std::fs;
19use std::path::Path;
20use std::sync::Mutex;
21use tracing;
22
23use crate::core::{CollectorResult, ImportInfo, ReplaceInfo};
24use crate::unified_visitor::{UnifiedResult, UnifiedVisitor};
25use crate::RuffDeprecatedFunctionCollector;
26
27/// Global cache for module analysis results
28static MODULE_CACHE: Lazy<Mutex<HashMap<String, CollectorResult>>> =
29    Lazy::new(|| Mutex::new(HashMap::new()));
30
31/// Collection result for dependency analysis
32#[derive(Debug, Clone)]
33pub struct DependencyCollectionResult {
34    pub replacements: HashMap<String, ReplaceInfo>,
35    pub inheritance_map: HashMap<String, Vec<String>>,
36    pub class_methods: HashMap<String, HashSet<String>>,
37}
38
39impl Default for DependencyCollectionResult {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl DependencyCollectionResult {
46    pub fn new() -> Self {
47        Self {
48            replacements: HashMap::new(),
49            inheritance_map: HashMap::new(),
50            class_methods: HashMap::new(),
51        }
52    }
53
54    /// Merge another result into this one
55    pub fn update(&mut self, other: &DependencyCollectionResult) {
56        // Avoid cloning by using references where possible and only clone when necessary
57        for (key, value) in &other.replacements {
58            self.replacements.insert(key.clone(), value.clone());
59        }
60
61        for (key, value) in &other.inheritance_map {
62            self.inheritance_map.insert(key.clone(), value.clone());
63        }
64
65        // Merge class_methods, combining sets for same classes
66        for (class_name, methods) in &other.class_methods {
67            self.class_methods
68                .entry(class_name.clone())
69                .or_default()
70                .extend(methods.iter().cloned());
71        }
72    }
73}
74
75impl From<CollectorResult> for DependencyCollectionResult {
76    fn from(result: CollectorResult) -> Self {
77        Self {
78            replacements: result.replacements,
79            inheritance_map: result.inheritance_map,
80            class_methods: result.class_methods,
81        }
82    }
83}
84
85/// Clear the module analysis cache
86pub fn clear_module_cache() {
87    if let Ok(mut cache) = MODULE_CACHE.lock() {
88        cache.clear();
89    }
90}
91
92/// Get all base classes in the inheritance chain for a given class
93fn get_inheritance_chain_for_class(
94    class_name: &str,
95    inheritance_map: &HashMap<String, Vec<String>>,
96) -> Vec<String> {
97    let mut chain = Vec::new();
98    let mut to_process = vec![class_name.to_string()];
99    let mut processed = HashSet::new();
100
101    while let Some(current) = to_process.pop() {
102        if !processed.insert(current.clone()) {
103            // Already processed, skip
104            continue;
105        }
106
107        if let Some(bases) = inheritance_map.get(&current) {
108            // Use iterators to avoid unnecessary clones
109            chain.extend(bases.iter().cloned());
110            to_process.extend(bases.iter().cloned());
111        }
112    }
113
114    chain
115}
116
117/// Extract all imports from a Python source file
118pub fn collect_imports_from_source(source: &str, module_name: &str) -> Result<Vec<ImportInfo>> {
119    // Create visitor to extract imports
120    let visitor = UnifiedVisitor::new_for_collection(module_name, None);
121    let unified_result = visitor.process_source(source.to_string())?;
122
123    let result = match unified_result {
124        UnifiedResult::Collection(result) => result,
125        _ => return Err(anyhow::anyhow!("Expected collection result")),
126    };
127
128    Ok(result.imports)
129}
130
131/// Resolve a module name to its actual import path
132pub fn resolve_module_path(module_name: &str, relative_to: Option<&str>) -> Option<String> {
133    // Handle relative imports
134    if module_name.starts_with('.') {
135        let relative_to = relative_to?;
136
137        // Count leading dots
138        let level = module_name.chars().take_while(|&c| c == '.').count();
139        let relative_parts: Vec<&str> = if module_name.len() > level {
140            module_name[level..].split('.').collect()
141        } else {
142            vec![]
143        };
144
145        // Go up 'level' packages from relative_to
146        let mut base_parts: Vec<&str> = relative_to.split('.').collect();
147        if level >= base_parts.len() {
148            return None;
149        }
150
151        base_parts.truncate(base_parts.len() - level);
152        base_parts.extend(relative_parts);
153
154        Some(base_parts.join("."))
155    } else {
156        Some(module_name.to_string())
157    }
158}
159
160/// Quick check if source might contain replace_me
161pub fn might_contain_replace_me(source: &str) -> bool {
162    // Check for @replace_me decorators even if replace_me itself isn't directly imported
163    source.contains("@replace_me") || source.contains("replace_me")
164}
165
166/// Find Python module file using importlib
167#[allow(dead_code)]
168fn find_module_file(module_path: &str) -> Option<String> {
169    find_module_file_with_paths(module_path, &[])
170}
171
172/// Find Python module file using importlib with additional search paths
173fn find_module_file_with_paths(module_path: &str, additional_paths: &[String]) -> Option<String> {
174    use pyo3::prelude::*;
175
176    Python::with_gil(|py| {
177        // First check additional paths if provided (for test environments)
178        if !additional_paths.is_empty() {
179            tracing::debug!(
180                "Checking additional paths for module {}: {:?}",
181                module_path,
182                additional_paths
183            );
184            // For each additional path, check if the module exists there
185            for base_path in additional_paths {
186                // Convert module path to file path
187                let module_parts: Vec<&str> = module_path.split('.').collect();
188                let mut file_path = std::path::PathBuf::from(base_path);
189                for part in &module_parts {
190                    file_path.push(part);
191                }
192
193                // Check for __init__.py (package)
194                let init_path = file_path.join("__init__.py");
195                if init_path.exists() {
196                    return Some(init_path.to_string_lossy().to_string());
197                }
198
199                // Check for .py file (module)
200                file_path.set_extension("py");
201                tracing::debug!(
202                    "Checking path: {:?}, exists: {}",
203                    file_path,
204                    file_path.exists()
205                );
206                if file_path.exists() {
207                    tracing::debug!("Found module at: {:?}", file_path);
208                    return Some(file_path.to_string_lossy().to_string());
209                }
210            }
211        }
212
213        // If not found in additional paths, try importlib
214        let importlib_util = py.import("importlib.util").ok()?;
215        let find_spec = importlib_util.getattr("find_spec").ok()?;
216
217        // Try to find the module with current sys.path
218        if let Ok(spec) = find_spec.call1((module_path,)) {
219            if !spec.is_none() {
220                if let Ok(origin) = spec.getattr("origin") {
221                    if !origin.is_none() {
222                        if let Ok(path) = origin.extract::<String>() {
223                            return Some(path);
224                        }
225                    }
226                }
227            }
228        }
229
230        None
231    })
232}
233
234/// Collect all deprecated functions from a single module
235pub fn collect_deprecated_from_module(module_path: &str) -> Result<DependencyCollectionResult> {
236    collect_deprecated_from_module_with_paths(module_path, &[])
237}
238
239/// Collect all deprecated functions from a single module with additional search paths
240pub fn collect_deprecated_from_module_with_paths(
241    module_path: &str,
242    additional_paths: &[String],
243) -> Result<DependencyCollectionResult> {
244    // Check cache first
245    if let Ok(cache) = MODULE_CACHE.lock() {
246        if let Some(cached) = cache.get(module_path) {
247            return Ok(cached.clone().into());
248        }
249    }
250
251    let mut result = CollectorResult::new();
252
253    // Find the module file
254    tracing::debug!(
255        "Looking for module {} with additional paths: {:?}",
256        module_path,
257        additional_paths
258    );
259    if let Some(file_path) = find_module_file_with_paths(module_path, additional_paths) {
260        tracing::debug!("Found module {} at {}", module_path, file_path);
261
262        // Read the source file
263        let source = fs::read_to_string(&file_path)
264            .with_context(|| format!("Failed to read module file: {}", file_path))?;
265
266        // Quick check for replace_me
267        if !might_contain_replace_me(&source) {
268            tracing::debug!("Module {} does not contain replace_me", module_path);
269            // Cache empty result
270            if let Ok(mut cache) = MODULE_CACHE.lock() {
271                cache.insert(module_path.to_string(), result.clone());
272            }
273            return Ok(result.into());
274        }
275
276        tracing::debug!("Module {} contains replace_me, collecting...", module_path);
277
278        // Parse and collect using Ruff
279        let collector = RuffDeprecatedFunctionCollector::new(
280            module_path.to_string(),
281            Some(Path::new(&file_path)),
282        );
283        if let Ok(collector_result) = collector.collect_from_source(source) {
284            tracing::debug!(
285                "Found {} replacements in {}",
286                collector_result.replacements.len(),
287                module_path
288            );
289            for (key, replacement) in &collector_result.replacements {
290                tracing::debug!(
291                    "  Replacement key: {} -> {}",
292                    key,
293                    replacement.replacement_expr
294                );
295            }
296            result = collector_result;
297        }
298    } else {
299        tracing::debug!("Module {} not found", module_path);
300    }
301
302    // Cache the result
303    if let Ok(mut cache) = MODULE_CACHE.lock() {
304        cache.insert(module_path.to_string(), result.clone());
305    }
306
307    Ok(result.into())
308}
309
310/// Collect all deprecated functions from imported modules
311pub fn collect_deprecated_from_dependencies(
312    source: &str,
313    module_name: &str,
314    max_depth: i32,
315) -> Result<DependencyCollectionResult> {
316    collect_deprecated_from_dependencies_with_paths(source, module_name, max_depth, &[])
317}
318
319/// Collect all deprecated functions from imported modules with additional search paths
320pub fn collect_deprecated_from_dependencies_with_paths(
321    source: &str,
322    module_name: &str,
323    max_depth: i32,
324    additional_paths: &[String],
325) -> Result<DependencyCollectionResult> {
326    tracing::info!(
327        "Starting recursive collection for module {} with max_depth {}",
328        module_name,
329        max_depth
330    );
331    collect_deprecated_from_dependencies_recursive(
332        source,
333        module_name,
334        max_depth,
335        &mut HashSet::new(),
336        additional_paths,
337    )
338}
339
340/// Internal recursive function that tracks visited modules to avoid cycles
341fn collect_deprecated_from_dependencies_recursive(
342    source: &str,
343    module_name: &str,
344    max_depth: i32,
345    visited_modules: &mut HashSet<String>,
346    additional_paths: &[String],
347) -> Result<DependencyCollectionResult> {
348    let mut result = DependencyCollectionResult::new();
349
350    // Stop if we've reached max depth
351    if max_depth <= 0 {
352        return Ok(result);
353    }
354
355    // Get imports from source
356    let imports = collect_imports_from_source(source, module_name)?;
357    tracing::info!("Found {} imports in source", imports.len());
358    for imp in &imports {
359        tracing::info!("  Import: {:?}", imp);
360    }
361
362    // Group imports by resolved module path
363    let mut module_imports: HashMap<String, Vec<ImportInfo>> = HashMap::new();
364
365    for imp in imports {
366        if let Some(resolved) = resolve_module_path(&imp.module, Some(module_name)) {
367            module_imports.entry(resolved).or_default().push(imp);
368        }
369    }
370
371    // Process each unique module
372    for (resolved, imp_list) in module_imports {
373        // Skip if we've already visited this module (avoid cycles)
374        if visited_modules.contains(&resolved) {
375            tracing::debug!("Skipping already visited module: {}", resolved);
376            continue;
377        }
378        tracing::debug!("Processing module: {} at depth {}", resolved, max_depth);
379        visited_modules.insert(resolved.clone());
380
381        // Collect from this module
382        tracing::debug!("Attempting to collect from module: {}", resolved);
383        if let Ok(module_result) =
384            collect_deprecated_from_module_with_paths(&resolved, additional_paths)
385        {
386            tracing::debug!(
387                "Module {} has {} replacements",
388                resolved,
389                module_result.replacements.len()
390            );
391            tracing::info!(
392                "Module {} has {} replacements and inheritance map: {:?}",
393                resolved,
394                module_result.replacements.len(),
395                module_result.inheritance_map
396            );
397            // Extend the inheritance map efficiently
398            for (key, value) in &module_result.inheritance_map {
399                result.inheritance_map.insert(key.clone(), value.clone());
400            }
401
402            // Collect all imported names
403            let mut all_imported_names = HashSet::new();
404            let mut has_star_import = false;
405
406            for imp in &imp_list {
407                for (name, _alias) in &imp.names {
408                    if name == "*" {
409                        has_star_import = true;
410                    } else {
411                        all_imported_names.insert(name.clone());
412                    }
413                }
414                if imp.names.is_empty() {
415                    // Import entire module
416                    has_star_import = true;
417                }
418            }
419
420            // Filter replacements based on imported names
421            if has_star_import {
422                // Include all replacements
423                tracing::debug!(
424                    "Star import from {}, including all {} replacements",
425                    resolved,
426                    module_result.replacements.len()
427                );
428                result
429                    .replacements
430                    .extend(module_result.replacements.clone());
431
432                // Also process all classes from the module for inheritance
433                for class_path in module_result.replacements.keys() {
434                    if let Some(class_name) = class_path.split('.').nth(1) {
435                        let full_class_path = format!("{}.{}", resolved, class_name);
436
437                        // Get inheritance chain for this class
438                        let inheritance_chain = get_inheritance_chain_for_class(
439                            &full_class_path,
440                            &module_result.inheritance_map,
441                        );
442
443                        for base_class in inheritance_chain {
444                            // Include all methods from base classes
445                            for (repl_path, repl_info) in &module_result.replacements {
446                                if repl_path.starts_with(&format!("{}.", base_class)) {
447                                    result
448                                        .replacements
449                                        .insert(repl_path.clone(), repl_info.clone());
450                                }
451                            }
452                        }
453                    }
454                }
455            } else {
456                // Check each imported name
457                tracing::info!("Checking imported names: {:?}", all_imported_names);
458                for name in &all_imported_names {
459                    let full_path = format!("{}.{}", resolved, name);
460                    tracing::debug!(
461                        "Checking imported name '{}', full_path: '{}'  with replacements: {:?}",
462                        name,
463                        full_path,
464                        module_result.replacements.keys().collect::<Vec<_>>()
465                    );
466
467                    // Check all replacements
468                    for (repl_path, repl_info) in &module_result.replacements {
469                        if repl_path == &full_path
470                            || repl_path.starts_with(&format!("{}.", full_path))
471                        {
472                            result
473                                .replacements
474                                .insert(repl_path.clone(), repl_info.clone());
475                        }
476                    }
477
478                    // Check inherited methods
479                    if !module_result.inheritance_map.is_empty() {
480                        let inheritance_chain = get_inheritance_chain_for_class(
481                            &full_path,
482                            &module_result.inheritance_map,
483                        );
484                        tracing::debug!(
485                            "Inheritance chain for {}: {:?}",
486                            full_path,
487                            inheritance_chain
488                        );
489
490                        for base_class in inheritance_chain {
491                            // Try both the simple name and the fully qualified name
492                            let qualified_base = format!("{}.{}", resolved, base_class);
493
494                            for (repl_path, repl_info) in &module_result.replacements {
495                                if repl_path.starts_with(&format!("{}.", base_class))
496                                    || repl_path.starts_with(&format!("{}.", qualified_base))
497                                {
498                                    tracing::debug!(
499                                        "Including inherited replacement: {}",
500                                        repl_path
501                                    );
502                                    result
503                                        .replacements
504                                        .insert(repl_path.clone(), repl_info.clone());
505                                }
506                            }
507                        }
508                    }
509
510                    // Check submodules
511                    let submodule_path = format!("{}.{}", resolved, name);
512                    if let Ok(submodule_result) =
513                        collect_deprecated_from_module_with_paths(&submodule_path, additional_paths)
514                    {
515                        if !submodule_result.replacements.is_empty() {
516                            result.update(&submodule_result);
517                        }
518                    }
519                }
520            }
521
522            // Always update class_methods from this module
523            result
524                .class_methods
525                .extend(module_result.class_methods.clone());
526
527            // If max_depth > 1, recursively process dependencies of this module
528            if max_depth > 1 {
529                // Read the module's source to find its imports
530                if let Some(module_file) = find_module_file_with_paths(&resolved, additional_paths)
531                {
532                    if let Ok(module_source) = fs::read_to_string(&module_file) {
533                        tracing::debug!(
534                            "Recursively processing imports from {} (depth {})",
535                            resolved,
536                            max_depth - 1
537                        );
538                        if let Ok(dep_result) = collect_deprecated_from_dependencies_recursive(
539                            &module_source,
540                            &resolved,
541                            max_depth - 1,
542                            visited_modules,
543                            additional_paths,
544                        ) {
545                            result.update(&dep_result);
546                        }
547                    }
548                }
549            }
550        }
551    }
552
553    Ok(result)
554}
555
556/// Scan a file and collect deprecated functions from it and its dependencies
557pub fn scan_file_with_dependencies(
558    file_path: &str,
559    module_name: &str,
560) -> Result<HashMap<String, ReplaceInfo>> {
561    let mut all_replacements = HashMap::new();
562
563    // Read the source file
564    let source = fs::read_to_string(file_path)
565        .with_context(|| format!("Failed to read file: {}", file_path))?;
566
567    // First collect from the file itself using Ruff
568    let collector =
569        RuffDeprecatedFunctionCollector::new(module_name.to_string(), Some(Path::new(&file_path)));
570    if let Ok(result) = collector.collect_from_source(source.clone()) {
571        all_replacements.extend(result.replacements);
572    }
573
574    // Then collect from dependencies with proper recursion depth
575    if let Ok(dep_result) = collect_deprecated_from_dependencies(&source, module_name, 5) {
576        all_replacements.extend(dep_result.replacements);
577    }
578
579    Ok(all_replacements)
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[test]
587    fn test_resolve_module_path_absolute() {
588        assert_eq!(
589            resolve_module_path("os.path", None),
590            Some("os.path".to_string())
591        );
592        assert_eq!(
593            resolve_module_path("dulwich.repo", None),
594            Some("dulwich.repo".to_string())
595        );
596    }
597
598    #[test]
599    fn test_resolve_module_path_relative() {
600        // Test single-level relative import
601        assert_eq!(
602            resolve_module_path(".sibling", Some("package.module")),
603            Some("package.sibling".to_string())
604        );
605
606        // Test two-level relative import
607        assert_eq!(
608            resolve_module_path("..parent", Some("package.sub.module")),
609            Some("package.parent".to_string())
610        );
611
612        // Test relative import without explicit module
613        assert_eq!(
614            resolve_module_path("..", Some("package.sub.module")),
615            Some("package".to_string())
616        );
617
618        // Test relative import that goes too far up
619        assert_eq!(
620            resolve_module_path("...toomuch", Some("package.module")),
621            None
622        );
623    }
624
625    #[test]
626    fn test_might_contain_replace_me() {
627        assert!(might_contain_replace_me("@replace_me\ndef foo(): pass"));
628        assert!(might_contain_replace_me("from dissolve import replace_me"));
629        assert!(!might_contain_replace_me("def regular_function(): pass"));
630    }
631
632    #[test]
633    fn test_get_inheritance_chain() {
634        let mut inheritance_map = HashMap::new();
635        inheritance_map.insert("Child".to_string(), vec!["Parent".to_string()]);
636        inheritance_map.insert("Parent".to_string(), vec!["GrandParent".to_string()]);
637        inheritance_map.insert(
638            "GrandParent".to_string(),
639            vec!["GreatGrandParent".to_string()],
640        );
641
642        let chain = get_inheritance_chain_for_class("Child", &inheritance_map);
643        assert_eq!(chain.len(), 3);
644        assert!(chain.contains(&"Parent".to_string()));
645        assert!(chain.contains(&"GrandParent".to_string()));
646        assert!(chain.contains(&"GreatGrandParent".to_string()));
647    }
648
649    #[test]
650    fn test_get_inheritance_chain_multiple_inheritance() {
651        let mut inheritance_map = HashMap::new();
652        inheritance_map.insert(
653            "Child".to_string(),
654            vec!["Parent1".to_string(), "Parent2".to_string()],
655        );
656        inheritance_map.insert("Parent1".to_string(), vec!["GrandParent".to_string()]);
657        inheritance_map.insert("Parent2".to_string(), vec!["GrandParent".to_string()]);
658
659        let chain = get_inheritance_chain_for_class("Child", &inheritance_map);
660        assert!(chain.contains(&"Parent1".to_string()));
661        assert!(chain.contains(&"Parent2".to_string()));
662        assert!(chain.contains(&"GrandParent".to_string()));
663        // GrandParent might appear multiple times, but we handle duplicates in the algorithm
664    }
665
666    #[test]
667    fn test_collect_imports_from_source() {
668        let source = r#"
669import os
670from sys import path
671from ..relative import something
672from . import sibling
673import multiple, imports, together
674"#;
675
676        let imports = collect_imports_from_source(source, "test_module").unwrap();
677        assert_eq!(imports.len(), 7); // os, sys, ..relative, ., multiple, imports, together are counted as 3 separate imports
678
679        // Check first import
680        assert_eq!(imports[0].module, "os");
681        assert_eq!(imports[0].names.len(), 1); // Import creates one entry per name, with the name in the names vec
682        assert_eq!(imports[0].names[0], ("os".to_string(), None));
683
684        // Check from import
685        assert_eq!(imports[1].module, "sys");
686        assert_eq!(imports[1].names, vec![("path".to_string(), None)]);
687
688        // Check relative imports
689        assert_eq!(imports[2].module, "..relative");
690        assert_eq!(imports[2].names, vec![("something".to_string(), None)]);
691
692        assert_eq!(imports[3].module, ".");
693        assert_eq!(imports[3].names, vec![("sibling".to_string(), None)]);
694
695        // Check multiple imports on one line
696        assert_eq!(imports[4].module, "multiple");
697        assert_eq!(imports[4].names.len(), 1);
698        assert_eq!(imports[4].names[0], ("multiple".to_string(), None));
699
700        assert_eq!(imports[5].module, "imports");
701        assert_eq!(imports[5].names.len(), 1);
702        assert_eq!(imports[5].names[0], ("imports".to_string(), None));
703
704        assert_eq!(imports[6].module, "together");
705        assert_eq!(imports[6].names.len(), 1);
706        assert_eq!(imports[6].names[0], ("together".to_string(), None));
707    }
708
709    #[test]
710    fn test_empty_module_cache() {
711        clear_module_cache();
712
713        // Cache should work without errors
714        let result = collect_deprecated_from_module("nonexistent.module").unwrap();
715        assert!(result.replacements.is_empty());
716    }
717
718    #[test]
719    fn test_max_depth_zero() {
720        // max_depth = 0 should return empty results
721        let source = "import os";
722        let result = collect_deprecated_from_dependencies(source, "test_module", 0).unwrap();
723        assert!(result.replacements.is_empty());
724    }
725
726    #[test]
727    fn test_visited_modules_cycle_prevention() {
728        // This tests that we don't get into infinite loops with circular imports
729        // The actual test would need mock modules, but the visited_modules set
730        // ensures we don't process the same module twice
731        let mut visited = HashSet::new();
732        visited.insert("module_a".to_string());
733
734        // If module_a imports module_b and module_b imports module_a,
735        // we should skip module_a when processing module_b's imports
736        assert!(visited.contains("module_a"));
737    }
738}