1use ndarray::Array2;
15
16use crate::float::Float;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
19#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
20pub struct CsrMatrix<F: Float> {
21 pub indptr: Vec<usize>,
22 pub indices: Vec<usize>,
23 pub data: Vec<F>,
24 pub n_rows: usize,
25 pub n_cols: usize,
26}
27
28impl<F: Float> CsrMatrix<F> {
29 pub fn from_triplets(n_rows: usize, n_cols: usize, triplets: Vec<(usize, usize, F)>) -> Self {
33 let mut buckets: Vec<Vec<(usize, F)>> = vec![Vec::new(); n_rows];
35 for (r, c, v) in triplets {
36 buckets[r].push((c, v));
37 }
38 let mut indptr = Vec::with_capacity(n_rows + 1);
40 let mut indices = Vec::new();
41 let mut data = Vec::new();
42 indptr.push(0);
43 for row in buckets.iter_mut() {
44 row.sort_by(|a, b| a.0.cmp(&b.0));
45 let mut last_col: Option<usize> = None;
47 for &(c, v) in row.iter() {
48 if Some(c) == last_col {
49 let n = data.len();
50 data[n - 1] = data[n - 1] + v;
51 } else {
52 indices.push(c);
53 data.push(v);
54 last_col = Some(c);
55 }
56 }
57 indptr.push(indices.len());
58 }
59 Self {
60 indptr,
61 indices,
62 data,
63 n_rows,
64 n_cols,
65 }
66 }
67
68 pub fn nnz(&self) -> usize {
69 self.data.len()
70 }
71
72 pub fn density(&self) -> f64 {
73 if self.n_rows == 0 || self.n_cols == 0 {
74 return 0.0;
75 }
76 self.nnz() as f64 / (self.n_rows as f64 * self.n_cols as f64)
77 }
78
79 pub fn row_iter(&self, i: usize) -> impl Iterator<Item = (usize, F)> + '_ {
81 let start = self.indptr[i];
82 let end = self.indptr[i + 1];
83 self.indices[start..end]
84 .iter()
85 .copied()
86 .zip(self.data[start..end].iter().copied())
87 }
88
89 pub fn to_dense(&self) -> Array2<F> {
90 let mut out = Array2::<F>::zeros((self.n_rows, self.n_cols));
91 for i in 0..self.n_rows {
92 for (c, v) in self.row_iter(i) {
93 out[[i, c]] = v;
94 }
95 }
96 out
97 }
98
99 pub fn matvec(&self, x: &[F]) -> Vec<F> {
102 assert_eq!(x.len(), self.n_cols, "matvec: dimension mismatch");
103 let mut y = vec![F::zero(); self.n_rows];
104 for i in 0..self.n_rows {
105 let mut s = F::zero();
106 for (c, v) in self.row_iter(i) {
107 s = s + v * x[c];
108 }
109 y[i] = s;
110 }
111 y
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_csr_from_triplets_basic() {
121 let csr = CsrMatrix::<f64>::from_triplets(
126 3,
127 4,
128 vec![
129 (0, 0, 1.0),
130 (0, 3, 2.0),
131 (1, 1, 3.0),
132 (2, 2, 4.0),
133 (2, 3, 5.0),
134 ],
135 );
136 assert_eq!(csr.nnz(), 5);
137 let dense = csr.to_dense();
138 assert_eq!(dense[[0, 0]], 1.0);
139 assert_eq!(dense[[0, 3]], 2.0);
140 assert_eq!(dense[[1, 1]], 3.0);
141 assert_eq!(dense[[2, 2]], 4.0);
142 assert_eq!(dense[[2, 3]], 5.0);
143 assert_eq!(dense[[1, 0]], 0.0);
144 }
145
146 #[test]
147 fn test_csr_duplicate_triplets_sum() {
148 let csr =
149 CsrMatrix::<f64>::from_triplets(1, 3, vec![(0, 1, 1.0), (0, 1, 2.0), (0, 1, 3.0)]);
150 assert_eq!(csr.nnz(), 1);
151 assert_eq!(csr.to_dense()[[0, 1]], 6.0);
152 }
153
154 #[test]
155 fn test_csr_matvec() {
156 let csr = CsrMatrix::<f64>::from_triplets(2, 2, vec![(0, 0, 1.0), (1, 1, 2.0)]);
158 let y = csr.matvec(&[3.0, 4.0]);
159 assert_eq!(y, vec![3.0, 8.0]);
160 }
161
162 #[test]
163 fn test_csr_density() {
164 let csr = CsrMatrix::<f64>::from_triplets(2, 2, vec![(0, 0, 1.0)]);
165 assert!((csr.density() - 0.25).abs() < 1e-12);
166 }
167}