nalgebra_sparse/csr.rs
1//! An implementation of the CSR sparse matrix format.
2//!
3//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
4//! CSC implementation.
5
6#[cfg(feature = "serde-serialize")]
7mod csr_serde;
8
9use crate::cs;
10use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
11use crate::csc::CscMatrix;
12use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
13use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
14
15use nalgebra::Scalar;
16use num_traits::One;
17
18use std::slice::{Iter, IterMut};
19
20/// A CSR representation of a sparse matrix.
21///
22/// The Compressed Sparse Row (CSR) format is well-suited as a general-purpose storage format
23/// for many sparse matrix applications.
24///
25/// # Usage
26///
27/// ```
28/// use nalgebra_sparse::coo::CooMatrix;
29/// use nalgebra_sparse::csr::CsrMatrix;
30/// use nalgebra::{DMatrix, Matrix3x4};
31/// use matrixcompare::assert_matrix_eq;
32///
33/// // The sparsity patterns of CSR matrices are immutable. This means that you cannot dynamically
34/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
35/// // way to construct a CSR matrix is to first incrementally construct a COO matrix,
36/// // and then convert it to CSR.
37///
38/// let mut coo = CooMatrix::<f64>::new(3, 3);
39/// coo.push(2, 0, 1.0);
40/// let csr = CsrMatrix::from(&coo);
41///
42/// // Alternatively, a CSR matrix can be constructed directly from raw CSR data.
43/// // Here, we construct a 3x4 matrix
44/// let row_offsets = vec![0, 3, 3, 5];
45/// let col_indices = vec![0, 1, 3, 1, 2];
46/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
47///
48/// // The dense representation of the CSR data, for comparison
49/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 3.0,
50/// 0.0, 0.0, 0.0, 0.0,
51/// 0.0, 4.0, 5.0, 0.0);
52///
53/// // The constructor validates the raw CSR data and returns an error if it is invalid.
54/// let csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
55/// .expect("CSR data must conform to format specifications");
56/// assert_matrix_eq!(csr, dense);
57///
58/// // A third approach is to construct a CSR matrix from a pattern and values. Sometimes this is
59/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
60/// let (pattern, values) = csr.into_pattern_and_values();
61/// let csr = CsrMatrix::try_from_pattern_and_values(pattern, values)
62/// .expect("The pattern and values must be compatible");
63///
64/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
65/// // other CSR matrices and dense matrices/vectors.
66/// let x = csr;
67/// # #[allow(non_snake_case)]
68/// let xTx = x.transpose() * &x;
69/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
70/// let w = 3.0 * xTx * z;
71///
72/// // Although the sparsity pattern of a CSR matrix cannot be changed, its values can.
73/// // Here are two different ways to scale all values by a constant:
74/// let mut x = x;
75/// x *= 5.0;
76/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
77/// ```
78///
79/// # Format
80///
81/// An `m x n` sparse matrix with `nnz` non-zeros in CSR format is represented by the
82/// following three arrays:
83///
84/// - `row_offsets`, an array of integers with length `m + 1`.
85/// - `col_indices`, an array of integers with length `nnz`.
86/// - `values`, an array of values with length `nnz`.
87///
88/// The relationship between the arrays is described below.
89///
90/// - Each consecutive pair of entries `row_offsets[i] .. row_offsets[i + 1]` corresponds to an
91/// offset range in `col_indices` that holds the column indices in row `i`.
92/// - For an entry represented by the index `idx`, `col_indices[idx]` stores its column index and
93/// `values[idx]` stores its value.
94///
95/// The following invariants must be upheld and are enforced by the data structure:
96///
97/// - `row_offsets[0] == 0`
98/// - `row_offsets[m] == nnz`
99/// - `row_offsets` is monotonically increasing.
100/// - `0 <= col_indices[idx] < n` for all `idx < nnz`.
101/// - The column indices associated with each row are monotonically increasing (see below).
102///
103/// The CSR format is a standard sparse matrix format (see [Wikipedia article]). The format
104/// represents the matrix in a row-by-row fashion. The entries associated with row `i` are
105/// determined as follows:
106///
107/// ```
108/// # let row_offsets: Vec<usize> = vec![0, 0];
109/// # let col_indices: Vec<usize> = vec![];
110/// # let values: Vec<i32> = vec![];
111/// # let i = 0;
112/// let range = row_offsets[i] .. row_offsets[i + 1];
113/// let row_i_cols = &col_indices[range.clone()];
114/// let row_i_vals = &values[range];
115///
116/// // For each pair (j, v) in (row_i_cols, row_i_vals), we obtain a corresponding entry
117/// // (i, j, v) in the matrix.
118/// assert_eq!(row_i_cols.len(), row_i_vals.len());
119/// ```
120///
121/// In the above example, for each row `i`, the column indices `row_i_cols` must appear in
122/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
123/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
124/// assumption for both correctness and performance for many algorithms.
125///
126/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
127/// column-by-column instead of row-by-row like CSR.
128///
129/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct CsrMatrix<T> {
132 // Rows are major, cols are minor in the sparsity pattern
133 pub(crate) cs: CsMatrix<T>,
134}
135
136impl<T> CsrMatrix<T> {
137 /// Constructs a CSR representation of the (square) `n x n` identity matrix.
138 #[inline]
139 pub fn identity(n: usize) -> Self
140 where
141 T: Scalar + One,
142 {
143 Self {
144 cs: CsMatrix::identity(n),
145 }
146 }
147
148 /// Create a zero CSR matrix with no explicitly stored entries.
149 pub fn zeros(nrows: usize, ncols: usize) -> Self {
150 Self {
151 cs: CsMatrix::new(nrows, ncols),
152 }
153 }
154
155 /// Try to construct a CSR matrix from raw CSR data.
156 ///
157 /// It is assumed that each row contains unique and sorted column indices that are in
158 /// bounds with respect to the number of columns in the matrix. If this is not the case,
159 /// an error is returned to indicate the failure.
160 ///
161 /// An error is returned if the data given does not conform to the CSR storage format.
162 /// See the documentation for [`CsrMatrix`] for more information.
163 pub fn try_from_csr_data(
164 num_rows: usize,
165 num_cols: usize,
166 row_offsets: Vec<usize>,
167 col_indices: Vec<usize>,
168 values: Vec<T>,
169 ) -> Result<Self, SparseFormatError> {
170 let pattern = SparsityPattern::try_from_offsets_and_indices(
171 num_rows,
172 num_cols,
173 row_offsets,
174 col_indices,
175 )
176 .map_err(pattern_format_error_to_csr_error)?;
177 Self::try_from_pattern_and_values(pattern, values)
178 }
179
180 /// Try to construct a CSR matrix from raw CSR data with unsorted column indices.
181 ///
182 /// It is assumed that each row contains unique column indices that are in
183 /// bounds with respect to the number of columns in the matrix. If this is not the case,
184 /// an error is returned to indicate the failure.
185 ///
186 /// An error is returned if the data given does not conform to the CSR storage format
187 /// with the exception of having unsorted column indices and values.
188 /// See the documentation for [`CsrMatrix`] for more information.
189 pub fn try_from_unsorted_csr_data(
190 num_rows: usize,
191 num_cols: usize,
192 row_offsets: Vec<usize>,
193 mut col_indices: Vec<usize>,
194 mut values: Vec<T>,
195 ) -> Result<Self, SparseFormatError>
196 where
197 T: Scalar,
198 {
199 let result = cs::validate_and_optionally_sort_cs_data(
200 num_rows,
201 num_cols,
202 &row_offsets,
203 &mut col_indices,
204 Some(&mut values),
205 true,
206 );
207
208 match result {
209 Ok(()) => {
210 let pattern = unsafe {
211 SparsityPattern::from_offset_and_indices_unchecked(
212 num_rows,
213 num_cols,
214 row_offsets,
215 col_indices,
216 )
217 };
218 Self::try_from_pattern_and_values(pattern, values)
219 }
220 Err(err) => Err(err),
221 }
222 }
223
224 /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
225 ///
226 /// Returns an error if the number of values does not match the number of minor indices
227 /// in the pattern.
228 pub fn try_from_pattern_and_values(
229 pattern: SparsityPattern,
230 values: Vec<T>,
231 ) -> Result<Self, SparseFormatError> {
232 if pattern.nnz() == values.len() {
233 Ok(Self {
234 cs: CsMatrix::from_pattern_and_values(pattern, values),
235 })
236 } else {
237 Err(SparseFormatError::from_kind_and_msg(
238 SparseFormatErrorKind::InvalidStructure,
239 "Number of values and column indices must be the same",
240 ))
241 }
242 }
243
244 /// The number of rows in the matrix.
245 #[inline]
246 #[must_use]
247 pub fn nrows(&self) -> usize {
248 self.cs.pattern().major_dim()
249 }
250
251 /// The number of columns in the matrix.
252 #[inline]
253 #[must_use]
254 pub fn ncols(&self) -> usize {
255 self.cs.pattern().minor_dim()
256 }
257
258 /// The number of non-zeros in the matrix.
259 ///
260 /// Note that this corresponds to the number of explicitly stored entries, *not* the actual
261 /// number of algebraically zero entries in the matrix. Explicitly stored entries can still
262 /// be zero. Corresponds to the number of entries in the sparsity pattern.
263 #[inline]
264 #[must_use]
265 pub fn nnz(&self) -> usize {
266 self.cs.pattern().nnz()
267 }
268
269 /// The row offsets defining part of the CSR format.
270 #[inline]
271 #[must_use]
272 pub fn row_offsets(&self) -> &[usize] {
273 let (offsets, _, _) = self.cs.cs_data();
274 offsets
275 }
276
277 /// The column indices defining part of the CSR format.
278 #[inline]
279 #[must_use]
280 pub fn col_indices(&self) -> &[usize] {
281 let (_, indices, _) = self.cs.cs_data();
282 indices
283 }
284
285 /// The non-zero values defining part of the CSR format.
286 #[inline]
287 #[must_use]
288 pub fn values(&self) -> &[T] {
289 self.cs.values()
290 }
291
292 /// Mutable access to the non-zero values.
293 #[inline]
294 pub fn values_mut(&mut self) -> &mut [T] {
295 self.cs.values_mut()
296 }
297
298 /// An iterator over non-zero triplets (i, j, v).
299 ///
300 /// The iteration happens in row-major fashion, meaning that i increases monotonically,
301 /// and j increases monotonically within each row.
302 ///
303 /// Examples
304 /// --------
305 /// ```
306 /// # use nalgebra_sparse::csr::CsrMatrix;
307 /// let row_offsets = vec![0, 2, 3, 4];
308 /// let col_indices = vec![0, 2, 1, 0];
309 /// let values = vec![1, 2, 3, 4];
310 /// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
311 /// .unwrap();
312 ///
313 /// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
314 /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
315 /// ```
316 pub fn triplet_iter(&self) -> CsrTripletIter<'_, T> {
317 CsrTripletIter {
318 pattern_iter: self.pattern().entries(),
319 values_iter: self.values().iter(),
320 }
321 }
322
323 /// A mutable iterator over non-zero triplets (i, j, v).
324 ///
325 /// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
326 ///
327 /// Examples
328 /// --------
329 /// ```
330 /// # use nalgebra_sparse::csr::CsrMatrix;
331 /// # let row_offsets = vec![0, 2, 3, 4];
332 /// # let col_indices = vec![0, 2, 1, 0];
333 /// # let values = vec![1, 2, 3, 4];
334 /// // Using the same data as in the `triplet_iter` example
335 /// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
336 /// .unwrap();
337 ///
338 /// // Zero out lower-triangular terms
339 /// csr.triplet_iter_mut()
340 /// .filter(|(i, j, _)| j < i)
341 /// .for_each(|(_, _, v)| *v = 0);
342 ///
343 /// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
344 /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
345 /// ```
346 pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<'_, T> {
347 let (pattern, values) = self.cs.pattern_and_values_mut();
348 CsrTripletIterMut {
349 pattern_iter: pattern.entries(),
350 values_mut_iter: values.iter_mut(),
351 }
352 }
353
354 /// Return the row at the given row index.
355 ///
356 /// Panics
357 /// ------
358 /// Panics if row index is out of bounds.
359 #[inline]
360 #[must_use]
361 pub fn row(&self, index: usize) -> CsrRow<'_, T> {
362 self.get_row(index).expect("Row index must be in bounds")
363 }
364
365 /// Mutable row access for the given row index.
366 ///
367 /// Panics
368 /// ------
369 /// Panics if row index is out of bounds.
370 #[inline]
371 pub fn row_mut(&mut self, index: usize) -> CsrRowMut<'_, T> {
372 self.get_row_mut(index)
373 .expect("Row index must be in bounds")
374 }
375
376 /// Return the row at the given row index, or `None` if out of bounds.
377 #[inline]
378 #[must_use]
379 pub fn get_row(&self, index: usize) -> Option<CsrRow<'_, T>> {
380 self.cs.get_lane(index).map(|lane| CsrRow { lane })
381 }
382
383 /// Mutable row access for the given row index, or `None` if out of bounds.
384 #[inline]
385 #[must_use]
386 pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<'_, T>> {
387 self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
388 }
389
390 /// An iterator over rows in the matrix.
391 pub fn row_iter(&self) -> CsrRowIter<'_, T> {
392 CsrRowIter {
393 lane_iter: CsLaneIter::new(self.pattern(), self.values()),
394 }
395 }
396
397 /// A mutable iterator over rows in the matrix.
398 pub fn row_iter_mut(&mut self) -> CsrRowIterMut<'_, T> {
399 let (pattern, values) = self.cs.pattern_and_values_mut();
400 CsrRowIterMut {
401 lane_iter: CsLaneIterMut::new(pattern, values),
402 }
403 }
404
405 /// Disassembles the CSR matrix into its underlying offset, index and value arrays.
406 ///
407 /// If the matrix contains the sole reference to the sparsity pattern,
408 /// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
409 ///
410 /// Examples
411 /// --------
412 ///
413 /// ```
414 /// # use nalgebra_sparse::csr::CsrMatrix;
415 /// let row_offsets = vec![0, 2, 3, 4];
416 /// let col_indices = vec![0, 2, 1, 0];
417 /// let values = vec![1, 2, 3, 4];
418 /// let mut csr = CsrMatrix::try_from_csr_data(
419 /// 3,
420 /// 4,
421 /// row_offsets.clone(),
422 /// col_indices.clone(),
423 /// values.clone())
424 /// .unwrap();
425 /// let (row_offsets2, col_indices2, values2) = csr.disassemble();
426 /// assert_eq!(row_offsets2, row_offsets);
427 /// assert_eq!(col_indices2, col_indices);
428 /// assert_eq!(values2, values);
429 /// ```
430 pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
431 self.cs.disassemble()
432 }
433
434 /// Returns the sparsity pattern and values associated with this matrix.
435 pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
436 self.cs.into_pattern_and_values()
437 }
438
439 /// Returns a reference to the sparsity pattern and a mutable reference to the values.
440 #[inline]
441 pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
442 self.cs.pattern_and_values_mut()
443 }
444
445 /// Returns a reference to the underlying sparsity pattern.
446 #[must_use]
447 pub fn pattern(&self) -> &SparsityPattern {
448 self.cs.pattern()
449 }
450
451 /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
452 ///
453 /// This operation does not touch the CSR data, and is effectively a no-op.
454 pub fn transpose_as_csc(self) -> CscMatrix<T> {
455 let (pattern, values) = self.cs.take_pattern_and_values();
456 CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
457 }
458
459 /// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
460 ///
461 /// Each call to this function incurs the cost of a binary search among the explicitly
462 /// stored column entries for the given row.
463 #[must_use]
464 pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<'_, T>> {
465 self.cs.get_entry(row_index, col_index)
466 }
467
468 /// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
469 /// of bounds.
470 ///
471 /// Each call to this function incurs the cost of a binary search among the explicitly
472 /// stored column entries for the given row.
473 pub fn get_entry_mut(
474 &mut self,
475 row_index: usize,
476 col_index: usize,
477 ) -> Option<SparseEntryMut<'_, T>> {
478 self.cs.get_entry_mut(row_index, col_index)
479 }
480
481 /// Returns an entry for the given row/col indices.
482 ///
483 /// Same as `get_entry`, except that it directly panics upon encountering row/col indices
484 /// out of bounds.
485 ///
486 /// Panics
487 /// ------
488 /// Panics if `row_index` or `col_index` is out of bounds.
489 #[must_use]
490 pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<'_, T> {
491 self.get_entry(row_index, col_index)
492 .expect("Out of bounds matrix indices encountered")
493 }
494
495 /// Returns a mutable entry for the given row/col indices.
496 ///
497 /// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
498 /// out of bounds.
499 ///
500 /// Panics
501 /// ------
502 /// Panics if `row_index` or `col_index` is out of bounds.
503 pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<'_, T> {
504 self.get_entry_mut(row_index, col_index)
505 .expect("Out of bounds matrix indices encountered")
506 }
507
508 /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
509 #[must_use]
510 pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
511 self.cs.cs_data()
512 }
513
514 /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
515 /// where the `values` array is mutable.
516 pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
517 self.cs.cs_data_mut()
518 }
519
520 /// Creates a sparse matrix that contains only the explicit entries decided by the
521 /// given predicate.
522 #[must_use]
523 pub fn filter<P>(&self, predicate: P) -> Self
524 where
525 T: Clone,
526 P: Fn(usize, usize, &T) -> bool,
527 {
528 Self {
529 cs: self
530 .cs
531 .filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
532 }
533 }
534
535 /// Returns a new matrix representing the upper triangular part of this matrix.
536 ///
537 /// The result includes the diagonal of the matrix.
538 #[must_use]
539 pub fn upper_triangle(&self) -> Self
540 where
541 T: Clone,
542 {
543 self.filter(|i, j, _| i <= j)
544 }
545
546 /// Returns a new matrix representing the lower triangular part of this matrix.
547 ///
548 /// The result includes the diagonal of the matrix.
549 #[must_use]
550 pub fn lower_triangle(&self) -> Self
551 where
552 T: Clone,
553 {
554 self.filter(|i, j, _| i >= j)
555 }
556
557 /// Returns the diagonal of the matrix as a sparse matrix.
558 #[must_use]
559 pub fn diagonal_as_csr(&self) -> Self
560 where
561 T: Clone,
562 {
563 Self {
564 cs: self.cs.diagonal_as_matrix(),
565 }
566 }
567
568 /// Compute the transpose of the matrix.
569 #[must_use]
570 pub fn transpose(&self) -> CsrMatrix<T>
571 where
572 T: Scalar,
573 {
574 CscMatrix::from(self).transpose_as_csr()
575 }
576}
577
578impl<T> Default for CsrMatrix<T> {
579 fn default() -> Self {
580 Self {
581 cs: Default::default(),
582 }
583 }
584}
585
586/// Convert pattern format errors into more meaningful CSR-specific errors.
587///
588/// This ensures that the terminology is consistent: we are talking about rows and columns,
589/// not lanes, major and minor dimensions.
590fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
591 use SparseFormatError as E;
592 use SparseFormatErrorKind as K;
593 use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
594 use SparsityPatternFormatError::*;
595
596 match err {
597 InvalidOffsetArrayLength => E::from_kind_and_msg(
598 K::InvalidStructure,
599 "Length of row offset array is not equal to nrows + 1.",
600 ),
601 InvalidOffsetFirstLast => E::from_kind_and_msg(
602 K::InvalidStructure,
603 "First or last row offset is inconsistent with format specification.",
604 ),
605 NonmonotonicOffsets => E::from_kind_and_msg(
606 K::InvalidStructure,
607 "Row offsets are not monotonically increasing.",
608 ),
609 NonmonotonicMinorIndices => E::from_kind_and_msg(
610 K::InvalidStructure,
611 "Column indices are not monotonically increasing (sorted) within each row.",
612 ),
613 MinorIndexOutOfBounds => {
614 E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
615 }
616 PatternDuplicateEntry => {
617 E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
618 }
619 }
620}
621
622/// Iterator type for iterating over triplets in a CSR matrix.
623#[derive(Debug)]
624pub struct CsrTripletIter<'a, T> {
625 pattern_iter: SparsityPatternIter<'a>,
626 values_iter: Iter<'a, T>,
627}
628
629impl<'a, T> Clone for CsrTripletIter<'a, T> {
630 fn clone(&self) -> Self {
631 CsrTripletIter {
632 pattern_iter: self.pattern_iter.clone(),
633 values_iter: self.values_iter.clone(),
634 }
635 }
636}
637
638impl<'a, T: Clone> CsrTripletIter<'a, T> {
639 /// Adapts the triplet iterator to return owned values.
640 ///
641 /// The triplet iterator returns references to the values. This method adapts the iterator
642 /// so that the values are cloned.
643 #[inline]
644 pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
645 self.map(|(i, j, v)| (i, j, v.clone()))
646 }
647}
648
649impl<'a, T> Iterator for CsrTripletIter<'a, T> {
650 type Item = (usize, usize, &'a T);
651
652 fn next(&mut self) -> Option<Self::Item> {
653 let next_entry = self.pattern_iter.next();
654 let next_value = self.values_iter.next();
655
656 match (next_entry, next_value) {
657 (Some((i, j)), Some(v)) => Some((i, j, v)),
658 _ => None,
659 }
660 }
661}
662
663/// Iterator type for mutably iterating over triplets in a CSR matrix.
664#[derive(Debug)]
665pub struct CsrTripletIterMut<'a, T> {
666 pattern_iter: SparsityPatternIter<'a>,
667 values_mut_iter: IterMut<'a, T>,
668}
669
670impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
671 type Item = (usize, usize, &'a mut T);
672
673 #[inline]
674 fn next(&mut self) -> Option<Self::Item> {
675 let next_entry = self.pattern_iter.next();
676 let next_value = self.values_mut_iter.next();
677
678 match (next_entry, next_value) {
679 (Some((i, j)), Some(v)) => Some((i, j, v)),
680 _ => None,
681 }
682 }
683}
684
685/// An immutable representation of a row in a CSR matrix.
686#[derive(Debug, Clone, PartialEq, Eq)]
687pub struct CsrRow<'a, T> {
688 lane: CsLane<'a, T>,
689}
690
691/// A mutable representation of a row in a CSR matrix.
692///
693/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
694/// to the row cannot be modified.
695#[derive(Debug, PartialEq, Eq)]
696pub struct CsrRowMut<'a, T> {
697 lane: CsLaneMut<'a, T>,
698}
699
700/// Implement the methods common to both CsrRow and CsrRowMut
701macro_rules! impl_csr_row_common_methods {
702 ($name:ty) => {
703 impl<'a, T> $name {
704 /// The number of global columns in the row.
705 #[inline]
706 #[must_use]
707 pub fn ncols(&self) -> usize {
708 self.lane.minor_dim()
709 }
710
711 /// The number of non-zeros in this row.
712 #[inline]
713 #[must_use]
714 pub fn nnz(&self) -> usize {
715 self.lane.nnz()
716 }
717
718 /// The column indices corresponding to explicitly stored entries in this row.
719 #[inline]
720 #[must_use]
721 pub fn col_indices(&self) -> &[usize] {
722 self.lane.minor_indices()
723 }
724
725 /// The values corresponding to explicitly stored entries in this row.
726 #[inline]
727 #[must_use]
728 pub fn values(&self) -> &[T] {
729 self.lane.values()
730 }
731
732 /// Returns an entry for the given global column index.
733 ///
734 /// Each call to this function incurs the cost of a binary search among the explicitly
735 /// stored column entries.
736 #[inline]
737 #[must_use]
738 pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
739 self.lane.get_entry(global_col_index)
740 }
741 }
742 };
743}
744
745impl_csr_row_common_methods!(CsrRow<'a, T>);
746impl_csr_row_common_methods!(CsrRowMut<'a, T>);
747
748impl<'a, T> CsrRowMut<'a, T> {
749 /// Mutable access to the values corresponding to explicitly stored entries in this row.
750 #[inline]
751 pub fn values_mut(&mut self) -> &mut [T] {
752 self.lane.values_mut()
753 }
754
755 /// Provides simultaneous access to column indices and mutable values corresponding to the
756 /// explicitly stored entries in this row.
757 ///
758 /// This method primarily facilitates low-level access for methods that process data stored
759 /// in CSR format directly.
760 #[inline]
761 pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
762 self.lane.indices_and_values_mut()
763 }
764
765 /// Returns a mutable entry for the given global column index.
766 #[inline]
767 #[must_use]
768 pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
769 self.lane.get_entry_mut(global_col_index)
770 }
771}
772
773/// Row iterator for [`CsrMatrix`].
774pub struct CsrRowIter<'a, T> {
775 lane_iter: CsLaneIter<'a, T>,
776}
777
778impl<'a, T> Iterator for CsrRowIter<'a, T> {
779 type Item = CsrRow<'a, T>;
780
781 fn next(&mut self) -> Option<Self::Item> {
782 self.lane_iter.next().map(|lane| CsrRow { lane })
783 }
784}
785
786/// Mutable row iterator for [`CsrMatrix`].
787pub struct CsrRowIterMut<'a, T> {
788 lane_iter: CsLaneIterMut<'a, T>,
789}
790
791impl<'a, T> Iterator for CsrRowIterMut<'a, T>
792where
793 T: 'a,
794{
795 type Item = CsrRowMut<'a, T>;
796
797 fn next(&mut self) -> Option<Self::Item> {
798 self.lane_iter.next().map(|lane| CsrRowMut { lane })
799 }
800}