Skip to main content

tang_sparse/
csr.rs

1use alloc::vec::Vec;
2use tang::Scalar;
3use tang_la::DVec;
4
5/// Compressed Sparse Row matrix.
6pub struct CsrMatrix<S> {
7    pub nrows: usize,
8    pub ncols: usize,
9    pub row_ptrs: Vec<usize>,    // length nrows + 1
10    pub col_indices: Vec<usize>, // length nnz
11    pub values: Vec<S>,          // length nnz
12}
13
14impl<S: Scalar> CsrMatrix<S> {
15    pub fn nnz(&self) -> usize {
16        self.values.len()
17    }
18
19    /// Sparse matrix-vector product: y += A * x
20    pub fn spmv_add(&self, x: &DVec<S>, y: &mut DVec<S>) {
21        assert_eq!(x.len(), self.ncols);
22        assert_eq!(y.len(), self.nrows);
23        for i in 0..self.nrows {
24            let start = self.row_ptrs[i];
25            let end = self.row_ptrs[i + 1];
26            let mut sum = S::ZERO;
27            for k in start..end {
28                sum += self.values[k] * x[self.col_indices[k]];
29            }
30            y[i] = y[i] + sum;
31        }
32    }
33
34    /// Sparse matrix-vector product: y = A * x
35    pub fn spmv(&self, x: &DVec<S>) -> DVec<S> {
36        let mut y = DVec::zeros(self.nrows);
37        self.spmv_add(x, &mut y);
38        y
39    }
40
41    /// Get element (i, j). O(nnz_row) lookup.
42    pub fn get(&self, row: usize, col: usize) -> S {
43        let start = self.row_ptrs[row];
44        let end = self.row_ptrs[row + 1];
45        for k in start..end {
46            if self.col_indices[k] == col {
47                return self.values[k];
48            }
49        }
50        S::ZERO
51    }
52
53    /// Transpose to CSR (builds a new CSR of A^T).
54    pub fn transpose(&self) -> Self {
55        let mut row_counts = alloc::vec![0usize; self.ncols + 1];
56        for &c in &self.col_indices {
57            row_counts[c + 1] += 1;
58        }
59        for i in 1..=self.ncols {
60            row_counts[i] += row_counts[i - 1];
61        }
62        let nnz = self.nnz();
63        let mut col_indices = alloc::vec![0usize; nnz];
64        let mut values = alloc::vec![S::ZERO; nnz];
65        let mut offsets = row_counts.clone();
66        for i in 0..self.nrows {
67            for k in self.row_ptrs[i]..self.row_ptrs[i + 1] {
68                let c = self.col_indices[k];
69                let pos = offsets[c];
70                col_indices[pos] = i;
71                values[pos] = self.values[k];
72                offsets[c] += 1;
73            }
74        }
75        CsrMatrix {
76            nrows: self.ncols,
77            ncols: self.nrows,
78            row_ptrs: row_counts,
79            col_indices,
80            values,
81        }
82    }
83
84    /// Convert to dense matrix.
85    pub fn to_dense(&self) -> tang_la::DMat<S> {
86        tang_la::DMat::from_fn(self.nrows, self.ncols, |i, j| self.get(i, j))
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::CooMatrix;
94
95    #[test]
96    fn spmv() {
97        // [[1 0 2]  * [1]   [5]
98        //  [0 3 0]]   [2] = [6]
99        //              [2]
100        let mut coo = CooMatrix::new(2, 3);
101        coo.push(0, 0, 1.0);
102        coo.push(0, 2, 2.0);
103        coo.push(1, 1, 3.0);
104        let csr = coo.to_csr();
105
106        let x = DVec::from_slice(&[1.0, 2.0, 2.0]);
107        let y = csr.spmv(&x);
108        assert!((y[0] - 5.0).abs() < 1e-10);
109        assert!((y[1] - 6.0).abs() < 1e-10);
110    }
111
112    #[test]
113    fn transpose() {
114        let mut coo = CooMatrix::new(2, 3);
115        coo.push(0, 0, 1.0);
116        coo.push(0, 2, 2.0);
117        coo.push(1, 1, 3.0);
118        let csr = coo.to_csr();
119        let csrt = csr.transpose();
120        assert_eq!(csrt.nrows, 3);
121        assert_eq!(csrt.ncols, 2);
122        assert!((csrt.get(0, 0) - 1.0).abs() < 1e-10);
123        assert!((csrt.get(2, 0) - 2.0).abs() < 1e-10);
124        assert!((csrt.get(1, 1) - 3.0).abs() < 1e-10);
125    }
126
127    #[test]
128    fn duplicate_entries_summed() {
129        let mut coo = CooMatrix::new(2, 2);
130        coo.push(0, 0, 1.0);
131        coo.push(0, 0, 2.0); // duplicate
132        coo.push(1, 1, 3.0);
133        let csr = coo.to_csr();
134        assert!((csr.get(0, 0) - 3.0).abs() < 1e-10);
135        assert_eq!(csr.nnz(), 2); // merged
136    }
137}