1use std::ops::{Add, AddAssign, Mul, MulAssign};
8
9use ferrolearn_core::{Dataset, FerroError};
10use ndarray::{Array1, Array2, ArrayView2};
11use num_traits::{Float, Zero};
12use sprs::CsMat;
13
14use crate::coo::CooMatrix;
15use crate::csc::CscMatrix;
16
17#[derive(Debug, Clone)]
34pub struct CsrMatrix<T> {
35 inner: CsMat<T>,
36}
37
38impl<T> CsrMatrix<T>
39where
40 T: Clone,
41{
42 pub fn new(
57 n_rows: usize,
58 n_cols: usize,
59 indptr: Vec<usize>,
60 indices: Vec<usize>,
61 data: Vec<T>,
62 ) -> Result<Self, FerroError> {
63 CsMat::try_new((n_rows, n_cols), indptr, indices, data)
64 .map(|inner| Self { inner })
65 .map_err(|(_, _, _, err)| FerroError::InvalidParameter {
66 name: "CsrMatrix raw components".into(),
67 reason: err.to_string(),
68 })
69 }
70
71 pub fn n_rows(&self) -> usize {
73 self.inner.rows()
74 }
75
76 pub fn n_cols(&self) -> usize {
78 self.inner.cols()
79 }
80
81 pub fn nnz(&self) -> usize {
83 self.inner.nnz()
84 }
85
86 pub fn inner(&self) -> &CsMat<T> {
88 &self.inner
89 }
90
91 pub fn into_inner(self) -> CsMat<T> {
93 self.inner
94 }
95
96 pub fn from_coo(coo: &CooMatrix<T>) -> Result<Self, FerroError>
104 where
105 T: Clone + Add<Output = T> + 'static,
106 {
107 let inner: CsMat<T> = coo.inner().to_csr();
108 Ok(Self { inner })
109 }
110
111 pub fn from_csc(csc: &CscMatrix<T>) -> Result<Self, FerroError>
117 where
118 T: Clone + Default + 'static,
119 {
120 let inner = csc.inner().to_csr();
121 Ok(Self { inner })
122 }
123
124 pub fn to_csc(&self) -> CscMatrix<T>
126 where
127 T: Clone + Default + 'static,
128 {
129 CscMatrix::from_inner(self.inner.to_csc())
130 }
131
132 pub fn to_coo(&self) -> CooMatrix<T> {
136 let mut coo = CooMatrix::with_capacity(self.n_rows(), self.n_cols(), self.nnz());
137 for (val, (r, c)) in self.inner.iter() {
138 let _ = coo.push(r, c, val.clone());
140 }
141 coo
142 }
143
144 pub fn to_dense(&self) -> Array2<T>
146 where
147 T: Clone + Zero + 'static,
148 {
149 self.inner.to_dense()
150 }
151
152 pub fn from_dense(dense: &ArrayView2<'_, T>, epsilon: T) -> Self
158 where
159 T: Copy + Zero + PartialOrd + num_traits::Signed + 'static,
160 {
161 let inner = CsMat::csr_from_dense(dense.view(), epsilon);
162 Self { inner }
163 }
164
165 pub fn row_slice(&self, start: usize, end: usize) -> Result<CsrMatrix<T>, FerroError>
172 where
173 T: Clone + Default + 'static,
174 {
175 if start > end {
176 return Err(FerroError::InvalidParameter {
177 name: "row_slice range".into(),
178 reason: format!("start ({start}) must be <= end ({end})"),
179 });
180 }
181 if end > self.n_rows() {
182 return Err(FerroError::InvalidParameter {
183 name: "row_slice range".into(),
184 reason: format!("end ({end}) exceeds n_rows ({})", self.n_rows()),
185 });
186 }
187 let view = self.inner.slice_outer(start..end);
188 Ok(Self {
189 inner: view.to_owned(),
190 })
191 }
192
193 pub fn scale(&mut self, scalar: T)
198 where
199 for<'r> T: MulAssign<&'r T>,
200 {
201 self.inner.scale(scalar);
202 }
203
204 pub fn mul_scalar(&self, scalar: T) -> CsrMatrix<T>
206 where
207 T: Copy + Mul<Output = T> + Zero + 'static,
208 {
209 let new_inner = self.inner.map(|&v| v * scalar);
210 Self { inner: new_inner }
211 }
212
213 pub fn add(&self, rhs: &CsrMatrix<T>) -> Result<CsrMatrix<T>, FerroError>
219 where
220 T: Zero + Default + Clone + 'static,
221 for<'r> &'r T: Add<&'r T, Output = T>,
222 {
223 if self.n_rows() != rhs.n_rows() || self.n_cols() != rhs.n_cols() {
224 return Err(FerroError::ShapeMismatch {
225 expected: vec![self.n_rows(), self.n_cols()],
226 actual: vec![rhs.n_rows(), rhs.n_cols()],
227 context: "CsrMatrix::add".into(),
228 });
229 }
230 let result = &self.inner + &rhs.inner;
231 Ok(Self { inner: result })
232 }
233
234 pub fn mul_vec(&self, rhs: &Array1<T>) -> Result<Array1<T>, FerroError>
240 where
241 T: Clone + Zero + 'static,
242 for<'r> &'r T: Mul<Output = T>,
243 T: AddAssign,
244 {
245 if rhs.len() != self.n_cols() {
246 return Err(FerroError::ShapeMismatch {
247 expected: vec![self.n_cols()],
248 actual: vec![rhs.len()],
249 context: "CsrMatrix::mul_vec".into(),
250 });
251 }
252 let result = &self.inner * rhs;
253 Ok(result)
254 }
255}
256
257impl<T> CsrMatrix<T>
258where
259 T: Float + Send + Sync + num_traits::Signed + 'static,
260{
261 pub fn from_dense_float(dense: &ArrayView2<'_, T>) -> Self {
264 CsrMatrix::from_dense(dense, T::epsilon())
265 }
266}
267
268impl<F> Dataset for CsrMatrix<F>
275where
276 F: Float + Send + Sync + 'static,
277{
278 fn n_samples(&self) -> usize {
279 self.n_rows()
280 }
281
282 fn n_features(&self) -> usize {
283 self.n_cols()
284 }
285
286 fn is_sparse(&self) -> bool {
287 true
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use approx::assert_abs_diff_eq;
295 use ndarray::array;
296
297 fn sample_csr() -> CsrMatrix<f64> {
298 CsrMatrix::new(
303 3,
304 3,
305 vec![0, 2, 3, 5],
306 vec![0, 2, 1, 0, 2],
307 vec![1.0, 2.0, 3.0, 4.0, 5.0],
308 )
309 .unwrap()
310 }
311
312 #[test]
313 fn test_new_valid() {
314 let m = sample_csr();
315 assert_eq!(m.n_rows(), 3);
316 assert_eq!(m.n_cols(), 3);
317 assert_eq!(m.nnz(), 5);
318 }
319
320 #[test]
321 fn test_new_invalid() {
322 let res = CsrMatrix::<f64>::new(2, 2, vec![0, 1], vec![0], vec![1.0]);
324 assert!(res.is_err());
325 }
326
327 #[test]
328 fn test_to_dense() {
329 let m = sample_csr();
330 let d = m.to_dense();
331 assert_abs_diff_eq!(d[[0, 0]], 1.0);
332 assert_abs_diff_eq!(d[[0, 1]], 0.0);
333 assert_abs_diff_eq!(d[[0, 2]], 2.0);
334 assert_abs_diff_eq!(d[[1, 1]], 3.0);
335 assert_abs_diff_eq!(d[[2, 0]], 4.0);
336 assert_abs_diff_eq!(d[[2, 2]], 5.0);
337 }
338
339 #[test]
340 fn test_from_dense() {
341 let dense = array![[1.0_f64, 0.0], [0.0, 2.0]];
342 let m = CsrMatrix::from_dense(&dense.view(), 0.0);
343 assert_eq!(m.nnz(), 2);
344 let back = m.to_dense();
345 assert_abs_diff_eq!(back[[0, 0]], 1.0);
346 assert_abs_diff_eq!(back[[1, 1]], 2.0);
347 }
348
349 #[test]
350 fn test_from_coo_roundtrip() {
351 let mut coo: CooMatrix<f64> = CooMatrix::new(3, 3);
352 coo.push(0, 0, 1.0).unwrap();
353 coo.push(1, 2, 4.0).unwrap();
354 coo.push(2, 1, 7.0).unwrap();
355 let csr = CsrMatrix::from_coo(&coo).unwrap();
356 let dense = csr.to_dense();
357 assert_abs_diff_eq!(dense[[0, 0]], 1.0);
358 assert_abs_diff_eq!(dense[[1, 2]], 4.0);
359 assert_abs_diff_eq!(dense[[2, 1]], 7.0);
360 assert_abs_diff_eq!(dense[[0, 1]], 0.0);
361 }
362
363 #[test]
364 fn test_to_coo_roundtrip() {
365 let csr = sample_csr();
366 let coo = csr.to_coo();
367 let back = CsrMatrix::from_coo(&coo).unwrap();
368 let d = back.to_dense();
369 assert_abs_diff_eq!(d[[0, 0]], 1.0);
370 assert_abs_diff_eq!(d[[2, 2]], 5.0);
371 }
372
373 #[test]
374 fn test_csr_csc_roundtrip() {
375 let csr = sample_csr();
376 let csc = csr.to_csc();
377 let back = CsrMatrix::from_csc(&csc).unwrap();
378 assert_eq!(back.to_dense(), csr.to_dense());
379 }
380
381 #[test]
382 fn test_row_slice() {
383 let m = sample_csr();
384 let sliced = m.row_slice(0, 2).unwrap();
385 assert_eq!(sliced.n_rows(), 2);
386 assert_eq!(sliced.n_cols(), 3);
387 let d = sliced.to_dense();
388 assert_abs_diff_eq!(d[[0, 0]], 1.0);
389 assert_abs_diff_eq!(d[[1, 1]], 3.0);
390 }
391
392 #[test]
393 fn test_row_slice_empty() {
394 let m = sample_csr();
395 let sliced = m.row_slice(1, 1).unwrap();
396 assert_eq!(sliced.n_rows(), 0);
397 }
398
399 #[test]
400 fn test_row_slice_invalid() {
401 let m = sample_csr();
402 assert!(m.row_slice(2, 1).is_err());
403 assert!(m.row_slice(0, 4).is_err());
404 }
405
406 #[test]
407 fn test_mul_scalar() {
408 let m = sample_csr();
409 let m2 = m.mul_scalar(2.0);
410 let d = m2.to_dense();
411 assert_abs_diff_eq!(d[[0, 0]], 2.0);
412 assert_abs_diff_eq!(d[[1, 1]], 6.0);
413 }
414
415 #[test]
416 fn test_scale_in_place() {
417 let mut m = sample_csr();
418 m.scale(3.0);
419 let d = m.to_dense();
420 assert_abs_diff_eq!(d[[0, 0]], 3.0);
421 assert_abs_diff_eq!(d[[2, 2]], 15.0);
422 }
423
424 #[test]
425 fn test_add() {
426 let m = sample_csr();
427 let sum = m.add(&m).unwrap();
428 let d = sum.to_dense();
429 assert_abs_diff_eq!(d[[0, 0]], 2.0);
430 assert_abs_diff_eq!(d[[1, 1]], 6.0);
431 }
432
433 #[test]
434 fn test_add_shape_mismatch() {
435 let m1 = sample_csr();
436 let m2 = CsrMatrix::new(2, 3, vec![0, 0, 0], vec![], vec![]).unwrap();
437 assert!(m1.add(&m2).is_err());
438 }
439
440 #[test]
441 fn test_mul_vec() {
442 let m = sample_csr();
443 let v = Array1::from(vec![1.0_f64, 2.0, 3.0]);
447 let result = m.mul_vec(&v).unwrap();
448 assert_abs_diff_eq!(result[0], 7.0);
449 assert_abs_diff_eq!(result[1], 6.0);
450 assert_abs_diff_eq!(result[2], 19.0);
451 }
452
453 #[test]
454 fn test_mul_vec_shape_mismatch() {
455 let m = sample_csr();
456 let v = Array1::from(vec![1.0_f64, 2.0]);
457 assert!(m.mul_vec(&v).is_err());
458 }
459
460 #[test]
461 fn test_dataset_trait() {
462 let m = sample_csr();
463 assert_eq!(m.n_samples(), 3);
464 assert_eq!(m.n_features(), 3);
465 assert!(m.is_sparse());
466 }
467
468 #[test]
469 fn test_dataset_trait_object() {
470 use ferrolearn_core::Dataset;
471 let m: CsrMatrix<f64> = sample_csr();
472 let d: &dyn Dataset = &m;
473 assert_eq!(d.n_samples(), 3);
474 assert!(d.is_sparse());
475 }
476
477 #[test]
478 fn test_from_dense_float() {
479 let dense = array![[1.0_f64, 0.0, 0.0], [0.0, 0.0, 2.0]];
480 let csr = CsrMatrix::from_dense_float(&dense.view());
481 assert_eq!(csr.nnz(), 2);
482 let back = csr.to_dense();
483 assert_abs_diff_eq!(back[[0, 0]], 1.0);
484 assert_abs_diff_eq!(back[[1, 2]], 2.0);
485 }
486}
487
488#[cfg(kani)]
502mod kani_proofs {
503 use super::*;
504 use crate::coo::CooMatrix;
505
506 const MAX_DIM: usize = 3;
508 const MAX_NNZ: usize = 4;
510
511 fn assert_csr_invariants<T>(m: &CsrMatrix<T>) {
513 let inner = m.inner();
514
515 let indptr = inner.indptr();
517 let indptr_raw = indptr.raw_storage();
518 assert!(indptr_raw.len() == m.n_rows() + 1);
519
520 for i in 0..m.n_rows() {
522 assert!(indptr_raw[i] <= indptr_raw[i + 1]);
523 }
524
525 let indices = inner.indices();
527 for &col_idx in indices {
528 assert!(col_idx < m.n_cols());
529 }
530
531 assert!(inner.indices().len() == inner.data().len());
533 }
534
535 #[kani::proof]
538 #[kani::unwind(5)]
539 fn csr_new_indptr_length() {
540 let n_rows: usize = kani::any();
541 let n_cols: usize = kani::any();
542 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
543 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
544
545 let indptr = vec![0usize; n_rows + 1];
547 let indices: Vec<usize> = vec![];
548 let data: Vec<i32> = vec![];
549
550 if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
551 let inner_indptr = m.inner().indptr();
552 assert!(inner_indptr.raw_storage().len() == n_rows + 1);
553 }
554 }
555
556 #[kani::proof]
559 #[kani::unwind(5)]
560 fn csr_new_indptr_monotonic() {
561 let n_rows: usize = kani::any();
562 let n_cols: usize = kani::any();
563 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
564 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
565
566 let row: usize = kani::any();
568 let col: usize = kani::any();
569 kani::assume(row < n_rows);
570 kani::assume(col < n_cols);
571
572 let mut indptr = vec![0usize; n_rows + 1];
574 for i in (row + 1)..=n_rows {
575 indptr[i] = 1;
576 }
577 let indices = vec![col];
578 let data = vec![42i32];
579
580 if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
581 let inner_indptr = m.inner().indptr().raw_storage().to_vec();
582 for i in 0..m.n_rows() {
583 assert!(inner_indptr[i] <= inner_indptr[i + 1]);
584 }
585 }
586 }
587
588 #[kani::proof]
590 #[kani::unwind(5)]
591 fn csr_new_column_indices_in_bounds() {
592 let n_rows: usize = kani::any();
593 let n_cols: usize = kani::any();
594 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
595 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
596
597 let col: usize = kani::any();
598 let row: usize = kani::any();
599 kani::assume(row < n_rows);
600 kani::assume(col < n_cols);
601
602 let mut indptr = vec![0usize; n_rows + 1];
603 for i in (row + 1)..=n_rows {
604 indptr[i] = 1;
605 }
606 let indices = vec![col];
607 let data = vec![1i32];
608
609 if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
610 for &c in m.inner().indices() {
611 assert!(c < m.n_cols());
612 }
613 }
614 }
615
616 #[kani::proof]
618 #[kani::unwind(5)]
619 fn csr_new_indices_data_same_length() {
620 let n_rows: usize = kani::any();
621 let n_cols: usize = kani::any();
622 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
623 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
624
625 let indptr = vec![0usize; n_rows + 1];
627 let indices: Vec<usize> = vec![];
628 let data: Vec<i32> = vec![];
629
630 if let Ok(m) = CsrMatrix::new(n_rows, n_cols, indptr, indices, data) {
631 assert!(m.inner().indices().len() == m.inner().data().len());
632 }
633 }
634
635 #[kani::proof]
638 #[kani::unwind(5)]
639 fn csr_new_rejects_mismatched_lengths() {
640 let n_rows: usize = kani::any();
641 let n_cols: usize = kani::any();
642 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
643 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
644
645 let indptr = vec![0usize; n_rows + 1];
647 let indices = vec![0usize];
648 let data: Vec<i32> = vec![];
649
650 let result = CsrMatrix::new(n_rows, n_cols, indptr, indices, data);
651 assert!(result.is_err());
652 }
653
654 #[kani::proof]
657 #[kani::unwind(5)]
658 fn csr_from_coo_invariants() {
659 let n_rows: usize = kani::any();
660 let n_cols: usize = kani::any();
661 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
662 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
663
664 let mut coo = CooMatrix::<i32>::new(n_rows, n_cols);
665
666 let do_insert: bool = kani::any();
668 if do_insert {
669 let row: usize = kani::any();
670 let col: usize = kani::any();
671 kani::assume(row < n_rows);
672 kani::assume(col < n_cols);
673 let _ = coo.push(row, col, 1i32);
674 }
675
676 if let Ok(csr) = CsrMatrix::from_coo(&coo) {
677 assert_csr_invariants(&csr);
678 assert!(csr.n_rows() == n_rows);
679 assert!(csr.n_cols() == n_cols);
680 }
681 }
682
683 #[kani::proof]
685 #[kani::unwind(5)]
686 fn csr_add_preserves_invariants() {
687 let n_rows: usize = kani::any();
688 let n_cols: usize = kani::any();
689 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
690 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
691
692 let indptr = vec![0usize; n_rows + 1];
694 let a = CsrMatrix::<i32>::new(n_rows, n_cols, indptr.clone(), vec![], vec![]);
695 let b = CsrMatrix::<i32>::new(n_rows, n_cols, indptr, vec![], vec![]);
696
697 if let (Ok(a), Ok(b)) = (a, b) {
698 if let Ok(sum) = a.add(&b) {
699 assert!(sum.n_rows() == n_rows);
701 assert!(sum.n_cols() == n_cols);
702 assert_csr_invariants(&sum);
704 }
705 }
706 }
707
708 #[kani::proof]
710 #[kani::unwind(5)]
711 fn csr_add_nonempty_preserves_invariants() {
712 let a = CsrMatrix::<i32>::new(2, 2, vec![0, 1, 1], vec![0], vec![1]);
714 let b = CsrMatrix::<i32>::new(2, 2, vec![0, 0, 1], vec![1], vec![2]);
715
716 if let (Ok(a), Ok(b)) = (a, b) {
717 if let Ok(sum) = a.add(&b) {
718 assert!(sum.n_rows() == 2);
719 assert!(sum.n_cols() == 2);
720 assert_csr_invariants(&sum);
721 }
722 }
723 }
724
725 #[kani::proof]
727 #[kani::unwind(5)]
728 fn csr_mul_vec_output_dimension() {
729 let n_rows: usize = kani::any();
730 let n_cols: usize = kani::any();
731 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
732 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
733
734 let indptr = vec![0usize; n_rows + 1];
736 let m = CsrMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
737
738 if let Ok(m) = m {
739 let v = Array1::<f64>::zeros(n_cols);
740 if let Ok(result) = m.mul_vec(&v) {
741 assert!(result.len() == n_rows);
742 }
743 }
744 }
745
746 #[kani::proof]
748 #[kani::unwind(5)]
749 fn csr_mul_vec_rejects_wrong_dimension() {
750 let n_rows: usize = kani::any();
751 let n_cols: usize = kani::any();
752 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
753 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
754
755 let indptr = vec![0usize; n_rows + 1];
756 let m = CsrMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
757
758 if let Ok(m) = m {
759 let wrong_len: usize = kani::any();
760 kani::assume(wrong_len <= MAX_DIM);
761 kani::assume(wrong_len != n_cols);
762 let v = Array1::<f64>::zeros(wrong_len);
763 let result = m.mul_vec(&v);
764 assert!(result.is_err());
765 }
766 }
767
768 #[kani::proof]
771 #[kani::unwind(5)]
772 fn csr_mul_vec_nonempty_no_oob() {
773 let m = CsrMatrix::<f64>::new(2, 3, vec![0, 1, 2], vec![1, 2], vec![3.0, 4.0]);
775 if let Ok(m) = m {
776 let v = Array1::from(vec![1.0, 2.0, 3.0]);
777 if let Ok(result) = m.mul_vec(&v) {
778 assert!(result.len() == 2);
779 }
780 }
781 }
782}