Skip to main content

tang_sparse/
csc.rs

1use alloc::vec::Vec;
2use tang::Scalar;
3use tang_la::DVec;
4
5/// Compressed Sparse Column matrix.
6pub struct CscMatrix<S> {
7    pub nrows: usize,
8    pub ncols: usize,
9    pub col_ptrs: Vec<usize>,    // length ncols + 1
10    pub row_indices: Vec<usize>, // length nnz
11    pub values: Vec<S>,          // length nnz
12}
13
14impl<S: Scalar> CscMatrix<S> {
15    pub fn nnz(&self) -> usize {
16        self.values.len()
17    }
18
19    /// Sparse matrix-vector product: y = A * x
20    pub fn spmv(&self, x: &DVec<S>) -> DVec<S> {
21        assert_eq!(x.len(), self.ncols);
22        let mut y = DVec::zeros(self.nrows);
23        for j in 0..self.ncols {
24            let xj = x[j];
25            for k in self.col_ptrs[j]..self.col_ptrs[j + 1] {
26                y[self.row_indices[k]] = y[self.row_indices[k]] + self.values[k] * xj;
27            }
28        }
29        y
30    }
31
32    /// Get element (i, j).
33    pub fn get(&self, row: usize, col: usize) -> S {
34        let start = self.col_ptrs[col];
35        let end = self.col_ptrs[col + 1];
36        for k in start..end {
37            if self.row_indices[k] == row {
38                return self.values[k];
39            }
40        }
41        S::ZERO
42    }
43
44    /// Convert from CSR.
45    pub fn from_csr(csr: &super::CsrMatrix<S>) -> Self {
46        let nrows = csr.nrows;
47        let ncols = csr.ncols;
48        let nnz = csr.nnz();
49
50        let mut col_counts = alloc::vec![0usize; ncols + 1];
51        for &c in &csr.col_indices {
52            col_counts[c + 1] += 1;
53        }
54        for j in 1..=ncols {
55            col_counts[j] += col_counts[j - 1];
56        }
57
58        let mut row_indices = alloc::vec![0usize; nnz];
59        let mut values = alloc::vec![S::ZERO; nnz];
60        let mut offsets = col_counts.clone();
61
62        for i in 0..nrows {
63            for k in csr.row_ptrs[i]..csr.row_ptrs[i + 1] {
64                let c = csr.col_indices[k];
65                let pos = offsets[c];
66                row_indices[pos] = i;
67                values[pos] = csr.values[k];
68                offsets[c] += 1;
69            }
70        }
71
72        Self {
73            nrows,
74            ncols,
75            col_ptrs: col_counts,
76            row_indices,
77            values,
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::CooMatrix;
86
87    #[test]
88    fn csc_spmv() {
89        let mut coo = CooMatrix::new(2, 3);
90        coo.push(0, 0, 1.0);
91        coo.push(0, 2, 2.0);
92        coo.push(1, 1, 3.0);
93        let csr = coo.to_csr();
94        let csc = CscMatrix::from_csr(&csr);
95
96        let x = DVec::from_slice(&[1.0, 2.0, 2.0]);
97        let y = csc.spmv(&x);
98        assert!((y[0] - 5.0).abs() < 1e-10);
99        assert!((y[1] - 6.0).abs() < 1e-10);
100    }
101}