use std::collections::{HashMap, HashSet, VecDeque};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use serde::{Deserialize, Serialize};
use crate::core::chunker::ChunkType;
use crate::core::entity::EdgeKind;
const DEFAULT_MAX_KG_NODES: usize = 100_000;
pub fn max_kg_nodes() -> usize {
std::env::var("TRUSTY_MAX_KG_NODES")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_KG_NODES)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SymbolNode {
pub symbol: String,
pub chunk_id: String,
pub file: String,
}
pub type ChunkTuple = (
String,
String,
Option<String>,
Vec<String>,
Vec<String>,
ChunkType,
);
#[derive(Debug, Default)]
pub struct SymbolGraph {
graph: DiGraph<SymbolNode, EdgeKind>,
by_symbol: HashMap<String, NodeIndex>,
chunk_to_symbol: HashMap<String, String>,
}
impl SymbolGraph {
pub fn new() -> Self {
Self::default()
}
pub fn build_from_chunks(chunks: &[ChunkTuple]) -> Self {
let mut g = Self::new();
g.register_symbol_nodes(chunks);
let by_suffix = g.build_suffix_lookup();
g.add_call_and_inherit_edges(chunks, &by_suffix);
g.add_module_contains_edges(chunks);
g
}
fn register_symbol_nodes(&mut self, chunks: &[ChunkTuple]) {
let cap = max_kg_nodes();
let mut cap_warned = false;
for (chunk_id, file, name, _calls, _inh, _ct) in chunks {
self.register_one_symbol(chunk_id, file, name.as_deref(), cap, &mut cap_warned);
}
}
fn register_one_symbol(
&mut self,
chunk_id: &str,
file: &str,
name: Option<&str>,
cap: usize,
cap_warned: &mut bool,
) {
let Some(name) = name else { return };
if name.is_empty() {
return;
}
if self.by_symbol.contains_key(name) {
self.chunk_to_symbol
.insert(chunk_id.to_string(), name.to_string());
return;
}
if Self::cap_exceeded(cap, self.by_symbol.len()) {
Self::warn_cap_once(cap, cap_warned);
return;
}
let idx = self.graph.add_node(SymbolNode {
symbol: name.to_string(),
chunk_id: chunk_id.to_string(),
file: file.to_string(),
});
self.by_symbol.insert(name.to_string(), idx);
self.chunk_to_symbol
.insert(chunk_id.to_string(), name.to_string());
}
fn cap_exceeded(cap: usize, current: usize) -> bool {
cap > 0 && current >= cap
}
fn warn_cap_once(cap: usize, cap_warned: &mut bool) {
if !*cap_warned {
tracing::warn!(
"symbol graph node cap ({}) reached — skipping further new symbols \
(override via TRUSTY_MAX_KG_NODES; 0 = unlimited)",
cap
);
*cap_warned = true;
}
}
fn build_suffix_lookup(&self) -> HashMap<String, NodeIndex> {
let mut by_suffix: HashMap<String, NodeIndex> = HashMap::new();
for (sym, &idx) in self.by_symbol.iter() {
if let Some(suffix) = sym.rsplit("::").next() {
by_suffix.entry(suffix.to_string()).or_insert(idx);
}
}
by_suffix
}
fn add_call_and_inherit_edges(
&mut self,
chunks: &[ChunkTuple],
by_suffix: &HashMap<String, NodeIndex>,
) {
for (_chunk_id, _file, name, calls, inherits_from, _ct) in chunks {
let Some(name) = name else { continue };
let Some(&from) = self.by_symbol.get(name) else {
continue;
};
self.add_edges_for_targets(from, calls, by_suffix, EdgeKind::CallsFunction);
self.add_edges_for_targets(from, inherits_from, by_suffix, EdgeKind::Implements);
}
}
fn add_edges_for_targets(
&mut self,
from: NodeIndex,
targets: &[String],
by_suffix: &HashMap<String, NodeIndex>,
kind: EdgeKind,
) {
for target in targets {
let Some(to) = self.resolve_callee_fast(target, by_suffix) else {
continue;
};
if from == to {
continue;
}
self.graph.add_edge(from, to, kind.clone());
}
}
fn add_module_contains_edges(&mut self, chunks: &[ChunkTuple]) {
if !Self::has_any_container(chunks) {
return;
}
let by_file = self.group_symbols_by_file(chunks);
for (_chunk_id, file, name, _calls, _inh, ct) in chunks {
self.emit_container_edges_for(file, name.as_deref(), ct, &by_file);
}
}
fn emit_container_edges_for(
&mut self,
file: &str,
name: Option<&str>,
ct: &ChunkType,
by_file: &HashMap<&str, Vec<(&str, NodeIndex)>>,
) {
if !Self::is_container(ct) {
return;
}
let Some(name) = name else { return };
let Some(&from) = self.by_symbol.get(name) else {
return;
};
let Some(siblings) = by_file.get(file) else {
return;
};
self.add_sibling_edges(from, name, siblings);
}
fn add_sibling_edges(&mut self, from: NodeIndex, owner: &str, siblings: &[(&str, NodeIndex)]) {
for (sib_name, sib_idx) in siblings {
if *sib_idx == from || *sib_name == owner {
continue;
}
self.graph
.add_edge(from, *sib_idx, EdgeKind::ModuleContains);
}
}
fn has_any_container(chunks: &[ChunkTuple]) -> bool {
chunks
.iter()
.any(|(_, _, name, _, _, ct)| name.is_some() && Self::is_container(ct))
}
fn is_container(ct: &ChunkType) -> bool {
matches!(
ct,
ChunkType::Impl | ChunkType::Class | ChunkType::Struct | ChunkType::Module
)
}
fn group_symbols_by_file<'a>(
&self,
chunks: &'a [ChunkTuple],
) -> HashMap<&'a str, Vec<(&'a str, NodeIndex)>> {
let mut by_file: HashMap<&str, Vec<(&str, NodeIndex)>> = HashMap::new();
for (_chunk_id, file, name, _calls, _inh, _ct) in chunks {
if let Some(name) = name {
if let Some(&idx) = self.by_symbol.get(name) {
by_file
.entry(file.as_str())
.or_default()
.push((name.as_str(), idx));
}
}
}
by_file
}
fn resolve_callee_fast(
&self,
callee: &str,
by_suffix: &HashMap<String, NodeIndex>,
) -> Option<NodeIndex> {
if let Some(&idx) = self.by_symbol.get(callee) {
return Some(idx);
}
by_suffix.get(callee).copied()
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn symbol_for_chunk(&self, chunk_id: &str) -> Option<&str> {
self.chunk_to_symbol.get(chunk_id).map(|s| s.as_str())
}
pub fn callers_of(&self, symbol: &str, hops: usize) -> Vec<(String, String)> {
self.bfs_neighbors(symbol, hops, Direction::Incoming)
}
pub fn callees_of(&self, symbol: &str, hops: usize) -> Vec<(String, String)> {
self.bfs_neighbors(symbol, hops, Direction::Outgoing)
}
pub fn neighbors_by_edge(
&self,
symbol: &str,
edge_kinds: &[EdgeKind],
hops: usize,
) -> Vec<(String, String, EdgeKind)> {
let Some(start) = self.start_index(symbol, hops) else {
return Vec::new();
};
if edge_kinds.is_empty() {
return Vec::new();
}
let allowed: HashSet<&EdgeKind> = edge_kinds.iter().collect();
let mut out: Vec<(String, String, EdgeKind)> = Vec::new();
self.bfs_walk(
start,
hops,
&[Direction::Outgoing, Direction::Incoming],
|edge| allowed.contains(edge.weight()),
|node, edge| {
out.push((
node.symbol.clone(),
node.chunk_id.clone(),
edge.weight().clone(),
));
},
);
out
}
fn bfs_neighbors(&self, symbol: &str, hops: usize, dir: Direction) -> Vec<(String, String)> {
let Some(start) = self.start_index(symbol, hops) else {
return Vec::new();
};
let mut out: Vec<(String, String)> = Vec::new();
self.bfs_walk(
start,
hops,
&[dir],
|edge| edge.weight() == &EdgeKind::CallsFunction,
|node, _edge| {
out.push((node.symbol.clone(), node.chunk_id.clone()));
},
);
out
}
fn start_index(&self, symbol: &str, hops: usize) -> Option<NodeIndex> {
if hops == 0 {
return None;
}
self.by_symbol.get(symbol).copied()
}
fn bfs_walk<F, V>(
&self,
start: NodeIndex,
hops: usize,
dirs: &[Direction],
edge_filter: F,
mut on_visit: V,
) where
F: Fn(petgraph::graph::EdgeReference<'_, EdgeKind>) -> bool,
V: FnMut(&SymbolNode, petgraph::graph::EdgeReference<'_, EdgeKind>),
{
let mut visited: HashSet<NodeIndex> = HashSet::new();
visited.insert(start);
let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
queue.push_back((start, 0));
while let Some((node, depth)) = queue.pop_front() {
if depth >= hops {
continue;
}
self.expand_node(
node,
depth,
dirs,
&edge_filter,
&mut on_visit,
&mut visited,
&mut queue,
);
}
}
#[allow(clippy::too_many_arguments)]
fn expand_node<F, V>(
&self,
node: NodeIndex,
depth: usize,
dirs: &[Direction],
edge_filter: &F,
on_visit: &mut V,
visited: &mut HashSet<NodeIndex>,
queue: &mut VecDeque<(NodeIndex, usize)>,
) where
F: Fn(petgraph::graph::EdgeReference<'_, EdgeKind>) -> bool,
V: FnMut(&SymbolNode, petgraph::graph::EdgeReference<'_, EdgeKind>),
{
for &dir in dirs {
for edge in self.graph.edges_directed(node, dir) {
if !edge_filter(edge) {
continue;
}
let nb = Self::neighbor_in_direction(edge, dir);
self.record_neighbor(nb, edge, depth, on_visit, visited, queue);
}
}
}
fn neighbor_in_direction(
edge: petgraph::graph::EdgeReference<'_, EdgeKind>,
dir: Direction,
) -> NodeIndex {
match dir {
Direction::Outgoing => edge.target(),
Direction::Incoming => edge.source(),
}
}
fn record_neighbor<V>(
&self,
nb: NodeIndex,
edge: petgraph::graph::EdgeReference<'_, EdgeKind>,
depth: usize,
on_visit: &mut V,
visited: &mut HashSet<NodeIndex>,
queue: &mut VecDeque<(NodeIndex, usize)>,
) where
V: FnMut(&SymbolNode, petgraph::graph::EdgeReference<'_, EdgeKind>),
{
if visited.insert(nb) {
let n = &self.graph[nb];
on_visit(n, edge);
queue.push_back((nb, depth + 1));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk(id: &str, file: &str, name: Option<&str>, calls: &[&str]) -> ChunkTuple {
chunk_full(id, file, name, calls, &[], ChunkType::Function)
}
fn chunk_full(
id: &str,
file: &str,
name: Option<&str>,
calls: &[&str],
inherits_from: &[&str],
chunk_type: ChunkType,
) -> ChunkTuple {
(
id.to_string(),
file.to_string(),
name.map(String::from),
calls.iter().map(|s| s.to_string()).collect(),
inherits_from.iter().map(|s| s.to_string()).collect(),
chunk_type,
)
}
#[test]
fn test_build_simple_graph() {
let chunks = vec![
chunk("a:1", "a.rs", Some("main"), &["foo", "bar"]),
chunk("a:2", "a.rs", Some("foo"), &["bar"]),
chunk("a:3", "a.rs", Some("bar"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 3);
}
#[test]
fn test_callers_of_one_hop() {
let chunks = vec![
chunk("m:1", "m.rs", Some("main"), &["authenticate"]),
chunk("h:1", "h.rs", Some("login_handler"), &["authenticate"]),
chunk("a:1", "a.rs", Some("authenticate"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let mut callers = g.callers_of("authenticate", 1);
callers.sort();
assert_eq!(
callers,
vec![
("login_handler".to_string(), "h:1".to_string()),
("main".to_string(), "m:1".to_string()),
]
);
}
#[test]
fn test_callees_of_one_hop() {
let chunks = vec![
chunk(
"a:1",
"a.rs",
Some("authenticate"),
&["hash_password", "lookup_user"],
),
chunk("p:1", "p.rs", Some("hash_password"), &[]),
chunk("u:1", "u.rs", Some("lookup_user"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let mut callees = g.callees_of("authenticate", 1);
callees.sort();
assert_eq!(
callees,
vec![
("hash_password".to_string(), "p:1".to_string()),
("lookup_user".to_string(), "u:1".to_string()),
]
);
}
#[test]
fn test_two_hop_traversal() {
let chunks = vec![
chunk("a:1", "a.rs", Some("a"), &["b"]),
chunk("b:1", "b.rs", Some("b"), &["c"]),
chunk("c:1", "c.rs", Some("c"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let one_hop = g.callees_of("a", 1);
assert_eq!(one_hop.len(), 1);
assert_eq!(one_hop[0].0, "b");
let two_hop = g.callees_of("a", 2);
let names: Vec<&str> = two_hop.iter().map(|(s, _)| s.as_str()).collect();
assert!(names.contains(&"b"));
assert!(names.contains(&"c"));
}
#[test]
fn test_unknown_symbol_returns_empty() {
let chunks = vec![chunk("a:1", "a.rs", Some("a"), &[])];
let g = SymbolGraph::build_from_chunks(&chunks);
assert!(g.callers_of("nonexistent", 1).is_empty());
assert!(g.callees_of("nonexistent", 1).is_empty());
}
#[test]
fn test_qualified_method_resolves_simple_callee() {
let chunks = vec![
chunk("f:1", "f.rs", Some("Foo::bar"), &["baz"]),
chunk("b:1", "b.rs", Some("baz"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let callers = g.callers_of("baz", 1);
assert_eq!(callers.len(), 1);
assert_eq!(callers[0].0, "Foo::bar");
}
#[test]
fn test_simple_callee_resolves_to_qualified_definition() {
let chunks = vec![
chunk("c:1", "c.rs", Some("caller"), &["bar"]),
chunk("f:1", "f.rs", Some("Foo::bar"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let callees = g.callees_of("caller", 1);
assert_eq!(callees.len(), 1);
assert_eq!(callees[0].0, "Foo::bar");
}
#[test]
fn test_chunk_with_no_function_name_is_skipped() {
let chunks = vec![
chunk("s:1", "s.rs", None, &[]),
chunk("f:1", "f.rs", Some("f"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
assert_eq!(g.node_count(), 1);
}
#[test]
fn test_zero_hops_returns_empty() {
let chunks = vec![
chunk("a:1", "a.rs", Some("a"), &["b"]),
chunk("b:1", "b.rs", Some("b"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
assert!(g.callees_of("a", 0).is_empty());
}
#[test]
fn test_symbol_for_chunk() {
let chunks = vec![chunk("a:1", "a.rs", Some("alpha"), &[])];
let g = SymbolGraph::build_from_chunks(&chunks);
assert_eq!(g.symbol_for_chunk("a:1"), Some("alpha"));
assert_eq!(g.symbol_for_chunk("missing"), None);
}
#[test]
fn test_neighbors_by_edge_filters_by_kind() {
let mut g = SymbolGraph::new();
let a = g.graph.add_node(SymbolNode {
symbol: "a".into(),
chunk_id: "a:1".into(),
file: "a.rs".into(),
});
let b = g.graph.add_node(SymbolNode {
symbol: "b".into(),
chunk_id: "b:1".into(),
file: "b.rs".into(),
});
let c = g.graph.add_node(SymbolNode {
symbol: "c".into(),
chunk_id: "c:1".into(),
file: "c.rs".into(),
});
g.by_symbol.insert("a".into(), a);
g.by_symbol.insert("b".into(), b);
g.by_symbol.insert("c".into(), c);
g.graph.add_edge(a, b, EdgeKind::CallsFunction);
g.graph.add_edge(a, c, EdgeKind::Implements);
let calls = g.neighbors_by_edge("a", &[EdgeKind::CallsFunction], 1);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].0, "b");
let impls = g.neighbors_by_edge("a", &[EdgeKind::Implements], 1);
assert_eq!(impls.len(), 1);
assert_eq!(impls[0].0, "c");
let both = g.neighbors_by_edge("a", &[EdgeKind::CallsFunction, EdgeKind::Implements], 1);
assert_eq!(both.len(), 2);
assert!(g.neighbors_by_edge("a", &[], 1).is_empty());
assert!(g
.neighbors_by_edge("a", &[EdgeKind::CallsFunction], 0)
.is_empty());
}
#[test]
fn test_calls_function_edges_present_in_graph() {
let chunks = vec![
chunk("a:1", "a.rs", Some("alpha"), &["bar"]),
chunk("b:1", "a.rs", Some("bar"), &[]),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let calls = g.neighbors_by_edge("alpha", &[EdgeKind::CallsFunction], 1);
assert_eq!(
calls.len(),
1,
"expected exactly one CallsFunction neighbour, got {calls:?}"
);
assert_eq!(calls[0].0, "bar");
assert!(matches!(calls[0].2, EdgeKind::CallsFunction));
}
#[test]
fn test_inherits_from_emits_implements_edges() {
let chunks = vec![
chunk_full(
"c:1",
"c.rs",
Some("Child"),
&[],
&["Parent"],
ChunkType::Class,
),
chunk_full("p:1", "p.rs", Some("Parent"), &[], &[], ChunkType::Class),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let impls = g.neighbors_by_edge("Child", &[EdgeKind::Implements], 1);
assert_eq!(impls.len(), 1, "expected one Implements edge: {impls:?}");
assert_eq!(impls[0].0, "Parent");
}
#[test]
fn test_module_contains_edges_from_container_chunks() {
let chunks = vec![
chunk_full("i:1", "f.rs", Some("FooImpl"), &[], &[], ChunkType::Impl),
chunk_full("m:1", "f.rs", Some("method_a"), &[], &[], ChunkType::Method),
chunk_full("m:2", "f.rs", Some("method_b"), &[], &[], ChunkType::Method),
chunk_full(
"o:1",
"other.rs",
Some("outside"),
&[],
&[],
ChunkType::Function,
),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let contained = g.neighbors_by_edge("FooImpl", &[EdgeKind::ModuleContains], 1);
let names: HashSet<&str> = contained.iter().map(|(n, _, _)| n.as_str()).collect();
assert!(names.contains("method_a"), "got {names:?}");
assert!(names.contains("method_b"), "got {names:?}");
assert!(!names.contains("outside"), "cross-file leak: {names:?}");
}
#[test]
fn test_neighbors_by_edge_only_returns_filtered_kinds() {
let chunks = vec![
chunk_full(
"a:1",
"a.rs",
Some("Alpha"),
&["beta"],
&["BaseAlpha"],
ChunkType::Class,
),
chunk("b:1", "a.rs", Some("beta"), &[]),
chunk_full(
"ba:1",
"a.rs",
Some("BaseAlpha"),
&[],
&[],
ChunkType::Class,
),
];
let g = SymbolGraph::build_from_chunks(&chunks);
let calls = g.neighbors_by_edge("Alpha", &[EdgeKind::CallsFunction], 1);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].0, "beta");
assert!(calls.iter().all(|(_, _, k)| k == &EdgeKind::CallsFunction));
let impls = g.neighbors_by_edge("Alpha", &[EdgeKind::Implements], 1);
assert!(impls.iter().any(|(n, _, _)| n == "BaseAlpha"));
assert!(impls.iter().all(|(_, _, k)| k == &EdgeKind::Implements));
}
#[test]
fn test_self_call_does_not_create_self_loop() {
let chunks = vec![chunk("f:1", "f.rs", Some("f"), &["f"])];
let g = SymbolGraph::build_from_chunks(&chunks);
assert_eq!(g.edge_count(), 0);
}
}