#[allow(unused_imports)]
use crate::core::{CallEdge, EdgeConfidence, NodeId};
use petgraph::graph::NodeIndex;
#[derive(Debug, Clone)]
pub struct CsrGraph {
pub row_offsets: Vec<usize>,
pub column_indices: Vec<NodeIndex>,
pub edge_data: Vec<CallEdge>,
pub node_count: usize,
}
impl CsrGraph {
pub fn from_edges(node_count: usize, mut edges: Vec<(NodeIndex, NodeIndex, CallEdge)>) -> Self {
let mut row_offsets = vec![0; node_count + 1];
for (from_idx, _, _) in &edges {
row_offsets[from_idx.index() + 1] += 1;
}
for i in 1..=node_count {
row_offsets[i] += row_offsets[i - 1];
}
edges.sort_by(|a, b| (a.0.index(), a.1.index()).cmp(&(b.0.index(), b.1.index())));
let mut column_indices = Vec::with_capacity(edges.len());
let mut edge_data = Vec::with_capacity(edges.len());
for (_, to_idx, edge) in edges {
column_indices.push(to_idx);
edge_data.push(edge);
}
CsrGraph {
row_offsets,
column_indices,
edge_data,
node_count,
}
}
pub fn get_callees(&self, node_idx: NodeIndex) -> Vec<(NodeIndex, &CallEdge)> {
let node_id = node_idx.index();
if node_id >= self.node_count {
return Vec::new();
}
let start = self.row_offsets[node_id];
let end = self.row_offsets[node_id + 1];
self.column_indices[start..end]
.iter()
.zip(&self.edge_data[start..end])
.map(|(idx, edge)| (*idx, edge))
.collect()
}
pub fn has_edge(&self, from_idx: NodeIndex, to_idx: NodeIndex) -> bool {
let node_id = from_idx.index();
if node_id >= self.node_count {
return false;
}
let start = self.row_offsets[node_id];
let end = self.row_offsets[node_id + 1];
self.column_indices[start..end]
.binary_search(&to_idx)
.is_ok()
}
pub fn get_edge(&self, from_idx: NodeIndex, to_idx: NodeIndex) -> Option<&CallEdge> {
let node_id = from_idx.index();
if node_id >= self.node_count {
return None;
}
let start = self.row_offsets[node_id];
let end = self.row_offsets[node_id + 1];
self.column_indices[start..end]
.binary_search(&to_idx)
.ok()
.map(|idx| &self.edge_data[start + idx])
}
pub fn edge_count(&self) -> usize {
self.column_indices.len()
}
pub fn memory_bytes(&self) -> usize {
let offsets_size = self.row_offsets.len() * std::mem::size_of::<usize>();
let indices_size = self.column_indices.len() * std::mem::size_of::<NodeIndex>();
let edge_size = self.edge_data.len() * std::mem::size_of::<CallEdge>();
offsets_size + indices_size + edge_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csr_empty_graph() {
let csr = CsrGraph::from_edges(10, Vec::new());
assert_eq!(csr.node_count, 10);
assert_eq!(csr.edge_count(), 0);
}
#[test]
fn test_csr_single_edge() {
let from = NodeIndex::new(0);
let to = NodeIndex::new(1);
let edge = CallEdge::new(
NodeId::from_u32(100),
NodeId::from_u32(101),
EdgeConfidence::Certain,
);
let csr = CsrGraph::from_edges(5, vec![(from, to, edge.clone())]);
assert_eq!(csr.edge_count(), 1);
assert!(csr.has_edge(from, to));
assert!(!csr.has_edge(to, from));
}
#[test]
fn test_csr_multiple_edges_same_source() {
let from = NodeIndex::new(0);
let to1 = NodeIndex::new(1);
let to2 = NodeIndex::new(2);
let edge1 = CallEdge::new(
NodeId::from_u32(0),
NodeId::from_u32(1),
EdgeConfidence::Certain,
);
let edge2 = CallEdge::new(
NodeId::from_u32(0),
NodeId::from_u32(2),
EdgeConfidence::HighLikely,
);
let csr = CsrGraph::from_edges(
5,
vec![(from, to1, edge1.clone()), (from, to2, edge2.clone())],
);
assert_eq!(csr.edge_count(), 2);
let callees = csr.get_callees(from);
assert_eq!(callees.len(), 2);
}
#[test]
fn test_csr_get_edge() {
let from = NodeIndex::new(0);
let to = NodeIndex::new(1);
let edge = CallEdge::new(
NodeId::from_u32(0),
NodeId::from_u32(1),
EdgeConfidence::Certain,
);
let csr = CsrGraph::from_edges(5, vec![(from, to, edge.clone())]);
let retrieved = csr.get_edge(from, to);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().confidence, EdgeConfidence::Certain);
}
#[test]
fn test_csr_memory_efficiency() {
let mut edges = Vec::new();
for i in 0..100 {
for j in 0..10 {
let from = NodeIndex::new(i);
let to = NodeIndex::new(j);
let edge = CallEdge::new(
NodeId::from_u32(i as u32),
NodeId::from_u32(j as u32),
EdgeConfidence::HighLikely,
);
edges.push((from, to, edge));
}
}
let csr = CsrGraph::from_edges(100, edges);
let memory = csr.memory_bytes();
assert!(
memory < 100_000,
"CSR memory usage too high: {} bytes",
memory
);
}
}