1use crate::db::{IndexDb, index_db_path};
6use crate::project::ProjectRoot;
7use crate::project::is_excluded;
8use crate::symbols::language_for_path;
9use anyhow::Result;
10use serde::Serialize;
11use std::fs;
12use tree_sitter::{Node, Parser};
13use walkdir::WalkDir;
14
15#[derive(Debug, Clone, Serialize)]
17pub struct ScopedReference {
18 pub file_path: String,
19 pub line: usize,
20 pub column: usize,
21 pub end_column: usize,
22 pub kind: ReferenceKind,
23 pub scope: String,
25 pub line_content: String,
26}
27
28#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
30#[serde(rename_all = "snake_case")]
31pub enum ReferenceKind {
32 Definition,
33 Read,
34 Write,
35 Import,
36}
37
38const SCOPE_NODES: &[&str] = &[
42 "function_definition",
44 "class_definition",
45 "lambda",
46 "function_declaration",
48 "method_definition",
49 "arrow_function",
50 "class_declaration",
51 "method_declaration",
53 "constructor_declaration",
54 "class_body",
55 "function_declaration",
57 "method_declaration",
58 "func_literal",
59 "function_item",
61 "impl_item",
62 "closure_expression",
63 "function_definition",
65 "module",
67 "program",
68];
69
70const DEFINITION_PARENTS: &[&str] = &[
72 "function_definition",
74 "class_definition",
75 "parameters",
76 "default_parameter",
77 "typed_parameter",
78 "typed_default_parameter",
79 "for_statement",
80 "as_pattern",
81 "function_declaration",
83 "class_declaration",
84 "variable_declarator",
85 "formal_parameters",
86 "required_parameter",
87 "optional_parameter",
88 "rest_parameter",
89 "method_declaration",
91 "constructor_declaration",
92 "local_variable_declaration",
93 "formal_parameter",
94 "enhanced_for_statement",
95 "function_declaration",
97 "method_declaration",
98 "short_var_declaration",
99 "var_spec",
100 "parameter_declaration",
101 "range_clause",
102 "function_item",
104 "let_declaration",
105 "parameter",
106 "for_expression",
107 "function_definition",
109 "declaration",
110 "init_declarator",
111 "parameter_declaration",
112];
113
114const WRITE_PARENTS: &[&str] = &[
116 "assignment",
117 "augmented_assignment",
118 "assignment_expression",
119 "update_expression",
120 "compound_assignment_expr",
121];
122
123const EXCLUDED_NODES: &[&str] = &[
125 "comment",
126 "line_comment",
127 "block_comment",
128 "string",
129 "string_literal",
130 "template_string",
131 "raw_string_literal",
132 "interpreted_string_literal",
133];
134
135pub fn find_scoped_references_in_file(
139 project: &ProjectRoot,
140 file_path: &str,
141 symbol_name: &str,
142 _definition_line: Option<usize>,
143) -> Result<Vec<ScopedReference>> {
144 let resolved = project.resolve(file_path)?;
145 let config = language_for_path(&resolved)
146 .ok_or_else(|| anyhow::anyhow!("unsupported file type: {file_path}"))?;
147 let source = fs::read_to_string(&resolved)?;
148
149 let mut parser = Parser::new();
150 parser.set_language(&config.language)?;
151 let tree = parser
152 .parse(&source, None)
153 .ok_or_else(|| anyhow::anyhow!("failed to parse {file_path}"))?;
154
155 let source_bytes = source.as_bytes();
156 let lines: Vec<&str> = source.lines().collect();
157 let mut results = Vec::new();
158
159 collect_references(
160 tree.root_node(),
161 source_bytes,
162 &lines,
163 symbol_name,
164 file_path,
165 &mut Vec::new(), &mut results,
167 );
168
169 Ok(results)
170}
171
172pub fn find_scoped_references(
174 project: &ProjectRoot,
175 symbol_name: &str,
176 declaration_file: Option<&str>,
177 max_results: usize,
178) -> Result<Vec<ScopedReference>> {
179 let mut all_results = Vec::new();
180
181 let db_path = index_db_path(project.as_path());
183 let indexed_files = IndexDb::open(&db_path)
184 .ok()
185 .and_then(|db| db.all_file_paths().ok())
186 .filter(|paths| !paths.is_empty());
187
188 if let Some(rel_paths) = indexed_files {
189 for rel in &rel_paths {
190 let abs = project.as_path().join(rel);
191 if language_for_path(&abs).is_none() {
192 continue;
193 }
194 match find_scoped_references_in_file(project, rel, symbol_name, None) {
195 Ok(refs) => {
196 for r in refs {
197 all_results.push(r);
198 if all_results.len() >= max_results {
199 return Ok(all_results);
200 }
201 }
202 }
203 Err(_) => continue,
204 }
205 }
206 } else {
207 for entry in WalkDir::new(project.as_path())
209 .into_iter()
210 .filter_entry(|e| !is_excluded(e.path()))
211 {
212 let entry = entry?;
213 if !entry.file_type().is_file() {
214 continue;
215 }
216 if language_for_path(entry.path()).is_none() {
217 continue;
218 }
219 let rel = project.to_relative(entry.path());
220 match find_scoped_references_in_file(project, &rel, symbol_name, None) {
221 Ok(refs) => {
222 for r in refs {
223 all_results.push(r);
224 if all_results.len() >= max_results {
225 return Ok(all_results);
226 }
227 }
228 }
229 Err(_) => continue,
230 }
231 }
232 }
233
234 if let Some(decl_file) = declaration_file {
236 let decl = decl_file.to_string();
237 all_results.sort_by(|a, b| {
238 let a_is_decl = a.file_path == decl;
239 let b_is_decl = b.file_path == decl;
240 b_is_decl
241 .cmp(&a_is_decl)
242 .then(a.file_path.cmp(&b.file_path))
243 .then(a.line.cmp(&b.line))
244 .then(a.column.cmp(&b.column))
245 });
246 }
247
248 Ok(all_results)
249}
250
251fn collect_references(
254 node: Node,
255 source: &[u8],
256 lines: &[&str],
257 target_name: &str,
258 file_path: &str,
259 scope_stack: &mut Vec<String>,
260 results: &mut Vec<ScopedReference>,
261) {
262 let node_type = node.kind();
263
264 if EXCLUDED_NODES.contains(&node_type) {
266 return;
267 }
268
269 let pushed_scope = if SCOPE_NODES.contains(&node_type) {
271 let scope_name = extract_scope_name(node, source);
272 scope_stack.push(scope_name);
273 true
274 } else {
275 false
276 };
277
278 if is_identifier_node(node_type) {
280 let text = node_text(node, source);
281 if text == target_name {
282 let line = node.start_position().row + 1;
283 let column = node.start_position().column + 1;
284 let end_column = node.end_position().column + 1;
285 let kind = classify_reference(node);
286 let scope = scope_stack.join(".");
287 let line_content = lines
288 .get(line - 1)
289 .map(|l| l.trim().to_string())
290 .unwrap_or_default();
291
292 results.push(ScopedReference {
293 file_path: file_path.to_string(),
294 line,
295 column,
296 end_column,
297 kind,
298 scope,
299 line_content,
300 });
301 }
302 }
303
304 let child_count = node.child_count();
306 for i in 0..child_count {
307 if let Some(child) = node.child(i) {
308 collect_references(
309 child,
310 source,
311 lines,
312 target_name,
313 file_path,
314 scope_stack,
315 results,
316 );
317 }
318 }
319
320 if pushed_scope {
322 scope_stack.pop();
323 }
324}
325
326fn is_identifier_node(kind: &str) -> bool {
327 matches!(
328 kind,
329 "identifier"
330 | "type_identifier"
331 | "field_identifier"
332 | "property_identifier"
333 | "shorthand_property_identifier"
334 | "shorthand_property_identifier_pattern"
335 )
336}
337
338fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
339 std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
340}
341
342fn extract_scope_name(node: Node, source: &[u8]) -> String {
343 for i in 0..node.child_count() {
345 if let Some(child) = node.child(i) {
346 let kind = child.kind();
347 if kind == "identifier" || kind == "type_identifier" || kind == "name" {
348 return node_text(child, source).to_string();
349 }
350 }
351 }
352 node.kind().to_string()
354}
355
356fn classify_reference(node: Node) -> ReferenceKind {
357 if let Some(parent) = node.parent() {
358 let parent_type = parent.kind();
359
360 if parent_type.contains("import") || is_inside_import(node) {
362 return ReferenceKind::Import;
363 }
364
365 if DEFINITION_PARENTS.contains(&parent_type) {
367 if is_parameter_context(parent) {
369 return ReferenceKind::Definition;
370 }
371 if is_name_child(node, parent) {
373 return ReferenceKind::Definition;
374 }
375 }
376 if let Some(grandparent) = parent.parent() {
378 let _gp_type = grandparent.kind();
379 if is_parameter_context(grandparent) && is_identifier_node(node.kind()) {
380 if parent.kind().contains("parameter") || parent.kind().contains("pattern") {
382 return ReferenceKind::Definition;
383 }
384 }
385 }
386
387 if WRITE_PARENTS.contains(&parent_type) {
389 if let Some(first_child) = parent.child(0)
391 && (first_child.id() == node.id()
392 || (first_child.kind() != "identifier" && contains_node(first_child, node)))
393 {
394 return ReferenceKind::Write;
395 }
396 }
397 }
398
399 ReferenceKind::Read
400}
401
402fn is_name_child(node: Node, parent: Node) -> bool {
403 if let Some(name_node) = parent.child_by_field_name("name") {
406 return name_node.id() == node.id();
407 }
408 for i in 0..parent.child_count() {
410 if let Some(child) = parent.child(i)
411 && is_identifier_node(child.kind())
412 {
413 return child.id() == node.id();
414 }
415 }
416 false
417}
418
419fn is_parameter_context(node: Node) -> bool {
420 let kind = node.kind();
421 matches!(
422 kind,
423 "parameters"
424 | "formal_parameters"
425 | "required_parameter"
426 | "optional_parameter"
427 | "rest_parameter"
428 | "formal_parameter"
429 | "parameter_declaration"
430 | "typed_parameter"
431 | "typed_default_parameter"
432 | "default_parameter"
433 | "parameter"
434 )
435}
436
437fn is_inside_import(node: Node) -> bool {
438 let mut current = node;
439 while let Some(parent) = current.parent() {
440 if parent.kind().contains("import") {
441 return true;
442 }
443 current = parent;
444 }
445 false
446}
447
448fn contains_node(haystack: Node, needle: Node) -> bool {
449 if haystack.id() == needle.id() {
450 return true;
451 }
452 for i in 0..haystack.child_count() {
453 if let Some(child) = haystack.child(i)
454 && contains_node(child, needle)
455 {
456 return true;
457 }
458 }
459 false
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::ProjectRoot;
466
467 fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
468 let dir = std::env::temp_dir().join(format!(
469 "codelens-scope-fixture-{}",
470 std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .unwrap()
473 .as_nanos()
474 ));
475 fs::create_dir_all(&dir).unwrap();
476 fs::write(
477 dir.join("example.py"),
478 r#"class UserService:
479 def get_user(self, user_id):
480 user = self.db.find(user_id)
481 return user
482
483 def delete_user(self, user_id):
484 user = self.get_user(user_id)
485 self.db.delete(user)
486
487def get_user():
488 return "standalone function"
489"#,
490 )
491 .unwrap();
492 fs::write(
493 dir.join("main.py"),
494 "from example import UserService\n\nsvc = UserService()\nresult = svc.get_user(1)\n",
495 )
496 .unwrap();
497 let project = ProjectRoot::new(&dir).unwrap();
498 (dir, project)
499 }
500
501 #[test]
502 fn finds_references_in_single_file() {
503 let (_dir, project) = make_fixture();
504 let refs = find_scoped_references_in_file(&project, "example.py", "user_id", None).unwrap();
505 assert!(refs.len() >= 4, "got {} refs", refs.len());
507 assert!(
509 refs.iter()
510 .any(|r| r.kind == ReferenceKind::Definition || r.kind == ReferenceKind::Read),
511 "should have at least one definition or read"
512 );
513 }
514
515 #[test]
516 fn classifies_definition_vs_read() {
517 let (_dir, project) = make_fixture();
518 let refs =
519 find_scoped_references_in_file(&project, "example.py", "get_user", None).unwrap();
520 let definitions: Vec<_> = refs
521 .iter()
522 .filter(|r| r.kind == ReferenceKind::Definition)
523 .collect();
524 let reads: Vec<_> = refs
525 .iter()
526 .filter(|r| r.kind == ReferenceKind::Read)
527 .collect();
528 assert!(
530 definitions.len() >= 2,
531 "expected >= 2 definitions, got {}",
532 definitions.len()
533 );
534 assert!(!reads.is_empty(), "should have reads");
536 }
537
538 #[test]
539 fn classifies_write() {
540 let (_dir, project) = make_fixture();
541 let refs = find_scoped_references_in_file(&project, "example.py", "user", None).unwrap();
542 let writes: Vec<_> = refs
543 .iter()
544 .filter(|r| r.kind == ReferenceKind::Write)
545 .collect();
546 assert!(
548 writes.len() >= 2,
549 "expected >= 2 writes, got {}",
550 writes.len()
551 );
552 }
553
554 #[test]
555 fn tracks_scope_names() {
556 let (_dir, project) = make_fixture();
557 let refs = find_scoped_references_in_file(&project, "example.py", "user_id", None).unwrap();
558 let scoped: Vec<_> = refs
560 .iter()
561 .filter(|r| r.scope.contains("UserService") && r.scope.contains("get_user"))
562 .collect();
563 assert!(
564 !scoped.is_empty(),
565 "should track nested scope: {:?}",
566 refs.iter().map(|r| &r.scope).collect::<Vec<_>>()
567 );
568 }
569
570 #[test]
571 fn cross_file_search() {
572 let (_dir, project) = make_fixture();
573 let refs = find_scoped_references(&project, "UserService", None, 100).unwrap();
574 let files: std::collections::HashSet<_> = refs.iter().map(|r| &r.file_path).collect();
575 assert!(
576 files.len() >= 2,
577 "should span multiple files, got: {:?}",
578 files
579 );
580 }
581
582 #[test]
583 fn detects_import_reference() {
584 let (_dir, project) = make_fixture();
585 let refs =
586 find_scoped_references_in_file(&project, "main.py", "UserService", None).unwrap();
587 let imports: Vec<_> = refs
588 .iter()
589 .filter(|r| r.kind == ReferenceKind::Import)
590 .collect();
591 assert!(
592 !imports.is_empty(),
593 "should detect import of UserService: {:?}",
594 refs.iter().map(|r| (&r.kind, r.line)).collect::<Vec<_>>()
595 );
596 }
597
598 #[test]
599 fn excludes_comments_and_strings() {
600 let dir = std::env::temp_dir().join(format!(
601 "codelens-scope-comment-{}",
602 std::time::SystemTime::now()
603 .duration_since(std::time::UNIX_EPOCH)
604 .unwrap()
605 .as_nanos()
606 ));
607 fs::create_dir_all(&dir).unwrap();
608 fs::write(
609 dir.join("test.py"),
610 "# foo is mentioned in comment\nx = foo\nprint(\"foo in string\")\n",
611 )
612 .unwrap();
613 let project = ProjectRoot::new(&dir).unwrap();
614 let refs = find_scoped_references_in_file(&project, "test.py", "foo", None).unwrap();
615 assert_eq!(
617 refs.len(),
618 1,
619 "should exclude comment/string refs, got: {:?}",
620 refs
621 );
622 }
623
624 #[test]
625 fn reference_kind_serialization() {
626 assert_eq!(
627 serde_json::to_string(&ReferenceKind::Definition).unwrap(),
628 "\"definition\""
629 );
630 assert_eq!(
631 serde_json::to_string(&ReferenceKind::Write).unwrap(),
632 "\"write\""
633 );
634 }
635}