infiniloom_engine/index/
query.rs

1//! Call graph query API
2//!
3//! High-level functions for querying call relationships between symbols.
4//! Used by both Python and Node.js bindings.
5
6use super::types::{DepGraph, IndexSymbol, IndexSymbolKind, SymbolIndex, Visibility};
7use serde::Serialize;
8
9/// Information about a symbol, returned from call graph queries
10#[derive(Debug, Clone, Serialize)]
11pub struct SymbolInfo {
12    /// Symbol ID
13    pub id: u32,
14    /// Symbol name
15    pub name: String,
16    /// Symbol kind (function, class, method, etc.)
17    pub kind: String,
18    /// File path containing the symbol
19    pub file: String,
20    /// Start line number
21    pub line: u32,
22    /// End line number
23    pub end_line: u32,
24    /// Function/method signature
25    pub signature: Option<String>,
26    /// Visibility (public, private, etc.)
27    pub visibility: String,
28}
29
30/// A reference location in the codebase
31#[derive(Debug, Clone, Serialize)]
32pub struct ReferenceInfo {
33    /// Symbol making the reference
34    pub symbol: SymbolInfo,
35    /// Reference kind (call, import, inherit, implement)
36    pub kind: String,
37}
38
39/// An edge in the call graph
40#[derive(Debug, Clone, Serialize)]
41pub struct CallGraphEdge {
42    /// Caller symbol ID
43    pub caller_id: u32,
44    /// Callee symbol ID
45    pub callee_id: u32,
46    /// Caller symbol name
47    pub caller: String,
48    /// Callee symbol name
49    pub callee: String,
50    /// File containing the call site
51    pub file: String,
52    /// Line number of the call
53    pub line: u32,
54}
55
56/// Complete call graph with nodes and edges
57#[derive(Debug, Clone, Serialize)]
58pub struct CallGraph {
59    /// All symbols (nodes)
60    pub nodes: Vec<SymbolInfo>,
61    /// Call relationships (edges)
62    pub edges: Vec<CallGraphEdge>,
63    /// Summary statistics
64    pub stats: CallGraphStats,
65}
66
67/// Call graph statistics
68#[derive(Debug, Clone, Serialize)]
69pub struct CallGraphStats {
70    /// Total number of symbols
71    pub total_symbols: usize,
72    /// Total number of call edges
73    pub total_calls: usize,
74    /// Number of functions/methods
75    pub functions: usize,
76    /// Number of classes/structs
77    pub classes: usize,
78}
79
80impl SymbolInfo {
81    /// Create SymbolInfo from an IndexSymbol
82    pub fn from_index_symbol(sym: &IndexSymbol, index: &SymbolIndex) -> Self {
83        let file_path = index
84            .get_file_by_id(sym.file_id.as_u32())
85            .map(|f| f.path.clone())
86            .unwrap_or_else(|| "<unknown>".to_owned());
87
88        Self {
89            id: sym.id.as_u32(),
90            name: sym.name.clone(),
91            kind: format_symbol_kind(sym.kind),
92            file: file_path,
93            line: sym.span.start_line,
94            end_line: sym.span.end_line,
95            signature: sym.signature.clone(),
96            visibility: format_visibility(sym.visibility),
97        }
98    }
99}
100
101/// Find a symbol by name and return its info
102pub fn find_symbol(index: &SymbolIndex, name: &str) -> Vec<SymbolInfo> {
103    index
104        .find_symbols(name)
105        .into_iter()
106        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
107        .collect()
108}
109
110/// Get all callers of a symbol by name
111///
112/// Returns symbols that call any symbol with the given name.
113pub fn get_callers_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
114    let mut callers = Vec::new();
115
116    // Find all symbols with this name
117    for sym in index.find_symbols(name) {
118        let symbol_id = sym.id.as_u32();
119
120        // Get callers from the dependency graph
121        for caller_id in graph.get_callers(symbol_id) {
122            if let Some(caller_sym) = index.get_symbol(caller_id) {
123                callers.push(SymbolInfo::from_index_symbol(caller_sym, index));
124            }
125        }
126    }
127
128    // Deduplicate by symbol ID
129    callers.sort_by_key(|s| s.id);
130    callers.dedup_by_key(|s| s.id);
131
132    callers
133}
134
135/// Get all callees of a symbol by name
136///
137/// Returns symbols that are called by any symbol with the given name.
138pub fn get_callees_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
139    let mut callees = Vec::new();
140
141    // Find all symbols with this name
142    for sym in index.find_symbols(name) {
143        let symbol_id = sym.id.as_u32();
144
145        // Get callees from the dependency graph
146        for callee_id in graph.get_callees(symbol_id) {
147            if let Some(callee_sym) = index.get_symbol(callee_id) {
148                callees.push(SymbolInfo::from_index_symbol(callee_sym, index));
149            }
150        }
151    }
152
153    // Deduplicate by symbol ID
154    callees.sort_by_key(|s| s.id);
155    callees.dedup_by_key(|s| s.id);
156
157    callees
158}
159
160/// Get all references to a symbol by name
161///
162/// Returns symbols that reference any symbol with the given name
163/// (includes calls, imports, inheritance, and implementations).
164pub fn get_references_by_name(
165    index: &SymbolIndex,
166    graph: &DepGraph,
167    name: &str,
168) -> Vec<ReferenceInfo> {
169    let mut references = Vec::new();
170
171    // Find all symbols with this name
172    for sym in index.find_symbols(name) {
173        let symbol_id = sym.id.as_u32();
174
175        // Get callers (call references)
176        for caller_id in graph.get_callers(symbol_id) {
177            if let Some(caller_sym) = index.get_symbol(caller_id) {
178                references.push(ReferenceInfo {
179                    symbol: SymbolInfo::from_index_symbol(caller_sym, index),
180                    kind: "call".to_owned(),
181                });
182            }
183        }
184
185        // Get referencers (symbol_ref - may include imports/inheritance)
186        for ref_id in graph.get_referencers(symbol_id) {
187            if let Some(ref_sym) = index.get_symbol(ref_id) {
188                // Avoid duplicates with callers
189                if !graph.get_callers(symbol_id).contains(&ref_id) {
190                    references.push(ReferenceInfo {
191                        symbol: SymbolInfo::from_index_symbol(ref_sym, index),
192                        kind: "reference".to_owned(),
193                    });
194                }
195            }
196        }
197    }
198
199    // Deduplicate by symbol ID
200    references.sort_by_key(|r| r.symbol.id);
201    references.dedup_by_key(|r| r.symbol.id);
202
203    references
204}
205
206/// Get the complete call graph
207///
208/// Returns all symbols (nodes) and call relationships (edges).
209/// For large codebases, consider using `get_call_graph_filtered` with limits.
210pub fn get_call_graph(index: &SymbolIndex, graph: &DepGraph) -> CallGraph {
211    get_call_graph_filtered(index, graph, None, None)
212}
213
214/// Get a filtered call graph
215///
216/// Args:
217///   - `max_nodes`: Optional limit on number of symbols returned
218///   - `max_edges`: Optional limit on number of edges returned
219pub fn get_call_graph_filtered(
220    index: &SymbolIndex,
221    graph: &DepGraph,
222    max_nodes: Option<usize>,
223    max_edges: Option<usize>,
224) -> CallGraph {
225    // Collect all nodes
226    let mut nodes: Vec<SymbolInfo> = index
227        .symbols
228        .iter()
229        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
230        .collect();
231
232    // Apply node limit if specified
233    if let Some(limit) = max_nodes {
234        nodes.truncate(limit);
235    }
236
237    // Collect node IDs for filtering edges
238    let node_ids: std::collections::HashSet<u32> = nodes.iter().map(|n| n.id).collect();
239
240    // Collect all edges
241    let mut edges: Vec<CallGraphEdge> = graph
242        .calls
243        .iter()
244        .filter(|(caller, callee)| {
245            // Only include edges where both nodes are in our set
246            max_nodes.is_none() || (node_ids.contains(caller) && node_ids.contains(callee))
247        })
248        .filter_map(|&(caller_id, callee_id)| {
249            let caller_sym = index.get_symbol(caller_id)?;
250            let callee_sym = index.get_symbol(callee_id)?;
251
252            let file_path = index
253                .get_file_by_id(caller_sym.file_id.as_u32())
254                .map(|f| f.path.clone())
255                .unwrap_or_else(|| "<unknown>".to_owned());
256
257            Some(CallGraphEdge {
258                caller_id,
259                callee_id,
260                caller: caller_sym.name.clone(),
261                callee: callee_sym.name.clone(),
262                file: file_path,
263                line: caller_sym.span.start_line,
264            })
265        })
266        .collect();
267
268    // Apply edge limit if specified
269    if let Some(limit) = max_edges {
270        edges.truncate(limit);
271    }
272
273    // Calculate statistics
274    let functions = nodes
275        .iter()
276        .filter(|n| n.kind == "function" || n.kind == "method")
277        .count();
278    let classes = nodes
279        .iter()
280        .filter(|n| n.kind == "class" || n.kind == "struct")
281        .count();
282
283    let stats =
284        CallGraphStats { total_symbols: nodes.len(), total_calls: edges.len(), functions, classes };
285
286    CallGraph { nodes, edges, stats }
287}
288
289/// Get callers of a symbol by its ID
290pub fn get_callers_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
291    graph
292        .get_callers(symbol_id)
293        .into_iter()
294        .filter_map(|id| index.get_symbol(id))
295        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
296        .collect()
297}
298
299/// Get callees of a symbol by its ID
300pub fn get_callees_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
301    graph
302        .get_callees(symbol_id)
303        .into_iter()
304        .filter_map(|id| index.get_symbol(id))
305        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
306        .collect()
307}
308
309// Helper functions
310
311fn format_symbol_kind(kind: IndexSymbolKind) -> String {
312    match kind {
313        IndexSymbolKind::Function => "function",
314        IndexSymbolKind::Method => "method",
315        IndexSymbolKind::Class => "class",
316        IndexSymbolKind::Struct => "struct",
317        IndexSymbolKind::Interface => "interface",
318        IndexSymbolKind::Trait => "trait",
319        IndexSymbolKind::Enum => "enum",
320        IndexSymbolKind::Constant => "constant",
321        IndexSymbolKind::Variable => "variable",
322        IndexSymbolKind::Module => "module",
323        IndexSymbolKind::Import => "import",
324        IndexSymbolKind::Export => "export",
325        IndexSymbolKind::TypeAlias => "type_alias",
326        IndexSymbolKind::Macro => "macro",
327    }
328    .to_owned()
329}
330
331fn format_visibility(vis: Visibility) -> String {
332    match vis {
333        Visibility::Public => "public",
334        Visibility::Private => "private",
335        Visibility::Protected => "protected",
336        Visibility::Internal => "internal",
337    }
338    .to_owned()
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::index::types::{FileEntry, FileId, Language, Span, SymbolId};
345
346    fn create_test_index() -> (SymbolIndex, DepGraph) {
347        let mut index = SymbolIndex::default();
348
349        // Add test file
350        index.files.push(FileEntry {
351            id: FileId::new(0),
352            path: "test.py".to_string(),
353            language: Language::Python,
354            symbols: 0..2,
355            imports: vec![],
356            content_hash: [0u8; 32],
357            lines: 25,
358            tokens: 100,
359        });
360
361        // Add test symbols
362        index.symbols.push(IndexSymbol {
363            id: SymbolId::new(0),
364            name: "main".to_string(),
365            kind: IndexSymbolKind::Function,
366            file_id: FileId::new(0),
367            span: Span { start_line: 1, start_col: 0, end_line: 10, end_col: 0 },
368            signature: Some("def main()".to_string()),
369            parent: None,
370            visibility: Visibility::Public,
371            docstring: None,
372        });
373
374        index.symbols.push(IndexSymbol {
375            id: SymbolId::new(1),
376            name: "helper".to_string(),
377            kind: IndexSymbolKind::Function,
378            file_id: FileId::new(0),
379            span: Span { start_line: 12, start_col: 0, end_line: 20, end_col: 0 },
380            signature: Some("def helper()".to_string()),
381            parent: None,
382            visibility: Visibility::Private,
383            docstring: None,
384        });
385
386        // Build name index
387        index.symbols_by_name.insert("main".to_string(), vec![0]);
388        index.symbols_by_name.insert("helper".to_string(), vec![1]);
389
390        // Create dependency graph with call edge: main -> helper
391        let mut graph = DepGraph::new();
392        graph.add_call(0, 1); // main calls helper
393
394        (index, graph)
395    }
396
397    #[test]
398    fn test_find_symbol() {
399        let (index, _graph) = create_test_index();
400
401        let results = find_symbol(&index, "main");
402        assert_eq!(results.len(), 1);
403        assert_eq!(results[0].name, "main");
404        assert_eq!(results[0].kind, "function");
405        assert_eq!(results[0].file, "test.py");
406    }
407
408    #[test]
409    fn test_get_callers() {
410        let (index, graph) = create_test_index();
411
412        // helper is called by main
413        let callers = get_callers_by_name(&index, &graph, "helper");
414        assert_eq!(callers.len(), 1);
415        assert_eq!(callers[0].name, "main");
416    }
417
418    #[test]
419    fn test_get_callees() {
420        let (index, graph) = create_test_index();
421
422        // main calls helper
423        let callees = get_callees_by_name(&index, &graph, "main");
424        assert_eq!(callees.len(), 1);
425        assert_eq!(callees[0].name, "helper");
426    }
427
428    #[test]
429    fn test_get_call_graph() {
430        let (index, graph) = create_test_index();
431
432        let call_graph = get_call_graph(&index, &graph);
433        assert_eq!(call_graph.nodes.len(), 2);
434        assert_eq!(call_graph.edges.len(), 1);
435        assert_eq!(call_graph.stats.functions, 2);
436
437        // Check edge
438        assert_eq!(call_graph.edges[0].caller, "main");
439        assert_eq!(call_graph.edges[0].callee, "helper");
440    }
441}