Skip to main content

aft/
calls.rs

1//! Shared call-site extraction helpers.
2//!
3//! Extracted from `commands/zoom.rs` so both the zoom command and the
4//! call-graph engine can reuse the same AST-walking logic.
5
6use std::collections::BTreeSet;
7
8use crate::parser::LangId;
9
10/// Returns the tree-sitter node kind strings that represent call expressions
11/// for the given language.
12pub fn call_node_kinds(lang: LangId) -> Vec<&'static str> {
13    match lang {
14        LangId::TypeScript | LangId::JavaScript => vec!["call_expression", "new_expression"],
15        LangId::Tsx => vec![
16            "call_expression",
17            "new_expression",
18            "jsx_opening_element",
19            "jsx_self_closing_element",
20        ],
21        LangId::Go => vec!["call_expression"],
22        LangId::Python => vec!["call"],
23        LangId::Rust => vec!["call_expression", "macro_invocation"],
24        LangId::Solidity | LangId::Scala => vec!["call_expression"],
25        LangId::Java => vec!["method_invocation"],
26        LangId::Ruby => vec!["call"],
27        LangId::Kotlin | LangId::Swift => vec!["call_expression"],
28        LangId::Php => vec![
29            "function_call_expression",
30            "member_call_expression",
31            "nullsafe_member_call_expression",
32            "scoped_call_expression",
33        ],
34        LangId::Perl => vec![
35            "call_expression_recursive",
36            "call_expression_with_args_with_brackets",
37            "call_expression_with_bareword",
38            "call_expression_with_spaced_args",
39            "call_expression_with_sub",
40            "call_expression_with_variable",
41            "method_invocation",
42        ],
43        LangId::Lua => vec!["function_call"],
44        LangId::C
45        | LangId::Cpp
46        | LangId::Zig
47        | LangId::CSharp
48        | LangId::Bash
49        | LangId::Scss
50        | LangId::Vue
51        | LangId::Html
52        | LangId::Markdown
53        | LangId::Json
54        | LangId::Yaml => vec![],
55    }
56}
57
58/// Recursively walk tree nodes looking for call expressions within a byte range.
59///
60/// Collects `(callee_name, line_number)` pairs into `results`.
61pub fn walk_for_calls(
62    node: tree_sitter::Node,
63    source: &str,
64    byte_start: usize,
65    byte_end: usize,
66    call_kinds: &[&str],
67    results: &mut Vec<(String, u32)>,
68) {
69    let node_start = node.start_byte();
70    let node_end = node.end_byte();
71
72    // Skip nodes entirely outside our range
73    if node_end <= byte_start || node_start >= byte_end {
74        return;
75    }
76
77    if call_kinds.contains(&node.kind()) && node_start >= byte_start && node_end <= byte_end {
78        if let Some(name) = extract_callee_name(&node, source) {
79            results.push((name, node.start_position().row as u32 + 1));
80        }
81    }
82
83    // Recurse into children
84    let mut cursor = node.walk();
85    if cursor.goto_first_child() {
86        loop {
87            walk_for_calls(
88                cursor.node(),
89                source,
90                byte_start,
91                byte_end,
92                call_kinds,
93                results,
94            );
95            if !cursor.goto_next_sibling() {
96                break;
97            }
98        }
99    }
100}
101
102/// Extract the callee name from a call expression node.
103///
104/// For simple calls like `foo()`, returns "foo".
105/// For member access like `this.add()` or `obj.method()`, returns the last
106/// segment ("add" / "method").
107/// For Rust macros like `println!()`, returns "println!".
108pub fn extract_callee_name(node: &tree_sitter::Node, source: &str) -> Option<String> {
109    let kind = node.kind();
110
111    if kind == "macro_invocation" {
112        // Rust macro: first child is the macro name (e.g. `println!`)
113        let first_child = node.child(0)?;
114        let text = &source[first_child.byte_range()];
115        return Some(format!("{}!", text));
116    }
117
118    let func_node = callee_node(node)?;
119
120    let func_kind = func_node.kind();
121    match func_kind {
122        // Simple identifier: foo()
123        "identifier" => Some(source[func_node.byte_range()].to_string()),
124        // Member access: obj.method() / this.method()
125        "member_expression" | "field_expression" | "attribute" => {
126            // Last child that's a property_identifier, field_identifier, or identifier
127            extract_last_segment(&func_node, source)
128        }
129        // Computed member access: obj["method"]()
130        "subscript_expression" => extract_computed_member_name(&func_node, source)
131            .or_else(|| extract_last_segment(&func_node, source)),
132        _ => {
133            // Fallback: use the full text
134            let text = &source[func_node.byte_range()];
135            // If it contains a dot, take the last segment
136            if text.contains('.') {
137                text.rsplit('.').next().map(|s| s.trim().to_string())
138            } else {
139                Some(text.trim().to_string())
140            }
141        }
142    }
143}
144
145/// Extract the full callee expression from a call expression node.
146///
147/// Unlike `extract_callee_name` which returns only the last segment,
148/// this returns the full expression (e.g. "utils.foo" for `utils.foo()`).
149/// Used by the call graph engine to detect namespace-qualified calls.
150pub fn extract_full_callee(node: &tree_sitter::Node, source: &str) -> Option<String> {
151    let kind = node.kind();
152
153    if kind == "macro_invocation" {
154        let first_child = node.child(0)?;
155        let text = &source[first_child.byte_range()];
156        return Some(format!("{}!", text));
157    }
158
159    let func_node = callee_node(node)?;
160
161    Some(source[func_node.byte_range()].trim().to_string())
162}
163
164fn callee_node<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
165    match node.kind() {
166        "new_expression" => node
167            .child_by_field_name("constructor")
168            .or_else(|| node.named_child(0)),
169        "jsx_opening_element" | "jsx_self_closing_element" => node
170            .child_by_field_name("name")
171            .or_else(|| node.named_child(0)),
172        _ => node
173            .child_by_field_name("function")
174            .or_else(|| node.child(0)),
175    }
176}
177
178fn extract_computed_member_name(node: &tree_sitter::Node, source: &str) -> Option<String> {
179    let index = node.child_by_field_name("index")?;
180    let text = source[index.byte_range()].trim();
181    if (text.starts_with('"') && text.ends_with('"'))
182        || (text.starts_with('\'') && text.ends_with('\''))
183    {
184        return Some(text[1..text.len().saturating_sub(1)].to_string());
185    }
186    None
187}
188
189/// Extract the last segment of a member expression (the method/property name).
190pub fn extract_last_segment(node: &tree_sitter::Node, source: &str) -> Option<String> {
191    let child_count = node.child_count();
192    // Walk children from the end looking for an identifier-like node
193    for i in (0..child_count).rev() {
194        if let Some(child) = node.child(i as u32) {
195            match child.kind() {
196                "property_identifier" | "field_identifier" | "identifier" => {
197                    return Some(source[child.byte_range()].to_string());
198                }
199                _ => {}
200            }
201        }
202    }
203    // Fallback: full text, last dot segment
204    let text = &source[node.byte_range()];
205    text.rsplit('.').next().map(|s| s.trim().to_string())
206}
207
208/// Extract type-reference names within a byte range of the AST.
209///
210/// This is intentionally separate from call extraction. The live call graph and
211/// `aft_callgraph` commands remain call-edge-only; dead-code analysis consumes
212/// these type-position names as a side channel.
213pub fn extract_type_references_in_range(
214    source: &str,
215    root: tree_sitter::Node,
216    byte_start: usize,
217    byte_end: usize,
218    lang: LangId,
219) -> BTreeSet<String> {
220    let mut results = BTreeSet::new();
221    collect_type_references(root, source, byte_start, byte_end, lang, &mut results);
222    results
223}
224
225/// Extract all type-reference names in a parsed file.
226pub fn extract_type_references(
227    source: &str,
228    root: tree_sitter::Node,
229    lang: LangId,
230) -> BTreeSet<String> {
231    extract_type_references_in_range(source, root, 0, source.len(), lang)
232}
233
234fn collect_type_references(
235    node: tree_sitter::Node,
236    source: &str,
237    byte_start: usize,
238    byte_end: usize,
239    lang: LangId,
240    results: &mut BTreeSet<String>,
241) {
242    let node_start = node.start_byte();
243    let node_end = node.end_byte();
244
245    if node_end <= byte_start || node_start >= byte_end {
246        return;
247    }
248
249    if node_start >= byte_start && node_end <= byte_end {
250        collect_type_reference_fields(&node, source, lang, results);
251        if is_type_context_node(lang, node.kind()) {
252            collect_type_reference_identifiers(node, source, lang, results);
253            return;
254        }
255    }
256
257    let mut cursor = node.walk();
258    if cursor.goto_first_child() {
259        loop {
260            collect_type_references(cursor.node(), source, byte_start, byte_end, lang, results);
261            if !cursor.goto_next_sibling() {
262                break;
263            }
264        }
265    }
266}
267
268fn collect_type_reference_fields(
269    node: &tree_sitter::Node,
270    source: &str,
271    lang: LangId,
272    results: &mut BTreeSet<String>,
273) {
274    for field in ["type", "return_type", "result", "trait"] {
275        if let Some(child) = node.child_by_field_name(field) {
276            collect_type_reference_identifiers(child, source, lang, results);
277        }
278    }
279
280    if matches!(lang, LangId::TypeScript | LangId::Tsx) && node.kind() == "type_alias_declaration" {
281        if let Some(value) = node.child_by_field_name("value") {
282            collect_type_reference_identifiers(value, source, lang, results);
283        }
284    }
285}
286
287fn is_type_context_node(lang: LangId, kind: &str) -> bool {
288    match lang {
289        LangId::TypeScript | LangId::Tsx => matches!(
290            kind,
291            "type_annotation"
292                | "type_arguments"
293                | "extends_clause"
294                | "implements_clause"
295                | "satisfies_expression"
296        ),
297        LangId::JavaScript => false,
298        LangId::Python => kind == "type",
299        LangId::Rust => matches!(
300            kind,
301            "parameter"
302                | "field_declaration"
303                | "generic_type"
304                | "type_arguments"
305                | "reference_type"
306                | "array_type"
307                | "tuple_type"
308                | "bounded_type"
309        ),
310        LangId::Go => matches!(
311            kind,
312            "field_declaration"
313                | "parameter_declaration"
314                | "generic_type"
315                | "type_arguments"
316                | "type_elem"
317                | "pointer_type"
318                | "array_type"
319                | "slice_type"
320                | "map_type"
321                | "qualified_type"
322                | "channel_type"
323                | "function_type"
324        ),
325        _ => false,
326    }
327}
328
329fn collect_type_reference_identifiers(
330    node: tree_sitter::Node,
331    source: &str,
332    lang: LangId,
333    results: &mut BTreeSet<String>,
334) {
335    if is_type_reference_identifier(lang, node.kind()) {
336        let name = source[node.byte_range()].trim();
337        if let Some(name) = clean_type_reference_name(name) {
338            results.insert(name);
339        }
340    }
341
342    let mut cursor = node.walk();
343    if cursor.goto_first_child() {
344        loop {
345            collect_type_reference_identifiers(cursor.node(), source, lang, results);
346            if !cursor.goto_next_sibling() {
347                break;
348            }
349        }
350    }
351}
352
353fn is_type_reference_identifier(lang: LangId, kind: &str) -> bool {
354    match lang {
355        LangId::TypeScript | LangId::Tsx => matches!(kind, "type_identifier" | "identifier"),
356        LangId::Python => kind == "identifier",
357        LangId::Rust | LangId::Go => kind == "type_identifier",
358        _ => false,
359    }
360}
361
362fn clean_type_reference_name(name: &str) -> Option<String> {
363    let name = name
364        .rsplit(['.', ':'])
365        .find(|segment| !segment.is_empty())
366        .unwrap_or(name)
367        .trim()
368        .trim_start_matches('?');
369
370    if name.is_empty()
371        || !name
372            .chars()
373            .next()
374            .is_some_and(|c| c == '_' || c.is_alphabetic())
375    {
376        return None;
377    }
378
379    Some(name.to_string())
380}
381
382/// Extract call expression names within a byte range of the AST.
383///
384/// Walks all nodes in the tree, finds call_expression/call/macro_invocation
385/// nodes whose byte range falls within [byte_start, byte_end], and extracts
386/// the callee name (last segment for member access like `obj.method()`).
387///
388/// Returns (callee_name, line_number) pairs.
389pub fn extract_calls_in_range(
390    source: &str,
391    root: tree_sitter::Node,
392    byte_start: usize,
393    byte_end: usize,
394    lang: LangId,
395) -> Vec<(String, u32)> {
396    let mut results = Vec::new();
397    let call_kinds = call_node_kinds(lang);
398    walk_for_calls(
399        root,
400        source,
401        byte_start,
402        byte_end,
403        &call_kinds,
404        &mut results,
405    );
406    results
407}
408
409/// Extract calls with full callee expressions (including namespace qualifiers).
410///
411/// Returns `(full_callee, short_name, line)` triples.
412/// `full_callee` is e.g. "utils.foo", `short_name` is "foo".
413pub fn extract_calls_full(
414    source: &str,
415    root: tree_sitter::Node,
416    byte_start: usize,
417    byte_end: usize,
418    lang: LangId,
419) -> Vec<(String, String, u32)> {
420    let mut results = Vec::new();
421    let call_kinds = call_node_kinds(lang);
422    collect_calls_full(
423        root,
424        source,
425        byte_start,
426        byte_end,
427        &call_kinds,
428        &mut results,
429    );
430    results
431}
432
433fn collect_calls_full(
434    node: tree_sitter::Node,
435    source: &str,
436    byte_start: usize,
437    byte_end: usize,
438    call_kinds: &[&str],
439    results: &mut Vec<(String, String, u32)>,
440) {
441    let node_start = node.start_byte();
442    let node_end = node.end_byte();
443
444    if node_end <= byte_start || node_start >= byte_end {
445        return;
446    }
447
448    if call_kinds.contains(&node.kind()) && node_start >= byte_start && node_end <= byte_end {
449        if let (Some(full), Some(short)) = (
450            extract_full_callee(&node, source),
451            extract_callee_name(&node, source),
452        ) {
453            results.push((full, short, node.start_position().row as u32 + 1));
454        }
455    }
456
457    let mut cursor = node.walk();
458    if cursor.goto_first_child() {
459        loop {
460            collect_calls_full(
461                cursor.node(),
462                source,
463                byte_start,
464                byte_end,
465                call_kinds,
466                results,
467            );
468            if !cursor.goto_next_sibling() {
469                break;
470            }
471        }
472    }
473}