use super::node::NodeId;
use crate::{GraphError, Result};
#[derive(Debug, Clone)]
pub struct CsrMatrix {
pub num_rows: usize,
pub num_cols: usize,
pub row_ptr: Vec<u64>,
pub col_idx: Vec<u32>,
pub values: Option<Vec<f64>>,
}
impl CsrMatrix {
pub fn empty(num_nodes: usize) -> Self {
Self {
num_rows: num_nodes,
num_cols: num_nodes,
row_ptr: vec![0; num_nodes + 1],
col_idx: Vec::new(),
values: None,
}
}
pub fn from_edges(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
CsrMatrixBuilder::new(num_nodes).with_edges(edges).build()
}
pub fn from_weighted_edges(num_nodes: usize, edges: &[(u32, u32, f64)]) -> Self {
CsrMatrixBuilder::new(num_nodes)
.with_weighted_edges(edges)
.build()
}
pub fn num_nonzeros(&self) -> usize {
self.col_idx.len()
}
pub fn is_empty(&self) -> bool {
self.col_idx.is_empty()
}
pub fn degree(&self, node: NodeId) -> usize {
let i = node.0 as usize;
if i >= self.num_rows {
return 0;
}
(self.row_ptr[i + 1] - self.row_ptr[i]) as usize
}
pub fn neighbors(&self, node: NodeId) -> &[u32] {
let i = node.0 as usize;
if i >= self.num_rows {
return &[];
}
let start = self.row_ptr[i] as usize;
let end = self.row_ptr[i + 1] as usize;
&self.col_idx[start..end]
}
pub fn weighted_neighbors(&self, node: NodeId) -> Vec<(NodeId, f64)> {
let i = node.0 as usize;
if i >= self.num_rows {
return Vec::new();
}
let start = self.row_ptr[i] as usize;
let end = self.row_ptr[i + 1] as usize;
let neighbors = &self.col_idx[start..end];
match &self.values {
Some(vals) => neighbors
.iter()
.zip(&vals[start..end])
.map(|(&col, &w)| (NodeId(col), w))
.collect(),
None => neighbors.iter().map(|&col| (NodeId(col), 1.0)).collect(),
}
}
pub fn has_edge(&self, src: NodeId, dst: NodeId) -> bool {
self.neighbors(src).contains(&dst.0)
}
pub fn validate(&self) -> Result<()> {
if self.row_ptr.len() != self.num_rows + 1 {
return Err(GraphError::InvalidCsr(format!(
"row_ptr length {} != num_rows + 1 = {}",
self.row_ptr.len(),
self.num_rows + 1
)));
}
for i in 0..self.num_rows {
if self.row_ptr[i] > self.row_ptr[i + 1] {
return Err(GraphError::InvalidCsr(format!(
"row_ptr not monotonic at index {}",
i
)));
}
}
let nnz = *self.row_ptr.last().unwrap_or(&0) as usize;
if nnz != self.col_idx.len() {
return Err(GraphError::InvalidCsr(format!(
"row_ptr[-1] = {} != col_idx.len() = {}",
nnz,
self.col_idx.len()
)));
}
if let Some(ref vals) = self.values {
if vals.len() != self.col_idx.len() {
return Err(GraphError::InvalidCsr(format!(
"values.len() = {} != col_idx.len() = {}",
vals.len(),
self.col_idx.len()
)));
}
}
for &col in &self.col_idx {
if col as usize >= self.num_cols {
return Err(GraphError::InvalidCsr(format!(
"col_idx {} >= num_cols {}",
col, self.num_cols
)));
}
}
Ok(())
}
pub fn transpose(&self) -> Self {
let mut builder = CsrMatrixBuilder::new(self.num_cols);
let mut counts = vec![0u64; self.num_cols];
for &col in &self.col_idx {
counts[col as usize] += 1;
}
for row in 0..self.num_rows {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
for (i, &col) in self.col_idx[start..end].iter().enumerate() {
let weight = self.values.as_ref().map(|v| v[start + i]);
builder.edges.push((col, row as u32, weight));
}
}
builder.build()
}
}
#[derive(Debug, Default)]
pub struct CsrMatrixBuilder {
num_nodes: usize,
edges: Vec<(u32, u32, Option<f64>)>,
}
impl CsrMatrixBuilder {
pub fn new(num_nodes: usize) -> Self {
Self {
num_nodes,
edges: Vec::new(),
}
}
pub fn with_edges(mut self, edges: &[(u32, u32)]) -> Self {
for &(src, dst) in edges {
self.edges.push((src, dst, None));
}
self
}
pub fn with_weighted_edges(mut self, edges: &[(u32, u32, f64)]) -> Self {
for &(src, dst, w) in edges {
self.edges.push((src, dst, Some(w)));
}
self
}
pub fn add_edge(&mut self, src: u32, dst: u32) {
self.edges.push((src, dst, None));
}
pub fn add_weighted_edge(&mut self, src: u32, dst: u32, weight: f64) {
self.edges.push((src, dst, Some(weight)));
}
pub fn build(mut self) -> CsrMatrix {
self.edges.sort_by_key(|e| e.0);
let has_weights = self.edges.iter().any(|e| e.2.is_some());
let mut row_ptr = vec![0u64; self.num_nodes + 1];
for &(src, _, _) in &self.edges {
if (src as usize) < self.num_nodes {
row_ptr[src as usize + 1] += 1;
}
}
for i in 1..=self.num_nodes {
row_ptr[i] += row_ptr[i - 1];
}
let col_idx: Vec<u32> = self.edges.iter().map(|e| e.1).collect();
let values = if has_weights {
Some(self.edges.iter().map(|e| e.2.unwrap_or(1.0)).collect())
} else {
None
};
CsrMatrix {
num_rows: self.num_nodes,
num_cols: self.num_nodes,
row_ptr,
col_idx,
values,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_matrix() {
let csr = CsrMatrix::empty(5);
assert_eq!(csr.num_rows, 5);
assert_eq!(csr.num_nonzeros(), 0);
assert!(csr.is_empty());
}
#[test]
fn test_from_edges() {
let edges = [(0, 1), (1, 2), (1, 3)];
let csr = CsrMatrix::from_edges(4, &edges);
assert_eq!(csr.num_rows, 4);
assert_eq!(csr.num_nonzeros(), 3);
assert!(csr.validate().is_ok());
}
#[test]
fn test_neighbors() {
let edges = [(0, 1), (0, 2), (1, 2)];
let csr = CsrMatrix::from_edges(3, &edges);
let n0 = csr.neighbors(NodeId(0));
assert_eq!(n0.len(), 2);
assert!(n0.contains(&1));
assert!(n0.contains(&2));
let n1 = csr.neighbors(NodeId(1));
assert_eq!(n1.len(), 1);
assert!(n1.contains(&2));
let n2 = csr.neighbors(NodeId(2));
assert!(n2.is_empty());
}
#[test]
fn test_degree() {
let edges = [(0, 1), (0, 2), (0, 3), (1, 2)];
let csr = CsrMatrix::from_edges(4, &edges);
assert_eq!(csr.degree(NodeId(0)), 3);
assert_eq!(csr.degree(NodeId(1)), 1);
assert_eq!(csr.degree(NodeId(2)), 0);
assert_eq!(csr.degree(NodeId(3)), 0);
}
#[test]
fn test_has_edge() {
let edges = [(0, 1), (1, 2)];
let csr = CsrMatrix::from_edges(3, &edges);
assert!(csr.has_edge(NodeId(0), NodeId(1)));
assert!(csr.has_edge(NodeId(1), NodeId(2)));
assert!(!csr.has_edge(NodeId(0), NodeId(2)));
assert!(!csr.has_edge(NodeId(2), NodeId(0)));
}
#[test]
fn test_weighted_edges() {
let edges = [(0, 1, 1.5), (0, 2, 2.5), (1, 2, 3.0)];
let csr = CsrMatrix::from_weighted_edges(3, &edges);
assert!(csr.values.is_some());
let neighbors = csr.weighted_neighbors(NodeId(0));
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&(NodeId(1), 1.5)));
assert!(neighbors.contains(&(NodeId(2), 2.5)));
}
#[test]
fn test_transpose() {
let edges = [(0, 1), (1, 2)];
let csr = CsrMatrix::from_edges(3, &edges);
let transposed = csr.transpose();
assert!(transposed.has_edge(NodeId(1), NodeId(0)));
assert!(transposed.has_edge(NodeId(2), NodeId(1)));
assert!(!transposed.has_edge(NodeId(0), NodeId(1)));
}
#[test]
fn test_builder() {
let mut builder = CsrMatrixBuilder::new(4);
builder.add_edge(0, 1);
builder.add_edge(0, 2);
builder.add_weighted_edge(1, 3, 2.5);
let csr = builder.build();
assert_eq!(csr.num_nonzeros(), 3);
assert!(csr.values.is_some());
}
#[test]
fn test_validation() {
let csr = CsrMatrix::from_edges(3, &[(0, 1), (1, 2)]);
assert!(csr.validate().is_ok());
let invalid = CsrMatrix {
num_rows: 3,
num_cols: 3,
row_ptr: vec![0, 1, 2, 2],
col_idx: vec![1, 10], values: None,
};
assert!(invalid.validate().is_err());
}
}