1use alloc::vec::Vec;
2use tang::Scalar;
3use tang_la::DVec;
4
5pub struct CsrMatrix<S> {
7 pub nrows: usize,
8 pub ncols: usize,
9 pub row_ptrs: Vec<usize>, pub col_indices: Vec<usize>, pub values: Vec<S>, }
13
14impl<S: Scalar> CsrMatrix<S> {
15 pub fn nnz(&self) -> usize {
16 self.values.len()
17 }
18
19 pub fn spmv_add(&self, x: &DVec<S>, y: &mut DVec<S>) {
21 assert_eq!(x.len(), self.ncols);
22 assert_eq!(y.len(), self.nrows);
23 for i in 0..self.nrows {
24 let start = self.row_ptrs[i];
25 let end = self.row_ptrs[i + 1];
26 let mut sum = S::ZERO;
27 for k in start..end {
28 sum += self.values[k] * x[self.col_indices[k]];
29 }
30 y[i] = y[i] + sum;
31 }
32 }
33
34 pub fn spmv(&self, x: &DVec<S>) -> DVec<S> {
36 let mut y = DVec::zeros(self.nrows);
37 self.spmv_add(x, &mut y);
38 y
39 }
40
41 pub fn get(&self, row: usize, col: usize) -> S {
43 let start = self.row_ptrs[row];
44 let end = self.row_ptrs[row + 1];
45 for k in start..end {
46 if self.col_indices[k] == col {
47 return self.values[k];
48 }
49 }
50 S::ZERO
51 }
52
53 pub fn transpose(&self) -> Self {
55 let mut row_counts = alloc::vec![0usize; self.ncols + 1];
56 for &c in &self.col_indices {
57 row_counts[c + 1] += 1;
58 }
59 for i in 1..=self.ncols {
60 row_counts[i] += row_counts[i - 1];
61 }
62 let nnz = self.nnz();
63 let mut col_indices = alloc::vec![0usize; nnz];
64 let mut values = alloc::vec![S::ZERO; nnz];
65 let mut offsets = row_counts.clone();
66 for i in 0..self.nrows {
67 for k in self.row_ptrs[i]..self.row_ptrs[i + 1] {
68 let c = self.col_indices[k];
69 let pos = offsets[c];
70 col_indices[pos] = i;
71 values[pos] = self.values[k];
72 offsets[c] += 1;
73 }
74 }
75 CsrMatrix {
76 nrows: self.ncols,
77 ncols: self.nrows,
78 row_ptrs: row_counts,
79 col_indices,
80 values,
81 }
82 }
83
84 pub fn to_dense(&self) -> tang_la::DMat<S> {
86 tang_la::DMat::from_fn(self.nrows, self.ncols, |i, j| self.get(i, j))
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::CooMatrix;
94
95 #[test]
96 fn spmv() {
97 let mut coo = CooMatrix::new(2, 3);
101 coo.push(0, 0, 1.0);
102 coo.push(0, 2, 2.0);
103 coo.push(1, 1, 3.0);
104 let csr = coo.to_csr();
105
106 let x = DVec::from_slice(&[1.0, 2.0, 2.0]);
107 let y = csr.spmv(&x);
108 assert!((y[0] - 5.0).abs() < 1e-10);
109 assert!((y[1] - 6.0).abs() < 1e-10);
110 }
111
112 #[test]
113 fn transpose() {
114 let mut coo = CooMatrix::new(2, 3);
115 coo.push(0, 0, 1.0);
116 coo.push(0, 2, 2.0);
117 coo.push(1, 1, 3.0);
118 let csr = coo.to_csr();
119 let csrt = csr.transpose();
120 assert_eq!(csrt.nrows, 3);
121 assert_eq!(csrt.ncols, 2);
122 assert!((csrt.get(0, 0) - 1.0).abs() < 1e-10);
123 assert!((csrt.get(2, 0) - 2.0).abs() < 1e-10);
124 assert!((csrt.get(1, 1) - 3.0).abs() < 1e-10);
125 }
126
127 #[test]
128 fn duplicate_entries_summed() {
129 let mut coo = CooMatrix::new(2, 2);
130 coo.push(0, 0, 1.0);
131 coo.push(0, 0, 2.0); coo.push(1, 1, 3.0);
133 let csr = coo.to_csr();
134 assert!((csr.get(0, 0) - 3.0).abs() < 1e-10);
135 assert_eq!(csr.nnz(), 2); }
137}