Skip to main content

lean_ctx/core/
call_graph.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9use super::deep_queries;
10use super::graph_index::{normalize_project_root, ProjectIndex, SymbolEntry};
11
12// ---------------------------------------------------------------------------
13// Data types
14// ---------------------------------------------------------------------------
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallGraph {
18    pub project_root: String,
19    pub edges: Vec<CallEdge>,
20    pub file_hashes: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CallEdge {
25    pub caller_file: String,
26    pub caller_symbol: String,
27    pub caller_line: usize,
28    pub callee_name: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct BfsNode {
33    pub symbol: String,
34    pub file: String,
35    pub line: usize,
36    pub depth: usize,
37    pub from_symbol: String,
38}
39
40#[derive(Debug, Clone)]
41pub struct PathHop {
42    pub symbol: String,
43    pub file: String,
44    pub line: usize,
45}
46
47#[derive(Clone, Copy)]
48enum BfsDirection {
49    Callers,
50    Callees,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RiskLevel {
55    Low,
56    Medium,
57    High,
58    Critical,
59}
60
61impl RiskLevel {
62    pub fn from_caller_count(count: usize) -> Self {
63        match count {
64            0..=1 => Self::Low,
65            2..=4 => Self::Medium,
66            5..=10 => Self::High,
67            _ => Self::Critical,
68        }
69    }
70
71    pub fn label(self) -> &'static str {
72        match self {
73            Self::Low => "LOW",
74            Self::Medium => "MEDIUM",
75            Self::High => "HIGH",
76            Self::Critical => "CRITICAL",
77        }
78    }
79}
80
81// ---------------------------------------------------------------------------
82// Background build state (singleton per process)
83// ---------------------------------------------------------------------------
84
85#[derive(Debug, Clone, Serialize)]
86pub struct BuildProgress {
87    pub status: &'static str,
88    pub files_total: usize,
89    pub files_done: usize,
90    pub edges_found: usize,
91}
92
93enum BuildState {
94    Idle,
95    Building {
96        files_total: usize,
97        files_done: Arc<AtomicUsize>,
98        edges_found: Arc<AtomicUsize>,
99    },
100    Ready(Arc<CallGraph>),
101    Failed(String),
102}
103
104static BUILD: OnceLock<Mutex<BuildState>> = OnceLock::new();
105
106fn global_state() -> &'static Mutex<BuildState> {
107    BUILD.get_or_init(|| Mutex::new(BuildState::Idle))
108}
109
110impl CallGraph {
111    pub fn new(project_root: &str) -> Self {
112        Self {
113            project_root: normalize_project_root(project_root),
114            edges: Vec::new(),
115            file_hashes: HashMap::new(),
116        }
117    }
118
119    // -----------------------------------------------------------------------
120    // Parallel build — processes files via rayon thread pool
121    // -----------------------------------------------------------------------
122
123    pub fn build_parallel(
124        index: &ProjectIndex,
125        progress: Option<(&AtomicUsize, &AtomicUsize)>,
126    ) -> Self {
127        let project_root = &index.project_root;
128        let symbols_by_file = group_symbols_by_file_owned(index);
129        let file_keys: Vec<String> = index.files.keys().cloned().collect();
130
131        let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
132            .par_iter()
133            .filter_map(|rel_path| {
134                let abs_path = resolve_path(rel_path, project_root);
135                let content = std::fs::read_to_string(&abs_path).ok()?;
136                let hash = simple_hash(&content);
137
138                let ext = Path::new(rel_path)
139                    .extension()
140                    .and_then(|e| e.to_str())
141                    .unwrap_or("");
142
143                let analysis = deep_queries::analyze(&content, ext);
144                let file_symbols = symbols_by_file.get(rel_path.as_str());
145
146                let edges: Vec<CallEdge> = analysis
147                    .calls
148                    .iter()
149                    .map(|call| {
150                        let caller_sym = find_enclosing_symbol_owned(file_symbols, call.line + 1);
151                        CallEdge {
152                            caller_file: rel_path.clone(),
153                            caller_symbol: caller_sym,
154                            caller_line: call.line + 1,
155                            callee_name: call.callee.clone(),
156                        }
157                    })
158                    .collect();
159
160                if let Some((done, edge_count)) = progress {
161                    done.fetch_add(1, Ordering::Relaxed);
162                    edge_count.fetch_add(edges.len(), Ordering::Relaxed);
163                }
164
165                Some((rel_path.clone(), hash, edges))
166            })
167            .collect();
168
169        let mut graph = Self::new(project_root);
170        let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
171        graph.edges.reserve(edge_capacity);
172        graph.file_hashes.reserve(results.len());
173
174        for (path, hash, edges) in results {
175            graph.file_hashes.insert(path, hash);
176            graph.edges.extend(edges);
177        }
178
179        graph
180    }
181
182    // -----------------------------------------------------------------------
183    // Incremental parallel build — only re-analyzes changed files
184    // -----------------------------------------------------------------------
185
186    pub fn build_incremental_parallel(
187        index: &ProjectIndex,
188        previous: &CallGraph,
189        progress: Option<(&AtomicUsize, &AtomicUsize)>,
190    ) -> Self {
191        let project_root = &index.project_root;
192        let symbols_by_file = group_symbols_by_file_owned(index);
193        let file_keys: Vec<String> = index.files.keys().cloned().collect();
194
195        let prev_edges_by_file = group_edges_by_file(&previous.edges);
196
197        let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
198            .par_iter()
199            .filter_map(|rel_path| {
200                let abs_path = resolve_path(rel_path, project_root);
201                let content = std::fs::read_to_string(&abs_path).ok()?;
202                let hash = simple_hash(&content);
203                let changed = previous.file_hashes.get(rel_path.as_str()) != Some(&hash);
204
205                let edges = if changed {
206                    let ext = Path::new(rel_path)
207                        .extension()
208                        .and_then(|e| e.to_str())
209                        .unwrap_or("");
210
211                    let analysis = deep_queries::analyze(&content, ext);
212                    let file_symbols = symbols_by_file.get(rel_path.as_str());
213
214                    analysis
215                        .calls
216                        .iter()
217                        .map(|call| {
218                            let caller_sym =
219                                find_enclosing_symbol_owned(file_symbols, call.line + 1);
220                            CallEdge {
221                                caller_file: rel_path.clone(),
222                                caller_symbol: caller_sym,
223                                caller_line: call.line + 1,
224                                callee_name: call.callee.clone(),
225                            }
226                        })
227                        .collect()
228                } else {
229                    prev_edges_by_file
230                        .get(rel_path.as_str())
231                        .cloned()
232                        .unwrap_or_default()
233                };
234
235                if let Some((done, edge_count)) = progress {
236                    done.fetch_add(1, Ordering::Relaxed);
237                    edge_count.fetch_add(edges.len(), Ordering::Relaxed);
238                }
239
240                Some((rel_path.clone(), hash, edges))
241            })
242            .collect();
243
244        let mut graph = Self::new(project_root);
245        let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
246        graph.edges.reserve(edge_capacity);
247        graph.file_hashes.reserve(results.len());
248
249        for (path, hash, edges) in results {
250            graph.file_hashes.insert(path, hash);
251            graph.edges.extend(edges);
252        }
253
254        graph
255    }
256
257    // -----------------------------------------------------------------------
258    // Public API: non-blocking access for the dashboard
259    // -----------------------------------------------------------------------
260
261    /// Returns the cached graph immediately, or `None` + starts a background build.
262    pub fn get_or_start_build(
263        project_root: &str,
264        index: Arc<ProjectIndex>,
265    ) -> Result<Arc<CallGraph>, BuildProgress> {
266        let state = global_state();
267        let mut guard = state
268            .lock()
269            .unwrap_or_else(std::sync::PoisonError::into_inner);
270
271        match &*guard {
272            BuildState::Ready(graph) => return Ok(Arc::clone(graph)),
273            BuildState::Building {
274                files_total,
275                files_done,
276                edges_found,
277            } => {
278                return Err(BuildProgress {
279                    status: "building",
280                    files_total: *files_total,
281                    files_done: files_done.load(Ordering::Relaxed),
282                    edges_found: edges_found.load(Ordering::Relaxed),
283                });
284            }
285            BuildState::Failed(msg) => {
286                tracing::warn!("[call_graph: previous build failed: {msg} — retrying]");
287            }
288            BuildState::Idle => {}
289        }
290
291        // Try serving from disk cache first
292        if let Some(cached) = Self::load(project_root) {
293            if !cache_looks_stale(&cached, &index) {
294                let arc = Arc::new(cached);
295                *guard = BuildState::Ready(Arc::clone(&arc));
296                return Ok(arc);
297            }
298        }
299
300        let files_total = index.files.len();
301        let files_done = Arc::new(AtomicUsize::new(0));
302        let edges_found = Arc::new(AtomicUsize::new(0));
303
304        *guard = BuildState::Building {
305            files_total,
306            files_done: Arc::clone(&files_done),
307            edges_found: Arc::clone(&edges_found),
308        };
309        drop(guard);
310
311        let root = normalize_project_root(project_root);
312        let fd = Arc::clone(&files_done);
313        let ef = Arc::clone(&edges_found);
314
315        std::thread::spawn(move || {
316            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
317                let previous = CallGraph::load(&root);
318                if let Some(prev) = &previous {
319                    CallGraph::build_incremental_parallel(&index, prev, Some((&fd, &ef)))
320                } else {
321                    CallGraph::build_parallel(&index, Some((&fd, &ef)))
322                }
323            }));
324
325            match result {
326                Ok(graph) => {
327                    let _ = graph.save();
328                    let arc = Arc::new(graph);
329                    if let Ok(mut g) = global_state().lock() {
330                        *g = BuildState::Ready(Arc::clone(&arc));
331                    }
332                    tracing::info!(
333                        "[call_graph: build complete — {} files, {} edges]",
334                        arc.file_hashes.len(),
335                        arc.edges.len()
336                    );
337                }
338                Err(e) => {
339                    let msg = format!("{e:?}");
340                    tracing::error!("[call_graph: build panicked: {msg}]");
341                    if let Ok(mut g) = global_state().lock() {
342                        *g = BuildState::Failed(msg);
343                    }
344                }
345            }
346        });
347
348        Err(BuildProgress {
349            status: "building",
350            files_total,
351            files_done: 0,
352            edges_found: 0,
353        })
354    }
355
356    /// Returns current build status without starting anything.
357    pub fn build_status() -> BuildProgress {
358        let state = global_state();
359        let guard = state
360            .lock()
361            .unwrap_or_else(std::sync::PoisonError::into_inner);
362        match &*guard {
363            BuildState::Idle => BuildProgress {
364                status: "idle",
365                files_total: 0,
366                files_done: 0,
367                edges_found: 0,
368            },
369            BuildState::Building {
370                files_total,
371                files_done,
372                edges_found,
373            } => BuildProgress {
374                status: "building",
375                files_total: *files_total,
376                files_done: files_done.load(Ordering::Relaxed),
377                edges_found: edges_found.load(Ordering::Relaxed),
378            },
379            BuildState::Ready(graph) => BuildProgress {
380                status: "ready",
381                files_total: graph.file_hashes.len(),
382                files_done: graph.file_hashes.len(),
383                edges_found: graph.edges.len(),
384            },
385            BuildState::Failed(msg) => {
386                tracing::debug!("[call_graph: status check — failed: {msg}]");
387                BuildProgress {
388                    status: "error",
389                    files_total: 0,
390                    files_done: 0,
391                    edges_found: 0,
392                }
393            }
394        }
395    }
396
397    /// Force-invalidate the cached result so next request triggers a rebuild.
398    pub fn invalidate() {
399        if let Ok(mut g) = global_state().lock() {
400            *g = BuildState::Idle;
401        }
402    }
403
404    // -----------------------------------------------------------------------
405    // Legacy synchronous API (kept for non-dashboard callers)
406    // -----------------------------------------------------------------------
407
408    pub fn build(index: &ProjectIndex) -> Self {
409        Self::build_parallel(index, None)
410    }
411
412    pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
413        Self::build_incremental_parallel(index, previous, None)
414    }
415
416    pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
417        let sym_lower = symbol.to_lowercase();
418        self.edges
419            .iter()
420            .filter(|e| e.callee_name.to_lowercase() == sym_lower)
421            .collect()
422    }
423
424    pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
425        let sym_lower = symbol.to_lowercase();
426        self.edges
427            .iter()
428            .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
429            .collect()
430    }
431
432    // -----------------------------------------------------------------------
433    // Multi-hop BFS traversal
434    // -----------------------------------------------------------------------
435
436    /// BFS callers up to `max_depth` hops. Returns (symbol, file, line, depth) per node.
437    pub fn bfs_callers(&self, symbol: &str, max_depth: usize) -> Vec<BfsNode> {
438        self.bfs_traverse(symbol, max_depth, BfsDirection::Callers)
439    }
440
441    /// BFS callees up to `max_depth` hops. Returns (symbol, file, line, depth) per node.
442    pub fn bfs_callees(&self, symbol: &str, max_depth: usize) -> Vec<BfsNode> {
443        self.bfs_traverse(symbol, max_depth, BfsDirection::Callees)
444    }
445
446    fn bfs_traverse(&self, symbol: &str, max_depth: usize, dir: BfsDirection) -> Vec<BfsNode> {
447        use std::collections::{HashSet, VecDeque};
448
449        let mut visited: HashSet<String> = HashSet::new();
450        let mut queue: VecDeque<(String, usize)> = VecDeque::new();
451        let mut result: Vec<BfsNode> = Vec::new();
452
453        let start = symbol.to_lowercase();
454        visited.insert(start.clone());
455        queue.push_back((start, 0));
456
457        while let Some((current, depth)) = queue.pop_front() {
458            if depth >= max_depth {
459                continue;
460            }
461
462            let neighbors: Vec<&CallEdge> = match dir {
463                BfsDirection::Callers => self
464                    .edges
465                    .iter()
466                    .filter(|e| e.callee_name.to_lowercase() == current)
467                    .collect(),
468                BfsDirection::Callees => self
469                    .edges
470                    .iter()
471                    .filter(|e| e.caller_symbol.to_lowercase() == current)
472                    .collect(),
473            };
474
475            for edge in neighbors {
476                let next_sym = match dir {
477                    BfsDirection::Callers => &edge.caller_symbol,
478                    BfsDirection::Callees => &edge.callee_name,
479                };
480                let next_lower = next_sym.to_lowercase();
481
482                if !visited.insert(next_lower.clone()) {
483                    continue;
484                }
485
486                result.push(BfsNode {
487                    symbol: next_sym.clone(),
488                    file: edge.caller_file.clone(),
489                    line: edge.caller_line,
490                    depth: depth + 1,
491                    from_symbol: if depth == 0 {
492                        symbol.to_string()
493                    } else {
494                        current.clone()
495                    },
496                });
497
498                queue.push_back((next_lower, depth + 1));
499            }
500        }
501
502        result
503    }
504
505    /// Find shortest call path from `from` to `to` using BFS.
506    /// Returns None if no path exists (searched up to depth 10).
507    /// Find shortest call path from `from` to `to` using BFS.
508    /// Returns None if no path exists (searched up to depth 10).
509    pub fn find_call_path(&self, from: &str, to: &str) -> Option<Vec<PathHop>> {
510        use std::collections::{HashMap as BfsMap, VecDeque};
511
512        let from_lower = from.to_lowercase();
513        let to_lower = to.to_lowercase();
514
515        if from_lower == to_lower {
516            return Some(vec![PathHop {
517                symbol: from.to_string(),
518                file: String::new(),
519                line: 0,
520            }]);
521        }
522
523        const MAX_TRACE_DEPTH: usize = 10;
524
525        // (parent_symbol, file, line, depth)
526        let mut visited: BfsMap<String, (String, String, usize, usize)> = BfsMap::new();
527        let mut queue: VecDeque<String> = VecDeque::new();
528
529        visited.insert(from_lower.clone(), (String::new(), String::new(), 0, 0));
530        queue.push_back(from_lower.clone());
531
532        while let Some(current) = queue.pop_front() {
533            let current_depth = visited.get(&current).map_or(0, |e| e.3);
534            if current_depth >= MAX_TRACE_DEPTH {
535                continue;
536            }
537
538            let callees: Vec<&CallEdge> = self
539                .edges
540                .iter()
541                .filter(|e| e.caller_symbol.to_lowercase() == current)
542                .collect();
543
544            for edge in callees {
545                let next = edge.callee_name.to_lowercase();
546                if visited.contains_key(&next) {
547                    continue;
548                }
549
550                visited.insert(
551                    next.clone(),
552                    (
553                        current.clone(),
554                        edge.caller_file.clone(),
555                        edge.caller_line,
556                        current_depth + 1,
557                    ),
558                );
559
560                if next == to_lower {
561                    return Some(Self::reconstruct_path(
562                        &visited,
563                        &from_lower,
564                        &to_lower,
565                        from,
566                        to,
567                    ));
568                }
569
570                queue.push_back(next);
571            }
572        }
573
574        None
575    }
576
577    fn reconstruct_path(
578        visited: &std::collections::HashMap<String, (String, String, usize, usize)>,
579        from_lower: &str,
580        to_lower: &str,
581        from_orig: &str,
582        to_orig: &str,
583    ) -> Vec<PathHop> {
584        let mut path = Vec::new();
585        let mut current = to_lower.to_string();
586
587        while current != from_lower {
588            let (parent, file, line, _depth) = &visited[&current];
589            let sym_name = if current == to_lower {
590                to_orig.to_string()
591            } else {
592                current.clone()
593            };
594            path.push(PathHop {
595                symbol: sym_name,
596                file: file.clone(),
597                line: *line,
598            });
599            current = parent.clone();
600        }
601
602        path.push(PathHop {
603            symbol: from_orig.to_string(),
604            file: String::new(),
605            line: 0,
606        });
607
608        path.reverse();
609        path
610    }
611
612    /// Count unique transitive callers up to `max_depth`.
613    pub fn transitive_caller_count(&self, symbol: &str, max_depth: usize) -> usize {
614        let nodes = self.bfs_callers(symbol, max_depth);
615        let mut unique: std::collections::HashSet<String> = std::collections::HashSet::new();
616        for node in &nodes {
617            unique.insert(node.symbol.to_lowercase());
618        }
619        unique.len()
620    }
621
622    pub fn save(&self) -> Result<(), String> {
623        let dir = call_graph_dir(&self.project_root)
624            .ok_or_else(|| "Cannot determine home directory".to_string())?;
625        std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
626        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
627        let compressed = zstd::encode_all(json.as_bytes(), 9).map_err(|e| format!("zstd: {e}"))?;
628        let target = dir.join("call_graph.json.zst");
629        let tmp = target.with_extension("zst.tmp");
630        std::fs::write(&tmp, &compressed).map_err(|e| e.to_string())?;
631        std::fs::rename(&tmp, &target).map_err(|e| e.to_string())?;
632        let _ = std::fs::remove_file(dir.join("call_graph.json"));
633        Ok(())
634    }
635
636    pub fn load(project_root: &str) -> Option<Self> {
637        let dir = call_graph_dir(project_root)?;
638
639        let zst_path = dir.join("call_graph.json.zst");
640        if zst_path.exists() {
641            let compressed = std::fs::read(&zst_path).ok()?;
642            let data = zstd::decode_all(compressed.as_slice()).ok()?;
643            let content = String::from_utf8(data).ok()?;
644            return serde_json::from_str(&content).ok();
645        }
646
647        let json_path = dir.join("call_graph.json");
648        if json_path.exists() {
649            let content = std::fs::read_to_string(&json_path).ok()?;
650            let parsed: Self = serde_json::from_str(&content).ok()?;
651            // Auto-migrate: compress legacy JSON to zstd
652            if let Ok(compressed) = zstd::encode_all(content.as_bytes(), 9) {
653                let zst_tmp = zst_path.with_extension("zst.tmp");
654                if std::fs::write(&zst_tmp, &compressed).is_ok()
655                    && std::fs::rename(&zst_tmp, &zst_path).is_ok()
656                {
657                    let _ = std::fs::remove_file(&json_path);
658                }
659            }
660            return Some(parsed);
661        }
662
663        None
664    }
665
666    pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
667        if let Some(previous) = Self::load(project_root) {
668            Self::build_incremental(index, &previous)
669        } else {
670            Self::build(index)
671        }
672    }
673}
674
675// ---------------------------------------------------------------------------
676// Cache staleness check (fast — mtime-based, no content reads)
677// ---------------------------------------------------------------------------
678
679fn cache_looks_stale(cached: &CallGraph, index: &ProjectIndex) -> bool {
680    if cached.file_hashes.len() != index.files.len() {
681        return true;
682    }
683    let cached_files: std::collections::HashSet<&str> =
684        cached.file_hashes.keys().map(String::as_str).collect();
685    let index_files: std::collections::HashSet<&str> =
686        index.files.keys().map(String::as_str).collect();
687    cached_files != index_files
688}
689
690// ---------------------------------------------------------------------------
691// Helpers
692// ---------------------------------------------------------------------------
693
694fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
695    ProjectIndex::index_dir(project_root)
696}
697
698fn group_edges_by_file(edges: &[CallEdge]) -> HashMap<&str, Vec<CallEdge>> {
699    let mut map: HashMap<&str, Vec<CallEdge>> = HashMap::new();
700    for edge in edges {
701        map.entry(edge.caller_file.as_str())
702            .or_default()
703            .push(edge.clone());
704    }
705    map
706}
707
708/// Owned version for safe `Send` across rayon threads.
709fn group_symbols_by_file_owned(index: &ProjectIndex) -> HashMap<String, Vec<SymbolEntry>> {
710    let mut map: HashMap<String, Vec<SymbolEntry>> = HashMap::new();
711    for sym in index.symbols.values() {
712        map.entry(sym.file.clone()).or_default().push(sym.clone());
713    }
714    for syms in map.values_mut() {
715        syms.sort_by_key(|s| s.start_line);
716    }
717    map
718}
719
720fn find_enclosing_symbol_owned(file_symbols: Option<&Vec<SymbolEntry>>, line: usize) -> String {
721    let Some(syms) = file_symbols else {
722        return "<module>".to_string();
723    };
724    let mut best: Option<&SymbolEntry> = None;
725    for sym in syms {
726        if line >= sym.start_line && line <= sym.end_line {
727            match best {
728                None => best = Some(sym),
729                Some(prev) => {
730                    if (sym.end_line - sym.start_line) < (prev.end_line - prev.start_line) {
731                        best = Some(sym);
732                    }
733                }
734            }
735        }
736    }
737    best.map_or_else(|| "<module>".to_string(), |s| s.name.clone())
738}
739
740fn resolve_path(relative: &str, project_root: &str) -> String {
741    let p = Path::new(relative);
742    if p.is_absolute() && p.exists() {
743        return relative.to_string();
744    }
745    let relative = relative.trim_start_matches(['/', '\\']);
746    let joined = Path::new(project_root).join(relative);
747    joined.to_string_lossy().to_string()
748}
749
750fn simple_hash(content: &str) -> String {
751    use std::hash::{Hash, Hasher};
752    let mut hasher = std::collections::hash_map::DefaultHasher::new();
753    content.hash(&mut hasher);
754    format!("{:x}", hasher.finish())
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760
761    #[test]
762    fn callers_of_empty_graph() {
763        let graph = CallGraph::new("/tmp");
764        assert!(graph.callers_of("foo").is_empty());
765    }
766
767    #[test]
768    fn callers_of_finds_edges() {
769        let mut graph = CallGraph::new("/tmp");
770        graph.edges.push(CallEdge {
771            caller_file: "a.rs".to_string(),
772            caller_symbol: "bar".to_string(),
773            caller_line: 10,
774            callee_name: "foo".to_string(),
775        });
776        graph.edges.push(CallEdge {
777            caller_file: "b.rs".to_string(),
778            caller_symbol: "baz".to_string(),
779            caller_line: 20,
780            callee_name: "foo".to_string(),
781        });
782        graph.edges.push(CallEdge {
783            caller_file: "c.rs".to_string(),
784            caller_symbol: "qux".to_string(),
785            caller_line: 30,
786            callee_name: "other".to_string(),
787        });
788        let callers = graph.callers_of("foo");
789        assert_eq!(callers.len(), 2);
790    }
791
792    #[test]
793    fn callees_of_finds_edges() {
794        let mut graph = CallGraph::new("/tmp");
795        graph.edges.push(CallEdge {
796            caller_file: "a.rs".to_string(),
797            caller_symbol: "main".to_string(),
798            caller_line: 5,
799            callee_name: "init".to_string(),
800        });
801        graph.edges.push(CallEdge {
802            caller_file: "a.rs".to_string(),
803            caller_symbol: "main".to_string(),
804            caller_line: 6,
805            callee_name: "run".to_string(),
806        });
807        graph.edges.push(CallEdge {
808            caller_file: "a.rs".to_string(),
809            caller_symbol: "other".to_string(),
810            caller_line: 15,
811            callee_name: "init".to_string(),
812        });
813        let callees = graph.callees_of("main");
814        assert_eq!(callees.len(), 2);
815    }
816
817    #[test]
818    fn find_enclosing_picks_narrowest() {
819        let outer = SymbolEntry {
820            file: "a.rs".to_string(),
821            name: "Outer".to_string(),
822            kind: "struct".to_string(),
823            start_line: 1,
824            end_line: 50,
825            is_exported: true,
826        };
827        let inner = SymbolEntry {
828            file: "a.rs".to_string(),
829            name: "inner_fn".to_string(),
830            kind: "fn".to_string(),
831            start_line: 10,
832            end_line: 20,
833            is_exported: false,
834        };
835        let syms = vec![outer, inner];
836        let result = find_enclosing_symbol_owned(Some(&syms), 15);
837        assert_eq!(result, "inner_fn");
838    }
839
840    #[test]
841    fn find_enclosing_returns_module_when_no_match() {
842        let sym = SymbolEntry {
843            file: "a.rs".to_string(),
844            name: "foo".to_string(),
845            kind: "fn".to_string(),
846            start_line: 10,
847            end_line: 20,
848            is_exported: false,
849        };
850        let syms = vec![sym];
851        let result = find_enclosing_symbol_owned(Some(&syms), 5);
852        assert_eq!(result, "<module>");
853    }
854
855    #[test]
856    fn resolve_path_trims_rooted_relative_prefix() {
857        let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
858        assert_eq!(
859            resolved,
860            Path::new(r"C:\repo")
861                .join(r"src\main\kotlin\Example.kt")
862                .to_string_lossy()
863                .to_string()
864        );
865    }
866
867    fn build_chain_graph() -> CallGraph {
868        // A -> B -> C -> D
869        let mut graph = CallGraph::new("/tmp");
870        graph.edges.push(CallEdge {
871            caller_file: "a.rs".into(),
872            caller_symbol: "fn_a".into(),
873            caller_line: 1,
874            callee_name: "fn_b".into(),
875        });
876        graph.edges.push(CallEdge {
877            caller_file: "b.rs".into(),
878            caller_symbol: "fn_b".into(),
879            caller_line: 10,
880            callee_name: "fn_c".into(),
881        });
882        graph.edges.push(CallEdge {
883            caller_file: "c.rs".into(),
884            caller_symbol: "fn_c".into(),
885            caller_line: 20,
886            callee_name: "fn_d".into(),
887        });
888        graph
889    }
890
891    #[test]
892    fn bfs_callees_depth_1_returns_direct() {
893        let graph = build_chain_graph();
894        let nodes = graph.bfs_callees("fn_a", 1);
895        assert_eq!(nodes.len(), 1);
896        assert_eq!(nodes[0].symbol, "fn_b");
897        assert_eq!(nodes[0].depth, 1);
898    }
899
900    #[test]
901    fn bfs_callees_depth_3_returns_chain() {
902        let graph = build_chain_graph();
903        let nodes = graph.bfs_callees("fn_a", 3);
904        assert_eq!(nodes.len(), 3);
905        let syms: Vec<&str> = nodes.iter().map(|n| n.symbol.as_str()).collect();
906        assert!(syms.contains(&"fn_b"));
907        assert!(syms.contains(&"fn_c"));
908        assert!(syms.contains(&"fn_d"));
909    }
910
911    #[test]
912    fn bfs_callers_depth_2_returns_transitive() {
913        let graph = build_chain_graph();
914        let nodes = graph.bfs_callers("fn_c", 2);
915        assert_eq!(nodes.len(), 2);
916        let syms: Vec<&str> = nodes.iter().map(|n| n.symbol.as_str()).collect();
917        assert!(syms.contains(&"fn_b"));
918        assert!(syms.contains(&"fn_a"));
919    }
920
921    #[test]
922    fn find_call_path_direct() {
923        let graph = build_chain_graph();
924        let path = graph.find_call_path("fn_a", "fn_b");
925        assert!(path.is_some());
926        let hops = path.unwrap();
927        assert_eq!(hops.len(), 2);
928        assert_eq!(hops[0].symbol, "fn_a");
929        assert_eq!(hops[1].symbol, "fn_b");
930    }
931
932    #[test]
933    fn find_call_path_multi_hop() {
934        let graph = build_chain_graph();
935        let path = graph.find_call_path("fn_a", "fn_d");
936        assert!(path.is_some());
937        let hops = path.unwrap();
938        assert_eq!(hops.len(), 4);
939        assert_eq!(hops[0].symbol, "fn_a");
940        assert_eq!(hops[3].symbol, "fn_d");
941    }
942
943    #[test]
944    fn find_call_path_no_connection() {
945        let graph = build_chain_graph();
946        let path = graph.find_call_path("fn_d", "fn_a");
947        assert!(path.is_none());
948    }
949
950    #[test]
951    fn find_call_path_same_symbol() {
952        let graph = build_chain_graph();
953        let path = graph.find_call_path("fn_a", "fn_a");
954        assert!(path.is_some());
955        assert_eq!(path.unwrap().len(), 1);
956    }
957
958    #[test]
959    fn transitive_caller_count_returns_unique() {
960        let mut graph = CallGraph::new("/tmp");
961        // x -> target, y -> target, z -> x (so z is transitive caller of target)
962        graph.edges.push(CallEdge {
963            caller_file: "x.rs".into(),
964            caller_symbol: "x".into(),
965            caller_line: 1,
966            callee_name: "target".into(),
967        });
968        graph.edges.push(CallEdge {
969            caller_file: "y.rs".into(),
970            caller_symbol: "y".into(),
971            caller_line: 2,
972            callee_name: "target".into(),
973        });
974        graph.edges.push(CallEdge {
975            caller_file: "z.rs".into(),
976            caller_symbol: "z".into(),
977            caller_line: 3,
978            callee_name: "x".into(),
979        });
980        assert_eq!(graph.transitive_caller_count("target", 5), 3);
981    }
982
983    #[test]
984    fn risk_level_classification() {
985        assert_eq!(RiskLevel::from_caller_count(0), RiskLevel::Low);
986        assert_eq!(RiskLevel::from_caller_count(1), RiskLevel::Low);
987        assert_eq!(RiskLevel::from_caller_count(3), RiskLevel::Medium);
988        assert_eq!(RiskLevel::from_caller_count(7), RiskLevel::High);
989        assert_eq!(RiskLevel::from_caller_count(15), RiskLevel::Critical);
990    }
991
992    #[test]
993    fn bfs_handles_cycle_without_infinite_loop() {
994        let mut graph = CallGraph::new("/tmp");
995        graph.edges.push(CallEdge {
996            caller_file: "a.rs".into(),
997            caller_symbol: "a".into(),
998            caller_line: 1,
999            callee_name: "b".into(),
1000        });
1001        graph.edges.push(CallEdge {
1002            caller_file: "b.rs".into(),
1003            caller_symbol: "b".into(),
1004            caller_line: 2,
1005            callee_name: "a".into(),
1006        });
1007        let nodes = graph.bfs_callees("a", 5);
1008        // Should visit b once (depth 1), then a is already visited
1009        assert_eq!(nodes.len(), 1);
1010        assert_eq!(nodes[0].symbol, "b");
1011    }
1012}