Skip to main content

ferrolearn_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format.
2//!
3//! [`CsrMatrix<T>`] is a newtype wrapper around [`sprs::CsMat<T>`] in CSR
4//! storage. CSR matrices are efficient for row-wise operations, matrix-vector
5//! products, and row slicing.
6
7use std::ops::{Add, AddAssign, Mul, MulAssign};
8
9use ferrolearn_core::{Dataset, FerroError};
10use ndarray::{Array1, Array2, ArrayView2};
11use num_traits::{Float, Zero};
12use sprs::CsMat;
13
14use crate::coo::CooMatrix;
15use crate::csc::CscMatrix;
16
17/// Compressed Sparse Row (CSR) sparse matrix.
18///
19/// Stores non-zero entries in row-major order using three arrays: `indptr`
20/// (row pointer array of length `n_rows + 1`), `indices` (column indices of
21/// each non-zero), and `data` (values of each non-zero).
22///
23/// # Type Parameter
24///
25/// `T` — the scalar element type. No bounds are required for basic structural
26/// methods; arithmetic methods impose their own bounds.
27///
28/// # Dataset Trait
29///
30/// Implements [`ferrolearn_core::Dataset`] when `T: Float + Send + Sync + 'static`,
31/// reporting `n_samples() == n_rows()`, `n_features() == n_cols()`, and
32/// `is_sparse() == true`.
33#[derive(Debug, Clone)]
34pub struct CsrMatrix<T> {
35    inner: CsMat<T>,
36}
37
38impl<T> CsrMatrix<T>
39where
40    T: Clone,
41{
42    /// Construct a CSR matrix from raw components.
43    ///
44    /// # Arguments
45    ///
46    /// * `n_rows` — number of rows.
47    /// * `n_cols` — number of columns.
48    /// * `indptr` — row pointer array of length `n_rows + 1`.
49    /// * `indices` — column index of each non-zero entry.
50    /// * `data` — value of each non-zero entry.
51    ///
52    /// # Errors
53    ///
54    /// Returns [`FerroError::InvalidParameter`] if the data is structurally
55    /// invalid (wrong lengths, out-of-bound indices, unsorted inner indices).
56    pub fn new(
57        n_rows: usize,
58        n_cols: usize,
59        indptr: Vec<usize>,
60        indices: Vec<usize>,
61        data: Vec<T>,
62    ) -> Result<Self, FerroError> {
63        CsMat::try_new((n_rows, n_cols), indptr, indices, data)
64            .map(|inner| Self { inner })
65            .map_err(|(_, _, _, err)| FerroError::InvalidParameter {
66                name: "CsrMatrix raw components".into(),
67                reason: err.to_string(),
68            })
69    }
70
71    /// Returns the number of rows.
72    pub fn n_rows(&self) -> usize {
73        self.inner.rows()
74    }
75
76    /// Returns the number of columns.
77    pub fn n_cols(&self) -> usize {
78        self.inner.cols()
79    }
80
81    /// Returns the number of stored non-zero entries.
82    pub fn nnz(&self) -> usize {
83        self.inner.nnz()
84    }
85
86    /// Returns a reference to the underlying [`sprs::CsMat<T>`].
87    pub fn inner(&self) -> &CsMat<T> {
88        &self.inner
89    }
90
91    /// Consume this matrix and return the underlying [`sprs::CsMat<T>`].
92    pub fn into_inner(self) -> CsMat<T> {
93        self.inner
94    }
95
96    /// Construct a [`CsrMatrix`] from a [`CooMatrix`] by converting to CSR.
97    ///
98    /// Duplicate entries at the same position are summed.
99    ///
100    /// # Errors
101    ///
102    /// This conversion is always successful for structurally valid inputs.
103    pub fn from_coo(coo: &CooMatrix<T>) -> Result<Self, FerroError>
104    where
105        T: Clone + Add<Output = T> + 'static,
106    {
107        let inner: CsMat<T> = coo.inner().to_csr();
108        Ok(Self { inner })
109    }
110
111    /// Construct a [`CsrMatrix`] from a [`CscMatrix`].
112    ///
113    /// # Errors
114    ///
115    /// This conversion is always successful.
116    pub fn from_csc(csc: &CscMatrix<T>) -> Result<Self, FerroError>
117    where
118        T: Clone + Default + 'static,
119    {
120        let inner = csc.inner().to_csr();
121        Ok(Self { inner })
122    }
123
124    /// Convert to [`CscMatrix`].
125    pub fn to_csc(&self) -> CscMatrix<T>
126    where
127        T: Clone + Default + 'static,
128    {
129        CscMatrix::from_inner(self.inner.to_csc())
130    }
131
132    /// Convert to [`CooMatrix`].
133    ///
134    /// Each non-zero becomes one triplet entry.
135    pub fn to_coo(&self) -> CooMatrix<T> {
136        let mut coo = CooMatrix::with_capacity(self.n_rows(), self.n_cols(), self.nnz());
137        for (val, (r, c)) in self.inner.iter() {
138            // indices come from a valid matrix so push is infallible here
139            let _ = coo.push(r, c, val.clone());
140        }
141        coo
142    }
143
144    /// Convert this sparse matrix to a dense [`Array2<T>`].
145    pub fn to_dense(&self) -> Array2<T>
146    where
147        T: Clone + Zero + 'static,
148    {
149        self.inner.to_dense()
150    }
151
152    /// Construct a [`CsrMatrix`] from a dense [`Array2<T>`], dropping entries
153    /// whose absolute value is less than or equal to `epsilon`.
154    ///
155    /// Entries `v` where `|v| <= epsilon` are treated as structural zeros.
156    /// For integer types, pass `epsilon = 0`.
157    pub fn from_dense(dense: &ArrayView2<'_, T>, epsilon: T) -> Self
158    where
159        T: Copy + Zero + PartialOrd + num_traits::Signed + 'static,
160    {
161        let inner = CsMat::csr_from_dense(dense.view(), epsilon);
162        Self { inner }
163    }
164
165    /// Return a new CSR matrix containing only the rows in `start..end`.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`FerroError::InvalidParameter`] if `start > end` or
170    /// `end > n_rows()`.
171    pub fn row_slice(&self, start: usize, end: usize) -> Result<CsrMatrix<T>, FerroError>
172    where
173        T: Clone + Default + 'static,
174    {
175        if start > end {
176            return Err(FerroError::InvalidParameter {
177                name: "row_slice range".into(),
178                reason: format!("start ({start}) must be <= end ({end})"),
179            });
180        }
181        if end > self.n_rows() {
182            return Err(FerroError::InvalidParameter {
183                name: "row_slice range".into(),
184                reason: format!("end ({end}) exceeds n_rows ({})", self.n_rows()),
185            });
186        }
187        let view = self.inner.slice_outer(start..end);
188        Ok(Self {
189            inner: view.to_owned(),
190        })
191    }
192
193    /// Scalar multiplication in-place: multiplies every non-zero by `scalar`.
194    ///
195    /// Requires `T: for<'r> MulAssign<&'r T>`, which is satisfied by all
196    /// primitive numeric types.
197    pub fn scale(&mut self, scalar: T)
198    where
199        for<'r> T: MulAssign<&'r T>,
200    {
201        self.inner.scale(scalar);
202    }
203
204    /// Scalar multiplication returning a new matrix.
205    pub fn mul_scalar(&self, scalar: T) -> CsrMatrix<T>
206    where
207        T: Copy + Mul<Output = T> + Zero + 'static,
208    {
209        let new_inner = self.inner.map(|&v| v * scalar);
210        Self { inner: new_inner }
211    }
212
213    /// Element-wise addition of two CSR matrices with the same shape.
214    ///
215    /// # Errors
216    ///
217    /// Returns [`FerroError::ShapeMismatch`] if the matrices have different shapes.
218    pub fn add(&self, rhs: &CsrMatrix<T>) -> Result<CsrMatrix<T>, FerroError>
219    where
220        T: Zero + Default + Clone + 'static,
221        for<'r> &'r T: Add<&'r T, Output = T>,
222    {
223        if self.n_rows() != rhs.n_rows() || self.n_cols() != rhs.n_cols() {
224            return Err(FerroError::ShapeMismatch {
225                expected: vec![self.n_rows(), self.n_cols()],
226                actual: vec![rhs.n_rows(), rhs.n_cols()],
227                context: "CsrMatrix::add".into(),
228            });
229        }
230        let result = &self.inner + &rhs.inner;
231        Ok(Self { inner: result })
232    }
233
234    /// Sparse matrix-dense vector product: computes `self * rhs`.
235    ///
236    /// # Errors
237    ///
238    /// Returns [`FerroError::ShapeMismatch`] if `rhs.len() != n_cols()`.
239    pub fn mul_vec(&self, rhs: &Array1<T>) -> Result<Array1<T>, FerroError>
240    where
241        T: Clone + Zero + 'static,
242        for<'r> &'r T: Mul<Output = T>,
243        T: AddAssign,
244    {
245        if rhs.len() != self.n_cols() {
246            return Err(FerroError::ShapeMismatch {
247                expected: vec![self.n_cols()],
248                actual: vec![rhs.len()],
249                context: "CsrMatrix::mul_vec".into(),
250            });
251        }
252        let result = &self.inner * rhs;
253        Ok(result)
254    }
255}
256
257impl<T> CsrMatrix<T>
258where
259    T: Float + Send + Sync + num_traits::Signed + 'static,
260{
261    /// Construct a [`CsrMatrix`] from a dense [`Array2<T>`], treating entries
262    /// with absolute value at or below `T::epsilon()` as structural zeros.
263    pub fn from_dense_float(dense: &ArrayView2<'_, T>) -> Self {
264        CsrMatrix::from_dense(dense, T::epsilon())
265    }
266}
267
268/// Implements [`Dataset`] so that `CsrMatrix<F>` can be passed to any
269/// ferrolearn algorithm that accepts a dataset.
270///
271/// - `n_samples()` — number of rows (one sample per row).
272/// - `n_features()` — number of columns (one feature per column).
273/// - `is_sparse()` — always `true`.
274impl<F> Dataset for CsrMatrix<F>
275where
276    F: Float + Send + Sync + 'static,
277{
278    fn n_samples(&self) -> usize {
279        self.n_rows()
280    }
281
282    fn n_features(&self) -> usize {
283        self.n_cols()
284    }
285
286    fn is_sparse(&self) -> bool {
287        true
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use approx::assert_abs_diff_eq;
295    use ndarray::array;
296
297    fn sample_csr() -> CsrMatrix<f64> {
298        // 3x3 sparse matrix:
299        // [1 0 2]
300        // [0 3 0]
301        // [4 0 5]
302        CsrMatrix::new(
303            3,
304            3,
305            vec![0, 2, 3, 5],
306            vec![0, 2, 1, 0, 2],
307            vec![1.0, 2.0, 3.0, 4.0, 5.0],
308        )
309        .unwrap()
310    }
311
312    #[test]
313    fn test_new_valid() {
314        let m = sample_csr();
315        assert_eq!(m.n_rows(), 3);
316        assert_eq!(m.n_cols(), 3);
317        assert_eq!(m.nnz(), 5);
318    }
319
320    #[test]
321    fn test_new_invalid() {
322        // Wrong indptr length (needs n_rows+1 = 3, not 2)
323        let res = CsrMatrix::<f64>::new(2, 2, vec![0, 1], vec![0], vec![1.0]);
324        assert!(res.is_err());
325    }
326
327    #[test]
328    fn test_to_dense() {
329        let m = sample_csr();
330        let d = m.to_dense();
331        assert_abs_diff_eq!(d[[0, 0]], 1.0);
332        assert_abs_diff_eq!(d[[0, 1]], 0.0);
333        assert_abs_diff_eq!(d[[0, 2]], 2.0);
334        assert_abs_diff_eq!(d[[1, 1]], 3.0);
335        assert_abs_diff_eq!(d[[2, 0]], 4.0);
336        assert_abs_diff_eq!(d[[2, 2]], 5.0);
337    }
338
339    #[test]
340    fn test_from_dense() {
341        let dense = array![[1.0_f64, 0.0], [0.0, 2.0]];
342        let m = CsrMatrix::from_dense(&dense.view(), 0.0);
343        assert_eq!(m.nnz(), 2);
344        let back = m.to_dense();
345        assert_abs_diff_eq!(back[[0, 0]], 1.0);
346        assert_abs_diff_eq!(back[[1, 1]], 2.0);
347    }
348
349    #[test]
350    fn test_from_coo_roundtrip() {
351        let mut coo: CooMatrix<f64> = CooMatrix::new(3, 3);
352        coo.push(0, 0, 1.0).unwrap();
353        coo.push(1, 2, 4.0).unwrap();
354        coo.push(2, 1, 7.0).unwrap();
355        let csr = CsrMatrix::from_coo(&coo).unwrap();
356        let dense = csr.to_dense();
357        assert_abs_diff_eq!(dense[[0, 0]], 1.0);
358        assert_abs_diff_eq!(dense[[1, 2]], 4.0);
359        assert_abs_diff_eq!(dense[[2, 1]], 7.0);
360        assert_abs_diff_eq!(dense[[0, 1]], 0.0);
361    }
362
363    #[test]
364    fn test_to_coo_roundtrip() {
365        let csr = sample_csr();
366        let coo = csr.to_coo();
367        let back = CsrMatrix::from_coo(&coo).unwrap();
368        let d = back.to_dense();
369        assert_abs_diff_eq!(d[[0, 0]], 1.0);
370        assert_abs_diff_eq!(d[[2, 2]], 5.0);
371    }
372
373    #[test]
374    fn test_csr_csc_roundtrip() {
375        let csr = sample_csr();
376        let csc = csr.to_csc();
377        let back = CsrMatrix::from_csc(&csc).unwrap();
378        assert_eq!(back.to_dense(), csr.to_dense());
379    }
380
381    #[test]
382    fn test_row_slice() {
383        let m = sample_csr();
384        let sliced = m.row_slice(0, 2).unwrap();
385        assert_eq!(sliced.n_rows(), 2);
386        assert_eq!(sliced.n_cols(), 3);
387        let d = sliced.to_dense();
388        assert_abs_diff_eq!(d[[0, 0]], 1.0);
389        assert_abs_diff_eq!(d[[1, 1]], 3.0);
390    }
391
392    #[test]
393    fn test_row_slice_empty() {
394        let m = sample_csr();
395        let sliced = m.row_slice(1, 1).unwrap();
396        assert_eq!(sliced.n_rows(), 0);
397    }
398
399    #[test]
400    fn test_row_slice_invalid() {
401        let m = sample_csr();
402        assert!(m.row_slice(2, 1).is_err());
403        assert!(m.row_slice(0, 4).is_err());
404    }
405
406    #[test]
407    fn test_mul_scalar() {
408        let m = sample_csr();
409        let m2 = m.mul_scalar(2.0);
410        let d = m2.to_dense();
411        assert_abs_diff_eq!(d[[0, 0]], 2.0);
412        assert_abs_diff_eq!(d[[1, 1]], 6.0);
413    }
414
415    #[test]
416    fn test_scale_in_place() {
417        let mut m = sample_csr();
418        m.scale(3.0);
419        let d = m.to_dense();
420        assert_abs_diff_eq!(d[[0, 0]], 3.0);
421        assert_abs_diff_eq!(d[[2, 2]], 15.0);
422    }
423
424    #[test]
425    fn test_add() {
426        let m = sample_csr();
427        let sum = m.add(&m).unwrap();
428        let d = sum.to_dense();
429        assert_abs_diff_eq!(d[[0, 0]], 2.0);
430        assert_abs_diff_eq!(d[[1, 1]], 6.0);
431    }
432
433    #[test]
434    fn test_add_shape_mismatch() {
435        let m1 = sample_csr();
436        let m2 = CsrMatrix::new(2, 3, vec![0, 0, 0], vec![], vec![]).unwrap();
437        assert!(m1.add(&m2).is_err());
438    }
439
440    #[test]
441    fn test_mul_vec() {
442        let m = sample_csr();
443        // [1 0 2]   [1]   [7]
444        // [0 3 0] * [2] = [6]
445        // [4 0 5]   [3]   [19]
446        let v = Array1::from(vec![1.0_f64, 2.0, 3.0]);
447        let result = m.mul_vec(&v).unwrap();
448        assert_abs_diff_eq!(result[0], 7.0);
449        assert_abs_diff_eq!(result[1], 6.0);
450        assert_abs_diff_eq!(result[2], 19.0);
451    }
452
453    #[test]
454    fn test_mul_vec_shape_mismatch() {
455        let m = sample_csr();
456        let v = Array1::from(vec![1.0_f64, 2.0]);
457        assert!(m.mul_vec(&v).is_err());
458    }
459
460    #[test]
461    fn test_dataset_trait() {
462        let m = sample_csr();
463        assert_eq!(m.n_samples(), 3);
464        assert_eq!(m.n_features(), 3);
465        assert!(m.is_sparse());
466    }
467
468    #[test]
469    fn test_dataset_trait_object() {
470        use ferrolearn_core::Dataset;
471        let m: CsrMatrix<f64> = sample_csr();
472        let d: &dyn Dataset = &m;
473        assert_eq!(d.n_samples(), 3);
474        assert!(d.is_sparse());
475    }
476
477    #[test]
478    fn test_from_dense_float() {
479        let dense = array![[1.0_f64, 0.0, 0.0], [0.0, 0.0, 2.0]];
480        let csr = CsrMatrix::from_dense_float(&dense.view());
481        assert_eq!(csr.nnz(), 2);
482        let back = csr.to_dense();
483        assert_abs_diff_eq!(back[[0, 0]], 1.0);
484        assert_abs_diff_eq!(back[[1, 2]], 2.0);
485    }
486}
487
488/// Kani proof harnesses for CsrMatrix structural invariants.
489///
490/// These proofs verify that after construction via `new()`, `from_coo()`, and
491/// `add()`, the underlying CSR representation satisfies all structural
492/// invariants:
493///
494/// - `indptr.len() == n_rows + 1`
495/// - `indptr` is monotonically non-decreasing
496/// - All column indices are less than `n_cols`
497/// - `indices.len() == data.len()`
498///
499/// All proofs use small symbolic bounds (at most 3 rows/cols) because sparse
500/// matrix verification is computationally expensive for Kani.
501#[cfg(kani)]
502mod kani_proofs {
503    use super::*;
504    use crate::coo::CooMatrix;
505
506    /// Maximum dimension for symbolic exploration.
507    const MAX_DIM: usize = 3;
508    /// Maximum number of non-zero entries for symbolic exploration.
509    const MAX_NNZ: usize = 4;
510
511    /// Helper: assert all CSR structural invariants on the inner `CsMat`.
512    fn assert_csr_invariants<T>(m: &CsrMatrix<T>) {
513        let inner = m.inner();
514
515        // Invariant 1: indptr length == n_rows + 1
516        let indptr = inner.indptr();
517        let indptr_raw = indptr.raw_storage();
518        assert!(indptr_raw.len() == m.n_rows() + 1);
519
520        // Invariant 2: indptr is monotonically non-decreasing
521        for i in 0..m.n_rows() {
522            assert!(indptr_raw[i] <= indptr_raw[i + 1]);
523        }
524
525        // Invariant 3: all column indices < n_cols
526        let indices = inner.indices();
527        for &col_idx in indices {
528            assert!(col_idx < m.n_cols());
529        }
530
531        // Invariant 4: indices.len() == data.len()
532        assert!(inner.indices().len() == inner.data().len());
533    }
534
535    /// Verify `indptr.len() == n_rows + 1` after `new()` with a symbolic
536    /// empty matrix of arbitrary dimensions.
537    #[kani::proof]
538    #[kani::unwind(5)]
539    fn csr_new_indptr_length() {
540        let n_rows: usize = kani::any();
541        let n_cols: usize = kani::any();
542        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
543        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
544
545        // Build a valid empty CSR matrix
546        let indptr = vec![0usize; n_rows + 1];
547        let indices: Vec<usize> = vec![];
548        let data: Vec<i32> = vec![];
549
550        if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
551            let inner_indptr = m.inner().indptr();
552            assert!(inner_indptr.raw_storage().len() == n_rows + 1);
553        }
554    }
555
556    /// Verify indptr monotonicity after `new()` with a symbolic single-entry
557    /// matrix.
558    #[kani::proof]
559    #[kani::unwind(5)]
560    fn csr_new_indptr_monotonic() {
561        let n_rows: usize = kani::any();
562        let n_cols: usize = kani::any();
563        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
564        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
565
566        // Place a single non-zero at a symbolic valid position
567        let row: usize = kani::any();
568        let col: usize = kani::any();
569        kani::assume(row < n_rows);
570        kani::assume(col < n_cols);
571
572        // Build indptr for a single entry at (row, col)
573        let mut indptr = vec![0usize; n_rows + 1];
574        for i in (row + 1)..=n_rows {
575            indptr[i] = 1;
576        }
577        let indices = vec![col];
578        let data = vec![42i32];
579
580        if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
581            let inner_indptr = m.inner().indptr().raw_storage().to_vec();
582            for i in 0..m.n_rows() {
583                assert!(inner_indptr[i] <= inner_indptr[i + 1]);
584            }
585        }
586    }
587
588    /// Verify all column indices < n_cols after `new()`.
589    #[kani::proof]
590    #[kani::unwind(5)]
591    fn csr_new_column_indices_in_bounds() {
592        let n_rows: usize = kani::any();
593        let n_cols: usize = kani::any();
594        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
595        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
596
597        let col: usize = kani::any();
598        let row: usize = kani::any();
599        kani::assume(row < n_rows);
600        kani::assume(col < n_cols);
601
602        let mut indptr = vec![0usize; n_rows + 1];
603        for i in (row + 1)..=n_rows {
604            indptr[i] = 1;
605        }
606        let indices = vec![col];
607        let data = vec![1i32];
608
609        if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
610            for &c in m.inner().indices() {
611                assert!(c < m.n_cols());
612            }
613        }
614    }
615
616    /// Verify `indices.len() == data.len()` after `new()`.
617    #[kani::proof]
618    #[kani::unwind(5)]
619    fn csr_new_indices_data_same_length() {
620        let n_rows: usize = kani::any();
621        let n_cols: usize = kani::any();
622        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
623        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
624
625        // Try empty matrix
626        let indptr = vec![0usize; n_rows + 1];
627        let indices: Vec<usize> = vec![];
628        let data: Vec<i32> = vec![];
629
630        if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
631            assert!(m.inner().indices().len() == m.inner().data().len());
632        }
633    }
634
635    /// Verify that `new()` rejects inputs where indices and data have
636    /// mismatched lengths.
637    #[kani::proof]
638    #[kani::unwind(5)]
639    fn csr_new_rejects_mismatched_lengths() {
640        let n_rows: usize = kani::any();
641        let n_cols: usize = kani::any();
642        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
643        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
644
645        // indices has 1 element, data has 0 — must fail
646        let indptr = vec![0usize; n_rows + 1];
647        let indices = vec![0usize];
648        let data: Vec<i32> = vec![];
649
650        let result = CsrMatrix::new(n_rows, n_cols, indptr, indices, data);
651        assert!(result.is_err());
652    }
653
654    /// Verify all structural invariants after `from_coo()` with symbolic
655    /// entries.
656    #[kani::proof]
657    #[kani::unwind(5)]
658    fn csr_from_coo_invariants() {
659        let n_rows: usize = kani::any();
660        let n_cols: usize = kani::any();
661        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
662        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
663
664        let mut coo = CooMatrix::<i32>::new(n_rows, n_cols);
665
666        // Insert a symbolic number of entries (0 or 1)
667        let do_insert: bool = kani::any();
668        if do_insert {
669            let row: usize = kani::any();
670            let col: usize = kani::any();
671            kani::assume(row < n_rows);
672            kani::assume(col < n_cols);
673            let _ = coo.push(row, col, 1i32);
674        }
675
676        if let Ok(csr) = CsrMatrix::from_coo(&coo) {
677            assert_csr_invariants(&csr);
678            assert!(csr.n_rows() == n_rows);
679            assert!(csr.n_cols() == n_cols);
680        }
681    }
682
683    /// Verify that `add()` preserves shape and structural invariants.
684    #[kani::proof]
685    #[kani::unwind(5)]
686    fn csr_add_preserves_invariants() {
687        let n_rows: usize = kani::any();
688        let n_cols: usize = kani::any();
689        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
690        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
691
692        // Build two valid empty CSR matrices of the same shape
693        let indptr = vec![0usize; n_rows + 1];
694        let a = CsrMatrix::<i32>::new(n_rows, n_cols, indptr.clone(), vec![], vec![]);
695        let b = CsrMatrix::<i32>::new(n_rows, n_cols, indptr, vec![], vec![]);
696
697        if let (Ok(a), Ok(b)) = (a, b) {
698            if let Ok(sum) = a.add(&b) {
699                // Shape is preserved
700                assert!(sum.n_rows() == n_rows);
701                assert!(sum.n_cols() == n_cols);
702                // Structural invariants hold
703                assert_csr_invariants(&sum);
704            }
705        }
706    }
707
708    /// Verify that `add()` with non-empty matrices preserves invariants.
709    #[kani::proof]
710    #[kani::unwind(5)]
711    fn csr_add_nonempty_preserves_invariants() {
712        // Fixed 2x2 matrices with one entry each in different positions
713        let a = CsrMatrix::<i32>::new(2, 2, vec![0, 1, 1], vec![0], vec![1]);
714        let b = CsrMatrix::<i32>::new(2, 2, vec![0, 0, 1], vec![1], vec![2]);
715
716        if let (Ok(a), Ok(b)) = (a, b) {
717            if let Ok(sum) = a.add(&b) {
718                assert!(sum.n_rows() == 2);
719                assert!(sum.n_cols() == 2);
720                assert_csr_invariants(&sum);
721            }
722        }
723    }
724
725    /// Verify `mul_vec()` output has correct dimension and does not panic.
726    #[kani::proof]
727    #[kani::unwind(5)]
728    fn csr_mul_vec_output_dimension() {
729        let n_rows: usize = kani::any();
730        let n_cols: usize = kani::any();
731        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
732        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
733
734        // Empty matrix for tractable verification
735        let indptr = vec![0usize; n_rows + 1];
736        let m = CsrMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
737
738        if let Ok(m) = m {
739            let v = Array1::<f64>::zeros(n_cols);
740            if let Ok(result) = m.mul_vec(&v) {
741                assert!(result.len() == n_rows);
742            }
743        }
744    }
745
746    /// Verify `mul_vec()` rejects vectors of wrong dimension.
747    #[kani::proof]
748    #[kani::unwind(5)]
749    fn csr_mul_vec_rejects_wrong_dimension() {
750        let n_rows: usize = kani::any();
751        let n_cols: usize = kani::any();
752        kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
753        kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
754
755        let indptr = vec![0usize; n_rows + 1];
756        let m = CsrMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
757
758        if let Ok(m) = m {
759            let wrong_len: usize = kani::any();
760            kani::assume(wrong_len <= MAX_DIM);
761            kani::assume(wrong_len != n_cols);
762            let v = Array1::<f64>::zeros(wrong_len);
763            let result = m.mul_vec(&v);
764            assert!(result.is_err());
765        }
766    }
767
768    /// Verify `mul_vec()` with a non-empty matrix produces the correct
769    /// output dimension and does not trigger any out-of-bounds access.
770    #[kani::proof]
771    #[kani::unwind(5)]
772    fn csr_mul_vec_nonempty_no_oob() {
773        // 2x3 matrix with entries at (0,1) and (1,2)
774        let m = CsrMatrix::<f64>::new(2, 3, vec![0, 1, 2], vec![1, 2], vec![3.0, 4.0]);
775        if let Ok(m) = m {
776            let v = Array1::from(vec![1.0, 2.0, 3.0]);
777            if let Ok(result) = m.mul_vec(&v) {
778                assert!(result.len() == 2);
779            }
780        }
781    }
782}