Skip to main content

aft/
calls.rs

1//! Shared call-site extraction helpers.
2//!
3//! Extracted from `commands/zoom.rs` so both the zoom command and the
4//! call-graph engine can reuse the same AST-walking logic.
5
6use crate::parser::LangId;
7
8/// Returns the tree-sitter node kind strings that represent call expressions
9/// for the given language.
10pub fn call_node_kinds(lang: LangId) -> Vec<&'static str> {
11    match lang {
12        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Go => {
13            vec!["call_expression"]
14        }
15        LangId::Python => vec!["call"],
16        LangId::Rust => vec!["call_expression", "macro_invocation"],
17        LangId::Markdown => vec![],
18    }
19}
20
21/// Recursively walk tree nodes looking for call expressions within a byte range.
22///
23/// Collects `(callee_name, line_number)` pairs into `results`.
24pub fn walk_for_calls(
25    node: tree_sitter::Node,
26    source: &str,
27    byte_start: usize,
28    byte_end: usize,
29    call_kinds: &[&str],
30    results: &mut Vec<(String, u32)>,
31) {
32    let node_start = node.start_byte();
33    let node_end = node.end_byte();
34
35    // Skip nodes entirely outside our range
36    if node_end <= byte_start || node_start >= byte_end {
37        return;
38    }
39
40    if call_kinds.contains(&node.kind()) && node_start >= byte_start && node_end <= byte_end {
41        if let Some(name) = extract_callee_name(&node, source) {
42            results.push((name, node.start_position().row as u32 + 1));
43        }
44    }
45
46    // Recurse into children
47    let mut cursor = node.walk();
48    if cursor.goto_first_child() {
49        loop {
50            walk_for_calls(
51                cursor.node(),
52                source,
53                byte_start,
54                byte_end,
55                call_kinds,
56                results,
57            );
58            if !cursor.goto_next_sibling() {
59                break;
60            }
61        }
62    }
63}
64
65/// Extract the callee name from a call expression node.
66///
67/// For simple calls like `foo()`, returns "foo".
68/// For member access like `this.add()` or `obj.method()`, returns the last
69/// segment ("add" / "method").
70/// For Rust macros like `println!()`, returns "println!".
71pub fn extract_callee_name(node: &tree_sitter::Node, source: &str) -> Option<String> {
72    let kind = node.kind();
73
74    if kind == "macro_invocation" {
75        // Rust macro: first child is the macro name (e.g. `println!`)
76        let first_child = node.child(0)?;
77        let text = &source[first_child.byte_range()];
78        return Some(format!("{}!", text));
79    }
80
81    // call_expression / call — get the "function" child
82    let func_node = node
83        .child_by_field_name("function")
84        .or_else(|| node.child(0))?;
85
86    let func_kind = func_node.kind();
87    match func_kind {
88        // Simple identifier: foo()
89        "identifier" => Some(source[func_node.byte_range()].to_string()),
90        // Member access: obj.method() / this.method()
91        "member_expression" | "field_expression" | "attribute" => {
92            // Last child that's a property_identifier, field_identifier, or identifier
93            extract_last_segment(&func_node, source)
94        }
95        _ => {
96            // Fallback: use the full text
97            let text = &source[func_node.byte_range()];
98            // If it contains a dot, take the last segment
99            if text.contains('.') {
100                text.rsplit('.').next().map(|s| s.trim().to_string())
101            } else {
102                Some(text.trim().to_string())
103            }
104        }
105    }
106}
107
108/// Extract the full callee expression from a call expression node.
109///
110/// Unlike `extract_callee_name` which returns only the last segment,
111/// this returns the full expression (e.g. "utils.foo" for `utils.foo()`).
112/// Used by the call graph engine to detect namespace-qualified calls.
113pub fn extract_full_callee(node: &tree_sitter::Node, source: &str) -> Option<String> {
114    let kind = node.kind();
115
116    if kind == "macro_invocation" {
117        let first_child = node.child(0)?;
118        let text = &source[first_child.byte_range()];
119        return Some(format!("{}!", text));
120    }
121
122    let func_node = node
123        .child_by_field_name("function")
124        .or_else(|| node.child(0))?;
125
126    Some(source[func_node.byte_range()].trim().to_string())
127}
128
129/// Extract the last segment of a member expression (the method/property name).
130pub fn extract_last_segment(node: &tree_sitter::Node, source: &str) -> Option<String> {
131    let child_count = node.child_count();
132    // Walk children from the end looking for an identifier-like node
133    for i in (0..child_count).rev() {
134        if let Some(child) = node.child(i as u32) {
135            match child.kind() {
136                "property_identifier" | "field_identifier" | "identifier" => {
137                    return Some(source[child.byte_range()].to_string());
138                }
139                _ => {}
140            }
141        }
142    }
143    // Fallback: full text, last dot segment
144    let text = &source[node.byte_range()];
145    text.rsplit('.').next().map(|s| s.trim().to_string())
146}
147
148/// Extract call expression names within a byte range of the AST.
149///
150/// Walks all nodes in the tree, finds call_expression/call/macro_invocation
151/// nodes whose byte range falls within [byte_start, byte_end], and extracts
152/// the callee name (last segment for member access like `obj.method()`).
153///
154/// Returns (callee_name, line_number) pairs.
155pub fn extract_calls_in_range(
156    source: &str,
157    root: tree_sitter::Node,
158    byte_start: usize,
159    byte_end: usize,
160    lang: LangId,
161) -> Vec<(String, u32)> {
162    let mut results = Vec::new();
163    let call_kinds = call_node_kinds(lang);
164    walk_for_calls(
165        root,
166        source,
167        byte_start,
168        byte_end,
169        &call_kinds,
170        &mut results,
171    );
172    results
173}
174
175/// Extract calls with full callee expressions (including namespace qualifiers).
176///
177/// Returns `(full_callee, short_name, line)` triples.
178/// `full_callee` is e.g. "utils.foo", `short_name` is "foo".
179pub fn extract_calls_full(
180    source: &str,
181    root: tree_sitter::Node,
182    byte_start: usize,
183    byte_end: usize,
184    lang: LangId,
185) -> Vec<(String, String, u32)> {
186    let mut results = Vec::new();
187    let call_kinds = call_node_kinds(lang);
188    collect_calls_full(
189        root,
190        source,
191        byte_start,
192        byte_end,
193        &call_kinds,
194        &mut results,
195    );
196    results
197}
198
199fn collect_calls_full(
200    node: tree_sitter::Node,
201    source: &str,
202    byte_start: usize,
203    byte_end: usize,
204    call_kinds: &[&str],
205    results: &mut Vec<(String, String, u32)>,
206) {
207    let node_start = node.start_byte();
208    let node_end = node.end_byte();
209
210    if node_end <= byte_start || node_start >= byte_end {
211        return;
212    }
213
214    if call_kinds.contains(&node.kind()) && node_start >= byte_start && node_end <= byte_end {
215        if let (Some(full), Some(short)) = (
216            extract_full_callee(&node, source),
217            extract_callee_name(&node, source),
218        ) {
219            results.push((full, short, node.start_position().row as u32 + 1));
220        }
221    }
222
223    let mut cursor = node.walk();
224    if cursor.goto_first_child() {
225        loop {
226            collect_calls_full(
227                cursor.node(),
228                source,
229                byte_start,
230                byte_end,
231                call_kinds,
232                results,
233            );
234            if !cursor.goto_next_sibling() {
235                break;
236            }
237        }
238    }
239}