use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
use crate::lang::Lang;
use crate::skeleton::{self, parse};
type Sym = (usize, String);
#[derive(Debug, Default, Clone)]
pub struct SymbolGraph {
edges: HashMap<Sym, HashSet<Sym>>,
file_count: usize,
}
impl SymbolGraph {
pub fn from_source(source: &str, lang: Lang) -> Self {
let mut defs = Vec::new();
collect_symbol_refs(source, lang, &mut defs);
SymbolGraph::link(vec![defs])
}
pub fn from_sources<'a>(sources: impl IntoIterator<Item = (Lang, &'a str)>) -> Self {
let per_file: Vec<Vec<(String, HashSet<String>)>> = sources
.into_iter()
.map(|(lang, source)| {
let mut defs = Vec::new();
collect_symbol_refs(source, lang, &mut defs);
defs
})
.collect();
SymbolGraph::link(per_file)
}
fn link(per_file: Vec<Vec<(String, HashSet<String>)>>) -> Self {
let file_count = per_file.len();
let mut defined_in: HashMap<&str, Vec<usize>> = HashMap::new();
for (fi, defs) in per_file.iter().enumerate() {
for (name, _) in defs {
defined_in.entry(name.as_str()).or_default().push(fi);
}
}
let mut edges: HashMap<Sym, HashSet<Sym>> = HashMap::new();
for (fi, defs) in per_file.iter().enumerate() {
let local: HashSet<&str> = defs.iter().map(|(n, _)| n.as_str()).collect();
for (name, refs) in defs {
let entry = edges.entry((fi, name.clone())).or_default();
for r in refs {
if r == name {
continue; }
if local.contains(r.as_str()) {
entry.insert((fi, r.clone())); } else if let Some(files) = defined_in.get(r.as_str()) {
for &tf in files {
entry.insert((tf, r.clone())); }
}
}
}
}
SymbolGraph { edges, file_count }
}
pub fn neighbors_within(&self, target: &str, radius: u8) -> HashSet<String> {
self.reach(target, radius)
.into_iter()
.map(|(_, name)| name)
.collect()
}
pub fn neighbors_within_by_file(&self, target: &str, radius: u8) -> Vec<HashSet<String>> {
let mut per_file = vec![HashSet::new(); self.file_count.max(1)];
for (fi, name) in self.reach(target, radius) {
if let Some(set) = per_file.get_mut(fi) {
set.insert(name);
}
}
per_file
}
fn reach(&self, target: &str, radius: u8) -> HashSet<Sym> {
let mut visited: HashSet<Sym> = HashSet::new();
let mut frontier: Vec<Sym> = Vec::new();
for fi in 0..self.file_count.max(1) {
let sym = (fi, target.to_string());
if self.edges.contains_key(&sym) && visited.insert(sym.clone()) {
frontier.push(sym);
}
}
if visited.is_empty() {
visited.insert((0, target.to_string()));
}
for _ in 0..radius {
let mut next = Vec::new();
for node in &frontier {
if let Some(refs) = self.edges.get(node) {
for r in refs {
if visited.insert(r.clone()) {
next.push(r.clone());
}
}
}
}
if next.is_empty() {
break;
}
frontier = next;
}
visited
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.edges
.keys()
.map(|(_, name)| name.as_str())
.collect::<HashSet<_>>()
.into_iter()
}
}
pub fn find_identifier_lines(source: &str, lang: Lang, name: &str) -> Option<Vec<(usize, String)>> {
let tree = parse(source, lang)?;
let spec = lang.spec();
let mut rows: HashSet<usize> = HashSet::new();
collect_named_ref_rows(tree.root_node(), source, &spec, name, &mut rows);
let src_lines: Vec<&str> = source.lines().collect();
let mut rows: Vec<usize> = rows.into_iter().collect();
rows.sort_unstable();
Some(
rows.into_iter()
.map(|row| {
let text = src_lines
.get(row)
.map(|s| s.trim())
.unwrap_or("")
.to_string();
(row + 1, text)
})
.collect(),
)
}
fn collect_named_ref_rows(
node: Node,
source: &str,
spec: &crate::lang::LangSpec,
name: &str,
out: &mut HashSet<usize>,
) {
if spec.ref_ident_kinds.contains(&node.kind()) && text(node, source) == name {
out.insert(node.start_position().row);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_named_ref_rows(child, source, spec, name, out);
}
}
pub fn skeleton_with_radius(source: &str, lang: Lang, target: &str, radius: u8) -> String {
let keep = SymbolGraph::from_source(source, lang).neighbors_within(target, radius);
skeleton::skeletonize(source, lang, &keep)
}
fn collect_symbol_refs(source: &str, lang: Lang, out: &mut Vec<(String, HashSet<String>)>) {
let Some(tree) = parse(source, lang) else {
return;
};
let spec = lang.spec();
walk_symbols(tree.root_node(), source, &spec, out);
}
fn walk_symbols(
node: Node,
source: &str,
spec: &crate::lang::LangSpec,
out: &mut Vec<(String, HashSet<String>)>,
) {
if spec.symbol_kinds.contains(&node.kind()) {
if let Some(name) = node.child_by_field_name("name").map(|n| text(n, source)) {
let mut bound = HashSet::new();
collect_bound(node, source, spec, &mut bound);
let mut refs = HashSet::new();
collect_refs(node, source, spec, &bound, &mut refs);
refs.remove(&name); out.push((name, refs));
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_symbols(child, source, spec, out);
}
}
fn collect_bound(
node: Node,
source: &str,
spec: &crate::lang::LangSpec,
out: &mut HashSet<String>,
) {
if spec.binder_kinds.contains(&node.kind()) {
match node
.child_by_field_name("pattern")
.or_else(|| node.child_by_field_name("name"))
{
Some(target) => collect_idents(target, source, spec, out),
None => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_idents(child, source, spec, out);
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_bound(child, source, spec, out);
}
}
fn collect_refs(
node: Node,
source: &str,
spec: &crate::lang::LangSpec,
bound: &HashSet<String>,
out: &mut HashSet<String>,
) {
if spec.ref_ident_kinds.contains(&node.kind()) {
let t = text(node, source);
if !bound.contains(&t) {
out.insert(t);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_refs(child, source, spec, bound, out);
}
}
fn collect_idents(
node: Node,
source: &str,
spec: &crate::lang::LangSpec,
out: &mut HashSet<String>,
) {
if spec.ref_ident_kinds.contains(&node.kind()) {
out.insert(text(node, source));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_idents(child, source, spec, out);
}
}
fn text(node: Node, source: &str) -> String {
source[node.start_byte()..node.end_byte()].to_string()
}
#[cfg(test)]
mod tests {
use super::*;
const SRC: &str = "\
fn helper(x: i32) -> i32 {
x + 1
}
fn target() -> i32 {
helper(41)
}
fn unrelated() -> i32 {
7
}
";
#[test]
fn references_in_strings_and_comments_do_not_link() {
let src = "\
fn helper() -> i32 { 1 }
fn target() -> i32 {
// call helper here later
let s = \"remember to call helper\";
2
}
";
let g = SymbolGraph::from_source(src, Lang::Rust);
assert!(
!g.neighbors_within("target", 1).contains("helper"),
"string/comment mention must not create an edge"
);
}
#[test]
fn a_local_shadowing_a_symbol_name_does_not_link() {
let src = "\
fn helper() -> i32 { 1 }
fn target() -> i32 {
let helper = 5;
helper + 1
}
";
let g = SymbolGraph::from_source(src, Lang::Rust);
assert!(
!g.neighbors_within("target", 1).contains("helper"),
"a shadowing local must not link to the same-named symbol"
);
}
#[test]
fn graph_links_caller_to_callee() {
let g = SymbolGraph::from_source(SRC, Lang::Rust);
let within1 = g.neighbors_within("target", 1);
assert!(within1.contains("target"));
assert!(within1.contains("helper"));
assert!(!within1.contains("unrelated"));
}
#[test]
fn radius_zero_is_just_the_target() {
let g = SymbolGraph::from_source(SRC, Lang::Rust);
let within0 = g.neighbors_within("target", 0);
assert_eq!(within0.len(), 1);
assert!(within0.contains("target"));
}
#[test]
fn skeleton_with_radius_keeps_only_dependencies() {
let out = skeleton_with_radius(SRC, Lang::Rust, "target", 1);
assert!(out.contains("helper(41)"), "target body kept: {out}");
assert!(out.contains("x + 1"), "helper body kept: {out}");
assert!(!out.contains(" 7\n"), "unrelated body elided: {out}");
}
#[test]
fn radius_zero_elides_dependencies_too() {
let out = skeleton_with_radius(SRC, Lang::Rust, "target", 0);
assert!(out.contains("helper(41)"), "target body kept: {out}");
assert!(
!out.contains("x + 1"),
"helper body should be elided at radius 0: {out}"
);
}
#[test]
fn find_identifier_lines_matches_code_not_strings_or_comments() {
let src = "\
fn helper() -> i32 { 1 }
fn target() -> i32 {
// helper is mentioned in this comment
let s = \"call helper here\";
helper()
}
";
let hits = find_identifier_lines(src, Lang::Rust, "helper").unwrap();
let rows: Vec<usize> = hits.iter().map(|(r, _)| *r).collect();
assert_eq!(rows, vec![1, 5], "got {hits:?}");
assert!(hits.iter().any(|(r, t)| *r == 5 && t.contains("helper()")));
}
#[test]
fn find_identifier_lines_dedups_a_line_with_two_uses() {
let src = "fn f() -> i32 { g() + g() }\nfn g() -> i32 { 1 }\n";
let hits = find_identifier_lines(src, Lang::Rust, "g").unwrap();
assert_eq!(hits.iter().map(|(r, _)| *r).collect::<Vec<_>>(), vec![1, 2]);
}
#[test]
fn multi_file_graph_links_across_sources() {
let a = "fn caller() -> i32 { shared() }";
let b = "fn shared() -> i32 { 5 }";
let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
assert!(g.neighbors_within("caller", 1).contains("shared"));
}
#[test]
fn same_name_across_files_resolves_to_the_local_definition() {
let a = "\
fn helper() -> i32 { 0 }
fn target() -> i32 { helper() }
";
let b = "\
fn helper() -> i32 { dep() }
fn dep() -> i32 { 9 }
";
let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
let within2 = g.neighbors_within("target", 2);
assert!(
within2.contains("helper"),
"local helper linked: {within2:?}"
);
assert!(
!within2.contains("dep"),
"must not reach the OTHER file's helper dependency: {within2:?}"
);
}
#[test]
fn neighbors_by_file_keeps_a_body_only_in_its_owning_file() {
let a = "fn target() -> i32 { repo() }\nfn noise_a() -> i32 { 0 }";
let b = "fn repo() -> i32 { 5 }\nfn noise_b() -> i32 { 1 }";
let g = SymbolGraph::from_sources([(Lang::Rust, a), (Lang::Rust, b)]);
let per_file = g.neighbors_within_by_file("target", 1);
assert_eq!(per_file.len(), 2);
assert!(per_file[0].contains("target") && !per_file[0].contains("repo"));
assert!(per_file[1].contains("repo") && !per_file[1].contains("target"));
}
}