1use crate::prelude::*;
4use faer::{Mat, MatRef};
5use rayon::prelude::*;
6
7struct SpaAcc<T: EvocFloat> {
19 values: Vec<T>,
22 indices: Vec<usize>,
25 flags: Vec<bool>,
28}
29
30impl<T: EvocFloat> SpaAcc<T> {
31 fn new(size: usize) -> Self {
41 Self {
42 values: vec![T::zero(); size],
43 indices: Vec::with_capacity(size / 10),
44 flags: vec![false; size],
45 }
46 }
47
48 #[inline]
65 unsafe fn scatter(&mut self, idx: usize, val: T) {
66 unsafe {
67 if !*self.flags.get_unchecked(idx) {
68 *self.flags.get_unchecked_mut(idx) = true;
69 self.indices.push(idx);
70 *self.values.get_unchecked_mut(idx) = val;
71 } else {
72 let cur = *self.values.get_unchecked(idx);
73 *self.values.get_unchecked_mut(idx) = cur + val;
74 }
75 }
76 }
77
78 #[inline]
90 fn gather_sorted(&mut self) -> Vec<(usize, T)> {
91 self.indices.sort_unstable();
92 let out: Vec<(usize, T)> = self
93 .indices
94 .iter()
95 .map(|&i| unsafe { (i, *self.values.get_unchecked(i)) })
98 .collect();
99 for &i in &self.indices {
100 unsafe {
102 *self.flags.get_unchecked_mut(i) = false;
103 *self.values.get_unchecked_mut(i) = T::zero();
104 }
105 }
106 self.indices.clear();
107 out
108 }
109}
110
111#[derive(Clone)]
119pub struct CoordinateList<T> {
120 pub row_indices: Vec<usize>,
122 pub col_indices: Vec<usize>,
124 pub values: Vec<T>,
126 pub n_samples: usize,
128}
129
130#[derive(Clone, Debug)]
136pub struct Csr<T> {
137 pub indptr: Vec<usize>,
139 pub indices: Vec<usize>,
141 pub data: Vec<T>,
143 pub nrows: usize,
145 pub ncols: usize,
147}
148
149impl<T: EvocFloat> Csr<T> {
150 pub fn new(
163 indptr: Vec<usize>,
164 indices: Vec<usize>,
165 data: Vec<T>,
166 nrows: usize,
167 ncols: usize,
168 ) -> Self {
169 debug_assert_eq!(indptr.len(), nrows + 1);
170 debug_assert_eq!(indices.len(), data.len());
171 debug_assert_eq!(*indptr.last().unwrap(), data.len());
172 Self {
173 indptr,
174 indices,
175 data,
176 nrows,
177 ncols,
178 }
179 }
180
181 pub fn from_coo(coo: &CoordinateList<T>) -> Self {
196 let n = coo.n_samples;
197 let nnz = coo.values.len();
198 if nnz == 0 {
199 return Self::new(vec![0; n + 1], Vec::new(), Vec::new(), n, n);
200 }
201
202 let mut triplets: Vec<(usize, usize, T)> = (0..nnz)
203 .map(|i| (coo.row_indices[i], coo.col_indices[i], coo.values[i]))
204 .collect();
205 triplets.par_sort_unstable_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
206
207 let mut data = Vec::with_capacity(nnz);
208 let mut indices = Vec::with_capacity(nnz);
209 let mut indptr = vec![0usize; n + 1];
210
211 let mut prev_r = usize::MAX;
212 let mut prev_c = usize::MAX;
213 for &(r, c, v) in &triplets {
214 if r == prev_r && c == prev_c {
215 let last = data.last().copied().unwrap();
216 *data.last_mut().unwrap() = last + v;
217 } else {
218 data.push(v);
219 indices.push(c);
220 indptr[r + 1] += 1;
221 prev_r = r;
222 prev_c = c;
223 }
224 }
225 for i in 0..n {
226 indptr[i + 1] += indptr[i];
227 }
228
229 Self {
230 indptr,
231 indices,
232 data,
233 nrows: n,
234 ncols: n,
235 }
236 }
237
238 pub fn from_partition(partition: &[usize], n_parts: usize) -> Self {
249 let n = partition.len();
250 Self {
251 indptr: (0..=n).collect(),
252 indices: partition.to_vec(),
253 data: vec![T::one(); n],
254 nrows: n,
255 ncols: n_parts,
256 }
257 }
258
259 pub fn nnz(&self) -> usize {
261 self.data.len()
262 }
263
264 pub fn transpose(&self) -> Self {
274 let nnz = self.nnz();
275 let mut col_count = vec![0usize; self.ncols];
276 for &c in &self.indices {
277 col_count[c] += 1;
278 }
279
280 let mut indptr = vec![0usize; self.ncols + 1];
281 for i in 0..self.ncols {
282 indptr[i + 1] = indptr[i] + col_count[i];
283 }
284
285 let mut data = vec![T::zero(); nnz];
286 let mut indices = vec![0usize; nnz];
287 let mut cursor = indptr[..self.ncols].to_vec();
288
289 for row in 0..self.nrows {
290 for idx in self.indptr[row]..self.indptr[row + 1] {
291 let col = self.indices[idx];
292 let pos = cursor[col];
293 data[pos] = self.data[idx];
294 indices[pos] = row;
295 cursor[col] += 1;
296 }
297 }
298
299 Self {
300 indptr,
301 indices,
302 data,
303 nrows: self.ncols,
304 ncols: self.nrows,
305 }
306 }
307
308 pub fn matmul(&self, other: &Csr<T>) -> Self {
319 assert_eq!(
320 self.ncols, other.nrows,
321 "Dimension mismatch: ({} x {}) * ({} x {})",
322 self.nrows, self.ncols, other.nrows, other.ncols
323 );
324
325 let m = self.nrows;
326 let n = other.ncols;
327
328 let row_results: Vec<Vec<(usize, T)>> = (0..m)
329 .into_par_iter()
330 .map(|i| {
331 let mut acc = SpaAcc::new(n);
332 for a_idx in self.indptr[i]..self.indptr[i + 1] {
333 let k = self.indices[a_idx];
334 let a_val = self.data[a_idx];
335 for b_idx in other.indptr[k]..other.indptr[k + 1] {
336 unsafe {
337 acc.scatter(other.indices[b_idx], a_val * other.data[b_idx]);
338 }
339 }
340 }
341 acc.gather_sorted()
342 })
343 .collect();
344
345 let total_nnz: usize = row_results.iter().map(|r| r.len()).sum();
346 let mut data = Vec::with_capacity(total_nnz);
347 let mut indices = Vec::with_capacity(total_nnz);
348 let mut indptr = Vec::with_capacity(m + 1);
349 indptr.push(0);
350
351 for row in row_results {
352 for (col, val) in row {
353 indices.push(col);
354 data.push(val);
355 }
356 indptr.push(data.len());
357 }
358
359 Self {
360 indptr,
361 indices,
362 data,
363 nrows: m,
364 ncols: n,
365 }
366 }
367
368 pub fn elementwise_mul(&self, other: &Csr<T>) -> Self {
380 assert_eq!(
381 (self.nrows, self.ncols),
382 (other.nrows, other.ncols),
383 "Shape mismatch for element-wise multiply"
384 );
385
386 let mut indptr = vec![0usize; self.nrows + 1];
387 let mut indices = Vec::new();
388 let mut data = Vec::new();
389
390 for i in 0..self.nrows {
391 let (mut p, end_p) = (self.indptr[i], self.indptr[i + 1]);
392 let (mut q, end_q) = (other.indptr[i], other.indptr[i + 1]);
393 while p < end_p && q < end_q {
394 let ci = self.indices[p];
395 let cj = other.indices[q];
396 match ci.cmp(&cj) {
397 std::cmp::Ordering::Equal => {
398 indices.push(ci);
399 data.push(self.data[p] * other.data[q]);
400 p += 1;
401 q += 1;
402 }
403 std::cmp::Ordering::Less => p += 1,
404 std::cmp::Ordering::Greater => q += 1,
405 }
406 }
407 indptr[i + 1] = data.len();
408 }
409
410 Self {
411 indptr,
412 indices,
413 data,
414 nrows: self.nrows,
415 ncols: self.ncols,
416 }
417 }
418
419 pub fn normalise_cols_l2(&self) -> Self {
428 let mut col_sq = vec![T::zero(); self.ncols];
429 for (idx, &v) in self.data.iter().enumerate() {
430 let c = self.indices[idx];
431 col_sq[c] += v * v;
432 }
433
434 let col_inv: Vec<T> = col_sq
435 .iter()
436 .map(|&sq| {
437 let norm = sq.sqrt();
438 if norm > T::zero() {
439 T::one() / norm
440 } else {
441 T::one()
442 }
443 })
444 .collect();
445
446 let new_data: Vec<T> = self
447 .data
448 .iter()
449 .enumerate()
450 .map(|(idx, &v)| v * col_inv[self.indices[idx]])
451 .collect();
452
453 Self {
454 indptr: self.indptr.clone(),
455 indices: self.indices.clone(),
456 data: new_data,
457 nrows: self.nrows,
458 ncols: self.ncols,
459 }
460 }
461
462 pub fn normalise_rows_l1(&self) -> Self {
471 let mut new_data = self.data.clone();
472 for i in 0..self.nrows {
473 let start = self.indptr[i];
474 let end = self.indptr[i + 1];
475 let mut norm = T::zero();
476 for idx in start..end {
477 norm += self.data[idx].abs();
478 }
479 if norm > T::zero() {
480 let inv = T::one() / norm;
481 for idx in start..end {
482 new_data[idx] = new_data[idx] * inv;
483 }
484 }
485 }
486
487 Self {
488 indptr: self.indptr.clone(),
489 indices: self.indices.clone(),
490 data: new_data,
491 nrows: self.nrows,
492 ncols: self.ncols,
493 }
494 }
495
496 pub fn clip_values(&mut self, lo: T, hi: T) {
503 for d in &mut self.data {
504 if *d < lo {
505 *d = lo;
506 } else if *d > hi {
507 *d = hi;
508 }
509 }
510 }
511
512 pub fn to_adjacency_list(&self) -> Vec<Vec<(usize, T)>> {
522 (0..self.nrows)
523 .map(|i| {
524 (self.indptr[i]..self.indptr[i + 1])
525 .map(|idx| (self.indices[idx], self.data[idx]))
526 .collect()
527 })
528 .collect()
529 }
530
531 pub fn matmul_dense(&self, rhs: &MatRef<T>) -> Mat<T> {
543 assert_eq!(
544 self.ncols,
545 rhs.nrows(),
546 "Dimension mismatch: CSR cols {} vs Mat rows {}",
547 self.ncols,
548 rhs.nrows()
549 );
550
551 let d = rhs.ncols();
552 let rows: Vec<Vec<T>> = (0..self.nrows)
553 .into_par_iter()
554 .map(|i| {
555 let mut row = vec![T::zero(); d];
556 for idx in self.indptr[i]..self.indptr[i + 1] {
557 let j = self.indices[idx];
558 let v = self.data[idx];
559 for k in 0..d {
560 row[k] += v * rhs[(j, k)];
561 }
562 }
563 row
564 })
565 .collect();
566
567 Mat::from_fn(self.nrows, d, |i, j| rows[i][j])
568 }
569
570 pub fn to_dense(&self) -> Mat<T> {
576 let mut dense = Mat::zeros(self.nrows, self.ncols);
577 for i in 0..self.nrows {
578 for idx in self.indptr[i]..self.indptr[i + 1] {
579 dense[(i, self.indices[idx])] = self.data[idx];
580 }
581 }
582 dense
583 }
584}
585
586pub fn vecs_to_mat<T: EvocFloat>(rows: &[Vec<T>]) -> Mat<T> {
603 let n = rows.len();
604 if n == 0 {
605 return Mat::zeros(0, 0);
606 }
607 let d = rows[0].len();
608 Mat::from_fn(n, d, |i, j| rows[i][j])
609}
610
611pub fn mat_to_vecs<T: EvocFloat>(mat: &Mat<T>) -> Vec<Vec<T>> {
621 (0..mat.nrows())
622 .map(|i| (0..mat.ncols()).map(|j| mat[(i, j)]).collect())
623 .collect()
624}
625
626#[cfg(test)]
631mod tests {
632 use super::*;
633
634 fn make_3x3() -> Csr<f64> {
642 Csr::new(
643 vec![0, 2, 3, 5],
644 vec![0, 2, 1, 0, 2],
645 vec![1.0, 2.0, 3.0, 4.0, 5.0],
646 3,
647 3,
648 )
649 }
650
651 fn approx_eq(a: f64, b: f64) -> bool {
652 (a - b).abs() < 1e-12
653 }
654
655 #[test]
656 fn from_coo_basic() {
657 let coo = CoordinateList {
658 row_indices: vec![0, 0, 1, 2, 2],
659 col_indices: vec![0, 2, 1, 0, 2],
660 values: vec![1.0, 2.0, 3.0, 4.0, 5.0],
661 n_samples: 3,
662 };
663 let csr = Csr::from_coo(&coo);
664 assert_eq!(csr.nrows, 3);
665 assert_eq!(csr.ncols, 3);
666 assert_eq!(csr.nnz(), 5);
667 assert_eq!(csr.indptr, vec![0, 2, 3, 5]);
668 assert_eq!(csr.indices, vec![0, 2, 1, 0, 2]);
669 assert_eq!(csr.data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
670 }
671
672 #[test]
673 fn from_coo_duplicates_summed() {
674 let coo = CoordinateList {
675 row_indices: vec![0, 0, 0],
676 col_indices: vec![1, 1, 2],
677 values: vec![1.0, 3.0, 5.0],
678 n_samples: 2,
679 };
680 let csr = Csr::from_coo(&coo);
681 assert_eq!(csr.indptr, vec![0, 2, 2]);
683 assert_eq!(csr.indices, vec![1, 2]);
684 assert!(approx_eq(csr.data[0], 4.0));
685 assert!(approx_eq(csr.data[1], 5.0));
686 }
687
688 #[test]
689 fn from_coo_empty() {
690 let coo: CoordinateList<f64> = CoordinateList {
691 row_indices: Vec::new(),
692 col_indices: Vec::new(),
693 values: Vec::new(),
694 n_samples: 5,
695 };
696 let csr = Csr::from_coo(&coo);
697 assert_eq!(csr.nrows, 5);
698 assert_eq!(csr.nnz(), 0);
699 assert_eq!(csr.indptr, vec![0, 0, 0, 0, 0, 0]);
700 }
701
702 #[test]
703 fn from_partition_basic() {
704 let part = vec![2, 0, 1, 2];
705 let csr = Csr::<f64>::from_partition(&part, 3);
706 assert_eq!(csr.nrows, 4);
707 assert_eq!(csr.ncols, 3);
708 assert_eq!(csr.nnz(), 4);
709 assert_eq!(csr.indices, vec![2, 0, 1, 2]);
711 assert!(csr.data.iter().all(|&v| approx_eq(v, 1.0)));
712 }
713
714 #[test]
715 fn transpose_roundtrip() {
716 let a = make_3x3();
717 let at = a.transpose();
718 assert_eq!(at.nrows, 3);
719 assert_eq!(at.ncols, 3);
720 assert_eq!(at.nnz(), 5);
721
722 let row0: Vec<(usize, f64)> = (at.indptr[0]..at.indptr[1])
724 .map(|idx| (at.indices[idx], at.data[idx]))
725 .collect();
726 assert_eq!(row0, vec![(0, 1.0), (2, 4.0)]);
727
728 let att = at.transpose();
730 assert_eq!(att.indptr, a.indptr);
731 assert_eq!(att.indices, a.indices);
732 assert_eq!(att.data, a.data);
733 }
734
735 #[test]
736 fn transpose_non_square() {
737 let m = Csr::new(vec![0, 2, 3], vec![0, 1, 2], vec![1.0, 2.0, 3.0], 2, 3);
739 let mt = m.transpose();
740 assert_eq!(mt.nrows, 3);
741 assert_eq!(mt.ncols, 2);
742 assert_eq!(mt.indptr, vec![0, 1, 2, 3]);
744 assert_eq!(mt.indices, vec![0, 0, 1]);
745 assert_eq!(mt.data, vec![1.0, 2.0, 3.0]);
746 }
747
748 #[test]
749 fn matmul_identity() {
750 let a = make_3x3();
751 let eye = Csr::new(vec![0, 1, 2, 3], vec![0, 1, 2], vec![1.0, 1.0, 1.0], 3, 3);
753 let result = a.matmul(&eye);
754 assert_eq!(result.data, a.data);
755 assert_eq!(result.indices, a.indices);
756 }
757
758 #[test]
759 fn matmul_a_times_at() {
760 let a = make_3x3();
761 let at = a.transpose();
762 let aat = a.matmul(&at);
763 let dense = aat.to_dense();
764
765 assert!(approx_eq(dense[(0, 0)], 5.0));
767 assert!(approx_eq(dense[(0, 1)], 0.0));
768 assert!(approx_eq(dense[(0, 2)], 14.0));
769 assert!(approx_eq(dense[(1, 1)], 9.0));
770 assert!(approx_eq(dense[(2, 0)], 14.0));
771 assert!(approx_eq(dense[(2, 2)], 41.0));
772 }
773
774 #[test]
775 fn matmul_non_square() {
776 let a = Csr::new(vec![0, 2, 3], vec![0, 1, 2], vec![1.0, 2.0, 3.0], 2, 3);
778 let b = Csr::new(vec![0, 1, 2, 3], vec![0, 1, 0], vec![4.0, 5.0, 6.0], 3, 2);
779 let c = a.matmul(&b);
780 assert_eq!(c.nrows, 2);
781 assert_eq!(c.ncols, 2);
782 let dense = c.to_dense();
783 assert!(approx_eq(dense[(0, 0)], 4.0));
786 assert!(approx_eq(dense[(0, 1)], 10.0));
787 assert!(approx_eq(dense[(1, 0)], 18.0));
788 assert!(approx_eq(dense[(1, 1)], 0.0));
789 }
790
791 #[test]
792 fn matmul_dense_basic() {
793 let a = make_3x3();
794 let rhs = Mat::from_fn(3, 2, |i, j| match (i, j) {
796 (0, 0) | (1, 1) | (2, 0) | (2, 1) => 1.0_f64,
797 _ => 0.0,
798 });
799 let result = a.matmul_dense(&rhs.as_ref());
800 assert!(approx_eq(result[(0, 0)], 3.0));
804 assert!(approx_eq(result[(0, 1)], 2.0));
805 assert!(approx_eq(result[(1, 0)], 0.0));
806 assert!(approx_eq(result[(1, 1)], 3.0));
807 assert!(approx_eq(result[(2, 0)], 9.0));
808 assert!(approx_eq(result[(2, 1)], 5.0));
809 }
810
811 #[test]
812 fn elementwise_mul_with_transpose() {
813 let a = make_3x3();
814 let at = a.transpose();
815 let h = a.elementwise_mul(&at);
816 let dense = h.to_dense();
817
818 assert!(approx_eq(dense[(0, 0)], 1.0));
823 assert!(approx_eq(dense[(0, 2)], 8.0));
824 assert!(approx_eq(dense[(1, 1)], 9.0));
825 assert!(approx_eq(dense[(2, 0)], 8.0));
826 assert!(approx_eq(dense[(2, 2)], 25.0));
827 assert_eq!(h.nnz(), 5);
828 }
829
830 #[test]
831 fn elementwise_mul_disjoint() {
832 let a = Csr::new(vec![0, 1, 1], vec![0], vec![5.0], 2, 2);
834 let b = Csr::new(vec![0, 0, 1], vec![1], vec![7.0], 2, 2);
835 let h = a.elementwise_mul(&b);
836 assert_eq!(h.nnz(), 0);
837 }
838
839 #[test]
840 fn normalise_cols_l2_unit_norms() {
841 let m = Csr::new(
843 vec![0, 1, 2, 4],
844 vec![0, 1, 0, 1],
845 vec![1.0, 2.0, 3.0, 4.0],
846 3,
847 2,
848 );
849 let normed = m.normalise_cols_l2();
850
851 let mut col_sq = [0.0f64; 2];
853 for (idx, &v) in normed.data.iter().enumerate() {
854 col_sq[normed.indices[idx]] += v * v;
855 }
856 assert!(approx_eq(col_sq[0].sqrt(), 1.0));
857 assert!(approx_eq(col_sq[1].sqrt(), 1.0));
858
859 let c0_norm = (1.0f64 + 9.0).sqrt(); let c1_norm = (4.0f64 + 16.0).sqrt(); assert!(approx_eq(normed.data[0], 1.0 / c0_norm));
863 assert!(approx_eq(normed.data[1], 2.0 / c1_norm));
864 assert!(approx_eq(normed.data[2], 3.0 / c0_norm));
865 assert!(approx_eq(normed.data[3], 4.0 / c1_norm));
866 }
867
868 #[test]
869 fn normalise_cols_l2_empty_column() {
870 let m = Csr::new(vec![0, 1, 1], vec![0], vec![3.0], 2, 2);
872 let normed = m.normalise_cols_l2();
873 assert!(approx_eq(normed.data[0], 1.0)); }
875
876 #[test]
877 fn normalise_rows_l1_unit_sums() {
878 let m = Csr::new(
879 vec![0, 1, 2, 4],
880 vec![0, 1, 0, 1],
881 vec![1.0, 2.0, 3.0, 4.0],
882 3,
883 2,
884 );
885 let normed = m.normalise_rows_l1();
886
887 for i in 0..normed.nrows {
889 let sum: f64 = normed.data[normed.indptr[i]..normed.indptr[i + 1]]
890 .iter()
891 .map(|v: &f64| v.abs())
892 .sum();
893 assert!(approx_eq(sum, 1.0));
894 }
895
896 assert!(approx_eq(normed.data[0], 1.0));
898 assert!(approx_eq(normed.data[1], 1.0));
900 assert!(approx_eq(normed.data[2], 3.0 / 7.0));
902 assert!(approx_eq(normed.data[3], 4.0 / 7.0));
903 }
904
905 #[test]
906 fn normalise_rows_l1_empty_row() {
907 let m = Csr::new(vec![0, 0, 1], vec![0], vec![5.0], 2, 2);
908 let normed = m.normalise_rows_l1();
909 assert!(approx_eq(normed.data[0], 1.0));
911 }
912
913 #[test]
914 fn clip_values_basic() {
915 let mut m = Csr::new(vec![0, 3], vec![0, 1, 2], vec![-1.0, 0.5, 2.0], 1, 3);
916 m.clip_values(0.0, 1.0);
917 assert!(approx_eq(m.data[0], 0.0));
918 assert!(approx_eq(m.data[1], 0.5));
919 assert!(approx_eq(m.data[2], 1.0));
920 }
921
922 #[test]
923 fn to_adjacency_list_roundtrip() {
924 let a = make_3x3();
925 let adj = a.to_adjacency_list();
926 assert_eq!(adj.len(), 3);
927 assert_eq!(adj[0], vec![(0, 1.0), (2, 2.0)]);
928 assert_eq!(adj[1], vec![(1, 3.0)]);
929 assert_eq!(adj[2], vec![(0, 4.0), (2, 5.0)]);
930 }
931
932 #[test]
933 fn to_dense_roundtrip() {
934 let a = make_3x3();
935 let d = a.to_dense();
936 assert!(approx_eq(d[(0, 0)], 1.0));
937 assert!(approx_eq(d[(0, 1)], 0.0));
938 assert!(approx_eq(d[(0, 2)], 2.0));
939 assert!(approx_eq(d[(1, 1)], 3.0));
940 assert!(approx_eq(d[(2, 0)], 4.0));
941 assert!(approx_eq(d[(2, 2)], 5.0));
942 }
943
944 #[test]
945 fn vecs_mat_roundtrip() {
946 let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
947 let mat = vecs_to_mat(&rows);
948 let back = mat_to_vecs(&mat);
949 assert_eq!(rows, back);
950 }
951}