1use std::collections::HashMap;
2use std::path::Path;
3
4use serde::{Deserialize, Serialize};
5
6use super::deep_queries;
7use super::graph_index::{ProjectIndex, SymbolEntry};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallGraph {
11 pub project_root: String,
12 pub edges: Vec<CallEdge>,
13 pub file_hashes: HashMap<String, String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallEdge {
18 pub caller_file: String,
19 pub caller_symbol: String,
20 pub caller_line: usize,
21 pub callee_name: String,
22}
23
24impl CallGraph {
25 pub fn new(project_root: &str) -> Self {
26 Self {
27 project_root: project_root.to_string(),
28 edges: Vec::new(),
29 file_hashes: HashMap::new(),
30 }
31 }
32
33 pub fn build(index: &ProjectIndex) -> Self {
34 let project_root = &index.project_root;
35 let mut graph = Self::new(project_root);
36
37 let symbols_by_file = group_symbols_by_file(index);
38
39 for rel_path in index.files.keys() {
40 let abs_path = resolve_path(rel_path, project_root);
41 let content = match std::fs::read_to_string(&abs_path) {
42 Ok(c) => c,
43 Err(_) => continue,
44 };
45
46 let hash = simple_hash(&content);
47 graph.file_hashes.insert(rel_path.clone(), hash);
48
49 let ext = Path::new(rel_path)
50 .extension()
51 .and_then(|e| e.to_str())
52 .unwrap_or("");
53
54 let analysis = deep_queries::analyze(&content, ext);
55 let file_symbols = symbols_by_file.get(rel_path.as_str());
56
57 for call in &analysis.calls {
58 let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
59 graph.edges.push(CallEdge {
60 caller_file: rel_path.clone(),
61 caller_symbol: caller_sym,
62 caller_line: call.line + 1,
63 callee_name: call.callee.clone(),
64 });
65 }
66 }
67
68 graph
69 }
70
71 pub fn build_incremental(index: &ProjectIndex, previous: &CallGraph) -> Self {
72 let project_root = &index.project_root;
73 let mut graph = Self::new(project_root);
74 let symbols_by_file = group_symbols_by_file(index);
75
76 for rel_path in index.files.keys() {
77 let abs_path = resolve_path(rel_path, project_root);
78 let content = match std::fs::read_to_string(&abs_path) {
79 Ok(c) => c,
80 Err(_) => continue,
81 };
82
83 let hash = simple_hash(&content);
84 let changed = previous
85 .file_hashes
86 .get(rel_path)
87 .map(|old| old != &hash)
88 .unwrap_or(true);
89
90 graph.file_hashes.insert(rel_path.clone(), hash);
91
92 if !changed {
93 let old_edges: Vec<_> = previous
94 .edges
95 .iter()
96 .filter(|e| e.caller_file == rel_path.as_str())
97 .cloned()
98 .collect();
99 graph.edges.extend(old_edges);
100 continue;
101 }
102
103 let ext = Path::new(rel_path)
104 .extension()
105 .and_then(|e| e.to_str())
106 .unwrap_or("");
107
108 let analysis = deep_queries::analyze(&content, ext);
109 let file_symbols = symbols_by_file.get(rel_path.as_str());
110
111 for call in &analysis.calls {
112 let caller_sym = find_enclosing_symbol(file_symbols, call.line + 1);
113 graph.edges.push(CallEdge {
114 caller_file: rel_path.clone(),
115 caller_symbol: caller_sym,
116 caller_line: call.line + 1,
117 callee_name: call.callee.clone(),
118 });
119 }
120 }
121
122 graph
123 }
124
125 pub fn callers_of(&self, symbol: &str) -> Vec<&CallEdge> {
126 let sym_lower = symbol.to_lowercase();
127 self.edges
128 .iter()
129 .filter(|e| e.callee_name.to_lowercase() == sym_lower)
130 .collect()
131 }
132
133 pub fn callees_of(&self, symbol: &str) -> Vec<&CallEdge> {
134 let sym_lower = symbol.to_lowercase();
135 self.edges
136 .iter()
137 .filter(|e| e.caller_symbol.to_lowercase() == sym_lower)
138 .collect()
139 }
140
141 pub fn save(&self) -> Result<(), String> {
142 let dir = call_graph_dir(&self.project_root)
143 .ok_or_else(|| "Cannot determine home directory".to_string())?;
144 std::fs::create_dir_all(&dir).map_err(|e| e.to_string())?;
145 let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
146 std::fs::write(dir.join("call_graph.json"), json).map_err(|e| e.to_string())
147 }
148
149 pub fn load(project_root: &str) -> Option<Self> {
150 let dir = call_graph_dir(project_root)?;
151 let path = dir.join("call_graph.json");
152 let content = std::fs::read_to_string(path).ok()?;
153 serde_json::from_str(&content).ok()
154 }
155
156 pub fn load_or_build(project_root: &str, index: &ProjectIndex) -> Self {
157 if let Some(previous) = Self::load(project_root) {
158 Self::build_incremental(index, &previous)
159 } else {
160 Self::build(index)
161 }
162 }
163}
164
165fn call_graph_dir(project_root: &str) -> Option<std::path::PathBuf> {
166 ProjectIndex::index_dir(project_root)
167}
168
169fn group_symbols_by_file(index: &ProjectIndex) -> HashMap<&str, Vec<&SymbolEntry>> {
170 let mut map: HashMap<&str, Vec<&SymbolEntry>> = HashMap::new();
171 for sym in index.symbols.values() {
172 map.entry(sym.file.as_str()).or_default().push(sym);
173 }
174 for syms in map.values_mut() {
175 syms.sort_by_key(|s| s.start_line);
176 }
177 map
178}
179
180fn find_enclosing_symbol(file_symbols: Option<&Vec<&SymbolEntry>>, line: usize) -> String {
181 let syms = match file_symbols {
182 Some(s) => s,
183 None => return "<module>".to_string(),
184 };
185
186 let mut best: Option<&SymbolEntry> = None;
187 for sym in syms {
188 if line >= sym.start_line && line <= sym.end_line {
189 match best {
190 None => best = Some(sym),
191 Some(prev) => {
192 let prev_span = prev.end_line - prev.start_line;
193 let cur_span = sym.end_line - sym.start_line;
194 if cur_span < prev_span {
195 best = Some(sym);
196 }
197 }
198 }
199 }
200 }
201
202 best.map(|s| s.name.clone())
203 .unwrap_or_else(|| "<module>".to_string())
204}
205
206fn resolve_path(relative: &str, project_root: &str) -> String {
207 let p = Path::new(relative);
208 if p.is_absolute() && p.exists() {
209 return relative.to_string();
210 }
211 let relative = relative.trim_start_matches(['/', '\\']);
212 let joined = Path::new(project_root).join(relative);
213 joined.to_string_lossy().to_string()
214}
215
216fn simple_hash(content: &str) -> String {
217 use std::hash::{Hash, Hasher};
218 let mut hasher = std::collections::hash_map::DefaultHasher::new();
219 content.hash(&mut hasher);
220 format!("{:x}", hasher.finish())
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn callers_of_empty_graph() {
229 let graph = CallGraph::new("/tmp");
230 assert!(graph.callers_of("foo").is_empty());
231 }
232
233 #[test]
234 fn callers_of_finds_edges() {
235 let mut graph = CallGraph::new("/tmp");
236 graph.edges.push(CallEdge {
237 caller_file: "a.rs".to_string(),
238 caller_symbol: "bar".to_string(),
239 caller_line: 10,
240 callee_name: "foo".to_string(),
241 });
242 graph.edges.push(CallEdge {
243 caller_file: "b.rs".to_string(),
244 caller_symbol: "baz".to_string(),
245 caller_line: 20,
246 callee_name: "foo".to_string(),
247 });
248 graph.edges.push(CallEdge {
249 caller_file: "c.rs".to_string(),
250 caller_symbol: "qux".to_string(),
251 caller_line: 30,
252 callee_name: "other".to_string(),
253 });
254 let callers = graph.callers_of("foo");
255 assert_eq!(callers.len(), 2);
256 }
257
258 #[test]
259 fn callees_of_finds_edges() {
260 let mut graph = CallGraph::new("/tmp");
261 graph.edges.push(CallEdge {
262 caller_file: "a.rs".to_string(),
263 caller_symbol: "main".to_string(),
264 caller_line: 5,
265 callee_name: "init".to_string(),
266 });
267 graph.edges.push(CallEdge {
268 caller_file: "a.rs".to_string(),
269 caller_symbol: "main".to_string(),
270 caller_line: 6,
271 callee_name: "run".to_string(),
272 });
273 graph.edges.push(CallEdge {
274 caller_file: "a.rs".to_string(),
275 caller_symbol: "other".to_string(),
276 caller_line: 15,
277 callee_name: "init".to_string(),
278 });
279 let callees = graph.callees_of("main");
280 assert_eq!(callees.len(), 2);
281 }
282
283 #[test]
284 fn find_enclosing_picks_narrowest() {
285 let outer = SymbolEntry {
286 file: "a.rs".to_string(),
287 name: "Outer".to_string(),
288 kind: "struct".to_string(),
289 start_line: 1,
290 end_line: 50,
291 is_exported: true,
292 };
293 let inner = SymbolEntry {
294 file: "a.rs".to_string(),
295 name: "inner_fn".to_string(),
296 kind: "fn".to_string(),
297 start_line: 10,
298 end_line: 20,
299 is_exported: false,
300 };
301 let syms = vec![&outer, &inner];
302 let result = find_enclosing_symbol(Some(&syms), 15);
303 assert_eq!(result, "inner_fn");
304 }
305
306 #[test]
307 fn find_enclosing_returns_module_when_no_match() {
308 let sym = SymbolEntry {
309 file: "a.rs".to_string(),
310 name: "foo".to_string(),
311 kind: "fn".to_string(),
312 start_line: 10,
313 end_line: 20,
314 is_exported: false,
315 };
316 let syms = vec![&sym];
317 let result = find_enclosing_symbol(Some(&syms), 5);
318 assert_eq!(result, "<module>");
319 }
320
321 #[test]
322 fn resolve_path_trims_rooted_relative_prefix() {
323 let resolved = resolve_path(r"\src\main\kotlin\Example.kt", r"C:\repo");
324 assert_eq!(
325 resolved,
326 Path::new(r"C:\repo")
327 .join(r"src\main\kotlin\Example.kt")
328 .to_string_lossy()
329 .to_string()
330 );
331 }
332}