Skip to main content

codelens_engine/
call_graph.rs

1use crate::project::ProjectRoot;
2use anyhow::Result;
3use serde::Serialize;
4use std::collections::HashMap;
5use std::fs;
6use std::path::{Path, PathBuf};
7use std::sync::{Arc, LazyLock, Mutex};
8use streaming_iterator::StreamingIterator;
9use tree_sitter::{Language, Parser, Query, QueryCursor};
10
11/// Cached compiled tree-sitter Query for call graph extraction.
12/// Key: (language pointer as usize, query string pointer as usize)
13static CALL_QUERY_CACHE: LazyLock<Mutex<HashMap<usize, Arc<Query>>>> =
14    LazyLock::new(|| Mutex::new(HashMap::new()));
15
16fn cached_call_query(language: &Language, query_str: &'static str) -> Option<Arc<Query>> {
17    let key = query_str.as_ptr() as usize;
18    let mut cache = CALL_QUERY_CACHE.lock().unwrap_or_else(|p| p.into_inner());
19    if let Some(q) = cache.get(&key) {
20        return Some(Arc::clone(q));
21    }
22    let q = Query::new(language, query_str).ok()?;
23    let q = Arc::new(q);
24    cache.insert(key, Arc::clone(&q));
25    Some(q)
26}
27
28use crate::project::collect_files;
29
30#[derive(Debug, Clone, Serialize)]
31pub struct CallEdge {
32    pub caller_file: String,
33    pub caller_name: String,
34    pub callee_name: String,
35    pub line: usize,
36    /// Resolved file where the callee is defined (None if unresolved).
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub resolved_file: Option<String>,
39    /// Confidence of the resolution (0.0–1.0). Higher = more certain.
40    pub confidence: f64,
41    /// Which resolution strategy succeeded.
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub resolution_strategy: Option<&'static str>,
44}
45
46#[derive(Debug, Clone, Serialize)]
47pub struct CallerEntry {
48    pub file: String,
49    pub function: String,
50    pub line: usize,
51    /// Confidence that this caller actually calls the target (0.0–1.0).
52    pub confidence: f64,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub resolution: Option<&'static str>,
55}
56
57#[derive(Debug, Clone, Serialize)]
58pub struct CalleeEntry {
59    pub name: String,
60    pub line: usize,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub resolved_file: Option<String>,
63    pub confidence: f64,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub resolution: Option<&'static str>,
66}
67
68struct CallLanguageConfig {
69    language: Language,
70    /// Query to find function definitions: captures @func.name
71    func_query: &'static str,
72    /// Query to find call sites: captures @callee
73    call_query: &'static str,
74}
75
76/// Resolve call graph config via the unified language registry.
77/// Only a subset of languages have call graph queries defined.
78/// Filter out common std/builtin method calls that add noise to the call graph.
79/// Covers Rust std, Python builtins, JS/TS builtins, Go builtins, and Java/Kotlin stdlib.
80fn is_noise_callee(name: &str) -> bool {
81    matches!(
82        name,
83        // ── cross-language common ──
84        "get" | "set" | "push" | "pop" | "len" | "new" | "from" | "into"
85            | "map" | "filter" | "collect" | "contains" | "insert" | "remove"
86            | "format" | "print" | "clone" | "default" | "next" | "read"
87            | "write" | "open" | "close" | "keys" | "values" | "sort"
88            | "reverse" | "find" | "replace" | "delete" | "add" | "clear"
89            | "of" | "size" | "copy"
90            // ── Rust std ──
91            | "is_empty" | "to_string" | "to_owned" | "as_str" | "as_ref"
92            | "unwrap" | "expect" | "ok" | "err" | "and_then" | "or_else"
93            | "unwrap_or" | "unwrap_or_else" | "unwrap_or_default"
94            | "iter" | "into_iter" | "take" | "skip"
95            | "println" | "eprintln" | "drop" | "enter" | "lock" | "cloned"
96            // ── Python builtins ──
97            | "range" | "enumerate" | "zip" | "sorted" | "reversed"
98            | "isinstance" | "issubclass" | "hasattr" | "getattr" | "setattr" | "delattr"
99            | "type" | "super" | "str" | "int" | "float" | "bool"
100            | "list" | "dict" | "tuple" | "frozenset" | "bytes" | "bytearray"
101            | "repr" | "abs" | "min" | "max" | "sum" | "any" | "all"
102            | "ord" | "chr" | "hex" | "oct" | "bin" | "hash" | "id"
103            | "input" | "vars" | "dir" | "help" | "round"
104            | "append" | "extend" | "update" | "items" | "join" | "split"
105            | "strip" | "startswith" | "endswith" | "encode" | "decode"
106            | "upper" | "lower"
107            // ── JS/TS builtins ──
108            | "log" | "warn" | "error" | "info" | "debug"
109            | "toString" | "valueOf" | "JSON" | "parse" | "stringify" | "assign"
110            | "entries" | "forEach" | "reduce" | "findIndex" | "some" | "every"
111            | "includes" | "indexOf" | "slice" | "splice" | "concat"
112            | "flat" | "flatMap" | "fill" | "isArray"
113            | "Promise" | "resolve" | "reject" | "then" | "catch" | "finally"
114            | "setTimeout" | "setInterval" | "clearTimeout" | "clearInterval"
115            | "parseInt" | "parseFloat" | "isNaN" | "isFinite" | "require"
116            // ── Go builtins ──
117            | "make" | "cap" | "panic" | "recover" | "real" | "imag" | "complex"
118            | "Println" | "Printf" | "Sprintf" | "Fprintf" | "Errorf" | "New"
119            // ── Java/Kotlin stdlib ──
120            | "equals" | "hashCode" | "compareTo" | "getClass"
121            | "notify" | "notifyAll" | "wait" | "isEmpty"
122            | "addAll" | "containsKey" | "containsValue" | "put" | "putAll"
123            | "entrySet" | "keySet" | "charAt" | "substring" | "trim"
124            | "length" | "toArray" | "stream" | "asList"
125    )
126}
127
128fn call_language_for_path(path: &Path) -> Option<CallLanguageConfig> {
129    let lang_config = crate::lang_config::language_for_path(path)?;
130    // Map canonical extension to call graph queries (not all languages support this)
131    let (func_query, call_query) = match lang_config.extension {
132        "py" => (PYTHON_FUNC_QUERY, PYTHON_CALL_QUERY),
133        "js" => (JS_FUNC_QUERY, JS_CALL_QUERY),
134        "ts" | "tsx" => (JS_FUNC_QUERY, JS_CALL_QUERY),
135        "go" => (GO_FUNC_QUERY, GO_CALL_QUERY),
136        "java" => (JAVA_FUNC_QUERY, JAVA_CALL_QUERY),
137        "kt" => (KOTLIN_FUNC_QUERY, JAVA_CALL_QUERY),
138        "rs" => (RUST_FUNC_QUERY, RUST_CALL_QUERY),
139        _ => return None,
140    };
141    Some(CallLanguageConfig {
142        language: lang_config.language,
143        func_query,
144        call_query,
145    })
146}
147
148fn collect_candidate_files(root: &Path) -> Result<Vec<PathBuf>> {
149    collect_files(root, |path| call_language_for_path(path).is_some())
150}
151
152/// Parse a file and extract all call edges within each function.
153pub fn extract_calls(path: &Path) -> Vec<CallEdge> {
154    let Ok(source) = fs::read_to_string(path) else {
155        return Vec::new();
156    };
157    extract_calls_from_source(path, &source)
158}
159
160/// Extract call edges from already-loaded source content (avoids re-reading disk).
161pub fn extract_calls_from_source(path: &Path, source: &str) -> Vec<CallEdge> {
162    let Some(config) = call_language_for_path(path) else {
163        return Vec::new();
164    };
165
166    let mut parser = Parser::new();
167    if parser.set_language(&config.language).is_err() {
168        return Vec::new();
169    }
170    let Some(tree) = parser.parse(source, None) else {
171        return Vec::new();
172    };
173    let source_bytes = source.as_bytes();
174
175    // Build a map: byte_range_start -> caller_name for each function definition.
176    // We'll use this to find which function contains each call site.
177    let Some(func_query) = cached_call_query(&config.language, config.func_query) else {
178        return Vec::new();
179    };
180    let mut func_ranges: Vec<(usize, usize, String)> = Vec::new(); // (start, end, name)
181    let mut func_cursor = QueryCursor::new();
182    let mut func_matches = func_cursor.matches(&func_query, tree.root_node(), source_bytes);
183    while let Some(m) = func_matches.next() {
184        let mut def_range: Option<(usize, usize)> = None;
185        let mut func_name: Option<String> = None;
186        for cap in m.captures.iter() {
187            let cap_name = &func_query.capture_names()[cap.index as usize];
188            if *cap_name == "func.def" {
189                def_range = Some((cap.node.start_byte(), cap.node.end_byte()));
190            } else if *cap_name == "func.name" {
191                let start = cap.node.start_byte();
192                let end = cap.node.end_byte();
193                func_name = std::str::from_utf8(&source_bytes[start..end])
194                    .ok()
195                    .map(|s| s.trim().to_owned());
196            }
197        }
198        if let (Some((s, e)), Some(name)) = (def_range, func_name)
199            && !name.is_empty()
200        {
201            func_ranges.push((s, e, name));
202        }
203    }
204
205    // Parse call sites
206    let Some(call_query) = cached_call_query(&config.language, config.call_query) else {
207        return Vec::new();
208    };
209    let mut call_cursor = QueryCursor::new();
210    let mut call_matches = call_cursor.matches(&call_query, tree.root_node(), source_bytes);
211    let file_path = path.to_string_lossy().to_string();
212    let mut edges = Vec::new();
213
214    while let Some(m) = call_matches.next() {
215        for cap in m.captures.iter() {
216            let cap_name = &call_query.capture_names()[cap.index as usize];
217            if *cap_name != "callee" {
218                continue;
219            }
220            let start = cap.node.start_byte();
221            let end = cap.node.end_byte();
222            let Ok(callee_name) = std::str::from_utf8(&source_bytes[start..end]) else {
223                continue;
224            };
225            let callee_name = callee_name.trim().to_owned();
226            if callee_name.is_empty() || is_noise_callee(&callee_name) {
227                continue;
228            }
229            let line = cap.node.start_position().row + 1;
230
231            // Find the enclosing function
232            let caller_name = func_ranges
233                .iter()
234                .filter(|(fs, fe, _)| *fs <= start && *fe >= end)
235                // pick the innermost (smallest range)
236                .min_by_key(|(fs, fe, _)| fe - fs)
237                .map(|(_, _, name)| name.clone())
238                .unwrap_or_else(|| "<module>".to_owned());
239
240            edges.push(CallEdge {
241                caller_file: file_path.clone(),
242                caller_name,
243                callee_name,
244                line,
245                resolved_file: None,
246                confidence: 0.0,
247                resolution_strategy: None,
248            });
249        }
250    }
251
252    edges
253}
254
255// ── 6-stage call resolution cascade ──────────────────────────────────────
256
257/// Resolve callee names to their definition files using a 6-stage confidence cascade.
258/// Mutates edges in-place, setting resolved_file, confidence, and resolution_strategy.
259pub fn resolve_call_edges(
260    edges: &mut [CallEdge],
261    project: &ProjectRoot,
262    import_graph: Option<&HashMap<String, crate::import_graph::FileNode>>,
263) {
264    // Build a name→files index from the symbol DB for stages 3-5
265    let db_path = crate::db::index_db_path(project.as_path());
266    let symbol_index: HashMap<String, Vec<String>> = crate::db::IndexDb::open(&db_path)
267        .and_then(|db| {
268            let all = db.all_symbol_names()?;
269            let mut map: HashMap<String, Vec<String>> = HashMap::new();
270            for (name, _kind, _sig, _line, _name_path, file) in all {
271                map.entry(name).or_default().push(file);
272            }
273            Ok(map)
274        })
275        .unwrap_or_default();
276
277    for edge in edges.iter_mut() {
278        if edge.confidence > 0.0 {
279            continue; // already resolved
280        }
281
282        let callee = &edge.callee_name;
283        let caller_file = &edge.caller_file;
284
285        // Stage 1: Import map — callee's prefix matches an import in caller file (0.95)
286        if let Some(graph) = import_graph
287            && let Some(node) = graph.get(caller_file)
288        {
289            for imported_file in &node.imports {
290                // Check if imported file defines callee
291                if let Some(defs) = symbol_index.get(callee)
292                    && defs.iter().any(|f| f == imported_file)
293                {
294                    edge.resolved_file = Some(imported_file.clone());
295                    edge.confidence = 0.95;
296                    edge.resolution_strategy = Some("import_map");
297                    break;
298                }
299            }
300        }
301        if edge.confidence > 0.0 {
302            continue;
303        }
304
305        // Stage 2: Same file — callee defined in the same file (0.90)
306        if let Some(defs) = symbol_index.get(callee)
307            && defs.iter().any(|f| f == caller_file)
308        {
309            edge.resolved_file = Some(caller_file.clone());
310            edge.confidence = 0.90;
311            edge.resolution_strategy = Some("same_file");
312            continue;
313        }
314
315        // Stage 3: Unique name — only one definition exists project-wide (0.75)
316        if let Some(defs) = symbol_index.get(callee)
317            && defs.len() == 1
318        {
319            edge.resolved_file = Some(defs[0].clone());
320            edge.confidence = 0.75;
321            edge.resolution_strategy = Some("unique_name");
322            continue;
323        }
324
325        // Stage 4: Import suffix — callee matches suffix of an imported module (0.60)
326        if let Some(graph) = import_graph
327            && let Some(node) = graph.get(caller_file)
328            && let Some(defs) = symbol_index.get(callee)
329        {
330            // Pick the candidate that is also imported (transitively)
331            for def_file in defs {
332                if node.imports.iter().any(|imp| {
333                    // Match on full path suffix, not just filename
334                    def_file.ends_with(imp)
335                        || def_file.ends_with(&format!("/{imp}"))
336                        || imp.ends_with(def_file)
337                        || imp.ends_with(&format!("/{def_file}"))
338                }) {
339                    edge.resolved_file = Some(def_file.clone());
340                    edge.confidence = 0.60;
341                    edge.resolution_strategy = Some("import_suffix");
342                    break;
343                }
344            }
345        }
346        if edge.confidence > 0.0 {
347            continue;
348        }
349
350        // Stage 5: Multiple candidates — pick closest by path similarity (0.40)
351        if let Some(defs) = symbol_index.get(callee)
352            && !defs.is_empty()
353        {
354            // Pick the one with the most shared path prefix with caller_file
355            let best = defs
356                .iter()
357                .max_by_key(|f| {
358                    f.chars()
359                        .zip(caller_file.chars())
360                        .take_while(|(a, b)| a == b)
361                        .count()
362                })
363                .cloned();
364            if let Some(f) = best {
365                edge.resolved_file = Some(f);
366                edge.confidence = 0.40;
367                edge.resolution_strategy = Some("path_proximity");
368                continue;
369            }
370        }
371
372        // Stage 6: Unresolved — callee not found in symbol DB (0.10)
373        edge.confidence = 0.10;
374        edge.resolution_strategy = Some("unresolved");
375    }
376}
377
378/// Find all functions that call `function_name` across the project.
379/// Edges are resolved via the 6-stage confidence cascade when an import graph is available.
380pub fn get_callers(
381    project: &ProjectRoot,
382    function_name: &str,
383    max_results: usize,
384) -> Result<Vec<CallerEntry>> {
385    let files = collect_candidate_files(project.as_path())?;
386    let mut all_edges: Vec<CallEdge> = Vec::new();
387
388    for file in &files {
389        let mut edges = extract_calls(file);
390        // Relativize caller_file paths
391        for edge in &mut edges {
392            edge.caller_file = project.to_relative(file);
393        }
394        all_edges.extend(edges);
395    }
396
397    // Resolve callee targets (best-effort, no import graph in this path)
398    resolve_call_edges(&mut all_edges, project, None);
399
400    // Filter to edges calling our target
401    let mut seen = std::collections::HashSet::new();
402    let mut results = Vec::new();
403
404    for edge in all_edges {
405        if edge.callee_name == function_name {
406            let key = (
407                edge.caller_file.clone(),
408                edge.caller_name.clone(),
409                edge.line,
410            );
411            if seen.insert(key) {
412                results.push(CallerEntry {
413                    file: edge.caller_file,
414                    function: edge.caller_name,
415                    line: edge.line,
416                    confidence: edge.confidence,
417                    resolution: edge.resolution_strategy,
418                });
419                if max_results > 0 && results.len() >= max_results {
420                    break;
421                }
422            }
423        }
424    }
425
426    // Sort by confidence descending
427    results.sort_by(|a, b| {
428        b.confidence
429            .partial_cmp(&a.confidence)
430            .unwrap_or(std::cmp::Ordering::Equal)
431    });
432    Ok(results)
433}
434
435/// Find all functions called by `function_name` (optionally restricted to a file).
436/// Callee names are resolved to their definition files via the 6-stage cascade.
437pub fn get_callees(
438    project: &ProjectRoot,
439    function_name: &str,
440    file_path: Option<&str>,
441    max_results: usize,
442) -> Result<Vec<CalleeEntry>> {
443    let files: Vec<PathBuf> = if let Some(fp) = file_path {
444        let resolved = project.resolve(fp)?;
445        vec![resolved]
446    } else {
447        collect_candidate_files(project.as_path())?
448    };
449
450    let mut all_edges: Vec<CallEdge> = Vec::new();
451    for file in &files {
452        let mut edges = extract_calls(file);
453        for edge in &mut edges {
454            edge.caller_file = project.to_relative(file);
455        }
456        all_edges.extend(edges);
457    }
458
459    resolve_call_edges(&mut all_edges, project, None);
460
461    let mut seen: HashMap<(String, usize), ()> = HashMap::new();
462    let mut results = Vec::new();
463
464    for edge in all_edges {
465        if edge.caller_name == function_name {
466            let key = (edge.callee_name.clone(), edge.line);
467            if seen.insert(key, ()).is_none() {
468                results.push(CalleeEntry {
469                    name: edge.callee_name,
470                    line: edge.line,
471                    resolved_file: edge.resolved_file,
472                    confidence: edge.confidence,
473                    resolution: edge.resolution_strategy,
474                });
475                if max_results > 0 && results.len() >= max_results {
476                    break;
477                }
478            }
479        }
480    }
481
482    results.sort_by(|a, b| {
483        b.confidence
484            .partial_cmp(&a.confidence)
485            .unwrap_or(std::cmp::Ordering::Equal)
486    });
487    Ok(results)
488}
489
490// ---- Tree-sitter queries ----
491
492const PYTHON_FUNC_QUERY: &str = r#"
493(function_definition name: (identifier) @func.name) @func.def
494"#;
495
496const PYTHON_CALL_QUERY: &str = r#"
497(call function: (identifier) @callee)
498(call function: (attribute attribute: (identifier) @callee))
499"#;
500
501const JS_FUNC_QUERY: &str = r#"
502(function_declaration name: (identifier) @func.name) @func.def
503(method_definition name: (property_identifier) @func.name) @func.def
504(function (identifier) @func.name) @func.def
505"#;
506
507const JS_CALL_QUERY: &str = r#"
508(call_expression function: (identifier) @callee)
509(call_expression function: (member_expression property: (property_identifier) @callee))
510"#;
511
512const GO_FUNC_QUERY: &str = r#"
513(function_declaration name: (identifier) @func.name) @func.def
514(method_declaration name: (field_identifier) @func.name) @func.def
515"#;
516
517const GO_CALL_QUERY: &str = r#"
518(call_expression function: (identifier) @callee)
519(call_expression function: (selector_expression field: (field_identifier) @callee))
520"#;
521
522const JAVA_FUNC_QUERY: &str = r#"
523(method_declaration name: (identifier) @func.name) @func.def
524(constructor_declaration name: (identifier) @func.name) @func.def
525"#;
526
527const JAVA_CALL_QUERY: &str = r#"
528(method_invocation name: (identifier) @callee)
529"#;
530
531const KOTLIN_FUNC_QUERY: &str = r#"
532(function_declaration name: (identifier) @func.name) @func.def
533"#;
534
535const RUST_FUNC_QUERY: &str = r#"
536(function_item name: (identifier) @func.name) @func.def
537"#;
538
539const RUST_CALL_QUERY: &str = r#"
540(call_expression function: (identifier) @callee)
541(call_expression function: (field_expression field: (field_identifier) @callee))
542"#;
543
544#[cfg(test)]
545mod tests {
546    use super::{extract_calls, get_callees, get_callers};
547    use crate::ProjectRoot;
548    use std::fs;
549
550    fn temp_dir(name: &str) -> std::path::PathBuf {
551        let dir = std::env::temp_dir().join(format!(
552            "codelens-callgraph-{name}-{}",
553            std::time::SystemTime::now()
554                .duration_since(std::time::UNIX_EPOCH)
555                .expect("time")
556                .as_nanos()
557        ));
558        fs::create_dir_all(&dir).expect("create tempdir");
559        dir
560    }
561
562    #[test]
563    fn extracts_python_calls() {
564        let dir = temp_dir("py");
565        let path = dir.join("main.py");
566        fs::write(
567            &path,
568            "def greet(name):\n    return helper(name)\n\ndef helper(x):\n    return x\n",
569        )
570        .expect("write");
571        let edges = extract_calls(&path);
572        assert!(
573            edges
574                .iter()
575                .any(|e| e.caller_name == "greet" && e.callee_name == "helper"),
576            "expected greet->helper edge, got {edges:?}"
577        );
578    }
579
580    #[test]
581    fn extracts_rust_calls() {
582        let dir = temp_dir("rs");
583        let path = dir.join("main.rs");
584        fs::write(&path, "fn main() {\n    run();\n}\n\nfn run() {}\n").expect("write");
585        let edges = extract_calls(&path);
586        assert!(
587            edges
588                .iter()
589                .any(|e| e.caller_name == "main" && e.callee_name == "run"),
590            "expected main->run edge, got {edges:?}"
591        );
592    }
593
594    #[test]
595    fn get_callers_finds_callers() {
596        let dir = temp_dir("callers");
597        fs::write(dir.join("a.py"), "def foo():\n    bar()\n    baz()\n").expect("write a");
598        fs::write(dir.join("b.py"), "def qux():\n    bar()\n").expect("write b");
599        fs::write(dir.join("c.py"), "def bar():\n    pass\n").expect("write c");
600
601        let project = ProjectRoot::new(&dir).expect("project");
602        let callers = get_callers(&project, "bar", 50).expect("callers");
603        let names: Vec<&str> = callers.iter().map(|c| c.function.as_str()).collect();
604        assert!(
605            names.contains(&"foo"),
606            "expected foo as caller, got {names:?}"
607        );
608        assert!(
609            names.contains(&"qux"),
610            "expected qux as caller, got {names:?}"
611        );
612    }
613
614    #[test]
615    fn get_callees_finds_callees() {
616        let dir = temp_dir("callees");
617        fs::write(
618            dir.join("main.py"),
619            "def main():\n    foo()\n    bar()\n\ndef foo():\n    pass\n\ndef bar():\n    pass\n",
620        )
621        .expect("write");
622
623        let project = ProjectRoot::new(&dir).expect("project");
624        let callees = get_callees(&project, "main", None, 50).expect("callees");
625        let names: Vec<&str> = callees.iter().map(|c| c.name.as_str()).collect();
626        assert!(
627            names.contains(&"foo"),
628            "expected foo as callee, got {names:?}"
629        );
630        assert!(
631            names.contains(&"bar"),
632            "expected bar as callee, got {names:?}"
633        );
634    }
635
636    #[test]
637    fn get_callees_scoped_to_file() {
638        let dir = temp_dir("callees-file");
639        fs::write(dir.join("a.py"), "def process():\n    helper()\n").expect("write a");
640        fs::write(dir.join("b.py"), "def process():\n    other()\n").expect("write b");
641
642        let project = ProjectRoot::new(&dir).expect("project");
643        let callees = get_callees(&project, "process", Some("a.py"), 50).expect("callees");
644        let names: Vec<&str> = callees.iter().map(|c| c.name.as_str()).collect();
645        assert!(names.contains(&"helper"), "expected helper, got {names:?}");
646        assert!(!names.contains(&"other"), "should not have other from b.py");
647    }
648}