use crate::define_index;
use crate::symbol::{FileId, SymbolId};
use crate::SymbolKind;
use serde::{Deserialize, Serialize};
use slotmap::SecondaryMap;
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet};
define_index! {
pub struct EdgeId;
}
define_index! {
pub struct MatchExprId;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CodeEdgeV2 {
Contains,
Calls,
Implements,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EdgeData {
pub from: SymbolId,
pub to: SymbolId,
pub kind: CodeEdgeV2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MatchExprDataV2 {
pub file_id: FileId,
pub enum_id: SymbolId,
pub offset: u32,
pub line: u32,
}
#[derive(Clone, Default, Serialize)]
pub struct CodeGraphV2 {
edges: Vec<EdgeData>,
outgoing: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
incoming: SecondaryMap<SymbolId, SmallVec<[EdgeId; 4]>>,
nodes: SecondaryMap<SymbolId, ()>,
by_kind: HashMap<SymbolKind, SmallVec<[SymbolId; 16]>>,
crate_roots: SmallVec<[SymbolId; 4]>,
match_expr_index: SecondaryMap<SymbolId, SmallVec<[MatchExprId; 2]>>,
match_exprs: Vec<MatchExprDataV2>,
}
impl CodeGraphV2 {
pub fn new() -> Self {
Self::default()
}
pub fn with_capacity(_nodes: usize, edges: usize) -> Self {
Self {
edges: Vec::with_capacity(edges),
outgoing: SecondaryMap::new(),
incoming: SecondaryMap::new(),
nodes: SecondaryMap::new(),
by_kind: HashMap::new(),
crate_roots: SmallVec::new(),
match_expr_index: SecondaryMap::new(),
match_exprs: Vec::new(),
}
}
pub fn add_node(&mut self, id: SymbolId) -> bool {
if self.nodes.contains_key(id) {
return false;
}
self.nodes.insert(id, ());
true
}
#[inline]
pub fn contains(&self, id: SymbolId) -> bool {
self.nodes.contains_key(id)
}
pub fn remove_node(&mut self, id: SymbolId) -> bool {
if self.nodes.remove(id).is_none() {
return false;
}
self.outgoing.remove(id);
self.incoming.remove(id);
for symbols in self.by_kind.values_mut() {
symbols.retain(|s| *s != id);
}
self.crate_roots.retain(|s| *s != id);
self.match_expr_index.remove(id);
true
}
pub fn clear_outgoing_edges(&mut self, id: SymbolId) {
if let Some(edge_ids) = self.outgoing.remove(id) {
for edge_id in edge_ids.iter().copied() {
if let Some(edge) = self.edges.get(edge_id.as_usize()) {
let target = edge.to;
if let Some(incoming) = self.incoming.get_mut(target) {
incoming.retain(|eid| *eid != edge_id);
}
}
}
}
}
pub fn add_edge(&mut self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> EdgeId {
self.add_node(from);
self.add_node(to);
let edge_id = EdgeId::from_raw(self.edges.len() as u32);
self.edges.push(EdgeData { from, to, kind });
self.outgoing
.entry(from)
.expect("caller must supply a SymbolId already present in the SlotMap")
.or_default()
.push(edge_id);
self.incoming
.entry(to)
.expect("caller must supply a SymbolId already present in the SlotMap")
.or_default()
.push(edge_id);
edge_id
}
#[inline]
pub fn edge(&self, id: EdgeId) -> Option<&EdgeData> {
self.edges.get(id.as_usize())
}
pub fn has_edge(&self, from: SymbolId, to: SymbolId, kind: CodeEdgeV2) -> bool {
self.outgoing
.get(from)
.map(|edges| {
edges.iter().any(|&eid| {
self.edges
.get(eid.as_usize())
.map(|e| e.to == to && e.kind == kind)
.unwrap_or(false)
})
})
.unwrap_or(false)
}
pub fn add_crate_root(&mut self, id: SymbolId) {
self.add_node(id);
if !self.crate_roots.contains(&id) {
self.crate_roots.push(id);
}
}
#[inline]
pub fn crate_roots(&self) -> &[SymbolId] {
&self.crate_roots
}
pub fn add_to_kind_index(&mut self, id: SymbolId, kind: SymbolKind) {
let symbols = self.by_kind.entry(kind).or_default();
if !symbols.contains(&id) {
symbols.push(id);
}
}
pub fn iter_by_kind(&self, kind: SymbolKind) -> impl Iterator<Item = SymbolId> + '_ {
self.by_kind
.get(&kind)
.into_iter()
.flat_map(|v| v.iter().copied())
}
pub fn outgoing_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
self.outgoing
.get(id)
.into_iter()
.flat_map(|edges| edges.iter())
.filter_map(|&eid| self.edges.get(eid.as_usize()))
}
pub fn incoming_edges(&self, id: SymbolId) -> impl Iterator<Item = &EdgeData> + '_ {
self.incoming
.get(id)
.into_iter()
.flat_map(|edges| edges.iter())
.filter_map(|&eid| self.edges.get(eid.as_usize()))
}
pub fn callers_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
let mut seen = HashSet::new();
self.incoming_edges(id)
.filter(|e| e.kind == CodeEdgeV2::Calls)
.map(|e| e.from)
.filter(move |&id| seen.insert(id))
}
pub fn callees_of(&self, id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
let mut seen = HashSet::new();
self.outgoing_edges(id)
.filter(|e| e.kind == CodeEdgeV2::Calls)
.map(|e| e.to)
.filter(move |&id| seen.insert(id))
}
pub fn implementors_of(&self, trait_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
self.incoming_edges(trait_id)
.filter(|e| e.kind == CodeEdgeV2::Implements)
.map(|e| e.from)
}
pub fn children_of(&self, parent_id: SymbolId) -> impl Iterator<Item = SymbolId> + '_ {
self.outgoing_edges(parent_id)
.filter(|e| e.kind == CodeEdgeV2::Contains)
.map(|e| e.to)
}
pub fn parent_of(&self, id: SymbolId) -> Option<SymbolId> {
self.incoming_edges(id)
.find(|e| e.kind == CodeEdgeV2::Contains)
.map(|e| e.from)
}
pub fn reference_count(&self, id: SymbolId) -> usize {
self.incoming_edges(id)
.filter(|e| e.kind == CodeEdgeV2::Calls)
.count()
}
pub fn impl_count(&self, id: SymbolId) -> usize {
self.incoming_edges(id)
.filter(|e| e.kind == CodeEdgeV2::Implements)
.count()
}
pub fn add_match_expr(&mut self, function_id: SymbolId, data: MatchExprDataV2) -> MatchExprId {
let expr_id = MatchExprId::from_raw(self.match_exprs.len() as u32);
self.match_exprs.push(data);
self.match_expr_index
.entry(function_id)
.expect("caller must supply a function SymbolId already present in the SlotMap")
.or_default()
.push(expr_id);
expr_id
}
pub fn match_exprs_in(
&self,
function_id: SymbolId,
) -> impl Iterator<Item = &MatchExprDataV2> + '_ {
self.match_expr_index
.get(function_id)
.into_iter()
.flat_map(|ids| ids.iter())
.filter_map(|&id| self.match_exprs.get(id.as_usize()))
}
pub fn match_exprs_for_enum(
&self,
enum_id: SymbolId,
) -> impl Iterator<Item = (SymbolId, &MatchExprDataV2)> + '_ {
self.match_expr_index
.iter()
.flat_map(move |(func_id, ids)| {
ids.iter()
.filter_map(|&id| self.match_exprs.get(id.as_usize()))
.filter(move |data| data.enum_id == enum_id)
.map(move |data| (func_id, data))
})
}
pub fn match_expr_count(&self) -> usize {
self.match_exprs.len()
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn edge_count(&self) -> usize {
self.edges.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn callers_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
self.traverse_chain(start, max_depth, ChainDirection::Callers)
}
pub fn callees_chain(&self, start: SymbolId, max_depth: usize) -> Vec<ChainNode> {
self.traverse_chain(start, max_depth, ChainDirection::Callees)
}
fn traverse_chain(
&self,
start: SymbolId,
max_depth: usize,
direction: ChainDirection,
) -> Vec<ChainNode> {
use std::collections::{HashSet, VecDeque};
let mut result = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(start);
queue.push_back((start, 0usize));
while let Some((current, depth)) = queue.pop_front() {
if depth > 0 {
result.push(ChainNode {
symbol: current,
depth,
});
}
if depth >= max_depth {
continue;
}
let neighbors: Vec<SymbolId> = match direction {
ChainDirection::Callers => self.callers_of(current).collect(),
ChainDirection::Callees => self.callees_of(current).collect(),
ChainDirection::TypeUsers | ChainDirection::TypeDeps => {
unreachable!("TypeUsers/TypeDeps must use TypeFlowGraphV2")
}
};
for neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back((neighbor, depth + 1));
}
}
}
result
}
pub fn analyze_chain(
&self,
start: SymbolId,
max_depth: usize,
direction: ChainDirection,
) -> ChainResult {
let nodes = self.traverse_chain(start, max_depth, direction);
let mut by_depth: HashMap<usize, usize> = HashMap::new();
for node in &nodes {
*by_depth.entry(node.depth).or_default() += 1;
}
let max_actual_depth = nodes.iter().map(|n| n.depth).max().unwrap_or(0);
ChainResult {
start,
direction,
max_depth,
nodes,
max_actual_depth,
by_depth,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ChainDirection {
Callers,
Callees,
TypeUsers,
TypeDeps,
}
impl std::fmt::Display for ChainDirection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChainDirection::Callers => write!(f, "callers"),
ChainDirection::Callees => write!(f, "callees"),
ChainDirection::TypeUsers => write!(f, "type_users"),
ChainDirection::TypeDeps => write!(f, "type_deps"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChainNode {
pub symbol: SymbolId,
pub depth: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainResult {
pub start: SymbolId,
pub direction: ChainDirection,
pub max_depth: usize,
pub nodes: Vec<ChainNode>,
pub max_actual_depth: usize,
pub by_depth: HashMap<usize, usize>,
}
impl ChainResult {
pub fn total_count(&self) -> usize {
self.nodes.len()
}
pub fn at_depth(&self, depth: usize) -> impl Iterator<Item = &ChainNode> {
self.nodes.iter().filter(move |n| n.depth == depth)
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn symbols(&self) -> impl Iterator<Item = SymbolId> + '_ {
self.nodes.iter().map(|n| n.symbol)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol::{SymbolPath, SymbolRegistry};
fn setup() -> (SymbolRegistry, SymbolId, SymbolId, SymbolId) {
let mut registry = SymbolRegistry::new();
let id1 = registry
.register(SymbolPath::parse("foo::Bar").unwrap(), SymbolKind::Struct)
.unwrap();
let id2 = registry
.register(SymbolPath::parse("foo::baz").unwrap(), SymbolKind::Function)
.unwrap();
let id3 = registry
.register(SymbolPath::parse("foo::qux").unwrap(), SymbolKind::Function)
.unwrap();
(registry, id1, id2, id3)
}
#[test]
fn test_add_node() {
let (_, id1, _, _) = setup();
let mut graph = CodeGraphV2::new();
assert!(graph.add_node(id1));
assert!(!graph.add_node(id1)); assert!(graph.contains(id1));
}
#[test]
fn test_add_edge() {
let (_, id1, id2, _) = setup();
let mut graph = CodeGraphV2::new();
graph.add_edge(id1, id2, CodeEdgeV2::Contains);
assert!(graph.contains(id1));
assert!(graph.contains(id2));
assert!(graph.has_edge(id1, id2, CodeEdgeV2::Contains));
assert!(!graph.has_edge(id2, id1, CodeEdgeV2::Contains));
}
#[test]
fn test_callers_of() {
let (_, id1, id2, id3) = setup();
let mut graph = CodeGraphV2::new();
graph.add_edge(id1, id3, CodeEdgeV2::Calls);
graph.add_edge(id2, id3, CodeEdgeV2::Calls);
let callers: Vec<_> = graph.callers_of(id3).collect();
assert_eq!(callers.len(), 2);
assert!(callers.contains(&id1));
assert!(callers.contains(&id2));
}
#[test]
fn test_children_of() {
let (_, id1, id2, id3) = setup();
let mut graph = CodeGraphV2::new();
graph.add_edge(id1, id2, CodeEdgeV2::Contains);
graph.add_edge(id1, id3, CodeEdgeV2::Contains);
let children: Vec<_> = graph.children_of(id1).collect();
assert_eq!(children.len(), 2);
}
#[test]
fn test_parent_of() {
let (_, id1, id2, _) = setup();
let mut graph = CodeGraphV2::new();
graph.add_edge(id1, id2, CodeEdgeV2::Contains);
assert_eq!(graph.parent_of(id2), Some(id1));
assert_eq!(graph.parent_of(id1), None);
}
#[test]
fn test_remove_node() {
let (_, id1, id2, _) = setup();
let mut graph = CodeGraphV2::new();
graph.add_edge(id1, id2, CodeEdgeV2::Calls);
assert_eq!(graph.node_count(), 2);
assert!(graph.remove_node(id1));
assert_eq!(graph.node_count(), 1);
assert!(!graph.contains(id1));
assert!(graph.contains(id2));
}
#[test]
fn test_kind_index() {
let (_, id1, id2, id3) = setup();
let mut graph = CodeGraphV2::new();
graph.add_node(id1);
graph.add_node(id2);
graph.add_node(id3);
graph.add_to_kind_index(id1, SymbolKind::Struct);
graph.add_to_kind_index(id2, SymbolKind::Function);
graph.add_to_kind_index(id3, SymbolKind::Function);
let structs: Vec<_> = graph.iter_by_kind(SymbolKind::Struct).collect();
assert_eq!(structs.len(), 1);
let functions: Vec<_> = graph.iter_by_kind(SymbolKind::Function).collect();
assert_eq!(functions.len(), 2);
}
fn setup_chain() -> (
SymbolRegistry,
SymbolId,
SymbolId,
SymbolId,
SymbolId,
SymbolId,
) {
let mut registry = SymbolRegistry::new();
let a = registry
.register(
SymbolPath::parse("test::fn_a").unwrap(),
SymbolKind::Function,
)
.unwrap();
let b = registry
.register(
SymbolPath::parse("test::fn_b").unwrap(),
SymbolKind::Function,
)
.unwrap();
let c = registry
.register(
SymbolPath::parse("test::fn_c").unwrap(),
SymbolKind::Function,
)
.unwrap();
let d = registry
.register(
SymbolPath::parse("test::fn_d").unwrap(),
SymbolKind::Function,
)
.unwrap();
let e = registry
.register(
SymbolPath::parse("test::fn_e").unwrap(),
SymbolKind::Function,
)
.unwrap();
(registry, a, b, c, d, e)
}
#[test]
fn test_callers_chain_simple() {
let (_, a, b, c, d, _) = setup_chain();
let mut graph = CodeGraphV2::new();
graph.add_edge(a, b, CodeEdgeV2::Calls);
graph.add_edge(b, c, CodeEdgeV2::Calls);
graph.add_edge(c, d, CodeEdgeV2::Calls);
let chain = graph.callers_chain(d, 10);
assert_eq!(chain.len(), 3);
let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
assert_eq!(c_node.depth, 1);
let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
assert_eq!(b_node.depth, 2);
let a_node = chain.iter().find(|n| n.symbol == a).unwrap();
assert_eq!(a_node.depth, 3);
}
#[test]
fn test_callees_chain_simple() {
let (_, a, b, c, d, _) = setup_chain();
let mut graph = CodeGraphV2::new();
graph.add_edge(a, b, CodeEdgeV2::Calls);
graph.add_edge(b, c, CodeEdgeV2::Calls);
graph.add_edge(c, d, CodeEdgeV2::Calls);
let chain = graph.callees_chain(a, 10);
assert_eq!(chain.len(), 3);
let b_node = chain.iter().find(|n| n.symbol == b).unwrap();
assert_eq!(b_node.depth, 1);
let c_node = chain.iter().find(|n| n.symbol == c).unwrap();
assert_eq!(c_node.depth, 2);
let d_node = chain.iter().find(|n| n.symbol == d).unwrap();
assert_eq!(d_node.depth, 3);
}
#[test]
fn test_chain_with_max_depth() {
let (_, a, b, c, d, _) = setup_chain();
let mut graph = CodeGraphV2::new();
graph.add_edge(a, b, CodeEdgeV2::Calls);
graph.add_edge(b, c, CodeEdgeV2::Calls);
graph.add_edge(c, d, CodeEdgeV2::Calls);
let chain = graph.callees_chain(a, 2);
assert_eq!(chain.len(), 2);
let symbols: Vec<_> = chain.iter().map(|n| n.symbol).collect();
assert!(symbols.contains(&b));
assert!(symbols.contains(&c));
assert!(!symbols.contains(&d));
}
#[test]
fn test_chain_with_cycle() {
let (_, a, b, c, _, _) = setup_chain();
let mut graph = CodeGraphV2::new();
graph.add_edge(a, b, CodeEdgeV2::Calls);
graph.add_edge(b, c, CodeEdgeV2::Calls);
graph.add_edge(c, a, CodeEdgeV2::Calls);
let chain = graph.callees_chain(a, 10);
assert_eq!(chain.len(), 2); }
#[test]
fn test_analyze_chain() {
let (_, a, b, c, d, e) = setup_chain();
let mut graph = CodeGraphV2::new();
graph.add_edge(a, b, CodeEdgeV2::Calls);
graph.add_edge(b, c, CodeEdgeV2::Calls);
graph.add_edge(c, d, CodeEdgeV2::Calls);
graph.add_edge(d, e, CodeEdgeV2::Calls);
let result = graph.analyze_chain(a, 10, ChainDirection::Callees);
assert_eq!(result.start, a);
assert_eq!(result.direction, ChainDirection::Callees);
assert_eq!(result.total_count(), 4);
assert_eq!(result.max_actual_depth, 4);
assert_eq!(*result.by_depth.get(&1).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&2).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&3).unwrap_or(&0), 1); assert_eq!(*result.by_depth.get(&4).unwrap_or(&0), 1); }
}