1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6use super::deep_queries;
7use super::graph_index::{ProjectIndex, SymbolEntry};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallGraph {
11 pub project_root: String,
12 pub edges: Vec<CallEdge>,
13 pub file_hashes: HashMap<String, String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallEdge {
18 pub caller_file: String,
19 pub caller_symbol: String,
20 pub caller_line: usize,
21 pub callee_name: String,
22}
23
24impl CallGraph {
25 pub fn new(project_root: &str) -> Self {
26 Self {
27 project_root: project_root.to_string(),
28 edges: Vec::new(),
29 file_hashes: HashMap::new(),
30 }
31 }
32
33 pub fn build(index: &ProjectIndex) -> Self {
34 let project_root = &index.project_root;
35 let mut graph = Self::new(project_root);
36
37 let symbols_by_file = group_symbols_by_file(index);
38
39 for rel_path in index.files.keys() {
40 let abs_path = resolve_path(rel_path, project_root);
41 let content = match std::fs::read_to_string(&abs_path) {
42 Ok(c) => c,
43 Err(_) => continue,
44 };
45
46 let hash = simple_hash(&content);
47 graph.file_hashes.insert(rel_path.clone(), hash);
48
49 let ext = Path::new(rel_path)
50 .extension()
51 .and_then(|e| e.to_str())
52 .unwrap_or("");
53
54 let analysis = deep_queries::analyze(&content, ext);
55 let file_symbols = symbols_by_file.get(rel_path.as_str());
56
57 for call in &analysis.calls {
58 let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
59 graph.edges.push(CallEdge {
60 caller_file: rel_path.clone(),
61 caller_symbol: caller_sym,
62 caller_line: call.line + 1,
63 callee_name: call.callee.clone(),
64 });
65 }
66 }
67
68 graph
69 }
70
71 pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
72 let project_root = &index.project_root;
73 let mut graph = Self::new(project_root);
74 let symbols_by_file = group_symbols_by_file(index);
75
76 for rel_path in index.files.keys() {
77 let abs_path = resolve_path(rel_path, project_root);
78 let content = match std::fs::read_to_string(&abs_path) {
79 Ok(c) => c,
80 Err(_) => continue,
81 };
82
83 let hash = simple_hash(&content);
84 let changed = previous
85 .file_hashes
86 .get(rel_path)
87 .map(|old| old != &hash)
88 .unwrap_or(true);
89
90 graph.file_hashes.insert(rel_path.clone(), hash);
91
92 if !changed {
93 let old_edges: Vec<_> = previous
94 .edges
95 .iter()
96 .filter(|e| e.caller_file == rel_path.as_str())
97 .cloned()
98 .collect();
99 graph.edges.extend(old_edges);
100 continue;
101 }
102
103 let ext = Path::new(rel_path)
104 .extension()
105 .and_then(|e| e.to_str())
106 .unwrap_or("");
107
108 let analysis = deep_queries::analyze(&content, ext);
109 let file_symbols = symbols_by_file.get(rel_path.as_str());
110
111 for call in &analysis.calls {
112 let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
113 graph.edges.push(CallEdge {
114 caller_file: rel_path.clone(),
115 caller_symbol: caller_sym,
116 caller_line: call.line + 1,
117 callee_name: call.callee.clone(),
118 });
119 }
120 }
121
122 graph
123 }
124
125 pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
126 let sym_lower = symbol.to_lowercase();
127 self.edges
128 .iter()
129 .filter(|e| e.callee_name.to_lowercase() == sym_lower)
130 .collect()
131 }
132
133 pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
134 let sym_lower = symbol.to_lowercase();
135 self.edges
136 .iter()
137 .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
138 .collect()
139 }
140
141 pub fn save(&self) -> Result<(), String> {
142 let dir = call_graph_dir(&self.project_root)
143 .ok_or_else(|| "Cannot determine home directory".to_string())?;
144 std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
145 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
146 std::fs::write(dir.join("call_graph.json"), json).map_err(|e| e.to_string())
147 }
148
149 pub fn load(project_root: &str) -> Option<Self> {
150 let dir = call_graph_dir(project_root)?;
151 let path = dir.join("call_graph.json");
152 let content = std::fs::read_to_string(path).ok()?;
153 serde_json::from_str(&content).ok()
154 }
155
156 pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
157 if let Some(previous) = Self::load(project_root) {
158 Self::build_incremental(index, &previous)
159 } else {
160 Self::build(index)
161 }
162 }
163}
164
165fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
166 ProjectIndex::index_dir(project_root)
167}
168
169fn group_symbols_by_file(index: &ProjectIndex) -> HashMap<&str, Vec<&SymbolEntry>> {
170 let mut map: HashMap<&str, Vec<&SymbolEntry>> = HashMap::new();
171 for sym in index.symbols.values() {
172 map.entry(sym.file.as_str()).or_default().push(sym);
173 }
174 for syms in map.values_mut() {
175 syms.sort_by_key(|s| s.start_line);
176 }
177 map
178}
179
180fn find_enclosing_symbol(file_symbols: Option<&Vec<&SymbolEntry>>, line: usize) -> String {
181 let syms = match file_symbols {
182 Some(s) => s,
183 None => return "<module>".to_string(),
184 };
185
186 let mut best: Option<&SymbolEntry> = None;
187 for sym in syms {
188 if line >= sym.start_line && line <= sym.end_line {
189 match best {
190 None => best = Some(sym),
191 Some(prev) => {
192 let prev_span = prev.end_line - prev.start_line;
193 let cur_span = sym.end_line - sym.start_line;
194 if cur_span < prev_span {
195 best = Some(sym);
196 }
197 }
198 }
199 }
200 }
201
202 best.map(|s| s.name.clone())
203 .unwrap_or_else(|| "<module>".to_string())
204}
205
206fn resolve_path(relative: &str, project_root: &str) -> String {
207 let p = Path::new(relative);
208 if p.is_absolute() && p.exists() {
209 return relative.to_string();
210 }
211 let joined = Path::new(project_root).join(relative);
212 joined.to_string_lossy().to_string()
213}
214
215fn simple_hash(content: &str) -> String {
216 use std::hash::{Hash, Hasher};
217 let mut hasher = std::collections::hash_map::DefaultHasher::new();
218 content.hash(&mut hasher);
219 format!("{:x}", hasher.finish())
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn callers_of_empty_graph() {
228 let graph = CallGraph::new("/tmp");
229 assert!(graph.callers_of("foo").is_empty());
230 }
231
232 #[test]
233 fn callers_of_finds_edges() {
234 let mut graph = CallGraph::new("/tmp");
235 graph.edges.push(CallEdge {
236 caller_file: "a.rs".to_string(),
237 caller_symbol: "bar".to_string(),
238 caller_line: 10,
239 callee_name: "foo".to_string(),
240 });
241 graph.edges.push(CallEdge {
242 caller_file: "b.rs".to_string(),
243 caller_symbol: "baz".to_string(),
244 caller_line: 20,
245 callee_name: "foo".to_string(),
246 });
247 graph.edges.push(CallEdge {
248 caller_file: "c.rs".to_string(),
249 caller_symbol: "qux".to_string(),
250 caller_line: 30,
251 callee_name: "other".to_string(),
252 });
253 let callers = graph.callers_of("foo");
254 assert_eq!(callers.len(), 2);
255 }
256
257 #[test]
258 fn callees_of_finds_edges() {
259 let mut graph = CallGraph::new("/tmp");
260 graph.edges.push(CallEdge {
261 caller_file: "a.rs".to_string(),
262 caller_symbol: "main".to_string(),
263 caller_line: 5,
264 callee_name: "init".to_string(),
265 });
266 graph.edges.push(CallEdge {
267 caller_file: "a.rs".to_string(),
268 caller_symbol: "main".to_string(),
269 caller_line: 6,
270 callee_name: "run".to_string(),
271 });
272 graph.edges.push(CallEdge {
273 caller_file: "a.rs".to_string(),
274 caller_symbol: "other".to_string(),
275 caller_line: 15,
276 callee_name: "init".to_string(),
277 });
278 let callees = graph.callees_of("main");
279 assert_eq!(callees.len(), 2);
280 }
281
282 #[test]
283 fn find_enclosing_picks_narrowest() {
284 let outer = SymbolEntry {
285 file: "a.rs".to_string(),
286 name: "Outer".to_string(),
287 kind: "struct".to_string(),
288 start_line: 1,
289 end_line: 50,
290 is_exported: true,
291 };
292 let inner = SymbolEntry {
293 file: "a.rs".to_string(),
294 name: "inner_fn".to_string(),
295 kind: "fn".to_string(),
296 start_line: 10,
297 end_line: 20,
298 is_exported: false,
299 };
300 let syms = vec![&outer, &inner];
301 let result = find_enclosing_symbol(Some(&syms), 15);
302 assert_eq!(result, "inner_fn");
303 }
304
305 #[test]
306 fn find_enclosing_returns_module_when_no_match() {
307 let sym = SymbolEntry {
308 file: "a.rs".to_string(),
309 name: "foo".to_string(),
310 kind: "fn".to_string(),
311 start_line: 10,
312 end_line: 20,
313 is_exported: false,
314 };
315 let syms = vec![&sym];
316 let result = find_enclosing_symbol(Some(&syms), 5);
317 assert_eq!(result, "<module>");
318 }
319}