1use std::ops::{Add, AddAssign, Mul, MulAssign};
8
9use ferrolearn_core::FerroError;
10use ndarray::{Array1, Array2, ArrayView2};
11use num_traits::Zero;
12use sprs::CsMat;
13
14use crate::coo::CooMatrix;
15use crate::csr::CsrMatrix;
16
17#[derive(Debug, Clone)]
27pub struct CscMatrix<T> {
28 inner: CsMat<T>,
29}
30
31impl<T> CscMatrix<T>
32where
33 T: Clone,
34{
35 pub fn new(
50 n_rows: usize,
51 n_cols: usize,
52 indptr: Vec<usize>,
53 indices: Vec<usize>,
54 data: Vec<T>,
55 ) -> Result<Self, FerroError> {
56 CsMat::try_new_csc((n_rows, n_cols), indptr, indices, data)
57 .map(|inner| Self { inner })
58 .map_err(|(_, _, _, err)| FerroError::InvalidParameter {
59 name: "CscMatrix raw components".into(),
60 reason: err.to_string(),
61 })
62 }
63
64 pub(crate) fn from_inner(inner: CsMat<T>) -> Self {
68 debug_assert!(inner.is_csc(), "inner matrix must be in CSC storage");
69 Self { inner }
70 }
71
72 pub fn n_rows(&self) -> usize {
74 self.inner.rows()
75 }
76
77 pub fn n_cols(&self) -> usize {
79 self.inner.cols()
80 }
81
82 pub fn nnz(&self) -> usize {
84 self.inner.nnz()
85 }
86
87 pub fn inner(&self) -> &CsMat<T> {
89 &self.inner
90 }
91
92 pub fn into_inner(self) -> CsMat<T> {
94 self.inner
95 }
96
97 pub fn from_coo(coo: &CooMatrix<T>) -> Result<Self, FerroError>
105 where
106 T: Clone + Add<Output = T> + 'static,
107 {
108 let inner: CsMat<T> = coo.inner().to_csc();
109 Ok(Self { inner })
110 }
111
112 pub fn from_csr(csr: &CsrMatrix<T>) -> Result<Self, FerroError>
118 where
119 T: Clone + Default + 'static,
120 {
121 Ok(csr.to_csc())
122 }
123
124 pub fn to_csr(&self) -> CsrMatrix<T>
126 where
127 T: Clone + Default + 'static,
128 {
129 CsrMatrix::from_csc(self).unwrap()
131 }
132
133 pub fn to_coo(&self) -> CooMatrix<T> {
135 let mut coo = CooMatrix::with_capacity(self.n_rows(), self.n_cols(), self.nnz());
136 for (val, (r, c)) in self.inner.iter() {
137 let _ = coo.push(r, c, val.clone());
139 }
140 coo
141 }
142
143 pub fn to_dense(&self) -> Array2<T>
145 where
146 T: Clone + Zero + 'static,
147 {
148 self.inner.to_dense()
149 }
150
151 pub fn from_dense(dense: &ArrayView2<'_, T>, epsilon: T) -> Self
154 where
155 T: Copy + Zero + PartialOrd + num_traits::Signed + 'static,
156 {
157 let inner = CsMat::csc_from_dense(dense.view(), epsilon);
158 Self { inner }
159 }
160
161 pub fn col_slice(&self, start: usize, end: usize) -> Result<CscMatrix<T>, FerroError>
168 where
169 T: Clone + Default + 'static,
170 {
171 if start > end {
172 return Err(FerroError::InvalidParameter {
173 name: "col_slice range".into(),
174 reason: format!("start ({start}) must be <= end ({end})"),
175 });
176 }
177 if end > self.n_cols() {
178 return Err(FerroError::InvalidParameter {
179 name: "col_slice range".into(),
180 reason: format!("end ({end}) exceeds n_cols ({})", self.n_cols()),
181 });
182 }
183 let view = self.inner.slice_outer(start..end);
184 Ok(Self {
185 inner: view.to_owned(),
186 })
187 }
188
189 pub fn scale(&mut self, scalar: T)
194 where
195 for<'r> T: MulAssign<&'r T>,
196 {
197 self.inner.scale(scalar);
198 }
199
200 pub fn mul_scalar(&self, scalar: T) -> CscMatrix<T>
202 where
203 T: Copy + Mul<Output = T> + Zero + 'static,
204 {
205 let new_inner = self.inner.map(|&v| v * scalar);
206 Self { inner: new_inner }
207 }
208
209 pub fn add(&self, rhs: &CscMatrix<T>) -> Result<CscMatrix<T>, FerroError>
215 where
216 T: Zero + Default + Clone + 'static,
217 for<'r> &'r T: Add<&'r T, Output = T>,
218 {
219 if self.n_rows() != rhs.n_rows() || self.n_cols() != rhs.n_cols() {
220 return Err(FerroError::ShapeMismatch {
221 expected: vec![self.n_rows(), self.n_cols()],
222 actual: vec![rhs.n_rows(), rhs.n_cols()],
223 context: "CscMatrix::add".into(),
224 });
225 }
226 let result = &self.inner + &rhs.inner;
227 Ok(Self { inner: result })
228 }
229
230 pub fn mul_vec(&self, rhs: &Array1<T>) -> Result<Array1<T>, FerroError>
236 where
237 T: Clone + Zero + 'static,
238 for<'r> &'r T: Mul<Output = T>,
239 T: AddAssign,
240 {
241 if rhs.len() != self.n_cols() {
242 return Err(FerroError::ShapeMismatch {
243 expected: vec![self.n_cols()],
244 actual: vec![rhs.len()],
245 context: "CscMatrix::mul_vec".into(),
246 });
247 }
248 let result = &self.inner * rhs;
249 Ok(result)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use approx::assert_abs_diff_eq;
257 use ndarray::array;
258
259 fn sample_csc() -> CscMatrix<f64> {
260 CscMatrix::new(
270 3,
271 3,
272 vec![0, 2, 3, 5],
273 vec![0, 2, 1, 0, 2],
274 vec![1.0, 4.0, 3.0, 2.0, 5.0],
275 )
276 .unwrap()
277 }
278
279 #[test]
280 fn test_new_valid() {
281 let m = sample_csc();
282 assert_eq!(m.n_rows(), 3);
283 assert_eq!(m.n_cols(), 3);
284 assert_eq!(m.nnz(), 5);
285 }
286
287 #[test]
288 fn test_to_dense() {
289 let m = sample_csc();
290 let d = m.to_dense();
291 assert_abs_diff_eq!(d[[0, 0]], 1.0);
292 assert_abs_diff_eq!(d[[0, 2]], 2.0);
293 assert_abs_diff_eq!(d[[1, 1]], 3.0);
294 assert_abs_diff_eq!(d[[2, 0]], 4.0);
295 assert_abs_diff_eq!(d[[2, 2]], 5.0);
296 }
297
298 #[test]
299 fn test_from_dense() {
300 let dense = array![[1.0_f64, 0.0], [0.0, 2.0]];
301 let m = CscMatrix::from_dense(&dense.view(), 0.0);
302 assert_eq!(m.nnz(), 2);
303 let back = m.to_dense();
304 assert_abs_diff_eq!(back[[0, 0]], 1.0);
305 assert_abs_diff_eq!(back[[1, 1]], 2.0);
306 }
307
308 #[test]
309 fn test_from_coo_roundtrip() {
310 let mut coo: CooMatrix<f64> = CooMatrix::new(3, 3);
311 coo.push(0, 0, 1.0).unwrap();
312 coo.push(1, 2, 4.0).unwrap();
313 coo.push(2, 1, 7.0).unwrap();
314 let csc = CscMatrix::from_coo(&coo).unwrap();
315 let dense = csc.to_dense();
316 assert_abs_diff_eq!(dense[[0, 0]], 1.0);
317 assert_abs_diff_eq!(dense[[1, 2]], 4.0);
318 assert_abs_diff_eq!(dense[[2, 1]], 7.0);
319 }
320
321 #[test]
322 fn test_csc_csr_roundtrip() {
323 let csc = sample_csc();
324 let csr = csc.to_csr();
325 let back = CscMatrix::from_csr(&csr).unwrap();
326 assert_eq!(back.to_dense(), csc.to_dense());
327 }
328
329 #[test]
330 fn test_col_slice() {
331 let m = sample_csc();
332 let sliced = m.col_slice(0, 2).unwrap();
333 assert_eq!(sliced.n_rows(), 3);
334 assert_eq!(sliced.n_cols(), 2);
335 let d = sliced.to_dense();
336 assert_abs_diff_eq!(d[[0, 0]], 1.0);
337 assert_abs_diff_eq!(d[[1, 1]], 3.0);
338 }
339
340 #[test]
341 fn test_col_slice_empty() {
342 let m = sample_csc();
343 let sliced = m.col_slice(1, 1).unwrap();
344 assert_eq!(sliced.n_cols(), 0);
345 }
346
347 #[test]
348 fn test_col_slice_invalid() {
349 let m = sample_csc();
350 assert!(m.col_slice(2, 1).is_err());
351 assert!(m.col_slice(0, 4).is_err());
352 }
353
354 #[test]
355 fn test_mul_scalar() {
356 let m = sample_csc();
357 let m2 = m.mul_scalar(2.0);
358 let d = m2.to_dense();
359 assert_abs_diff_eq!(d[[0, 0]], 2.0);
360 assert_abs_diff_eq!(d[[1, 1]], 6.0);
361 }
362
363 #[test]
364 fn test_scale_in_place() {
365 let mut m = sample_csc();
366 m.scale(3.0);
367 let d = m.to_dense();
368 assert_abs_diff_eq!(d[[0, 0]], 3.0);
369 assert_abs_diff_eq!(d[[2, 2]], 15.0);
370 }
371
372 #[test]
373 fn test_add() {
374 let m = sample_csc();
375 let sum = m.add(&m).unwrap();
376 let d = sum.to_dense();
377 assert_abs_diff_eq!(d[[0, 0]], 2.0);
378 assert_abs_diff_eq!(d[[1, 1]], 6.0);
379 }
380
381 #[test]
382 fn test_add_shape_mismatch() {
383 let m1 = sample_csc();
384 let m2 = CscMatrix::new(2, 3, vec![0, 0, 0, 0], vec![], vec![]).unwrap();
385 assert!(m1.add(&m2).is_err());
386 }
387
388 #[test]
389 fn test_mul_vec() {
390 let m = sample_csc();
391 let v = Array1::from(vec![1.0_f64, 2.0, 3.0]);
392 let result = m.mul_vec(&v).unwrap();
393 assert_abs_diff_eq!(result[0], 7.0);
394 assert_abs_diff_eq!(result[1], 6.0);
395 assert_abs_diff_eq!(result[2], 19.0);
396 }
397
398 #[test]
399 fn test_mul_vec_shape_mismatch() {
400 let m = sample_csc();
401 let v = Array1::from(vec![1.0_f64, 2.0]);
402 assert!(m.mul_vec(&v).is_err());
403 }
404}
405
406#[cfg(kani)]
420mod kani_proofs {
421 use super::*;
422 use crate::coo::CooMatrix;
423
424 const MAX_DIM: usize = 3;
426
427 fn assert_csc_invariants<T>(m: &CscMatrix<T>) {
429 let inner = m.inner();
430
431 let indptr = inner.indptr();
433 let indptr_raw = indptr.raw_storage();
434 assert!(indptr_raw.len() == m.n_cols() + 1);
435
436 for i in 0..m.n_cols() {
438 assert!(indptr_raw[i] <= indptr_raw[i + 1]);
439 }
440
441 let indices = inner.indices();
443 for &row_idx in indices {
444 assert!(row_idx < m.n_rows());
445 }
446
447 assert!(inner.indices().len() == inner.data().len());
449 }
450
451 #[kani::proof]
454 #[kani::unwind(5)]
455 fn csc_new_indptr_length() {
456 let n_rows: usize = kani::any();
457 let n_cols: usize = kani::any();
458 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
459 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
460
461 let indptr = vec![0usize; n_cols + 1];
463 let indices: Vec<usize> = vec![];
464 let data: Vec<i32> = vec![];
465
466 if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
467 let inner_indptr = m.inner().indptr();
468 assert!(inner_indptr.raw_storage().len() == n_cols + 1);
469 }
470 }
471
472 #[kani::proof]
475 #[kani::unwind(5)]
476 fn csc_new_indptr_monotonic() {
477 let n_rows: usize = kani::any();
478 let n_cols: usize = kani::any();
479 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
480 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
481
482 let row: usize = kani::any();
484 let col: usize = kani::any();
485 kani::assume(row < n_rows);
486 kani::assume(col < n_cols);
487
488 let mut indptr = vec![0usize; n_cols + 1];
490 for i in (col + 1)..=n_cols {
491 indptr[i] = 1;
492 }
493 let indices = vec![row];
494 let data = vec![42i32];
495
496 if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
497 let inner_indptr = m.inner().indptr().raw_storage().to_vec();
498 for i in 0..m.n_cols() {
499 assert!(inner_indptr[i] <= inner_indptr[i + 1]);
500 }
501 }
502 }
503
504 #[kani::proof]
506 #[kani::unwind(5)]
507 fn csc_new_row_indices_in_bounds() {
508 let n_rows: usize = kani::any();
509 let n_cols: usize = kani::any();
510 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
511 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
512
513 let row: usize = kani::any();
514 let col: usize = kani::any();
515 kani::assume(row < n_rows);
516 kani::assume(col < n_cols);
517
518 let mut indptr = vec![0usize; n_cols + 1];
519 for i in (col + 1)..=n_cols {
520 indptr[i] = 1;
521 }
522 let indices = vec![row];
523 let data = vec![1i32];
524
525 if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
526 for &r in m.inner().indices() {
527 assert!(r < m.n_rows());
528 }
529 }
530 }
531
532 #[kani::proof]
534 #[kani::unwind(5)]
535 fn csc_new_indices_data_same_length() {
536 let n_rows: usize = kani::any();
537 let n_cols: usize = kani::any();
538 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
539 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
540
541 let indptr = vec![0usize; n_cols + 1];
542 let indices: Vec<usize> = vec![];
543 let data: Vec<i32> = vec![];
544
545 if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
546 assert!(m.inner().indices().len() == m.inner().data().len());
547 }
548 }
549
550 #[kani::proof]
553 #[kani::unwind(5)]
554 fn csc_new_rejects_mismatched_lengths() {
555 let n_rows: usize = kani::any();
556 let n_cols: usize = kani::any();
557 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
558 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
559
560 let indptr = vec![0usize; n_cols + 1];
562 let indices = vec![0usize];
563 let data: Vec<i32> = vec![];
564
565 let result = CscMatrix::new(n_rows, n_cols, indptr, indices, data);
566 assert!(result.is_err());
567 }
568
569 #[kani::proof]
572 #[kani::unwind(5)]
573 fn csc_from_coo_invariants() {
574 let n_rows: usize = kani::any();
575 let n_cols: usize = kani::any();
576 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
577 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
578
579 let mut coo = CooMatrix::<i32>::new(n_rows, n_cols);
580
581 let do_insert: bool = kani::any();
583 if do_insert {
584 let row: usize = kani::any();
585 let col: usize = kani::any();
586 kani::assume(row < n_rows);
587 kani::assume(col < n_cols);
588 let _ = coo.push(row, col, 1i32);
589 }
590
591 if let Ok(csc) = CscMatrix::from_coo(&coo) {
592 assert_csc_invariants(&csc);
593 assert!(csc.n_rows() == n_rows);
594 assert!(csc.n_cols() == n_cols);
595 }
596 }
597
598 #[kani::proof]
600 #[kani::unwind(5)]
601 fn csc_add_preserves_invariants() {
602 let n_rows: usize = kani::any();
603 let n_cols: usize = kani::any();
604 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
605 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
606
607 let indptr = vec![0usize; n_cols + 1];
609 let a = CscMatrix::<i32>::new(n_rows, n_cols, indptr.clone(), vec![], vec![]);
610 let b = CscMatrix::<i32>::new(n_rows, n_cols, indptr, vec![], vec![]);
611
612 if let (Ok(a), Ok(b)) = (a, b) {
613 if let Ok(sum) = a.add(&b) {
614 assert!(sum.n_rows() == n_rows);
616 assert!(sum.n_cols() == n_cols);
617 assert_csc_invariants(&sum);
619 }
620 }
621 }
622
623 #[kani::proof]
625 #[kani::unwind(5)]
626 fn csc_add_nonempty_preserves_invariants() {
627 let a = CscMatrix::<i32>::new(2, 2, vec![0, 1, 1], vec![0], vec![1]);
629 let b = CscMatrix::<i32>::new(2, 2, vec![0, 0, 1], vec![1], vec![2]);
630
631 if let (Ok(a), Ok(b)) = (a, b) {
632 if let Ok(sum) = a.add(&b) {
633 assert!(sum.n_rows() == 2);
634 assert!(sum.n_cols() == 2);
635 assert_csc_invariants(&sum);
636 }
637 }
638 }
639
640 #[kani::proof]
642 #[kani::unwind(5)]
643 fn csc_mul_vec_output_dimension() {
644 let n_rows: usize = kani::any();
645 let n_cols: usize = kani::any();
646 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
647 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
648
649 let indptr = vec![0usize; n_cols + 1];
651 let m = CscMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
652
653 if let Ok(m) = m {
654 let v = Array1::<f64>::zeros(n_cols);
655 if let Ok(result) = m.mul_vec(&v) {
656 assert!(result.len() == n_rows);
657 }
658 }
659 }
660
661 #[kani::proof]
663 #[kani::unwind(5)]
664 fn csc_mul_vec_rejects_wrong_dimension() {
665 let n_rows: usize = kani::any();
666 let n_cols: usize = kani::any();
667 kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
668 kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
669
670 let indptr = vec![0usize; n_cols + 1];
671 let m = CscMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
672
673 if let Ok(m) = m {
674 let wrong_len: usize = kani::any();
675 kani::assume(wrong_len <= MAX_DIM);
676 kani::assume(wrong_len != n_cols);
677 let v = Array1::<f64>::zeros(wrong_len);
678 let result = m.mul_vec(&v);
679 assert!(result.is_err());
680 }
681 }
682
683 #[kani::proof]
686 #[kani::unwind(5)]
687 fn csc_mul_vec_nonempty_no_oob() {
688 let m = CscMatrix::<f64>::new(2, 3, vec![0, 0, 1, 2], vec![0, 1], vec![3.0, 4.0]);
691 if let Ok(m) = m {
692 let v = Array1::from(vec![1.0, 2.0, 3.0]);
693 if let Ok(result) = m.mul_vec(&v) {
694 assert!(result.len() == 2);
695 }
696 }
697 }
698}