use anyhow::{anyhow, Result};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub u32);
#[derive(Debug, Clone)]
pub struct CsrGraph {
row_offsets: Vec<u32>,
col_indices: Vec<u32>,
edge_weights: Vec<f32>,
rev_row_offsets: Vec<u32>,
rev_col_indices: Vec<u32>,
rev_edge_weights: Vec<f32>,
node_names: HashMap<NodeId, String>,
num_nodes: usize,
}
impl CsrGraph {
#[must_use]
pub fn new() -> Self {
Self {
row_offsets: vec![0], col_indices: Vec::new(),
edge_weights: Vec::new(),
rev_row_offsets: vec![0], rev_col_indices: Vec::new(),
rev_edge_weights: Vec::new(),
node_names: HashMap::new(),
num_nodes: 0,
}
}
pub fn from_edge_list(edges: &[(NodeId, NodeId, f32)]) -> Result<Self> {
if edges.is_empty() {
return Ok(Self::new());
}
let max_node = edges
.iter()
.flat_map(|(src, dst, _)| [src.0, dst.0])
.max()
.ok_or_else(|| anyhow!("Empty edge list"))?;
let num_nodes = (max_node + 1) as usize;
let mut adj_list: Vec<Vec<(u32, f32)>> = vec![Vec::new(); num_nodes];
let mut rev_adj_list: Vec<Vec<(u32, f32)>> = vec![Vec::new(); num_nodes];
for (src, dst, weight) in edges {
adj_list[src.0 as usize].push((dst.0, *weight));
rev_adj_list[dst.0 as usize].push((src.0, *weight)); }
let mut row_offsets = Vec::with_capacity(num_nodes + 1);
let mut col_indices = Vec::new();
let mut edge_weights_vec = Vec::new();
let mut offset = 0_u32;
row_offsets.push(offset);
for neighbors in &adj_list {
#[allow(clippy::cast_possible_truncation)] let len_u32 = neighbors.len() as u32;
offset += len_u32;
row_offsets.push(offset);
for (target, weight) in neighbors {
col_indices.push(*target);
edge_weights_vec.push(*weight);
}
}
let mut rev_row_offsets = Vec::with_capacity(num_nodes + 1);
let mut rev_col_indices = Vec::new();
let mut rev_edge_weights_vec = Vec::new();
let mut rev_offset = 0_u32;
rev_row_offsets.push(rev_offset);
for rev_neighbors in &rev_adj_list {
#[allow(clippy::cast_possible_truncation)]
let len_u32 = rev_neighbors.len() as u32;
rev_offset += len_u32;
rev_row_offsets.push(rev_offset);
for (source, weight) in rev_neighbors {
rev_col_indices.push(*source);
rev_edge_weights_vec.push(*weight);
}
}
Ok(Self {
row_offsets,
col_indices,
edge_weights: edge_weights_vec,
rev_row_offsets,
rev_col_indices,
rev_edge_weights: rev_edge_weights_vec,
node_names: HashMap::new(),
num_nodes,
})
}
pub fn add_edge(&mut self, src: NodeId, dst: NodeId, weight: f32) -> Result<()> {
let max_node = src.0.max(dst.0) as usize;
if max_node >= self.num_nodes {
self.expand_to(max_node + 1);
}
let src_idx = src.0 as usize;
let end = self.row_offsets[src_idx + 1] as usize;
self.col_indices.insert(end, dst.0);
self.edge_weights.insert(end, weight);
for offset in &mut self.row_offsets[src_idx + 1..] {
*offset += 1;
}
let dst_idx = dst.0 as usize;
let rev_end = self.rev_row_offsets[dst_idx + 1] as usize;
self.rev_col_indices.insert(rev_end, src.0);
self.rev_edge_weights.insert(rev_end, weight);
for offset in &mut self.rev_row_offsets[dst_idx + 1..] {
*offset += 1;
}
Ok(())
}
pub fn outgoing_neighbors(&self, node: NodeId) -> Result<&[u32]> {
if (node.0 as usize) >= self.num_nodes {
return Err(anyhow!("Node ID {} out of bounds", node.0));
}
let idx = node.0 as usize;
let start = self.row_offsets[idx] as usize;
let end = self.row_offsets[idx + 1] as usize;
Ok(&self.col_indices[start..end])
}
pub fn incoming_neighbors(&self, target: NodeId) -> Result<&[u32]> {
if (target.0 as usize) >= self.num_nodes {
return Err(anyhow!("Node ID {} out of bounds", target.0));
}
let idx = target.0 as usize;
let start = self.rev_row_offsets[idx] as usize;
let end = self.rev_row_offsets[idx + 1] as usize;
Ok(&self.rev_col_indices[start..end])
}
pub fn set_node_name(&mut self, node: NodeId, name: String) {
self.node_names.insert(node, name);
}
#[must_use]
pub fn get_node_name(&self, node: NodeId) -> Option<&str> {
self.node_names.get(&node).map(String::as_str)
}
#[must_use]
pub const fn num_nodes(&self) -> usize {
self.num_nodes
}
#[must_use]
pub fn num_edges(&self) -> usize {
self.col_indices.len()
}
#[must_use]
pub fn row_offsets_slice(&self) -> &[u32] {
&self.row_offsets
}
#[must_use]
pub fn col_indices_slice(&self) -> &[u32] {
&self.col_indices
}
#[must_use]
pub fn edge_weights_slice(&self) -> &[f32] {
&self.edge_weights
}
#[must_use]
pub fn adjacency(&self, node_id: NodeId) -> (&[u32], &[f32]) {
let idx = node_id.0 as usize;
if idx >= self.num_nodes {
return (&[], &[]);
}
let start = self.row_offsets[idx] as usize;
let end = self.row_offsets[idx + 1] as usize;
(
&self.col_indices[start..end],
&self.edge_weights[start..end],
)
}
pub fn iter_adjacency(&self) -> impl Iterator<Item = (NodeId, &[u32], &[f32])> + '_ {
(0..self.num_nodes).map(move |node_id| {
let start = self.row_offsets[node_id] as usize;
let end = self.row_offsets[node_id + 1] as usize;
#[allow(clippy::cast_possible_truncation)]
(
NodeId(node_id as u32),
&self.col_indices[start..end],
&self.edge_weights[start..end],
)
})
}
fn expand_to(&mut self, new_size: usize) {
if new_size <= self.num_nodes {
return;
}
let last_offset = *self.row_offsets.last().unwrap_or(&0);
for _ in self.num_nodes..new_size {
self.row_offsets.push(last_offset);
}
let rev_last_offset = *self.rev_row_offsets.last().unwrap_or(&0);
for _ in self.num_nodes..new_size {
self.rev_row_offsets.push(rev_last_offset);
}
self.num_nodes = new_size;
}
#[must_use]
pub fn csr_components(&self) -> (&[u32], &[u32], &[f32]) {
(&self.row_offsets, &self.col_indices, &self.edge_weights)
}
}
impl Default for CsrGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_graph() {
let graph = CsrGraph::new();
assert_eq!(graph.num_nodes(), 0);
assert_eq!(graph.num_edges(), 0);
}
#[test]
fn test_from_edge_list_simple() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(2), 1.0),
(NodeId(1), NodeId(2), 1.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert_eq!(graph.num_nodes(), 3);
assert_eq!(graph.num_edges(), 3);
assert_eq!(graph.row_offsets, vec![0, 2, 3, 3]);
assert_eq!(graph.col_indices, vec![1, 2, 2]);
assert_eq!(graph.edge_weights, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_outgoing_neighbors() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(0), NodeId(2), 2.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let neighbors = graph.outgoing_neighbors(NodeId(0)).unwrap();
assert_eq!(neighbors, &[1, 2]);
let neighbors = graph.outgoing_neighbors(NodeId(1)).unwrap();
let empty: &[u32] = &[];
assert_eq!(neighbors, empty);
}
#[test]
fn test_incoming_neighbors() {
let edges = vec![(NodeId(0), NodeId(2), 1.0), (NodeId(1), NodeId(2), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let callers = graph.incoming_neighbors(NodeId(2)).unwrap();
assert_eq!(callers.len(), 2);
assert!(callers.contains(&0));
assert!(callers.contains(&1));
}
#[test]
fn test_reverse_csr_structure() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0), (NodeId(0), NodeId(2), 2.0), (NodeId(1), NodeId(2), 3.0), ];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let empty: &[u32] = &[];
assert_eq!(graph.incoming_neighbors(NodeId(0)).unwrap(), empty);
assert_eq!(graph.incoming_neighbors(NodeId(1)).unwrap(), &[0]);
let node2_incoming = graph.incoming_neighbors(NodeId(2)).unwrap();
assert_eq!(node2_incoming.len(), 2);
assert!(node2_incoming.contains(&0));
assert!(node2_incoming.contains(&1));
}
#[test]
fn test_reverse_csr_multi_edges() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(1), 2.0), (NodeId(2), NodeId(1), 3.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let incoming = graph.incoming_neighbors(NodeId(1)).unwrap();
assert_eq!(incoming.len(), 3);
let count_0 = incoming.iter().filter(|&&x| x == 0).count();
let count_2 = incoming.iter().filter(|&&x| x == 2).count();
assert_eq!(count_0, 2, "Should have 2 edges from node 0");
assert_eq!(count_2, 1, "Should have 1 edge from node 2");
}
#[test]
fn test_reverse_csr_with_add_edge() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(1), 2.0).unwrap();
let incoming = graph.incoming_neighbors(NodeId(1)).unwrap();
assert_eq!(incoming.len(), 2);
assert!(incoming.contains(&0));
assert!(incoming.contains(&2));
graph.add_edge(NodeId(3), NodeId(1), 3.0).unwrap();
let incoming = graph.incoming_neighbors(NodeId(1)).unwrap();
assert_eq!(incoming.len(), 3);
assert!(incoming.contains(&0));
assert!(incoming.contains(&2));
assert!(incoming.contains(&3));
}
#[test]
fn test_add_edge_dynamic() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(0), NodeId(2), 1.0).unwrap();
assert_eq!(graph.num_nodes(), 3);
assert_eq!(graph.num_edges(), 2);
let neighbors = graph.outgoing_neighbors(NodeId(0)).unwrap();
assert_eq!(neighbors, &[1, 2]);
}
#[test]
fn test_node_names() {
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.set_node_name(NodeId(0), "main".to_string());
graph.set_node_name(NodeId(1), "parse_args".to_string());
assert_eq!(graph.get_node_name(NodeId(0)), Some("main"));
assert_eq!(graph.get_node_name(NodeId(1)), Some("parse_args"));
}
#[test]
fn test_csr_components() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(0), NodeId(2), 2.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let (row_offsets, col_indices, weights) = graph.csr_components();
assert_eq!(row_offsets, &[0, 2, 2, 2]);
assert_eq!(col_indices, &[1, 2]);
assert_eq!(weights, &[1.0, 2.0]);
}
#[test]
fn test_adjacency() {
let edges = vec![
(NodeId(0), NodeId(1), 1.5),
(NodeId(0), NodeId(2), 2.5),
(NodeId(1), NodeId(2), 3.5),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let (targets, weights) = graph.adjacency(NodeId(0));
assert_eq!(targets, &[1, 2]);
assert_eq!(weights, &[1.5, 2.5]);
let (targets, weights) = graph.adjacency(NodeId(1));
assert_eq!(targets, &[2]);
assert_eq!(weights, &[3.5]);
let (targets, weights) = graph.adjacency(NodeId(2));
let empty_u32: &[u32] = &[];
let empty_f32: &[f32] = &[];
assert_eq!(targets, empty_u32);
assert_eq!(weights, empty_f32);
}
#[test]
fn test_adjacency_out_of_bounds() {
let edges = vec![(NodeId(0), NodeId(1), 1.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let (targets, weights) = graph.adjacency(NodeId(999));
let empty_u32: &[u32] = &[];
let empty_f32: &[f32] = &[];
assert_eq!(targets, empty_u32);
assert_eq!(weights, empty_f32);
}
#[test]
fn test_iter_adjacency() {
let edges = vec![
(NodeId(0), NodeId(1), 1.0),
(NodeId(0), NodeId(2), 2.0),
(NodeId(1), NodeId(2), 3.0),
];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
let adjacencies: Vec<_> = graph.iter_adjacency().collect();
assert_eq!(adjacencies.len(), 3);
assert_eq!(adjacencies[0].0, NodeId(0));
assert_eq!(adjacencies[0].1, &[1, 2]);
assert_eq!(adjacencies[0].2, &[1.0, 2.0]);
assert_eq!(adjacencies[1].0, NodeId(1));
assert_eq!(adjacencies[1].1, &[2]);
assert_eq!(adjacencies[1].2, &[3.0]);
assert_eq!(adjacencies[2].0, NodeId(2));
let empty_u32: &[u32] = &[];
let empty_f32: &[f32] = &[];
assert_eq!(adjacencies[2].1, empty_u32);
assert_eq!(adjacencies[2].2, empty_f32);
}
#[test]
fn test_slice_methods() {
let edges = vec![(NodeId(0), NodeId(1), 1.0), (NodeId(0), NodeId(2), 2.0)];
let graph = CsrGraph::from_edge_list(&edges).unwrap();
assert_eq!(graph.row_offsets_slice(), &[0, 2, 2, 2]);
assert_eq!(graph.col_indices_slice(), &[1, 2]);
assert_eq!(graph.edge_weights_slice(), &[1.0, 2.0]);
}
#[test]
fn test_get_node_name_nonexistent() {
let graph = CsrGraph::new();
assert_eq!(graph.get_node_name(NodeId(0)), None);
}
#[test]
fn test_empty_adjacency_iterator() {
let graph = CsrGraph::new();
let adjacencies: Vec<_> = graph.iter_adjacency().collect();
assert_eq!(adjacencies.len(), 0);
}
}