Skip to main content

ringkernel_graph/models/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format.
2//!
3//! CSR is an efficient format for sparse matrices/graphs that enables:
4//! - O(1) access to row start/end positions
5//! - O(degree) iteration over neighbors
6//! - Cache-friendly sequential access patterns
7//!
8//! Memory layout:
9//! - `row_ptr[i]` = starting index in col_idx for row i
10//! - `col_idx[row_ptr[i]..row_ptr[i+1]]` = column indices (neighbors) of row i
11//! - `values` (optional) = edge weights
12
13use super::node::NodeId;
14use crate::{GraphError, Result};
15
16/// Compressed Sparse Row matrix for graph adjacency.
17///
18/// For a graph with N nodes and M edges:
19/// - `row_ptr`: N+1 elements, where row_ptr[i] is the start of row i's edges
20/// - `col_idx`: M elements, the column indices (neighbor node IDs)
21/// - `values`: Optional M elements for weighted graphs
22#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24    /// Number of rows (nodes).
25    pub num_rows: usize,
26    /// Number of columns (typically equals num_rows for square adjacency).
27    pub num_cols: usize,
28    /// Row pointers (length = num_rows + 1).
29    pub row_ptr: Vec<u64>,
30    /// Column indices (length = nnz).
31    pub col_idx: Vec<u32>,
32    /// Optional edge values/weights.
33    pub values: Option<Vec<f64>>,
34}
35
36impl CsrMatrix {
37    /// Create an empty CSR matrix.
38    pub fn empty(num_nodes: usize) -> Self {
39        Self {
40            num_rows: num_nodes,
41            num_cols: num_nodes,
42            row_ptr: vec![0; num_nodes + 1],
43            col_idx: Vec::new(),
44            values: None,
45        }
46    }
47
48    /// Create CSR from edge list.
49    ///
50    /// # Arguments
51    ///
52    /// * `num_nodes` - Number of nodes in the graph
53    /// * `edges` - List of (source, destination) pairs
54    ///
55    /// # Example
56    ///
57    /// ```
58    /// use ringkernel_graph::CsrMatrix;
59    ///
60    /// // Graph: 0 -> 1 -> 2
61    /// let csr = CsrMatrix::from_edges(3, &[(0, 1), (1, 2)]);
62    /// assert_eq!(csr.num_nonzeros(), 2);
63    /// ```
64    pub fn from_edges(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
65        CsrMatrixBuilder::new(num_nodes).with_edges(edges).build()
66    }
67
68    /// Create CSR from edge list with weights.
69    pub fn from_weighted_edges(num_nodes: usize, edges: &[(u32, u32, f64)]) -> Self {
70        CsrMatrixBuilder::new(num_nodes)
71            .with_weighted_edges(edges)
72            .build()
73    }
74
75    /// Number of non-zero entries (edges).
76    pub fn num_nonzeros(&self) -> usize {
77        self.col_idx.len()
78    }
79
80    /// Check if matrix is empty.
81    pub fn is_empty(&self) -> bool {
82        self.col_idx.is_empty()
83    }
84
85    /// Get the degree (number of outgoing edges) of a node.
86    pub fn degree(&self, node: NodeId) -> usize {
87        let i = node.0 as usize;
88        if i >= self.num_rows {
89            return 0;
90        }
91        (self.row_ptr[i + 1] - self.row_ptr[i]) as usize
92    }
93
94    /// Get neighbors of a node.
95    pub fn neighbors(&self, node: NodeId) -> &[u32] {
96        let i = node.0 as usize;
97        if i >= self.num_rows {
98            return &[];
99        }
100        let start = self.row_ptr[i] as usize;
101        let end = self.row_ptr[i + 1] as usize;
102        &self.col_idx[start..end]
103    }
104
105    /// Get neighbors with weights (returns empty if unweighted).
106    pub fn weighted_neighbors(&self, node: NodeId) -> Vec<(NodeId, f64)> {
107        let i = node.0 as usize;
108        if i >= self.num_rows {
109            return Vec::new();
110        }
111        let start = self.row_ptr[i] as usize;
112        let end = self.row_ptr[i + 1] as usize;
113
114        let neighbors = &self.col_idx[start..end];
115        match &self.values {
116            Some(vals) => neighbors
117                .iter()
118                .zip(&vals[start..end])
119                .map(|(&col, &w)| (NodeId(col), w))
120                .collect(),
121            None => neighbors.iter().map(|&col| (NodeId(col), 1.0)).collect(),
122        }
123    }
124
125    /// Check if edge exists from src to dst.
126    pub fn has_edge(&self, src: NodeId, dst: NodeId) -> bool {
127        self.neighbors(src).contains(&dst.0)
128    }
129
130    /// Validate CSR structure.
131    pub fn validate(&self) -> Result<()> {
132        // Check row_ptr length
133        if self.row_ptr.len() != self.num_rows + 1 {
134            return Err(GraphError::InvalidCsr(format!(
135                "row_ptr length {} != num_rows + 1 = {}",
136                self.row_ptr.len(),
137                self.num_rows + 1
138            )));
139        }
140
141        // Check row_ptr is non-decreasing
142        for i in 0..self.num_rows {
143            if self.row_ptr[i] > self.row_ptr[i + 1] {
144                return Err(GraphError::InvalidCsr(format!(
145                    "row_ptr not monotonic at index {}",
146                    i
147                )));
148            }
149        }
150
151        // Check final row_ptr matches col_idx length
152        let nnz = *self.row_ptr.last().unwrap_or(&0) as usize;
153        if nnz != self.col_idx.len() {
154            return Err(GraphError::InvalidCsr(format!(
155                "row_ptr[-1] = {} != col_idx.len() = {}",
156                nnz,
157                self.col_idx.len()
158            )));
159        }
160
161        // Check values length if present
162        if let Some(ref vals) = self.values {
163            if vals.len() != self.col_idx.len() {
164                return Err(GraphError::InvalidCsr(format!(
165                    "values.len() = {} != col_idx.len() = {}",
166                    vals.len(),
167                    self.col_idx.len()
168                )));
169            }
170        }
171
172        // Check col_idx values are in bounds
173        for &col in &self.col_idx {
174            if col as usize >= self.num_cols {
175                return Err(GraphError::InvalidCsr(format!(
176                    "col_idx {} >= num_cols {}",
177                    col, self.num_cols
178                )));
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Create transpose (reverse graph).
186    pub fn transpose(&self) -> Self {
187        let mut builder = CsrMatrixBuilder::new(self.num_cols);
188
189        // Count in-degrees for new row_ptr
190        let mut counts = vec![0u64; self.num_cols];
191        for &col in &self.col_idx {
192            counts[col as usize] += 1;
193        }
194
195        // Build transposed edges
196        for row in 0..self.num_rows {
197            let start = self.row_ptr[row] as usize;
198            let end = self.row_ptr[row + 1] as usize;
199            for (i, &col) in self.col_idx[start..end].iter().enumerate() {
200                let weight = self.values.as_ref().map(|v| v[start + i]);
201                builder.edges.push((col, row as u32, weight));
202            }
203        }
204
205        builder.build()
206    }
207}
208
209/// Builder for CSR matrices.
210#[derive(Debug, Default)]
211pub struct CsrMatrixBuilder {
212    num_nodes: usize,
213    edges: Vec<(u32, u32, Option<f64>)>,
214}
215
216impl CsrMatrixBuilder {
217    /// Create new builder with given number of nodes.
218    pub fn new(num_nodes: usize) -> Self {
219        Self {
220            num_nodes,
221            edges: Vec::new(),
222        }
223    }
224
225    /// Add edges from slice.
226    pub fn with_edges(mut self, edges: &[(u32, u32)]) -> Self {
227        for &(src, dst) in edges {
228            self.edges.push((src, dst, None));
229        }
230        self
231    }
232
233    /// Add weighted edges from slice.
234    pub fn with_weighted_edges(mut self, edges: &[(u32, u32, f64)]) -> Self {
235        for &(src, dst, w) in edges {
236            self.edges.push((src, dst, Some(w)));
237        }
238        self
239    }
240
241    /// Add a single edge.
242    pub fn add_edge(&mut self, src: u32, dst: u32) {
243        self.edges.push((src, dst, None));
244    }
245
246    /// Add a weighted edge.
247    pub fn add_weighted_edge(&mut self, src: u32, dst: u32, weight: f64) {
248        self.edges.push((src, dst, Some(weight)));
249    }
250
251    /// Build the CSR matrix.
252    pub fn build(mut self) -> CsrMatrix {
253        // Sort edges by source
254        self.edges.sort_by_key(|e| e.0);
255
256        let has_weights = self.edges.iter().any(|e| e.2.is_some());
257
258        // Build row_ptr
259        let mut row_ptr = vec![0u64; self.num_nodes + 1];
260        for &(src, _, _) in &self.edges {
261            if (src as usize) < self.num_nodes {
262                row_ptr[src as usize + 1] += 1;
263            }
264        }
265
266        // Cumulative sum
267        for i in 1..=self.num_nodes {
268            row_ptr[i] += row_ptr[i - 1];
269        }
270
271        // Build col_idx and values
272        let col_idx: Vec<u32> = self.edges.iter().map(|e| e.1).collect();
273        let values = if has_weights {
274            Some(self.edges.iter().map(|e| e.2.unwrap_or(1.0)).collect())
275        } else {
276            None
277        };
278
279        CsrMatrix {
280            num_rows: self.num_nodes,
281            num_cols: self.num_nodes,
282            row_ptr,
283            col_idx,
284            values,
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_empty_matrix() {
295        let csr = CsrMatrix::empty(5);
296        assert_eq!(csr.num_rows, 5);
297        assert_eq!(csr.num_nonzeros(), 0);
298        assert!(csr.is_empty());
299    }
300
301    #[test]
302    fn test_from_edges() {
303        // 0 -> 1 -> 2
304        //      |
305        //      v
306        //      3
307        let edges = [(0, 1), (1, 2), (1, 3)];
308        let csr = CsrMatrix::from_edges(4, &edges);
309
310        assert_eq!(csr.num_rows, 4);
311        assert_eq!(csr.num_nonzeros(), 3);
312        assert!(csr.validate().is_ok());
313    }
314
315    #[test]
316    fn test_neighbors() {
317        let edges = [(0, 1), (0, 2), (1, 2)];
318        let csr = CsrMatrix::from_edges(3, &edges);
319
320        let n0 = csr.neighbors(NodeId(0));
321        assert_eq!(n0.len(), 2);
322        assert!(n0.contains(&1));
323        assert!(n0.contains(&2));
324
325        let n1 = csr.neighbors(NodeId(1));
326        assert_eq!(n1.len(), 1);
327        assert!(n1.contains(&2));
328
329        let n2 = csr.neighbors(NodeId(2));
330        assert!(n2.is_empty());
331    }
332
333    #[test]
334    fn test_degree() {
335        let edges = [(0, 1), (0, 2), (0, 3), (1, 2)];
336        let csr = CsrMatrix::from_edges(4, &edges);
337
338        assert_eq!(csr.degree(NodeId(0)), 3);
339        assert_eq!(csr.degree(NodeId(1)), 1);
340        assert_eq!(csr.degree(NodeId(2)), 0);
341        assert_eq!(csr.degree(NodeId(3)), 0);
342    }
343
344    #[test]
345    fn test_has_edge() {
346        let edges = [(0, 1), (1, 2)];
347        let csr = CsrMatrix::from_edges(3, &edges);
348
349        assert!(csr.has_edge(NodeId(0), NodeId(1)));
350        assert!(csr.has_edge(NodeId(1), NodeId(2)));
351        assert!(!csr.has_edge(NodeId(0), NodeId(2)));
352        assert!(!csr.has_edge(NodeId(2), NodeId(0)));
353    }
354
355    #[test]
356    fn test_weighted_edges() {
357        let edges = [(0, 1, 1.5), (0, 2, 2.5), (1, 2, 3.0)];
358        let csr = CsrMatrix::from_weighted_edges(3, &edges);
359
360        assert!(csr.values.is_some());
361
362        let neighbors = csr.weighted_neighbors(NodeId(0));
363        assert_eq!(neighbors.len(), 2);
364        assert!(neighbors.contains(&(NodeId(1), 1.5)));
365        assert!(neighbors.contains(&(NodeId(2), 2.5)));
366    }
367
368    #[test]
369    fn test_transpose() {
370        // 0 -> 1 -> 2
371        let edges = [(0, 1), (1, 2)];
372        let csr = CsrMatrix::from_edges(3, &edges);
373        let transposed = csr.transpose();
374
375        // Should now be: 1 -> 0, 2 -> 1
376        assert!(transposed.has_edge(NodeId(1), NodeId(0)));
377        assert!(transposed.has_edge(NodeId(2), NodeId(1)));
378        assert!(!transposed.has_edge(NodeId(0), NodeId(1)));
379    }
380
381    #[test]
382    fn test_builder() {
383        let mut builder = CsrMatrixBuilder::new(4);
384        builder.add_edge(0, 1);
385        builder.add_edge(0, 2);
386        builder.add_weighted_edge(1, 3, 2.5);
387
388        let csr = builder.build();
389        assert_eq!(csr.num_nonzeros(), 3);
390        assert!(csr.values.is_some());
391    }
392
393    #[test]
394    fn test_validation() {
395        // Valid matrix
396        let csr = CsrMatrix::from_edges(3, &[(0, 1), (1, 2)]);
397        assert!(csr.validate().is_ok());
398
399        // Invalid: col_idx out of bounds
400        let invalid = CsrMatrix {
401            num_rows: 3,
402            num_cols: 3,
403            row_ptr: vec![0, 1, 2, 2],
404            col_idx: vec![1, 10], // 10 is out of bounds
405            values: None,
406        };
407        assert!(invalid.validate().is_err());
408    }
409}