infiniloom_engine/index/
query.rs

1//! Call graph query API
2//!
3//! High-level functions for querying call relationships between symbols.
4//! Used by both Python and Node.js bindings.
5
6use super::types::{DepGraph, IndexSymbol, IndexSymbolKind, SymbolIndex, Visibility};
7use serde::Serialize;
8
9/// Information about a symbol, returned from call graph queries
10#[derive(Debug, Clone, Serialize)]
11pub struct SymbolInfo {
12    /// Symbol ID
13    pub id: u32,
14    /// Symbol name
15    pub name: String,
16    /// Symbol kind (function, class, method, etc.)
17    pub kind: String,
18    /// File path containing the symbol
19    pub file: String,
20    /// Start line number
21    pub line: u32,
22    /// End line number
23    pub end_line: u32,
24    /// Function/method signature
25    pub signature: Option<String>,
26    /// Visibility (public, private, etc.)
27    pub visibility: String,
28}
29
30/// A reference location in the codebase
31#[derive(Debug, Clone, Serialize)]
32pub struct ReferenceInfo {
33    /// Symbol making the reference
34    pub symbol: SymbolInfo,
35    /// Reference kind (call, import, inherit, implement)
36    pub kind: String,
37}
38
39/// An edge in the call graph
40#[derive(Debug, Clone, Serialize)]
41pub struct CallGraphEdge {
42    /// Caller symbol ID
43    pub caller_id: u32,
44    /// Callee symbol ID
45    pub callee_id: u32,
46    /// Caller symbol name
47    pub caller: String,
48    /// Callee symbol name
49    pub callee: String,
50    /// File containing the call site
51    pub file: String,
52    /// Line number of the call
53    pub line: u32,
54}
55
56/// Complete call graph with nodes and edges
57#[derive(Debug, Clone, Serialize)]
58pub struct CallGraph {
59    /// All symbols (nodes)
60    pub nodes: Vec<SymbolInfo>,
61    /// Call relationships (edges)
62    pub edges: Vec<CallGraphEdge>,
63    /// Summary statistics
64    pub stats: CallGraphStats,
65}
66
67/// Call graph statistics
68#[derive(Debug, Clone, Serialize)]
69pub struct CallGraphStats {
70    /// Total number of symbols
71    pub total_symbols: usize,
72    /// Total number of call edges
73    pub total_calls: usize,
74    /// Number of functions/methods
75    pub functions: usize,
76    /// Number of classes/structs
77    pub classes: usize,
78}
79
80impl SymbolInfo {
81    /// Create SymbolInfo from an IndexSymbol
82    pub fn from_index_symbol(sym: &IndexSymbol, index: &SymbolIndex) -> Self {
83        let file_path = index
84            .get_file_by_id(sym.file_id.as_u32())
85            .map(|f| f.path.clone())
86            .unwrap_or_else(|| "<unknown>".to_owned());
87
88        Self {
89            id: sym.id.as_u32(),
90            name: sym.name.clone(),
91            kind: format_symbol_kind(sym.kind),
92            file: file_path,
93            line: sym.span.start_line,
94            end_line: sym.span.end_line,
95            signature: sym.signature.clone(),
96            visibility: format_visibility(sym.visibility),
97        }
98    }
99}
100
101/// Find a symbol by name and return its info
102pub fn find_symbol(index: &SymbolIndex, name: &str) -> Vec<SymbolInfo> {
103    index
104        .find_symbols(name)
105        .into_iter()
106        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
107        .collect()
108}
109
110/// Get all callers of a symbol by name
111///
112/// Returns symbols that call any symbol with the given name.
113pub fn get_callers_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
114    let mut callers = Vec::new();
115
116    // Find all symbols with this name
117    for sym in index.find_symbols(name) {
118        let symbol_id = sym.id.as_u32();
119
120        // Get callers from the dependency graph
121        for caller_id in graph.get_callers(symbol_id) {
122            if let Some(caller_sym) = index.get_symbol(caller_id) {
123                callers.push(SymbolInfo::from_index_symbol(caller_sym, index));
124            }
125        }
126    }
127
128    // Deduplicate by symbol ID
129    callers.sort_by_key(|s| s.id);
130    callers.dedup_by_key(|s| s.id);
131
132    callers
133}
134
135/// Get all callees of a symbol by name
136///
137/// Returns symbols that are called by any symbol with the given name.
138pub fn get_callees_by_name(index: &SymbolIndex, graph: &DepGraph, name: &str) -> Vec<SymbolInfo> {
139    let mut callees = Vec::new();
140
141    // Find all symbols with this name
142    for sym in index.find_symbols(name) {
143        let symbol_id = sym.id.as_u32();
144
145        // Get callees from the dependency graph
146        for callee_id in graph.get_callees(symbol_id) {
147            if let Some(callee_sym) = index.get_symbol(callee_id) {
148                callees.push(SymbolInfo::from_index_symbol(callee_sym, index));
149            }
150        }
151    }
152
153    // Deduplicate by symbol ID
154    callees.sort_by_key(|s| s.id);
155    callees.dedup_by_key(|s| s.id);
156
157    callees
158}
159
160/// Get all references to a symbol by name
161///
162/// Returns symbols that reference any symbol with the given name
163/// (includes calls, imports, inheritance, and implementations).
164pub fn get_references_by_name(
165    index: &SymbolIndex,
166    graph: &DepGraph,
167    name: &str,
168) -> Vec<ReferenceInfo> {
169    let mut references = Vec::new();
170
171    // Find all symbols with this name
172    for sym in index.find_symbols(name) {
173        let symbol_id = sym.id.as_u32();
174
175        // Get callers (call references)
176        for caller_id in graph.get_callers(symbol_id) {
177            if let Some(caller_sym) = index.get_symbol(caller_id) {
178                references.push(ReferenceInfo {
179                    symbol: SymbolInfo::from_index_symbol(caller_sym, index),
180                    kind: "call".to_owned(),
181                });
182            }
183        }
184
185        // Get referencers (symbol_ref - may include imports/inheritance)
186        for ref_id in graph.get_referencers(symbol_id) {
187            if let Some(ref_sym) = index.get_symbol(ref_id) {
188                // Avoid duplicates with callers
189                if !graph.get_callers(symbol_id).contains(&ref_id) {
190                    references.push(ReferenceInfo {
191                        symbol: SymbolInfo::from_index_symbol(ref_sym, index),
192                        kind: "reference".to_owned(),
193                    });
194                }
195            }
196        }
197    }
198
199    // Deduplicate by symbol ID
200    references.sort_by_key(|r| r.symbol.id);
201    references.dedup_by_key(|r| r.symbol.id);
202
203    references
204}
205
206/// Get the complete call graph
207///
208/// Returns all symbols (nodes) and call relationships (edges).
209/// For large codebases, consider using `get_call_graph_filtered` with limits.
210pub fn get_call_graph(index: &SymbolIndex, graph: &DepGraph) -> CallGraph {
211    get_call_graph_filtered(index, graph, None, None)
212}
213
214/// Get a filtered call graph
215///
216/// Args:
217///   - `max_nodes`: Optional limit on number of symbols returned
218///   - `max_edges`: Optional limit on number of edges returned
219pub fn get_call_graph_filtered(
220    index: &SymbolIndex,
221    graph: &DepGraph,
222    max_nodes: Option<usize>,
223    max_edges: Option<usize>,
224) -> CallGraph {
225    // Collect all nodes
226    let mut nodes: Vec<SymbolInfo> = index
227        .symbols
228        .iter()
229        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
230        .collect();
231
232    // Apply node limit if specified
233    if let Some(limit) = max_nodes {
234        nodes.truncate(limit);
235    }
236
237    // Collect node IDs for filtering edges
238    let node_ids: std::collections::HashSet<u32> = nodes.iter().map(|n| n.id).collect();
239
240    // Collect all edges
241    let mut edges: Vec<CallGraphEdge> = graph
242        .calls
243        .iter()
244        .filter(|(caller, callee)| {
245            // Only include edges where both nodes are in our set
246            max_nodes.is_none() || (node_ids.contains(caller) && node_ids.contains(callee))
247        })
248        .filter_map(|&(caller_id, callee_id)| {
249            let caller_sym = index.get_symbol(caller_id)?;
250            let callee_sym = index.get_symbol(callee_id)?;
251
252            let file_path = index
253                .get_file_by_id(caller_sym.file_id.as_u32())
254                .map(|f| f.path.clone())
255                .unwrap_or_else(|| "<unknown>".to_owned());
256
257            Some(CallGraphEdge {
258                caller_id,
259                callee_id,
260                caller: caller_sym.name.clone(),
261                callee: callee_sym.name.clone(),
262                file: file_path,
263                line: caller_sym.span.start_line,
264            })
265        })
266        .collect();
267
268    // Apply edge limit if specified
269    if let Some(limit) = max_edges {
270        edges.truncate(limit);
271    }
272
273    // Calculate statistics
274    let functions = nodes
275        .iter()
276        .filter(|n| n.kind == "function" || n.kind == "method")
277        .count();
278    let classes = nodes
279        .iter()
280        .filter(|n| n.kind == "class" || n.kind == "struct")
281        .count();
282
283    let stats =
284        CallGraphStats { total_symbols: nodes.len(), total_calls: edges.len(), functions, classes };
285
286    CallGraph { nodes, edges, stats }
287}
288
289/// Get callers of a symbol by its ID
290pub fn get_callers_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
291    graph
292        .get_callers(symbol_id)
293        .into_iter()
294        .filter_map(|id| index.get_symbol(id))
295        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
296        .collect()
297}
298
299/// Get callees of a symbol by its ID
300pub fn get_callees_by_id(index: &SymbolIndex, graph: &DepGraph, symbol_id: u32) -> Vec<SymbolInfo> {
301    graph
302        .get_callees(symbol_id)
303        .into_iter()
304        .filter_map(|id| index.get_symbol(id))
305        .map(|sym| SymbolInfo::from_index_symbol(sym, index))
306        .collect()
307}
308
309// Helper functions
310
311fn format_symbol_kind(kind: IndexSymbolKind) -> String {
312    match kind {
313        IndexSymbolKind::Function => "function",
314        IndexSymbolKind::Method => "method",
315        IndexSymbolKind::Class => "class",
316        IndexSymbolKind::Struct => "struct",
317        IndexSymbolKind::Interface => "interface",
318        IndexSymbolKind::Trait => "trait",
319        IndexSymbolKind::Enum => "enum",
320        IndexSymbolKind::Constant => "constant",
321        IndexSymbolKind::Variable => "variable",
322        IndexSymbolKind::Module => "module",
323        IndexSymbolKind::Import => "import",
324        IndexSymbolKind::Export => "export",
325        IndexSymbolKind::TypeAlias => "type_alias",
326        IndexSymbolKind::Macro => "macro",
327    }.to_owned()
328}
329
330fn format_visibility(vis: Visibility) -> String {
331    match vis {
332        Visibility::Public => "public",
333        Visibility::Private => "private",
334        Visibility::Protected => "protected",
335        Visibility::Internal => "internal",
336    }.to_owned()
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::index::types::{FileEntry, FileId, Language, Span, SymbolId};
343
344    fn create_test_index() -> (SymbolIndex, DepGraph) {
345        let mut index = SymbolIndex::default();
346
347        // Add test file
348        index.files.push(FileEntry {
349            id: FileId::new(0),
350            path: "test.py".to_string(),
351            language: Language::Python,
352            symbols: 0..2,
353            imports: vec![],
354            content_hash: [0u8; 32],
355            lines: 25,
356            tokens: 100,
357        });
358
359        // Add test symbols
360        index.symbols.push(IndexSymbol {
361            id: SymbolId::new(0),
362            name: "main".to_string(),
363            kind: IndexSymbolKind::Function,
364            file_id: FileId::new(0),
365            span: Span { start_line: 1, start_col: 0, end_line: 10, end_col: 0 },
366            signature: Some("def main()".to_string()),
367            parent: None,
368            visibility: Visibility::Public,
369            docstring: None,
370        });
371
372        index.symbols.push(IndexSymbol {
373            id: SymbolId::new(1),
374            name: "helper".to_string(),
375            kind: IndexSymbolKind::Function,
376            file_id: FileId::new(0),
377            span: Span { start_line: 12, start_col: 0, end_line: 20, end_col: 0 },
378            signature: Some("def helper()".to_string()),
379            parent: None,
380            visibility: Visibility::Private,
381            docstring: None,
382        });
383
384        // Build name index
385        index.symbols_by_name.insert("main".to_string(), vec![0]);
386        index.symbols_by_name.insert("helper".to_string(), vec![1]);
387
388        // Create dependency graph with call edge: main -> helper
389        let mut graph = DepGraph::new();
390        graph.add_call(0, 1); // main calls helper
391
392        (index, graph)
393    }
394
395    #[test]
396    fn test_find_symbol() {
397        let (index, _graph) = create_test_index();
398
399        let results = find_symbol(&index, "main");
400        assert_eq!(results.len(), 1);
401        assert_eq!(results[0].name, "main");
402        assert_eq!(results[0].kind, "function");
403        assert_eq!(results[0].file, "test.py");
404    }
405
406    #[test]
407    fn test_get_callers() {
408        let (index, graph) = create_test_index();
409
410        // helper is called by main
411        let callers = get_callers_by_name(&index, &graph, "helper");
412        assert_eq!(callers.len(), 1);
413        assert_eq!(callers[0].name, "main");
414    }
415
416    #[test]
417    fn test_get_callees() {
418        let (index, graph) = create_test_index();
419
420        // main calls helper
421        let callees = get_callees_by_name(&index, &graph, "main");
422        assert_eq!(callees.len(), 1);
423        assert_eq!(callees[0].name, "helper");
424    }
425
426    #[test]
427    fn test_get_call_graph() {
428        let (index, graph) = create_test_index();
429
430        let call_graph = get_call_graph(&index, &graph);
431        assert_eq!(call_graph.nodes.len(), 2);
432        assert_eq!(call_graph.edges.len(), 1);
433        assert_eq!(call_graph.stats.functions, 2);
434
435        // Check edge
436        assert_eq!(call_graph.edges[0].caller, "main");
437        assert_eq!(call_graph.edges[0].callee, "helper");
438    }
439}