Skip to main content

lean_ctx/core/
call_graph.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6use super::deep_queries;
7use super::graph_index::{normalize_project_root, ProjectIndex, SymbolEntry};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallGraph {
11    pub project_root: String,
12    pub edges: Vec<CallEdge>,
13    pub file_hashes: HashMap<String, String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallEdge {
18    pub caller_file: String,
19    pub caller_symbol: String,
20    pub caller_line: usize,
21    pub callee_name: String,
22}
23
24impl CallGraph {
25    pub fn new(project_root: &str) -> Self {
26        Self {
27            project_root: normalize_project_root(project_root),
28            edges: Vec::new(),
29            file_hashes: HashMap::new(),
30        }
31    }
32
33    pub fn build(index: &ProjectIndex) -> Self {
34        let project_root = &index.project_root;
35        let mut graph = Self::new(project_root);
36
37        let symbols_by_file = group_symbols_by_file(index);
38
39        for rel_path in index.files.keys() {
40            let abs_path = resolve_path(rel_path, project_root);
41            let Ok(content) = std::fs::read_to_string(&abs_path) else {
42                continue;
43            };
44
45            let hash = simple_hash(&content);
46            graph.file_hashes.insert(rel_path.clone(), hash);
47
48            let ext = Path::new(rel_path)
49                .extension()
50                .and_then(|e| e.to_str())
51                .unwrap_or("");
52
53            let analysis = deep_queries::analyze(&content, ext);
54            let file_symbols = symbols_by_file.get(rel_path.as_str());
55
56            for call in &analysis.calls {
57                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
58                graph.edges.push(CallEdge {
59                    caller_file: rel_path.clone(),
60                    caller_symbol: caller_sym,
61                    caller_line: call.line + 1,
62                    callee_name: call.callee.clone(),
63                });
64            }
65        }
66
67        graph
68    }
69
70    pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
71        let project_root = &index.project_root;
72        let mut graph = Self::new(project_root);
73        let symbols_by_file = group_symbols_by_file(index);
74
75        for rel_path in index.files.keys() {
76            let abs_path = resolve_path(rel_path, project_root);
77            let Ok(content) = std::fs::read_to_string(&abs_path) else {
78                continue;
79            };
80
81            let hash = simple_hash(&content);
82            let changed = previous.file_hashes.get(rel_path) != Some(&hash);
83
84            graph.file_hashes.insert(rel_path.clone(), hash);
85
86            if !changed {
87                let old_edges: Vec<_> = previous
88                    .edges
89                    .iter()
90                    .filter(|e| e.caller_file == rel_path.as_str())
91                    .cloned()
92                    .collect();
93                graph.edges.extend(old_edges);
94                continue;
95            }
96
97            let ext = Path::new(rel_path)
98                .extension()
99                .and_then(|e| e.to_str())
100                .unwrap_or("");
101
102            let analysis = deep_queries::analyze(&content, ext);
103            let file_symbols = symbols_by_file.get(rel_path.as_str());
104
105            for call in &analysis.calls {
106                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
107                graph.edges.push(CallEdge {
108                    caller_file: rel_path.clone(),
109                    caller_symbol: caller_sym,
110                    caller_line: call.line + 1,
111                    callee_name: call.callee.clone(),
112                });
113            }
114        }
115
116        graph
117    }
118
119    pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
120        let sym_lower = symbol.to_lowercase();
121        self.edges
122            .iter()
123            .filter(|e| e.callee_name.to_lowercase() == sym_lower)
124            .collect()
125    }
126
127    pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
128        let sym_lower = symbol.to_lowercase();
129        self.edges
130            .iter()
131            .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
132            .collect()
133    }
134
135    pub fn save(&self) -> Result<(), String> {
136        let dir = call_graph_dir(&self.project_root)
137            .ok_or_else(|| "Cannot determine home directory".to_string())?;
138        std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
139        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
140        std::fs::write(dir.join("call_graph.json"), json).map_err(|e| e.to_string())
141    }
142
143    pub fn load(project_root: &str) -> Option<Self> {
144        let dir = call_graph_dir(project_root)?;
145        let path = dir.join("call_graph.json");
146        let content = std::fs::read_to_string(path).ok()?;
147        serde_json::from_str(&content).ok()
148    }
149
150    pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
151        if let Some(previous) = Self::load(project_root) {
152            Self::build_incremental(index, &previous)
153        } else {
154            Self::build(index)
155        }
156    }
157}
158
159fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
160    ProjectIndex::index_dir(project_root)
161}
162
163fn group_symbols_by_file(index: &ProjectIndex) -> HashMap<&str, Vec<&SymbolEntry>> {
164    let mut map: HashMap<&str, Vec<&SymbolEntry>> = HashMap::new();
165    for sym in index.symbols.values() {
166        map.entry(sym.file.as_str()).or_default().push(sym);
167    }
168    for syms in map.values_mut() {
169        syms.sort_by_key(|s| s.start_line);
170    }
171    map
172}
173
174fn find_enclosing_symbol(file_symbols: Option<&Vec<&SymbolEntry>>, line: usize) -> String {
175    let Some(syms) = file_symbols else {
176        return "<module>".to_string();
177    };
178
179    let mut best: Option<&SymbolEntry> = None;
180    for sym in syms {
181        if line >= sym.start_line && line <= sym.end_line {
182            match best {
183                None => best = Some(sym),
184                Some(prev) => {
185                    let prev_span = prev.end_line - prev.start_line;
186                    let cur_span = sym.end_line - sym.start_line;
187                    if cur_span < prev_span {
188                        best = Some(sym);
189                    }
190                }
191            }
192        }
193    }
194
195    best.map_or_else(|| "<module>".to_string(), |s| s.name.clone())
196}
197
198fn resolve_path(relative: &str, project_root: &str) -> String {
199    let p = Path::new(relative);
200    if p.is_absolute() && p.exists() {
201        return relative.to_string();
202    }
203    let relative = relative.trim_start_matches(['/', '\\']);
204    let joined = Path::new(project_root).join(relative);
205    joined.to_string_lossy().to_string()
206}
207
208fn simple_hash(content: &str) -> String {
209    use std::hash::{Hash, Hasher};
210    let mut hasher = std::collections::hash_map::DefaultHasher::new();
211    content.hash(&mut hasher);
212    format!("{:x}", hasher.finish())
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn callers_of_empty_graph() {
221        let graph = CallGraph::new("/tmp");
222        assert!(graph.callers_of("foo").is_empty());
223    }
224
225    #[test]
226    fn callers_of_finds_edges() {
227        let mut graph = CallGraph::new("/tmp");
228        graph.edges.push(CallEdge {
229            caller_file: "a.rs".to_string(),
230            caller_symbol: "bar".to_string(),
231            caller_line: 10,
232            callee_name: "foo".to_string(),
233        });
234        graph.edges.push(CallEdge {
235            caller_file: "b.rs".to_string(),
236            caller_symbol: "baz".to_string(),
237            caller_line: 20,
238            callee_name: "foo".to_string(),
239        });
240        graph.edges.push(CallEdge {
241            caller_file: "c.rs".to_string(),
242            caller_symbol: "qux".to_string(),
243            caller_line: 30,
244            callee_name: "other".to_string(),
245        });
246        let callers = graph.callers_of("foo");
247        assert_eq!(callers.len(), 2);
248    }
249
250    #[test]
251    fn callees_of_finds_edges() {
252        let mut graph = CallGraph::new("/tmp");
253        graph.edges.push(CallEdge {
254            caller_file: "a.rs".to_string(),
255            caller_symbol: "main".to_string(),
256            caller_line: 5,
257            callee_name: "init".to_string(),
258        });
259        graph.edges.push(CallEdge {
260            caller_file: "a.rs".to_string(),
261            caller_symbol: "main".to_string(),
262            caller_line: 6,
263            callee_name: "run".to_string(),
264        });
265        graph.edges.push(CallEdge {
266            caller_file: "a.rs".to_string(),
267            caller_symbol: "other".to_string(),
268            caller_line: 15,
269            callee_name: "init".to_string(),
270        });
271        let callees = graph.callees_of("main");
272        assert_eq!(callees.len(), 2);
273    }
274
275    #[test]
276    fn find_enclosing_picks_narrowest() {
277        let outer = SymbolEntry {
278            file: "a.rs".to_string(),
279            name: "Outer".to_string(),
280            kind: "struct".to_string(),
281            start_line: 1,
282            end_line: 50,
283            is_exported: true,
284        };
285        let inner = SymbolEntry {
286            file: "a.rs".to_string(),
287            name: "inner_fn".to_string(),
288            kind: "fn".to_string(),
289            start_line: 10,
290            end_line: 20,
291            is_exported: false,
292        };
293        let syms = vec![&outer, &inner];
294        let result = find_enclosing_symbol(Some(&syms), 15);
295        assert_eq!(result, "inner_fn");
296    }
297
298    #[test]
299    fn find_enclosing_returns_module_when_no_match() {
300        let sym = SymbolEntry {
301            file: "a.rs".to_string(),
302            name: "foo".to_string(),
303            kind: "fn".to_string(),
304            start_line: 10,
305            end_line: 20,
306            is_exported: false,
307        };
308        let syms = vec![&sym];
309        let result = find_enclosing_symbol(Some(&syms), 5);
310        assert_eq!(result, "<module>");
311    }
312
313    #[test]
314    fn resolve_path_trims_rooted_relative_prefix() {
315        let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
316        assert_eq!(
317            resolved,
318            Path::new(r"C:\repo")
319                .join(r"src\main\kotlin\Example.kt")
320                .to_string_lossy()
321                .to_string()
322        );
323    }
324}