1use crate::error::{Result, ScryLearnError};
9use std::ops;
10
11#[derive(Clone, Debug)]
17pub struct SparseRow<'a> {
18 indices: &'a [usize],
19 data: &'a [f64],
20}
21
22impl<'a> SparseRow<'a> {
23 pub fn iter(&self) -> impl Iterator<Item = (usize, f64)> + 'a {
25 self.indices.iter().copied().zip(self.data.iter().copied())
26 }
27
28 pub fn nnz(&self) -> usize {
30 self.indices.len()
31 }
32
33 pub fn indices(&self) -> &[usize] {
35 self.indices
36 }
37
38 pub fn values(&self) -> &[f64] {
40 self.data
41 }
42
43 pub fn dot(&self, other: &[f64]) -> f64 {
45 self.indices
46 .iter()
47 .zip(self.data.iter())
48 .map(|(&j, &v)| v * other[j])
49 .sum()
50 }
51}
52
53#[derive(Clone, Debug)]
55pub struct SparseCol<'a> {
56 indices: &'a [usize],
57 data: &'a [f64],
58}
59
60impl<'a> SparseCol<'a> {
61 pub fn iter(&self) -> impl Iterator<Item = (usize, f64)> + 'a {
63 self.indices.iter().copied().zip(self.data.iter().copied())
64 }
65
66 pub fn nnz(&self) -> usize {
68 self.indices.len()
69 }
70
71 pub fn dot(&self, other: &[f64]) -> f64 {
73 self.indices
74 .iter()
75 .zip(self.data.iter())
76 .map(|(&i, &v)| v * other[i])
77 .sum()
78 }
79}
80
81#[derive(Clone, Debug)]
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92#[non_exhaustive]
93pub struct CsrMatrix {
94 indptr: Vec<usize>,
96 indices: Vec<usize>,
98 data: Vec<f64>,
100 n_rows: usize,
101 n_cols: usize,
102}
103
104impl CsrMatrix {
105 pub fn from_triplets(
109 rows: &[usize],
110 cols: &[usize],
111 vals: &[f64],
112 n_rows: usize,
113 n_cols: usize,
114 ) -> Result<Self> {
115 let nnz = rows.len();
116 if cols.len() != nnz || vals.len() != nnz {
117 return Err(ScryLearnError::InvalidParameter(format!(
118 "triplet arrays must have equal length (rows={}, cols={}, vals={})",
119 nnz,
120 cols.len(),
121 vals.len()
122 )));
123 }
124
125 for i in 0..nnz {
127 if rows[i] >= n_rows || cols[i] >= n_cols {
128 return Err(ScryLearnError::InvalidParameter(format!(
129 "triplet index ({}, {}) out of bounds for {}x{} matrix",
130 rows[i], cols[i], n_rows, n_cols
131 )));
132 }
133 }
134
135 let mut row_counts = vec![0usize; n_rows];
137 for &r in rows {
138 row_counts[r] += 1;
139 }
140
141 let mut indptr = vec![0usize; n_rows + 1];
143 for i in 0..n_rows {
144 indptr[i + 1] = indptr[i] + row_counts[i];
145 }
146
147 let total = indptr[n_rows];
149 let mut csr_indices = vec![0usize; total];
150 let mut csr_data = vec![0.0f64; total];
151 let mut offsets = indptr[..n_rows].to_vec();
152
153 for k in 0..nnz {
154 let r = rows[k];
155 let pos = offsets[r];
156 csr_indices[pos] = cols[k];
157 csr_data[pos] = vals[k];
158 offsets[r] += 1;
159 }
160
161 let mut final_indices = Vec::with_capacity(total);
163 let mut final_data = Vec::with_capacity(total);
164 let mut new_indptr = vec![0usize; n_rows + 1];
165
166 for i in 0..n_rows {
167 let start = indptr[i];
168 let end = indptr[i + 1];
169
170 let mut pairs: Vec<(usize, f64)> = csr_indices[start..end]
172 .iter()
173 .copied()
174 .zip(csr_data[start..end].iter().copied())
175 .collect();
176 pairs.sort_by_key(|&(c, _)| c);
177
178 let row_start = final_indices.len();
180 for &(col, val) in &pairs {
181 if final_indices.len() > row_start && final_indices[final_indices.len() - 1] == col
183 {
184 let last = final_data.len() - 1;
185 final_data[last] += val;
186 continue;
187 }
188 final_indices.push(col);
189 final_data.push(val);
190 }
191 new_indptr[i + 1] = final_indices.len();
192 }
193
194 Ok(Self {
195 indptr: new_indptr,
196 indices: final_indices,
197 data: final_data,
198 n_rows,
199 n_cols,
200 })
201 }
202
203 pub fn from_dense(rows: &[Vec<f64>]) -> Self {
205 let n_rows = rows.len();
206 let n_cols = if n_rows > 0 { rows[0].len() } else { 0 };
207
208 let mut indptr = vec![0usize; n_rows + 1];
209 let mut indices = Vec::new();
210 let mut data = Vec::new();
211
212 for (i, row) in rows.iter().enumerate() {
213 for (j, &val) in row.iter().enumerate() {
214 if val != 0.0 {
215 indices.push(j);
216 data.push(val);
217 }
218 }
219 indptr[i + 1] = indices.len();
220 }
221
222 Self {
223 indptr,
224 indices,
225 data,
226 n_rows,
227 n_cols,
228 }
229 }
230
231 #[inline]
233 pub fn n_rows(&self) -> usize {
234 self.n_rows
235 }
236
237 #[inline]
239 pub fn n_cols(&self) -> usize {
240 self.n_cols
241 }
242
243 #[inline]
245 pub fn nnz(&self) -> usize {
246 self.data.len()
247 }
248
249 #[inline]
253 pub fn density(&self) -> f64 {
254 let total = self.n_rows * self.n_cols;
255 if total == 0 {
256 return 0.0;
257 }
258 self.nnz() as f64 / total as f64
259 }
260
261 pub fn row(&self, i: usize) -> SparseRow<'_> {
263 let start = self.indptr[i];
264 let end = self.indptr[i + 1];
265 SparseRow {
266 indices: &self.indices[start..end],
267 data: &self.data[start..end],
268 }
269 }
270
271 pub fn get(&self, row: usize, col: usize) -> f64 {
273 let start = self.indptr[row];
274 let end = self.indptr[row + 1];
275 self.indices[start..end]
276 .binary_search(&col)
277 .map_or(0.0, |pos| self.data[start + pos])
278 }
279
280 pub fn to_csc(&self) -> CscMatrix {
282 let nnz = self.nnz();
283
284 let mut col_counts = vec![0usize; self.n_cols];
286 for &c in &self.indices {
287 col_counts[c] += 1;
288 }
289
290 let mut indptr = vec![0usize; self.n_cols + 1];
291 for j in 0..self.n_cols {
292 indptr[j + 1] = indptr[j] + col_counts[j];
293 }
294
295 let mut csc_indices = vec![0usize; nnz];
296 let mut csc_data = vec![0.0f64; nnz];
297 let mut offsets = indptr[..self.n_cols].to_vec();
298
299 for i in 0..self.n_rows {
300 let start = self.indptr[i];
301 let end = self.indptr[i + 1];
302 for k in start..end {
303 let col = self.indices[k];
304 let pos = offsets[col];
305 csc_indices[pos] = i;
306 csc_data[pos] = self.data[k];
307 offsets[col] += 1;
308 }
309 }
310
311 CscMatrix {
312 indptr,
313 indices: csc_indices,
314 data: csc_data,
315 n_rows: self.n_rows,
316 n_cols: self.n_cols,
317 }
318 }
319
320 pub fn to_dense(&self) -> Vec<Vec<f64>> {
322 let mut dense = vec![vec![0.0; self.n_cols]; self.n_rows];
323 for (i, row) in dense.iter_mut().enumerate() {
324 let start = self.indptr[i];
325 let end = self.indptr[i + 1];
326 for k in start..end {
327 row[self.indices[k]] = self.data[k];
328 }
329 }
330 dense
331 }
332
333 pub fn dot_vec(&self, x: &[f64]) -> Vec<f64> {
335 let mut y = vec![0.0; self.n_rows];
336 for (yi, i) in y.iter_mut().zip(0..self.n_rows) {
337 *yi = self.row(i).dot(x);
338 }
339 y
340 }
341}
342
343#[derive(Clone, Debug)]
352#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
353#[non_exhaustive]
354pub struct CscMatrix {
355 indptr: Vec<usize>,
357 indices: Vec<usize>,
359 data: Vec<f64>,
361 n_rows: usize,
362 n_cols: usize,
363}
364
365impl CscMatrix {
366 pub fn from_triplets(
370 rows: &[usize],
371 cols: &[usize],
372 vals: &[f64],
373 n_rows: usize,
374 n_cols: usize,
375 ) -> Result<Self> {
376 let csr = CsrMatrix::from_triplets(rows, cols, vals, n_rows, n_cols)?;
378 Ok(csr.to_csc())
379 }
380
381 pub fn from_dense(cols: &[Vec<f64>]) -> Self {
385 let n_cols = cols.len();
386 let n_rows = if n_cols > 0 { cols[0].len() } else { 0 };
387
388 let mut indptr = vec![0usize; n_cols + 1];
389 let mut indices = Vec::new();
390 let mut data = Vec::new();
391
392 for (j, col) in cols.iter().enumerate() {
393 for (i, &val) in col.iter().enumerate() {
394 if val != 0.0 {
395 indices.push(i);
396 data.push(val);
397 }
398 }
399 indptr[j + 1] = indices.len();
400 }
401
402 Self {
403 indptr,
404 indices,
405 data,
406 n_rows,
407 n_cols,
408 }
409 }
410
411 #[inline]
413 pub fn n_rows(&self) -> usize {
414 self.n_rows
415 }
416
417 #[inline]
419 pub fn n_cols(&self) -> usize {
420 self.n_cols
421 }
422
423 #[inline]
425 pub fn nnz(&self) -> usize {
426 self.data.len()
427 }
428
429 #[inline]
431 pub fn density(&self) -> f64 {
432 let total = self.n_rows * self.n_cols;
433 if total == 0 {
434 return 0.0;
435 }
436 self.nnz() as f64 / total as f64
437 }
438
439 pub fn col(&self, j: usize) -> SparseCol<'_> {
441 let start = self.indptr[j];
442 let end = self.indptr[j + 1];
443 SparseCol {
444 indices: &self.indices[start..end],
445 data: &self.data[start..end],
446 }
447 }
448
449 pub fn get(&self, row: usize, col: usize) -> f64 {
451 let start = self.indptr[col];
452 let end = self.indptr[col + 1];
453 self.indices[start..end]
454 .binary_search(&row)
455 .map_or(0.0, |pos| self.data[start + pos])
456 }
457
458 pub fn to_csr(&self) -> CsrMatrix {
460 let nnz = self.nnz();
461
462 let mut row_counts = vec![0usize; self.n_rows];
464 for &r in &self.indices {
465 row_counts[r] += 1;
466 }
467
468 let mut indptr = vec![0usize; self.n_rows + 1];
469 for i in 0..self.n_rows {
470 indptr[i + 1] = indptr[i] + row_counts[i];
471 }
472
473 let mut csr_indices = vec![0usize; nnz];
474 let mut csr_data = vec![0.0f64; nnz];
475 let mut offsets = indptr[..self.n_rows].to_vec();
476
477 for j in 0..self.n_cols {
478 let start = self.indptr[j];
479 let end = self.indptr[j + 1];
480 for k in start..end {
481 let row = self.indices[k];
482 let pos = offsets[row];
483 csr_indices[pos] = j;
484 csr_data[pos] = self.data[k];
485 offsets[row] += 1;
486 }
487 }
488
489 CsrMatrix {
490 indptr,
491 indices: csr_indices,
492 data: csr_data,
493 n_rows: self.n_rows,
494 n_cols: self.n_cols,
495 }
496 }
497
498 pub fn to_dense(&self) -> Vec<Vec<f64>> {
500 self.to_csr().to_dense()
502 }
503
504 pub fn dot_vec(&self, x: &[f64]) -> Vec<f64> {
506 let mut y = vec![0.0; self.n_rows];
507 for (j, &xj) in x.iter().enumerate() {
508 let start = self.indptr[j];
509 let end = self.indptr[j + 1];
510 for k in start..end {
511 y[self.indices[k]] += self.data[k] * xj;
512 }
513 }
514 y
515 }
516}
517
518impl ops::Add for &CsrMatrix {
523 type Output = CsrMatrix;
524
525 fn add(self, rhs: &CsrMatrix) -> CsrMatrix {
531 assert_eq!(
532 (self.n_rows, self.n_cols),
533 (rhs.n_rows, rhs.n_cols),
534 "CsrMatrix addition requires same shape"
535 );
536
537 let mut indptr = vec![0usize; self.n_rows + 1];
538 let mut indices = Vec::new();
539 let mut data = Vec::new();
540
541 for i in 0..self.n_rows {
542 let a_start = self.indptr[i];
543 let a_end = self.indptr[i + 1];
544 let b_start = rhs.indptr[i];
545 let b_end = rhs.indptr[i + 1];
546
547 let mut a = a_start;
548 let mut b = b_start;
549
550 while a < a_end && b < b_end {
552 match self.indices[a].cmp(&rhs.indices[b]) {
553 std::cmp::Ordering::Less => {
554 indices.push(self.indices[a]);
555 data.push(self.data[a]);
556 a += 1;
557 }
558 std::cmp::Ordering::Greater => {
559 indices.push(rhs.indices[b]);
560 data.push(rhs.data[b]);
561 b += 1;
562 }
563 std::cmp::Ordering::Equal => {
564 let sum = self.data[a] + rhs.data[b];
565 if sum != 0.0 {
566 indices.push(self.indices[a]);
567 data.push(sum);
568 }
569 a += 1;
570 b += 1;
571 }
572 }
573 }
574 while a < a_end {
575 indices.push(self.indices[a]);
576 data.push(self.data[a]);
577 a += 1;
578 }
579 while b < b_end {
580 indices.push(rhs.indices[b]);
581 data.push(rhs.data[b]);
582 b += 1;
583 }
584
585 indptr[i + 1] = indices.len();
586 }
587
588 CsrMatrix {
589 indptr,
590 indices,
591 data,
592 n_rows: self.n_rows,
593 n_cols: self.n_cols,
594 }
595 }
596}
597
598impl ops::Mul<f64> for &CsrMatrix {
599 type Output = CsrMatrix;
600
601 fn mul(self, rhs: f64) -> CsrMatrix {
603 CsrMatrix {
604 indptr: self.indptr.clone(),
605 indices: self.indices.clone(),
606 data: self.data.iter().map(|&v| v * rhs).collect(),
607 n_rows: self.n_rows,
608 n_cols: self.n_cols,
609 }
610 }
611}
612
613#[cfg(test)]
618#[allow(clippy::float_cmp)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_from_triplets_basic() {
624 let rows = vec![0, 0, 1, 2, 2];
629 let cols = vec![0, 2, 1, 0, 2];
630 let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
631
632 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
633 assert_eq!(csr.n_rows(), 3);
634 assert_eq!(csr.n_cols(), 3);
635 assert_eq!(csr.nnz(), 5);
636 assert_eq!(csr.get(0, 0), 1.0);
637 assert_eq!(csr.get(0, 2), 2.0);
638 assert_eq!(csr.get(1, 1), 3.0);
639 assert_eq!(csr.get(2, 0), 4.0);
640 assert_eq!(csr.get(2, 2), 5.0);
641 assert_eq!(csr.get(0, 1), 0.0);
642 assert_eq!(csr.get(1, 0), 0.0);
643 }
644
645 #[test]
646 fn test_duplicate_entries_summed() {
647 let rows = vec![0, 0, 0];
648 let cols = vec![1, 1, 1];
649 let vals = vec![1.0, 2.0, 3.0];
650
651 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
652 assert_eq!(csr.nnz(), 1);
653 assert_eq!(csr.get(0, 1), 6.0);
654 }
655
656 #[test]
657 fn test_csr_csc_roundtrip() {
658 let rows = vec![0, 0, 1, 2, 2];
659 let cols = vec![0, 2, 1, 0, 2];
660 let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
661
662 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
663 let csc = csr.to_csc();
664 let csr2 = csc.to_csr();
665
666 assert_eq!(csr.to_dense(), csr2.to_dense());
667 }
668
669 #[test]
670 fn test_dense_roundtrip() {
671 let dense = vec![
672 vec![1.0, 0.0, 2.0],
673 vec![0.0, 3.0, 0.0],
674 vec![4.0, 0.0, 5.0],
675 ];
676
677 let csr = CsrMatrix::from_dense(&dense);
678 assert_eq!(csr.to_dense(), dense);
679 }
680
681 #[test]
682 fn test_get_existing_and_missing() {
683 let csr = CsrMatrix::from_dense(&[vec![0.0, 7.0], vec![8.0, 0.0]]);
684 assert_eq!(csr.get(0, 1), 7.0);
685 assert_eq!(csr.get(1, 0), 8.0);
686 assert_eq!(csr.get(0, 0), 0.0);
687 assert_eq!(csr.get(1, 1), 0.0);
688 }
689
690 #[test]
691 fn test_dot_vec_csr() {
692 let csr = CsrMatrix::from_dense(&[vec![1.0, 2.0], vec![0.0, 3.0]]);
695 let result = csr.dot_vec(&[3.0, 4.0]);
696 assert_eq!(result, vec![11.0, 12.0]);
697 }
698
699 #[test]
700 fn test_dot_vec_csc() {
701 let dense = vec![vec![1.0, 2.0], vec![0.0, 3.0]];
702 let csr = CsrMatrix::from_dense(&dense);
703 let csc = csr.to_csc();
704 let result = csc.dot_vec(&[3.0, 4.0]);
705 assert_eq!(result, vec![11.0, 12.0]);
706 }
707
708 #[test]
709 fn test_sparse_row_iteration() {
710 let csr = CsrMatrix::from_dense(&[vec![0.0, 5.0, 0.0, 7.0]]);
711 let row = csr.row(0);
712 let pairs: Vec<(usize, f64)> = row.iter().collect();
713 assert_eq!(pairs, vec![(1, 5.0), (3, 7.0)]);
714 assert_eq!(row.nnz(), 2);
715 }
716
717 #[test]
718 fn test_sparse_col_iteration() {
719 let csr = CsrMatrix::from_dense(&[vec![1.0, 0.0], vec![0.0, 0.0], vec![3.0, 0.0]]);
720 let csc = csr.to_csc();
721 let col = csc.col(0);
722 let pairs: Vec<(usize, f64)> = col.iter().collect();
723 assert_eq!(pairs, vec![(0, 1.0), (2, 3.0)]);
724 assert_eq!(col.nnz(), 2);
725 }
726
727 #[test]
728 fn test_empty_matrix() {
729 let csr = CsrMatrix::from_triplets(&[], &[], &[], 0, 0).unwrap();
731 assert_eq!(csr.n_rows(), 0);
732 assert_eq!(csr.n_cols(), 0);
733 assert_eq!(csr.nnz(), 0);
734 assert_eq!(csr.density(), 0.0);
735
736 let csr = CsrMatrix::from_triplets(&[], &[], &[], 5, 5).unwrap();
738 assert_eq!(csr.n_rows(), 5);
739 assert_eq!(csr.n_cols(), 5);
740 assert_eq!(csr.nnz(), 0);
741 assert_eq!(csr.density(), 0.0);
742 assert_eq!(csr.get(2, 3), 0.0);
743 }
744
745 #[test]
746 fn test_density() {
747 let rows = vec![0, 0, 1, 2, 2];
749 let cols = vec![0, 2, 1, 0, 2];
750 let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
751 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
752 assert!((csr.density() - 5.0 / 9.0).abs() < 1e-10);
753 }
754
755 #[test]
756 fn test_large_sparse() {
757 let n = 1000;
759 let mut rng = fastrand::Rng::with_seed(42);
760 let target_nnz = (n * n) / 1000; let mut rows = Vec::with_capacity(target_nnz);
763 let mut cols = Vec::with_capacity(target_nnz);
764 let mut vals = Vec::with_capacity(target_nnz);
765
766 for _ in 0..target_nnz {
767 rows.push(rng.usize(..n));
768 cols.push(rng.usize(..n));
769 vals.push(rng.f64() * 10.0);
770 }
771
772 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, n, n).unwrap();
773 assert_eq!(csr.n_rows(), n);
774 assert_eq!(csr.n_cols(), n);
775 assert!(csr.nnz() <= target_nnz);
777 assert!(csr.nnz() > 0);
778 assert!(csr.density() < 0.002);
779
780 let csc = csr.to_csc();
782 let csr2 = csc.to_csr();
783 assert_eq!(csr.nnz(), csr2.nnz());
784 }
785
786 #[test]
787 fn test_from_dense_skips_zeros() {
788 let dense = vec![
789 vec![0.0, 0.0, 1.0],
790 vec![0.0, 0.0, 0.0],
791 vec![2.0, 0.0, 0.0],
792 ];
793 let csr = CsrMatrix::from_dense(&dense);
794 assert_eq!(csr.nnz(), 2);
795 assert_eq!(csr.get(0, 2), 1.0);
796 assert_eq!(csr.get(2, 0), 2.0);
797 }
798
799 #[test]
800 fn test_csr_add() {
801 let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 0.0]]);
802 let b = CsrMatrix::from_dense(&[vec![0.0, 4.0, 0.0], vec![5.0, 0.0, 6.0]]);
803 let c = &a + &b;
804 assert_eq!(
805 c.to_dense(),
806 vec![vec![1.0, 4.0, 2.0], vec![5.0, 3.0, 6.0],]
807 );
808 }
809
810 #[test]
811 fn test_csr_scalar_mul() {
812 let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 0.0]]);
813 let b = &a * 2.0;
814 assert_eq!(
815 b.to_dense(),
816 vec![vec![2.0, 0.0, 4.0], vec![0.0, 6.0, 0.0],]
817 );
818 }
819
820 #[test]
821 fn test_csc_from_triplets() {
822 let rows = vec![0, 1, 2];
823 let cols = vec![0, 1, 2];
824 let vals = vec![1.0, 2.0, 3.0];
825 let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
826 assert_eq!(csc.n_rows(), 3);
827 assert_eq!(csc.n_cols(), 3);
828 assert_eq!(csc.nnz(), 3);
829 assert_eq!(csc.get(0, 0), 1.0);
830 assert_eq!(csc.get(1, 1), 2.0);
831 assert_eq!(csc.get(2, 2), 3.0);
832 assert_eq!(csc.get(0, 1), 0.0);
833 }
834
835 #[test]
836 fn test_csc_from_dense() {
837 let cols = vec![
839 vec![1.0, 0.0, 4.0], vec![0.0, 3.0, 0.0], vec![2.0, 0.0, 5.0], ];
843 let csc = CscMatrix::from_dense(&cols);
844 assert_eq!(csc.n_rows(), 3);
845 assert_eq!(csc.n_cols(), 3);
846 assert_eq!(csc.nnz(), 5);
847 assert_eq!(csc.get(0, 0), 1.0);
848 assert_eq!(csc.get(2, 0), 4.0);
849 assert_eq!(csc.get(1, 1), 3.0);
850 assert_eq!(csc.get(0, 2), 2.0);
851 assert_eq!(csc.get(2, 2), 5.0);
852 }
853
854 #[test]
855 fn test_sparse_row_dot() {
856 let csr = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
857 let row = csr.row(0);
858 assert!((row.dot(&[1.0, 10.0, 100.0]) - 320.0).abs() < 1e-10);
859 }
860
861 #[test]
862 fn test_csr_add_cancellation() {
863 let a = CsrMatrix::from_dense(&[vec![1.0, 2.0]]);
865 let b = CsrMatrix::from_dense(&[vec![-1.0, -2.0]]);
866 let c = &a + &b;
867 assert_eq!(c.nnz(), 0);
868 assert_eq!(c.to_dense(), vec![vec![0.0, 0.0]]);
869 }
870
871 #[test]
872 fn test_from_triplets_cross_row_dedup() {
873 let rows = vec![0, 1];
879 let cols = vec![2, 2];
880 let vals = vec![1.0, 3.0];
881
882 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
883 assert_eq!(csr.nnz(), 2);
884 assert_eq!(csr.get(0, 2), 1.0);
885 assert_eq!(csr.get(1, 2), 3.0);
886 }
887
888 #[test]
889 fn test_from_triplets_intra_row_dedup() {
890 let rows = vec![0, 0, 1, 1];
892 let cols = vec![1, 1, 2, 2];
893 let vals = vec![1.0, 2.0, 3.0, 4.0];
894
895 let csr = CsrMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
896 assert_eq!(csr.nnz(), 2);
897 assert_eq!(csr.get(0, 1), 3.0); assert_eq!(csr.get(1, 2), 7.0); }
900
901 #[test]
902 fn test_csc_from_triplets_cross_row_dedup() {
903 let rows = vec![0, 1];
905 let cols = vec![2, 2];
906 let vals = vec![1.0, 3.0];
907
908 let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
909 assert_eq!(csc.nnz(), 2);
910 assert_eq!(csc.get(0, 2), 1.0);
911 assert_eq!(csc.get(1, 2), 3.0);
912 }
913
914 #[test]
915 fn test_csc_from_triplets_roundtrip_with_dupes() {
916 let rows = vec![0, 0, 1, 2, 2];
918 let cols = vec![0, 0, 1, 0, 2]; let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
920
921 let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
922 assert_eq!(csc.get(0, 0), 3.0); assert_eq!(csc.get(1, 1), 3.0);
924 assert_eq!(csc.get(2, 0), 4.0);
925 assert_eq!(csc.get(2, 2), 5.0);
926
927 let csr = csc.to_csr();
929 let csc2 = csr.to_csc();
930 assert_eq!(csc.to_dense(), csc2.to_dense());
931 }
932
933 #[test]
934 fn test_sparse_row_accessors() {
935 let csr = CsrMatrix::from_dense(&[vec![0.0, 5.0, 0.0, 7.0]]);
936 let row = csr.row(0);
937 assert_eq!(row.indices(), &[1, 3]);
938 assert_eq!(row.values(), &[5.0, 7.0]);
939 }
940}