1use crate::import_graph::extract_imports_for_file;
7use crate::project::ProjectRoot;
8use crate::symbols::{SymbolIndex, SymbolInfo, get_symbols_overview, language_for_path};
9use anyhow::Result;
10use regex::Regex;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use std::fs;
14use std::hash::{Hash, Hasher};
15use std::path::Path;
16use std::sync::{LazyLock, Mutex};
17use tree_sitter::{Node, Parser};
18
19static TYPE_CANDIDATE_RE: LazyLock<Regex> =
20 LazyLock::new(|| Regex::new(r"\b([A-Z][a-zA-Z0-9_]*)\b").unwrap());
21
22const IMPORT_CACHE_CAPACITY: usize = 64;
23
24static IMPORT_ANALYSIS_CACHE: LazyLock<Mutex<HashMap<u64, MissingImportAnalysis>>> =
25 LazyLock::new(|| Mutex::new(HashMap::new()));
26
27fn content_cache_key(file_path: &str, content: &str) -> u64 {
28 let mut hasher = std::hash::DefaultHasher::new();
29 file_path.hash(&mut hasher);
30 content.hash(&mut hasher);
31 hasher.finish()
32}
33
34#[derive(Debug, Clone, Serialize)]
35pub struct ImportSuggestion {
36 pub symbol_name: String,
37 pub source_file: String,
38 pub import_statement: String,
39 pub insert_line: usize,
40 pub confidence: f64,
41}
42
43#[derive(Debug, Clone, Serialize)]
44pub struct MissingImportAnalysis {
45 pub file_path: String,
46 pub unresolved_symbols: Vec<String>,
47 pub suggestions: Vec<ImportSuggestion>,
48}
49
50pub fn analyze_missing_imports(
53 project: &ProjectRoot,
54 file_path: &str,
55) -> Result<MissingImportAnalysis> {
56 let resolved = project.resolve(file_path)?;
57 let source = fs::read_to_string(&resolved)?;
58 let cache_key = content_cache_key(file_path, &source);
59
60 if let Ok(cache) = IMPORT_ANALYSIS_CACHE.lock()
62 && let Some(cached) = cache.get(&cache_key)
63 {
64 return Ok(cached.clone());
65 }
66
67 let ext = resolved
68 .extension()
69 .and_then(|e| e.to_str())
70 .unwrap_or("")
71 .to_ascii_lowercase();
72
73 let used_types = collect_type_candidates_ast(&resolved, &source)?;
75
76 let local_symbols: HashSet<String> = get_symbols_overview(project, file_path, 0)?
78 .into_iter()
79 .flat_map(flatten_names)
80 .collect();
81
82 let existing_imports = extract_existing_import_names(&resolved);
84
85 let unresolved: Vec<String> = used_types
87 .into_iter()
88 .filter(|name| !local_symbols.contains(name) && !existing_imports.contains(name))
89 .filter(|name| !is_builtin(name, &ext))
90 .collect();
91
92 let insert_line = find_import_insert_line(&source, &ext);
94 let mut suggestions = Vec::new();
95 let index = SymbolIndex::new(project.clone());
96
97 for name in &unresolved {
98 if let Ok(matches) = index.find_symbol(name, None, false, true, 3) {
99 let external: Vec<_> = matches
101 .iter()
102 .filter(|m| m.file_path != file_path)
103 .collect();
104 let best_ref = external.first().copied().or(matches.first());
105 if let Some(best) = best_ref {
106 let import_stmt = generate_import_statement(name, &best.file_path, &ext);
107 suggestions.push(ImportSuggestion {
108 symbol_name: name.clone(),
109 source_file: best.file_path.clone(),
110 import_statement: import_stmt,
111 insert_line,
112 confidence: if external.len() == 1 { 0.95 } else { 0.7 },
113 });
114 }
115 }
116 }
117
118 let result = MissingImportAnalysis {
119 file_path: file_path.to_string(),
120 unresolved_symbols: unresolved,
121 suggestions,
122 };
123
124 if let Ok(mut cache) = IMPORT_ANALYSIS_CACHE.lock() {
126 if cache.len() >= IMPORT_CACHE_CAPACITY
127 && let Some(&oldest_key) = cache.keys().next()
128 {
129 cache.remove(&oldest_key);
130 }
131 cache.insert(cache_key, result.clone());
132 }
133
134 Ok(result)
135}
136
137pub fn add_import(
139 project: &ProjectRoot,
140 file_path: &str,
141 import_statement: &str,
142) -> Result<String> {
143 let resolved = project.resolve(file_path)?;
144 let content = fs::read_to_string(&resolved)?;
145 let ext = resolved
146 .extension()
147 .and_then(|e| e.to_str())
148 .unwrap_or("")
149 .to_ascii_lowercase();
150
151 if content.contains(import_statement.trim()) {
153 return Ok(content);
154 }
155
156 let insert_line = find_import_insert_line(&content, &ext);
157 let mut lines: Vec<&str> = content.lines().collect();
158 let insert_idx = (insert_line - 1).min(lines.len());
159 lines.insert(insert_idx, import_statement.trim());
160
161 let mut result = lines.join("\n");
162 if content.ends_with('\n') {
163 result.push('\n');
164 }
165 fs::write(&resolved, &result)?;
166 Ok(result)
167}
168
169fn collect_type_candidates_ast(file_path: &Path, source: &str) -> Result<Vec<String>> {
173 let Some(config) = language_for_path(file_path) else {
174 return Ok(collect_type_candidates_regex(source));
176 };
177
178 let mut parser = Parser::new();
179 parser.set_language(&config.language)?;
180 let Some(tree) = parser.parse(source, None) else {
181 return Ok(collect_type_candidates_regex(source));
182 };
183
184 let source_bytes = source.as_bytes();
185 let mut seen = HashSet::new();
186 let mut result = Vec::new();
187 collect_type_nodes(tree.root_node(), source_bytes, &mut seen, &mut result);
188 Ok(result)
189}
190
191fn collect_type_nodes(
192 node: Node,
193 source: &[u8],
194 seen: &mut HashSet<String>,
195 out: &mut Vec<String>,
196) {
197 let kind = node.kind();
198
199 if matches!(
201 kind,
202 "comment"
203 | "line_comment"
204 | "block_comment"
205 | "string"
206 | "string_literal"
207 | "template_string"
208 | "raw_string_literal"
209 | "interpreted_string_literal"
210 ) {
211 return;
212 }
213
214 if kind == "type_identifier" || kind == "identifier" {
216 let text = std::str::from_utf8(&source[node.byte_range()]).unwrap_or("");
217 if !text.is_empty()
218 && text
219 .chars()
220 .next()
221 .map(|c| c.is_uppercase())
222 .unwrap_or(false)
223 && !is_keyword(text)
224 && seen.insert(text.to_string())
225 {
226 out.push(text.to_string());
227 }
228 }
229
230 for i in 0..node.child_count() {
231 if let Some(child) = node.child(i) {
232 collect_type_nodes(child, source, seen, out);
233 }
234 }
235}
236
237fn collect_type_candidates_regex(source: &str) -> Vec<String> {
239 let re = &*TYPE_CANDIDATE_RE;
240 let mut seen = HashSet::new();
241 let mut result = Vec::new();
242 for line in source.lines() {
243 let trimmed = line.trim();
244 if trimmed.starts_with('#') || trimmed.starts_with("//") || trimmed.starts_with("/*") {
245 continue;
246 }
247 for cap in re.find_iter(line) {
248 let name = cap.as_str().to_string();
249 if !is_keyword(&name) && seen.insert(name.clone()) {
250 result.push(name);
251 }
252 }
253 }
254 result
255}
256
257fn extract_existing_import_names(path: &Path) -> HashSet<String> {
259 let raw_imports = extract_imports_for_file(path);
260 let mut names = HashSet::new();
261 for imp in &raw_imports {
262 if let Some(last) = imp.rsplit('.').next() {
264 names.insert(last.to_string());
265 }
266 if let Some(pos) = imp.find(" import ") {
268 let after = &imp[pos + 8..];
269 for part in after.split(',') {
270 let name = part.trim().split(" as ").next().unwrap_or("").trim();
271 if !name.is_empty() {
272 names.insert(name.to_string());
273 }
274 }
275 }
276 }
277 names
278}
279
280fn find_import_insert_line(source: &str, ext: &str) -> usize {
282 let mut last_import_line = 0;
283 let mut in_docstring = false;
284
285 for (i, line) in source.lines().enumerate() {
286 let trimmed = line.trim();
287
288 if trimmed.contains("\"\"\"") || trimmed.contains("'''") {
290 in_docstring = !in_docstring;
291 continue;
292 }
293 if in_docstring {
294 continue;
295 }
296
297 let is_import = match ext {
298 "py" => trimmed.starts_with("import ") || trimmed.starts_with("from "),
299 "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
300 trimmed.starts_with("import ") || trimmed.starts_with("import{")
301 }
302 "java" | "kt" | "kts" => trimmed.starts_with("import "),
303 "go" => trimmed.starts_with("import ") || trimmed == "import (",
304 "rs" => trimmed.starts_with("use "),
305 _ => false,
306 };
307
308 if is_import {
309 last_import_line = i + 1;
310 }
311 }
312
313 if last_import_line == 0 {
315 for (i, line) in source.lines().enumerate() {
316 let trimmed = line.trim();
317 if trimmed.starts_with("package ")
318 || trimmed.starts_with("module ")
319 || (trimmed.starts_with('#') && trimmed.contains("!"))
320 {
321 return i + 2; }
323 }
324 return 1;
325 }
326
327 last_import_line + 1
328}
329
330fn generate_import_statement(symbol_name: &str, source_file: &str, target_ext: &str) -> String {
332 let module = source_file
333 .trim_end_matches(".py")
334 .trim_end_matches(".ts")
335 .trim_end_matches(".tsx")
336 .trim_end_matches(".js")
337 .trim_end_matches(".jsx")
338 .trim_end_matches(".java")
339 .trim_end_matches(".kt")
340 .trim_end_matches(".rs")
341 .trim_end_matches(".go")
342 .replace('/', ".");
343
344 match target_ext {
345 "py" => format!("from {module} import {symbol_name}"),
346 "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => {
347 let rel_path = format!(
348 "./{}",
349 source_file
350 .trim_end_matches(".ts")
351 .trim_end_matches(".tsx")
352 .trim_end_matches(".js")
353 );
354 format!("import {{ {} }} from '{}';", symbol_name, rel_path)
355 }
356 "java" => format!("import {};", module),
357 "kt" | "kts" => format!("import {}", module),
358 "rs" => format!("use crate::{};", module.replace('.', "::")),
359 "go" => format!("import \"{}\"", source_file.trim_end_matches(".go")),
360 _ => format!("// import {} from {}", symbol_name, source_file),
361 }
362}
363
364fn flatten_names(symbol: SymbolInfo) -> Vec<String> {
365 let mut names = vec![symbol.name.clone()];
366 for child in symbol.children {
367 names.extend(flatten_names(child));
368 }
369 names
370}
371
372fn is_keyword(name: &str) -> bool {
373 matches!(
374 name,
375 "True"
376 | "False"
377 | "None"
378 | "Self"
379 | "String"
380 | "Result"
381 | "Option"
382 | "Vec"
383 | "HashMap"
384 | "HashSet"
385 | "Object"
386 | "Array"
387 | "Map"
388 | "Set"
389 | "Promise"
390 | "Error"
391 | "TypeError"
392 | "ValueError"
393 | "Exception"
394 | "RuntimeError"
395 | "Boolean"
396 | "Integer"
397 | "Float"
398 | "Double"
399 | "NULL"
400 | "EOF"
401 | "TODO"
402 | "FIXME"
403 | "HACK"
404 )
405}
406
407fn is_builtin(name: &str, ext: &str) -> bool {
408 if is_keyword(name) {
409 return true;
410 }
411 match ext {
412 "py" => matches!(
413 name,
414 "int"
415 | "str"
416 | "float"
417 | "bool"
418 | "list"
419 | "dict"
420 | "tuple"
421 | "set"
422 | "Type"
423 | "Optional"
424 | "List"
425 | "Dict"
426 | "Tuple"
427 | "Set"
428 | "Any"
429 | "Union"
430 | "Callable"
431 ),
432 "ts" | "tsx" | "js" | "jsx" => matches!(
433 name,
434 "Date"
435 | "RegExp"
436 | "JSON"
437 | "Math"
438 | "Number"
439 | "Console"
440 | "Window"
441 | "Document"
442 | "Element"
443 | "HTMLElement"
444 | "Event"
445 | "Response"
446 | "Request"
447 | "Partial"
448 | "Readonly"
449 | "Record"
450 | "Pick"
451 | "Omit"
452 ),
453 "java" | "kt" => matches!(
454 name,
455 "System"
456 | "Math"
457 | "Thread"
458 | "Class"
459 | "Comparable"
460 | "Iterable"
461 | "Iterator"
462 | "Override"
463 | "Deprecated"
464 | "Test"
465 | "Suppress"
466 ),
467 "rs" => matches!(
468 name,
469 "Ok" | "Err"
470 | "Some"
471 | "Copy"
472 | "Clone"
473 | "Debug"
474 | "Default"
475 | "Display"
476 | "From"
477 | "Into"
478 | "Send"
479 | "Sync"
480 | "Sized"
481 | "Drop"
482 | "Fn"
483 | "FnMut"
484 | "FnOnce"
485 | "Box"
486 | "Rc"
487 | "Arc"
488 | "Mutex"
489 | "RwLock"
490 | "Pin"
491 | "Serialize"
492 | "Deserialize"
493 | "Regex"
494 | "Path"
495 | "PathBuf"
496 | "File"
497 | "Read"
498 | "Write"
499 | "BufRead"
500 | "BufReader"
501 | "BufWriter"
502 | "WalkDir"
503 | "Context"
504 | "Cow"
505 | "PhantomData"
506 | "ManuallyDrop"
507 ),
508 _ => false,
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::ProjectRoot;
516
517 fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
518 let dir = std::env::temp_dir().join(format!(
519 "codelens-autoimport-{}",
520 std::time::SystemTime::now()
521 .duration_since(std::time::UNIX_EPOCH)
522 .unwrap()
523 .as_nanos()
524 ));
525 fs::create_dir_all(dir.join("src")).unwrap();
526 fs::write(
527 dir.join("src/models.py"),
528 "class UserModel:\n def __init__(self, name):\n self.name = name\n",
529 )
530 .unwrap();
531 fs::write(
532 dir.join("src/service.py"),
533 "class UserService:\n def get(self):\n return UserModel()\n",
534 )
535 .unwrap();
536 let project = ProjectRoot::new(&dir).unwrap();
537 (dir, project)
538 }
539
540 #[test]
541 fn detects_unresolved_type() {
542 let (_dir, project) = make_fixture();
543 let result = analyze_missing_imports(&project, "src/service.py").unwrap();
544 assert!(
545 result.unresolved_symbols.contains(&"UserModel".to_string()),
546 "should detect UserModel as unresolved: {:?}",
547 result.unresolved_symbols
548 );
549 }
550
551 #[test]
552 fn suggests_import_for_unresolved() {
553 let (_dir, project) = make_fixture();
554 let result = analyze_missing_imports(&project, "src/service.py").unwrap();
555 let suggestion = result
556 .suggestions
557 .iter()
558 .find(|s| s.symbol_name == "UserModel");
559 assert!(
560 suggestion.is_some(),
561 "should suggest import for UserModel: {:?}",
562 result.suggestions
563 );
564 let s = suggestion.unwrap();
565 assert!(
566 s.import_statement.contains("UserModel"),
567 "import statement should mention UserModel: {}",
568 s.import_statement
569 );
570 assert!(s.confidence > 0.5);
571 }
572
573 #[test]
574 fn does_not_suggest_locally_defined() {
575 let (_dir, project) = make_fixture();
576 let result = analyze_missing_imports(&project, "src/models.py").unwrap();
577 assert!(
578 !result.unresolved_symbols.contains(&"UserModel".to_string()),
579 "locally defined UserModel should not be unresolved"
580 );
581 }
582
583 #[test]
584 fn add_import_inserts_at_correct_position() {
585 let dir = std::env::temp_dir().join(format!(
586 "codelens-addimport-{}",
587 std::time::SystemTime::now()
588 .duration_since(std::time::UNIX_EPOCH)
589 .unwrap()
590 .as_nanos()
591 ));
592 fs::create_dir_all(&dir).unwrap();
593 fs::write(
594 dir.join("test.py"),
595 "import os\nimport sys\n\ndef main():\n pass\n",
596 )
597 .unwrap();
598 let project = ProjectRoot::new(&dir).unwrap();
599 let result = add_import(&project, "test.py", "from models import User").unwrap();
600 let lines: Vec<&str> = result.lines().collect();
601 assert!(
603 lines.contains(&"from models import User"),
604 "should contain new import: {:?}",
605 lines
606 );
607 let import_idx = lines
608 .iter()
609 .position(|l| *l == "from models import User")
610 .unwrap();
611 let sys_idx = lines.iter().position(|l| *l == "import sys").unwrap();
612 assert!(
613 import_idx > sys_idx,
614 "new import should be after existing imports"
615 );
616 }
617
618 #[test]
619 fn skip_already_imported() {
620 let dir = std::env::temp_dir().join(format!(
621 "codelens-skipimport-{}",
622 std::time::SystemTime::now()
623 .duration_since(std::time::UNIX_EPOCH)
624 .unwrap()
625 .as_nanos()
626 ));
627 fs::create_dir_all(&dir).unwrap();
628 fs::write(
629 dir.join("test.py"),
630 "from models import User\n\nx = User()\n",
631 )
632 .unwrap();
633 let project = ProjectRoot::new(&dir).unwrap();
634 let result = add_import(&project, "test.py", "from models import User").unwrap();
635 assert_eq!(
637 result.matches("from models import User").count(),
638 1,
639 "should not duplicate import"
640 );
641 }
642
643 #[test]
644 fn find_import_insert_line_python() {
645 let source = "import os\nimport sys\n\ndef main():\n pass\n";
646 assert_eq!(find_import_insert_line(source, "py"), 3);
647 }
648
649 #[test]
650 fn find_import_insert_line_empty() {
651 let source = "def main():\n pass\n";
652 assert_eq!(find_import_insert_line(source, "py"), 1);
653 }
654
655 #[test]
656 fn generate_python_import() {
657 let stmt = generate_import_statement("UserModel", "src/models.py", "py");
658 assert_eq!(stmt, "from src.models import UserModel");
659 }
660
661 #[test]
662 fn generate_typescript_import() {
663 let stmt = generate_import_statement("UserService", "src/service.ts", "ts");
664 assert!(stmt.contains("import { UserService }"));
665 assert!(stmt.contains("'./src/service'"));
666 }
667}