Skip to main content

lean_ctx/core/
call_graph.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6use super::deep_queries;
7use super::graph_index::{ProjectIndex, SymbolEntry};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallGraph {
11    pub project_root: String,
12    pub edges: Vec<CallEdge>,
13    pub file_hashes: HashMap<String, String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallEdge {
18    pub caller_file: String,
19    pub caller_symbol: String,
20    pub caller_line: usize,
21    pub callee_name: String,
22}
23
24impl CallGraph {
25    pub fn new(project_root: &str) -> Self {
26        Self {
27            project_root: project_root.to_string(),
28            edges: Vec::new(),
29            file_hashes: HashMap::new(),
30        }
31    }
32
33    pub fn build(index: &ProjectIndex) -> Self {
34        let project_root = &index.project_root;
35        let mut graph = Self::new(project_root);
36
37        let symbols_by_file = group_symbols_by_file(index);
38
39        for rel_path in index.files.keys() {
40            let abs_path = resolve_path(rel_path, project_root);
41            let content = match std::fs::read_to_string(&abs_path) {
42                Ok(c) => c,
43                Err(_) => continue,
44            };
45
46            let hash = simple_hash(&content);
47            graph.file_hashes.insert(rel_path.clone(), hash);
48
49            let ext = Path::new(rel_path)
50                .extension()
51                .and_then(|e| e.to_str())
52                .unwrap_or("");
53
54            let analysis = deep_queries::analyze(&content, ext);
55            let file_symbols = symbols_by_file.get(rel_path.as_str());
56
57            for call in &analysis.calls {
58                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
59                graph.edges.push(CallEdge {
60                    caller_file: rel_path.clone(),
61                    caller_symbol: caller_sym,
62                    caller_line: call.line + 1,
63                    callee_name: call.callee.clone(),
64                });
65            }
66        }
67
68        graph
69    }
70
71    pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
72        let project_root = &index.project_root;
73        let mut graph = Self::new(project_root);
74        let symbols_by_file = group_symbols_by_file(index);
75
76        for rel_path in index.files.keys() {
77            let abs_path = resolve_path(rel_path, project_root);
78            let content = match std::fs::read_to_string(&abs_path) {
79                Ok(c) => c,
80                Err(_) => continue,
81            };
82
83            let hash = simple_hash(&content);
84            let changed = previous
85                .file_hashes
86                .get(rel_path)
87                .map(|old| old != &hash)
88                .unwrap_or(true);
89
90            graph.file_hashes.insert(rel_path.clone(), hash);
91
92            if !changed {
93                let old_edges: Vec<_> = previous
94                    .edges
95                    .iter()
96                    .filter(|e| e.caller_file == rel_path.as_str())
97                    .cloned()
98                    .collect();
99                graph.edges.extend(old_edges);
100                continue;
101            }
102
103            let ext = Path::new(rel_path)
104                .extension()
105                .and_then(|e| e.to_str())
106                .unwrap_or("");
107
108            let analysis = deep_queries::analyze(&content, ext);
109            let file_symbols = symbols_by_file.get(rel_path.as_str());
110
111            for call in &analysis.calls {
112                let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
113                graph.edges.push(CallEdge {
114                    caller_file: rel_path.clone(),
115                    caller_symbol: caller_sym,
116                    caller_line: call.line + 1,
117                    callee_name: call.callee.clone(),
118                });
119            }
120        }
121
122        graph
123    }
124
125    pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
126        let sym_lower = symbol.to_lowercase();
127        self.edges
128            .iter()
129            .filter(|e| e.callee_name.to_lowercase() == sym_lower)
130            .collect()
131    }
132
133    pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
134        let sym_lower = symbol.to_lowercase();
135        self.edges
136            .iter()
137            .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
138            .collect()
139    }
140
141    pub fn save(&self) -> Result<(), String> {
142        let dir = call_graph_dir(&self.project_root)
143            .ok_or_else(|| "Cannot determine home directory".to_string())?;
144        std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
145        let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
146        std::fs::write(dir.join("call_graph.json"), json).map_err(|e| e.to_string())
147    }
148
149    pub fn load(project_root: &str) -> Option<Self> {
150        let dir = call_graph_dir(project_root)?;
151        let path = dir.join("call_graph.json");
152        let content = std::fs::read_to_string(path).ok()?;
153        serde_json::from_str(&content).ok()
154    }
155
156    pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
157        if let Some(previous) = Self::load(project_root) {
158            Self::build_incremental(index, &previous)
159        } else {
160            Self::build(index)
161        }
162    }
163}
164
165fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
166    ProjectIndex::index_dir(project_root)
167}
168
169fn group_symbols_by_file(index: &ProjectIndex) -> HashMap<&str, Vec<&SymbolEntry>> {
170    let mut map: HashMap<&str, Vec<&SymbolEntry>> = HashMap::new();
171    for sym in index.symbols.values() {
172        map.entry(sym.file.as_str()).or_default().push(sym);
173    }
174    for syms in map.values_mut() {
175        syms.sort_by_key(|s| s.start_line);
176    }
177    map
178}
179
180fn find_enclosing_symbol(file_symbols: Option<&Vec<&SymbolEntry>>, line: usize) -> String {
181    let syms = match file_symbols {
182        Some(s) => s,
183        None => return "<module>".to_string(),
184    };
185
186    let mut best: Option<&SymbolEntry> = None;
187    for sym in syms {
188        if line >= sym.start_line && line <= sym.end_line {
189            match best {
190                None => best = Some(sym),
191                Some(prev) => {
192                    let prev_span = prev.end_line - prev.start_line;
193                    let cur_span = sym.end_line - sym.start_line;
194                    if cur_span < prev_span {
195                        best = Some(sym);
196                    }
197                }
198            }
199        }
200    }
201
202    best.map(|s| s.name.clone())
203        .unwrap_or_else(|| "<module>".to_string())
204}
205
206fn resolve_path(relative: &str, project_root: &str) -> String {
207    let p = Path::new(relative);
208    if p.is_absolute() && p.exists() {
209        return relative.to_string();
210    }
211    let joined = Path::new(project_root).join(relative);
212    joined.to_string_lossy().to_string()
213}
214
215fn simple_hash(content: &str) -> String {
216    use std::hash::{Hash, Hasher};
217    let mut hasher = std::collections::hash_map::DefaultHasher::new();
218    content.hash(&mut hasher);
219    format!("{:x}", hasher.finish())
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn callers_of_empty_graph() {
228        let graph = CallGraph::new("/tmp");
229        assert!(graph.callers_of("foo").is_empty());
230    }
231
232    #[test]
233    fn callers_of_finds_edges() {
234        let mut graph = CallGraph::new("/tmp");
235        graph.edges.push(CallEdge {
236            caller_file: "a.rs".to_string(),
237            caller_symbol: "bar".to_string(),
238            caller_line: 10,
239            callee_name: "foo".to_string(),
240        });
241        graph.edges.push(CallEdge {
242            caller_file: "b.rs".to_string(),
243            caller_symbol: "baz".to_string(),
244            caller_line: 20,
245            callee_name: "foo".to_string(),
246        });
247        graph.edges.push(CallEdge {
248            caller_file: "c.rs".to_string(),
249            caller_symbol: "qux".to_string(),
250            caller_line: 30,
251            callee_name: "other".to_string(),
252        });
253        let callers = graph.callers_of("foo");
254        assert_eq!(callers.len(), 2);
255    }
256
257    #[test]
258    fn callees_of_finds_edges() {
259        let mut graph = CallGraph::new("/tmp");
260        graph.edges.push(CallEdge {
261            caller_file: "a.rs".to_string(),
262            caller_symbol: "main".to_string(),
263            caller_line: 5,
264            callee_name: "init".to_string(),
265        });
266        graph.edges.push(CallEdge {
267            caller_file: "a.rs".to_string(),
268            caller_symbol: "main".to_string(),
269            caller_line: 6,
270            callee_name: "run".to_string(),
271        });
272        graph.edges.push(CallEdge {
273            caller_file: "a.rs".to_string(),
274            caller_symbol: "other".to_string(),
275            caller_line: 15,
276            callee_name: "init".to_string(),
277        });
278        let callees = graph.callees_of("main");
279        assert_eq!(callees.len(), 2);
280    }
281
282    #[test]
283    fn find_enclosing_picks_narrowest() {
284        let outer = SymbolEntry {
285            file: "a.rs".to_string(),
286            name: "Outer".to_string(),
287            kind: "struct".to_string(),
288            start_line: 1,
289            end_line: 50,
290            is_exported: true,
291        };
292        let inner = SymbolEntry {
293            file: "a.rs".to_string(),
294            name: "inner_fn".to_string(),
295            kind: "fn".to_string(),
296            start_line: 10,
297            end_line: 20,
298            is_exported: false,
299        };
300        let syms = vec![&outer, &inner];
301        let result = find_enclosing_symbol(Some(&syms), 15);
302        assert_eq!(result, "inner_fn");
303    }
304
305    #[test]
306    fn find_enclosing_returns_module_when_no_match() {
307        let sym = SymbolEntry {
308            file: "a.rs".to_string(),
309            name: "foo".to_string(),
310            kind: "fn".to_string(),
311            start_line: 10,
312            end_line: 20,
313            is_exported: false,
314        };
315        let syms = vec![&sym];
316        let result = find_enclosing_symbol(Some(&syms), 5);
317        assert_eq!(result, "<module>");
318    }
319}