Skip to main content

lean_ctx/core/
call_graph.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9use super::deep_queries;
10use super::graph_index::{normalize_project_root, ProjectIndex, SymbolEntry};
11
12// ---------------------------------------------------------------------------
13// Data types
14// ---------------------------------------------------------------------------
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallGraph {
18    pub project_root: String,
19    pub edges: Vec<CallEdge>,
20    pub file_hashes: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CallEdge {
25    pub caller_file: String,
26    pub caller_symbol: String,
27    pub caller_line: usize,
28    pub callee_name: String,
29}
30
31// ---------------------------------------------------------------------------
32// Background build state (singleton per process)
33// ---------------------------------------------------------------------------
34
35#[derive(Debug, Clone, Serialize)]
36pub struct BuildProgress {
37    pub status: &'static str,
38    pub files_total: usize,
39    pub files_done: usize,
40    pub edges_found: usize,
41}
42
43enum BuildState {
44    Idle,
45    Building {
46        files_total: usize,
47        files_done: Arc<AtomicUsize>,
48        edges_found: Arc<AtomicUsize>,
49    },
50    Ready(Arc<CallGraph>),
51    Failed(String),
52}
53
54static BUILD: OnceLock<Mutex<BuildState>> = OnceLock::new();
55
56fn global_state() -> &'static Mutex<BuildState> {
57    BUILD.get_or_init(|| Mutex::new(BuildState::Idle))
58}
59
60impl CallGraph {
61    pub fn new(project_root: &str) -> Self {
62        Self {
63            project_root: normalize_project_root(project_root),
64            edges: Vec::new(),
65            file_hashes: HashMap::new(),
66        }
67    }
68
69    // -----------------------------------------------------------------------
70    // Parallel build — processes files via rayon thread pool
71    // -----------------------------------------------------------------------
72
73    pub fn build_parallel(
74        index: &ProjectIndex,
75        progress: Option<(&AtomicUsize, &AtomicUsize)>,
76    ) -> Self {
77        let project_root = &index.project_root;
78        let symbols_by_file = group_symbols_by_file_owned(index);
79        let file_keys: Vec<String> = index.files.keys().cloned().collect();
80
81        let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
82            .par_iter()
83            .filter_map(|rel_path| {
84                let abs_path = resolve_path(rel_path, project_root);
85                let content = std::fs::read_to_string(&abs_path).ok()?;
86                let hash = simple_hash(&content);
87
88                let ext = Path::new(rel_path)
89                    .extension()
90                    .and_then(|e| e.to_str())
91                    .unwrap_or("");
92
93                let analysis = deep_queries::analyze(&content, ext);
94                let file_symbols = symbols_by_file.get(rel_path.as_str());
95
96                let edges: Vec<CallEdge> = analysis
97                    .calls
98                    .iter()
99                    .map(|call| {
100                        let caller_sym = find_enclosing_symbol_owned(file_symbols, call.line + 1);
101                        CallEdge {
102                            caller_file: rel_path.clone(),
103                            caller_symbol: caller_sym,
104                            caller_line: call.line + 1,
105                            callee_name: call.callee.clone(),
106                        }
107                    })
108                    .collect();
109
110                if let Some((done, edge_count)) = progress {
111                    done.fetch_add(1, Ordering::Relaxed);
112                    edge_count.fetch_add(edges.len(), Ordering::Relaxed);
113                }
114
115                Some((rel_path.clone(), hash, edges))
116            })
117            .collect();
118
119        let mut graph = Self::new(project_root);
120        let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
121        graph.edges.reserve(edge_capacity);
122        graph.file_hashes.reserve(results.len());
123
124        for (path, hash, edges) in results {
125            graph.file_hashes.insert(path, hash);
126            graph.edges.extend(edges);
127        }
128
129        graph
130    }
131
132    // -----------------------------------------------------------------------
133    // Incremental parallel build — only re-analyzes changed files
134    // -----------------------------------------------------------------------
135
136    pub fn build_incremental_parallel(
137        index: &ProjectIndex,
138        previous: &CallGraph,
139        progress: Option<(&AtomicUsize, &AtomicUsize)>,
140    ) -> Self {
141        let project_root = &index.project_root;
142        let symbols_by_file = group_symbols_by_file_owned(index);
143        let file_keys: Vec<String> = index.files.keys().cloned().collect();
144
145        let prev_edges_by_file = group_edges_by_file(&previous.edges);
146
147        let results: Vec<(String, String, Vec<CallEdge>)> = file_keys
148            .par_iter()
149            .filter_map(|rel_path| {
150                let abs_path = resolve_path(rel_path, project_root);
151                let content = std::fs::read_to_string(&abs_path).ok()?;
152                let hash = simple_hash(&content);
153                let changed = previous.file_hashes.get(rel_path.as_str()) != Some(&hash);
154
155                let edges = if changed {
156                    let ext = Path::new(rel_path)
157                        .extension()
158                        .and_then(|e| e.to_str())
159                        .unwrap_or("");
160
161                    let analysis = deep_queries::analyze(&content, ext);
162                    let file_symbols = symbols_by_file.get(rel_path.as_str());
163
164                    analysis
165                        .calls
166                        .iter()
167                        .map(|call| {
168                            let caller_sym =
169                                find_enclosing_symbol_owned(file_symbols, call.line + 1);
170                            CallEdge {
171                                caller_file: rel_path.clone(),
172                                caller_symbol: caller_sym,
173                                caller_line: call.line + 1,
174                                callee_name: call.callee.clone(),
175                            }
176                        })
177                        .collect()
178                } else {
179                    prev_edges_by_file
180                        .get(rel_path.as_str())
181                        .cloned()
182                        .unwrap_or_default()
183                };
184
185                if let Some((done, edge_count)) = progress {
186                    done.fetch_add(1, Ordering::Relaxed);
187                    edge_count.fetch_add(edges.len(), Ordering::Relaxed);
188                }
189
190                Some((rel_path.clone(), hash, edges))
191            })
192            .collect();
193
194        let mut graph = Self::new(project_root);
195        let edge_capacity: usize = results.iter().map(|(_, _, e)| e.len()).sum();
196        graph.edges.reserve(edge_capacity);
197        graph.file_hashes.reserve(results.len());
198
199        for (path, hash, edges) in results {
200            graph.file_hashes.insert(path, hash);
201            graph.edges.extend(edges);
202        }
203
204        graph
205    }
206
207    // -----------------------------------------------------------------------
208    // Public API: non-blocking access for the dashboard
209    // -----------------------------------------------------------------------
210
211    /// Returns the cached graph immediately, or `None` + starts a background build.
212    pub fn get_or_start_build(
213        project_root: &str,
214        index: Arc<ProjectIndex>,
215    ) -> Result<Arc<CallGraph>, BuildProgress> {
216        let state = global_state();
217        let mut guard = state
218            .lock()
219            .unwrap_or_else(std::sync::PoisonError::into_inner);
220
221        match &*guard {
222            BuildState::Ready(graph) => return Ok(Arc::clone(graph)),
223            BuildState::Building {
224                files_total,
225                files_done,
226                edges_found,
227            } => {
228                return Err(BuildProgress {
229                    status: "building",
230                    files_total: *files_total,
231                    files_done: files_done.load(Ordering::Relaxed),
232                    edges_found: edges_found.load(Ordering::Relaxed),
233                });
234            }
235            BuildState::Failed(msg) => {
236                tracing::warn!("[call_graph: previous build failed: {msg} — retrying]");
237            }
238            BuildState::Idle => {}
239        }
240
241        // Try serving from disk cache first
242        if let Some(cached) = Self::load(project_root) {
243            if !cache_looks_stale(&cached, &index) {
244                let arc = Arc::new(cached);
245                *guard = BuildState::Ready(Arc::clone(&arc));
246                return Ok(arc);
247            }
248        }
249
250        let files_total = index.files.len();
251        let files_done = Arc::new(AtomicUsize::new(0));
252        let edges_found = Arc::new(AtomicUsize::new(0));
253
254        *guard = BuildState::Building {
255            files_total,
256            files_done: Arc::clone(&files_done),
257            edges_found: Arc::clone(&edges_found),
258        };
259        drop(guard);
260
261        let root = normalize_project_root(project_root);
262        let fd = Arc::clone(&files_done);
263        let ef = Arc::clone(&edges_found);
264
265        std::thread::spawn(move || {
266            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
267                let previous = CallGraph::load(&root);
268                if let Some(prev) = &previous {
269                    CallGraph::build_incremental_parallel(&index, prev, Some((&fd, &ef)))
270                } else {
271                    CallGraph::build_parallel(&index, Some((&fd, &ef)))
272                }
273            }));
274
275            match result {
276                Ok(graph) => {
277                    let _ = graph.save();
278                    let arc = Arc::new(graph);
279                    if let Ok(mut g) = global_state().lock() {
280                        *g = BuildState::Ready(Arc::clone(&arc));
281                    }
282                    tracing::info!(
283                        "[call_graph: build complete — {} files, {} edges]",
284                        arc.file_hashes.len(),
285                        arc.edges.len()
286                    );
287                }
288                Err(e) => {
289                    let msg = format!("{e:?}");
290                    tracing::error!("[call_graph: build panicked: {msg}]");
291                    if let Ok(mut g) = global_state().lock() {
292                        *g = BuildState::Failed(msg);
293                    }
294                }
295            }
296        });
297
298        Err(BuildProgress {
299            status: "building",
300            files_total,
301            files_done: 0,
302            edges_found: 0,
303        })
304    }
305
306    /// Returns current build status without starting anything.
307    pub fn build_status() -> BuildProgress {
308        let state = global_state();
309        let guard = state
310            .lock()
311            .unwrap_or_else(std::sync::PoisonError::into_inner);
312        match &*guard {
313            BuildState::Idle => BuildProgress {
314                status: "idle",
315                files_total: 0,
316                files_done: 0,
317                edges_found: 0,
318            },
319            BuildState::Building {
320                files_total,
321                files_done,
322                edges_found,
323            } => BuildProgress {
324                status: "building",
325                files_total: *files_total,
326                files_done: files_done.load(Ordering::Relaxed),
327                edges_found: edges_found.load(Ordering::Relaxed),
328            },
329            BuildState::Ready(graph) => BuildProgress {
330                status: "ready",
331                files_total: graph.file_hashes.len(),
332                files_done: graph.file_hashes.len(),
333                edges_found: graph.edges.len(),
334            },
335            BuildState::Failed(msg) => {
336                tracing::debug!("[call_graph: status check — failed: {msg}]");
337                BuildProgress {
338                    status: "error",
339                    files_total: 0,
340                    files_done: 0,
341                    edges_found: 0,
342                }
343            }
344        }
345    }
346
347    /// Force-invalidate the cached result so next request triggers a rebuild.
348    pub fn invalidate() {
349        if let Ok(mut g) = global_state().lock() {
350            *g = BuildState::Idle;
351        }
352    }
353
354    // -----------------------------------------------------------------------
355    // Legacy synchronous API (kept for non-dashboard callers)
356    // -----------------------------------------------------------------------
357
358    pub fn build(index: &ProjectIndex) -> Self {
359        Self::build_parallel(index, None)
360    }
361
362    pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
363        Self::build_incremental_parallel(index, previous, None)
364    }
365
366    pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
367        let sym_lower = symbol.to_lowercase();
368        self.edges
369            .iter()
370            .filter(|e| e.callee_name.to_lowercase() == sym_lower)
371            .collect()
372    }
373
374    pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
375        let sym_lower = symbol.to_lowercase();
376        self.edges
377            .iter()
378            .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
379            .collect()
380    }
381
382    pub fn save(&self) -> Result<(), String> {
383        let dir = call_graph_dir(&self.project_root)
384            .ok_or_else(|| "Cannot determine home directory".to_string())?;
385        std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
386        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
387        let compressed = zstd::encode_all(json.as_bytes(), 9).map_err(|e| format!("zstd: {e}"))?;
388        let target = dir.join("call_graph.json.zst");
389        let tmp = target.with_extension("zst.tmp");
390        std::fs::write(&tmp, &compressed).map_err(|e| e.to_string())?;
391        std::fs::rename(&tmp, &target).map_err(|e| e.to_string())?;
392        let _ = std::fs::remove_file(dir.join("call_graph.json"));
393        Ok(())
394    }
395
396    pub fn load(project_root: &str) -> Option<Self> {
397        let dir = call_graph_dir(project_root)?;
398
399        let zst_path = dir.join("call_graph.json.zst");
400        if zst_path.exists() {
401            let compressed = std::fs::read(&zst_path).ok()?;
402            let data = zstd::decode_all(compressed.as_slice()).ok()?;
403            let content = String::from_utf8(data).ok()?;
404            return serde_json::from_str(&content).ok();
405        }
406
407        let json_path = dir.join("call_graph.json");
408        if json_path.exists() {
409            let content = std::fs::read_to_string(&json_path).ok()?;
410            let parsed: Self = serde_json::from_str(&content).ok()?;
411            // Auto-migrate: compress legacy JSON to zstd
412            if let Ok(compressed) = zstd::encode_all(content.as_bytes(), 9) {
413                let zst_tmp = zst_path.with_extension("zst.tmp");
414                if std::fs::write(&zst_tmp, &compressed).is_ok()
415                    && std::fs::rename(&zst_tmp, &zst_path).is_ok()
416                {
417                    let _ = std::fs::remove_file(&json_path);
418                }
419            }
420            return Some(parsed);
421        }
422
423        None
424    }
425
426    pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
427        if let Some(previous) = Self::load(project_root) {
428            Self::build_incremental(index, &previous)
429        } else {
430            Self::build(index)
431        }
432    }
433}
434
435// ---------------------------------------------------------------------------
436// Cache staleness check (fast — mtime-based, no content reads)
437// ---------------------------------------------------------------------------
438
439fn cache_looks_stale(cached: &CallGraph, index: &ProjectIndex) -> bool {
440    if cached.file_hashes.len() != index.files.len() {
441        return true;
442    }
443    let cached_files: std::collections::HashSet<&str> =
444        cached.file_hashes.keys().map(String::as_str).collect();
445    let index_files: std::collections::HashSet<&str> =
446        index.files.keys().map(String::as_str).collect();
447    cached_files != index_files
448}
449
450// ---------------------------------------------------------------------------
451// Helpers
452// ---------------------------------------------------------------------------
453
454fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
455    ProjectIndex::index_dir(project_root)
456}
457
458fn group_edges_by_file(edges: &[CallEdge]) -> HashMap<&str, Vec<CallEdge>> {
459    let mut map: HashMap<&str, Vec<CallEdge>> = HashMap::new();
460    for edge in edges {
461        map.entry(edge.caller_file.as_str())
462            .or_default()
463            .push(edge.clone());
464    }
465    map
466}
467
468/// Owned version for safe `Send` across rayon threads.
469fn group_symbols_by_file_owned(index: &ProjectIndex) -> HashMap<String, Vec<SymbolEntry>> {
470    let mut map: HashMap<String, Vec<SymbolEntry>> = HashMap::new();
471    for sym in index.symbols.values() {
472        map.entry(sym.file.clone()).or_default().push(sym.clone());
473    }
474    for syms in map.values_mut() {
475        syms.sort_by_key(|s| s.start_line);
476    }
477    map
478}
479
480fn find_enclosing_symbol_owned(file_symbols: Option<&Vec<SymbolEntry>>, line: usize) -> String {
481    let Some(syms) = file_symbols else {
482        return "<module>".to_string();
483    };
484    let mut best: Option<&SymbolEntry> = None;
485    for sym in syms {
486        if line >= sym.start_line && line <= sym.end_line {
487            match best {
488                None => best = Some(sym),
489                Some(prev) => {
490                    if (sym.end_line - sym.start_line) < (prev.end_line - prev.start_line) {
491                        best = Some(sym);
492                    }
493                }
494            }
495        }
496    }
497    best.map_or_else(|| "<module>".to_string(), |s| s.name.clone())
498}
499
500fn resolve_path(relative: &str, project_root: &str) -> String {
501    let p = Path::new(relative);
502    if p.is_absolute() && p.exists() {
503        return relative.to_string();
504    }
505    let relative = relative.trim_start_matches(['/', '\\']);
506    let joined = Path::new(project_root).join(relative);
507    joined.to_string_lossy().to_string()
508}
509
510fn simple_hash(content: &str) -> String {
511    use std::hash::{Hash, Hasher};
512    let mut hasher = std::collections::hash_map::DefaultHasher::new();
513    content.hash(&mut hasher);
514    format!("{:x}", hasher.finish())
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn callers_of_empty_graph() {
523        let graph = CallGraph::new("/tmp");
524        assert!(graph.callers_of("foo").is_empty());
525    }
526
527    #[test]
528    fn callers_of_finds_edges() {
529        let mut graph = CallGraph::new("/tmp");
530        graph.edges.push(CallEdge {
531            caller_file: "a.rs".to_string(),
532            caller_symbol: "bar".to_string(),
533            caller_line: 10,
534            callee_name: "foo".to_string(),
535        });
536        graph.edges.push(CallEdge {
537            caller_file: "b.rs".to_string(),
538            caller_symbol: "baz".to_string(),
539            caller_line: 20,
540            callee_name: "foo".to_string(),
541        });
542        graph.edges.push(CallEdge {
543            caller_file: "c.rs".to_string(),
544            caller_symbol: "qux".to_string(),
545            caller_line: 30,
546            callee_name: "other".to_string(),
547        });
548        let callers = graph.callers_of("foo");
549        assert_eq!(callers.len(), 2);
550    }
551
552    #[test]
553    fn callees_of_finds_edges() {
554        let mut graph = CallGraph::new("/tmp");
555        graph.edges.push(CallEdge {
556            caller_file: "a.rs".to_string(),
557            caller_symbol: "main".to_string(),
558            caller_line: 5,
559            callee_name: "init".to_string(),
560        });
561        graph.edges.push(CallEdge {
562            caller_file: "a.rs".to_string(),
563            caller_symbol: "main".to_string(),
564            caller_line: 6,
565            callee_name: "run".to_string(),
566        });
567        graph.edges.push(CallEdge {
568            caller_file: "a.rs".to_string(),
569            caller_symbol: "other".to_string(),
570            caller_line: 15,
571            callee_name: "init".to_string(),
572        });
573        let callees = graph.callees_of("main");
574        assert_eq!(callees.len(), 2);
575    }
576
577    #[test]
578    fn find_enclosing_picks_narrowest() {
579        let outer = SymbolEntry {
580            file: "a.rs".to_string(),
581            name: "Outer".to_string(),
582            kind: "struct".to_string(),
583            start_line: 1,
584            end_line: 50,
585            is_exported: true,
586        };
587        let inner = SymbolEntry {
588            file: "a.rs".to_string(),
589            name: "inner_fn".to_string(),
590            kind: "fn".to_string(),
591            start_line: 10,
592            end_line: 20,
593            is_exported: false,
594        };
595        let syms = vec![outer, inner];
596        let result = find_enclosing_symbol_owned(Some(&syms), 15);
597        assert_eq!(result, "inner_fn");
598    }
599
600    #[test]
601    fn find_enclosing_returns_module_when_no_match() {
602        let sym = SymbolEntry {
603            file: "a.rs".to_string(),
604            name: "foo".to_string(),
605            kind: "fn".to_string(),
606            start_line: 10,
607            end_line: 20,
608            is_exported: false,
609        };
610        let syms = vec![sym];
611        let result = find_enclosing_symbol_owned(Some(&syms), 5);
612        assert_eq!(result, "<module>");
613    }
614
615    #[test]
616    fn resolve_path_trims_rooted_relative_prefix() {
617        let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
618        assert_eq!(
619            resolved,
620            Path::new(r"C:\repo")
621                .join(r"src\main\kotlin\Example.kt")
622                .to_string_lossy()
623                .to_string()
624        );
625    }
626}