nalgebra_sparse/
csr.rs

1//! An implementation of the CSR sparse matrix format.
2//!
3//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
4//! CSC implementation.
5
6#[cfg(feature = "serde-serialize")]
7mod csr_serde;
8
9use crate::cs;
10use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
11use crate::csc::CscMatrix;
12use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
13use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
14
15use nalgebra::Scalar;
16use num_traits::One;
17
18use std::slice::{Iter, IterMut};
19
20/// A CSR representation of a sparse matrix.
21///
22/// The Compressed Sparse Row (CSR) format is well-suited as a general-purpose storage format
23/// for many sparse matrix applications.
24///
25/// # Usage
26///
27/// ```
28/// use nalgebra_sparse::coo::CooMatrix;
29/// use nalgebra_sparse::csr::CsrMatrix;
30/// use nalgebra::{DMatrix, Matrix3x4};
31/// use matrixcompare::assert_matrix_eq;
32///
33/// // The sparsity patterns of CSR matrices are immutable. This means that you cannot dynamically
34/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
35/// // way to construct a CSR matrix is to first incrementally construct a COO matrix,
36/// // and then convert it to CSR.
37///
38/// let mut coo = CooMatrix::<f64>::new(3, 3);
39/// coo.push(2, 0, 1.0);
40/// let csr = CsrMatrix::from(&coo);
41///
42/// // Alternatively, a CSR matrix can be constructed directly from raw CSR data.
43/// // Here, we construct a 3x4 matrix
44/// let row_offsets = vec![0, 3, 3, 5];
45/// let col_indices = vec![0, 1, 3, 1, 2];
46/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
47///
48/// // The dense representation of the CSR data, for comparison
49/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 3.0,
50///                            0.0, 0.0, 0.0, 0.0,
51///                            0.0, 4.0, 5.0, 0.0);
52///
53/// // The constructor validates the raw CSR data and returns an error if it is invalid.
54/// let csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
55///     .expect("CSR data must conform to format specifications");
56/// assert_matrix_eq!(csr, dense);
57///
58/// // A third approach is to construct a CSR matrix from a pattern and values. Sometimes this is
59/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
60/// let (pattern, values) = csr.into_pattern_and_values();
61/// let csr = CsrMatrix::try_from_pattern_and_values(pattern, values)
62///     .expect("The pattern and values must be compatible");
63///
64/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
65/// // other CSR matrices and dense matrices/vectors.
66/// let x = csr;
67/// # #[allow(non_snake_case)]
68/// let xTx = x.transpose() * &x;
69/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
70/// let w = 3.0 * xTx * z;
71///
72/// // Although the sparsity pattern of a CSR matrix cannot be changed, its values can.
73/// // Here are two different ways to scale all values by a constant:
74/// let mut x = x;
75/// x *= 5.0;
76/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
77/// ```
78///
79/// # Format
80///
81/// An `m x n` sparse matrix with `nnz` non-zeros in CSR format is represented by the
82/// following three arrays:
83///
84/// - `row_offsets`, an array of integers with length `m + 1`.
85/// - `col_indices`, an array of integers with length `nnz`.
86/// - `values`, an array of values with length `nnz`.
87///
88/// The relationship between the arrays is described below.
89///
90/// - Each consecutive pair of entries `row_offsets[i] .. row_offsets[i + 1]` corresponds to an
91///   offset range in `col_indices` that holds the column indices in row `i`.
92/// - For an entry represented by the index `idx`, `col_indices[idx]` stores its column index and
93///   `values[idx]` stores its value.
94///
95/// The following invariants must be upheld and are enforced by the data structure:
96///
97/// - `row_offsets[0] == 0`
98/// - `row_offsets[m] == nnz`
99/// - `row_offsets` is monotonically increasing.
100/// - `0 <= col_indices[idx] < n` for all `idx < nnz`.
101/// - The column indices associated with each row are monotonically increasing (see below).
102///
103/// The CSR format is a standard sparse matrix format (see [Wikipedia article]). The format
104/// represents the matrix in a row-by-row fashion. The entries associated with row `i` are
105/// determined as follows:
106///
107/// ```
108/// # let row_offsets: Vec<usize> = vec![0, 0];
109/// # let col_indices: Vec<usize> = vec![];
110/// # let values: Vec<i32> = vec![];
111/// # let i = 0;
112/// let range = row_offsets[i] .. row_offsets[i + 1];
113/// let row_i_cols = &col_indices[range.clone()];
114/// let row_i_vals = &values[range];
115///
116/// // For each pair (j, v) in (row_i_cols, row_i_vals), we obtain a corresponding entry
117/// // (i, j, v) in the matrix.
118/// assert_eq!(row_i_cols.len(), row_i_vals.len());
119/// ```
120///
121/// In the above example, for each row `i`, the column indices `row_i_cols` must appear in
122/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
123/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
124/// assumption for both correctness and performance for many algorithms.
125///
126/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
127/// column-by-column instead of row-by-row like CSR.
128///
129/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct CsrMatrix<T> {
132    // Rows are major, cols are minor in the sparsity pattern
133    pub(crate) cs: CsMatrix<T>,
134}
135
136impl<T> CsrMatrix<T> {
137    /// Constructs a CSR representation of the (square) `n x n` identity matrix.
138    #[inline]
139    pub fn identity(n: usize) -> Self
140    where
141        T: Scalar + One,
142    {
143        Self {
144            cs: CsMatrix::identity(n),
145        }
146    }
147
148    /// Create a zero CSR matrix with no explicitly stored entries.
149    pub fn zeros(nrows: usize, ncols: usize) -> Self {
150        Self {
151            cs: CsMatrix::new(nrows, ncols),
152        }
153    }
154
155    /// Try to construct a CSR matrix from raw CSR data.
156    ///
157    /// It is assumed that each row contains unique and sorted column indices that are in
158    /// bounds with respect to the number of columns in the matrix. If this is not the case,
159    /// an error is returned to indicate the failure.
160    ///
161    /// An error is returned if the data given does not conform to the CSR storage format.
162    /// See the documentation for [`CsrMatrix`] for more information.
163    pub fn try_from_csr_data(
164        num_rows: usize,
165        num_cols: usize,
166        row_offsets: Vec<usize>,
167        col_indices: Vec<usize>,
168        values: Vec<T>,
169    ) -> Result<Self, SparseFormatError> {
170        let pattern = SparsityPattern::try_from_offsets_and_indices(
171            num_rows,
172            num_cols,
173            row_offsets,
174            col_indices,
175        )
176        .map_err(pattern_format_error_to_csr_error)?;
177        Self::try_from_pattern_and_values(pattern, values)
178    }
179
180    /// Try to construct a CSR matrix from raw CSR data with unsorted column indices.
181    ///
182    /// It is assumed that each row contains unique column indices that are in
183    /// bounds with respect to the number of columns in the matrix. If this is not the case,
184    /// an error is returned to indicate the failure.
185    ///
186    /// An error is returned if the data given does not conform to the CSR storage format
187    /// with the exception of having unsorted column indices and values.
188    /// See the documentation for [`CsrMatrix`] for more information.
189    pub fn try_from_unsorted_csr_data(
190        num_rows: usize,
191        num_cols: usize,
192        row_offsets: Vec<usize>,
193        mut col_indices: Vec<usize>,
194        mut values: Vec<T>,
195    ) -> Result<Self, SparseFormatError>
196    where
197        T: Scalar,
198    {
199        let result = cs::validate_and_optionally_sort_cs_data(
200            num_rows,
201            num_cols,
202            &row_offsets,
203            &mut col_indices,
204            Some(&mut values),
205            true,
206        );
207
208        match result {
209            Ok(()) => {
210                let pattern = unsafe {
211                    SparsityPattern::from_offset_and_indices_unchecked(
212                        num_rows,
213                        num_cols,
214                        row_offsets,
215                        col_indices,
216                    )
217                };
218                Self::try_from_pattern_and_values(pattern, values)
219            }
220            Err(err) => Err(err),
221        }
222    }
223
224    /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
225    ///
226    /// Returns an error if the number of values does not match the number of minor indices
227    /// in the pattern.
228    pub fn try_from_pattern_and_values(
229        pattern: SparsityPattern,
230        values: Vec<T>,
231    ) -> Result<Self, SparseFormatError> {
232        if pattern.nnz() == values.len() {
233            Ok(Self {
234                cs: CsMatrix::from_pattern_and_values(pattern, values),
235            })
236        } else {
237            Err(SparseFormatError::from_kind_and_msg(
238                SparseFormatErrorKind::InvalidStructure,
239                "Number of values and column indices must be the same",
240            ))
241        }
242    }
243
244    /// The number of rows in the matrix.
245    #[inline]
246    #[must_use]
247    pub fn nrows(&self) -> usize {
248        self.cs.pattern().major_dim()
249    }
250
251    /// The number of columns in the matrix.
252    #[inline]
253    #[must_use]
254    pub fn ncols(&self) -> usize {
255        self.cs.pattern().minor_dim()
256    }
257
258    /// The number of non-zeros in the matrix.
259    ///
260    /// Note that this corresponds to the number of explicitly stored entries, *not* the actual
261    /// number of algebraically zero entries in the matrix. Explicitly stored entries can still
262    /// be zero. Corresponds to the number of entries in the sparsity pattern.
263    #[inline]
264    #[must_use]
265    pub fn nnz(&self) -> usize {
266        self.cs.pattern().nnz()
267    }
268
269    /// The row offsets defining part of the CSR format.
270    #[inline]
271    #[must_use]
272    pub fn row_offsets(&self) -> &[usize] {
273        let (offsets, _, _) = self.cs.cs_data();
274        offsets
275    }
276
277    /// The column indices defining part of the CSR format.
278    #[inline]
279    #[must_use]
280    pub fn col_indices(&self) -> &[usize] {
281        let (_, indices, _) = self.cs.cs_data();
282        indices
283    }
284
285    /// The non-zero values defining part of the CSR format.
286    #[inline]
287    #[must_use]
288    pub fn values(&self) -> &[T] {
289        self.cs.values()
290    }
291
292    /// Mutable access to the non-zero values.
293    #[inline]
294    pub fn values_mut(&mut self) -> &mut [T] {
295        self.cs.values_mut()
296    }
297
298    /// An iterator over non-zero triplets (i, j, v).
299    ///
300    /// The iteration happens in row-major fashion, meaning that i increases monotonically,
301    /// and j increases monotonically within each row.
302    ///
303    /// Examples
304    /// --------
305    /// ```
306    /// # use nalgebra_sparse::csr::CsrMatrix;
307    /// let row_offsets = vec![0, 2, 3, 4];
308    /// let col_indices = vec![0, 2, 1, 0];
309    /// let values = vec![1, 2, 3, 4];
310    /// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
311    ///     .unwrap();
312    ///
313    /// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
314    /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
315    /// ```
316    pub fn triplet_iter(&self) -> CsrTripletIter<'_, T> {
317        CsrTripletIter {
318            pattern_iter: self.pattern().entries(),
319            values_iter: self.values().iter(),
320        }
321    }
322
323    /// A mutable iterator over non-zero triplets (i, j, v).
324    ///
325    /// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
326    ///
327    /// Examples
328    /// --------
329    /// ```
330    /// # use nalgebra_sparse::csr::CsrMatrix;
331    /// # let row_offsets = vec![0, 2, 3, 4];
332    /// # let col_indices = vec![0, 2, 1, 0];
333    /// # let values = vec![1, 2, 3, 4];
334    /// // Using the same data as in the `triplet_iter` example
335    /// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
336    ///     .unwrap();
337    ///
338    /// // Zero out lower-triangular terms
339    /// csr.triplet_iter_mut()
340    ///    .filter(|(i, j, _)| j < i)
341    ///    .for_each(|(_, _, v)| *v = 0);
342    ///
343    /// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
344    /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
345    /// ```
346    pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<'_, T> {
347        let (pattern, values) = self.cs.pattern_and_values_mut();
348        CsrTripletIterMut {
349            pattern_iter: pattern.entries(),
350            values_mut_iter: values.iter_mut(),
351        }
352    }
353
354    /// Return the row at the given row index.
355    ///
356    /// Panics
357    /// ------
358    /// Panics if row index is out of bounds.
359    #[inline]
360    #[must_use]
361    pub fn row(&self, index: usize) -> CsrRow<'_, T> {
362        self.get_row(index).expect("Row index must be in bounds")
363    }
364
365    /// Mutable row access for the given row index.
366    ///
367    /// Panics
368    /// ------
369    /// Panics if row index is out of bounds.
370    #[inline]
371    pub fn row_mut(&mut self, index: usize) -> CsrRowMut<'_, T> {
372        self.get_row_mut(index)
373            .expect("Row index must be in bounds")
374    }
375
376    /// Return the row at the given row index, or `None` if out of bounds.
377    #[inline]
378    #[must_use]
379    pub fn get_row(&self, index: usize) -> Option<CsrRow<'_, T>> {
380        self.cs.get_lane(index).map(|lane| CsrRow { lane })
381    }
382
383    /// Mutable row access for the given row index, or `None` if out of bounds.
384    #[inline]
385    #[must_use]
386    pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<'_, T>> {
387        self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
388    }
389
390    /// An iterator over rows in the matrix.
391    pub fn row_iter(&self) -> CsrRowIter<'_, T> {
392        CsrRowIter {
393            lane_iter: CsLaneIter::new(self.pattern(), self.values()),
394        }
395    }
396
397    /// A mutable iterator over rows in the matrix.
398    pub fn row_iter_mut(&mut self) -> CsrRowIterMut<'_, T> {
399        let (pattern, values) = self.cs.pattern_and_values_mut();
400        CsrRowIterMut {
401            lane_iter: CsLaneIterMut::new(pattern, values),
402        }
403    }
404
405    /// Disassembles the CSR matrix into its underlying offset, index and value arrays.
406    ///
407    /// If the matrix contains the sole reference to the sparsity pattern,
408    /// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
409    ///
410    /// Examples
411    /// --------
412    ///
413    /// ```
414    /// # use nalgebra_sparse::csr::CsrMatrix;
415    /// let row_offsets = vec![0, 2, 3, 4];
416    /// let col_indices = vec![0, 2, 1, 0];
417    /// let values = vec![1, 2, 3, 4];
418    /// let mut csr = CsrMatrix::try_from_csr_data(
419    ///     3,
420    ///     4,
421    ///     row_offsets.clone(),
422    ///     col_indices.clone(),
423    ///     values.clone())
424    ///     .unwrap();
425    /// let (row_offsets2, col_indices2, values2) = csr.disassemble();
426    /// assert_eq!(row_offsets2, row_offsets);
427    /// assert_eq!(col_indices2, col_indices);
428    /// assert_eq!(values2, values);
429    /// ```
430    pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
431        self.cs.disassemble()
432    }
433
434    /// Returns the sparsity pattern and values associated with this matrix.
435    pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
436        self.cs.into_pattern_and_values()
437    }
438
439    /// Returns a reference to the sparsity pattern and a mutable reference to the values.
440    #[inline]
441    pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
442        self.cs.pattern_and_values_mut()
443    }
444
445    /// Returns a reference to the underlying sparsity pattern.
446    #[must_use]
447    pub fn pattern(&self) -> &SparsityPattern {
448        self.cs.pattern()
449    }
450
451    /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
452    ///
453    /// This operation does not touch the CSR data, and is effectively a no-op.
454    pub fn transpose_as_csc(self) -> CscMatrix<T> {
455        let (pattern, values) = self.cs.take_pattern_and_values();
456        CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
457    }
458
459    /// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
460    ///
461    /// Each call to this function incurs the cost of a binary search among the explicitly
462    /// stored column entries for the given row.
463    #[must_use]
464    pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<'_, T>> {
465        self.cs.get_entry(row_index, col_index)
466    }
467
468    /// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
469    /// of bounds.
470    ///
471    /// Each call to this function incurs the cost of a binary search among the explicitly
472    /// stored column entries for the given row.
473    pub fn get_entry_mut(
474        &mut self,
475        row_index: usize,
476        col_index: usize,
477    ) -> Option<SparseEntryMut<'_, T>> {
478        self.cs.get_entry_mut(row_index, col_index)
479    }
480
481    /// Returns an entry for the given row/col indices.
482    ///
483    /// Same as `get_entry`, except that it directly panics upon encountering row/col indices
484    /// out of bounds.
485    ///
486    /// Panics
487    /// ------
488    /// Panics if `row_index` or `col_index` is out of bounds.
489    #[must_use]
490    pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<'_, T> {
491        self.get_entry(row_index, col_index)
492            .expect("Out of bounds matrix indices encountered")
493    }
494
495    /// Returns a mutable entry for the given row/col indices.
496    ///
497    /// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
498    /// out of bounds.
499    ///
500    /// Panics
501    /// ------
502    /// Panics if `row_index` or `col_index` is out of bounds.
503    pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<'_, T> {
504        self.get_entry_mut(row_index, col_index)
505            .expect("Out of bounds matrix indices encountered")
506    }
507
508    /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
509    #[must_use]
510    pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
511        self.cs.cs_data()
512    }
513
514    /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
515    /// where the `values` array is mutable.
516    pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
517        self.cs.cs_data_mut()
518    }
519
520    /// Creates a sparse matrix that contains only the explicit entries decided by the
521    /// given predicate.
522    #[must_use]
523    pub fn filter<P>(&self, predicate: P) -> Self
524    where
525        T: Clone,
526        P: Fn(usize, usize, &T) -> bool,
527    {
528        Self {
529            cs: self
530                .cs
531                .filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
532        }
533    }
534
535    /// Returns a new matrix representing the upper triangular part of this matrix.
536    ///
537    /// The result includes the diagonal of the matrix.
538    #[must_use]
539    pub fn upper_triangle(&self) -> Self
540    where
541        T: Clone,
542    {
543        self.filter(|i, j, _| i <= j)
544    }
545
546    /// Returns a new matrix representing the lower triangular part of this matrix.
547    ///
548    /// The result includes the diagonal of the matrix.
549    #[must_use]
550    pub fn lower_triangle(&self) -> Self
551    where
552        T: Clone,
553    {
554        self.filter(|i, j, _| i >= j)
555    }
556
557    /// Returns the diagonal of the matrix as a sparse matrix.
558    #[must_use]
559    pub fn diagonal_as_csr(&self) -> Self
560    where
561        T: Clone,
562    {
563        Self {
564            cs: self.cs.diagonal_as_matrix(),
565        }
566    }
567
568    /// Compute the transpose of the matrix.
569    #[must_use]
570    pub fn transpose(&self) -> CsrMatrix<T>
571    where
572        T: Scalar,
573    {
574        CscMatrix::from(self).transpose_as_csr()
575    }
576}
577
578impl<T> Default for CsrMatrix<T> {
579    fn default() -> Self {
580        Self {
581            cs: Default::default(),
582        }
583    }
584}
585
586/// Convert pattern format errors into more meaningful CSR-specific errors.
587///
588/// This ensures that the terminology is consistent: we are talking about rows and columns,
589/// not lanes, major and minor dimensions.
590fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
591    use SparseFormatError as E;
592    use SparseFormatErrorKind as K;
593    use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
594    use SparsityPatternFormatError::*;
595
596    match err {
597        InvalidOffsetArrayLength => E::from_kind_and_msg(
598            K::InvalidStructure,
599            "Length of row offset array is not equal to nrows + 1.",
600        ),
601        InvalidOffsetFirstLast => E::from_kind_and_msg(
602            K::InvalidStructure,
603            "First or last row offset is inconsistent with format specification.",
604        ),
605        NonmonotonicOffsets => E::from_kind_and_msg(
606            K::InvalidStructure,
607            "Row offsets are not monotonically increasing.",
608        ),
609        NonmonotonicMinorIndices => E::from_kind_and_msg(
610            K::InvalidStructure,
611            "Column indices are not monotonically increasing (sorted) within each row.",
612        ),
613        MinorIndexOutOfBounds => {
614            E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
615        }
616        PatternDuplicateEntry => {
617            E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
618        }
619    }
620}
621
622/// Iterator type for iterating over triplets in a CSR matrix.
623#[derive(Debug)]
624pub struct CsrTripletIter<'a, T> {
625    pattern_iter: SparsityPatternIter<'a>,
626    values_iter: Iter<'a, T>,
627}
628
629impl<'a, T> Clone for CsrTripletIter<'a, T> {
630    fn clone(&self) -> Self {
631        CsrTripletIter {
632            pattern_iter: self.pattern_iter.clone(),
633            values_iter: self.values_iter.clone(),
634        }
635    }
636}
637
638impl<'a, T: Clone> CsrTripletIter<'a, T> {
639    /// Adapts the triplet iterator to return owned values.
640    ///
641    /// The triplet iterator returns references to the values. This method adapts the iterator
642    /// so that the values are cloned.
643    #[inline]
644    pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
645        self.map(|(i, j, v)| (i, j, v.clone()))
646    }
647}
648
649impl<'a, T> Iterator for CsrTripletIter<'a, T> {
650    type Item = (usize, usize, &'a T);
651
652    fn next(&mut self) -> Option<Self::Item> {
653        let next_entry = self.pattern_iter.next();
654        let next_value = self.values_iter.next();
655
656        match (next_entry, next_value) {
657            (Some((i, j)), Some(v)) => Some((i, j, v)),
658            _ => None,
659        }
660    }
661}
662
663/// Iterator type for mutably iterating over triplets in a CSR matrix.
664#[derive(Debug)]
665pub struct CsrTripletIterMut<'a, T> {
666    pattern_iter: SparsityPatternIter<'a>,
667    values_mut_iter: IterMut<'a, T>,
668}
669
670impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
671    type Item = (usize, usize, &'a mut T);
672
673    #[inline]
674    fn next(&mut self) -> Option<Self::Item> {
675        let next_entry = self.pattern_iter.next();
676        let next_value = self.values_mut_iter.next();
677
678        match (next_entry, next_value) {
679            (Some((i, j)), Some(v)) => Some((i, j, v)),
680            _ => None,
681        }
682    }
683}
684
685/// An immutable representation of a row in a CSR matrix.
686#[derive(Debug, Clone, PartialEq, Eq)]
687pub struct CsrRow<'a, T> {
688    lane: CsLane<'a, T>,
689}
690
691/// A mutable representation of a row in a CSR matrix.
692///
693/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
694/// to the row cannot be modified.
695#[derive(Debug, PartialEq, Eq)]
696pub struct CsrRowMut<'a, T> {
697    lane: CsLaneMut<'a, T>,
698}
699
700/// Implement the methods common to both CsrRow and CsrRowMut
701macro_rules! impl_csr_row_common_methods {
702    ($name:ty) => {
703        impl<'a, T> $name {
704            /// The number of global columns in the row.
705            #[inline]
706            #[must_use]
707            pub fn ncols(&self) -> usize {
708                self.lane.minor_dim()
709            }
710
711            /// The number of non-zeros in this row.
712            #[inline]
713            #[must_use]
714            pub fn nnz(&self) -> usize {
715                self.lane.nnz()
716            }
717
718            /// The column indices corresponding to explicitly stored entries in this row.
719            #[inline]
720            #[must_use]
721            pub fn col_indices(&self) -> &[usize] {
722                self.lane.minor_indices()
723            }
724
725            /// The values corresponding to explicitly stored entries in this row.
726            #[inline]
727            #[must_use]
728            pub fn values(&self) -> &[T] {
729                self.lane.values()
730            }
731
732            /// Returns an entry for the given global column index.
733            ///
734            /// Each call to this function incurs the cost of a binary search among the explicitly
735            /// stored column entries.
736            #[inline]
737            #[must_use]
738            pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
739                self.lane.get_entry(global_col_index)
740            }
741        }
742    };
743}
744
745impl_csr_row_common_methods!(CsrRow<'a, T>);
746impl_csr_row_common_methods!(CsrRowMut<'a, T>);
747
748impl<'a, T> CsrRowMut<'a, T> {
749    /// Mutable access to the values corresponding to explicitly stored entries in this row.
750    #[inline]
751    pub fn values_mut(&mut self) -> &mut [T] {
752        self.lane.values_mut()
753    }
754
755    /// Provides simultaneous access to column indices and mutable values corresponding to the
756    /// explicitly stored entries in this row.
757    ///
758    /// This method primarily facilitates low-level access for methods that process data stored
759    /// in CSR format directly.
760    #[inline]
761    pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
762        self.lane.indices_and_values_mut()
763    }
764
765    /// Returns a mutable entry for the given global column index.
766    #[inline]
767    #[must_use]
768    pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
769        self.lane.get_entry_mut(global_col_index)
770    }
771}
772
773/// Row iterator for [`CsrMatrix`].
774pub struct CsrRowIter<'a, T> {
775    lane_iter: CsLaneIter<'a, T>,
776}
777
778impl<'a, T> Iterator for CsrRowIter<'a, T> {
779    type Item = CsrRow<'a, T>;
780
781    fn next(&mut self) -> Option<Self::Item> {
782        self.lane_iter.next().map(|lane| CsrRow { lane })
783    }
784}
785
786/// Mutable row iterator for [`CsrMatrix`].
787pub struct CsrRowIterMut<'a, T> {
788    lane_iter: CsLaneIterMut<'a, T>,
789}
790
791impl<'a, T> Iterator for CsrRowIterMut<'a, T>
792where
793    T: 'a,
794{
795    type Item = CsrRowMut<'a, T>;
796
797    fn next(&mut self) -> Option<Self::Item> {
798        self.lane_iter.next().map(|lane| CsrRowMut { lane })
799    }
800}