1use alloc::vec::Vec;
2use tang::Scalar;
3use tang_la::DVec;
4
5pub struct CscMatrix<S> {
7 pub nrows: usize,
8 pub ncols: usize,
9 pub col_ptrs: Vec<usize>, pub row_indices: Vec<usize>, pub values: Vec<S>, }
13
14impl<S: Scalar> CscMatrix<S> {
15 pub fn nnz(&self) -> usize {
16 self.values.len()
17 }
18
19 pub fn spmv(&self, x: &DVec<S>) -> DVec<S> {
21 assert_eq!(x.len(), self.ncols);
22 let mut y = DVec::zeros(self.nrows);
23 for j in 0..self.ncols {
24 let xj = x[j];
25 for k in self.col_ptrs[j]..self.col_ptrs[j + 1] {
26 y[self.row_indices[k]] = y[self.row_indices[k]] + self.values[k] * xj;
27 }
28 }
29 y
30 }
31
32 pub fn get(&self, row: usize, col: usize) -> S {
34 let start = self.col_ptrs[col];
35 let end = self.col_ptrs[col + 1];
36 for k in start..end {
37 if self.row_indices[k] == row {
38 return self.values[k];
39 }
40 }
41 S::ZERO
42 }
43
44 pub fn from_csr(csr: &super::CsrMatrix<S>) -> Self {
46 let nrows = csr.nrows;
47 let ncols = csr.ncols;
48 let nnz = csr.nnz();
49
50 let mut col_counts = alloc::vec![0usize; ncols + 1];
51 for &c in &csr.col_indices {
52 col_counts[c + 1] += 1;
53 }
54 for j in 1..=ncols {
55 col_counts[j] += col_counts[j - 1];
56 }
57
58 let mut row_indices = alloc::vec![0usize; nnz];
59 let mut values = alloc::vec![S::ZERO; nnz];
60 let mut offsets = col_counts.clone();
61
62 for i in 0..nrows {
63 for k in csr.row_ptrs[i]..csr.row_ptrs[i + 1] {
64 let c = csr.col_indices[k];
65 let pos = offsets[c];
66 row_indices[pos] = i;
67 values[pos] = csr.values[k];
68 offsets[c] += 1;
69 }
70 }
71
72 Self {
73 nrows,
74 ncols,
75 col_ptrs: col_counts,
76 row_indices,
77 values,
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::CooMatrix;
86
87 #[test]
88 fn csc_spmv() {
89 let mut coo = CooMatrix::new(2, 3);
90 coo.push(0, 0, 1.0);
91 coo.push(0, 2, 2.0);
92 coo.push(1, 1, 3.0);
93 let csr = coo.to_csr();
94 let csc = CscMatrix::from_csr(&csr);
95
96 let x = DVec::from_slice(&[1.0, 2.0, 2.0]);
97 let y = csc.spmv(&x);
98 assert!((y[0] - 5.0).abs() < 1e-10);
99 assert!((y[1] - 6.0).abs() < 1e-10);
100 }
101}