1use std::ops::{Add, AddAssign, Mul};
11
12use crate::dense::array::reference::{ArrayRef, ArrayRefMut};
13use crate::dense::array::slice::ArraySlice;
14use crate::sparse::tools::normalize_aij;
15use crate::{AijIteratorByValue, BaseItem, Shape, dense::array::DynArray, sparse::SparseMatType};
16use crate::{
17 AijIteratorMut, Array, AsMatrixApply, FromAij, Nonzeros, SparseMatrixType,
18 UnsafeRandom1DAccessMut, UnsafeRandomAccessByValue, UnsafeRandomAccessMut, empty_array,
19};
20use itertools::{Itertools, izip};
21use num::One;
22
23use super::mat_operations::SparseMatOpIterator;
24
25pub struct CsrMatrix<Item> {
27 mat_type: SparseMatType,
29 shape: [usize; 2],
31 indices: DynArray<usize, 1>,
33 indptr: DynArray<usize, 1>,
35 data: DynArray<Item, 1>,
37}
38
39impl<Item> CsrMatrix<Item> {
40 pub fn new(
42 shape: [usize; 2],
43 indices: DynArray<usize, 1>,
44 indptr: DynArray<usize, 1>,
45 data: DynArray<Item, 1>,
46 ) -> Self {
47 assert_eq!(indptr.len(), 1 + shape[0]);
51 assert_eq!(data.len(), indices.len());
52 assert_eq!(*indptr.data().unwrap().last().unwrap(), data.len());
53
54 for (first, second) in indptr.iter_value().tuple_windows() {
59 assert!(
60 first <= second,
61 "Elements of indptr not in increasing order {first} > {second}."
62 );
63 }
64 assert_eq!(*indptr.data().unwrap().last().unwrap(), indices.len());
66
67 if let Some(&max_col_index) = indices.data().unwrap().iter().max() {
70 assert!(max_col_index < shape[1]);
71 }
72
73 Self {
74 mat_type: SparseMatType::Csr,
75 shape,
76 indices,
77 indptr,
78 data,
79 }
80 }
81
82 pub fn indptr(&self) -> &DynArray<usize, 1> {
84 &self.indptr
85 }
86
87 pub fn indices(&self) -> &DynArray<usize, 1> {
89 &self.indices
90 }
91
92 pub fn data(&self) -> &DynArray<Item, 1> {
94 &self.data
95 }
96}
97
98impl<Item> FromAij for CsrMatrix<Item>
99where
100 Item: AddAssign + PartialEq + Copy + Default,
101{
102 fn from_aij(shape: [usize; 2], rows: &[usize], cols: &[usize], data: &[Item]) -> Self {
105 let (rows, cols, data) = normalize_aij(rows, cols, data, SparseMatType::Csr);
106
107 let max_col = if let Some(col) = cols.iter().max() {
108 *col
109 } else {
110 0
111 };
112 let max_row = if let Some(row) = rows.last() { *row } else { 0 };
113
114 assert!(
115 max_col < shape[1],
116 "Maximum column {} must be smaller than `shape.1` {}",
117 max_col,
118 shape[1]
119 );
120
121 assert!(
122 max_row < shape[0],
123 "Maximum row {} must be smaller than `shape.0` {}",
124 max_row,
125 shape[0]
126 );
127
128 let nelems = data.len();
129
130 let mut indptr = Vec::<usize>::with_capacity(1 + shape[0]);
131
132 let mut count: usize = 0;
133 for row in 0..(shape[0]) {
134 indptr.push(count);
135 while count < nelems && row == rows[count] {
136 count += 1;
137 }
138 }
139 indptr.push(count);
140
141 let indptr = DynArray::from_shape_and_vec([1 + shape[0]], indptr);
142 let indices = DynArray::from_shape_and_vec([nelems], cols);
143 let data = DynArray::from_shape_and_vec([nelems], data);
144
145 Self::new(shape, indices, indptr, data)
146 }
147}
148
149impl<Item> Shape<2> for CsrMatrix<Item> {
150 fn shape(&self) -> [usize; 2] {
151 self.shape
152 }
153}
154
155impl<Item> BaseItem for CsrMatrix<Item>
156where
157 Item: Copy + Default,
158{
159 type Item = Item;
160}
161
162impl<Item> Nonzeros for CsrMatrix<Item> {
163 fn nnz(&self) -> usize {
164 self.data.len()
165 }
166}
167
168impl<Item> SparseMatrixType for CsrMatrix<Item> {
169 fn mat_type(&self) -> SparseMatType {
170 self.mat_type
171 }
172}
173
174impl<Item> AijIteratorByValue for CsrMatrix<Item>
175where
176 Item: Copy + Default,
177{
178 fn iter_aij_value(&self) -> impl Iterator<Item = ([usize; 2], Self::Item)> + '_ {
179 self.indptr
180 .iter_value()
181 .tuple_windows::<(usize, usize)>()
182 .enumerate()
183 .flat_map(|(row, (start, end))| {
184 izip!(
185 self.indices.data().unwrap()[start..end].iter(),
186 self.data.data().unwrap()[start..end].iter()
187 )
188 .map(|(col, value)| ([row, *col], *value))
189 .collect::<Vec<_>>()
190 })
191 }
192}
193
194impl<Item> AijIteratorMut for CsrMatrix<Item>
195where
196 Item: Copy + Default,
197{
198 fn iter_aij_mut(&mut self) -> impl Iterator<Item = ([usize; 2], &mut Self::Item)> + '_ {
199 self.indptr
200 .iter_value()
201 .tuple_windows::<(usize, usize)>()
202 .enumerate()
203 .flat_map(|(row, (start, end))| {
204 izip!(
205 self.indices.data().unwrap()[start..end].iter(),
206 self.data.data_mut().unwrap()[start..end]
207 .iter_mut()
208 .map(|v| v as *mut Item)
211 )
212 .map(|(col, value)| ([row, *col], value))
213 .collect::<Vec<_>>()
214 })
215 .map(|(idx, value)| (idx, unsafe { &mut *value }))
216 }
217}
218
219impl<Item: Copy + Default> CsrMatrix<Item> {
220 pub fn op(&self) -> SparseMatOpIterator<Item, impl Iterator<Item = ([usize; 2], Item)> + '_> {
222 SparseMatOpIterator::new(self.iter_aij_value(), self.shape())
223 }
224}
225
226impl<Item: Copy + Default> CsrMatrix<Item> {
227 pub fn todense(&self) -> DynArray<Item, 2> {
229 DynArray::from_iter_aij(self.shape(), self.iter_aij_value())
230 }
231}
232
233impl<Item: Default + Mul<Output = Item> + AddAssign<Item> + Add<Output = Item> + Copy + One>
234 CsrMatrix<Item>
235{
236 pub fn dot<ArrayImpl, const NDIM: usize>(
238 &self,
239 other: &Array<ArrayImpl, NDIM>,
240 ) -> DynArray<Item, NDIM>
241 where
242 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
243 {
244 let mut out = empty_array::<Item, NDIM>();
245
246 if NDIM == 1 {
247 let mut out = out.r_mut().coerce_dim::<1>().unwrap();
248 let other = other.r().coerce_dim::<1>().unwrap();
249 out.resize_in_place([self.shape()[0]]);
250 self.apply(One::one(), &other, Default::default(), &mut out);
251 } else if NDIM == 2 {
252 let mut out = out.r_mut().coerce_dim::<2>().unwrap();
253 let other = other.r().coerce_dim::<2>().unwrap();
254 out.resize_in_place([self.shape()[0], other.shape()[1]]);
255 self.apply(One::one(), &other, Default::default(), &mut out);
256 } else {
257 panic!(
258 "Unsupported number of dimensions NDIM = {NDIM}. Only NDIM=1 or NDIM=2 supported."
259 );
260 }
261
262 out
263 }
264}
265
266impl<Item, ArrayImplX, ArrayImplY> AsMatrixApply<Array<ArrayImplX, 1>, Array<ArrayImplY, 1>>
267 for CsrMatrix<Item>
268where
269 Item: Default + Mul<Output = Item> + AddAssign<Item> + Add<Output = Item> + Copy + One,
270 ArrayImplX: UnsafeRandomAccessByValue<1, Item = Item> + Shape<1>,
271 ArrayImplY: UnsafeRandom1DAccessMut<Item = Item> + Shape<1>,
272{
273 fn apply(
274 &self,
275 alpha: Self::Item,
276 x: &crate::Array<ArrayImplX, 1>,
277 beta: Self::Item,
278 y: &mut crate::Array<ArrayImplY, 1>,
279 ) {
280 assert_eq!(y.len(), self.shape()[0]);
281 assert_eq!(x.len(), self.shape()[1]);
282 for (row, out) in y.iter_mut().enumerate() {
283 *out = beta * *out
284 + alpha * {
285 let c1 = unsafe { self.indptr.get_value_unchecked([row]) };
286 let c2 = unsafe { self.indptr.get_value_unchecked([1 + row]) };
287 let mut acc = Item::default();
288
289 for index in c1..c2 {
290 let col = unsafe { self.indices.get_value_unchecked([index]) };
291 acc += unsafe {
292 self.data.get_value_unchecked([index]) * x.get_value_unchecked([col])
293 };
294 }
295 acc
296 }
297 }
298 }
299}
300
301impl<Item, ArrayImplX, ArrayImplY> AsMatrixApply<Array<ArrayImplX, 2>, Array<ArrayImplY, 2>>
302 for CsrMatrix<Item>
303where
304 Item: Copy,
305 Self: BaseItem<Item = Item>,
306 ArrayImplX: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>,
307 ArrayImplY: UnsafeRandomAccessMut<2, Item = Item> + Shape<2>,
308 for<'b> Self: AsMatrixApply<
309 Array<ArraySlice<ArrayRef<'b, ArrayImplX, 2>, 2, 1>, 1>,
310 Array<ArraySlice<ArrayRefMut<'b, ArrayImplY, 2>, 2, 1>, 1>,
311 >,
312{
313 fn apply(
314 &self,
315 alpha: Self::Item,
316 x: &crate::Array<ArrayImplX, 2>,
317 beta: Self::Item,
318 y: &mut crate::Array<ArrayImplY, 2>,
319 ) {
320 for (colx, mut coly) in izip!(x.col_iter(), y.col_iter_mut()) {
321 self.apply(alpha, &colx, beta, &mut coly)
322 }
323 }
324}
325
326#[cfg(test)]
327mod test {
328
329 use super::*;
330
331 #[test]
332 fn test_csr() {
333 let rows: Vec<usize> = vec![1, 4, 4];
335 let cols: Vec<usize> = vec![2, 5, 6];
336 let data: Vec<f64> = vec![1.0, 2.0, 3.0];
337
338 let shape = [8, 13];
339 let sparse_mat = CsrMatrix::from_aij(shape, &rows, &cols, &data);
340
341 let mut x = DynArray::<f64, 1>::from_shape([shape[1]]);
342 x.fill_from_seed_equally_distributed(0);
343
344 let y = crate::dot!(sparse_mat, x);
345 let expected = crate::dot!(sparse_mat.todense(), x);
346
347 crate::assert_array_relative_eq!(y, expected, 1E-10);
348 }
349}