1use crate::import_graph::extract_imports_for_file;
13use crate::project::ProjectRoot;
14use crate::symbols::{SymbolIndex, SymbolInfo, get_symbols_overview, language_for_path};
15use anyhow::Result;
16use regex::Regex;
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::fs;
20use std::hash::{Hash, Hasher};
21use std::path::Path;
22use std::sync::{LazyLock, Mutex};
23use tree_sitter::{Node, Parser};
24
25static TYPE_CANDIDATE_RE: LazyLock<Regex> =
26 LazyLock::new(|| Regex::new(r"\b([A-Z][a-zA-Z0-9_]*)\b").unwrap());
27
28const IMPORT_CACHE_CAPACITY: usize = 64;
29
30static IMPORT_ANALYSIS_CACHE: LazyLock<Mutex<HashMap<u64, MissingImportAnalysis>>> =
31 LazyLock::new(|| Mutex::new(HashMap::new()));
32
33fn content_cache_key(file_path: &str, content: &str) -> u64 {
34 let mut hasher = std::hash::DefaultHasher::new();
35 file_path.hash(&mut hasher);
36 content.hash(&mut hasher);
37 hasher.finish()
38}
39
40#[derive(Debug, Clone, Serialize)]
41pub struct ImportSuggestion {
42 pub symbol_name: String,
43 pub source_file: String,
44 pub import_statement: String,
45 pub insert_line: usize,
46 pub confidence: f64,
47}
48
49#[derive(Debug, Clone, Serialize)]
50pub struct MissingImportAnalysis {
51 pub file_path: String,
52 pub unresolved_symbols: Vec<String>,
53 pub suggestions: Vec<ImportSuggestion>,
54}
55
56pub fn analyze_missing_imports(
59 project: &ProjectRoot,
60 file_path: &str,
61) -> Result<MissingImportAnalysis> {
62 let resolved = project.resolve(file_path)?;
63 let source = fs::read_to_string(&resolved)?;
64 let cache_key = content_cache_key(file_path, &source);
65
66 if let Ok(cache) = IMPORT_ANALYSIS_CACHE.lock()
68 && let Some(cached) = cache.get(&cache_key)
69 {
70 return Ok(cached.clone());
71 }
72
73 let ext = resolved
74 .extension()
75 .and_then(|e| e.to_str())
76 .unwrap_or("")
77 .to_ascii_lowercase();
78
79 let used_types = collect_type_candidates_ast(&resolved, &source)?;
81
82 let local_symbols: HashSet<String> = get_symbols_overview(project, file_path, 0)?
84 .into_iter()
85 .flat_map(flatten_names)
86 .collect();
87
88 let existing_imports = extract_existing_import_names(&resolved);
90
91 let unresolved: Vec<String> = used_types
93 .into_iter()
94 .filter(|name| !local_symbols.contains(name) && !existing_imports.contains(name))
95 .filter(|name| !is_builtin(name, &ext))
96 .collect();
97
98 let insert_line = find_import_insert_line(&source, &ext);
100 let mut suggestions = Vec::new();
101 let index = SymbolIndex::new(project.clone());
102
103 for name in &unresolved {
104 if let Ok(matches) = index.find_symbol(name, None, false, true, 3) {
105 let external: Vec<_> = matches
107 .iter()
108 .filter(|m| m.file_path != file_path)
109 .collect();
110 let best_ref = external.first().copied().or(matches.first());
111 if let Some(best) = best_ref {
112 let import_stmt = generate_import_statement(name, &best.file_path, &ext);
113 suggestions.push(ImportSuggestion {
114 symbol_name: name.clone(),
115 source_file: best.file_path.clone(),
116 import_statement: import_stmt,
117 insert_line,
118 confidence: if external.len() == 1 { 0.95 } else { 0.7 },
119 });
120 }
121 }
122 }
123
124 let result = MissingImportAnalysis {
125 file_path: file_path.to_string(),
126 unresolved_symbols: unresolved,
127 suggestions,
128 };
129
130 if let Ok(mut cache) = IMPORT_ANALYSIS_CACHE.lock() {
132 if cache.len() >= IMPORT_CACHE_CAPACITY
133 && let Some(&oldest_key) = cache.keys().next()
134 {
135 cache.remove(&oldest_key);
136 }
137 cache.insert(cache_key, result.clone());
138 }
139
140 Ok(result)
141}
142
143pub fn add_import(
145 project: &ProjectRoot,
146 file_path: &str,
147 import_statement: &str,
148) -> Result<(String, crate::edit_transaction::ApplyEvidence)> {
149 let resolved = project.resolve(file_path)?;
150 let content = fs::read_to_string(&resolved)?;
151 let ext = resolved
152 .extension()
153 .and_then(|e| e.to_str())
154 .unwrap_or("")
155 .to_ascii_lowercase();
156
157 if content.contains(import_statement.trim()) {
159 let evidence = crate::edit_transaction::ApplyEvidence {
160 status: crate::edit_transaction::ApplyStatus::Applied,
161 file_hashes_before: Default::default(),
162 file_hashes_after: Default::default(),
163 rollback_report: vec![],
164 modified_files: 0,
165 edit_count: 0,
166 };
167 return Ok((content, evidence));
168 }
169
170 let insert_line = find_import_insert_line(&content, &ext);
171 let mut lines: Vec<&str> = content.lines().collect();
172 let insert_idx = (insert_line - 1).min(lines.len());
173 lines.insert(insert_idx, import_statement.trim());
174
175 let mut new_content = lines.join("\n");
176 if content.ends_with('\n') {
177 new_content.push('\n');
178 }
179
180 let relative_path = file_path;
181 let evidence = match crate::edit_transaction::apply_full_write_with_evidence(
182 project,
183 relative_path,
184 &new_content,
185 ) {
186 Ok(ev) => ev,
187 Err(crate::edit_transaction::ApplyError::ApplyFailed {
188 source: _,
189 evidence,
190 }) => evidence,
191 Err(other) => return Err(anyhow::Error::msg(other.to_string())),
192 };
193 Ok((new_content, evidence))
194}
195
196fn collect_type_candidates_ast(file_path: &Path, source: &str) -> Result<Vec<String>> {
200 let Some(config) = language_for_path(file_path) else {
201 return Ok(collect_type_candidates_regex(source));
203 };
204
205 let mut parser = Parser::new();
206 parser.set_language(&config.language)?;
207 let Some(tree) = parser.parse(source, None) else {
208 return Ok(collect_type_candidates_regex(source));
209 };
210
211 let source_bytes = source.as_bytes();
212 let mut seen = HashSet::new();
213 let mut result = Vec::new();
214 collect_type_nodes(tree.root_node(), source_bytes, &mut seen, &mut result);
215 Ok(result)
216}
217
218fn collect_type_nodes(
219 node: Node,
220 source: &[u8],
221 seen: &mut HashSet<String>,
222 out: &mut Vec<String>,
223) {
224 let kind = node.kind();
225
226 if matches!(
228 kind,
229 "comment"
230 | "line_comment"
231 | "block_comment"
232 | "string"
233 | "string_literal"
234 | "template_string"
235 | "raw_string_literal"
236 | "interpreted_string_literal"
237 ) {
238 return;
239 }
240
241 if kind == "type_identifier" || kind == "identifier" {
243 let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
244 if !text.is_empty()
245 && text
246 .chars()
247 .next()
248 .map(|c| c.is_uppercase())
249 .unwrap_or(false)
250 && !is_keyword(text)
251 && seen.insert(text.to_string())
252 {
253 out.push(text.to_string());
254 }
255 }
256
257 for i in 0..node.child_count() {
258 if let Some(child) = node.child(i) {
259 collect_type_nodes(child, source, seen, out);
260 }
261 }
262}
263
264fn collect_type_candidates_regex(source: &str) -> Vec<String> {
266 let re = &*TYPE_CANDIDATE_RE;
267 let mut seen = HashSet::new();
268 let mut result = Vec::new();
269 for line in source.lines() {
270 let trimmed = line.trim();
271 if trimmed.starts_with('#') || trimmed.starts_with("//") || trimmed.starts_with("/*") {
272 continue;
273 }
274 for cap in re.find_iter(line) {
275 let name = cap.as_str().to_string();
276 if !is_keyword(&name) && seen.insert(name.clone()) {
277 result.push(name);
278 }
279 }
280 }
281 result
282}
283
284fn extract_existing_import_names(path: &Path) -> HashSet<String> {
286 let raw_imports = extract_imports_for_file(path);
287 let mut names = HashSet::new();
288 for imp in &raw_imports {
289 if let Some(last) = imp.rsplit('.').next() {
291 names.insert(last.to_string());
292 }
293 if let Some(pos) = imp.find(" import ") {
295 let after = &imp[pos + 8..];
296 for part in after.split(',') {
297 let name = part.trim().split(" as ").next().unwrap_or("").trim();
298 if !name.is_empty() {
299 names.insert(name.to_string());
300 }
301 }
302 }
303 }
304 names
305}
306
307fn find_import_insert_line(source: &str, ext: &str) -> usize {
309 let mut last_import_line = 0;
310 let mut in_docstring = false;
311
312 for (i, line) in source.lines().enumerate() {
313 let trimmed = line.trim();
314
315 if trimmed.contains("\"\"\"") || trimmed.contains("'''") {
317 in_docstring = !in_docstring;
318 continue;
319 }
320 if in_docstring {
321 continue;
322 }
323
324 let is_import = match ext {
325 "py" => trimmed.starts_with("import ") || trimmed.starts_with("from "),
326 "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
327 trimmed.starts_with("import ") || trimmed.starts_with("import{")
328 }
329 "java" | "kt" | "kts" => trimmed.starts_with("import "),
330 "go" => trimmed.starts_with("import ") || trimmed == "import (",
331 "rs" => trimmed.starts_with("use "),
332 _ => false,
333 };
334
335 if is_import {
336 last_import_line = i + 1;
337 }
338 }
339
340 if last_import_line == 0 {
342 for (i, line) in source.lines().enumerate() {
343 let trimmed = line.trim();
344 if trimmed.starts_with("package ")
345 || trimmed.starts_with("module ")
346 || (trimmed.starts_with('#') && trimmed.contains("!"))
347 {
348 return i + 2; }
350 }
351 return 1;
352 }
353
354 last_import_line + 1
355}
356
357fn generate_import_statement(symbol_name: &str, source_file: &str, target_ext: &str) -> String {
359 let module = source_file
360 .trim_end_matches(".py")
361 .trim_end_matches(".ts")
362 .trim_end_matches(".tsx")
363 .trim_end_matches(".js")
364 .trim_end_matches(".jsx")
365 .trim_end_matches(".java")
366 .trim_end_matches(".kt")
367 .trim_end_matches(".rs")
368 .trim_end_matches(".go")
369 .replace('/', ".");
370
371 match target_ext {
372 "py" => format!("from {module} import {symbol_name}"),
373 "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
374 let rel_path = format!(
375 "./{}",
376 source_file
377 .trim_end_matches(".ts")
378 .trim_end_matches(".tsx")
379 .trim_end_matches(".js")
380 );
381 format!("import {{ {} }} from '{}';", symbol_name, rel_path)
382 }
383 "java" => format!("import {};", module),
384 "kt" | "kts" => format!("import {}", module),
385 "rs" => format!("use crate::{};", module.replace('.', "::")),
386 "go" => format!("import \"{}\"", source_file.trim_end_matches(".go")),
387 _ => format!("// import {} from {}", symbol_name, source_file),
388 }
389}
390
391fn flatten_names(symbol: SymbolInfo) -> Vec<String> {
392 let mut names = vec![symbol.name.clone()];
393 for child in symbol.children {
394 names.extend(flatten_names(child));
395 }
396 names
397}
398
399fn is_keyword(name: &str) -> bool {
400 matches!(
401 name,
402 "True"
403 | "False"
404 | "None"
405 | "Self"
406 | "String"
407 | "Result"
408 | "Option"
409 | "Vec"
410 | "HashMap"
411 | "HashSet"
412 | "Object"
413 | "Array"
414 | "Map"
415 | "Set"
416 | "Promise"
417 | "Error"
418 | "TypeError"
419 | "ValueError"
420 | "Exception"
421 | "RuntimeError"
422 | "Boolean"
423 | "Integer"
424 | "Float"
425 | "Double"
426 | "NULL"
427 | "EOF"
428 | "TODO"
429 | "FIXME"
430 | "HACK"
431 )
432}
433
434fn is_builtin(name: &str, ext: &str) -> bool {
435 if is_keyword(name) {
436 return true;
437 }
438 match ext {
439 "py" => matches!(
440 name,
441 "int"
442 | "str"
443 | "float"
444 | "bool"
445 | "list"
446 | "dict"
447 | "tuple"
448 | "set"
449 | "Type"
450 | "Optional"
451 | "List"
452 | "Dict"
453 | "Tuple"
454 | "Set"
455 | "Any"
456 | "Union"
457 | "Callable"
458 ),
459 "ts" | "tsx" | "js" | "jsx" => matches!(
460 name,
461 "Date"
462 | "RegExp"
463 | "JSON"
464 | "Math"
465 | "Number"
466 | "Console"
467 | "Window"
468 | "Document"
469 | "Element"
470 | "HTMLElement"
471 | "Event"
472 | "Response"
473 | "Request"
474 | "Partial"
475 | "Readonly"
476 | "Record"
477 | "Pick"
478 | "Omit"
479 ),
480 "java" | "kt" => matches!(
481 name,
482 "System"
483 | "Math"
484 | "Thread"
485 | "Class"
486 | "Comparable"
487 | "Iterable"
488 | "Iterator"
489 | "Override"
490 | "Deprecated"
491 | "Test"
492 | "Suppress"
493 ),
494 "rs" => matches!(
495 name,
496 "Ok" | "Err"
497 | "Some"
498 | "Copy"
499 | "Clone"
500 | "Debug"
501 | "Default"
502 | "Display"
503 | "From"
504 | "Into"
505 | "Send"
506 | "Sync"
507 | "Sized"
508 | "Drop"
509 | "Fn"
510 | "FnMut"
511 | "FnOnce"
512 | "Box"
513 | "Rc"
514 | "Arc"
515 | "Mutex"
516 | "RwLock"
517 | "Pin"
518 | "Serialize"
519 | "Deserialize"
520 | "Regex"
521 | "Path"
522 | "PathBuf"
523 | "File"
524 | "Read"
525 | "Write"
526 | "BufRead"
527 | "BufReader"
528 | "BufWriter"
529 | "WalkDir"
530 | "Context"
531 | "Cow"
532 | "PhantomData"
533 | "ManuallyDrop"
534 ),
535 _ => false,
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use crate::ProjectRoot;
543
544 fn make_fixture() -> (tempfile::TempDir, std::path::PathBuf, ProjectRoot) {
545 let (temp_dir, dir) = crate::test_helpers::make_unique_temp_dir("codelens-autoimport-");
546 fs::create_dir_all(dir.join("src")).unwrap();
547 fs::write(
548 dir.join("src/models.py"),
549 "class UserModel:\n def __init__(self, name):\n self.name = name\n",
550 )
551 .unwrap();
552 fs::write(
553 dir.join("src/service.py"),
554 "class UserService:\n def get(self):\n return UserModel()\n",
555 )
556 .unwrap();
557 let project = ProjectRoot::new(&dir).unwrap();
558 (temp_dir, dir, project)
559 }
560
561 #[test]
562 fn detects_unresolved_type() {
563 let (_temp, _dir, project) = make_fixture();
564 let result = analyze_missing_imports(&project, "src/service.py").unwrap();
565 assert!(
566 result.unresolved_symbols.contains(&"UserModel".to_string()),
567 "should detect UserModel as unresolved: {:?}",
568 result.unresolved_symbols
569 );
570 }
571
572 #[test]
573 fn suggests_import_for_unresolved() {
574 let (_temp, _dir, project) = make_fixture();
575 let result = analyze_missing_imports(&project, "src/service.py").unwrap();
576 let suggestion = result
577 .suggestions
578 .iter()
579 .find(|s| s.symbol_name == "UserModel");
580 assert!(
581 suggestion.is_some(),
582 "should suggest import for UserModel: {:?}",
583 result.suggestions
584 );
585 let s = suggestion.unwrap();
586 assert!(
587 s.import_statement.contains("UserModel"),
588 "import statement should mention UserModel: {}",
589 s.import_statement
590 );
591 assert!(s.confidence > 0.5);
592 }
593
594 #[test]
595 fn does_not_suggest_locally_defined() {
596 let (_temp, _dir, project) = make_fixture();
597 let result = analyze_missing_imports(&project, "src/models.py").unwrap();
598 assert!(
599 !result.unresolved_symbols.contains(&"UserModel".to_string()),
600 "locally defined UserModel should not be unresolved"
601 );
602 }
603
604 #[test]
605 fn add_import_inserts_at_correct_position() {
606 use crate::edit_transaction::ApplyStatus;
607
608 let (_temp_dir, dir) = crate::test_helpers::make_unique_temp_dir("codelens-addimport-");
609 fs::create_dir_all(&dir).unwrap();
610 fs::write(
611 dir.join("test.py"),
612 "import os\nimport sys\n\ndef main():\n pass\n",
613 )
614 .unwrap();
615 let project = ProjectRoot::new(&dir).unwrap();
616 let (result, evidence) =
617 add_import(&project, "test.py", "from models import User").unwrap();
618 assert_eq!(
619 evidence.status,
620 ApplyStatus::Applied,
621 "evidence should be Applied"
622 );
623 let lines: Vec<&str> = result.lines().collect();
624 assert!(
626 lines.contains(&"from models import User"),
627 "should contain new import: {:?}",
628 lines
629 );
630 let import_idx = lines
631 .iter()
632 .position(|l| *l == "from models import User")
633 .unwrap();
634 let sys_idx = lines.iter().position(|l| *l == "import sys").unwrap();
635 assert!(
636 import_idx > sys_idx,
637 "new import should be after existing imports"
638 );
639 }
640
641 #[test]
642 fn skip_already_imported() {
643 let (_temp_dir, dir) = crate::test_helpers::make_unique_temp_dir("codelens-skipimport-");
644 fs::create_dir_all(&dir).unwrap();
645 fs::write(
646 dir.join("test.py"),
647 "from models import User\n\nx = User()\n",
648 )
649 .unwrap();
650 let project = ProjectRoot::new(&dir).unwrap();
651 let (result, _evidence) =
652 add_import(&project, "test.py", "from models import User").unwrap();
653 assert_eq!(
655 result.matches("from models import User").count(),
656 1,
657 "should not duplicate import"
658 );
659 }
660
661 #[test]
662 fn find_import_insert_line_python() {
663 let source = "import os\nimport sys\n\ndef main():\n pass\n";
664 assert_eq!(find_import_insert_line(source, "py"), 3);
665 }
666
667 #[test]
668 fn find_import_insert_line_empty() {
669 let source = "def main():\n pass\n";
670 assert_eq!(find_import_insert_line(source, "py"), 1);
671 }
672
673 #[test]
674 fn generate_python_import() {
675 let stmt = generate_import_statement("UserModel", "src/models.py", "py");
676 assert_eq!(stmt, "from src.models import UserModel");
677 }
678
679 #[test]
680 fn generate_typescript_import() {
681 let stmt = generate_import_statement("UserService", "src/service.ts", "ts");
682 assert!(stmt.contains("import { UserService }"));
683 assert!(stmt.contains("'./src/service'"));
684 }
685}