Skip to main content

lean_ctx/core/
call_graph.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6use super::deep_queries;
7use super::graph_index::{ProjectIndex, SymbolEntry};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallGraph {
11    pub project_root: String,
12    pub edges: Vec<CallEdge>,
13    pub file_hashes: HashMap<String, String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallEdge {
18    pub caller_file: String,
19    pub caller_symbol: String,
20    pub caller_line: usize,
21    pub callee_name: String,
22}
23
24impl CallGraph {
25    pub fn new(project_root: &str) -> Self {
26        Self {
27            project_root: project_root.to_string(),
28            edges: Vec::new(),
29            file_hashes: HashMap::new(),
30        }
31    }
32
33    pub fn build(index: &ProjectIndex) -> Self {
34        let project_root = &index.project_root;
35        let mut graph = Self::new(project_root);
36
37        let symbols_by_file = group_symbols_by_file(index);
38
39        for rel_path in index.files.keys() {
40            let abs_path = resolve_path(rel_path, project_root);
41            let content = match std::fs::read_to_string(&abs_path) {
42                Ok(c) => c,
43                Err(_) => continue,
44            };
45
46            let hash = simple_hash(&content);
47            graph.file_hashes.insert(rel_path.clone(), hash);
48
49            let ext = Path::new(rel_path)
50                .extension()
51                .and_then(|e| e.to_str())
52                .unwrap_or("");
53
54            let analysis = deep_queries::analyze(&content, ext);
55            let file_symbols = symbols_by_file.get(rel_path.as_str());
56
57            for call in &analysis.calls {
58                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
59                graph.edges.push(CallEdge {
60                    caller_file: rel_path.clone(),
61                    caller_symbol: caller_sym,
62                    caller_line: call.line + 1,
63                    callee_name: call.callee.clone(),
64                });
65            }
66        }
67
68        graph
69    }
70
71    pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
72        let project_root = &index.project_root;
73        let mut graph = Self::new(project_root);
74        let symbols_by_file = group_symbols_by_file(index);
75
76        for rel_path in index.files.keys() {
77            let abs_path = resolve_path(rel_path, project_root);
78            let content = match std::fs::read_to_string(&abs_path) {
79                Ok(c) => c,
80                Err(_) => continue,
81            };
82
83            let hash = simple_hash(&content);
84            let changed = previous
85                .file_hashes
86                .get(rel_path)
87                .map(|old| old != &hash)
88                .unwrap_or(true);
89
90            graph.file_hashes.insert(rel_path.clone(), hash);
91
92            if !changed {
93                let old_edges: Vec<_> = previous
94                    .edges
95                    .iter()
96                    .filter(|e| e.caller_file == rel_path.as_str())
97                    .cloned()
98                    .collect();
99                graph.edges.extend(old_edges);
100                continue;
101            }
102
103            let ext = Path::new(rel_path)
104                .extension()
105                .and_then(|e| e.to_str())
106                .unwrap_or("");
107
108            let analysis = deep_queries::analyze(&content, ext);
109            let file_symbols = symbols_by_file.get(rel_path.as_str());
110
111            for call in &analysis.calls {
112                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
113                graph.edges.push(CallEdge {
114                    caller_file: rel_path.clone(),
115                    caller_symbol: caller_sym,
116                    caller_line: call.line + 1,
117                    callee_name: call.callee.clone(),
118                });
119            }
120        }
121
122        graph
123    }
124
125    pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
126        let sym_lower = symbol.to_lowercase();
127        self.edges
128            .iter()
129            .filter(|e| e.callee_name.to_lowercase() == sym_lower)
130            .collect()
131    }
132
133    pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
134        let sym_lower = symbol.to_lowercase();
135        self.edges
136            .iter()
137            .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
138            .collect()
139    }
140
141    pub fn save(&self) -> Result<(), String> {
142        let dir = call_graph_dir(&self.project_root)
143            .ok_or_else(|| "Cannot determine home directory".to_string())?;
144        std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
145        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
146        std::fs::write(dir.join("call_graph.json"), json).map_err(|e| e.to_string())
147    }
148
149    pub fn load(project_root: &str) -> Option<Self> {
150        let dir = call_graph_dir(project_root)?;
151        let path = dir.join("call_graph.json");
152        let content = std::fs::read_to_string(path).ok()?;
153        serde_json::from_str(&content).ok()
154    }
155
156    pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
157        if let Some(previous) = Self::load(project_root) {
158            Self::build_incremental(index, &previous)
159        } else {
160            Self::build(index)
161        }
162    }
163}
164
165fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
166    ProjectIndex::index_dir(project_root)
167}
168
169fn group_symbols_by_file(index: &ProjectIndex) -> HashMap<&str, Vec<&SymbolEntry>> {
170    let mut map: HashMap<&str, Vec<&SymbolEntry>> = HashMap::new();
171    for sym in index.symbols.values() {
172        map.entry(sym.file.as_str()).or_default().push(sym);
173    }
174    for syms in map.values_mut() {
175        syms.sort_by_key(|s| s.start_line);
176    }
177    map
178}
179
180fn find_enclosing_symbol(file_symbols: Option<&Vec<&SymbolEntry>>, line: usize) -> String {
181    let syms = match file_symbols {
182        Some(s) => s,
183        None => return "<module>".to_string(),
184    };
185
186    let mut best: Option<&SymbolEntry> = None;
187    for sym in syms {
188        if line >= sym.start_line && line <= sym.end_line {
189            match best {
190                None => best = Some(sym),
191                Some(prev) => {
192                    let prev_span = prev.end_line - prev.start_line;
193                    let cur_span = sym.end_line - sym.start_line;
194                    if cur_span < prev_span {
195                        best = Some(sym);
196                    }
197                }
198            }
199        }
200    }
201
202    best.map(|s| s.name.clone())
203        .unwrap_or_else(|| "<module>".to_string())
204}
205
206fn resolve_path(relative: &str, project_root: &str) -> String {
207    let p = Path::new(relative);
208    if p.is_absolute() && p.exists() {
209        return relative.to_string();
210    }
211    let relative = relative.trim_start_matches(['/', '\\']);
212    let joined = Path::new(project_root).join(relative);
213    joined.to_string_lossy().to_string()
214}
215
216fn simple_hash(content: &str) -> String {
217    use std::hash::{Hash, Hasher};
218    let mut hasher = std::collections::hash_map::DefaultHasher::new();
219    content.hash(&mut hasher);
220    format!("{:x}", hasher.finish())
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn callers_of_empty_graph() {
229        let graph = CallGraph::new("/tmp");
230        assert!(graph.callers_of("foo").is_empty());
231    }
232
233    #[test]
234    fn callers_of_finds_edges() {
235        let mut graph = CallGraph::new("/tmp");
236        graph.edges.push(CallEdge {
237            caller_file: "a.rs".to_string(),
238            caller_symbol: "bar".to_string(),
239            caller_line: 10,
240            callee_name: "foo".to_string(),
241        });
242        graph.edges.push(CallEdge {
243            caller_file: "b.rs".to_string(),
244            caller_symbol: "baz".to_string(),
245            caller_line: 20,
246            callee_name: "foo".to_string(),
247        });
248        graph.edges.push(CallEdge {
249            caller_file: "c.rs".to_string(),
250            caller_symbol: "qux".to_string(),
251            caller_line: 30,
252            callee_name: "other".to_string(),
253        });
254        let callers = graph.callers_of("foo");
255        assert_eq!(callers.len(), 2);
256    }
257
258    #[test]
259    fn callees_of_finds_edges() {
260        let mut graph = CallGraph::new("/tmp");
261        graph.edges.push(CallEdge {
262            caller_file: "a.rs".to_string(),
263            caller_symbol: "main".to_string(),
264            caller_line: 5,
265            callee_name: "init".to_string(),
266        });
267        graph.edges.push(CallEdge {
268            caller_file: "a.rs".to_string(),
269            caller_symbol: "main".to_string(),
270            caller_line: 6,
271            callee_name: "run".to_string(),
272        });
273        graph.edges.push(CallEdge {
274            caller_file: "a.rs".to_string(),
275            caller_symbol: "other".to_string(),
276            caller_line: 15,
277            callee_name: "init".to_string(),
278        });
279        let callees = graph.callees_of("main");
280        assert_eq!(callees.len(), 2);
281    }
282
283    #[test]
284    fn find_enclosing_picks_narrowest() {
285        let outer = SymbolEntry {
286            file: "a.rs".to_string(),
287            name: "Outer".to_string(),
288            kind: "struct".to_string(),
289            start_line: 1,
290            end_line: 50,
291            is_exported: true,
292        };
293        let inner = SymbolEntry {
294            file: "a.rs".to_string(),
295            name: "inner_fn".to_string(),
296            kind: "fn".to_string(),
297            start_line: 10,
298            end_line: 20,
299            is_exported: false,
300        };
301        let syms = vec![&outer, &inner];
302        let result = find_enclosing_symbol(Some(&syms), 15);
303        assert_eq!(result, "inner_fn");
304    }
305
306    #[test]
307    fn find_enclosing_returns_module_when_no_match() {
308        let sym = SymbolEntry {
309            file: "a.rs".to_string(),
310            name: "foo".to_string(),
311            kind: "fn".to_string(),
312            start_line: 10,
313            end_line: 20,
314            is_exported: false,
315        };
316        let syms = vec![&sym];
317        let result = find_enclosing_symbol(Some(&syms), 5);
318        assert_eq!(result, "<module>");
319    }
320
321    #[test]
322    fn resolve_path_trims_rooted_relative_prefix() {
323        let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
324        assert_eq!(
325            resolved,
326            Path::new(r"C:\repo")
327                .join(r"src\main\kotlin\Example.kt")
328                .to_string_lossy()
329                .to_string()
330        );
331    }
332}