use crate::graph::pdg::{EdgeId, NodeId, NodeType, ProgramDependenceGraph};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
pub struct CrossProjectPDG {
pub root_project_id: String,
pub merged_pdg: ProgramDependenceGraph,
pub node_origins: HashMap<NodeId, String>,
pub external_refs: HashMap<String, Vec<NodeId>>,
pub max_depth: usize,
}
impl CrossProjectPDG {
pub fn new(root_project_id: String, root_pdg: ProgramDependenceGraph) -> Self {
Self {
root_project_id,
merged_pdg: root_pdg,
node_origins: HashMap::new(),
external_refs: HashMap::new(),
max_depth: 3,
}
}
pub fn with_max_depth(
root_project_id: String,
root_pdg: ProgramDependenceGraph,
max_depth: usize,
) -> Self {
Self {
root_project_id,
merged_pdg: root_pdg,
node_origins: HashMap::new(),
external_refs: HashMap::new(),
max_depth,
}
}
pub fn merge_external_pdg(
&mut self,
project_id: &str,
external_pdg: &ProgramDependenceGraph,
) -> Result<(), MergeError> {
let current_depth = self.external_refs.len();
if current_depth >= self.max_depth {
return Err(MergeError::MaxDepthExceeded(self.max_depth));
}
let mut node_id_map: HashMap<NodeId, NodeId> = HashMap::new();
let mut added_nodes = Vec::new();
for old_node_id in external_pdg.node_indices() {
if let Some(node) = external_pdg.get_node(old_node_id) {
let new_node_id = self.merged_pdg.add_node(node.clone());
node_id_map.insert(old_node_id, new_node_id);
self.node_origins
.insert(new_node_id, project_id.to_string());
added_nodes.push(new_node_id);
}
}
for edge_id in external_pdg.edge_indices() {
if let Some(edge) = external_pdg.get_edge(edge_id) {
if let Some((old_source, old_target)) = external_pdg.edge_endpoints(edge_id) {
let new_source = match node_id_map.get(&old_source) {
Some(&id) => id,
None => old_source, };
let new_target = match node_id_map.get(&old_target) {
Some(&id) => id,
None => old_target,
};
self.merged_pdg
.add_edge(new_source, new_target, edge.clone());
}
}
}
self.external_refs
.insert(project_id.to_string(), added_nodes);
Ok(())
}
pub fn add_external_ref(&mut self, node_id: NodeId, project_id: &str) {
self.external_refs
.entry(project_id.to_string())
.or_default()
.push(node_id);
self.node_origins.insert(node_id, project_id.to_string());
}
pub fn is_external_node(&self, node_id: &NodeId) -> bool {
self.node_origins
.get(node_id)
.map(|origin| origin != &self.root_project_id)
.unwrap_or(false)
}
pub fn get_node_origin(&self, node_id: &NodeId) -> Option<&String> {
self.node_origins.get(node_id)
}
pub fn get_referenced_projects(&self) -> Vec<String> {
self.external_refs
.keys()
.filter(|project_id| **project_id != self.root_project_id)
.cloned()
.collect()
}
pub fn local_nodes(&self) -> Vec<NodeId> {
self.merged_pdg
.node_indices()
.filter(|id| !self.is_external_node(id))
.collect()
}
pub fn external_nodes(&self) -> Vec<NodeId> {
self.merged_pdg
.node_indices()
.filter(|id| self.is_external_node(id))
.collect()
}
pub fn node_count(&self) -> usize {
self.merged_pdg.node_count()
}
pub fn edge_count(&self) -> usize {
self.merged_pdg.edge_count()
}
pub fn pdg(&self) -> &ProgramDependenceGraph {
&self.merged_pdg
}
pub fn pdg_mut(&mut self) -> &mut ProgramDependenceGraph {
&mut self.merged_pdg
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExternalNodeRef {
pub node_id: u32,
pub project_id: String,
pub symbol_name: String,
pub node_type: NodeType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableCrossProjectPDG {
pub root_project_id: String,
pub node_origins: Vec<(u32, String)>,
pub external_refs: HashMap<String, Vec<ExternalNodeRef>>,
}
impl CrossProjectPDG {
pub fn to_serializable(&self) -> SerializableCrossProjectPDG {
let node_origins: Vec<(u32, String)> = self
.node_origins
.iter()
.map(|(id, project)| (id.index() as u32, project.clone()))
.collect();
let external_refs_serializable: HashMap<String, Vec<ExternalNodeRef>> = self
.external_refs
.iter()
.map(|(project, nodes)| {
let refs: Vec<ExternalNodeRef> = nodes
.iter()
.filter_map(|node_id| {
self.merged_pdg
.get_node(*node_id)
.map(|node| ExternalNodeRef {
node_id: node_id.index() as u32,
project_id: project.clone(),
symbol_name: node.name.clone(),
node_type: node.node_type.clone(),
})
})
.collect();
(project.clone(), refs)
})
.collect();
SerializableCrossProjectPDG {
root_project_id: self.root_project_id.clone(),
node_origins,
external_refs: external_refs_serializable,
}
}
pub fn from_serializable_with_pdg(
serializable: SerializableCrossProjectPDG,
pdg: ProgramDependenceGraph,
) -> Self {
let mut node_origins = HashMap::new();
for (node_index, project_id) in serializable.node_origins {
node_origins.insert(NodeId::new(node_index as usize), project_id);
}
Self {
root_project_id: serializable.root_project_id,
merged_pdg: pdg,
node_origins,
external_refs: HashMap::new(),
max_depth: 3,
}
}
}
#[derive(Debug, Error)]
pub enum MergeError {
#[error("Node ID conflict: {0:?} exists in both local and external")]
NodeConflict(NodeId),
#[error("Edge ID conflict: {0:?} exists in both local and external")]
EdgeConflict(EdgeId),
#[error("Max depth exceeded: {0}")]
MaxDepthExceeded(usize),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::pdg::Node;
fn create_test_node(name: &str) -> Node {
Node {
id: name.to_string(),
node_type: NodeType::Function,
name: name.to_string(),
file_path: std::sync::Arc::from(format!("src/{}.rs", name)),
byte_range: (0, 100),
complexity: 5,
language: "rust".to_string(),
}
}
fn create_test_pdg(nodes: Vec<&str>) -> ProgramDependenceGraph {
let mut pdg = ProgramDependenceGraph::new();
for name in nodes {
let node = create_test_node(name);
pdg.add_node(node);
}
pdg
}
#[test]
fn test_cross_project_pdg_creation() {
let root_pdg = create_test_pdg(vec!["func_a", "func_b"]);
let cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
assert_eq!(cross_pdg.root_project_id, "root_project");
assert_eq!(cross_pdg.node_count(), 2);
assert_eq!(cross_pdg.local_nodes().len(), 2);
assert_eq!(cross_pdg.external_nodes().len(), 0);
}
#[test]
fn test_merge_external_pdg() {
let root_pdg = create_test_pdg(vec!["root_func"]);
let mut cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
let ext_pdg = create_test_pdg(vec!["ext_func_a", "ext_func_b"]);
cross_pdg
.merge_external_pdg("external_project", &ext_pdg)
.unwrap();
assert_eq!(cross_pdg.node_count(), 3);
assert_eq!(cross_pdg.local_nodes().len(), 1);
assert_eq!(cross_pdg.external_nodes().len(), 2);
for node_id in cross_pdg.external_nodes() {
assert!(cross_pdg.is_external_node(&node_id));
assert_eq!(
cross_pdg.get_node_origin(&node_id),
Some(&"external_project".to_string())
);
}
}
#[test]
fn test_max_depth_exceeded() {
let root_pdg = create_test_pdg(vec!["root_func"]);
let mut cross_pdg =
CrossProjectPDG::with_max_depth("root_project".to_string(), root_pdg, 1);
let ext_pdg = create_test_pdg(vec!["ext_func"]);
cross_pdg
.merge_external_pdg("ext_project_1", &ext_pdg)
.unwrap();
let result = cross_pdg.merge_external_pdg("ext_project_2", &ext_pdg);
assert!(matches!(result, Err(MergeError::MaxDepthExceeded(1))));
}
#[test]
fn test_add_external_ref() {
let root_pdg = create_test_pdg(vec!["root_func"]);
let mut cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
let fake_node_id = NodeId::new(100);
cross_pdg.add_external_ref(fake_node_id, "external_project");
assert!(cross_pdg.is_external_node(&fake_node_id));
assert_eq!(
cross_pdg.get_node_origin(&fake_node_id),
Some(&"external_project".to_string())
);
}
#[test]
fn test_get_referenced_projects() {
let root_pdg = create_test_pdg(vec!["root_func"]);
let mut cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
let ext_pdg_1 = create_test_pdg(vec!["ext_func_a"]);
let ext_pdg_2 = create_test_pdg(vec!["ext_func_b"]);
cross_pdg
.merge_external_pdg("ext_project_1", &ext_pdg_1)
.unwrap();
cross_pdg
.merge_external_pdg("ext_project_2", &ext_pdg_2)
.unwrap();
let referenced = cross_pdg.get_referenced_projects();
assert_eq!(referenced.len(), 2);
assert!(referenced.contains(&"ext_project_1".to_string()));
assert!(referenced.contains(&"ext_project_2".to_string()));
}
#[test]
fn test_merge_with_edges() {
let mut root_pdg = ProgramDependenceGraph::new();
let node_a = create_test_node("root_func_a");
let node_b = create_test_node("root_func_b");
let id_a = root_pdg.add_node(node_a);
let id_b = root_pdg.add_node(node_b);
let edge = crate::graph::pdg::Edge {
edge_type: crate::graph::pdg::EdgeType::Call,
metadata: crate::graph::pdg::EdgeMetadata {
call_count: Some(1),
variable_name: None,
confidence: None,
},
};
root_pdg.add_edge(id_a, id_b, edge);
let mut cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
let mut ext_pdg = ProgramDependenceGraph::new();
let node_x = create_test_node("ext_func_x");
let node_y = create_test_node("ext_func_y");
let id_x = ext_pdg.add_node(node_x);
let id_y = ext_pdg.add_node(node_y);
let ext_edge = crate::graph::pdg::Edge {
edge_type: crate::graph::pdg::EdgeType::DataDependency,
metadata: crate::graph::pdg::EdgeMetadata {
call_count: None,
confidence: None,
variable_name: Some("data".to_string()),
},
};
ext_pdg.add_edge(id_x, id_y, ext_edge);
cross_pdg
.merge_external_pdg("external_project", &ext_pdg)
.unwrap();
assert_eq!(cross_pdg.node_count(), 4);
assert_eq!(cross_pdg.edge_count(), 2); }
#[test]
fn test_serialization() {
let root_pdg = create_test_pdg(vec!["root_func"]);
let mut cross_pdg = CrossProjectPDG::new("root_project".to_string(), root_pdg);
let ext_pdg = create_test_pdg(vec!["ext_func"]);
cross_pdg
.merge_external_pdg("external_project", &ext_pdg)
.unwrap();
let serializable = cross_pdg.to_serializable();
assert_eq!(serializable.root_project_id, "root_project");
assert_eq!(serializable.node_origins.len(), 1);
assert_eq!(serializable.node_origins[0].1, "external_project");
}
}