Skip to main content

codelens_engine/
auto_import.rs

1//! Automatic import suggestion and insertion.
2//!
3//! Detects unresolved symbols in a file and suggests imports from the project's symbol index.
4//! Generates language-appropriate import statements and inserts at the correct position.
5//!
6//! **Internal API.** `add_import` writes through
7//! `apply_full_write_with_evidence` and bypasses ADR-0009 role
8//! gates, audit sinks, and cache invalidation when called outside
9//! `codelens-mcp` dispatch. Route mutations through MCP — see the
10//! crate-level docs in `lib.rs`.
11
12use crate::import_graph::extract_imports_for_file;
13use crate::project::ProjectRoot;
14use crate::symbols::{SymbolIndex, SymbolInfo, get_symbols_overview, language_for_path};
15use anyhow::Result;
16use regex::Regex;
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::fs;
20use std::hash::{Hash, Hasher};
21use std::path::Path;
22use std::sync::{LazyLock, Mutex};
23use tree_sitter::{Node, Parser};
24
25static TYPE_CANDIDATE_RE: LazyLock<Regex> =
26    LazyLock::new(|| Regex::new(r"\b([A-Z][a-zA-Z0-9_]*)\b").unwrap());
27
28const IMPORT_CACHE_CAPACITY: usize = 64;
29
30static IMPORT_ANALYSIS_CACHE: LazyLock<Mutex<HashMap<u64, MissingImportAnalysis>>> =
31    LazyLock::new(|| Mutex::new(HashMap::new()));
32
33fn content_cache_key(file_path: &str, content: &str) -> u64 {
34    let mut hasher = std::hash::DefaultHasher::new();
35    file_path.hash(&mut hasher);
36    content.hash(&mut hasher);
37    hasher.finish()
38}
39
40#[derive(Debug, Clone, Serialize)]
41pub struct ImportSuggestion {
42    pub symbol_name: String,
43    pub source_file: String,
44    pub import_statement: String,
45    pub insert_line: usize,
46    pub confidence: f64,
47}
48
49#[derive(Debug, Clone, Serialize)]
50pub struct MissingImportAnalysis {
51    pub file_path: String,
52    pub unresolved_symbols: Vec<String>,
53    pub suggestions: Vec<ImportSuggestion>,
54}
55
56/// Analyze a file for potentially unresolved symbols and suggest imports.
57/// Results are cached by (file_path, content_hash) to avoid redundant parsing and lookups.
58pub fn analyze_missing_imports(
59    project: &ProjectRoot,
60    file_path: &str,
61) -> Result<MissingImportAnalysis> {
62    let resolved = project.resolve(file_path)?;
63    let source = fs::read_to_string(&resolved)?;
64    let cache_key = content_cache_key(file_path, &source);
65
66    // Return cached result if file content unchanged
67    if let Ok(cache) = IMPORT_ANALYSIS_CACHE.lock()
68        && let Some(cached) = cache.get(&cache_key)
69    {
70        return Ok(cached.clone());
71    }
72
73    let ext = resolved
74        .extension()
75        .and_then(|e| e.to_str())
76        .unwrap_or("")
77        .to_ascii_lowercase();
78
79    // Step 1: Extract type identifiers via tree-sitter (not regex)
80    let used_types = collect_type_candidates_ast(&resolved, &source)?;
81
82    // Step 2: Collect locally defined symbols
83    let local_symbols: HashSet<String> = get_symbols_overview(project, file_path, 0)?
84        .into_iter()
85        .flat_map(flatten_names)
86        .collect();
87
88    // Step 3: Collect already-imported symbols
89    let existing_imports = extract_existing_import_names(&resolved);
90
91    // Step 4: Find unresolved = used - local - imported - builtins
92    let unresolved: Vec<String> = used_types
93        .into_iter()
94        .filter(|name| !local_symbols.contains(name) && !existing_imports.contains(name))
95        .filter(|name| !is_builtin(name, &ext))
96        .collect();
97
98    // Step 5: Batch lookup via SymbolIndex (SQLite) — much faster than per-name find_symbol
99    let insert_line = find_import_insert_line(&source, &ext);
100    let mut suggestions = Vec::new();
101    let index = SymbolIndex::new(project.clone());
102
103    for name in &unresolved {
104        if let Ok(matches) = index.find_symbol(name, None, false, true, 3) {
105            // Skip if only found in the same file
106            let external: Vec<_> = matches
107                .iter()
108                .filter(|m| m.file_path != file_path)
109                .collect();
110            let best_ref = external.first().copied().or(matches.first());
111            if let Some(best) = best_ref {
112                let import_stmt = generate_import_statement(name, &best.file_path, &ext);
113                suggestions.push(ImportSuggestion {
114                    symbol_name: name.clone(),
115                    source_file: best.file_path.clone(),
116                    import_statement: import_stmt,
117                    insert_line,
118                    confidence: if external.len() == 1 { 0.95 } else { 0.7 },
119                });
120            }
121        }
122    }
123
124    let result = MissingImportAnalysis {
125        file_path: file_path.to_string(),
126        unresolved_symbols: unresolved,
127        suggestions,
128    };
129
130    // Store in cache, evict oldest if at capacity
131    if let Ok(mut cache) = IMPORT_ANALYSIS_CACHE.lock() {
132        if cache.len() >= IMPORT_CACHE_CAPACITY
133            && let Some(&oldest_key) = cache.keys().next()
134        {
135            cache.remove(&oldest_key);
136        }
137        cache.insert(cache_key, result.clone());
138    }
139
140    Ok(result)
141}
142
143/// Add an import statement to a file at the correct position.
144pub fn add_import(
145    project: &ProjectRoot,
146    file_path: &str,
147    import_statement: &str,
148) -> Result<(String, crate::edit_transaction::ApplyEvidence)> {
149    let resolved = project.resolve(file_path)?;
150    let content = fs::read_to_string(&resolved)?;
151    let ext = resolved
152        .extension()
153        .and_then(|e| e.to_str())
154        .unwrap_or("")
155        .to_ascii_lowercase();
156
157    // Check if already imported — synthesize a no-op Applied evidence
158    if content.contains(import_statement.trim()) {
159        let evidence = crate::edit_transaction::ApplyEvidence {
160            status: crate::edit_transaction::ApplyStatus::Applied,
161            file_hashes_before: Default::default(),
162            file_hashes_after: Default::default(),
163            rollback_report: vec![],
164            modified_files: 0,
165            edit_count: 0,
166        };
167        return Ok((content, evidence));
168    }
169
170    let insert_line = find_import_insert_line(&content, &ext);
171    let mut lines: Vec<&str> = content.lines().collect();
172    let insert_idx = (insert_line - 1).min(lines.len());
173    lines.insert(insert_idx, import_statement.trim());
174
175    let mut new_content = lines.join("\n");
176    if content.ends_with('\n') {
177        new_content.push('\n');
178    }
179
180    let relative_path = file_path;
181    let evidence = match crate::edit_transaction::apply_full_write_with_evidence(
182        project,
183        relative_path,
184        &new_content,
185    ) {
186        Ok(ev) => ev,
187        Err(crate::edit_transaction::ApplyError::ApplyFailed {
188            source: _,
189            evidence,
190        }) => evidence,
191        Err(other) => return Err(anyhow::Error::msg(other.to_string())),
192    };
193    Ok((new_content, evidence))
194}
195
196// ── Helpers ──────────────────────────────────────────────────────────────
197
198/// Collect type candidates from AST — only real type identifiers, not strings/comments.
199fn collect_type_candidates_ast(file_path: &Path, source: &str) -> Result<Vec<String>> {
200    let Some(config) = language_for_path(file_path) else {
201        // Fallback to regex for unsupported languages
202        return Ok(collect_type_candidates_regex(source));
203    };
204
205    let mut parser = Parser::new();
206    parser.set_language(&config.language)?;
207    let Some(tree) = parser.parse(source, None) else {
208        return Ok(collect_type_candidates_regex(source));
209    };
210
211    let source_bytes = source.as_bytes();
212    let mut seen = HashSet::new();
213    let mut result = Vec::new();
214    collect_type_nodes(tree.root_node(), source_bytes, &mut seen, &mut result);
215    Ok(result)
216}
217
218fn collect_type_nodes(
219    node: Node,
220    source: &[u8],
221    seen: &mut HashSet<String>,
222    out: &mut Vec<String>,
223) {
224    let kind = node.kind();
225
226    // Skip comments and strings entirely
227    if matches!(
228        kind,
229        "comment"
230            | "line_comment"
231            | "block_comment"
232            | "string"
233            | "string_literal"
234            | "template_string"
235            | "raw_string_literal"
236            | "interpreted_string_literal"
237    ) {
238        return;
239    }
240
241    // Collect type identifiers and uppercase identifiers in type positions
242    if kind == "type_identifier" || kind == "identifier" {
243        let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
244        if !text.is_empty()
245            && text
246                .chars()
247                .next()
248                .map(|c| c.is_uppercase())
249                .unwrap_or(false)
250            && !is_keyword(text)
251            && seen.insert(text.to_string())
252        {
253            out.push(text.to_string());
254        }
255    }
256
257    for i in 0..node.child_count() {
258        if let Some(child) = node.child(i) {
259            collect_type_nodes(child, source, seen, out);
260        }
261    }
262}
263
264/// Regex fallback for unsupported languages.
265fn collect_type_candidates_regex(source: &str) -> Vec<String> {
266    let re = &*TYPE_CANDIDATE_RE;
267    let mut seen = HashSet::new();
268    let mut result = Vec::new();
269    for line in source.lines() {
270        let trimmed = line.trim();
271        if trimmed.starts_with('#') || trimmed.starts_with("//") || trimmed.starts_with("/*") {
272            continue;
273        }
274        for cap in re.find_iter(line) {
275            let name = cap.as_str().to_string();
276            if !is_keyword(&name) && seen.insert(name.clone()) {
277                result.push(name);
278            }
279        }
280    }
281    result
282}
283
284/// Extract names that are already imported.
285fn extract_existing_import_names(path: &Path) -> HashSet<String> {
286    let raw_imports = extract_imports_for_file(path);
287    let mut names = HashSet::new();
288    for imp in &raw_imports {
289        // Extract last segment: "from foo import Bar" → "Bar", "import foo.Bar" → "Bar"
290        if let Some(last) = imp.rsplit('.').next() {
291            names.insert(last.to_string());
292        }
293        // Also try extracting from "from X import Y" patterns
294        if let Some(pos) = imp.find(" import ") {
295            let after = &imp[pos + 8..];
296            for part in after.split(',') {
297                let name = part.trim().split(" as ").next().unwrap_or("").trim();
298                if !name.is_empty() {
299                    names.insert(name.to_string());
300                }
301            }
302        }
303    }
304    names
305}
306
307/// Find the line number where new imports should be inserted.
308fn find_import_insert_line(source: &str, ext: &str) -> usize {
309    let mut last_import_line = 0;
310    let mut in_docstring = false;
311
312    for (i, line) in source.lines().enumerate() {
313        let trimmed = line.trim();
314
315        // Skip Python docstrings
316        if trimmed.contains("\"\"\"") || trimmed.contains("'''") {
317            in_docstring = !in_docstring;
318            continue;
319        }
320        if in_docstring {
321            continue;
322        }
323
324        let is_import = match ext {
325            "py" => trimmed.starts_with("import ") || trimmed.starts_with("from "),
326            "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
327                trimmed.starts_with("import ") || trimmed.starts_with("import{")
328            }
329            "java" | "kt" | "kts" => trimmed.starts_with("import "),
330            "go" => trimmed.starts_with("import ") || trimmed == "import (",
331            "rs" => trimmed.starts_with("use "),
332            _ => false,
333        };
334
335        if is_import {
336            last_import_line = i + 1;
337        }
338    }
339
340    // If no imports found, insert after package/module declaration or at top
341    if last_import_line == 0 {
342        for (i, line) in source.lines().enumerate() {
343            let trimmed = line.trim();
344            if trimmed.starts_with("package ")
345                || trimmed.starts_with("module ")
346                || (trimmed.starts_with('#') && trimmed.contains("!"))
347            {
348                return i + 2; // After package + blank line
349            }
350        }
351        return 1;
352    }
353
354    last_import_line + 1
355}
356
357/// Generate a language-appropriate import statement.
358fn generate_import_statement(symbol_name: &str, source_file: &str, target_ext: &str) -> String {
359    let module = source_file
360        .trim_end_matches(".py")
361        .trim_end_matches(".ts")
362        .trim_end_matches(".tsx")
363        .trim_end_matches(".js")
364        .trim_end_matches(".jsx")
365        .trim_end_matches(".java")
366        .trim_end_matches(".kt")
367        .trim_end_matches(".rs")
368        .trim_end_matches(".go")
369        .replace('/', ".");
370
371    match target_ext {
372        "py" => format!("from {module} import {symbol_name}"),
373        "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
374            let rel_path = format!(
375                "./{}",
376                source_file
377                    .trim_end_matches(".ts")
378                    .trim_end_matches(".tsx")
379                    .trim_end_matches(".js")
380            );
381            format!("import {{ {} }} from '{}';", symbol_name, rel_path)
382        }
383        "java" => format!("import {};", module),
384        "kt" | "kts" => format!("import {}", module),
385        "rs" => format!("use crate::{};", module.replace('.', "::")),
386        "go" => format!("import \"{}\"", source_file.trim_end_matches(".go")),
387        _ => format!("// import {} from {}", symbol_name, source_file),
388    }
389}
390
391fn flatten_names(symbol: SymbolInfo) -> Vec<String> {
392    let mut names = vec![symbol.name.clone()];
393    for child in symbol.children {
394        names.extend(flatten_names(child));
395    }
396    names
397}
398
399fn is_keyword(name: &str) -> bool {
400    matches!(
401        name,
402        "True"
403            | "False"
404            | "None"
405            | "Self"
406            | "String"
407            | "Result"
408            | "Option"
409            | "Vec"
410            | "HashMap"
411            | "HashSet"
412            | "Object"
413            | "Array"
414            | "Map"
415            | "Set"
416            | "Promise"
417            | "Error"
418            | "TypeError"
419            | "ValueError"
420            | "Exception"
421            | "RuntimeError"
422            | "Boolean"
423            | "Integer"
424            | "Float"
425            | "Double"
426            | "NULL"
427            | "EOF"
428            | "TODO"
429            | "FIXME"
430            | "HACK"
431    )
432}
433
434fn is_builtin(name: &str, ext: &str) -> bool {
435    if is_keyword(name) {
436        return true;
437    }
438    match ext {
439        "py" => matches!(
440            name,
441            "int"
442                | "str"
443                | "float"
444                | "bool"
445                | "list"
446                | "dict"
447                | "tuple"
448                | "set"
449                | "Type"
450                | "Optional"
451                | "List"
452                | "Dict"
453                | "Tuple"
454                | "Set"
455                | "Any"
456                | "Union"
457                | "Callable"
458        ),
459        "ts" | "tsx" | "js" | "jsx" => matches!(
460            name,
461            "Date"
462                | "RegExp"
463                | "JSON"
464                | "Math"
465                | "Number"
466                | "Console"
467                | "Window"
468                | "Document"
469                | "Element"
470                | "HTMLElement"
471                | "Event"
472                | "Response"
473                | "Request"
474                | "Partial"
475                | "Readonly"
476                | "Record"
477                | "Pick"
478                | "Omit"
479        ),
480        "java" | "kt" => matches!(
481            name,
482            "System"
483                | "Math"
484                | "Thread"
485                | "Class"
486                | "Comparable"
487                | "Iterable"
488                | "Iterator"
489                | "Override"
490                | "Deprecated"
491                | "Test"
492                | "Suppress"
493        ),
494        "rs" => matches!(
495            name,
496            "Ok" | "Err"
497                | "Some"
498                | "Copy"
499                | "Clone"
500                | "Debug"
501                | "Default"
502                | "Display"
503                | "From"
504                | "Into"
505                | "Send"
506                | "Sync"
507                | "Sized"
508                | "Drop"
509                | "Fn"
510                | "FnMut"
511                | "FnOnce"
512                | "Box"
513                | "Rc"
514                | "Arc"
515                | "Mutex"
516                | "RwLock"
517                | "Pin"
518                | "Serialize"
519                | "Deserialize"
520                | "Regex"
521                | "Path"
522                | "PathBuf"
523                | "File"
524                | "Read"
525                | "Write"
526                | "BufRead"
527                | "BufReader"
528                | "BufWriter"
529                | "WalkDir"
530                | "Context"
531                | "Cow"
532                | "PhantomData"
533                | "ManuallyDrop"
534        ),
535        _ => false,
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use crate::ProjectRoot;
543
544    fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
545        let dir = std::env::temp_dir().join(format!(
546            "codelens-autoimport-{}",
547            std::time::SystemTime::now()
548                .duration_since(std::time::UNIX_EPOCH)
549                .unwrap()
550                .as_nanos()
551        ));
552        fs::create_dir_all(dir.join("src")).unwrap();
553        fs::write(
554            dir.join("src/models.py"),
555            "class UserModel:\n    def __init__(self, name):\n        self.name = name\n",
556        )
557        .unwrap();
558        fs::write(
559            dir.join("src/service.py"),
560            "class UserService:\n    def get(self):\n        return UserModel()\n",
561        )
562        .unwrap();
563        let project = ProjectRoot::new(&dir).unwrap();
564        (dir, project)
565    }
566
567    #[test]
568    fn detects_unresolved_type() {
569        let (_dir, project) = make_fixture();
570        let result = analyze_missing_imports(&project, "src/service.py").unwrap();
571        assert!(
572            result.unresolved_symbols.contains(&"UserModel".to_string()),
573            "should detect UserModel as unresolved: {:?}",
574            result.unresolved_symbols
575        );
576    }
577
578    #[test]
579    fn suggests_import_for_unresolved() {
580        let (_dir, project) = make_fixture();
581        let result = analyze_missing_imports(&project, "src/service.py").unwrap();
582        let suggestion = result
583            .suggestions
584            .iter()
585            .find(|s| s.symbol_name == "UserModel");
586        assert!(
587            suggestion.is_some(),
588            "should suggest import for UserModel: {:?}",
589            result.suggestions
590        );
591        let s = suggestion.unwrap();
592        assert!(
593            s.import_statement.contains("UserModel"),
594            "import statement should mention UserModel: {}",
595            s.import_statement
596        );
597        assert!(s.confidence > 0.5);
598    }
599
600    #[test]
601    fn does_not_suggest_locally_defined() {
602        let (_dir, project) = make_fixture();
603        let result = analyze_missing_imports(&project, "src/models.py").unwrap();
604        assert!(
605            !result.unresolved_symbols.contains(&"UserModel".to_string()),
606            "locally defined UserModel should not be unresolved"
607        );
608    }
609
610    #[test]
611    fn add_import_inserts_at_correct_position() {
612        use crate::edit_transaction::ApplyStatus;
613
614        let dir = std::env::temp_dir().join(format!(
615            "codelens-addimport-{}",
616            std::time::SystemTime::now()
617                .duration_since(std::time::UNIX_EPOCH)
618                .unwrap()
619                .as_nanos()
620        ));
621        fs::create_dir_all(&dir).unwrap();
622        fs::write(
623            dir.join("test.py"),
624            "import os\nimport sys\n\ndef main():\n    pass\n",
625        )
626        .unwrap();
627        let project = ProjectRoot::new(&dir).unwrap();
628        let (result, evidence) =
629            add_import(&project, "test.py", "from models import User").unwrap();
630        assert_eq!(
631            evidence.status,
632            ApplyStatus::Applied,
633            "evidence should be Applied"
634        );
635        let lines: Vec<&str> = result.lines().collect();
636        // Should be inserted after existing imports (line 3)
637        assert!(
638            lines.contains(&"from models import User"),
639            "should contain new import: {:?}",
640            lines
641        );
642        let import_idx = lines
643            .iter()
644            .position(|l| *l == "from models import User")
645            .unwrap();
646        let sys_idx = lines.iter().position(|l| *l == "import sys").unwrap();
647        assert!(
648            import_idx > sys_idx,
649            "new import should be after existing imports"
650        );
651    }
652
653    #[test]
654    fn skip_already_imported() {
655        let dir = std::env::temp_dir().join(format!(
656            "codelens-skipimport-{}",
657            std::time::SystemTime::now()
658                .duration_since(std::time::UNIX_EPOCH)
659                .unwrap()
660                .as_nanos()
661        ));
662        fs::create_dir_all(&dir).unwrap();
663        fs::write(
664            dir.join("test.py"),
665            "from models import User\n\nx = User()\n",
666        )
667        .unwrap();
668        let project = ProjectRoot::new(&dir).unwrap();
669        let (result, _evidence) =
670            add_import(&project, "test.py", "from models import User").unwrap();
671        // Should not duplicate
672        assert_eq!(
673            result.matches("from models import User").count(),
674            1,
675            "should not duplicate import"
676        );
677    }
678
679    #[test]
680    fn find_import_insert_line_python() {
681        let source = "import os\nimport sys\n\ndef main():\n    pass\n";
682        assert_eq!(find_import_insert_line(source, "py"), 3);
683    }
684
685    #[test]
686    fn find_import_insert_line_empty() {
687        let source = "def main():\n    pass\n";
688        assert_eq!(find_import_insert_line(source, "py"), 1);
689    }
690
691    #[test]
692    fn generate_python_import() {
693        let stmt = generate_import_statement("UserModel", "src/models.py", "py");
694        assert_eq!(stmt, "from src.models import UserModel");
695    }
696
697    #[test]
698    fn generate_typescript_import() {
699        let stmt = generate_import_statement("UserService", "src/service.ts", "ts");
700        assert!(stmt.contains("import { UserService }"));
701        assert!(stmt.contains("'./src/service'"));
702    }
703}