Skip to main content

codelens_engine/
auto_import.rs

1//! Automatic import suggestion and insertion.
2//!
3//! Detects unresolved symbols in a file and suggests imports from the project's symbol index.
4//! Generates language-appropriate import statements and inserts at the correct position.
5
6use crate::import_graph::extract_imports_for_file;
7use crate::project::ProjectRoot;
8use crate::symbols::{SymbolIndex, SymbolInfo, get_symbols_overview, language_for_path};
9use anyhow::Result;
10use regex::Regex;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use std::fs;
14use std::hash::{Hash, Hasher};
15use std::path::Path;
16use std::sync::{LazyLock, Mutex};
17use tree_sitter::{Node, Parser};
18
19static TYPE_CANDIDATE_RE: LazyLock<Regex> =
20    LazyLock::new(|| Regex::new(r"\b([A-Z][a-zA-Z0-9_]*)\b").unwrap());
21
22const IMPORT_CACHE_CAPACITY: usize = 64;
23
24static IMPORT_ANALYSIS_CACHE: LazyLock<Mutex<HashMap<u64, MissingImportAnalysis>>> =
25    LazyLock::new(|| Mutex::new(HashMap::new()));
26
27fn content_cache_key(file_path: &str, content: &str) -> u64 {
28    let mut hasher = std::hash::DefaultHasher::new();
29    file_path.hash(&mut hasher);
30    content.hash(&mut hasher);
31    hasher.finish()
32}
33
34#[derive(Debug, Clone, Serialize)]
35pub struct ImportSuggestion {
36    pub symbol_name: String,
37    pub source_file: String,
38    pub import_statement: String,
39    pub insert_line: usize,
40    pub confidence: f64,
41}
42
43#[derive(Debug, Clone, Serialize)]
44pub struct MissingImportAnalysis {
45    pub file_path: String,
46    pub unresolved_symbols: Vec<String>,
47    pub suggestions: Vec<ImportSuggestion>,
48}
49
50/// Analyze a file for potentially unresolved symbols and suggest imports.
51/// Results are cached by (file_path, content_hash) to avoid redundant parsing and lookups.
52pub fn analyze_missing_imports(
53    project: &ProjectRoot,
54    file_path: &str,
55) -> Result<MissingImportAnalysis> {
56    let resolved = project.resolve(file_path)?;
57    let source = fs::read_to_string(&resolved)?;
58    let cache_key = content_cache_key(file_path, &source);
59
60    // Return cached result if file content unchanged
61    if let Ok(cache) = IMPORT_ANALYSIS_CACHE.lock()
62        && let Some(cached) = cache.get(&cache_key)
63    {
64        return Ok(cached.clone());
65    }
66
67    let ext = resolved
68        .extension()
69        .and_then(|e| e.to_str())
70        .unwrap_or("")
71        .to_ascii_lowercase();
72
73    // Step 1: Extract type identifiers via tree-sitter (not regex)
74    let used_types = collect_type_candidates_ast(&resolved, &source)?;
75
76    // Step 2: Collect locally defined symbols
77    let local_symbols: HashSet<String> = get_symbols_overview(project, file_path, 0)?
78        .into_iter()
79        .flat_map(flatten_names)
80        .collect();
81
82    // Step 3: Collect already-imported symbols
83    let existing_imports = extract_existing_import_names(&resolved);
84
85    // Step 4: Find unresolved = used - local - imported - builtins
86    let unresolved: Vec<String> = used_types
87        .into_iter()
88        .filter(|name| !local_symbols.contains(name) && !existing_imports.contains(name))
89        .filter(|name| !is_builtin(name, &ext))
90        .collect();
91
92    // Step 5: Batch lookup via SymbolIndex (SQLite) — much faster than per-name find_symbol
93    let insert_line = find_import_insert_line(&source, &ext);
94    let mut suggestions = Vec::new();
95    let index = SymbolIndex::new(project.clone());
96
97    for name in &unresolved {
98        if let Ok(matches) = index.find_symbol(name, None, false, true, 3) {
99            // Skip if only found in the same file
100            let external: Vec<_> = matches
101                .iter()
102                .filter(|m| m.file_path != file_path)
103                .collect();
104            let best_ref = external.first().copied().or(matches.first());
105            if let Some(best) = best_ref {
106                let import_stmt = generate_import_statement(name, &best.file_path, &ext);
107                suggestions.push(ImportSuggestion {
108                    symbol_name: name.clone(),
109                    source_file: best.file_path.clone(),
110                    import_statement: import_stmt,
111                    insert_line,
112                    confidence: if external.len() == 1 { 0.95 } else { 0.7 },
113                });
114            }
115        }
116    }
117
118    let result = MissingImportAnalysis {
119        file_path: file_path.to_string(),
120        unresolved_symbols: unresolved,
121        suggestions,
122    };
123
124    // Store in cache, evict oldest if at capacity
125    if let Ok(mut cache) = IMPORT_ANALYSIS_CACHE.lock() {
126        if cache.len() >= IMPORT_CACHE_CAPACITY
127            && let Some(&oldest_key) = cache.keys().next()
128        {
129            cache.remove(&oldest_key);
130        }
131        cache.insert(cache_key, result.clone());
132    }
133
134    Ok(result)
135}
136
137/// Add an import statement to a file at the correct position.
138pub fn add_import(
139    project: &ProjectRoot,
140    file_path: &str,
141    import_statement: &str,
142) -> Result<String> {
143    let resolved = project.resolve(file_path)?;
144    let content = fs::read_to_string(&resolved)?;
145    let ext = resolved
146        .extension()
147        .and_then(|e| e.to_str())
148        .unwrap_or("")
149        .to_ascii_lowercase();
150
151    // Check if already imported
152    if content.contains(import_statement.trim()) {
153        return Ok(content);
154    }
155
156    let insert_line = find_import_insert_line(&content, &ext);
157    let mut lines: Vec<&str> = content.lines().collect();
158    let insert_idx = (insert_line - 1).min(lines.len());
159    lines.insert(insert_idx, import_statement.trim());
160
161    let mut result = lines.join("\n");
162    if content.ends_with('\n') {
163        result.push('\n');
164    }
165    fs::write(&resolved, &result)?;
166    Ok(result)
167}
168
169// ── Helpers ──────────────────────────────────────────────────────────────
170
171/// Collect type candidates from AST — only real type identifiers, not strings/comments.
172fn collect_type_candidates_ast(file_path: &Path, source: &str) -> Result<Vec<String>> {
173    let Some(config) = language_for_path(file_path) else {
174        // Fallback to regex for unsupported languages
175        return Ok(collect_type_candidates_regex(source));
176    };
177
178    let mut parser = Parser::new();
179    parser.set_language(&config.language)?;
180    let Some(tree) = parser.parse(source, None) else {
181        return Ok(collect_type_candidates_regex(source));
182    };
183
184    let source_bytes = source.as_bytes();
185    let mut seen = HashSet::new();
186    let mut result = Vec::new();
187    collect_type_nodes(tree.root_node(), source_bytes, &mut seen, &mut result);
188    Ok(result)
189}
190
191fn collect_type_nodes(
192    node: Node,
193    source: &[u8],
194    seen: &mut HashSet<String>,
195    out: &mut Vec<String>,
196) {
197    let kind = node.kind();
198
199    // Skip comments and strings entirely
200    if matches!(
201        kind,
202        "comment"
203            | "line_comment"
204            | "block_comment"
205            | "string"
206            | "string_literal"
207            | "template_string"
208            | "raw_string_literal"
209            | "interpreted_string_literal"
210    ) {
211        return;
212    }
213
214    // Collect type identifiers and uppercase identifiers in type positions
215    if kind == "type_identifier" || kind == "identifier" {
216        let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
217        if !text.is_empty()
218            && text
219                .chars()
220                .next()
221                .map(|c| c.is_uppercase())
222                .unwrap_or(false)
223            && !is_keyword(text)
224            && seen.insert(text.to_string())
225        {
226            out.push(text.to_string());
227        }
228    }
229
230    for i in 0..node.child_count() {
231        if let Some(child) = node.child(i) {
232            collect_type_nodes(child, source, seen, out);
233        }
234    }
235}
236
237/// Regex fallback for unsupported languages.
238fn collect_type_candidates_regex(source: &str) -> Vec<String> {
239    let re = &*TYPE_CANDIDATE_RE;
240    let mut seen = HashSet::new();
241    let mut result = Vec::new();
242    for line in source.lines() {
243        let trimmed = line.trim();
244        if trimmed.starts_with('#') || trimmed.starts_with("//") || trimmed.starts_with("/*") {
245            continue;
246        }
247        for cap in re.find_iter(line) {
248            let name = cap.as_str().to_string();
249            if !is_keyword(&name) && seen.insert(name.clone()) {
250                result.push(name);
251            }
252        }
253    }
254    result
255}
256
257/// Extract names that are already imported.
258fn extract_existing_import_names(path: &Path) -> HashSet<String> {
259    let raw_imports = extract_imports_for_file(path);
260    let mut names = HashSet::new();
261    for imp in &raw_imports {
262        // Extract last segment: "from foo import Bar" → "Bar", "import foo.Bar" → "Bar"
263        if let Some(last) = imp.rsplit('.').next() {
264            names.insert(last.to_string());
265        }
266        // Also try extracting from "from X import Y" patterns
267        if let Some(pos) = imp.find(" import ") {
268            let after = &imp[pos + 8..];
269            for part in after.split(',') {
270                let name = part.trim().split(" as ").next().unwrap_or("").trim();
271                if !name.is_empty() {
272                    names.insert(name.to_string());
273                }
274            }
275        }
276    }
277    names
278}
279
280/// Find the line number where new imports should be inserted.
281fn find_import_insert_line(source: &str, ext: &str) -> usize {
282    let mut last_import_line = 0;
283    let mut in_docstring = false;
284
285    for (i, line) in source.lines().enumerate() {
286        let trimmed = line.trim();
287
288        // Skip Python docstrings
289        if trimmed.contains("\"\"\"") || trimmed.contains("'''") {
290            in_docstring = !in_docstring;
291            continue;
292        }
293        if in_docstring {
294            continue;
295        }
296
297        let is_import = match ext {
298            "py" => trimmed.starts_with("import ") || trimmed.starts_with("from "),
299            "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
300                trimmed.starts_with("import ") || trimmed.starts_with("import{")
301            }
302            "java" | "kt" | "kts" => trimmed.starts_with("import "),
303            "go" => trimmed.starts_with("import ") || trimmed == "import (",
304            "rs" => trimmed.starts_with("use "),
305            _ => false,
306        };
307
308        if is_import {
309            last_import_line = i + 1;
310        }
311    }
312
313    // If no imports found, insert after package/module declaration or at top
314    if last_import_line == 0 {
315        for (i, line) in source.lines().enumerate() {
316            let trimmed = line.trim();
317            if trimmed.starts_with("package ")
318                || trimmed.starts_with("module ")
319                || (trimmed.starts_with('#') && trimmed.contains("!"))
320            {
321                return i + 2; // After package + blank line
322            }
323        }
324        return 1;
325    }
326
327    last_import_line + 1
328}
329
330/// Generate a language-appropriate import statement.
331fn generate_import_statement(symbol_name: &str, source_file: &str, target_ext: &str) -> String {
332    let module = source_file
333        .trim_end_matches(".py")
334        .trim_end_matches(".ts")
335        .trim_end_matches(".tsx")
336        .trim_end_matches(".js")
337        .trim_end_matches(".jsx")
338        .trim_end_matches(".java")
339        .trim_end_matches(".kt")
340        .trim_end_matches(".rs")
341        .trim_end_matches(".go")
342        .replace('/', ".");
343
344    match target_ext {
345        "py" => format!("from {module} import {symbol_name}"),
346        "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
347            let rel_path = format!(
348                "./{}",
349                source_file
350                    .trim_end_matches(".ts")
351                    .trim_end_matches(".tsx")
352                    .trim_end_matches(".js")
353            );
354            format!("import {{ {} }} from '{}';", symbol_name, rel_path)
355        }
356        "java" => format!("import {};", module),
357        "kt" | "kts" => format!("import {}", module),
358        "rs" => format!("use crate::{};", module.replace('.', "::")),
359        "go" => format!("import \"{}\"", source_file.trim_end_matches(".go")),
360        _ => format!("// import {} from {}", symbol_name, source_file),
361    }
362}
363
364fn flatten_names(symbol: SymbolInfo) -> Vec<String> {
365    let mut names = vec![symbol.name.clone()];
366    for child in symbol.children {
367        names.extend(flatten_names(child));
368    }
369    names
370}
371
372fn is_keyword(name: &str) -> bool {
373    matches!(
374        name,
375        "True"
376            | "False"
377            | "None"
378            | "Self"
379            | "String"
380            | "Result"
381            | "Option"
382            | "Vec"
383            | "HashMap"
384            | "HashSet"
385            | "Object"
386            | "Array"
387            | "Map"
388            | "Set"
389            | "Promise"
390            | "Error"
391            | "TypeError"
392            | "ValueError"
393            | "Exception"
394            | "RuntimeError"
395            | "Boolean"
396            | "Integer"
397            | "Float"
398            | "Double"
399            | "NULL"
400            | "EOF"
401            | "TODO"
402            | "FIXME"
403            | "HACK"
404    )
405}
406
407fn is_builtin(name: &str, ext: &str) -> bool {
408    if is_keyword(name) {
409        return true;
410    }
411    match ext {
412        "py" => matches!(
413            name,
414            "int"
415                | "str"
416                | "float"
417                | "bool"
418                | "list"
419                | "dict"
420                | "tuple"
421                | "set"
422                | "Type"
423                | "Optional"
424                | "List"
425                | "Dict"
426                | "Tuple"
427                | "Set"
428                | "Any"
429                | "Union"
430                | "Callable"
431        ),
432        "ts" | "tsx" | "js" | "jsx" => matches!(
433            name,
434            "Date"
435                | "RegExp"
436                | "JSON"
437                | "Math"
438                | "Number"
439                | "Console"
440                | "Window"
441                | "Document"
442                | "Element"
443                | "HTMLElement"
444                | "Event"
445                | "Response"
446                | "Request"
447                | "Partial"
448                | "Readonly"
449                | "Record"
450                | "Pick"
451                | "Omit"
452        ),
453        "java" | "kt" => matches!(
454            name,
455            "System"
456                | "Math"
457                | "Thread"
458                | "Class"
459                | "Comparable"
460                | "Iterable"
461                | "Iterator"
462                | "Override"
463                | "Deprecated"
464                | "Test"
465                | "Suppress"
466        ),
467        "rs" => matches!(
468            name,
469            "Ok" | "Err"
470                | "Some"
471                | "Copy"
472                | "Clone"
473                | "Debug"
474                | "Default"
475                | "Display"
476                | "From"
477                | "Into"
478                | "Send"
479                | "Sync"
480                | "Sized"
481                | "Drop"
482                | "Fn"
483                | "FnMut"
484                | "FnOnce"
485                | "Box"
486                | "Rc"
487                | "Arc"
488                | "Mutex"
489                | "RwLock"
490                | "Pin"
491                | "Serialize"
492                | "Deserialize"
493                | "Regex"
494                | "Path"
495                | "PathBuf"
496                | "File"
497                | "Read"
498                | "Write"
499                | "BufRead"
500                | "BufReader"
501                | "BufWriter"
502                | "WalkDir"
503                | "Context"
504                | "Cow"
505                | "PhantomData"
506                | "ManuallyDrop"
507        ),
508        _ => false,
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::ProjectRoot;
516
517    fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
518        let dir = std::env::temp_dir().join(format!(
519            "codelens-autoimport-{}",
520            std::time::SystemTime::now()
521                .duration_since(std::time::UNIX_EPOCH)
522                .unwrap()
523                .as_nanos()
524        ));
525        fs::create_dir_all(dir.join("src")).unwrap();
526        fs::write(
527            dir.join("src/models.py"),
528            "class UserModel:\n    def __init__(self, name):\n        self.name = name\n",
529        )
530        .unwrap();
531        fs::write(
532            dir.join("src/service.py"),
533            "class UserService:\n    def get(self):\n        return UserModel()\n",
534        )
535        .unwrap();
536        let project = ProjectRoot::new(&dir).unwrap();
537        (dir, project)
538    }
539
540    #[test]
541    fn detects_unresolved_type() {
542        let (_dir, project) = make_fixture();
543        let result = analyze_missing_imports(&project, "src/service.py").unwrap();
544        assert!(
545            result.unresolved_symbols.contains(&"UserModel".to_string()),
546            "should detect UserModel as unresolved: {:?}",
547            result.unresolved_symbols
548        );
549    }
550
551    #[test]
552    fn suggests_import_for_unresolved() {
553        let (_dir, project) = make_fixture();
554        let result = analyze_missing_imports(&project, "src/service.py").unwrap();
555        let suggestion = result
556            .suggestions
557            .iter()
558            .find(|s| s.symbol_name == "UserModel");
559        assert!(
560            suggestion.is_some(),
561            "should suggest import for UserModel: {:?}",
562            result.suggestions
563        );
564        let s = suggestion.unwrap();
565        assert!(
566            s.import_statement.contains("UserModel"),
567            "import statement should mention UserModel: {}",
568            s.import_statement
569        );
570        assert!(s.confidence > 0.5);
571    }
572
573    #[test]
574    fn does_not_suggest_locally_defined() {
575        let (_dir, project) = make_fixture();
576        let result = analyze_missing_imports(&project, "src/models.py").unwrap();
577        assert!(
578            !result.unresolved_symbols.contains(&"UserModel".to_string()),
579            "locally defined UserModel should not be unresolved"
580        );
581    }
582
583    #[test]
584    fn add_import_inserts_at_correct_position() {
585        let dir = std::env::temp_dir().join(format!(
586            "codelens-addimport-{}",
587            std::time::SystemTime::now()
588                .duration_since(std::time::UNIX_EPOCH)
589                .unwrap()
590                .as_nanos()
591        ));
592        fs::create_dir_all(&dir).unwrap();
593        fs::write(
594            dir.join("test.py"),
595            "import os\nimport sys\n\ndef main():\n    pass\n",
596        )
597        .unwrap();
598        let project = ProjectRoot::new(&dir).unwrap();
599        let result = add_import(&project, "test.py", "from models import User").unwrap();
600        let lines: Vec<&str> = result.lines().collect();
601        // Should be inserted after existing imports (line 3)
602        assert!(
603            lines.contains(&"from models import User"),
604            "should contain new import: {:?}",
605            lines
606        );
607        let import_idx = lines
608            .iter()
609            .position(|l| *l == "from models import User")
610            .unwrap();
611        let sys_idx = lines.iter().position(|l| *l == "import sys").unwrap();
612        assert!(
613            import_idx > sys_idx,
614            "new import should be after existing imports"
615        );
616    }
617
618    #[test]
619    fn skip_already_imported() {
620        let dir = std::env::temp_dir().join(format!(
621            "codelens-skipimport-{}",
622            std::time::SystemTime::now()
623                .duration_since(std::time::UNIX_EPOCH)
624                .unwrap()
625                .as_nanos()
626        ));
627        fs::create_dir_all(&dir).unwrap();
628        fs::write(
629            dir.join("test.py"),
630            "from models import User\n\nx = User()\n",
631        )
632        .unwrap();
633        let project = ProjectRoot::new(&dir).unwrap();
634        let result = add_import(&project, "test.py", "from models import User").unwrap();
635        // Should not duplicate
636        assert_eq!(
637            result.matches("from models import User").count(),
638            1,
639            "should not duplicate import"
640        );
641    }
642
643    #[test]
644    fn find_import_insert_line_python() {
645        let source = "import os\nimport sys\n\ndef main():\n    pass\n";
646        assert_eq!(find_import_insert_line(source, "py"), 3);
647    }
648
649    #[test]
650    fn find_import_insert_line_empty() {
651        let source = "def main():\n    pass\n";
652        assert_eq!(find_import_insert_line(source, "py"), 1);
653    }
654
655    #[test]
656    fn generate_python_import() {
657        let stmt = generate_import_statement("UserModel", "src/models.py", "py");
658        assert_eq!(stmt, "from src.models import UserModel");
659    }
660
661    #[test]
662    fn generate_typescript_import() {
663        let stmt = generate_import_statement("UserService", "src/service.ts", "ts");
664        assert!(stmt.contains("import { UserService }"));
665        assert!(stmt.contains("'./src/service'"));
666    }
667}