1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9use super::deep_queries;
10use super::graph_index::{normalize_project_root, ProjectIndex, SymbolEntry};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallGraph {
18 pub project_root: String,
19 pub edges: Vec<CallEdge>,
20 pub file_hashes: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CallEdge {
25 pub caller_file: String,
26 pub caller_symbol: String,
27 pub caller_line: usize,
28 pub callee_name: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct BfsNode {
33 pub symbol: String,
34 pub file: String,
35 pub line: usize,
36 pub depth: usize,
37 pub from_symbol: String,
38}
39
40#[derive(Debug, Clone)]
41pub struct PathHop {
42 pub symbol: String,
43 pub file: String,
44 pub line: usize,
45}
46
47#[derive(Clone, Copy)]
48enum BfsDirection {
49 Callers,
50 Callees,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RiskLevel {
55 Low,
56 Medium,
57 High,
58 Critical,
59}
60
61impl RiskLevel {
62 pub fn from_caller_count(count: usize) -> Self {
63 match count {
64 0..=1 => Self::Low,
65 2..=4 => Self::Medium,
66 5..=10 => Self::High,
67 _ => Self::Critical,
68 }
69 }
70
71 pub fn label(self) -> &'static str {
72 match self {
73 Self::Low => "LOW",
74 Self::Medium => "MEDIUM",
75 Self::High => "HIGH",
76 Self::Critical => "CRITICAL",
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize)]
86pub struct BuildProgress {
87 pub status: &'static str,
88 pub files_total: usize,
89 pub files_done: usize,
90 pub edges_found: usize,
91}
92
93enum BuildState {
94 Idle,
95 Building {
96 files_total: usize,
97 files_done: Arc<AtomicUsize>,
98 edges_found: Arc<AtomicUsize>,
99 },
100 Ready(Arc<CallGraph>),
101 Failed(String),
102}
103
104static BUILD: OnceLock<Mutex<BuildState>> = OnceLock::new();
105
106fn global_state() -> &'static Mutex<BuildState> {
107 BUILD.get_or_init(|| Mutex::new(BuildState::Idle))
108}
109
110impl CallGraph {
111 pub fn new(project_root: &str) -> Self {
112 Self {
113 project_root: normalize_project_root(project_root),
114 edges: Vec::new(),
115 file_hashes: HashMap::new(),
116 }
117 }
118
119 pub fn build_parallel(
124 index: &ProjectIndex,
125 progress: Option<(&AtomicUsize, &AtomicUsize)>,
126 ) -> Self {
127 let project_root = &index.project_root;
128 let symbols_by_file = group_symbols_by_file_owned(index);
129 let file_keys: Vec<String> = index.files.keys().cloned().collect();
130
131 let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
132 .par_iter()
133 .filter_map(|rel_path| {
134 let abs_path = resolve_path(rel_path, project_root);
135 let content = std::fs::read_to_string(&abs_path).ok()?;
136 let hash = simple_hash(&content);
137
138 let ext = Path::new(rel_path)
139 .extension()
140 .and_then(|e| e.to_str())
141 .unwrap_or("");
142
143 let analysis = deep_queries::analyze(&content, ext);
144 let file_symbols = symbols_by_file.get(rel_path.as_str());
145
146 let edges: Vec<CallEdge> = analysis
147 .calls
148 .iter()
149 .map(|call| {
150 let caller_sym = find_enclosing_symbol_owned(file_symbols, call.line + 1);
151 CallEdge {
152 caller_file: rel_path.clone(),
153 caller_symbol: caller_sym,
154 caller_line: call.line + 1,
155 callee_name: call.callee.clone(),
156 }
157 })
158 .collect();
159
160 if let Some((done, edge_count)) = progress {
161 done.fetch_add(1, Ordering::Relaxed);
162 edge_count.fetch_add(edges.len(), Ordering::Relaxed);
163 }
164
165 Some((rel_path.clone(), hash, edges))
166 })
167 .collect();
168
169 let mut graph = Self::new(project_root);
170 let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
171 graph.edges.reserve(edge_capacity);
172 graph.file_hashes.reserve(results.len());
173
174 for (path, hash, edges) in results {
175 graph.file_hashes.insert(path, hash);
176 graph.edges.extend(edges);
177 }
178
179 graph
180 }
181
182 pub fn build_incremental_parallel(
187 index: &ProjectIndex,
188 previous: &CallGraph,
189 progress: Option<(&AtomicUsize, &AtomicUsize)>,
190 ) -> Self {
191 let project_root = &index.project_root;
192 let symbols_by_file = group_symbols_by_file_owned(index);
193 let file_keys: Vec<String> = index.files.keys().cloned().collect();
194
195 let prev_edges_by_file = group_edges_by_file(&previous.edges);
196
197 let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
198 .par_iter()
199 .filter_map(|rel_path| {
200 let abs_path = resolve_path(rel_path, project_root);
201 let content = std::fs::read_to_string(&abs_path).ok()?;
202 let hash = simple_hash(&content);
203 let changed = previous.file_hashes.get(rel_path.as_str()) != Some(&hash);
204
205 let edges = if changed {
206 let ext = Path::new(rel_path)
207 .extension()
208 .and_then(|e| e.to_str())
209 .unwrap_or("");
210
211 let analysis = deep_queries::analyze(&content, ext);
212 let file_symbols = symbols_by_file.get(rel_path.as_str());
213
214 analysis
215 .calls
216 .iter()
217 .map(|call| {
218 let caller_sym =
219 find_enclosing_symbol_owned(file_symbols, call.line + 1);
220 CallEdge {
221 caller_file: rel_path.clone(),
222 caller_symbol: caller_sym,
223 caller_line: call.line + 1,
224 callee_name: call.callee.clone(),
225 }
226 })
227 .collect()
228 } else {
229 prev_edges_by_file
230 .get(rel_path.as_str())
231 .cloned()
232 .unwrap_or_default()
233 };
234
235 if let Some((done, edge_count)) = progress {
236 done.fetch_add(1, Ordering::Relaxed);
237 edge_count.fetch_add(edges.len(), Ordering::Relaxed);
238 }
239
240 Some((rel_path.clone(), hash, edges))
241 })
242 .collect();
243
244 let mut graph = Self::new(project_root);
245 let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
246 graph.edges.reserve(edge_capacity);
247 graph.file_hashes.reserve(results.len());
248
249 for (path, hash, edges) in results {
250 graph.file_hashes.insert(path, hash);
251 graph.edges.extend(edges);
252 }
253
254 graph
255 }
256
257 pub fn get_or_start_build(
263 project_root: &str,
264 index: Arc<ProjectIndex>,
265 ) -> Result<Arc<CallGraph>, BuildProgress> {
266 let state = global_state();
267 let mut guard = state
268 .lock()
269 .unwrap_or_else(std::sync::PoisonError::into_inner);
270
271 match &*guard {
272 BuildState::Ready(graph) => return Ok(Arc::clone(graph)),
273 BuildState::Building {
274 files_total,
275 files_done,
276 edges_found,
277 } => {
278 return Err(BuildProgress {
279 status: "building",
280 files_total: *files_total,
281 files_done: files_done.load(Ordering::Relaxed),
282 edges_found: edges_found.load(Ordering::Relaxed),
283 });
284 }
285 BuildState::Failed(msg) => {
286 tracing::warn!("[call_graph: previous build failed: {msg} — retrying]");
287 }
288 BuildState::Idle => {}
289 }
290
291 if let Some(cached) = Self::load(project_root) {
293 if !cache_looks_stale(&cached, &index) {
294 let arc = Arc::new(cached);
295 *guard = BuildState::Ready(Arc::clone(&arc));
296 return Ok(arc);
297 }
298 }
299
300 let files_total = index.files.len();
301 let files_done = Arc::new(AtomicUsize::new(0));
302 let edges_found = Arc::new(AtomicUsize::new(0));
303
304 *guard = BuildState::Building {
305 files_total,
306 files_done: Arc::clone(&files_done),
307 edges_found: Arc::clone(&edges_found),
308 };
309 drop(guard);
310
311 let root = normalize_project_root(project_root);
312 let fd = Arc::clone(&files_done);
313 let ef = Arc::clone(&edges_found);
314
315 std::thread::spawn(move || {
316 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
317 let previous = CallGraph::load(&root);
318 if let Some(prev) = &previous {
319 CallGraph::build_incremental_parallel(&index, prev, Some((&fd, &ef)))
320 } else {
321 CallGraph::build_parallel(&index, Some((&fd, &ef)))
322 }
323 }));
324
325 match result {
326 Ok(graph) => {
327 let _ = graph.save();
328 let arc = Arc::new(graph);
329 if let Ok(mut g) = global_state().lock() {
330 *g = BuildState::Ready(Arc::clone(&arc));
331 }
332 tracing::info!(
333 "[call_graph: build complete — {} files, {} edges]",
334 arc.file_hashes.len(),
335 arc.edges.len()
336 );
337 }
338 Err(e) => {
339 let msg = format!("{e:?}");
340 tracing::error!("[call_graph: build panicked: {msg}]");
341 if let Ok(mut g) = global_state().lock() {
342 *g = BuildState::Failed(msg);
343 }
344 }
345 }
346 });
347
348 Err(BuildProgress {
349 status: "building",
350 files_total,
351 files_done: 0,
352 edges_found: 0,
353 })
354 }
355
356 pub fn build_status() -> BuildProgress {
358 let state = global_state();
359 let guard = state
360 .lock()
361 .unwrap_or_else(std::sync::PoisonError::into_inner);
362 match &*guard {
363 BuildState::Idle => BuildProgress {
364 status: "idle",
365 files_total: 0,
366 files_done: 0,
367 edges_found: 0,
368 },
369 BuildState::Building {
370 files_total,
371 files_done,
372 edges_found,
373 } => BuildProgress {
374 status: "building",
375 files_total: *files_total,
376 files_done: files_done.load(Ordering::Relaxed),
377 edges_found: edges_found.load(Ordering::Relaxed),
378 },
379 BuildState::Ready(graph) => BuildProgress {
380 status: "ready",
381 files_total: graph.file_hashes.len(),
382 files_done: graph.file_hashes.len(),
383 edges_found: graph.edges.len(),
384 },
385 BuildState::Failed(msg) => {
386 tracing::debug!("[call_graph: status check — failed: {msg}]");
387 BuildProgress {
388 status: "error",
389 files_total: 0,
390 files_done: 0,
391 edges_found: 0,
392 }
393 }
394 }
395 }
396
397 pub fn invalidate() {
399 if let Ok(mut g) = global_state().lock() {
400 *g = BuildState::Idle;
401 }
402 }
403
404 pub fn build(index: &ProjectIndex) -> Self {
409 Self::build_parallel(index, None)
410 }
411
412 pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
413 Self::build_incremental_parallel(index, previous, None)
414 }
415
416 pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
417 let sym_lower = symbol.to_lowercase();
418 self.edges
419 .iter()
420 .filter(|e| e.callee_name.to_lowercase() == sym_lower)
421 .collect()
422 }
423
424 pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
425 let sym_lower = symbol.to_lowercase();
426 self.edges
427 .iter()
428 .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
429 .collect()
430 }
431
432 pub fn bfs_callers(&self, symbol: &str, max_depth: usize) -> Vec<BfsNode> {
438 self.bfs_traverse(symbol, max_depth, BfsDirection::Callers)
439 }
440
441 pub fn bfs_callees(&self, symbol: &str, max_depth: usize) -> Vec<BfsNode> {
443 self.bfs_traverse(symbol, max_depth, BfsDirection::Callees)
444 }
445
446 fn bfs_traverse(&self, symbol: &str, max_depth: usize, dir: BfsDirection) -> Vec<BfsNode> {
447 use std::collections::{HashSet, VecDeque};
448
449 let mut visited: HashSet<String> = HashSet::new();
450 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
451 let mut result: Vec<BfsNode> = Vec::new();
452
453 let start = symbol.to_lowercase();
454 visited.insert(start.clone());
455 queue.push_back((start, 0));
456
457 while let Some((current, depth)) = queue.pop_front() {
458 if depth >= max_depth {
459 continue;
460 }
461
462 let neighbors: Vec<&CallEdge> = match dir {
463 BfsDirection::Callers => self
464 .edges
465 .iter()
466 .filter(|e| e.callee_name.to_lowercase() == current)
467 .collect(),
468 BfsDirection::Callees => self
469 .edges
470 .iter()
471 .filter(|e| e.caller_symbol.to_lowercase() == current)
472 .collect(),
473 };
474
475 for edge in neighbors {
476 let next_sym = match dir {
477 BfsDirection::Callers => &edge.caller_symbol,
478 BfsDirection::Callees => &edge.callee_name,
479 };
480 let next_lower = next_sym.to_lowercase();
481
482 if !visited.insert(next_lower.clone()) {
483 continue;
484 }
485
486 result.push(BfsNode {
487 symbol: next_sym.clone(),
488 file: edge.caller_file.clone(),
489 line: edge.caller_line,
490 depth: depth + 1,
491 from_symbol: if depth == 0 {
492 symbol.to_string()
493 } else {
494 current.clone()
495 },
496 });
497
498 queue.push_back((next_lower, depth + 1));
499 }
500 }
501
502 result
503 }
504
505 pub fn find_call_path(&self, from: &str, to: &str) -> Option<Vec<PathHop>> {
510 use std::collections::{HashMap as BfsMap, VecDeque};
511
512 let from_lower = from.to_lowercase();
513 let to_lower = to.to_lowercase();
514
515 if from_lower == to_lower {
516 return Some(vec![PathHop {
517 symbol: from.to_string(),
518 file: String::new(),
519 line: 0,
520 }]);
521 }
522
523 const MAX_TRACE_DEPTH: usize = 10;
524
525 let mut visited: BfsMap<String, (String, String, usize, usize)> = BfsMap::new();
527 let mut queue: VecDeque<String> = VecDeque::new();
528
529 visited.insert(from_lower.clone(), (String::new(), String::new(), 0, 0));
530 queue.push_back(from_lower.clone());
531
532 while let Some(current) = queue.pop_front() {
533 let current_depth = visited.get(¤t).map_or(0, |e| e.3);
534 if current_depth >= MAX_TRACE_DEPTH {
535 continue;
536 }
537
538 let callees: Vec<&CallEdge> = self
539 .edges
540 .iter()
541 .filter(|e| e.caller_symbol.to_lowercase() == current)
542 .collect();
543
544 for edge in callees {
545 let next = edge.callee_name.to_lowercase();
546 if visited.contains_key(&next) {
547 continue;
548 }
549
550 visited.insert(
551 next.clone(),
552 (
553 current.clone(),
554 edge.caller_file.clone(),
555 edge.caller_line,
556 current_depth + 1,
557 ),
558 );
559
560 if next == to_lower {
561 return Some(Self::reconstruct_path(
562 &visited,
563 &from_lower,
564 &to_lower,
565 from,
566 to,
567 ));
568 }
569
570 queue.push_back(next);
571 }
572 }
573
574 None
575 }
576
577 fn reconstruct_path(
578 visited: &std::collections::HashMap<String, (String, String, usize, usize)>,
579 from_lower: &str,
580 to_lower: &str,
581 from_orig: &str,
582 to_orig: &str,
583 ) -> Vec<PathHop> {
584 let mut path = Vec::new();
585 let mut current = to_lower.to_string();
586
587 while current != from_lower {
588 let (parent, file, line, _depth) = &visited[¤t];
589 let sym_name = if current == to_lower {
590 to_orig.to_string()
591 } else {
592 current.clone()
593 };
594 path.push(PathHop {
595 symbol: sym_name,
596 file: file.clone(),
597 line: *line,
598 });
599 current = parent.clone();
600 }
601
602 path.push(PathHop {
603 symbol: from_orig.to_string(),
604 file: String::new(),
605 line: 0,
606 });
607
608 path.reverse();
609 path
610 }
611
612 pub fn transitive_caller_count(&self, symbol: &str, max_depth: usize) -> usize {
614 let nodes = self.bfs_callers(symbol, max_depth);
615 let mut unique: std::collections::HashSet<String> = std::collections::HashSet::new();
616 for node in &nodes {
617 unique.insert(node.symbol.to_lowercase());
618 }
619 unique.len()
620 }
621
622 pub fn save(&self) -> Result<(), String> {
623 let dir = call_graph_dir(&self.project_root)
624 .ok_or_else(|| "Cannot determine home directory".to_string())?;
625 std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
626 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
627 let compressed = zstd::encode_all(json.as_bytes(), 9).map_err(|e| format!("zstd: {e}"))?;
628 let target = dir.join("call_graph.json.zst");
629 let tmp = target.with_extension("zst.tmp");
630 std::fs::write(&tmp, &compressed).map_err(|e| e.to_string())?;
631 std::fs::rename(&tmp, &target).map_err(|e| e.to_string())?;
632 let _ = std::fs::remove_file(dir.join("call_graph.json"));
633 Ok(())
634 }
635
636 pub fn load(project_root: &str) -> Option<Self> {
637 let dir = call_graph_dir(project_root)?;
638
639 let zst_path = dir.join("call_graph.json.zst");
640 if zst_path.exists() {
641 let compressed = std::fs::read(&zst_path).ok()?;
642 let data = zstd::decode_all(compressed.as_slice()).ok()?;
643 let content = String::from_utf8(data).ok()?;
644 return serde_json::from_str(&content).ok();
645 }
646
647 let json_path = dir.join("call_graph.json");
648 if json_path.exists() {
649 let content = std::fs::read_to_string(&json_path).ok()?;
650 let parsed: Self = serde_json::from_str(&content).ok()?;
651 if let Ok(compressed) = zstd::encode_all(content.as_bytes(), 9) {
653 let zst_tmp = zst_path.with_extension("zst.tmp");
654 if std::fs::write(&zst_tmp, &compressed).is_ok()
655 && std::fs::rename(&zst_tmp, &zst_path).is_ok()
656 {
657 let _ = std::fs::remove_file(&json_path);
658 }
659 }
660 return Some(parsed);
661 }
662
663 None
664 }
665
666 pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
667 if let Some(previous) = Self::load(project_root) {
668 Self::build_incremental(index, &previous)
669 } else {
670 Self::build(index)
671 }
672 }
673}
674
675fn cache_looks_stale(cached: &CallGraph, index: &ProjectIndex) -> bool {
680 if cached.file_hashes.len() != index.files.len() {
681 return true;
682 }
683 let cached_files: std::collections::HashSet<&str> =
684 cached.file_hashes.keys().map(String::as_str).collect();
685 let index_files: std::collections::HashSet<&str> =
686 index.files.keys().map(String::as_str).collect();
687 cached_files != index_files
688}
689
690fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
695 ProjectIndex::index_dir(project_root)
696}
697
698fn group_edges_by_file(edges: &[CallEdge]) -> HashMap<&str, Vec<CallEdge>> {
699 let mut map: HashMap<&str, Vec<CallEdge>> = HashMap::new();
700 for edge in edges {
701 map.entry(edge.caller_file.as_str())
702 .or_default()
703 .push(edge.clone());
704 }
705 map
706}
707
708fn group_symbols_by_file_owned(index: &ProjectIndex) -> HashMap<String, Vec<SymbolEntry>> {
710 let mut map: HashMap<String, Vec<SymbolEntry>> = HashMap::new();
711 for sym in index.symbols.values() {
712 map.entry(sym.file.clone()).or_default().push(sym.clone());
713 }
714 for syms in map.values_mut() {
715 syms.sort_by_key(|s| s.start_line);
716 }
717 map
718}
719
720fn find_enclosing_symbol_owned(file_symbols: Option<&Vec<SymbolEntry>>, line: usize) -> String {
721 let Some(syms) = file_symbols else {
722 return "<module>".to_string();
723 };
724 let mut best: Option<&SymbolEntry> = None;
725 for sym in syms {
726 if line >= sym.start_line && line <= sym.end_line {
727 match best {
728 None => best = Some(sym),
729 Some(prev) => {
730 if (sym.end_line - sym.start_line) < (prev.end_line - prev.start_line) {
731 best = Some(sym);
732 }
733 }
734 }
735 }
736 }
737 best.map_or_else(|| "<module>".to_string(), |s| s.name.clone())
738}
739
740fn resolve_path(relative: &str, project_root: &str) -> String {
741 let p = Path::new(relative);
742 if p.is_absolute() && p.exists() {
743 return relative.to_string();
744 }
745 let relative = relative.trim_start_matches(['/', '\\']);
746 let joined = Path::new(project_root).join(relative);
747 joined.to_string_lossy().to_string()
748}
749
750fn simple_hash(content: &str) -> String {
751 use std::hash::{Hash, Hasher};
752 let mut hasher = std::collections::hash_map::DefaultHasher::new();
753 content.hash(&mut hasher);
754 format!("{:x}", hasher.finish())
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
762 fn callers_of_empty_graph() {
763 let graph = CallGraph::new("/tmp");
764 assert!(graph.callers_of("foo").is_empty());
765 }
766
767 #[test]
768 fn callers_of_finds_edges() {
769 let mut graph = CallGraph::new("/tmp");
770 graph.edges.push(CallEdge {
771 caller_file: "a.rs".to_string(),
772 caller_symbol: "bar".to_string(),
773 caller_line: 10,
774 callee_name: "foo".to_string(),
775 });
776 graph.edges.push(CallEdge {
777 caller_file: "b.rs".to_string(),
778 caller_symbol: "baz".to_string(),
779 caller_line: 20,
780 callee_name: "foo".to_string(),
781 });
782 graph.edges.push(CallEdge {
783 caller_file: "c.rs".to_string(),
784 caller_symbol: "qux".to_string(),
785 caller_line: 30,
786 callee_name: "other".to_string(),
787 });
788 let callers = graph.callers_of("foo");
789 assert_eq!(callers.len(), 2);
790 }
791
792 #[test]
793 fn callees_of_finds_edges() {
794 let mut graph = CallGraph::new("/tmp");
795 graph.edges.push(CallEdge {
796 caller_file: "a.rs".to_string(),
797 caller_symbol: "main".to_string(),
798 caller_line: 5,
799 callee_name: "init".to_string(),
800 });
801 graph.edges.push(CallEdge {
802 caller_file: "a.rs".to_string(),
803 caller_symbol: "main".to_string(),
804 caller_line: 6,
805 callee_name: "run".to_string(),
806 });
807 graph.edges.push(CallEdge {
808 caller_file: "a.rs".to_string(),
809 caller_symbol: "other".to_string(),
810 caller_line: 15,
811 callee_name: "init".to_string(),
812 });
813 let callees = graph.callees_of("main");
814 assert_eq!(callees.len(), 2);
815 }
816
817 #[test]
818 fn find_enclosing_picks_narrowest() {
819 let outer = SymbolEntry {
820 file: "a.rs".to_string(),
821 name: "Outer".to_string(),
822 kind: "struct".to_string(),
823 start_line: 1,
824 end_line: 50,
825 is_exported: true,
826 };
827 let inner = SymbolEntry {
828 file: "a.rs".to_string(),
829 name: "inner_fn".to_string(),
830 kind: "fn".to_string(),
831 start_line: 10,
832 end_line: 20,
833 is_exported: false,
834 };
835 let syms = vec![outer, inner];
836 let result = find_enclosing_symbol_owned(Some(&syms), 15);
837 assert_eq!(result, "inner_fn");
838 }
839
840 #[test]
841 fn find_enclosing_returns_module_when_no_match() {
842 let sym = SymbolEntry {
843 file: "a.rs".to_string(),
844 name: "foo".to_string(),
845 kind: "fn".to_string(),
846 start_line: 10,
847 end_line: 20,
848 is_exported: false,
849 };
850 let syms = vec![sym];
851 let result = find_enclosing_symbol_owned(Some(&syms), 5);
852 assert_eq!(result, "<module>");
853 }
854
855 #[test]
856 fn resolve_path_trims_rooted_relative_prefix() {
857 let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
858 assert_eq!(
859 resolved,
860 Path::new(r"C:\repo")
861 .join(r"src\main\kotlin\Example.kt")
862 .to_string_lossy()
863 .to_string()
864 );
865 }
866
867 fn build_chain_graph() -> CallGraph {
868 let mut graph = CallGraph::new("/tmp");
870 graph.edges.push(CallEdge {
871 caller_file: "a.rs".into(),
872 caller_symbol: "fn_a".into(),
873 caller_line: 1,
874 callee_name: "fn_b".into(),
875 });
876 graph.edges.push(CallEdge {
877 caller_file: "b.rs".into(),
878 caller_symbol: "fn_b".into(),
879 caller_line: 10,
880 callee_name: "fn_c".into(),
881 });
882 graph.edges.push(CallEdge {
883 caller_file: "c.rs".into(),
884 caller_symbol: "fn_c".into(),
885 caller_line: 20,
886 callee_name: "fn_d".into(),
887 });
888 graph
889 }
890
891 #[test]
892 fn bfs_callees_depth_1_returns_direct() {
893 let graph = build_chain_graph();
894 let nodes = graph.bfs_callees("fn_a", 1);
895 assert_eq!(nodes.len(), 1);
896 assert_eq!(nodes[0].symbol, "fn_b");
897 assert_eq!(nodes[0].depth, 1);
898 }
899
900 #[test]
901 fn bfs_callees_depth_3_returns_chain() {
902 let graph = build_chain_graph();
903 let nodes = graph.bfs_callees("fn_a", 3);
904 assert_eq!(nodes.len(), 3);
905 let syms: Vec<&str> = nodes.iter().map(|n| n.symbol.as_str()).collect();
906 assert!(syms.contains(&"fn_b"));
907 assert!(syms.contains(&"fn_c"));
908 assert!(syms.contains(&"fn_d"));
909 }
910
911 #[test]
912 fn bfs_callers_depth_2_returns_transitive() {
913 let graph = build_chain_graph();
914 let nodes = graph.bfs_callers("fn_c", 2);
915 assert_eq!(nodes.len(), 2);
916 let syms: Vec<&str> = nodes.iter().map(|n| n.symbol.as_str()).collect();
917 assert!(syms.contains(&"fn_b"));
918 assert!(syms.contains(&"fn_a"));
919 }
920
921 #[test]
922 fn find_call_path_direct() {
923 let graph = build_chain_graph();
924 let path = graph.find_call_path("fn_a", "fn_b");
925 assert!(path.is_some());
926 let hops = path.unwrap();
927 assert_eq!(hops.len(), 2);
928 assert_eq!(hops[0].symbol, "fn_a");
929 assert_eq!(hops[1].symbol, "fn_b");
930 }
931
932 #[test]
933 fn find_call_path_multi_hop() {
934 let graph = build_chain_graph();
935 let path = graph.find_call_path("fn_a", "fn_d");
936 assert!(path.is_some());
937 let hops = path.unwrap();
938 assert_eq!(hops.len(), 4);
939 assert_eq!(hops[0].symbol, "fn_a");
940 assert_eq!(hops[3].symbol, "fn_d");
941 }
942
943 #[test]
944 fn find_call_path_no_connection() {
945 let graph = build_chain_graph();
946 let path = graph.find_call_path("fn_d", "fn_a");
947 assert!(path.is_none());
948 }
949
950 #[test]
951 fn find_call_path_same_symbol() {
952 let graph = build_chain_graph();
953 let path = graph.find_call_path("fn_a", "fn_a");
954 assert!(path.is_some());
955 assert_eq!(path.unwrap().len(), 1);
956 }
957
958 #[test]
959 fn transitive_caller_count_returns_unique() {
960 let mut graph = CallGraph::new("/tmp");
961 graph.edges.push(CallEdge {
963 caller_file: "x.rs".into(),
964 caller_symbol: "x".into(),
965 caller_line: 1,
966 callee_name: "target".into(),
967 });
968 graph.edges.push(CallEdge {
969 caller_file: "y.rs".into(),
970 caller_symbol: "y".into(),
971 caller_line: 2,
972 callee_name: "target".into(),
973 });
974 graph.edges.push(CallEdge {
975 caller_file: "z.rs".into(),
976 caller_symbol: "z".into(),
977 caller_line: 3,
978 callee_name: "x".into(),
979 });
980 assert_eq!(graph.transitive_caller_count("target", 5), 3);
981 }
982
983 #[test]
984 fn risk_level_classification() {
985 assert_eq!(RiskLevel::from_caller_count(0), RiskLevel::Low);
986 assert_eq!(RiskLevel::from_caller_count(1), RiskLevel::Low);
987 assert_eq!(RiskLevel::from_caller_count(3), RiskLevel::Medium);
988 assert_eq!(RiskLevel::from_caller_count(7), RiskLevel::High);
989 assert_eq!(RiskLevel::from_caller_count(15), RiskLevel::Critical);
990 }
991
992 #[test]
993 fn bfs_handles_cycle_without_infinite_loop() {
994 let mut graph = CallGraph::new("/tmp");
995 graph.edges.push(CallEdge {
996 caller_file: "a.rs".into(),
997 caller_symbol: "a".into(),
998 caller_line: 1,
999 callee_name: "b".into(),
1000 });
1001 graph.edges.push(CallEdge {
1002 caller_file: "b.rs".into(),
1003 caller_symbol: "b".into(),
1004 caller_line: 2,
1005 callee_name: "a".into(),
1006 });
1007 let nodes = graph.bfs_callees("a", 5);
1008 assert_eq!(nodes.len(), 1);
1010 assert_eq!(nodes[0].symbol, "b");
1011 }
1012}