rlst/sparse/
csr_mat.rs

1//! Definition of CSR matrices.
2//!
3//! A CSR matrix consists of three arrays.
4//! - `data` - Stores all entries of the CSR matrix.
5//! - `indices` - The column indices associated with each entry in `data`.
6//! - `indptr` - An arry of pointers. The data entries for row `i` are contained in `data[indptr[i]]..data[indptr[i + 1]]`.
7//!
8//! The last entry of `indptr` is the number of nonzero elements of the sparse matrix.
9
10use std::ops::{Add, AddAssign, Mul};
11
12use crate::dense::array::reference::{ArrayRef, ArrayRefMut};
13use crate::dense::array::slice::ArraySlice;
14use crate::sparse::tools::normalize_aij;
15use crate::{AijIteratorByValue, BaseItem, Shape, dense::array::DynArray, sparse::SparseMatType};
16use crate::{
17    AijIteratorMut, Array, AsMatrixApply, FromAij, Nonzeros, SparseMatrixType,
18    UnsafeRandom1DAccessMut, UnsafeRandomAccessByValue, UnsafeRandomAccessMut, empty_array,
19};
20use itertools::{Itertools, izip};
21use num::One;
22
23use super::mat_operations::SparseMatOpIterator;
24
25/// A CSR matrix
26pub struct CsrMatrix<Item> {
27    /// The `mat_type` denotes the storage type for the sparse matrix.
28    mat_type: SparseMatType,
29    /// The shape of the sparse matrix.
30    shape: [usize; 2],
31    /// The array of column indices.
32    indices: DynArray<usize, 1>,
33    /// The array of index pointers.
34    indptr: DynArray<usize, 1>,
35    /// The entries of the sparse matrix.
36    data: DynArray<Item, 1>,
37}
38
39impl<Item> CsrMatrix<Item> {
40    /// Create a new CSR matrix
41    pub fn new(
42        shape: [usize; 2],
43        indices: DynArray<usize, 1>,
44        indptr: DynArray<usize, 1>,
45        data: DynArray<Item, 1>,
46    ) -> Self {
47        // Check that the indices cannot be out of bounds.
48        // This is because `apply` uses unsafe access to the entries.
49
50        assert_eq!(indptr.len(), 1 + shape[0]);
51        assert_eq!(data.len(), indices.len());
52        assert_eq!(*indptr.data().unwrap().last().unwrap(), data.len());
53
54        // Check that the indices in indptr are monotonically increasing and
55        // are smaller or equal to the overall length of `indices`. This
56        // guarantees that there cannot be a memory error in the unsafe
57        // access in `apply`.
58        for (first, second) in indptr.iter_value().tuple_windows() {
59            assert!(
60                first <= second,
61                "Elements of indptr not in increasing order {first} > {second}."
62            );
63        }
64        // Check that the last element in indptr is the length of the `indices` array.
65        assert_eq!(*indptr.data().unwrap().last().unwrap(), indices.len());
66
67        // Check that the column indices in `indices` are smaller than `shape[1]`.
68
69        if let Some(&max_col_index) = indices.data().unwrap().iter().max() {
70            assert!(max_col_index < shape[1]);
71        }
72
73        Self {
74            mat_type: SparseMatType::Csr,
75            shape,
76            indices,
77            indptr,
78            data,
79        }
80    }
81
82    /// Return the index pointer.
83    pub fn indptr(&self) -> &DynArray<usize, 1> {
84        &self.indptr
85    }
86
87    /// Return the indices.
88    pub fn indices(&self) -> &DynArray<usize, 1> {
89        &self.indices
90    }
91
92    /// Return the data.
93    pub fn data(&self) -> &DynArray<Item, 1> {
94        &self.data
95    }
96}
97
98impl<Item> FromAij for CsrMatrix<Item>
99where
100    Item: AddAssign + PartialEq + Copy + Default,
101{
102    /// Create a new CSR matrix from arrays `row`, `cols` and `data` that store for
103    /// each nonzero entry the associated row, column, and value.
104    fn from_aij(shape: [usize; 2], rows: &[usize], cols: &[usize], data: &[Item]) -> Self {
105        let (rows, cols, data) = normalize_aij(rows, cols, data, SparseMatType::Csr);
106
107        let max_col = if let Some(col) = cols.iter().max() {
108            *col
109        } else {
110            0
111        };
112        let max_row = if let Some(row) = rows.last() { *row } else { 0 };
113
114        assert!(
115            max_col < shape[1],
116            "Maximum column {} must be smaller than `shape.1` {}",
117            max_col,
118            shape[1]
119        );
120
121        assert!(
122            max_row < shape[0],
123            "Maximum row {} must be smaller than `shape.0` {}",
124            max_row,
125            shape[0]
126        );
127
128        let nelems = data.len();
129
130        let mut indptr = Vec::<usize>::with_capacity(1 + shape[0]);
131
132        let mut count: usize = 0;
133        for row in 0..(shape[0]) {
134            indptr.push(count);
135            while count < nelems && row == rows[count] {
136                count += 1;
137            }
138        }
139        indptr.push(count);
140
141        let indptr = DynArray::from_shape_and_vec([1 + shape[0]], indptr);
142        let indices = DynArray::from_shape_and_vec([nelems], cols);
143        let data = DynArray::from_shape_and_vec([nelems], data);
144
145        Self::new(shape, indices, indptr, data)
146    }
147}
148
149impl<Item> Shape<2> for CsrMatrix<Item> {
150    fn shape(&self) -> [usize; 2] {
151        self.shape
152    }
153}
154
155impl<Item> BaseItem for CsrMatrix<Item>
156where
157    Item: Copy + Default,
158{
159    type Item = Item;
160}
161
162impl<Item> Nonzeros for CsrMatrix<Item> {
163    fn nnz(&self) -> usize {
164        self.data.len()
165    }
166}
167
168impl<Item> SparseMatrixType for CsrMatrix<Item> {
169    fn mat_type(&self) -> SparseMatType {
170        self.mat_type
171    }
172}
173
174impl<Item> AijIteratorByValue for CsrMatrix<Item>
175where
176    Item: Copy + Default,
177{
178    fn iter_aij_value(&self) -> impl Iterator<Item = ([usize; 2], Self::Item)> + '_ {
179        self.indptr
180            .iter_value()
181            .tuple_windows::<(usize, usize)>()
182            .enumerate()
183            .flat_map(|(row, (start, end))| {
184                izip!(
185                    self.indices.data().unwrap()[start..end].iter(),
186                    self.data.data().unwrap()[start..end].iter()
187                )
188                .map(|(col, value)| ([row, *col], *value))
189                .collect::<Vec<_>>()
190            })
191    }
192}
193
194impl<Item> AijIteratorMut for CsrMatrix<Item>
195where
196    Item: Copy + Default,
197{
198    fn iter_aij_mut(&mut self) -> impl Iterator<Item = ([usize; 2], &mut Self::Item)> + '_ {
199        self.indptr
200            .iter_value()
201            .tuple_windows::<(usize, usize)>()
202            .enumerate()
203            .flat_map(|(row, (start, end))| {
204                izip!(
205                    self.indices.data().unwrap()[start..end].iter(),
206                    self.data.data_mut().unwrap()[start..end]
207                        .iter_mut()
208                        // Need to convert the mutable reference to the raw pointer
209                        // as borrow checker does not allow the mutable reference to leak from FnMut.
210                        .map(|v| v as *mut Item)
211                )
212                .map(|(col, value)| ([row, *col], value))
213                .collect::<Vec<_>>()
214            })
215            .map(|(idx, value)| (idx, unsafe { &mut *value }))
216    }
217}
218
219impl<Item: Copy + Default> CsrMatrix<Item> {
220    /// Return as sparse matrix in iterator form.
221    pub fn op(&self) -> SparseMatOpIterator<Item, impl Iterator<Item = ([usize; 2], Item)> + '_> {
222        SparseMatOpIterator::new(self.iter_aij_value(), self.shape())
223    }
224}
225
226impl<Item: Copy + Default> CsrMatrix<Item> {
227    /// Convert to a dense matrix.
228    pub fn todense(&self) -> DynArray<Item, 2> {
229        DynArray::from_iter_aij(self.shape(), self.iter_aij_value())
230    }
231}
232
233impl<Item: Default + Mul<Output = Item> + AddAssign<Item> + Add<Output = Item> + Copy + One>
234    CsrMatrix<Item>
235{
236    /// Apply the matrix to a vector or dense matrix.
237    pub fn dot<ArrayImpl, const NDIM: usize>(
238        &self,
239        other: &Array<ArrayImpl, NDIM>,
240    ) -> DynArray<Item, NDIM>
241    where
242        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
243    {
244        let mut out = empty_array::<Item, NDIM>();
245
246        if NDIM == 1 {
247            let mut out = out.r_mut().coerce_dim::<1>().unwrap();
248            let other = other.r().coerce_dim::<1>().unwrap();
249            out.resize_in_place([self.shape()[0]]);
250            self.apply(One::one(), &other, Default::default(), &mut out);
251        } else if NDIM == 2 {
252            let mut out = out.r_mut().coerce_dim::<2>().unwrap();
253            let other = other.r().coerce_dim::<2>().unwrap();
254            out.resize_in_place([self.shape()[0], other.shape()[1]]);
255            self.apply(One::one(), &other, Default::default(), &mut out);
256        } else {
257            panic!(
258                "Unsupported number of dimensions NDIM = {NDIM}. Only NDIM=1 or NDIM=2 supported."
259            );
260        }
261
262        out
263    }
264}
265
266impl<Item, ArrayImplX, ArrayImplY> AsMatrixApply<Array<ArrayImplX, 1>, Array<ArrayImplY, 1>>
267    for CsrMatrix<Item>
268where
269    Item: Default + Mul<Output = Item> + AddAssign<Item> + Add<Output = Item> + Copy + One,
270    ArrayImplX: UnsafeRandomAccessByValue<1, Item = Item> + Shape<1>,
271    ArrayImplY: UnsafeRandom1DAccessMut<Item = Item> + Shape<1>,
272{
273    fn apply(
274        &self,
275        alpha: Self::Item,
276        x: &crate::Array<ArrayImplX, 1>,
277        beta: Self::Item,
278        y: &mut crate::Array<ArrayImplY, 1>,
279    ) {
280        assert_eq!(y.len(), self.shape()[0]);
281        assert_eq!(x.len(), self.shape()[1]);
282        for (row, out) in y.iter_mut().enumerate() {
283            *out = beta * *out
284                + alpha * {
285                    let c1 = unsafe { self.indptr.get_value_unchecked([row]) };
286                    let c2 = unsafe { self.indptr.get_value_unchecked([1 + row]) };
287                    let mut acc = Item::default();
288
289                    for index in c1..c2 {
290                        let col = unsafe { self.indices.get_value_unchecked([index]) };
291                        acc += unsafe {
292                            self.data.get_value_unchecked([index]) * x.get_value_unchecked([col])
293                        };
294                    }
295                    acc
296                }
297        }
298    }
299}
300
301impl<Item, ArrayImplX, ArrayImplY> AsMatrixApply<Array<ArrayImplX, 2>, Array<ArrayImplY, 2>>
302    for CsrMatrix<Item>
303where
304    Item: Copy,
305    Self: BaseItem<Item = Item>,
306    ArrayImplX: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>,
307    ArrayImplY: UnsafeRandomAccessMut<2, Item = Item> + Shape<2>,
308    for<'b> Self: AsMatrixApply<
309            Array<ArraySlice<ArrayRef<'b, ArrayImplX, 2>, 2, 1>, 1>,
310            Array<ArraySlice<ArrayRefMut<'b, ArrayImplY, 2>, 2, 1>, 1>,
311        >,
312{
313    fn apply(
314        &self,
315        alpha: Self::Item,
316        x: &crate::Array<ArrayImplX, 2>,
317        beta: Self::Item,
318        y: &mut crate::Array<ArrayImplY, 2>,
319    ) {
320        for (colx, mut coly) in izip!(x.col_iter(), y.col_iter_mut()) {
321            self.apply(alpha, &colx, beta, &mut coly)
322        }
323    }
324}
325
326#[cfg(test)]
327mod test {
328
329    use super::*;
330
331    #[test]
332    fn test_csr() {
333        // We create a simple CSR matrix.
334        let rows: Vec<usize> = vec![1, 4, 4];
335        let cols: Vec<usize> = vec![2, 5, 6];
336        let data: Vec<f64> = vec![1.0, 2.0, 3.0];
337
338        let shape = [8, 13];
339        let sparse_mat = CsrMatrix::from_aij(shape, &rows, &cols, &data);
340
341        let mut x = DynArray::<f64, 1>::from_shape([shape[1]]);
342        x.fill_from_seed_equally_distributed(0);
343
344        let y = crate::dot!(sparse_mat, x);
345        let expected = crate::dot!(sparse_mat.todense(), x);
346
347        crate::assert_array_relative_eq!(y, expected, 1E-10);
348    }
349}