1use std::collections::{HashMap, HashSet};
26
27use tree_sitter::Node;
28
29use crate::lang::Lang;
30use crate::skeleton::{self, parse};
31
32type Sym = (usize, String);
36
37#[derive(Debug, Default, Clone)]
39pub struct SymbolGraph {
40 edges: HashMap<Sym, HashSet<Sym>>,
41 file_count: usize,
42}
43
44impl SymbolGraph {
45 pub fn from_source(source: &str, lang: Lang) -> Self {
47 let mut defs = Vec::new();
48 collect_symbol_refs(source, lang, &mut defs);
49 SymbolGraph::link(vec![defs])
50 }
51
52 pub fn from_sources<'a>(sources: impl IntoIterator<Item = (Lang, &'a str)>) -> Self {
58 let per_file: Vec<Vec<(String, HashSet<String>)>> = sources
62 .into_iter()
63 .map(|(lang, source)| {
64 let mut defs = Vec::new();
65 collect_symbol_refs(source, lang, &mut defs);
66 defs
67 })
68 .collect();
69 SymbolGraph::link(per_file)
70 }
71
72 fn link(per_file: Vec<Vec<(String, HashSet<String>)>>) -> Self {
77 let file_count = per_file.len();
78 let mut defined_in: HashMap<&str, Vec<usize>> = HashMap::new();
80 for (fi, defs) in per_file.iter().enumerate() {
81 for (name, _) in defs {
82 defined_in.entry(name.as_str()).or_default().push(fi);
83 }
84 }
85 let mut edges: HashMap<Sym, HashSet<Sym>> = HashMap::new();
86 for (fi, defs) in per_file.iter().enumerate() {
87 let local: HashSet<&str> = defs.iter().map(|(n, _)| n.as_str()).collect();
88 for (name, refs) in defs {
89 let entry = edges.entry((fi, name.clone())).or_default();
90 for r in refs {
91 if r == name {
92 continue; }
94 if local.contains(r.as_str()) {
95 entry.insert((fi, r.clone())); } else if let Some(files) = defined_in.get(r.as_str()) {
97 for &tf in files {
98 entry.insert((tf, r.clone())); }
100 }
101 }
102 }
103 }
104 SymbolGraph { edges, file_count }
105 }
106
107 pub fn neighbors_within(&self, target: &str, radius: u8) -> HashSet<String> {
112 self.reach(target, radius)
113 .into_iter()
114 .map(|(_, name)| name)
115 .collect()
116 }
117
118 pub fn neighbors_within_by_file(&self, target: &str, radius: u8) -> Vec<HashSet<String>> {
124 let mut per_file = vec![HashSet::new(); self.file_count.max(1)];
125 for (fi, name) in self.reach(target, radius) {
126 if let Some(set) = per_file.get_mut(fi) {
127 set.insert(name);
128 }
129 }
130 per_file
131 }
132
133 fn reach(&self, target: &str, radius: u8) -> HashSet<Sym> {
136 let mut visited: HashSet<Sym> = HashSet::new();
137 let mut frontier: Vec<Sym> = Vec::new();
138 for fi in 0..self.file_count.max(1) {
140 let sym = (fi, target.to_string());
141 if self.edges.contains_key(&sym) && visited.insert(sym.clone()) {
142 frontier.push(sym);
143 }
144 }
145 if visited.is_empty() {
147 visited.insert((0, target.to_string()));
148 }
149 for _ in 0..radius {
150 let mut next = Vec::new();
151 for node in &frontier {
152 if let Some(refs) = self.edges.get(node) {
153 for r in refs {
154 if visited.insert(r.clone()) {
155 next.push(r.clone());
156 }
157 }
158 }
159 }
160 if next.is_empty() {
161 break;
162 }
163 frontier = next;
164 }
165 visited
166 }
167
168 pub fn names(&self) -> impl Iterator<Item = &str> {
170 self.edges
171 .keys()
172 .map(|(_, name)| name.as_str())
173 .collect::<HashSet<_>>()
174 .into_iter()
175 }
176}
177
178pub fn find_identifier_lines(source: &str, lang: Lang, name: &str) -> Option<Vec<(usize, String)>> {
192 let tree = parse(source, lang)?;
193 let spec = lang.spec();
194 let mut rows: HashSet<usize> = HashSet::new();
195 collect_named_ref_rows(tree.root_node(), source, &spec, name, &mut rows);
196 let src_lines: Vec<&str> = source.lines().collect();
197 let mut rows: Vec<usize> = rows.into_iter().collect();
198 rows.sort_unstable();
199 Some(
200 rows.into_iter()
201 .map(|row| {
202 let text = src_lines
203 .get(row)
204 .map(|s| s.trim())
205 .unwrap_or("")
206 .to_string();
207 (row + 1, text)
208 })
209 .collect(),
210 )
211}
212
213fn collect_named_ref_rows(
215 node: Node,
216 source: &str,
217 spec: &crate::lang::LangSpec,
218 name: &str,
219 out: &mut HashSet<usize>,
220) {
221 if spec.ref_ident_kinds.contains(&node.kind()) && text(node, source) == name {
222 out.insert(node.start_position().row);
223 }
224 let mut cursor = node.walk();
225 for child in node.children(&mut cursor) {
226 collect_named_ref_rows(child, source, spec, name, out);
227 }
228}
229
230pub fn skeleton_with_radius(source: &str, lang: Lang, target: &str, radius: u8) -> String {
234 let keep = SymbolGraph::from_source(source, lang).neighbors_within(target, radius);
235 skeleton::skeletonize(source, lang, &keep)
236}
237
238fn collect_symbol_refs(source: &str, lang: Lang, out: &mut Vec<(String, HashSet<String>)>) {
243 let Some(tree) = parse(source, lang) else {
244 return;
245 };
246 let spec = lang.spec();
247 walk_symbols(tree.root_node(), source, &spec, out);
248}
249
250fn walk_symbols(
251 node: Node,
252 source: &str,
253 spec: &crate::lang::LangSpec,
254 out: &mut Vec<(String, HashSet<String>)>,
255) {
256 if spec.symbol_kinds.contains(&node.kind()) {
257 if let Some(name) = node.child_by_field_name("name").map(|n| text(n, source)) {
258 let mut bound = HashSet::new();
259 collect_bound(node, source, spec, &mut bound);
260 let mut refs = HashSet::new();
261 collect_refs(node, source, spec, &bound, &mut refs);
262 refs.remove(&name); out.push((name, refs));
264 }
265 }
266 let mut cursor = node.walk();
267 for child in node.children(&mut cursor) {
268 walk_symbols(child, source, spec, out);
269 }
270}
271
272fn collect_bound(
278 node: Node,
279 source: &str,
280 spec: &crate::lang::LangSpec,
281 out: &mut HashSet<String>,
282) {
283 if spec.binder_kinds.contains(&node.kind()) {
284 match node
285 .child_by_field_name("pattern")
286 .or_else(|| node.child_by_field_name("name"))
287 {
288 Some(target) => collect_idents(target, source, spec, out),
289 None => {
290 let mut cursor = node.walk();
291 for child in node.children(&mut cursor) {
292 collect_idents(child, source, spec, out);
293 }
294 }
295 }
296 }
297 let mut cursor = node.walk();
298 for child in node.children(&mut cursor) {
299 collect_bound(child, source, spec, out);
300 }
301}
302
303fn collect_refs(
305 node: Node,
306 source: &str,
307 spec: &crate::lang::LangSpec,
308 bound: &HashSet<String>,
309 out: &mut HashSet<String>,
310) {
311 if spec.ref_ident_kinds.contains(&node.kind()) {
312 let t = text(node, source);
313 if !bound.contains(&t) {
314 out.insert(t);
315 }
316 }
317 let mut cursor = node.walk();
318 for child in node.children(&mut cursor) {
319 collect_refs(child, source, spec, bound, out);
320 }
321}
322
323fn collect_idents(
326 node: Node,
327 source: &str,
328 spec: &crate::lang::LangSpec,
329 out: &mut HashSet<String>,
330) {
331 if spec.ref_ident_kinds.contains(&node.kind()) {
332 out.insert(text(node, source));
333 }
334 let mut cursor = node.walk();
335 for child in node.children(&mut cursor) {
336 collect_idents(child, source, spec, out);
337 }
338}
339
340fn text(node: Node, source: &str) -> String {
341 source[node.start_byte()..node.end_byte()].to_string()
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 const SRC: &str = "\
349fn helper(x: i32) -> i32 {
350 x + 1
351}
352
353fn target() -> i32 {
354 helper(41)
355}
356
357fn unrelated() -> i32 {
358 7
359}
360";
361
362 #[test]
363 fn references_in_strings_and_comments_do_not_link() {
364 let src = "\
367fn helper() -> i32 { 1 }
368fn target() -> i32 {
369 // call helper here later
370 let s = \"remember to call helper\";
371 2
372}
373";
374 let g = SymbolGraph::from_source(src, Lang::Rust);
375 assert!(
376 !g.neighbors_within("target", 1).contains("helper"),
377 "string/comment mention must not create an edge"
378 );
379 }
380
381 #[test]
382 fn a_local_shadowing_a_symbol_name_does_not_link() {
383 let src = "\
385fn helper() -> i32 { 1 }
386fn target() -> i32 {
387 let helper = 5;
388 helper + 1
389}
390";
391 let g = SymbolGraph::from_source(src, Lang::Rust);
392 assert!(
393 !g.neighbors_within("target", 1).contains("helper"),
394 "a shadowing local must not link to the same-named symbol"
395 );
396 }
397
398 #[test]
399 fn graph_links_caller_to_callee() {
400 let g = SymbolGraph::from_source(SRC, Lang::Rust);
401 let within1 = g.neighbors_within("target", 1);
402 assert!(within1.contains("target"));
403 assert!(within1.contains("helper"));
404 assert!(!within1.contains("unrelated"));
405 }
406
407 #[test]
408 fn radius_zero_is_just_the_target() {
409 let g = SymbolGraph::from_source(SRC, Lang::Rust);
410 let within0 = g.neighbors_within("target", 0);
411 assert_eq!(within0.len(), 1);
412 assert!(within0.contains("target"));
413 }
414
415 #[test]
416 fn skeleton_with_radius_keeps_only_dependencies() {
417 let out = skeleton_with_radius(SRC, Lang::Rust, "target", 1);
419 assert!(out.contains("helper(41)"), "target body kept: {out}");
420 assert!(out.contains("x + 1"), "helper body kept: {out}");
421 assert!(!out.contains(" 7\n"), "unrelated body elided: {out}");
422 }
423
424 #[test]
425 fn radius_zero_elides_dependencies_too() {
426 let out = skeleton_with_radius(SRC, Lang::Rust, "target", 0);
427 assert!(out.contains("helper(41)"), "target body kept: {out}");
428 assert!(
429 !out.contains("x + 1"),
430 "helper body should be elided at radius 0: {out}"
431 );
432 }
433
434 #[test]
435 fn find_identifier_lines_matches_code_not_strings_or_comments() {
436 let src = "\
437fn helper() -> i32 { 1 }
438fn target() -> i32 {
439 // helper is mentioned in this comment
440 let s = \"call helper here\";
441 helper()
442}
443";
444 let hits = find_identifier_lines(src, Lang::Rust, "helper").unwrap();
445 let rows: Vec<usize> = hits.iter().map(|(r, _)| *r).collect();
446 assert_eq!(rows, vec![1, 5], "got {hits:?}");
448 assert!(hits.iter().any(|(r, t)| *r == 5 && t.contains("helper()")));
449 }
450
451 #[test]
452 fn find_identifier_lines_dedups_a_line_with_two_uses() {
453 let src = "fn f() -> i32 { g() + g() }\nfn g() -> i32 { 1 }\n";
454 let hits = find_identifier_lines(src, Lang::Rust, "g").unwrap();
455 assert_eq!(hits.iter().map(|(r, _)| *r).collect::<Vec<_>>(), vec![1, 2]);
457 }
458
459 #[test]
460 fn multi_file_graph_links_across_sources() {
461 let a = "fn caller() -> i32 { shared() }";
462 let b = "fn shared() -> i32 { 5 }";
463 let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
464 assert!(g.neighbors_within("caller", 1).contains("shared"));
465 }
466
467 #[test]
468 fn same_name_across_files_resolves_to_the_local_definition() {
469 let a = "\
474fn helper() -> i32 { 0 }
475fn target() -> i32 { helper() }
476";
477 let b = "\
478fn helper() -> i32 { dep() }
479fn dep() -> i32 { 9 }
480";
481 let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
482 let within2 = g.neighbors_within("target", 2);
483 assert!(
484 within2.contains("helper"),
485 "local helper linked: {within2:?}"
486 );
487 assert!(
488 !within2.contains("dep"),
489 "must not reach the OTHER file's helper dependency: {within2:?}"
490 );
491 }
492
493 #[test]
494 fn neighbors_by_file_keeps_a_body_only_in_its_owning_file() {
495 let a = "fn target() -> i32 { repo() }\nfn noise_a() -> i32 { 0 }";
498 let b = "fn repo() -> i32 { 5 }\nfn noise_b() -> i32 { 1 }";
499 let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
500 let per_file = g.neighbors_within_by_file("target", 1);
501 assert_eq!(per_file.len(), 2);
502 assert!(per_file[0].contains("target") && !per_file[0].contains("repo"));
503 assert!(per_file[1].contains("repo") && !per_file[1].contains("target"));
504 }
505}