Skip to main content

codelens_engine/
scope_analysis.rs

1//! Scope-aware reference analysis using tree-sitter.
2//!
3//! Replaces JetBrains PSI `find_references` with tree-sitter scope resolution.
4
5use crate::db::{IndexDb, index_db_path};
6use crate::project::ProjectRoot;
7use crate::project::is_excluded;
8use crate::symbols::language_for_path;
9use anyhow::Result;
10use serde::Serialize;
11use std::fs;
12use tree_sitter::{Node, Parser};
13use walkdir::WalkDir;
14
15/// A resolved reference with scope context.
16#[derive(Debug, Clone, Serialize)]
17pub struct ScopedReference {
18    pub file_path: String,
19    pub line: usize,
20    pub column: usize,
21    pub end_column: usize,
22    pub kind: ReferenceKind,
23    /// Enclosing scope name (e.g. "UserService.get_user")
24    pub scope: String,
25    pub line_content: String,
26}
27
28/// Classification of how a symbol is referenced.
29#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
30#[serde(rename_all = "snake_case")]
31pub enum ReferenceKind {
32    Definition,
33    Read,
34    Write,
35    Import,
36}
37
38// ── Node type sets for classification ────────────────────────────────────
39
40/// AST node types that define a new scope
41const SCOPE_NODES: &[&str] = &[
42    // Python
43    "function_definition",
44    "class_definition",
45    "lambda",
46    // JS/TS
47    "function_declaration",
48    "method_definition",
49    "arrow_function",
50    "class_declaration",
51    // Java/Kotlin
52    "method_declaration",
53    "constructor_declaration",
54    "class_body",
55    // Go
56    "function_declaration",
57    "method_declaration",
58    "func_literal",
59    // Rust
60    "function_item",
61    "impl_item",
62    "closure_expression",
63    // C/C++
64    "function_definition",
65    // General
66    "module",
67    "program",
68];
69
70/// AST node types where an identifier child is a definition
71const DEFINITION_PARENTS: &[&str] = &[
72    // Python
73    "function_definition",
74    "class_definition",
75    "parameters",
76    "default_parameter",
77    "typed_parameter",
78    "typed_default_parameter",
79    "for_statement",
80    "as_pattern",
81    // JS/TS
82    "function_declaration",
83    "class_declaration",
84    "variable_declarator",
85    "formal_parameters",
86    "required_parameter",
87    "optional_parameter",
88    "rest_parameter",
89    // Java/Kotlin
90    "method_declaration",
91    "constructor_declaration",
92    "local_variable_declaration",
93    "formal_parameter",
94    "enhanced_for_statement",
95    // Go
96    "function_declaration",
97    "method_declaration",
98    "short_var_declaration",
99    "var_spec",
100    "parameter_declaration",
101    "range_clause",
102    // Rust
103    "function_item",
104    "let_declaration",
105    "parameter",
106    "for_expression",
107    // C/C++
108    "function_definition",
109    "declaration",
110    "init_declarator",
111    "parameter_declaration",
112];
113
114/// AST node types where an identifier is written (assigned)
115const WRITE_PARENTS: &[&str] = &[
116    "assignment",
117    "augmented_assignment",
118    "assignment_expression",
119    "update_expression",
120    "compound_assignment_expr",
121];
122
123/// AST node types that are comments or strings (to exclude)
124const EXCLUDED_NODES: &[&str] = &[
125    "comment",
126    "line_comment",
127    "block_comment",
128    "string",
129    "string_literal",
130    "template_string",
131    "raw_string_literal",
132    "interpreted_string_literal",
133];
134
135// ── Public API ───────────────────────────────────────────────────────────
136
137/// Find all scope-aware references to a symbol in a single file.
138pub fn find_scoped_references_in_file(
139    project: &ProjectRoot,
140    file_path: &str,
141    symbol_name: &str,
142    _definition_line: Option<usize>,
143) -> Result<Vec<ScopedReference>> {
144    let resolved = project.resolve(file_path)?;
145    let config = language_for_path(&resolved)
146        .ok_or_else(|| anyhow::anyhow!("unsupported file type: {file_path}"))?;
147    let source = fs::read_to_string(&resolved)?;
148
149    let mut parser = Parser::new();
150    parser.set_language(&config.language)?;
151    let tree = parser
152        .parse(&source, None)
153        .ok_or_else(|| anyhow::anyhow!("failed to parse {file_path}"))?;
154
155    let source_bytes = source.as_bytes();
156    let lines: Vec<&str> = source.lines().collect();
157    let mut results = Vec::new();
158
159    collect_references(
160        tree.root_node(),
161        source_bytes,
162        &lines,
163        symbol_name,
164        file_path,
165        &mut Vec::new(), // scope stack
166        &mut results,
167    );
168
169    Ok(results)
170}
171
172/// Find all scope-aware references across the project.
173pub fn find_scoped_references(
174    project: &ProjectRoot,
175    symbol_name: &str,
176    declaration_file: Option<&str>,
177    max_results: usize,
178) -> Result<Vec<ScopedReference>> {
179    let mut all_results = Vec::new();
180
181    // Try DB-accelerated path: iterate only indexed files
182    let db_path = index_db_path(project.as_path());
183    let indexed_files = IndexDb::open(&db_path)
184        .ok()
185        .and_then(|db| db.all_file_paths().ok())
186        .filter(|paths| !paths.is_empty());
187
188    if let Some(rel_paths) = indexed_files {
189        for rel in &rel_paths {
190            let abs = project.as_path().join(rel);
191            if language_for_path(&abs).is_none() {
192                continue;
193            }
194            match find_scoped_references_in_file(project, rel, symbol_name, None) {
195                Ok(refs) => {
196                    for r in refs {
197                        all_results.push(r);
198                        if all_results.len() >= max_results {
199                            return Ok(all_results);
200                        }
201                    }
202                }
203                Err(_) => continue,
204            }
205        }
206    } else {
207        // Fallback: full walk
208        for entry in WalkDir::new(project.as_path())
209            .into_iter()
210            .filter_entry(|e| !is_excluded(e.path()))
211        {
212            let entry = entry?;
213            if !entry.file_type().is_file() {
214                continue;
215            }
216            if language_for_path(entry.path()).is_none() {
217                continue;
218            }
219            let rel = project.to_relative(entry.path());
220            match find_scoped_references_in_file(project, &rel, symbol_name, None) {
221                Ok(refs) => {
222                    for r in refs {
223                        all_results.push(r);
224                        if all_results.len() >= max_results {
225                            return Ok(all_results);
226                        }
227                    }
228                }
229                Err(_) => continue,
230            }
231        }
232    }
233
234    // Sort: declaration file first, then by file/line
235    if let Some(decl_file) = declaration_file {
236        let decl = decl_file.to_string();
237        all_results.sort_by(|a, b| {
238            let a_is_decl = a.file_path == decl;
239            let b_is_decl = b.file_path == decl;
240            b_is_decl
241                .cmp(&a_is_decl)
242                .then(a.file_path.cmp(&b.file_path))
243                .then(a.line.cmp(&b.line))
244                .then(a.column.cmp(&b.column))
245        });
246    }
247
248    Ok(all_results)
249}
250
251// ── AST traversal ────────────────────────────────────────────────────────
252
253fn collect_references(
254    node: Node,
255    source: &[u8],
256    lines: &[&str],
257    target_name: &str,
258    file_path: &str,
259    scope_stack: &mut Vec<String>,
260    results: &mut Vec<ScopedReference>,
261) {
262    let node_type = node.kind();
263
264    // Skip excluded nodes (comments, strings)
265    if EXCLUDED_NODES.contains(&node_type) {
266        return;
267    }
268
269    // Push scope
270    let pushed_scope = if SCOPE_NODES.contains(&node_type) {
271        let scope_name = extract_scope_name(node, source);
272        scope_stack.push(scope_name);
273        true
274    } else {
275        false
276    };
277
278    // Check if this is an identifier matching our target
279    if is_identifier_node(node_type) {
280        let text = node_text(node, source);
281        if text == target_name {
282            let line = node.start_position().row + 1;
283            let column = node.start_position().column + 1;
284            let end_column = node.end_position().column + 1;
285            let kind = classify_reference(node);
286            let scope = scope_stack.join(".");
287            let line_content = lines
288                .get(line - 1)
289                .map(|l| l.trim().to_string())
290                .unwrap_or_default();
291
292            results.push(ScopedReference {
293                file_path: file_path.to_string(),
294                line,
295                column,
296                end_column,
297                kind,
298                scope,
299                line_content,
300            });
301        }
302    }
303
304    // Recurse into children
305    let child_count = node.child_count();
306    for i in 0..child_count {
307        if let Some(child) = node.child(i) {
308            collect_references(
309                child,
310                source,
311                lines,
312                target_name,
313                file_path,
314                scope_stack,
315                results,
316            );
317        }
318    }
319
320    // Pop scope
321    if pushed_scope {
322        scope_stack.pop();
323    }
324}
325
326fn is_identifier_node(kind: &str) -> bool {
327    matches!(
328        kind,
329        "identifier"
330            | "type_identifier"
331            | "field_identifier"
332            | "property_identifier"
333            | "shorthand_property_identifier"
334            | "shorthand_property_identifier_pattern"
335    )
336}
337
338fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
339    std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
340}
341
342fn extract_scope_name(node: Node, source: &[u8]) -> String {
343    // Try to find a name child (identifier)
344    for i in 0..node.child_count() {
345        if let Some(child) = node.child(i) {
346            let kind = child.kind();
347            if kind == "identifier" || kind == "type_identifier" || kind == "name" {
348                return node_text(child, source).to_string();
349            }
350        }
351    }
352    // Fallback to node type
353    node.kind().to_string()
354}
355
356fn classify_reference(node: Node) -> ReferenceKind {
357    if let Some(parent) = node.parent() {
358        let parent_type = parent.kind();
359
360        // Import detection — check parent chain for import nodes
361        if parent_type.contains("import") || is_inside_import(node) {
362            return ReferenceKind::Import;
363        }
364
365        // Definition detection
366        if DEFINITION_PARENTS.contains(&parent_type) {
367            // Parameters: ALL identifier children are definitions
368            if is_parameter_context(parent) {
369                return ReferenceKind::Definition;
370            }
371            // Other definitions: only the "name" child
372            if is_name_child(node, parent) {
373                return ReferenceKind::Definition;
374            }
375        }
376        // Also check grandparent for typed_parameter → identifier patterns
377        if let Some(grandparent) = parent.parent() {
378            let _gp_type = grandparent.kind();
379            if is_parameter_context(grandparent) && is_identifier_node(node.kind()) {
380                // identifier inside a typed_parameter/default_parameter = definition
381                if parent.kind().contains("parameter") || parent.kind().contains("pattern") {
382                    return ReferenceKind::Definition;
383                }
384            }
385        }
386
387        // Write detection
388        if WRITE_PARENTS.contains(&parent_type) {
389            // Left side of assignment
390            if let Some(first_child) = parent.child(0)
391                && (first_child.id() == node.id()
392                    || (first_child.kind() != "identifier" && contains_node(first_child, node)))
393            {
394                return ReferenceKind::Write;
395            }
396        }
397    }
398
399    ReferenceKind::Read
400}
401
402fn is_name_child(node: Node, parent: Node) -> bool {
403    // In most languages, the "name" of a definition is the first identifier child
404    // or a specifically named field
405    if let Some(name_node) = parent.child_by_field_name("name") {
406        return name_node.id() == node.id();
407    }
408    // Fallback: first identifier child
409    for i in 0..parent.child_count() {
410        if let Some(child) = parent.child(i)
411            && is_identifier_node(child.kind())
412        {
413            return child.id() == node.id();
414        }
415    }
416    false
417}
418
419fn is_parameter_context(node: Node) -> bool {
420    let kind = node.kind();
421    matches!(
422        kind,
423        "parameters"
424            | "formal_parameters"
425            | "required_parameter"
426            | "optional_parameter"
427            | "rest_parameter"
428            | "formal_parameter"
429            | "parameter_declaration"
430            | "typed_parameter"
431            | "typed_default_parameter"
432            | "default_parameter"
433            | "parameter"
434    )
435}
436
437fn is_inside_import(node: Node) -> bool {
438    let mut current = node;
439    while let Some(parent) = current.parent() {
440        if parent.kind().contains("import") {
441            return true;
442        }
443        current = parent;
444    }
445    false
446}
447
448fn contains_node(haystack: Node, needle: Node) -> bool {
449    if haystack.id() == needle.id() {
450        return true;
451    }
452    for i in 0..haystack.child_count() {
453        if let Some(child) = haystack.child(i)
454            && contains_node(child, needle)
455        {
456            return true;
457        }
458    }
459    false
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::ProjectRoot;
466
467    fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
468        let dir = std::env::temp_dir().join(format!(
469            "codelens-scope-fixture-{}",
470            std::time::SystemTime::now()
471                .duration_since(std::time::UNIX_EPOCH)
472                .unwrap()
473                .as_nanos()
474        ));
475        fs::create_dir_all(&dir).unwrap();
476        fs::write(
477            dir.join("example.py"),
478            r#"class UserService:
479    def get_user(self, user_id):
480        user = self.db.find(user_id)
481        return user
482
483    def delete_user(self, user_id):
484        user = self.get_user(user_id)
485        self.db.delete(user)
486
487def get_user():
488    return "standalone function"
489"#,
490        )
491        .unwrap();
492        fs::write(
493            dir.join("main.py"),
494            "from example import UserService\n\nsvc = UserService()\nresult = svc.get_user(1)\n",
495        )
496        .unwrap();
497        let project = ProjectRoot::new(&dir).unwrap();
498        (dir, project)
499    }
500
501    #[test]
502    fn finds_references_in_single_file() {
503        let (_dir, project) = make_fixture();
504        let refs = find_scoped_references_in_file(&project, "example.py", "user_id", None).unwrap();
505        // user_id appears as parameter in get_user and delete_user, plus usages
506        assert!(refs.len() >= 4, "got {} refs", refs.len());
507        // At least some should be definitions (parameters) or reads
508        assert!(
509            refs.iter()
510                .any(|r| r.kind == ReferenceKind::Definition || r.kind == ReferenceKind::Read),
511            "should have at least one definition or read"
512        );
513    }
514
515    #[test]
516    fn classifies_definition_vs_read() {
517        let (_dir, project) = make_fixture();
518        let refs =
519            find_scoped_references_in_file(&project, "example.py", "get_user", None).unwrap();
520        let definitions: Vec<_> = refs
521            .iter()
522            .filter(|r| r.kind == ReferenceKind::Definition)
523            .collect();
524        let reads: Vec<_> = refs
525            .iter()
526            .filter(|r| r.kind == ReferenceKind::Read)
527            .collect();
528        // "def get_user" = 2 definitions (class method + standalone function)
529        assert!(
530            definitions.len() >= 2,
531            "expected >= 2 definitions, got {}",
532            definitions.len()
533        );
534        // "self.get_user(user_id)" = read
535        assert!(!reads.is_empty(), "should have reads");
536    }
537
538    #[test]
539    fn classifies_write() {
540        let (_dir, project) = make_fixture();
541        let refs = find_scoped_references_in_file(&project, "example.py", "user", None).unwrap();
542        let writes: Vec<_> = refs
543            .iter()
544            .filter(|r| r.kind == ReferenceKind::Write)
545            .collect();
546        // "user = self.db.find(user_id)" and "user = self.get_user(user_id)" are writes
547        assert!(
548            writes.len() >= 2,
549            "expected >= 2 writes, got {}",
550            writes.len()
551        );
552    }
553
554    #[test]
555    fn tracks_scope_names() {
556        let (_dir, project) = make_fixture();
557        let refs = find_scoped_references_in_file(&project, "example.py", "user_id", None).unwrap();
558        // Refs inside UserService.get_user should have scope containing both
559        let scoped: Vec<_> = refs
560            .iter()
561            .filter(|r| r.scope.contains("UserService") && r.scope.contains("get_user"))
562            .collect();
563        assert!(
564            !scoped.is_empty(),
565            "should track nested scope: {:?}",
566            refs.iter().map(|r| &r.scope).collect::<Vec<_>>()
567        );
568    }
569
570    #[test]
571    fn cross_file_search() {
572        let (_dir, project) = make_fixture();
573        let refs = find_scoped_references(&project, "UserService", None, 100).unwrap();
574        let files: std::collections::HashSet<_> = refs.iter().map(|r| &r.file_path).collect();
575        assert!(
576            files.len() >= 2,
577            "should span multiple files, got: {:?}",
578            files
579        );
580    }
581
582    #[test]
583    fn detects_import_reference() {
584        let (_dir, project) = make_fixture();
585        let refs =
586            find_scoped_references_in_file(&project, "main.py", "UserService", None).unwrap();
587        let imports: Vec<_> = refs
588            .iter()
589            .filter(|r| r.kind == ReferenceKind::Import)
590            .collect();
591        assert!(
592            !imports.is_empty(),
593            "should detect import of UserService: {:?}",
594            refs.iter().map(|r| (&r.kind, r.line)).collect::<Vec<_>>()
595        );
596    }
597
598    #[test]
599    fn excludes_comments_and_strings() {
600        let dir = std::env::temp_dir().join(format!(
601            "codelens-scope-comment-{}",
602            std::time::SystemTime::now()
603                .duration_since(std::time::UNIX_EPOCH)
604                .unwrap()
605                .as_nanos()
606        ));
607        fs::create_dir_all(&dir).unwrap();
608        fs::write(
609            dir.join("test.py"),
610            "# foo is mentioned in comment\nx = foo\nprint(\"foo in string\")\n",
611        )
612        .unwrap();
613        let project = ProjectRoot::new(&dir).unwrap();
614        let refs = find_scoped_references_in_file(&project, "test.py", "foo", None).unwrap();
615        // Should only find the assignment "x = foo", not comment or string
616        assert_eq!(
617            refs.len(),
618            1,
619            "should exclude comment/string refs, got: {:?}",
620            refs
621        );
622    }
623
624    #[test]
625    fn reference_kind_serialization() {
626        assert_eq!(
627            serde_json::to_string(&ReferenceKind::Definition).unwrap(),
628            "\"definition\""
629        );
630        assert_eq!(
631            serde_json::to_string(&ReferenceKind::Write).unwrap(),
632            "\"write\""
633        );
634    }
635}