Skip to main content

magellan/
references.rs

1//! Reference and call extraction from Rust source code
2//!
3//! Extracts factual, byte-accurate references and calls to symbols without semantic analysis.
4
5use crate::common::safe_slice;
6use crate::ingest::{SymbolFact, SymbolKind};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11/// A fact about a reference to a symbol
12///
13/// Pure data structure. No behavior. No semantic resolution.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct ReferenceFact {
16    /// File containing this reference
17    pub file_path: PathBuf,
18    /// Name of the symbol being referenced
19    pub referenced_symbol: String,
20    /// Byte offset where reference starts in file
21    pub byte_start: usize,
22    /// Byte offset where reference ends in file
23    pub byte_end: usize,
24    /// Line where reference starts (1-indexed)
25    pub start_line: usize,
26    /// Column where reference starts (0-indexed, bytes)
27    pub start_col: usize,
28    /// Line where reference ends (1-indexed)
29    pub end_line: usize,
30    /// Column where reference ends (0-indexed, bytes)
31    pub end_col: usize,
32}
33
34/// A fact about a function call (forward call graph edge)
35///
36/// Represents: caller function → callee function
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
38pub struct CallFact {
39    /// File containing this call
40    pub file_path: PathBuf,
41    /// Name of the calling function
42    pub caller: String,
43    /// Name of the function being called
44    pub callee: String,
45    /// Stable symbol ID of the caller (optional, for correlation)
46    #[serde(default)]
47    pub caller_symbol_id: Option<String>,
48    /// Stable symbol ID of the callee (optional, for correlation)
49    #[serde(default)]
50    pub callee_symbol_id: Option<String>,
51    /// Byte offset where call starts in file
52    pub byte_start: usize,
53    /// Byte offset where call ends in file
54    pub byte_end: usize,
55    /// Line where call starts (1-indexed)
56    pub start_line: usize,
57    /// Column where call starts (0-indexed, bytes)
58    pub start_col: usize,
59    /// Line where call ends (1-indexed)
60    pub end_line: usize,
61    /// Column where call ends (0-indexed, bytes)
62    pub end_col: usize,
63}
64
65/// Reference extractor
66pub struct ReferenceExtractor {
67    parser: tree_sitter::Parser,
68}
69
70impl ReferenceExtractor {
71    /// Create a new reference extractor
72    pub fn new() -> anyhow::Result<Self> {
73        let mut parser = tree_sitter::Parser::new();
74        let language = tree_sitter_rust::language();
75        parser.set_language(&language)?;
76
77        Ok(Self { parser })
78    }
79
80    /// Extract reference facts from Rust source code
81    ///
82    /// # Arguments
83    /// * `file_path` - Path to the file (for context only, not accessed)
84    /// * `source` - Source code content as bytes
85    /// * `symbols` - Symbols defined in this file (to match against and exclude)
86    ///
87    /// # Returns
88    /// Vector of reference facts found in the source
89    ///
90    /// # Guarantees
91    /// - Pure function: same input → same output
92    /// - No side effects
93    /// - No filesystem access
94    /// - No semantic analysis (textual + position match only)
95    pub fn extract_references(
96        &mut self,
97        file_path: PathBuf,
98        source: &[u8],
99        symbols: &[SymbolFact],
100    ) -> Vec<ReferenceFact> {
101        let tree = match self.parser.parse(source, None) {
102            Some(t) => t,
103            None => return Vec::new(),
104        };
105
106        let root_node = tree.root_node();
107        let mut references = Vec::new();
108
109        // Walk tree and find references
110        self.walk_tree_for_references(&root_node, source, &file_path, symbols, &mut references);
111
112        references
113    }
114
115    /// Walk tree-sitter tree recursively and extract references
116    fn walk_tree_for_references(
117        &self,
118        node: &tree_sitter::Node,
119        source: &[u8],
120        file_path: &PathBuf,
121        symbols: &[SymbolFact],
122        references: &mut Vec<ReferenceFact>,
123    ) {
124        // Check if this node is a reference we care about
125        if let Some(reference) = self.extract_reference(node, source, file_path, symbols) {
126            references.push(reference);
127
128            // Don't recurse into scoped_identifier - we've already handled it
129            // This prevents extracting child identifier nodes within it
130            if node.kind() == "scoped_identifier" {
131                return;
132            }
133        }
134
135        // Recurse into children
136        let mut cursor = node.walk();
137        for child in node.children(&mut cursor) {
138            self.walk_tree_for_references(&child, source, file_path, symbols, references);
139        }
140    }
141
142    /// Extract a reference fact from a tree-sitter node, if applicable
143    fn extract_reference(
144        &self,
145        node: &tree_sitter::Node,
146        source: &[u8],
147        file_path: &PathBuf,
148        symbols: &[SymbolFact],
149    ) -> Option<ReferenceFact> {
150        let kind = node.kind();
151
152        // Only process identifier and scoped_identifier nodes
153        match kind {
154            "identifier" => {}
155            "scoped_identifier" => {}
156            _ => return None,
157        }
158
159        // Get the text of this node
160        let text_bytes = safe_slice(source, node.start_byte(), node.end_byte())?;
161        let text = std::str::from_utf8(text_bytes).ok()?;
162
163        // For scoped_identifier (e.g., a::foo), extract the final component
164        let symbol_name = if kind == "scoped_identifier" {
165            // Split by :: and take the last part
166            text.split("::").last().unwrap_or(text)
167        } else {
168            text
169        };
170
171        // Find if this matches any symbol
172        let referenced_symbol = symbols
173            .iter()
174            .find(|s| s.name.as_ref().map(|n| n == symbol_name).unwrap_or(false))?;
175
176        // Check if reference is OUTSIDE the symbol's defining span
177        let ref_start = node.start_byte();
178        let ref_end = node.end_byte();
179
180        // Only apply span filter for same-file references (self-references)
181        // Cross-file references should never be filtered by span
182        if referenced_symbol.file_path == *file_path && ref_start < referenced_symbol.byte_end {
183            return None; // Reference is within defining span (same file only)
184        }
185
186        Some(ReferenceFact {
187            file_path: file_path.clone(),
188            referenced_symbol: symbol_name.to_string(),
189            byte_start: ref_start,
190            byte_end: ref_end,
191            start_line: node.start_position().row + 1,
192            start_col: node.start_position().column,
193            end_line: node.end_position().row + 1,
194            end_col: node.end_position().column,
195        })
196    }
197}
198
199impl Default for ReferenceExtractor {
200    fn default() -> Self {
201        Self::new().expect("Failed to create reference extractor")
202    }
203}
204
205/// Extension to Parser for reference extraction (convenience wrapper)
206impl crate::ingest::Parser {
207    /// Extract reference facts using the inner parser
208    pub fn extract_references(
209        &mut self,
210        file_path: PathBuf,
211        source: &[u8],
212        symbols: &[SymbolFact],
213    ) -> Vec<ReferenceFact> {
214        let mut extractor = ReferenceExtractor::new().unwrap();
215        extractor.extract_references(file_path, source, symbols)
216    }
217
218    /// Extract function call facts (forward call graph)
219    ///
220    /// # Arguments
221    /// * `file_path` - Path to the file (for context only, not accessed)
222    /// * `source` - Source code content as bytes
223    /// * `symbols` - Symbols defined in this file (to match against)
224    ///
225    /// # Returns
226    /// Vector of CallFact representing caller → callee relationships
227    ///
228    /// # Guarantees
229    /// - Only function calls are extracted (not type references)
230    /// - Calls are extracted when a function identifier within a function body
231    ///   references another function symbol
232    /// - No semantic analysis (AST-based only)
233    pub fn extract_calls(
234        &mut self,
235        file_path: PathBuf,
236        source: &[u8],
237        symbols: &[SymbolFact],
238    ) -> Vec<CallFact> {
239        let mut extractor = CallExtractor::new().unwrap();
240        extractor.extract_calls(file_path, source, symbols)
241    }
242}
243
244/// Call extractor for forward call graph
245///
246/// Extracts caller → callee relationships from function bodies
247pub struct CallExtractor {
248    parser: tree_sitter::Parser,
249}
250
251impl CallExtractor {
252    /// Create a new call extractor
253    pub fn new() -> anyhow::Result<Self> {
254        let mut parser = tree_sitter::Parser::new();
255        let language = tree_sitter_rust::language();
256        parser.set_language(&language)?;
257
258        Ok(Self { parser })
259    }
260
261    /// Extract function call facts from Rust source code
262    ///
263    /// # Behavior
264    /// 1. Parse the source code
265    /// 2. Find all function definitions
266    /// 3. For each function, find identifier nodes that reference other functions
267    /// 4. Create CallFact for each unique caller → callee relationship
268    pub fn extract_calls(
269        &mut self,
270        file_path: PathBuf,
271        source: &[u8],
272        symbols: &[SymbolFact],
273    ) -> Vec<CallFact> {
274        let tree = match self.parser.parse(source, None) {
275            Some(t) => t,
276            None => return Vec::new(),
277        };
278
279        let root_node = tree.root_node();
280        let mut calls = Vec::new();
281
282        // Build map: symbol name → symbol fact (for quick lookup)
283        let symbol_map: HashMap<String, &SymbolFact> = symbols
284            .iter()
285            .filter_map(|s| s.name.as_ref().map(|name| (name.clone(), s)))
286            .collect();
287
288        // Filter to only functions (potential callers and callees)
289        let functions: Vec<&SymbolFact> = symbols
290            .iter()
291            .filter(|s| s.kind == SymbolKind::Function)
292            .collect();
293
294        // Walk tree and find calls within function bodies
295        self.walk_tree_for_calls(
296            &root_node,
297            source,
298            &file_path,
299            &symbol_map,
300            &functions,
301            &mut calls,
302        );
303
304        calls
305    }
306
307    /// Walk tree-sitter tree and extract function calls
308    fn walk_tree_for_calls(
309        &self,
310        node: &tree_sitter::Node,
311        source: &[u8],
312        file_path: &PathBuf,
313        symbol_map: &HashMap<String, &SymbolFact>,
314        _functions: &[&SymbolFact],
315        calls: &mut Vec<CallFact>,
316    ) {
317        self.walk_tree_for_calls_with_caller(node, source, file_path, symbol_map, None, calls);
318    }
319
320    /// Walk tree-sitter tree and extract function calls, tracking current function
321    fn walk_tree_for_calls_with_caller(
322        &self,
323        node: &tree_sitter::Node,
324        source: &[u8],
325        file_path: &PathBuf,
326        symbol_map: &HashMap<String, &SymbolFact>,
327        current_caller: Option<&SymbolFact>,
328        calls: &mut Vec<CallFact>,
329    ) {
330        let kind = node.kind();
331
332        // Track which function we're inside (if any)
333        let caller: Option<&SymbolFact> = if kind == "function_item" {
334            // Extract function name - this becomes the new caller for children
335            self.extract_function_name(node, source)
336                .and_then(|name| symbol_map.get(&name).copied())
337        } else {
338            current_caller
339        };
340
341        // If we have a caller and this is a call_expression, extract the call
342        if kind == "call_expression" {
343            if let Some(caller_fact) = caller {
344                self.extract_calls_in_node(node, source, file_path, caller_fact, symbol_map, calls);
345            }
346        }
347
348        // Recurse into children
349        let mut cursor = node.walk();
350        for child in node.children(&mut cursor) {
351            self.walk_tree_for_calls_with_caller(
352                &child, source, file_path, symbol_map, caller, calls,
353            );
354        }
355    }
356
357    /// Extract function name from a function_item node
358    fn extract_function_name(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
359        let mut cursor = node.walk();
360        for child in node.children(&mut cursor) {
361            if child.kind() == "identifier" || child.kind() == "type_identifier" {
362                let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
363                return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
364            }
365        }
366        None
367    }
368
369    /// Extract calls within a node (function body)
370    fn extract_calls_in_node(
371        &self,
372        node: &tree_sitter::Node,
373        source: &[u8],
374        file_path: &PathBuf,
375        caller: &SymbolFact,
376        symbol_map: &HashMap<String, &SymbolFact>,
377        calls: &mut Vec<CallFact>,
378    ) {
379        // Look for call_expression nodes or identifier nodes
380        let kind = node.kind();
381
382        if kind == "call_expression" {
383            // Extract the function being called
384            if let Some(callee_name) = self.extract_callee_from_call(node, source) {
385                // Only create call if callee is a known function symbol
386                if symbol_map.contains_key(&callee_name) {
387                    let node_start = node.start_byte();
388                    let node_end = node.end_byte();
389                    let call_fact = CallFact {
390                        file_path: file_path.clone(),
391                        caller: caller.name.clone().unwrap_or_default(),
392                        callee: callee_name,
393                        caller_symbol_id: None,
394                        callee_symbol_id: None,
395                        byte_start: node_start,
396                        byte_end: node_end,
397                        start_line: node.start_position().row + 1,
398                        start_col: node.start_position().column,
399                        end_line: node.end_position().row + 1,
400                        end_col: node.end_position().column,
401                    };
402                    calls.push(call_fact);
403                }
404            }
405        }
406    }
407
408    /// Extract callee name from a call_expression node
409    fn extract_callee_from_call(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
410        // The callee is typically the first child (identifier) or a scoped_identifier
411        let mut cursor = node.walk();
412        for child in node.children(&mut cursor) {
413            let kind = child.kind();
414            if kind == "identifier" {
415                let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
416                return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
417            }
418            // Handle method calls like obj.method() - we want the method name
419            if kind == "field_expression" || kind == "method_expression" {
420                // For a.b(), extract "b"
421                return self.extract_method_name(&child, source);
422            }
423        }
424        None
425    }
426
427    /// Extract method name from a field_expression or method_expression
428    fn extract_method_name(&self, node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
429        let mut cursor = node.walk();
430        for child in node.children(&mut cursor) {
431            // Look for the field_identifier (method name in a.b())
432            if child.kind() == "field_identifier" {
433                let name_bytes = safe_slice(source, child.start_byte(), child.end_byte())?;
434                return std::str::from_utf8(name_bytes).ok().map(|s| s.to_string());
435            }
436        }
437        None
438    }
439}
440
441impl Default for CallExtractor {
442    fn default() -> Self {
443        Self::new().expect("Failed to create call extractor")
444    }
445}