infiniloom_engine/index/
query.rs

1//! Call graph query API
2//!
3//! High-level functions for querying call relationships between symbols.
4//! Used by both Python and Node.js bindings.
5
6use super::types::{DepGraph, IndexSymbol, IndexSymbolKind, SymbolIndex, Visibility};
7use serde::Serialize;
8
9/// Information about a symbol, returned from call graph queries
10#[derive(Debug, Clone, Serialize)]
11pub struct SymbolInfo {
12    /// Symbol ID
13    pub id: u32,
14    /// Symbol name
15    pub name: String,
16    /// Symbol kind (function, class, method, etc.)
17    pub kind: String,
18    /// File path containing the symbol
19    pub file: String,
20    /// Start line number
21    pub line: u32,
22    /// End line number
23    pub end_line: u32,
24    /// Function/method signature
25    pub signature: Option<String>,
26    /// Visibility (public, private, etc.)
27    pub visibility: String,
28}
29
30/// A reference location in the codebase
31#[derive(Debug, Clone, Serialize)]
32pub struct ReferenceInfo {
33    /// Symbol making the reference
34    pub symbol: SymbolInfo,
35    /// Reference kind (call, import, inherit, implement)
36    pub kind: String,
37}
38
39/// An edge in the call graph
40#[derive(Debug, Clone, Serialize)]
41pub struct CallGraphEdge {
42    /// Caller symbol ID
43    pub caller_id: u32,
44    /// Callee symbol ID
45    pub callee_id: u32,
46    /// Caller symbol name
47    pub caller: String,
48    /// Callee symbol name
49    pub callee: String,
50    /// File containing the call site
51    pub file: String,
52    /// Line number of the call
53    pub line: u32,
54}
55
56/// Complete call graph with nodes and edges
57#[derive(Debug, Clone, Serialize)]
58pub struct CallGraph {
59    /// All symbols (nodes)
60    pub nodes: Vec<SymbolInfo>,
61    /// Call relationships (edges)
62    pub edges: Vec<CallGraphEdge>,
63    /// Summary statistics
64    pub stats: CallGraphStats,
65}
66
67/// Call graph statistics
68#[derive(Debug, Clone, Serialize)]
69pub struct CallGraphStats {
70    /// Total number of symbols
71    pub total_symbols: usize,
72    /// Total number of call edges
73    pub total_calls: usize,
74    /// Number of functions/methods
75    pub functions: usize,
76    /// Number of classes/structs
77    pub classes: usize,
78}
79
80impl SymbolInfo {
81    /// Create SymbolInfo from an IndexSymbol
82    pub fn from_index_symbol(sym: &IndexSymbol, index: &SymbolIndex) -> Self {
83        let file_path = index
84            .get_file_by_id(sym.file_id.as_u32())
85            .map(|f| f.path.clone())
86            .unwrap_or_else(|| "<unknown>".to_owned());
87
88        Self {
89            id: sym.id.as_u32(),
90            name: sym.name.clone(),
91            kind: format_symbol_kind(sym.kind),
92            file: file_path,
93            line: sym.span.start_line,
94            end_line: sym.span.end_line,
95            signature: sym.signature.clone(),
96            visibility: format_visibility(sym.visibility),
97        }
98    }
99}
100
101/// Find a symbol by name and return its info
102///
103/// Deduplicates results by file path and line number to avoid returning
104/// the same symbol multiple times (e.g., export + declaration).
105pub fn find_symbol(index: &SymbolIndex, name: &str) -> Vec<SymbolInfo> {
106    let mut results: Vec<SymbolInfo> = index
107        .find_symbols(name)
108        .into_iter()
109        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
110        .collect();
111
112    // Deduplicate by (file, line) to avoid returning export+declaration as separate entries
113    results.sort_by(|a, b| (&a.file, a.line).cmp(&(&b.file, b.line)));
114    results.dedup_by(|a, b| a.file == b.file && a.line == b.line);
115
116    results
117}
118
119/// Get all callers of a symbol by name
120///
121/// Returns symbols that call any symbol with the given name.
122pub fn get_callers_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
123    let mut callers = Vec::new();
124
125    // Find all symbols with this name
126    for sym in index.find_symbols(name) {
127        let symbol_id = sym.id.as_u32();
128
129        // Get callers from the dependency graph
130        for caller_id in graph.get_callers(symbol_id) {
131            if let Some(caller_sym) = index.get_symbol(caller_id) {
132                callers.push(SymbolInfo::from_index_symbol(caller_sym, index));
133            }
134        }
135    }
136
137    // Deduplicate by symbol ID
138    callers.sort_by_key(|s| s.id);
139    callers.dedup_by_key(|s| s.id);
140
141    callers
142}
143
144/// Get all callees of a symbol by name
145///
146/// Returns symbols that are called by any symbol with the given name.
147pub fn get_callees_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
148    let mut callees = Vec::new();
149
150    // Find all symbols with this name
151    for sym in index.find_symbols(name) {
152        let symbol_id = sym.id.as_u32();
153
154        // Get callees from the dependency graph
155        for callee_id in graph.get_callees(symbol_id) {
156            if let Some(callee_sym) = index.get_symbol(callee_id) {
157                callees.push(SymbolInfo::from_index_symbol(callee_sym, index));
158            }
159        }
160    }
161
162    // Deduplicate by symbol ID
163    callees.sort_by_key(|s| s.id);
164    callees.dedup_by_key(|s| s.id);
165
166    callees
167}
168
169/// Get all references to a symbol by name
170///
171/// Returns symbols that reference any symbol with the given name
172/// (includes calls, imports, inheritance, and implementations).
173pub fn get_references_by_name(
174    index: &SymbolIndex,
175    graph: &DepGraph,
176    name: &str,
177) -> Vec<ReferenceInfo> {
178    let mut references = Vec::new();
179
180    // Find all symbols with this name
181    for sym in index.find_symbols(name) {
182        let symbol_id = sym.id.as_u32();
183
184        // Get callers (call references)
185        for caller_id in graph.get_callers(symbol_id) {
186            if let Some(caller_sym) = index.get_symbol(caller_id) {
187                references.push(ReferenceInfo {
188                    symbol: SymbolInfo::from_index_symbol(caller_sym, index),
189                    kind: "call".to_owned(),
190                });
191            }
192        }
193
194        // Get referencers (symbol_ref - may include imports/inheritance)
195        for ref_id in graph.get_referencers(symbol_id) {
196            if let Some(ref_sym) = index.get_symbol(ref_id) {
197                // Avoid duplicates with callers
198                if !graph.get_callers(symbol_id).contains(&ref_id) {
199                    references.push(ReferenceInfo {
200                        symbol: SymbolInfo::from_index_symbol(ref_sym, index),
201                        kind: "reference".to_owned(),
202                    });
203                }
204            }
205        }
206    }
207
208    // Deduplicate by symbol ID
209    references.sort_by_key(|r| r.symbol.id);
210    references.dedup_by_key(|r| r.symbol.id);
211
212    references
213}
214
215/// Get the complete call graph
216///
217/// Returns all symbols (nodes) and call relationships (edges).
218/// For large codebases, consider using `get_call_graph_filtered` with limits.
219pub fn get_call_graph(index: &SymbolIndex, graph: &DepGraph) -> CallGraph {
220    get_call_graph_filtered(index, graph, None, None)
221}
222
223/// Get a filtered call graph
224///
225/// Args:
226///   - `max_nodes`: Optional limit on number of symbols returned
227///   - `max_edges`: Optional limit on number of edges returned
228pub fn get_call_graph_filtered(
229    index: &SymbolIndex,
230    graph: &DepGraph,
231    max_nodes: Option<usize>,
232    max_edges: Option<usize>,
233) -> CallGraph {
234    // Bug #5 fix: When only max_edges is specified, limit nodes to those that appear in edges
235    // This ensures users get a small, focused graph rather than all nodes with limited edges
236
237    // First, collect all edges and apply edge limit
238    let mut edges: Vec<CallGraphEdge> = graph
239        .calls
240        .iter()
241        .filter_map(|&(caller_id, callee_id)| {
242            let caller_sym = index.get_symbol(caller_id)?;
243            let callee_sym = index.get_symbol(callee_id)?;
244
245            let file_path = index
246                .get_file_by_id(caller_sym.file_id.as_u32())
247                .map(|f| f.path.clone())
248                .unwrap_or_else(|| "<unknown>".to_owned());
249
250            Some(CallGraphEdge {
251                caller_id,
252                callee_id,
253                caller: caller_sym.name.clone(),
254                callee: callee_sym.name.clone(),
255                file: file_path,
256                line: caller_sym.span.start_line,
257            })
258        })
259        .collect();
260
261    // Apply edge limit first (before node filtering for more intuitive behavior)
262    if let Some(limit) = max_edges {
263        edges.truncate(limit);
264    }
265
266    // Collect node IDs that appear in the (possibly limited) edges
267    let edge_node_ids: std::collections::HashSet<u32> = edges
268        .iter()
269        .flat_map(|e| [e.caller_id, e.callee_id])
270        .collect();
271
272    // Collect nodes - when max_edges is specified without max_nodes, only include nodes from edges
273    let mut nodes: Vec<SymbolInfo> = if max_edges.is_some() && max_nodes.is_none() {
274        // Only include nodes that appear in the limited edges
275        index
276            .symbols
277            .iter()
278            .filter(|sym| edge_node_ids.contains(&sym.id.as_u32()))
279            .map(|sym| SymbolInfo::from_index_symbol(sym, index))
280            .collect()
281    } else {
282        // Include all nodes, then optionally truncate
283        index
284            .symbols
285            .iter()
286            .map(|sym| SymbolInfo::from_index_symbol(sym, index))
287            .collect()
288    };
289
290    // Apply node limit if specified
291    if let Some(limit) = max_nodes {
292        nodes.truncate(limit);
293
294        // When max_nodes is applied, also filter edges to only include those between limited nodes
295        let node_ids: std::collections::HashSet<u32> = nodes.iter().map(|n| n.id).collect();
296        edges.retain(|e| node_ids.contains(&e.caller_id) && node_ids.contains(&e.callee_id));
297    }
298
299    // Calculate statistics
300    let functions = nodes
301        .iter()
302        .filter(|n| n.kind == "function" || n.kind == "method")
303        .count();
304    let classes = nodes
305        .iter()
306        .filter(|n| n.kind == "class" || n.kind == "struct")
307        .count();
308
309    let stats =
310        CallGraphStats { total_symbols: nodes.len(), total_calls: edges.len(), functions, classes };
311
312    CallGraph { nodes, edges, stats }
313}
314
315/// Get callers of a symbol by its ID
316pub fn get_callers_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
317    graph
318        .get_callers(symbol_id)
319        .into_iter()
320        .filter_map(|id| index.get_symbol(id))
321        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
322        .collect()
323}
324
325/// Get callees of a symbol by its ID
326pub fn get_callees_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
327    graph
328        .get_callees(symbol_id)
329        .into_iter()
330        .filter_map(|id| index.get_symbol(id))
331        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
332        .collect()
333}
334
335// Helper functions
336
337fn format_symbol_kind(kind: IndexSymbolKind) -> String {
338    match kind {
339        IndexSymbolKind::Function => "function",
340        IndexSymbolKind::Method => "method",
341        IndexSymbolKind::Class => "class",
342        IndexSymbolKind::Struct => "struct",
343        IndexSymbolKind::Interface => "interface",
344        IndexSymbolKind::Trait => "trait",
345        IndexSymbolKind::Enum => "enum",
346        IndexSymbolKind::Constant => "constant",
347        IndexSymbolKind::Variable => "variable",
348        IndexSymbolKind::Module => "module",
349        IndexSymbolKind::Import => "import",
350        IndexSymbolKind::Export => "export",
351        IndexSymbolKind::TypeAlias => "type_alias",
352        IndexSymbolKind::Macro => "macro",
353    }
354    .to_owned()
355}
356
357fn format_visibility(vis: Visibility) -> String {
358    match vis {
359        Visibility::Public => "public",
360        Visibility::Private => "private",
361        Visibility::Protected => "protected",
362        Visibility::Internal => "internal",
363    }
364    .to_owned()
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::index::types::{FileEntry, FileId, Language, Span, SymbolId};
371
372    fn create_test_index() -> (SymbolIndex, DepGraph) {
373        let mut index = SymbolIndex::default();
374
375        // Add test file
376        index.files.push(FileEntry {
377            id: FileId::new(0),
378            path: "test.py".to_string(),
379            language: Language::Python,
380            symbols: 0..2,
381            imports: vec![],
382            content_hash: [0u8; 32],
383            lines: 25,
384            tokens: 100,
385        });
386
387        // Add test symbols
388        index.symbols.push(IndexSymbol {
389            id: SymbolId::new(0),
390            name: "main".to_string(),
391            kind: IndexSymbolKind::Function,
392            file_id: FileId::new(0),
393            span: Span { start_line: 1, start_col: 0, end_line: 10, end_col: 0 },
394            signature: Some("def main()".to_string()),
395            parent: None,
396            visibility: Visibility::Public,
397            docstring: None,
398        });
399
400        index.symbols.push(IndexSymbol {
401            id: SymbolId::new(1),
402            name: "helper".to_string(),
403            kind: IndexSymbolKind::Function,
404            file_id: FileId::new(0),
405            span: Span { start_line: 12, start_col: 0, end_line: 20, end_col: 0 },
406            signature: Some("def helper()".to_string()),
407            parent: None,
408            visibility: Visibility::Private,
409            docstring: None,
410        });
411
412        // Build name index
413        index.symbols_by_name.insert("main".to_string(), vec![0]);
414        index.symbols_by_name.insert("helper".to_string(), vec![1]);
415
416        // Create dependency graph with call edge: main -> helper
417        let mut graph = DepGraph::new();
418        graph.add_call(0, 1); // main calls helper
419
420        (index, graph)
421    }
422
423    #[test]
424    fn test_find_symbol() {
425        let (index, _graph) = create_test_index();
426
427        let results = find_symbol(&index, "main");
428        assert_eq!(results.len(), 1);
429        assert_eq!(results[0].name, "main");
430        assert_eq!(results[0].kind, "function");
431        assert_eq!(results[0].file, "test.py");
432    }
433
434    #[test]
435    fn test_get_callers() {
436        let (index, graph) = create_test_index();
437
438        // helper is called by main
439        let callers = get_callers_by_name(&index, &graph, "helper");
440        assert_eq!(callers.len(), 1);
441        assert_eq!(callers[0].name, "main");
442    }
443
444    #[test]
445    fn test_get_callees() {
446        let (index, graph) = create_test_index();
447
448        // main calls helper
449        let callees = get_callees_by_name(&index, &graph, "main");
450        assert_eq!(callees.len(), 1);
451        assert_eq!(callees[0].name, "helper");
452    }
453
454    #[test]
455    fn test_get_call_graph() {
456        let (index, graph) = create_test_index();
457
458        let call_graph = get_call_graph(&index, &graph);
459        assert_eq!(call_graph.nodes.len(), 2);
460        assert_eq!(call_graph.edges.len(), 1);
461        assert_eq!(call_graph.stats.functions, 2);
462
463        // Check edge
464        assert_eq!(call_graph.edges[0].caller, "main");
465        assert_eq!(call_graph.edges[0].callee, "helper");
466    }
467}