1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9use super::deep_queries;
10use super::graph_index::{normalize_project_root, ProjectIndex, SymbolEntry};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallGraph {
18 pub project_root: String,
19 pub edges: Vec<CallEdge>,
20 pub file_hashes: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CallEdge {
25 pub caller_file: String,
26 pub caller_symbol: String,
27 pub caller_line: usize,
28 pub callee_name: String,
29}
30
31#[derive(Debug, Clone, Serialize)]
36pub struct BuildProgress {
37 pub status: &'static str,
38 pub files_total: usize,
39 pub files_done: usize,
40 pub edges_found: usize,
41}
42
43enum BuildState {
44 Idle,
45 Building {
46 files_total: usize,
47 files_done: Arc<AtomicUsize>,
48 edges_found: Arc<AtomicUsize>,
49 },
50 Ready(Arc<CallGraph>),
51 Failed(String),
52}
53
54static BUILD: OnceLock<Mutex<BuildState>> = OnceLock::new();
55
56fn global_state() -> &'static Mutex<BuildState> {
57 BUILD.get_or_init(|| Mutex::new(BuildState::Idle))
58}
59
60impl CallGraph {
61 pub fn new(project_root: &str) -> Self {
62 Self {
63 project_root: normalize_project_root(project_root),
64 edges: Vec::new(),
65 file_hashes: HashMap::new(),
66 }
67 }
68
69 pub fn build_parallel(
74 index: &ProjectIndex,
75 progress: Option<(&AtomicUsize, &AtomicUsize)>,
76 ) -> Self {
77 let project_root = &index.project_root;
78 let symbols_by_file = group_symbols_by_file_owned(index);
79 let file_keys: Vec<String> = index.files.keys().cloned().collect();
80
81 let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
82 .par_iter()
83 .filter_map(|rel_path| {
84 let abs_path = resolve_path(rel_path, project_root);
85 let content = std::fs::read_to_string(&abs_path).ok()?;
86 let hash = simple_hash(&content);
87
88 let ext = Path::new(rel_path)
89 .extension()
90 .and_then(|e| e.to_str())
91 .unwrap_or("");
92
93 let analysis = deep_queries::analyze(&content, ext);
94 let file_symbols = symbols_by_file.get(rel_path.as_str());
95
96 let edges: Vec<CallEdge> = analysis
97 .calls
98 .iter()
99 .map(|call| {
100 let caller_sym = find_enclosing_symbol_owned(file_symbols, call.line + 1);
101 CallEdge {
102 caller_file: rel_path.clone(),
103 caller_symbol: caller_sym,
104 caller_line: call.line + 1,
105 callee_name: call.callee.clone(),
106 }
107 })
108 .collect();
109
110 if let Some((done, edge_count)) = progress {
111 done.fetch_add(1, Ordering::Relaxed);
112 edge_count.fetch_add(edges.len(), Ordering::Relaxed);
113 }
114
115 Some((rel_path.clone(), hash, edges))
116 })
117 .collect();
118
119 let mut graph = Self::new(project_root);
120 let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
121 graph.edges.reserve(edge_capacity);
122 graph.file_hashes.reserve(results.len());
123
124 for (path, hash, edges) in results {
125 graph.file_hashes.insert(path, hash);
126 graph.edges.extend(edges);
127 }
128
129 graph
130 }
131
132 pub fn build_incremental_parallel(
137 index: &ProjectIndex,
138 previous: &CallGraph,
139 progress: Option<(&AtomicUsize, &AtomicUsize)>,
140 ) -> Self {
141 let project_root = &index.project_root;
142 let symbols_by_file = group_symbols_by_file_owned(index);
143 let file_keys: Vec<String> = index.files.keys().cloned().collect();
144
145 let prev_edges_by_file = group_edges_by_file(&previous.edges);
146
147 let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
148 .par_iter()
149 .filter_map(|rel_path| {
150 let abs_path = resolve_path(rel_path, project_root);
151 let content = std::fs::read_to_string(&abs_path).ok()?;
152 let hash = simple_hash(&content);
153 let changed = previous.file_hashes.get(rel_path.as_str()) != Some(&hash);
154
155 let edges = if changed {
156 let ext = Path::new(rel_path)
157 .extension()
158 .and_then(|e| e.to_str())
159 .unwrap_or("");
160
161 let analysis = deep_queries::analyze(&content, ext);
162 let file_symbols = symbols_by_file.get(rel_path.as_str());
163
164 analysis
165 .calls
166 .iter()
167 .map(|call| {
168 let caller_sym =
169 find_enclosing_symbol_owned(file_symbols, call.line + 1);
170 CallEdge {
171 caller_file: rel_path.clone(),
172 caller_symbol: caller_sym,
173 caller_line: call.line + 1,
174 callee_name: call.callee.clone(),
175 }
176 })
177 .collect()
178 } else {
179 prev_edges_by_file
180 .get(rel_path.as_str())
181 .cloned()
182 .unwrap_or_default()
183 };
184
185 if let Some((done, edge_count)) = progress {
186 done.fetch_add(1, Ordering::Relaxed);
187 edge_count.fetch_add(edges.len(), Ordering::Relaxed);
188 }
189
190 Some((rel_path.clone(), hash, edges))
191 })
192 .collect();
193
194 let mut graph = Self::new(project_root);
195 let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
196 graph.edges.reserve(edge_capacity);
197 graph.file_hashes.reserve(results.len());
198
199 for (path, hash, edges) in results {
200 graph.file_hashes.insert(path, hash);
201 graph.edges.extend(edges);
202 }
203
204 graph
205 }
206
207 pub fn get_or_start_build(
213 project_root: &str,
214 index: Arc<ProjectIndex>,
215 ) -> Result<Arc<CallGraph>, BuildProgress> {
216 let state = global_state();
217 let mut guard = state
218 .lock()
219 .unwrap_or_else(std::sync::PoisonError::into_inner);
220
221 match &*guard {
222 BuildState::Ready(graph) => return Ok(Arc::clone(graph)),
223 BuildState::Building {
224 files_total,
225 files_done,
226 edges_found,
227 } => {
228 return Err(BuildProgress {
229 status: "building",
230 files_total: *files_total,
231 files_done: files_done.load(Ordering::Relaxed),
232 edges_found: edges_found.load(Ordering::Relaxed),
233 });
234 }
235 BuildState::Failed(msg) => {
236 tracing::warn!("[call_graph: previous build failed: {msg} — retrying]");
237 }
238 BuildState::Idle => {}
239 }
240
241 if let Some(cached) = Self::load(project_root) {
243 if !cache_looks_stale(&cached, &index) {
244 let arc = Arc::new(cached);
245 *guard = BuildState::Ready(Arc::clone(&arc));
246 return Ok(arc);
247 }
248 }
249
250 let files_total = index.files.len();
251 let files_done = Arc::new(AtomicUsize::new(0));
252 let edges_found = Arc::new(AtomicUsize::new(0));
253
254 *guard = BuildState::Building {
255 files_total,
256 files_done: Arc::clone(&files_done),
257 edges_found: Arc::clone(&edges_found),
258 };
259 drop(guard);
260
261 let root = normalize_project_root(project_root);
262 let fd = Arc::clone(&files_done);
263 let ef = Arc::clone(&edges_found);
264
265 std::thread::spawn(move || {
266 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
267 let previous = CallGraph::load(&root);
268 if let Some(prev) = &previous {
269 CallGraph::build_incremental_parallel(&index, prev, Some((&fd, &ef)))
270 } else {
271 CallGraph::build_parallel(&index, Some((&fd, &ef)))
272 }
273 }));
274
275 match result {
276 Ok(graph) => {
277 let _ = graph.save();
278 let arc = Arc::new(graph);
279 if let Ok(mut g) = global_state().lock() {
280 *g = BuildState::Ready(Arc::clone(&arc));
281 }
282 tracing::info!(
283 "[call_graph: build complete — {} files, {} edges]",
284 arc.file_hashes.len(),
285 arc.edges.len()
286 );
287 }
288 Err(e) => {
289 let msg = format!("{e:?}");
290 tracing::error!("[call_graph: build panicked: {msg}]");
291 if let Ok(mut g) = global_state().lock() {
292 *g = BuildState::Failed(msg);
293 }
294 }
295 }
296 });
297
298 Err(BuildProgress {
299 status: "building",
300 files_total,
301 files_done: 0,
302 edges_found: 0,
303 })
304 }
305
306 pub fn build_status() -> BuildProgress {
308 let state = global_state();
309 let guard = state
310 .lock()
311 .unwrap_or_else(std::sync::PoisonError::into_inner);
312 match &*guard {
313 BuildState::Idle => BuildProgress {
314 status: "idle",
315 files_total: 0,
316 files_done: 0,
317 edges_found: 0,
318 },
319 BuildState::Building {
320 files_total,
321 files_done,
322 edges_found,
323 } => BuildProgress {
324 status: "building",
325 files_total: *files_total,
326 files_done: files_done.load(Ordering::Relaxed),
327 edges_found: edges_found.load(Ordering::Relaxed),
328 },
329 BuildState::Ready(graph) => BuildProgress {
330 status: "ready",
331 files_total: graph.file_hashes.len(),
332 files_done: graph.file_hashes.len(),
333 edges_found: graph.edges.len(),
334 },
335 BuildState::Failed(msg) => {
336 tracing::debug!("[call_graph: status check — failed: {msg}]");
337 BuildProgress {
338 status: "error",
339 files_total: 0,
340 files_done: 0,
341 edges_found: 0,
342 }
343 }
344 }
345 }
346
347 pub fn invalidate() {
349 if let Ok(mut g) = global_state().lock() {
350 *g = BuildState::Idle;
351 }
352 }
353
354 pub fn build(index: &ProjectIndex) -> Self {
359 Self::build_parallel(index, None)
360 }
361
362 pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
363 Self::build_incremental_parallel(index, previous, None)
364 }
365
366 pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
367 let sym_lower = symbol.to_lowercase();
368 self.edges
369 .iter()
370 .filter(|e| e.callee_name.to_lowercase() == sym_lower)
371 .collect()
372 }
373
374 pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
375 let sym_lower = symbol.to_lowercase();
376 self.edges
377 .iter()
378 .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
379 .collect()
380 }
381
382 pub fn save(&self) -> Result<(), String> {
383 let dir = call_graph_dir(&self.project_root)
384 .ok_or_else(|| "Cannot determine home directory".to_string())?;
385 std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
386 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
387 let compressed = zstd::encode_all(json.as_bytes(), 9).map_err(|e| format!("zstd: {e}"))?;
388 let target = dir.join("call_graph.json.zst");
389 let tmp = target.with_extension("zst.tmp");
390 std::fs::write(&tmp, &compressed).map_err(|e| e.to_string())?;
391 std::fs::rename(&tmp, &target).map_err(|e| e.to_string())?;
392 let _ = std::fs::remove_file(dir.join("call_graph.json"));
393 Ok(())
394 }
395
396 pub fn load(project_root: &str) -> Option<Self> {
397 let dir = call_graph_dir(project_root)?;
398
399 let zst_path = dir.join("call_graph.json.zst");
400 if zst_path.exists() {
401 let compressed = std::fs::read(&zst_path).ok()?;
402 let data = zstd::decode_all(compressed.as_slice()).ok()?;
403 let content = String::from_utf8(data).ok()?;
404 return serde_json::from_str(&content).ok();
405 }
406
407 let json_path = dir.join("call_graph.json");
408 if json_path.exists() {
409 let content = std::fs::read_to_string(&json_path).ok()?;
410 let parsed: Self = serde_json::from_str(&content).ok()?;
411 if let Ok(compressed) = zstd::encode_all(content.as_bytes(), 9) {
413 let zst_tmp = zst_path.with_extension("zst.tmp");
414 if std::fs::write(&zst_tmp, &compressed).is_ok()
415 && std::fs::rename(&zst_tmp, &zst_path).is_ok()
416 {
417 let _ = std::fs::remove_file(&json_path);
418 }
419 }
420 return Some(parsed);
421 }
422
423 None
424 }
425
426 pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
427 if let Some(previous) = Self::load(project_root) {
428 Self::build_incremental(index, &previous)
429 } else {
430 Self::build(index)
431 }
432 }
433}
434
435fn cache_looks_stale(cached: &CallGraph, index: &ProjectIndex) -> bool {
440 if cached.file_hashes.len() != index.files.len() {
441 return true;
442 }
443 let cached_files: std::collections::HashSet<&str> =
444 cached.file_hashes.keys().map(String::as_str).collect();
445 let index_files: std::collections::HashSet<&str> =
446 index.files.keys().map(String::as_str).collect();
447 cached_files != index_files
448}
449
450fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
455 ProjectIndex::index_dir(project_root)
456}
457
458fn group_edges_by_file(edges: &[CallEdge]) -> HashMap<&str, Vec<CallEdge>> {
459 let mut map: HashMap<&str, Vec<CallEdge>> = HashMap::new();
460 for edge in edges {
461 map.entry(edge.caller_file.as_str())
462 .or_default()
463 .push(edge.clone());
464 }
465 map
466}
467
468fn group_symbols_by_file_owned(index: &ProjectIndex) -> HashMap<String, Vec<SymbolEntry>> {
470 let mut map: HashMap<String, Vec<SymbolEntry>> = HashMap::new();
471 for sym in index.symbols.values() {
472 map.entry(sym.file.clone()).or_default().push(sym.clone());
473 }
474 for syms in map.values_mut() {
475 syms.sort_by_key(|s| s.start_line);
476 }
477 map
478}
479
480fn find_enclosing_symbol_owned(file_symbols: Option<&Vec<SymbolEntry>>, line: usize) -> String {
481 let Some(syms) = file_symbols else {
482 return "<module>".to_string();
483 };
484 let mut best: Option<&SymbolEntry> = None;
485 for sym in syms {
486 if line >= sym.start_line && line <= sym.end_line {
487 match best {
488 None => best = Some(sym),
489 Some(prev) => {
490 if (sym.end_line - sym.start_line) < (prev.end_line - prev.start_line) {
491 best = Some(sym);
492 }
493 }
494 }
495 }
496 }
497 best.map_or_else(|| "<module>".to_string(), |s| s.name.clone())
498}
499
500fn resolve_path(relative: &str, project_root: &str) -> String {
501 let p = Path::new(relative);
502 if p.is_absolute() && p.exists() {
503 return relative.to_string();
504 }
505 let relative = relative.trim_start_matches(['/', '\\']);
506 let joined = Path::new(project_root).join(relative);
507 joined.to_string_lossy().to_string()
508}
509
510fn simple_hash(content: &str) -> String {
511 use std::hash::{Hash, Hasher};
512 let mut hasher = std::collections::hash_map::DefaultHasher::new();
513 content.hash(&mut hasher);
514 format!("{:x}", hasher.finish())
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn callers_of_empty_graph() {
523 let graph = CallGraph::new("/tmp");
524 assert!(graph.callers_of("foo").is_empty());
525 }
526
527 #[test]
528 fn callers_of_finds_edges() {
529 let mut graph = CallGraph::new("/tmp");
530 graph.edges.push(CallEdge {
531 caller_file: "a.rs".to_string(),
532 caller_symbol: "bar".to_string(),
533 caller_line: 10,
534 callee_name: "foo".to_string(),
535 });
536 graph.edges.push(CallEdge {
537 caller_file: "b.rs".to_string(),
538 caller_symbol: "baz".to_string(),
539 caller_line: 20,
540 callee_name: "foo".to_string(),
541 });
542 graph.edges.push(CallEdge {
543 caller_file: "c.rs".to_string(),
544 caller_symbol: "qux".to_string(),
545 caller_line: 30,
546 callee_name: "other".to_string(),
547 });
548 let callers = graph.callers_of("foo");
549 assert_eq!(callers.len(), 2);
550 }
551
552 #[test]
553 fn callees_of_finds_edges() {
554 let mut graph = CallGraph::new("/tmp");
555 graph.edges.push(CallEdge {
556 caller_file: "a.rs".to_string(),
557 caller_symbol: "main".to_string(),
558 caller_line: 5,
559 callee_name: "init".to_string(),
560 });
561 graph.edges.push(CallEdge {
562 caller_file: "a.rs".to_string(),
563 caller_symbol: "main".to_string(),
564 caller_line: 6,
565 callee_name: "run".to_string(),
566 });
567 graph.edges.push(CallEdge {
568 caller_file: "a.rs".to_string(),
569 caller_symbol: "other".to_string(),
570 caller_line: 15,
571 callee_name: "init".to_string(),
572 });
573 let callees = graph.callees_of("main");
574 assert_eq!(callees.len(), 2);
575 }
576
577 #[test]
578 fn find_enclosing_picks_narrowest() {
579 let outer = SymbolEntry {
580 file: "a.rs".to_string(),
581 name: "Outer".to_string(),
582 kind: "struct".to_string(),
583 start_line: 1,
584 end_line: 50,
585 is_exported: true,
586 };
587 let inner = SymbolEntry {
588 file: "a.rs".to_string(),
589 name: "inner_fn".to_string(),
590 kind: "fn".to_string(),
591 start_line: 10,
592 end_line: 20,
593 is_exported: false,
594 };
595 let syms = vec![outer, inner];
596 let result = find_enclosing_symbol_owned(Some(&syms), 15);
597 assert_eq!(result, "inner_fn");
598 }
599
600 #[test]
601 fn find_enclosing_returns_module_when_no_match() {
602 let sym = SymbolEntry {
603 file: "a.rs".to_string(),
604 name: "foo".to_string(),
605 kind: "fn".to_string(),
606 start_line: 10,
607 end_line: 20,
608 is_exported: false,
609 };
610 let syms = vec![sym];
611 let result = find_enclosing_symbol_owned(Some(&syms), 5);
612 assert_eq!(result, "<module>");
613 }
614
615 #[test]
616 fn resolve_path_trims_rooted_relative_prefix() {
617 let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
618 assert_eq!(
619 resolved,
620 Path::new(r"C:\repo")
621 .join(r"src\main\kotlin\Example.kt")
622 .to_string_lossy()
623 .to_string()
624 );
625 }
626}