Skip to main content

codelens_engine/call_graph/
extract.rs

1use std::collections::{HashMap, HashSet};
2use std::fs;
3use std::path::Path;
4use std::sync::{Arc, LazyLock, Mutex};
5
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{Language, Node, Parser, Query, QueryCursor};
8
9use super::js_imports::LocalBindingScope;
10use super::language::call_language_for_path;
11use super::noise::is_noise_callee_for_lang;
12use super::types::CallEdge;
13
14/// Cached compiled tree-sitter Query for call graph extraction.
15/// Key: (canonical language key, query string pointer as usize).
16type CallQueryCacheKey = (&'static str, usize);
17type CallQueryCache = Mutex<HashMap<CallQueryCacheKey, Arc<Query>>>;
18
19static CALL_QUERY_CACHE: LazyLock<CallQueryCache> = LazyLock::new(|| Mutex::new(HashMap::new()));
20fn cached_call_query(
21    language_key: &'static str,
22    language: &Language,
23    query_str: &'static str,
24) -> Option<Arc<Query>> {
25    let key = (language_key, query_str.as_ptr() as usize);
26    let mut cache = CALL_QUERY_CACHE.lock().unwrap_or_else(|p| p.into_inner());
27    if let Some(q) = cache.get(&key) {
28        return Some(Arc::clone(q));
29    }
30    let q = match Query::new(language, query_str) {
31        Ok(q) => q,
32        Err(error) => {
33            #[cfg(test)]
34            {
35                panic!("invalid call graph query: {error}");
36            }
37            #[cfg(not(test))]
38            {
39                let _ = error;
40                return None;
41            }
42        }
43    };
44    let q = Arc::new(q);
45    cache.insert(key, Arc::clone(&q));
46    Some(q)
47}
48/// Parse a file and extract all call edges within each function.
49pub fn extract_calls(path: &Path) -> Vec<CallEdge> {
50    let Ok(source) = fs::read_to_string(path) else {
51        return Vec::new();
52    };
53    extract_calls_from_source(path, &source)
54}
55
56fn collect_identifier_names(node: Node<'_>, source_bytes: &[u8], names: &mut HashSet<String>) {
57    if node.kind() == "identifier" {
58        if let Ok(name) = std::str::from_utf8(&source_bytes[node.start_byte()..node.end_byte()]) {
59            let name = name.trim();
60            if !name.is_empty() {
61                names.insert(name.to_owned());
62            }
63        }
64        return;
65    }
66
67    let mut cursor = node.walk();
68    for child in node.children(&mut cursor) {
69        collect_identifier_names(child, source_bytes, names);
70    }
71}
72
73fn collect_rust_closure_binding_scopes(
74    node: Node<'_>,
75    source_bytes: &[u8],
76    scopes: &mut Vec<LocalBindingScope>,
77) {
78    if node.kind() == "closure_expression" {
79        let mut names = HashSet::new();
80        if let Some(parameters) = node.child_by_field_name("parameters") {
81            collect_identifier_names(parameters, source_bytes, &mut names);
82        }
83        if !names.is_empty() {
84            scopes.push(LocalBindingScope {
85                start_byte: node.start_byte(),
86                end_byte: node.end_byte(),
87                names,
88            });
89        }
90    }
91
92    let mut cursor = node.walk();
93    for child in node.children(&mut cursor) {
94        collect_rust_closure_binding_scopes(child, source_bytes, scopes);
95    }
96}
97
98fn is_argument_identifier_capture(node: Node<'_>) -> bool {
99    node.parent().is_some_and(|parent| {
100        matches!(
101            parent.kind(),
102            "arguments" | "argument_list" | "value_arguments" | "value_argument"
103        )
104    })
105}
106
107fn shadowed_by_rust_closure_binding(
108    scopes: &[LocalBindingScope],
109    start_byte: usize,
110    end_byte: usize,
111    name: &str,
112) -> bool {
113    scopes.iter().any(|scope| {
114        scope.start_byte <= start_byte && scope.end_byte >= end_byte && scope.names.contains(name)
115    })
116}
117
118/// Extract call edges from already-loaded source content (avoids re-reading disk).
119pub fn extract_calls_from_source(path: &Path, source: &str) -> Vec<CallEdge> {
120    let Some(config) = call_language_for_path(path) else {
121        return Vec::new();
122    };
123
124    let mut parser = Parser::new();
125    if parser.set_language(&config.language).is_err() {
126        return Vec::new();
127    }
128    let Some(tree) = parser.parse(source, None) else {
129        return Vec::new();
130    };
131    let source_bytes = source.as_bytes();
132    let rust_closure_binding_scopes = if config.language_key == "rs" {
133        let mut scopes = Vec::new();
134        collect_rust_closure_binding_scopes(tree.root_node(), source_bytes, &mut scopes);
135        scopes
136    } else {
137        Vec::new()
138    };
139
140    // Build a map: byte_range_start -> caller_name for each function definition.
141    // We'll use this to find which function contains each call site.
142    let Some(func_query) =
143        cached_call_query(config.language_key, &config.language, config.func_query)
144    else {
145        return Vec::new();
146    };
147    let mut func_ranges: Vec<(usize, usize, String)> = Vec::new(); // (start, end, name)
148    let mut func_cursor = QueryCursor::new();
149    let mut func_matches = func_cursor.matches(&func_query, tree.root_node(), source_bytes);
150    while let Some(m) = func_matches.next() {
151        let mut def_range: Option<(usize, usize)> = None;
152        let mut func_name: Option<String> = None;
153        for cap in m.captures.iter() {
154            let cap_name = &func_query.capture_names()[cap.index as usize];
155            if *cap_name == "func.def" {
156                def_range = Some((cap.node.start_byte(), cap.node.end_byte()));
157            } else if *cap_name == "func.name" {
158                let start = cap.node.start_byte();
159                let end = cap.node.end_byte();
160                func_name = std::str::from_utf8(&source_bytes[start..end])
161                    .ok()
162                    .map(|s| s.trim().to_owned());
163            }
164        }
165        if let (Some((s, e)), Some(name)) = (def_range, func_name)
166            && !name.is_empty()
167        {
168            func_ranges.push((s, e, name));
169        }
170    }
171
172    // Parse call sites
173    let Some(call_query) =
174        cached_call_query(config.language_key, &config.language, config.call_query)
175    else {
176        return Vec::new();
177    };
178    let mut call_cursor = QueryCursor::new();
179    let mut call_matches = call_cursor.matches(&call_query, tree.root_node(), source_bytes);
180    let file_path = path.to_string_lossy().to_string();
181    let mut edges = Vec::new();
182
183    while let Some(m) = call_matches.next() {
184        let callee_qualifier = m
185            .captures
186            .iter()
187            .find(|cap| call_query.capture_names()[cap.index as usize] == "callee.object")
188            .and_then(|cap| {
189                let start = cap.node.start_byte();
190                let end = cap.node.end_byte();
191                std::str::from_utf8(&source_bytes[start..end])
192                    .ok()
193                    .map(str::trim)
194                    .filter(|name| !name.is_empty())
195                    .map(str::to_owned)
196            });
197        for cap in m.captures.iter() {
198            let cap_name = &call_query.capture_names()[cap.index as usize];
199            if *cap_name != "callee" {
200                continue;
201            }
202            let start = cap.node.start_byte();
203            let end = cap.node.end_byte();
204            let Ok(callee_name) = std::str::from_utf8(&source_bytes[start..end]) else {
205                continue;
206            };
207            let callee_name = callee_name.trim().to_owned();
208            if callee_name.is_empty()
209                || is_noise_callee_for_lang(&callee_name, Some(config.language_key))
210            {
211                continue;
212            }
213            if config.language_key == "rs"
214                && is_argument_identifier_capture(cap.node)
215                && shadowed_by_rust_closure_binding(
216                    &rust_closure_binding_scopes,
217                    start,
218                    end,
219                    &callee_name,
220                )
221            {
222                continue;
223            }
224            let line = cap.node.start_position().row + 1;
225
226            // Find the enclosing function
227            let caller_name = func_ranges
228                .iter()
229                .filter(|(fs, fe, _)| *fs <= start && *fe >= end)
230                // pick the innermost (smallest range)
231                .min_by_key(|(fs, fe, _)| fe - fs)
232                .map(|(_, _, name)| name.clone())
233                .unwrap_or_else(|| "<module>".to_owned());
234
235            edges.push(CallEdge {
236                caller_file: file_path.clone(),
237                caller_name,
238                callee_name,
239                callee_qualifier: callee_qualifier.clone(),
240                line,
241                resolved_file: None,
242                confidence: 0.0,
243                resolution_strategy: None,
244                canonical_callee_name: None,
245            });
246        }
247    }
248
249    edges
250}