use petgraph::stable_graph::StableGraph;
use petgraph::visit::EdgeRef;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, Mutex};
pub type NodeId = petgraph::stable_graph::NodeIndex;
pub type EdgeId = petgraph::stable_graph::EdgeIndex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub id: String,
pub node_type: NodeType,
pub name: String,
pub file_path: Arc<str>,
pub byte_range: (usize, usize),
pub complexity: u32,
pub language: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum NodeType {
Function,
Class,
Method,
Variable,
Module,
External,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum EdgeType {
Call,
DataDependency,
Inheritance,
Import,
Containment,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edge {
pub edge_type: EdgeType,
pub metadata: EdgeMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeMetadata {
pub call_count: Option<usize>,
pub variable_name: Option<String>,
pub confidence: Option<f32>,
}
impl EdgeMetadata {
pub fn empty() -> Self {
Self {
call_count: None,
variable_name: None,
confidence: None,
}
}
pub fn with_confidence(confidence: f32) -> Self {
Self {
call_count: None,
variable_name: None,
confidence: Some(confidence),
}
}
pub fn with_variable(name: String) -> Self {
Self {
call_count: None,
variable_name: Some(name),
confidence: None,
}
}
}
#[derive(Debug, Clone)]
pub struct TraversalConfig {
pub max_depth: Option<usize>,
pub max_nodes: Option<usize>,
pub allowed_edge_types: Option<&'static [EdgeType]>,
pub excluded_node_types: Option<Vec<NodeType>>,
pub min_complexity: Option<u32>,
pub min_edge_confidence: f32,
}
impl TraversalConfig {
pub fn for_llm_context() -> Self {
Self {
max_depth: Some(3),
max_nodes: Some(50),
allowed_edge_types: Some(&[EdgeType::Call, EdgeType::DataDependency]),
excluded_node_types: Some(vec![NodeType::Module]),
min_complexity: None,
min_edge_confidence: 0.5,
}
}
pub fn for_semantic_analysis() -> Self {
Self {
max_depth: Some(5),
max_nodes: Some(150),
allowed_edge_types: Some(&[
EdgeType::Call,
EdgeType::DataDependency,
EdgeType::Inheritance,
]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.4,
}
}
pub fn for_impact_analysis() -> Self {
Self {
max_depth: None,
max_nodes: Some(500),
allowed_edge_types: Some(&[
EdgeType::Call,
EdgeType::DataDependency,
EdgeType::Inheritance,
]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.0,
}
}
pub fn for_import_graph() -> Self {
Self {
max_depth: Some(10),
max_nodes: Some(1000),
allowed_edge_types: Some(&[EdgeType::Import]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.0,
}
}
fn edge_allowed(&self, edge: &Edge) -> bool {
let type_ok = self
.allowed_edge_types
.as_ref()
.map(|types| types.contains(&edge.edge_type))
.unwrap_or(true);
let confidence_ok = edge
.metadata
.confidence
.map(|c| c >= self.min_edge_confidence)
.unwrap_or(true);
type_ok && confidence_ok
}
fn node_should_collect(&self, node: &Node) -> bool {
let type_ok = self
.excluded_node_types
.as_ref()
.map(|excluded| !excluded.contains(&node.node_type))
.unwrap_or(true);
let complexity_ok = self
.min_complexity
.map(|min| node.complexity >= min)
.unwrap_or(true);
type_ok && complexity_ok
}
}
#[derive(Debug, Default, Clone)]
pub struct EmbeddingStore {
pub(crate) embeddings: HashMap<String, Vec<f32>>, }
impl EmbeddingStore {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, node_id: &str, embedding: Vec<f32>) {
self.embeddings.insert(node_id.to_string(), embedding);
}
pub fn get(&self, node_id: &str) -> Option<&Vec<f32>> {
self.embeddings.get(node_id)
}
pub fn remove(&mut self, node_id: &str) {
self.embeddings.remove(node_id);
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableNode {
index: u32,
node: Node,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializableEdge {
source: u32,
target: u32,
edge: Edge,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SerializablePDG {
nodes: Vec<SerializableNode>,
edges: Vec<SerializableEdge>,
symbol_index: HashMap<String, u32>,
file_index: HashMap<String, Vec<u32>>,
#[serde(default)]
name_index: HashMap<String, Vec<u32>>,
#[serde(default)]
name_lower_index: HashMap<String, Vec<u32>>,
#[serde(default)]
embeddings: HashMap<String, Vec<f32>>,
}
impl SerializablePDG {
fn from_pdg(pdg: &ProgramDependenceGraph) -> Self {
let nodes = pdg
.graph
.node_indices()
.map(|idx| SerializableNode {
index: idx.index() as u32,
node: pdg.graph[idx].clone(),
})
.collect();
let edges = pdg
.graph
.edge_indices()
.map(|eidx| {
let (source, target) = pdg
.graph
.edge_endpoints(eidx)
.expect("Edge endpoints must exist");
SerializableEdge {
source: source.index() as u32,
target: target.index() as u32,
edge: pdg.graph[eidx].clone(),
}
})
.collect();
let symbol_index = pdg
.symbol_index
.iter()
.map(|(k, v)| (k.clone(), v.index() as u32))
.collect();
let file_index = pdg
.file_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect();
let name_index = pdg
.name_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect();
let name_lower_index = pdg
.name_lower_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect();
Self {
nodes,
edges,
symbol_index,
file_index,
name_index,
name_lower_index,
embeddings: pdg.embedding_store.embeddings.clone(),
}
}
fn to_pdg(&self) -> Result<ProgramDependenceGraph, String> {
let mut pdg = ProgramDependenceGraph::new();
let mut index_map: HashMap<u32, NodeId> = HashMap::new();
for sn in &self.nodes {
let id = pdg.graph.add_node(sn.node.clone());
index_map.insert(sn.index, id);
}
for (sym, old_idx) in &self.symbol_index {
if let Some(&nid) = index_map.get(old_idx) {
pdg.symbol_index.insert(sym.clone(), nid);
}
}
for (fp, old_idxs) in &self.file_index {
let nids: Vec<NodeId> = old_idxs
.iter()
.filter_map(|i| index_map.get(i).copied())
.collect();
if !nids.is_empty() {
pdg.file_index.insert(fp.clone(), nids);
}
}
for (name, old_idxs) in &self.name_index {
let nids: Vec<NodeId> = old_idxs
.iter()
.filter_map(|i| index_map.get(i).copied())
.collect();
if !nids.is_empty() {
pdg.name_index.insert(name.clone(), nids);
}
}
for (name_lc, old_idxs) in &self.name_lower_index {
let nids: Vec<NodeId> = old_idxs
.iter()
.filter_map(|i| index_map.get(i).copied())
.collect();
if !nids.is_empty() {
pdg.name_lower_index.insert(name_lc.clone(), nids);
}
}
if pdg.name_index.is_empty() {
for nid in pdg.graph.node_indices() {
if let Some(node) = pdg.graph.node_weight(nid) {
pdg.name_index
.entry(node.name.clone())
.or_default()
.push(nid);
pdg.name_lower_index
.entry(node.name.to_lowercase())
.or_default()
.push(nid);
}
}
}
for se in &self.edges {
let src = index_map
.get(&se.source)
.ok_or_else(|| format!("Missing source {}", se.source))?;
let tgt = index_map
.get(&se.target)
.ok_or_else(|| format!("Missing target {}", se.target))?;
pdg.graph.add_edge(*src, *tgt, se.edge.clone());
}
for (node_id, embedding) in &self.embeddings {
pdg.embedding_store.insert(node_id, embedding.clone());
}
for nid in pdg.graph.node_indices() {
if let Some(node) = pdg.graph.node_weight(nid) {
pdg.name_file_index
.insert((node.name.clone(), node.file_path.to_string()), nid);
}
}
Ok(pdg)
}
}
pub struct ProgramDependenceGraph {
pub(crate) graph: StableGraph<Node, Edge>,
pub(crate) symbol_index: HashMap<String, NodeId>,
pub(crate) file_index: HashMap<String, Vec<NodeId>>,
pub(crate) name_index: HashMap<String, Vec<NodeId>>,
pub(crate) name_lower_index: HashMap<String, Vec<NodeId>>,
pub embedding_store: EmbeddingStore,
name_file_index: HashMap<(String, String), NodeId>,
bfs_scratch: Mutex<Vec<NodeId>>,
}
impl Clone for ProgramDependenceGraph {
fn clone(&self) -> Self {
Self {
graph: self.graph.clone(),
symbol_index: self.symbol_index.clone(),
file_index: self.file_index.clone(),
name_index: self.name_index.clone(),
name_lower_index: self.name_lower_index.clone(),
embedding_store: self.embedding_store.clone(),
name_file_index: self.name_file_index.clone(),
bfs_scratch: Mutex::new(Vec::new()),
}
}
}
impl ProgramDependenceGraph {
pub fn new() -> Self {
Self {
graph: StableGraph::new(),
symbol_index: HashMap::new(),
file_index: HashMap::new(),
name_index: HashMap::new(),
name_lower_index: HashMap::new(),
embedding_store: EmbeddingStore::new(),
name_file_index: HashMap::new(),
bfs_scratch: Mutex::new(Vec::new()),
}
}
pub fn add_node(&mut self, node: Node) -> NodeId {
let id = self.graph.add_node(node.clone());
self.symbol_index.insert(node.id.clone(), id);
self.file_index
.entry(node.file_path.to_string())
.or_default()
.push(id);
self.name_index
.entry(node.name.clone())
.or_default()
.push(id);
self.name_lower_index
.entry(node.name.to_lowercase())
.or_default()
.push(id);
self.name_file_index
.insert((node.name.clone(), node.file_path.to_string()), id);
id
}
pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) -> EdgeId {
debug_assert!(
self.graph.contains_node(from) && self.graph.contains_node(to),
"add_edge called with invalid NodeId(s): from={:?} to={:?}",
from,
to
);
self.graph.add_edge(from, to, edge)
}
pub fn remove_node(&mut self, node_id: NodeId) -> Option<Node> {
if let Some(node) = self.graph.remove_node(node_id) {
self.symbol_index.remove(&node.id);
self.embedding_store.remove(&node.id);
if let Some(v) = self.file_index.get_mut(&*node.file_path) {
v.retain(|&id| id != node_id);
}
if let Some(v) = self.name_index.get_mut(&node.name) {
v.retain(|&id| id != node_id);
}
if let Some(v) = self.name_lower_index.get_mut(&node.name.to_lowercase()) {
v.retain(|&id| id != node_id);
}
self.name_file_index
.remove(&(node.name.clone(), node.file_path.to_string()));
Some(node)
} else {
None
}
}
pub fn remove_edge(&mut self, id: EdgeId) -> Option<Edge> {
self.graph.remove_edge(id)
}
pub fn remove_file(&mut self, file_path: &str) {
let ids = self.nodes_in_file(file_path);
for id in ids {
self.remove_node(id);
}
self.file_index.remove(file_path);
}
pub fn get_node(&self, id: NodeId) -> Option<&Node> {
self.graph.node_weight(id)
}
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
self.graph.node_weight_mut(id)
}
pub fn node_weights_mut(&mut self) -> impl Iterator<Item = &mut Node> {
self.graph.node_weights_mut()
}
pub fn get_edge(&self, id: EdgeId) -> Option<&Edge> {
self.graph.edge_weight(id)
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn file_count(&self) -> usize {
self.file_index
.values()
.filter(|node_ids| !node_ids.is_empty())
.count()
}
pub fn node_indices(&self) -> impl Iterator<Item = NodeId> + '_ {
self.graph.node_indices()
}
pub fn edge_indices(&self) -> impl Iterator<Item = EdgeId> + '_ {
self.graph.edge_indices()
}
pub fn edge_endpoints(&self, edge_id: EdgeId) -> Option<(NodeId, NodeId)> {
self.graph.edge_endpoints(edge_id)
}
pub fn neighbors(&self, node_id: NodeId) -> Vec<NodeId> {
self.graph.neighbors(node_id).collect()
}
pub fn predecessors(&self, node_id: NodeId) -> Vec<NodeId> {
use petgraph::Direction;
self.graph
.neighbors_directed(node_id, Direction::Incoming)
.collect()
}
pub fn predecessor_count(&self, node_id: NodeId) -> usize {
use petgraph::Direction;
self.graph
.neighbors_directed(node_id, Direction::Incoming)
.count()
}
pub fn find_by_symbol(&self, symbol: &str) -> Option<NodeId> {
self.symbol_index.get(symbol).copied()
}
pub fn find_by_id(&self, node_id: &str) -> Option<NodeId> {
self.symbol_index.get(node_id).copied()
}
pub fn nodes_in_file(&self, file_path: &str) -> Vec<NodeId> {
self.file_index.get(file_path).cloned().unwrap_or_default()
}
pub fn find_by_name(&self, name: &str) -> Option<NodeId> {
self.name_index
.get(name)
.and_then(|ids| ids.first().copied())
}
pub fn find_all_by_name(&self, name: &str) -> Vec<NodeId> {
self.name_index.get(name).cloned().unwrap_or_default()
}
pub fn find_by_name_in_file(&self, name: &str, file_hint: Option<&str>) -> Option<NodeId> {
if let Some(fp) = file_hint {
if let Some(&nid) = self
.name_file_index
.get(&(name.to_string(), fp.to_string()))
{
return Some(nid);
}
}
let candidates = self.name_index.get(name).cloned().unwrap_or_default();
if !candidates.is_empty() {
if let Some(fp) = file_hint {
if let Some(&nid) = candidates.iter().find(|&&nid| {
self.get_node(nid)
.map(|n| n.file_path.as_ref() == fp)
.unwrap_or(false)
}) {
return Some(nid);
}
}
return Some(candidates[0]);
}
let name_lower = name.to_lowercase();
let ci_candidates = self
.name_lower_index
.get(&name_lower)
.cloned()
.unwrap_or_default();
if !ci_candidates.is_empty() {
if let Some(fp) = file_hint {
if let Some(&nid) = ci_candidates.iter().find(|&&nid| {
self.get_node(nid)
.map(|n| n.file_path.as_ref() == fp)
.unwrap_or(false)
}) {
return Some(nid);
}
}
return Some(ci_candidates[0]);
}
let search_space: Box<dyn Iterator<Item = NodeId>> = if let Some(fp) = file_hint {
Box::new(self.nodes_in_file(fp).into_iter())
} else {
Box::new(self.graph.node_indices())
};
for nid in search_space {
if let Some(node) = self.graph.node_weight(nid) {
if node.name.to_lowercase().contains(&name_lower)
|| node.id.to_lowercase().contains(&name_lower)
{
return Some(nid);
}
}
}
None
}
pub fn add_call_edges(&mut self, calls: Vec<(NodeId, NodeId)>) {
for (from, to) in calls {
self.add_edge(
from,
to,
Edge {
edge_type: EdgeType::Call,
metadata: EdgeMetadata::empty(),
},
);
}
}
pub fn add_data_flow_edges(&mut self, flows: Vec<(NodeId, NodeId, String, f32)>) {
for (from, to, var_name, confidence) in flows {
self.add_edge(
from,
to,
Edge {
edge_type: EdgeType::DataDependency,
metadata: EdgeMetadata {
call_count: None,
variable_name: Some(var_name),
confidence: Some(confidence),
},
},
);
}
}
pub fn add_inheritance_edges(&mut self, edges: Vec<(NodeId, NodeId, f32)>) {
for (child, parent, confidence) in edges {
self.add_edge(
child,
parent,
Edge {
edge_type: EdgeType::Inheritance,
metadata: EdgeMetadata::with_confidence(confidence),
},
);
}
}
pub fn add_containment_edges(&mut self, edges: Vec<(NodeId, NodeId)>) {
for (container, contained) in edges {
self.add_edge(
container,
contained,
Edge {
edge_type: EdgeType::Containment,
metadata: EdgeMetadata::empty(),
},
);
}
}
pub fn add_import_edges(&mut self, imports: Vec<(NodeId, NodeId)>) {
for (importer, imported) in imports {
self.add_edge(
importer,
imported,
Edge {
edge_type: EdgeType::Import,
metadata: EdgeMetadata::empty(),
},
);
}
}
pub fn set_embedding(&mut self, node_id: &str, embedding: Vec<f32>) {
self.embedding_store.insert(node_id, embedding);
}
pub fn get_embedding(&self, node_id: &str) -> Option<&Vec<f32>> {
self.embedding_store.get(node_id)
}
pub fn embedding_count(&self) -> usize {
self.embedding_store.len()
}
pub fn forward_impact(&self, start: NodeId, config: &TraversalConfig) -> Vec<NodeId> {
self.bfs_directed(start, config, Direction::Forward)
}
pub fn backward_impact(&self, start: NodeId, config: &TraversalConfig) -> Vec<NodeId> {
self.bfs_directed(start, config, Direction::Backward)
}
pub fn bidirectional_impact(&self, start: NodeId, config: &TraversalConfig) -> Vec<NodeId> {
let forward = self.bfs_directed(start, config, Direction::Forward);
let backward = self.bfs_directed(start, config, Direction::Backward);
let mut combined: HashSet<NodeId> = forward.into_iter().collect();
combined.extend(backward);
combined.remove(&start);
combined.into_iter().collect()
}
fn bfs_directed(&self, start: NodeId, config: &TraversalConfig, dir: Direction) -> Vec<NodeId> {
let mut visited: HashSet<NodeId> = HashSet::new();
let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
let mut result: Vec<NodeId> = Vec::new();
visited.insert(start);
queue.push_back((start, 0));
while let Some((current, depth)) = queue.pop_front() {
if let Some(max_n) = config.max_nodes {
if result.len() >= max_n {
break;
}
}
if current != start {
if let Some(node) = self.graph.node_weight(current) {
if config.node_should_collect(node) {
result.push(current);
}
}
}
if let Some(max_d) = config.max_depth {
if depth >= max_d {
continue;
}
}
let mut scratch = self.bfs_scratch.lock().unwrap();
scratch.clear();
match dir {
Direction::Forward => {
scratch.extend(
self.graph
.edges(current)
.filter(|e| config.edge_allowed(e.weight()))
.map(|e| e.target()),
);
}
Direction::Backward => {
use petgraph::Direction as PD;
scratch.extend(
self.graph
.edges_directed(current, PD::Incoming)
.filter(|e| config.edge_allowed(e.weight()))
.map(|e| e.source()),
);
}
}
for &neighbor in scratch.iter() {
if visited.insert(neighbor) {
queue.push_back((neighbor, depth + 1));
}
}
}
result
}
pub fn serialize(&self) -> Result<Vec<u8>, String> {
bincode::serialize(&SerializablePDG::from_pdg(self))
.map_err(|e| format!("Serialize failed: {}", e))
}
pub fn deserialize(data: &[u8]) -> Result<Self, String> {
bincode::deserialize::<SerializablePDG>(data)
.map_err(|e| format!("Deserialize failed: {}", e))
.and_then(|s| s.to_pdg())
}
#[deprecated(
since = "2.0.0",
note = "Use forward_impact with TraversalConfig instead"
)]
pub fn get_forward_impact(&self, node_id: NodeId) -> Vec<NodeId> {
self.forward_impact(node_id, &TraversalConfig::for_impact_analysis())
}
#[deprecated(
since = "2.0.0",
note = "Use backward_impact with TraversalConfig instead"
)]
pub fn get_backward_impact(&self, node_id: NodeId) -> Vec<NodeId> {
self.backward_impact(node_id, &TraversalConfig::for_impact_analysis())
}
#[deprecated(
since = "2.0.0",
note = "Use forward_impact with TraversalConfig instead"
)]
pub fn get_forward_impact_bounded(&self, start: NodeId, max_depth: usize) -> Vec<NodeId> {
let config = TraversalConfig {
max_depth: Some(max_depth),
max_nodes: Some(500),
allowed_edge_types: Some(&[
EdgeType::Call,
EdgeType::DataDependency,
EdgeType::Inheritance,
]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.0,
};
self.forward_impact(start, &config)
}
#[deprecated(
since = "2.0.0",
note = "Use backward_impact with TraversalConfig instead"
)]
pub fn get_backward_impact_bounded(&self, start: NodeId, max_depth: usize) -> Vec<NodeId> {
let config = TraversalConfig {
max_depth: Some(max_depth),
max_nodes: Some(500),
allowed_edge_types: Some(&[
EdgeType::Call,
EdgeType::DataDependency,
EdgeType::Inheritance,
]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.0,
};
self.backward_impact(start, &config)
}
pub fn add_call_graph_edges(&mut self, calls: Vec<(NodeId, NodeId)>) {
self.add_call_edges(calls);
}
}
impl Default for ProgramDependenceGraph {
fn default() -> Self {
Self::new()
}
}
enum Direction {
Forward,
Backward,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_node(id: &str, name: &str, file: &str, ntype: NodeType) -> Node {
Node {
id: id.to_string(),
node_type: ntype,
name: name.to_string(),
file_path: Arc::from(file),
byte_range: (0, 10),
complexity: 2,
language: "rust".to_string(),
}
}
#[test]
fn traversal_respects_max_nodes() {
let mut pdg = ProgramDependenceGraph::new();
let n: Vec<NodeId> = (0..10)
.map(|i| {
pdg.add_node(make_node(
&format!("n{i}"),
&format!("n{i}"),
"f.rs",
NodeType::Function,
))
})
.collect();
for i in 0..9 {
pdg.add_call_edges(vec![(n[i], n[i + 1])]);
}
let config = TraversalConfig {
max_depth: None,
max_nodes: Some(3),
..TraversalConfig::for_impact_analysis()
};
let result = pdg.forward_impact(n[0], &config);
assert!(result.len() <= 3, "Should respect max_nodes cap");
}
#[test]
fn traversal_filters_containment_edges() {
let mut pdg = ProgramDependenceGraph::new();
let cls = pdg.add_node(make_node("f:MyClass", "MyClass", "f.rs", NodeType::Class));
let method = pdg.add_node(make_node("f:MyClass::foo", "foo", "f.rs", NodeType::Method));
let callee = pdg.add_node(make_node("f:bar", "bar", "f.rs", NodeType::Function));
pdg.add_containment_edges(vec![(cls, method)]);
pdg.add_call_edges(vec![(method, callee)]);
let config = TraversalConfig::for_semantic_analysis();
let result = pdg.forward_impact(cls, &config);
assert!(
!result.contains(&callee) || result.contains(&method),
"Containment edges should be filtered from semantic traversal"
);
}
#[test]
fn find_by_name_in_file_no_scan_needed() {
let mut pdg = ProgramDependenceGraph::new();
for i in 0..1000 {
pdg.add_node(make_node(
&format!("f:func{i}"),
&format!("func{i}"),
"f.rs",
NodeType::Function,
));
}
let result = pdg.find_by_name_in_file("FUNC42", None);
assert!(result.is_some());
}
#[test]
fn name_file_index_provides_o1_lookup() {
let mut pdg = ProgramDependenceGraph::new();
let a = pdg.add_node(make_node("a.rs:foo", "foo", "a.rs", NodeType::Function));
let b = pdg.add_node(make_node("b.rs:foo", "foo", "b.rs", NodeType::Function));
let c = pdg.add_node(make_node("c.rs:foo", "foo", "c.rs", NodeType::Function));
assert_eq!(pdg.find_by_name_in_file("foo", Some("a.rs")), Some(a));
assert_eq!(pdg.find_by_name_in_file("foo", Some("b.rs")), Some(b));
assert_eq!(pdg.find_by_name_in_file("foo", Some("c.rs")), Some(c));
assert_eq!(pdg.find_by_name_in_file("foo", Some("z.rs")), Some(a));
}
#[test]
fn name_file_index_maintained_on_remove() {
let mut pdg = ProgramDependenceGraph::new();
let a = pdg.add_node(make_node("a.rs:foo", "foo", "a.rs", NodeType::Function));
let b = pdg.add_node(make_node("b.rs:bar", "bar", "b.rs", NodeType::Function));
assert_eq!(pdg.find_by_name_in_file("foo", Some("a.rs")), Some(a));
assert_eq!(pdg.find_by_name_in_file("bar", Some("b.rs")), Some(b));
pdg.remove_node(a);
assert_eq!(pdg.find_by_name_in_file("foo", Some("a.rs")), None);
assert_eq!(pdg.find_by_name_in_file("bar", Some("b.rs")), Some(b));
}
#[test]
fn containment_edge_type_is_separate_from_call() {
let mut pdg = ProgramDependenceGraph::new();
let cls = pdg.add_node(make_node("f:C", "C", "f.rs", NodeType::Class));
let m = pdg.add_node(make_node("f:C::m", "m", "f.rs", NodeType::Method));
pdg.add_containment_edges(vec![(cls, m)]);
let containment_count = pdg
.edge_indices()
.filter_map(|e| pdg.get_edge(e))
.filter(|e| e.edge_type == EdgeType::Containment)
.count();
let call_count = pdg
.edge_indices()
.filter_map(|e| pdg.get_edge(e))
.filter(|e| e.edge_type == EdgeType::Call)
.count();
assert_eq!(containment_count, 1);
assert_eq!(call_count, 0);
}
#[test]
fn confidence_filtering_works() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:a", "a", "f.rs", NodeType::Function));
let n2 = pdg.add_node(make_node("f:b", "b", "f.rs", NodeType::Function));
pdg.add_data_flow_edges(vec![(n1, n2, "T".to_string(), 0.3)]);
let config = TraversalConfig {
max_depth: Some(5),
max_nodes: Some(100),
allowed_edge_types: Some(&[EdgeType::DataDependency]),
excluded_node_types: None,
min_complexity: None,
min_edge_confidence: 0.5,
};
let result = pdg.forward_impact(n1, &config);
assert!(
!result.contains(&n2),
"Low confidence edge should be filtered"
);
}
#[test]
fn backward_traversal_works() {
let mut pdg = ProgramDependenceGraph::new();
let n: Vec<NodeId> = (0..5)
.map(|i| {
pdg.add_node(make_node(
&format!("f:n{i}"),
&format!("n{i}"),
"f.rs",
NodeType::Function,
))
})
.collect();
for i in 0..4 {
pdg.add_call_edges(vec![(n[i], n[i + 1])]);
}
let config = TraversalConfig::for_impact_analysis();
let backward = pdg.backward_impact(n[4], &config);
assert!(backward.contains(&n[0]));
assert!(backward.contains(&n[1]));
assert!(backward.contains(&n[2]));
assert!(backward.contains(&n[3]));
}
#[test]
fn bidirectional_traversal_works() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:a", "a", "f.rs", NodeType::Function));
let n2 = pdg.add_node(make_node("f:b", "b", "f.rs", NodeType::Function));
let n3 = pdg.add_node(make_node("f:c", "c", "f.rs", NodeType::Function));
pdg.add_call_edges(vec![(n1, n2), (n2, n3)]);
let config = TraversalConfig::for_impact_analysis();
let bidirectional = pdg.bidirectional_impact(n2, &config);
assert!(bidirectional.contains(&n1), "Should reach backward");
assert!(bidirectional.contains(&n3), "Should reach forward");
assert!(
!bidirectional.contains(&n2),
"Should not include start node"
);
}
#[test]
fn embedding_store_field_initialized_on_new_pdg() {
let pdg = ProgramDependenceGraph::new();
assert!(pdg.embedding_store.is_empty());
assert_eq!(pdg.embedding_count(), 0);
}
#[test]
fn set_and_get_embedding_roundtrip() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:foo", "foo", "f.rs", NodeType::Function));
let emb = vec![0.1, 0.2, 0.3, 0.4];
pdg.set_embedding("f:foo", emb.clone());
assert_eq!(pdg.get_embedding("f:foo"), Some(&emb));
assert_eq!(pdg.embedding_count(), 1);
assert!(pdg.get_node(n1).is_some());
}
#[test]
fn remove_node_cleans_up_embedding() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:foo", "foo", "f.rs", NodeType::Function));
pdg.set_embedding("f:foo", vec![0.5, 0.6]);
assert_eq!(pdg.embedding_count(), 1);
let removed = pdg.remove_node(n1);
assert!(removed.is_some());
assert_eq!(pdg.embedding_count(), 0);
assert!(pdg.get_embedding("f:foo").is_none());
}
#[test]
fn remove_file_cleans_up_all_embeddings() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:a", "a", "src/lib.rs", NodeType::Function));
let n2 = pdg.add_node(make_node("f:b", "b", "src/lib.rs", NodeType::Function));
let n3 = pdg.add_node(make_node("f:c", "c", "src/other.rs", NodeType::Function));
pdg.set_embedding("f:a", vec![1.0]);
pdg.set_embedding("f:b", vec![2.0]);
pdg.set_embedding("f:c", vec![3.0]);
assert_eq!(pdg.embedding_count(), 3);
pdg.remove_file("src/lib.rs");
assert!(
pdg.get_embedding("f:a").is_none(),
"a's embedding should be removed"
);
assert!(
pdg.get_embedding("f:b").is_none(),
"b's embedding should be removed"
);
assert_eq!(
pdg.get_embedding("f:c"),
Some(&vec![3.0]),
"c's embedding should remain"
);
assert_eq!(pdg.embedding_count(), 1);
assert!(pdg.get_node(n1).is_none());
assert!(pdg.get_node(n2).is_none());
assert!(pdg.get_node(n3).is_some());
}
#[test]
fn embedding_store_overwrite() {
let mut pdg = ProgramDependenceGraph::new();
pdg.add_node(make_node("f:foo", "foo", "f.rs", NodeType::Function));
pdg.set_embedding("f:foo", vec![1.0, 2.0]);
assert_eq!(pdg.get_embedding("f:foo"), Some(&vec![1.0, 2.0]));
pdg.set_embedding("f:foo", vec![3.0, 4.0]);
assert_eq!(pdg.get_embedding("f:foo"), Some(&vec![3.0, 4.0]));
assert_eq!(
pdg.embedding_count(),
1,
"Should still have 1 embedding after overwrite"
);
}
#[test]
fn serialization_preserves_embeddings() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:foo", "foo", "f.rs", NodeType::Function));
let n2 = pdg.add_node(make_node("f:bar", "bar", "f.rs", NodeType::Function));
pdg.add_call_edges(vec![(n1, n2)]);
pdg.set_embedding("f:foo", vec![0.1, 0.2, 0.3]);
pdg.set_embedding("f:bar", vec![0.4, 0.5, 0.6]);
let bytes = pdg.serialize().expect("Serialization should succeed");
let restored =
ProgramDependenceGraph::deserialize(&bytes).expect("Deserialization should succeed");
assert_eq!(restored.get_embedding("f:foo"), Some(&vec![0.1, 0.2, 0.3]));
assert_eq!(restored.get_embedding("f:bar"), Some(&vec![0.4, 0.5, 0.6]));
assert_eq!(restored.embedding_count(), 2);
}
#[test]
fn deserialization_backward_compat_no_embeddings() {
let mut pdg = ProgramDependenceGraph::new();
let n1 = pdg.add_node(make_node("f:foo", "foo", "f.rs", NodeType::Function));
pdg.add_call_edges(vec![(n1, n1)]);
let old_format = SerializablePDG {
nodes: pdg
.graph
.node_indices()
.map(|idx| SerializableNode {
index: idx.index() as u32,
node: pdg.graph[idx].clone(),
})
.collect(),
edges: pdg
.graph
.edge_indices()
.map(|eidx| {
let (source, target) = pdg.graph.edge_endpoints(eidx).unwrap();
SerializableEdge {
source: source.index() as u32,
target: target.index() as u32,
edge: pdg.graph[eidx].clone(),
}
})
.collect(),
symbol_index: pdg
.symbol_index
.iter()
.map(|(k, v)| (k.clone(), v.index() as u32))
.collect(),
file_index: pdg
.file_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect(),
name_index: pdg
.name_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect(),
name_lower_index: pdg
.name_lower_index
.iter()
.map(|(k, v)| (k.clone(), v.iter().map(|id| id.index() as u32).collect()))
.collect(),
embeddings: HashMap::new(), };
let bytes = bincode::serialize(&old_format).expect("Serialize old format");
let restored = ProgramDependenceGraph::deserialize(&bytes)
.expect("Should deserialize old format without error");
assert_eq!(restored.embedding_count(), 0);
assert_eq!(restored.node_count(), 1);
}
#[test]
fn bulk_import_edges_helper() {
let mut pdg = ProgramDependenceGraph::new();
let a = pdg.add_node(make_node("mod:a", "a", "a.rs", NodeType::Module));
let b = pdg.add_node(make_node("mod:b", "b", "b.rs", NodeType::Module));
let c = pdg.add_node(make_node("mod:c", "c", "c.rs", NodeType::Module));
pdg.add_import_edges(vec![(a, b), (a, c)]);
let import_count = pdg
.edge_indices()
.filter_map(|e| pdg.get_edge(e))
.filter(|e| e.edge_type == EdgeType::Import)
.count();
assert_eq!(import_count, 2, "Should have 2 import edges");
}
#[test]
fn bulk_inheritance_edges_with_confidence() {
let mut pdg = ProgramDependenceGraph::new();
let child = pdg.add_node(make_node("f:Child", "Child", "f.rs", NodeType::Class));
let parent = pdg.add_node(make_node("f:Parent", "Parent", "f.rs", NodeType::Class));
pdg.add_inheritance_edges(vec![(child, parent, 0.85)]);
let edges: Vec<_> = pdg
.edge_indices()
.filter_map(|e| {
let edge = pdg.get_edge(e)?;
if edge.edge_type == EdgeType::Inheritance {
Some((pdg.edge_endpoints(e).unwrap(), edge.clone()))
} else {
None
}
})
.collect();
assert_eq!(edges.len(), 1);
let ((src, tgt), edge) = &edges[0];
assert_eq!(*src, child);
assert_eq!(*tgt, parent);
assert_eq!(edge.metadata.confidence, Some(0.85));
}
}