greppy/trace/
traverse.rs

1//! Graph Traversal Algorithms
2//!
3//! Provides algorithms for navigating the semantic index:
4//! - BFS backward traversal from target to entry points
5//! - Reference finding
6//! - Path reconstruction
7//!
8//! @module trace/traverse
9
10use std::collections::{HashSet, VecDeque};
11
12use super::index::SemanticIndex;
13use super::types::{RefKind, Reference, Symbol};
14
15// =============================================================================
16// INVOCATION PATH
17// =============================================================================
18
19/// A single invocation path from an entry point to a target symbol
20#[derive(Debug, Clone)]
21pub struct InvocationPath {
22    /// The entry point symbol (starting point of the call chain)
23    pub entry_point: u32,
24    /// Chain of symbol IDs from entry point to target (inclusive)
25    /// First element is entry_point, last is target
26    pub chain: Vec<u32>,
27    /// Line numbers where each call occurs (len = chain.len() - 1)
28    pub call_lines: Vec<u32>,
29}
30
31impl InvocationPath {
32    /// Create a new invocation path
33    pub fn new(entry_point: u32, chain: Vec<u32>, call_lines: Vec<u32>) -> Self {
34        Self {
35            entry_point,
36            chain,
37            call_lines,
38        }
39    }
40
41    /// Get the target symbol (last in chain)
42    pub fn target(&self) -> u32 {
43        *self.chain.last().unwrap_or(&self.entry_point)
44    }
45
46    /// Get the depth of the call chain (0 = direct call from entry point)
47    pub fn depth(&self) -> usize {
48        self.chain.len().saturating_sub(1)
49    }
50
51    /// Check if this is a direct call (entry point directly calls target)
52    pub fn is_direct(&self) -> bool {
53        self.chain.len() <= 2
54    }
55}
56
57// =============================================================================
58// TRACE RESULT
59// =============================================================================
60
61/// Result of a trace operation
62#[derive(Debug, Clone)]
63pub struct TraceResult {
64    /// Target symbol ID that was traced
65    pub target: u32,
66    /// All invocation paths to the target
67    pub paths: Vec<InvocationPath>,
68    /// Symbols that were visited during traversal
69    pub visited_count: usize,
70}
71
72impl TraceResult {
73    /// Check if any paths were found
74    pub fn has_paths(&self) -> bool {
75        !self.paths.is_empty()
76    }
77
78    /// Get the number of unique entry points
79    pub fn entry_point_count(&self) -> usize {
80        let entries: HashSet<_> = self.paths.iter().map(|p| p.entry_point).collect();
81        entries.len()
82    }
83
84    /// Get the shortest path depth
85    pub fn min_depth(&self) -> Option<usize> {
86        self.paths.iter().map(|p| p.depth()).min()
87    }
88
89    /// Get the longest path depth
90    pub fn max_depth(&self) -> Option<usize> {
91        self.paths.iter().map(|p| p.depth()).max()
92    }
93}
94
95// =============================================================================
96// TRACE SYMBOL
97// =============================================================================
98
99/// Trace all invocation paths from entry points to a target symbol
100///
101/// Uses BFS backward traversal from the target to find all entry points
102/// that can reach it, then reconstructs the paths.
103///
104/// # Arguments
105/// * `index` - The semantic index to search
106/// * `target_id` - The symbol ID to trace to
107/// * `max_depth` - Maximum call chain depth (default: 50)
108///
109/// # Returns
110/// A `TraceResult` containing all invocation paths
111pub fn trace_symbol(
112    index: &SemanticIndex,
113    target_id: u32,
114    max_depth: Option<usize>,
115) -> TraceResult {
116    let max_depth = max_depth.unwrap_or(50);
117
118    // Validate target exists
119    if index.symbol(target_id).is_none() {
120        return TraceResult {
121            target: target_id,
122            paths: Vec::new(),
123            visited_count: 0,
124        };
125    }
126
127    // BFS backward traversal
128    let mut visited: HashSet<u32> = HashSet::new();
129    let mut queue: VecDeque<(u32, Vec<u32>, Vec<u32>)> = VecDeque::new();
130    let mut paths: Vec<InvocationPath> = Vec::new();
131
132    // Start from target
133    queue.push_back((target_id, vec![target_id], Vec::new()));
134    visited.insert(target_id);
135
136    while let Some((current, chain, call_lines)) = queue.pop_front() {
137        // Check depth limit
138        if chain.len() > max_depth {
139            continue;
140        }
141
142        // Get symbol info
143        let symbol = match index.symbol(current) {
144            Some(s) => s,
145            None => continue,
146        };
147
148        // Check if this is an entry point
149        if symbol.is_entry_point() {
150            // Found a complete path - reverse it so entry point is first
151            let mut path_chain = chain.clone();
152            let mut path_lines = call_lines.clone();
153            path_chain.reverse();
154            path_lines.reverse();
155
156            paths.push(InvocationPath::new(current, path_chain, path_lines));
157        }
158
159        // Find all callers (symbols that call this one)
160        for &caller_id in index.callers(current) {
161            if visited.contains(&caller_id) {
162                // Skip already visited to avoid cycles
163                // But we might want to still record the path if it reaches an entry point
164                continue;
165            }
166
167            // Find the edge to get the call line
168            let call_line = index
169                .edges
170                .iter()
171                .find(|e| e.from_symbol == caller_id && e.to_symbol == current)
172                .map(|e| e.line)
173                .unwrap_or(0);
174
175            // Extend the chain
176            let mut new_chain = chain.clone();
177            new_chain.push(caller_id);
178
179            let mut new_lines = call_lines.clone();
180            new_lines.push(call_line);
181
182            visited.insert(caller_id);
183            queue.push_back((caller_id, new_chain, new_lines));
184        }
185    }
186
187    TraceResult {
188        target: target_id,
189        paths,
190        visited_count: visited.len(),
191    }
192}
193
194/// Trace a symbol by name
195///
196/// Finds all symbols matching the name and traces each one.
197pub fn trace_symbol_by_name(
198    index: &SemanticIndex,
199    name: &str,
200    max_depth: Option<usize>,
201) -> Vec<TraceResult> {
202    let symbol_ids = match index.symbols_by_name(name) {
203        Some(ids) => ids.clone(),
204        None => return Vec::new(),
205    };
206
207    symbol_ids
208        .iter()
209        .map(|&id| trace_symbol(index, id, max_depth))
210        .collect()
211}
212
213// =============================================================================
214// FIND REFERENCES
215// =============================================================================
216
217/// Reference with full context
218#[derive(Debug, Clone)]
219pub struct ReferenceContext {
220    /// The reference
221    pub reference: Reference,
222    /// Token ID
223    pub token_id: u32,
224    /// File ID where the reference occurs
225    pub file_id: u16,
226    /// Line number
227    pub line: u32,
228    /// Column number
229    pub column: u16,
230    /// Scope ID
231    pub scope_id: u32,
232    /// Symbol name (if available)
233    pub symbol_name: Option<String>,
234}
235
236/// Find all references to a symbol
237pub fn find_refs(index: &SemanticIndex, symbol_id: u32) -> Vec<ReferenceContext> {
238    let mut results = Vec::new();
239
240    for reference in index.references_to(symbol_id) {
241        if let Some(token) = index.token(reference.token_id) {
242            let symbol_name = index
243                .symbol(symbol_id)
244                .and_then(|s| index.symbol_name(s))
245                .map(|s| s.to_string());
246
247            results.push(ReferenceContext {
248                reference: *reference,
249                token_id: reference.token_id,
250                file_id: token.file_id,
251                line: token.line,
252                column: token.column,
253                scope_id: token.scope_id,
254                symbol_name,
255            });
256        }
257    }
258
259    results
260}
261
262/// Find references of a specific kind
263pub fn find_refs_of_kind(
264    index: &SemanticIndex,
265    symbol_id: u32,
266    kind: RefKind,
267) -> Vec<ReferenceContext> {
268    find_refs(index, symbol_id)
269        .into_iter()
270        .filter(|r| r.reference.ref_kind() == kind)
271        .collect()
272}
273
274/// Find call references to a symbol
275pub fn find_call_refs(index: &SemanticIndex, symbol_id: u32) -> Vec<ReferenceContext> {
276    find_refs_of_kind(index, symbol_id, RefKind::Call)
277}
278
279/// Find read references to a symbol
280pub fn find_read_refs(index: &SemanticIndex, symbol_id: u32) -> Vec<ReferenceContext> {
281    find_refs_of_kind(index, symbol_id, RefKind::Read)
282}
283
284/// Find write references to a symbol
285pub fn find_write_refs(index: &SemanticIndex, symbol_id: u32) -> Vec<ReferenceContext> {
286    find_refs_of_kind(index, symbol_id, RefKind::Write)
287}
288
289// =============================================================================
290// DEAD CODE DETECTION
291// =============================================================================
292
293/// Find potentially dead symbols (no incoming references or calls)
294///
295/// Returns symbols that:
296/// - Are not entry points
297/// - Have no incoming edges (no one calls them)
298/// - Have no references
299pub fn find_dead_symbols(index: &SemanticIndex) -> Vec<&Symbol> {
300    index
301        .symbols
302        .iter()
303        .filter(|s| {
304            // Skip entry points - they're supposed to be "uncalled"
305            if s.is_entry_point() {
306                return false;
307            }
308
309            // Check for incoming edges
310            let has_callers = !index.callers(s.id).is_empty();
311            if has_callers {
312                return false;
313            }
314
315            // Check for references
316            let has_refs = index.references_to(s.id).next().is_some();
317            !has_refs
318        })
319        .collect()
320}
321
322// =============================================================================
323// CALL CHAIN HELPERS
324// =============================================================================
325
326/// Format a call chain as a string
327pub fn format_call_chain(index: &SemanticIndex, chain: &[u32]) -> String {
328    chain
329        .iter()
330        .filter_map(|&id| index.symbol(id).and_then(|s| index.symbol_name(s)))
331        .collect::<Vec<_>>()
332        .join(" -> ")
333}
334
335/// Format an invocation path with file locations
336pub fn format_invocation_path(index: &SemanticIndex, path: &InvocationPath) -> String {
337    let mut result = String::new();
338
339    for (i, &symbol_id) in path.chain.iter().enumerate() {
340        if let Some(symbol) = index.symbol(symbol_id) {
341            let name = index.symbol_name(symbol).unwrap_or("<unknown>");
342            let file = index
343                .file_path(symbol.file_id)
344                .map(|p| p.to_string_lossy().to_string())
345                .unwrap_or_else(|| "<unknown>".into());
346            let start_line = symbol.start_line;
347
348            if i == 0 {
349                result.push_str(&format!("{} ({}:{})", name, file, start_line));
350            } else {
351                let call_line = path.call_lines.get(i - 1).copied().unwrap_or(0);
352                result.push_str(&format!(
353                    "\n  -> {} ({}:{}) [called at line {}]",
354                    name, file, start_line, call_line
355                ));
356            }
357        }
358    }
359
360    result
361}
362
363// =============================================================================
364// TESTS
365// =============================================================================
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::trace::types::{Edge, SymbolFlags, SymbolKind};
371
372    fn create_test_index() -> SemanticIndex {
373        let mut index = SemanticIndex::new();
374
375        // Add a file
376        let file_id = index.add_file("test.rs".into());
377
378        // Create symbols: main -> a -> b -> target
379        let names: Vec<_> = ["main", "a", "b", "target"]
380            .iter()
381            .map(|n| index.strings.intern(n))
382            .collect();
383
384        // main is entry point
385        index.add_symbol(
386            Symbol::new(
387                0,
388                names[0],
389                file_id,
390                SymbolKind::Function,
391                SymbolFlags::IS_ENTRY_POINT,
392                1,
393                10,
394            ),
395            "main",
396        );
397
398        // Other functions
399        for (i, name) in ["a", "b", "target"].iter().enumerate() {
400            index.add_symbol(
401                Symbol::new(
402                    (i + 1) as u32,
403                    names[i + 1],
404                    file_id,
405                    SymbolKind::Function,
406                    SymbolFlags::empty(),
407                    ((i + 1) * 10 + 1) as u32,
408                    ((i + 2) * 10) as u32,
409                ),
410                name,
411            );
412        }
413
414        // Add call edges: main -> a -> b -> target
415        index.add_edge(Edge::new(0, 1, 5)); // main calls a
416        index.add_edge(Edge::new(1, 2, 15)); // a calls b
417        index.add_edge(Edge::new(2, 3, 25)); // b calls target
418
419        index
420    }
421
422    #[test]
423    fn test_trace_symbol() {
424        let index = create_test_index();
425
426        // Trace to "target" (id=3)
427        let result = trace_symbol(&index, 3, None);
428
429        assert!(result.has_paths());
430        assert_eq!(result.paths.len(), 1);
431
432        let path = &result.paths[0];
433        assert_eq!(path.entry_point, 0); // main
434        assert_eq!(path.chain, vec![0, 1, 2, 3]); // main -> a -> b -> target
435        assert_eq!(path.depth(), 3);
436    }
437
438    #[test]
439    fn test_trace_direct_call() {
440        let index = create_test_index();
441
442        // Trace to "a" (id=1) - directly called by main
443        let result = trace_symbol(&index, 1, None);
444
445        assert!(result.has_paths());
446        assert_eq!(result.paths.len(), 1);
447
448        let path = &result.paths[0];
449        assert_eq!(path.chain, vec![0, 1]); // main -> a
450        assert!(path.is_direct());
451    }
452
453    #[test]
454    fn test_trace_nonexistent() {
455        let index = create_test_index();
456
457        // Trace to nonexistent symbol
458        let result = trace_symbol(&index, 999, None);
459
460        assert!(!result.has_paths());
461        assert_eq!(result.visited_count, 0);
462    }
463
464    #[test]
465    fn test_format_call_chain() {
466        let index = create_test_index();
467        let chain = vec![0, 1, 2, 3];
468
469        let formatted = format_call_chain(&index, &chain);
470        assert_eq!(formatted, "main -> a -> b -> target");
471    }
472
473    #[test]
474    fn test_find_dead_symbols() {
475        let mut index = SemanticIndex::new();
476        let file_id = index.add_file("test.rs".into());
477
478        // Add an entry point
479        let name1 = index.strings.intern("main");
480        index.add_symbol(
481            Symbol::new(
482                0,
483                name1,
484                file_id,
485                SymbolKind::Function,
486                SymbolFlags::IS_ENTRY_POINT,
487                1,
488                10,
489            ),
490            "main",
491        );
492
493        // Add a called function
494        let name2 = index.strings.intern("used");
495        index.add_symbol(
496            Symbol::new(
497                1,
498                name2,
499                file_id,
500                SymbolKind::Function,
501                SymbolFlags::empty(),
502                15,
503                25,
504            ),
505            "used",
506        );
507
508        // Add a dead function (never called)
509        let name3 = index.strings.intern("dead");
510        index.add_symbol(
511            Symbol::new(
512                2,
513                name3,
514                file_id,
515                SymbolKind::Function,
516                SymbolFlags::empty(),
517                30,
518                40,
519            ),
520            "dead",
521        );
522
523        // main calls used
524        index.add_edge(Edge::new(0, 1, 5));
525
526        let dead = find_dead_symbols(&index);
527        assert_eq!(dead.len(), 1);
528        assert_eq!(dead[0].id, 2);
529    }
530}