1use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14pub const WEIGHT_CLIP: f64 = 5.0;
16
17pub const GRAD_CLIP: f64 = 5.0;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Matrix {
35 pub data: Vec<f64>,
37 pub rows: usize,
39 pub cols: usize,
41}
42
43impl Matrix {
44 pub fn zeros(rows: usize, cols: usize) -> Self {
55 Self {
56 data: vec![0.0; rows * cols],
57 rows,
58 cols,
59 }
60 }
61
62 pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
77 let limit = (6.0 / (rows + cols) as f64).sqrt();
78 let data: Vec<f64> = (0..rows * cols)
79 .map(|_| rng.gen_range(-limit..limit))
80 .collect();
81 Self { data, rows, cols }
82 }
83
84 pub fn get(&self, row: usize, col: usize) -> f64 {
93 assert!(
94 row < self.rows && col < self.cols,
95 "Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
96 self.rows,
97 self.cols
98 );
99 self.data[row * self.cols + col]
100 }
101
102 pub fn set(&mut self, row: usize, col: usize, val: f64) {
112 assert!(
113 row < self.rows && col < self.cols,
114 "Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
115 self.rows,
116 self.cols
117 );
118 self.data[row * self.cols + col] = val;
119 }
120
121 pub fn transpose(&self) -> Self {
127 let mut result = Matrix::zeros(self.cols, self.rows);
128 for r in 0..self.rows {
129 for c in 0..self.cols {
130 result.set(c, r, self.get(r, c));
131 }
132 }
133 result
134 }
135
136 pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
150 assert_eq!(
151 v.len(),
152 self.cols,
153 "dimension mismatch: vector length {} != matrix cols {}",
154 v.len(),
155 self.cols
156 );
157 (0..self.rows)
158 .map(|r| {
159 let row_start = r * self.cols;
160 self.data[row_start..row_start + self.cols]
161 .iter()
162 .zip(v.iter())
163 .map(|(a, b)| a * b)
164 .sum()
165 })
166 .collect()
167 }
168
169 pub fn outer(a: &[f64], b: &[f64]) -> Self {
181 if a.is_empty() || b.is_empty() {
182 return Matrix::zeros(0, 0);
183 }
184 let rows = a.len();
185 let cols = b.len();
186 let mut data = vec![0.0; rows * cols];
187 for r in 0..rows {
188 for c in 0..cols {
189 data[r * cols + c] = a[r] * b[c];
190 }
191 }
192 Self { data, rows, cols }
193 }
194
195 pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
206 assert!(
207 self.rows == other.rows && self.cols == other.cols,
208 "dimension mismatch in scale_add: ({},{}) vs ({},{})",
209 self.rows,
210 self.cols,
211 other.rows,
212 other.cols
213 );
214 for i in 0..self.data.len() {
215 self.data[i] += scale * other.data[i];
216 self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
217 }
218 }
219}
220
221pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
235 let mut result = vec![0.0; logits.len()];
236 if mask.is_empty() {
237 return result;
238 }
239 assert!(
240 mask.iter().all(|&i| i < logits.len()),
241 "softmax_masked: mask index out of bounds (max mask={}, logits len={})",
242 mask.iter().max().unwrap_or(&0),
243 logits.len()
244 );
245
246 let max_val = mask
247 .iter()
248 .map(|&i| logits[i])
249 .fold(f64::NEG_INFINITY, f64::max);
250 let mut sum = 0.0;
251 for &i in mask {
252 let exp_val = (logits[i] - max_val).exp();
253 result[i] = exp_val;
254 sum += exp_val;
255 }
256 if sum > 0.0 {
257 for &i in mask {
258 result[i] /= sum;
259 }
260 }
261 result
262}
263
264pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
275 assert!(!mask.is_empty(), "argmax_masked: empty mask");
276 assert!(
277 mask.iter().all(|&i| i < values.len()),
278 "argmax_masked: mask index out of bounds (max mask={}, values len={})",
279 mask.iter().max().unwrap_or(&0),
280 values.len()
281 );
282 let mut best_idx = mask[0];
283 let mut best_val = values[mask[0]];
284 for &i in &mask[1..] {
285 if values[i] > best_val {
286 best_val = values[i];
287 best_idx = i;
288 }
289 }
290 best_idx
291}
292
293pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
303 let mut sum_sq = 0.0;
304 let mut count = 0usize;
305 for v in error_vecs {
306 for &e in *v {
307 sum_sq += e * e;
308 count += 1;
309 }
310 }
311 if count == 0 {
312 return 0.0;
313 }
314 (sum_sq / count as f64).sqrt()
315}
316
317pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
332 assert!(!mask.is_empty(), "sample_from_probs: empty mask");
333
334 if mask.len() == 1 {
335 return mask[0];
336 }
337
338 let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
339 if sum <= 0.0 {
340 return mask[rng.gen_range(0..mask.len())];
342 }
343
344 let threshold: f64 = rng.gen_range(0.0..1.0);
345 let mut cumulative = 0.0;
346 for &i in mask {
347 cumulative += probs[i] / sum;
348 if cumulative >= threshold {
349 return i;
350 }
351 }
352
353 *mask.last().unwrap()
355}
356
357pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
364 for x in v.iter_mut() {
365 *x = x.clamp(-max_abs, max_abs);
366 }
367}
368
369pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
380 assert_eq!(
381 a.len(),
382 b.len(),
383 "vec_sub: length mismatch {} vs {}",
384 a.len(),
385 b.len()
386 );
387 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
388}
389
390pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
401 assert_eq!(
402 a.len(),
403 b.len(),
404 "vec_add: length mismatch {} vs {}",
405 a.len(),
406 b.len()
407 );
408 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
422 v.iter().map(|x| x * s).collect()
423}
424
425pub fn cca_neuron_alignment<L: crate::linalg::LinAlg>(
441 backend: &L,
442 act_a: &L::Matrix,
443 act_b: &L::Matrix,
444) -> Result<Vec<usize>, crate::error::PcError> {
445 let batch_size = backend.mat_rows(act_a);
446 let n_a = backend.mat_cols(act_a);
447 let n_b = backend.mat_cols(act_b);
448 let k = n_a.min(n_b);
449
450 if k == 0 || batch_size < 2 {
451 return Ok((0..k).collect());
452 }
453
454 let std_a = standardize_columns(backend, act_a);
456 let std_b = standardize_columns(backend, act_b);
457
458 let scale = 1.0 / (batch_size as f64 - 1.0);
459
460 let std_a_t = backend.mat_transpose(&std_a);
462 let std_b_t = backend.mat_transpose(&std_b);
463
464 let mut c_a = backend.mat_mul(&std_a_t, &std_a); let mut c_b = backend.mat_mul(&std_b_t, &std_b); let mut c_ab = backend.mat_mul(&std_a_t, &std_b); scale_matrix(backend, &mut c_a, n_a, n_a, scale);
470 scale_matrix(backend, &mut c_b, n_b, n_b, scale);
471 scale_matrix(backend, &mut c_ab, n_a, n_b, scale);
472
473 let c_a_inv_sqrt = mat_inv_sqrt(backend, &c_a)?;
475 let c_b_inv_sqrt = mat_inv_sqrt(backend, &c_b)?;
476
477 let temp = backend.mat_mul(&c_a_inv_sqrt, &c_ab);
479 let m = backend.mat_mul(&temp, &c_b_inv_sqrt);
480
481 let (u, s, v) = backend.svd(&m)?;
483
484 let n_canonical = backend.mat_cols(&u).min(backend.mat_cols(&v));
486
487 let mut cost = vec![vec![0.0; n_a]; n_b];
490 for (b, cost_row) in cost.iter_mut().enumerate() {
491 for (a, cost_cell) in cost_row.iter_mut().enumerate() {
492 let mut sim = 0.0;
493 for kk in 0..n_canonical {
494 let sk = backend.vec_get(&s, kk);
495 sim += sk * backend.mat_get(&u, a, kk).abs() * backend.mat_get(&v, b, kk).abs();
496 }
497 *cost_cell = -sim; }
499 }
500
501 let assignment = hungarian_assignment(&cost);
503
504 let k = n_a.min(n_b);
506 let mut perm = vec![0usize; k];
507 for (b, &a) in assignment.iter().enumerate().take(k) {
508 perm[b] = a;
509 }
510
511 Ok(perm)
512}
513
514fn scale_matrix<L: crate::linalg::LinAlg>(
516 backend: &L,
517 m: &mut L::Matrix,
518 rows: usize,
519 cols: usize,
520 s: f64,
521) {
522 for r in 0..rows {
523 for c in 0..cols {
524 let val = backend.mat_get(m, r, c);
525 backend.mat_set(m, r, c, val * s);
526 }
527 }
528}
529
530fn standardize_columns<L: crate::linalg::LinAlg>(backend: &L, m: &L::Matrix) -> L::Matrix {
533 let rows = backend.mat_rows(m);
534 let cols = backend.mat_cols(m);
535 let mut result = backend.zeros_mat(rows, cols);
536 let eps = 1e-12;
537
538 for c in 0..cols {
539 let mut sum = 0.0;
541 for r in 0..rows {
542 sum += backend.mat_get(m, r, c);
543 }
544 let mean = sum / rows as f64;
545
546 let mut var_sum = 0.0;
548 for r in 0..rows {
549 let diff = backend.mat_get(m, r, c) - mean;
550 var_sum += diff * diff;
551 }
552 let std = (var_sum / (rows as f64 - 1.0)).sqrt();
553
554 if std > eps {
555 for r in 0..rows {
556 backend.mat_set(&mut result, r, c, (backend.mat_get(m, r, c) - mean) / std);
557 }
558 }
559 }
561 result
562}
563
564fn mat_inv_sqrt<L: crate::linalg::LinAlg>(
567 backend: &L,
568 m: &L::Matrix,
569) -> Result<L::Matrix, crate::error::PcError> {
570 let n = backend.mat_rows(m);
571 let (u, s, _v) = backend.svd(m)?;
572 let eps = 1e-10;
573
574 let k = backend.vec_len(&s);
576 let mut diag_inv_sqrt = backend.zeros_mat(k, k);
577 for i in 0..k {
578 let si = backend.vec_get(&s, i);
579 if si > eps {
580 backend.mat_set(&mut diag_inv_sqrt, i, i, 1.0 / si.sqrt());
581 }
582 }
583
584 let temp = backend.mat_mul(&u, &diag_inv_sqrt);
587 let ut = backend.mat_transpose(&u);
588 let mut result = backend.mat_mul(&temp, &ut);
589
590 if backend.mat_rows(&result) != n || backend.mat_cols(&result) != n {
592 let mut padded = backend.zeros_mat(n, n);
593 let r_rows = backend.mat_rows(&result);
594 let r_cols = backend.mat_cols(&result);
595 for r in 0..r_rows.min(n) {
596 for c in 0..r_cols.min(n) {
597 backend.mat_set(&mut padded, r, c, backend.mat_get(&result, r, c));
598 }
599 }
600 result = padded;
601 }
602
603 Ok(result)
604}
605
606pub(crate) fn hungarian_assignment(cost: &[Vec<f64>]) -> Vec<usize> {
617 let n_rows = cost.len();
618 if n_rows == 0 {
619 return vec![];
620 }
621 let n_cols = cost[0].len();
622 let n = n_rows.max(n_cols);
623
624 let mut c = vec![vec![0.0; n + 1]; n + 1]; for (i, row) in cost.iter().enumerate() {
627 for (j, &val) in row.iter().enumerate() {
628 c[i + 1][j + 1] = val;
629 }
630 }
631
632 let mut u = vec![0.0; n + 1];
634 let mut v = vec![0.0; n + 1];
635 let mut p = vec![0usize; n + 1];
637 let mut way = vec![0usize; n + 1];
639
640 for i in 1..=n {
641 p[0] = i;
643 let mut j0 = 0usize; let mut min_v = vec![f64::MAX; n + 1];
645 let mut used = vec![false; n + 1];
646
647 loop {
648 used[j0] = true;
649 let i0 = p[j0];
650 let mut delta = f64::MAX;
651 let mut j1 = 0usize;
652
653 for j in 1..=n {
654 if !used[j] {
655 let cur = c[i0][j] - u[i0] - v[j];
656 if cur < min_v[j] {
657 min_v[j] = cur;
658 way[j] = j0;
659 }
660 if min_v[j] < delta {
661 delta = min_v[j];
662 j1 = j;
663 }
664 }
665 }
666
667 for j in 0..=n {
669 if used[j] {
670 u[p[j]] += delta;
671 v[j] -= delta;
672 } else {
673 min_v[j] -= delta;
674 }
675 }
676
677 j0 = j1;
678
679 if p[j0] == 0 {
680 break; }
682 }
683
684 loop {
686 let j1 = way[j0];
687 p[j0] = p[j1];
688 j0 = j1;
689 if j0 == 0 {
690 break;
691 }
692 }
693 }
694
695 let mut result = vec![0usize; n_rows];
697 for j in 1..=n {
698 if p[j] >= 1 && p[j] <= n_rows {
699 result[p[j] - 1] = j - 1;
700 }
701 }
702 result
703}
704
705#[allow(dead_code)]
708fn greedy_match<L: crate::linalg::LinAlg>(
709 backend: &L,
710 u: &L::Matrix,
711 v: &L::Matrix,
712 n_a: usize,
713 n_b: usize,
714) -> Vec<usize> {
715 let k = n_a.min(n_b);
716 let n_canonical = backend.mat_cols(u).min(backend.mat_cols(v));
717
718 let mut matched_a = vec![false; n_a];
719 let mut matched_b = vec![false; n_b];
720 let mut perm = vec![0usize; k];
721 let mut assigned = vec![false; k];
722
723 for col in 0..n_canonical {
725 let mut best_a = 0;
727 let mut best_a_val = 0.0_f64;
728 for (i, &is_matched) in matched_a
729 .iter()
730 .enumerate()
731 .take(n_a.min(backend.mat_rows(u)))
732 {
733 let val = backend.mat_get(u, i, col).abs();
734 if val > best_a_val && !is_matched {
735 best_a_val = val;
736 best_a = i;
737 }
738 }
739
740 let mut best_b = 0;
742 let mut best_b_val = 0.0_f64;
743 for (i, &is_matched) in matched_b
744 .iter()
745 .enumerate()
746 .take(n_b.min(backend.mat_rows(v)))
747 {
748 let val = backend.mat_get(v, i, col).abs();
749 if val > best_b_val && !is_matched {
750 best_b_val = val;
751 best_b = i;
752 }
753 }
754
755 if !matched_a[best_a] && !matched_b[best_b] && best_b < k {
756 perm[best_b] = best_a;
757 assigned[best_b] = true;
758 matched_a[best_a] = true;
759 matched_b[best_b] = true;
760 }
761 }
762
763 let remaining_a: Vec<usize> = (0..n_a).filter(|i| !matched_a[*i]).collect();
765 let unassigned_b: Vec<usize> = (0..k).filter(|i| !assigned[*i]).collect();
766 for (idx, &b_idx) in unassigned_b.iter().enumerate() {
767 if idx < remaining_a.len() {
768 perm[b_idx] = remaining_a[idx];
769 }
770 }
771
772 perm
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778 use rand::rngs::StdRng;
779 use rand::SeedableRng;
780
781 #[test]
784 fn test_zeros_all_zero_correct_dims() {
785 let m = Matrix::zeros(3, 4);
786 assert_eq!(m.rows, 3);
787 assert_eq!(m.cols, 4);
788 assert_eq!(m.data.len(), 12);
789 assert!(m.data.iter().all(|&v| v == 0.0));
790 }
791
792 #[test]
793 fn test_xavier_variance_approx() {
794 let mut rng = StdRng::seed_from_u64(42);
795 let m = Matrix::xavier(100, 100, &mut rng);
796 let n = m.data.len() as f64;
797 let mean = m.data.iter().sum::<f64>() / n;
798 let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
799 let expected_var = 2.0 / (100.0 + 100.0); assert!(
801 (variance - expected_var).abs() < expected_var * 0.5,
802 "variance {} not within 50% of expected {}",
803 variance,
804 expected_var
805 );
806 }
807
808 #[test]
809 fn test_xavier_all_finite() {
810 let mut rng = StdRng::seed_from_u64(42);
811 let m = Matrix::xavier(50, 50, &mut rng);
812 assert!(m.data.iter().all(|x| x.is_finite()));
813 }
814
815 #[test]
816 fn test_get_set_roundtrip() {
817 let mut m = Matrix::zeros(3, 3);
818 m.set(1, 2, 42.0);
819 assert_eq!(m.get(1, 2), 42.0);
820 }
821
822 #[test]
823 fn test_get_zero_default() {
824 let m = Matrix::zeros(2, 2);
825 assert_eq!(m.get(0, 0), 0.0);
826 }
827
828 #[test]
829 fn test_transpose_swaps_dims() {
830 let m = Matrix::zeros(3, 5);
831 let t = m.transpose();
832 assert_eq!(t.rows, 5);
833 assert_eq!(t.cols, 3);
834 }
835
836 #[test]
837 fn test_transpose_repositions_values() {
838 let mut m = Matrix::zeros(2, 3);
839 m.set(0, 1, 7.0);
840 m.set(1, 2, 3.0);
841 let t = m.transpose();
842 assert_eq!(t.get(1, 0), 7.0);
843 assert_eq!(t.get(2, 1), 3.0);
844 }
845
846 #[test]
847 fn test_transpose_double_is_identity() {
848 let mut rng = StdRng::seed_from_u64(42);
849 let m = Matrix::xavier(3, 5, &mut rng);
850 let tt = m.transpose().transpose();
851 assert_eq!(m.rows, tt.rows);
852 assert_eq!(m.cols, tt.cols);
853 for i in 0..m.data.len() {
854 assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
855 }
856 }
857
858 #[test]
859 fn test_mul_vec_known_result() {
860 let mut m = Matrix::zeros(2, 2);
862 m.set(0, 0, 1.0);
863 m.set(0, 1, 2.0);
864 m.set(1, 0, 3.0);
865 m.set(1, 1, 4.0);
866 let result = m.mul_vec(&[5.0, 6.0]);
867 assert_eq!(result.len(), 2);
868 assert!((result[0] - 17.0).abs() < 1e-10);
869 assert!((result[1] - 39.0).abs() < 1e-10);
870 }
871
872 #[test]
873 fn test_mul_vec_output_length_equals_rows() {
874 let m = Matrix::zeros(4, 3);
875 let result = m.mul_vec(&[1.0, 2.0, 3.0]);
876 assert_eq!(result.len(), 4);
877 }
878
879 #[test]
880 #[should_panic(expected = "dimension")]
881 fn test_mul_vec_panics_wrong_length() {
882 let m = Matrix::zeros(2, 3);
883 m.mul_vec(&[1.0, 2.0]); }
885
886 #[test]
887 fn test_mul_vec_zero_matrix_returns_zeros() {
888 let m = Matrix::zeros(3, 2);
889 let result = m.mul_vec(&[5.0, 10.0]);
890 assert!(result.iter().all(|&v| v == 0.0));
891 }
892
893 #[test]
894 fn test_outer_dims_and_values() {
895 let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
896 assert_eq!(m.rows, 2);
897 assert_eq!(m.cols, 3);
898 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
899 assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
900 assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
901 assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
902 assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
903 assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
904 }
905
906 #[test]
907 fn test_outer_empty_first_returns_zero_matrix() {
908 let m = Matrix::outer(&[], &[1.0, 2.0]);
909 assert_eq!(m.rows, 0);
910 assert_eq!(m.cols, 0);
911 }
912
913 #[test]
914 fn test_outer_empty_second_returns_zero_matrix() {
915 let m = Matrix::outer(&[1.0, 2.0], &[]);
916 assert_eq!(m.rows, 0);
917 assert_eq!(m.cols, 0);
918 }
919
920 #[test]
921 fn test_scale_add_basic() {
922 let mut m = Matrix::zeros(2, 2);
923 m.set(0, 0, 1.0);
924 m.set(1, 1, 2.0);
925 let mut other = Matrix::zeros(2, 2);
926 other.set(0, 0, 0.5);
927 other.set(1, 1, 0.5);
928 m.scale_add(&other, 2.0);
929 assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
930 assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
931 }
932
933 #[test]
934 fn test_scale_add_clips_to_weight_clip() {
935 let mut m = Matrix::zeros(1, 1);
936 m.set(0, 0, 4.0);
937 let mut other = Matrix::zeros(1, 1);
938 other.set(0, 0, 10.0);
939 m.scale_add(&other, 1.0);
940 assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
941 }
942
943 #[test]
944 fn test_scale_add_negative_clips_to_neg_weight_clip() {
945 let mut m = Matrix::zeros(1, 1);
946 m.set(0, 0, -4.0);
947 let mut other = Matrix::zeros(1, 1);
948 other.set(0, 0, -10.0);
949 m.scale_add(&other, 1.0);
950 assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
951 }
952
953 #[test]
954 fn test_scale_add_zero_scale_only_clips() {
955 let mut m = Matrix::zeros(1, 1);
956 m.set(0, 0, 3.0);
957 let other = Matrix::zeros(1, 1);
958 m.scale_add(&other, 0.0);
959 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
960 }
961
962 #[test]
963 #[should_panic(expected = "dimension")]
964 fn test_scale_add_panics_on_dimension_mismatch() {
965 let mut m = Matrix::zeros(2, 2);
966 let other = Matrix::zeros(3, 3);
967 m.scale_add(&other, 1.0);
968 }
969
970 #[test]
973 fn test_softmax_masked_sums_to_one() {
974 let logits = vec![1.0, 2.0, 3.0, 4.0];
975 let mask = vec![0, 1, 2, 3];
976 let probs = softmax_masked(&logits, &mask);
977 let sum: f64 = probs.iter().sum();
978 assert!((sum - 1.0).abs() < 1e-10);
979 }
980
981 #[test]
982 fn test_softmax_masked_unmasked_are_zero() {
983 let logits = vec![1.0, 2.0, 3.0, 4.0];
984 let mask = vec![1, 3];
985 let probs = softmax_masked(&logits, &mask);
986 assert_eq!(probs[0], 0.0);
987 assert_eq!(probs[2], 0.0);
988 assert!(probs[1] > 0.0);
989 assert!(probs[3] > 0.0);
990 }
991
992 #[test]
993 fn test_softmax_masked_single_index_is_one() {
994 let logits = vec![1.0, 2.0, 3.0];
995 let mask = vec![1];
996 let probs = softmax_masked(&logits, &mask);
997 assert!((probs[1] - 1.0).abs() < 1e-10);
998 }
999
1000 #[test]
1001 fn test_softmax_masked_empty_mask_returns_all_zeros() {
1002 let logits = vec![1.0, 2.0, 3.0];
1003 let probs = softmax_masked(&logits, &[]);
1004 assert!(probs.iter().all(|&v| v == 0.0));
1005 }
1006
1007 #[test]
1008 fn test_softmax_masked_numerically_stable_large_logits() {
1009 let logits = vec![1000.0, 1001.0, 1002.0];
1010 let mask = vec![0, 1, 2];
1011 let probs = softmax_masked(&logits, &mask);
1012 assert!(probs.iter().all(|p| p.is_finite()));
1013 let sum: f64 = probs.iter().sum();
1014 assert!((sum - 1.0).abs() < 1e-10);
1015 }
1016
1017 #[test]
1018 fn test_softmax_masked_higher_logit_gets_higher_prob() {
1019 let logits = vec![1.0, 5.0, 2.0];
1020 let mask = vec![0, 1, 2];
1021 let probs = softmax_masked(&logits, &mask);
1022 assert!(probs[1] > probs[2]);
1023 assert!(probs[2] > probs[0]);
1024 }
1025
1026 #[test]
1029 fn test_argmax_masked_returns_highest_in_mask() {
1030 let values = vec![1.0, 5.0, 3.0, 4.0];
1031 let mask = vec![0, 2, 3];
1032 assert_eq!(argmax_masked(&values, &mask), 3);
1033 }
1034
1035 #[test]
1036 fn test_argmax_masked_single_element() {
1037 let values = vec![1.0, 5.0, 3.0];
1038 let mask = vec![2];
1039 assert_eq!(argmax_masked(&values, &mask), 2);
1040 }
1041
1042 #[test]
1043 fn test_argmax_masked_tie_returns_first() {
1044 let values = vec![3.0, 3.0, 3.0];
1045 let mask = vec![0, 1, 2];
1046 assert_eq!(argmax_masked(&values, &mask), 0);
1047 }
1048
1049 #[test]
1050 #[should_panic]
1051 fn test_argmax_masked_empty_panics() {
1052 let values = vec![1.0, 2.0];
1053 argmax_masked(&values, &[]);
1054 }
1055
1056 #[test]
1059 fn test_rms_error_empty_returns_zero() {
1060 assert_eq!(rms_error(&[]), 0.0);
1061 }
1062
1063 #[test]
1064 fn test_rms_error_single_empty_vec_returns_zero() {
1065 let empty: &[f64] = &[];
1066 assert_eq!(rms_error(&[empty]), 0.0);
1067 }
1068
1069 #[test]
1070 fn test_rms_error_known_two_vecs() {
1071 let v1: &[f64] = &[1.0, 0.0];
1072 let v2: &[f64] = &[0.0, 1.0];
1073 let rms = rms_error(&[v1, v2]);
1074 let expected = (0.5_f64).sqrt();
1076 assert!((rms - expected).abs() < 1e-10);
1077 }
1078
1079 #[test]
1080 fn test_rms_error_single_vec() {
1081 let v: &[f64] = &[3.0, 4.0];
1082 let rms = rms_error(&[v]);
1083 let expected = (25.0 / 2.0_f64).sqrt();
1085 assert!((rms - expected).abs() < 1e-10);
1086 }
1087
1088 #[test]
1089 fn test_rms_error_all_zeros_returns_zero() {
1090 let v: &[f64] = &[0.0, 0.0, 0.0];
1091 assert_eq!(rms_error(&[v]), 0.0);
1092 }
1093
1094 #[test]
1097 fn test_sample_from_probs_always_in_mask() {
1098 let mut rng = StdRng::seed_from_u64(42);
1099 let probs = vec![0.1, 0.2, 0.3, 0.4];
1100 let mask = vec![1, 3];
1101 for _ in 0..20 {
1102 let idx = sample_from_probs(&probs, &mask, &mut rng);
1103 assert!(mask.contains(&idx));
1104 }
1105 }
1106
1107 #[test]
1108 fn test_sample_from_probs_single_action_always_returns_it() {
1109 let mut rng = StdRng::seed_from_u64(42);
1110 let probs = vec![0.5, 0.5];
1111 let mask = vec![1];
1112 for _ in 0..10 {
1113 assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
1114 }
1115 }
1116
1117 #[test]
1118 fn test_sample_from_probs_visits_multiple_actions() {
1119 let mut rng = StdRng::seed_from_u64(42);
1120 let probs = vec![0.5, 0.5];
1121 let mask = vec![0, 1];
1122 let mut seen = [false; 2];
1123 for _ in 0..100 {
1124 let idx = sample_from_probs(&probs, &mask, &mut rng);
1125 seen[idx] = true;
1126 }
1127 assert!(seen[0] && seen[1], "should visit both actions");
1128 }
1129
1130 #[test]
1131 fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
1132 let mut rng = StdRng::seed_from_u64(42);
1133 let probs = vec![0.0, 0.0, 0.0];
1134 let mask = vec![0, 2];
1135 for _ in 0..20 {
1136 let idx = sample_from_probs(&probs, &mask, &mut rng);
1137 assert!(mask.contains(&idx));
1138 }
1139 }
1140
1141 #[test]
1142 #[should_panic]
1143 fn test_sample_from_probs_empty_mask_panics() {
1144 let mut rng = StdRng::seed_from_u64(42);
1145 let probs = vec![0.5, 0.5];
1146 sample_from_probs(&probs, &[], &mut rng);
1147 }
1148
1149 #[test]
1152 fn test_vec_sub_known() {
1153 let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
1154 assert!((result[0] - 2.0).abs() < 1e-10);
1155 assert!((result[1] - (-1.0)).abs() < 1e-10);
1156 }
1157
1158 #[test]
1159 fn test_vec_add_known() {
1160 let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
1161 assert!((result[0] - 4.0).abs() < 1e-10);
1162 assert!((result[1] - 6.0).abs() < 1e-10);
1163 }
1164
1165 #[test]
1166 fn test_vec_scale_known() {
1167 let result = vec_scale(&[1.0, -2.0], 3.0);
1168 assert!((result[0] - 3.0).abs() < 1e-10);
1169 assert!((result[1] - (-6.0)).abs() < 1e-10);
1170 }
1171
1172 #[test]
1173 fn test_clip_vec_clamps_positive() {
1174 let mut v = vec![10.0, -10.0, 0.5];
1175 clip_vec(&mut v, 5.0);
1176 assert!((v[0] - 5.0).abs() < 1e-10);
1177 assert!((v[1] - (-5.0)).abs() < 1e-10);
1178 assert!((v[2] - 0.5).abs() < 1e-10);
1179 }
1180
1181 #[test]
1182 #[should_panic(expected = "length mismatch")]
1183 fn test_vec_sub_panics_on_length_mismatch() {
1184 vec_sub(&[1.0, 2.0], &[1.0]);
1185 }
1186
1187 #[test]
1188 #[should_panic(expected = "length mismatch")]
1189 fn test_vec_add_panics_on_length_mismatch() {
1190 vec_add(&[1.0, 2.0], &[1.0]);
1191 }
1192
1193 #[test]
1194 fn test_clip_vec_leaves_safe_values() {
1195 let mut v = vec![1.0, -1.0, 0.0];
1196 clip_vec(&mut v, 5.0);
1197 assert!((v[0] - 1.0).abs() < 1e-10);
1198 assert!((v[1] - (-1.0)).abs() < 1e-10);
1199 assert!((v[2] - 0.0).abs() < 1e-10);
1200 }
1201
1202 #[test]
1205 #[should_panic(expected = "out of bounds")]
1206 fn test_get_panics_on_oob_row() {
1207 let m = Matrix::zeros(2, 2);
1208 m.get(5, 0); }
1210
1211 #[test]
1212 #[should_panic(expected = "out of bounds")]
1213 fn test_set_panics_on_oob_row() {
1214 let mut m = Matrix::zeros(2, 2);
1215 m.set(5, 0, 1.0); }
1217
1218 #[test]
1219 #[should_panic(expected = "mask index out of bounds")]
1220 fn test_softmax_masked_panics_on_oob_mask() {
1221 let logits = vec![1.0, 2.0, 3.0];
1222 softmax_masked(&logits, &[0, 5]); }
1224
1225 #[test]
1226 #[should_panic(expected = "mask index out of bounds")]
1227 fn test_argmax_masked_panics_on_oob_mask() {
1228 let values = vec![1.0, 2.0, 3.0];
1229 argmax_masked(&values, &[0, 5]); }
1231
1232 #[test]
1237 fn test_cca_identical_activations_identity_permutation() {
1238 use crate::linalg::cpu::CpuLinAlg;
1240 use crate::linalg::LinAlg;
1241 let backend = CpuLinAlg::new();
1242
1243 let batch_size = 100;
1244 let n_neurons = 3;
1245 let mut rng = StdRng::seed_from_u64(42);
1246
1247 let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1249 for r in 0..batch_size {
1250 for c in 0..n_neurons {
1251 let val: f64 = rng.gen_range(-1.0..1.0);
1252 backend.mat_set(&mut act_a, r, c, val);
1253 }
1254 }
1255 let act_b = act_a.clone();
1256
1257 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1258 assert_eq!(perm.len(), n_neurons);
1259 assert_eq!(perm, vec![0, 1, 2]);
1260 }
1261
1262 #[test]
1263 fn test_cca_permutation_length_is_min() {
1264 use crate::linalg::cpu::CpuLinAlg;
1266 use crate::linalg::LinAlg;
1267 let backend = CpuLinAlg::new();
1268
1269 let batch_size = 100;
1270 let mut rng = StdRng::seed_from_u64(42);
1271
1272 let mut act = backend.zeros_mat(batch_size, 4);
1273 for r in 0..batch_size {
1274 for c in 0..4 {
1275 let val: f64 = rng.gen_range(-1.0..1.0);
1276 backend.mat_set(&mut act, r, c, val);
1277 }
1278 }
1279
1280 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act, &act).unwrap();
1281 assert_eq!(perm.len(), 4);
1282 }
1283
1284 #[test]
1287 fn test_cca_permuted_activations_recovers_permutation() {
1288 use crate::linalg::cpu::CpuLinAlg;
1290 use crate::linalg::LinAlg;
1291 let backend = CpuLinAlg::new();
1292
1293 let batch_size = 500;
1294 let n_neurons = 3;
1295 let mut rng = StdRng::seed_from_u64(42);
1296
1297 let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1298 for r in 0..batch_size {
1299 for c in 0..n_neurons {
1300 let val: f64 = rng.gen_range(-1.0..1.0);
1301 backend.mat_set(&mut act_a, r, c, val);
1302 }
1303 }
1304
1305 let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1310 let col_map = [2, 0, 1]; for r in 0..batch_size {
1312 for (j, &src_col) in col_map.iter().enumerate() {
1313 backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1314 }
1315 }
1316
1317 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1318 assert_eq!(perm, vec![2, 0, 1]);
1319 }
1320
1321 #[test]
1322 fn test_cca_permuted_with_small_batch() {
1323 use crate::linalg::cpu::CpuLinAlg;
1325 use crate::linalg::LinAlg;
1326 let backend = CpuLinAlg::new();
1327
1328 let batch_size = 50;
1329 let n_neurons = 3;
1330 let mut rng = StdRng::seed_from_u64(99);
1331
1332 let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1333 for r in 0..batch_size {
1334 for c in 0..n_neurons {
1335 let val: f64 = rng.gen_range(-1.0..1.0);
1336 backend.mat_set(&mut act_a, r, c, val);
1337 }
1338 }
1339
1340 let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1342 let col_map = [1, 2, 0];
1343 for r in 0..batch_size {
1344 for (j, &src_col) in col_map.iter().enumerate() {
1345 backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1346 }
1347 }
1348
1349 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1350 assert_eq!(perm, vec![1, 2, 0]);
1351 }
1352
1353 #[test]
1354 fn test_cca_permuted_large_batch() {
1355 use crate::linalg::cpu::CpuLinAlg;
1357 use crate::linalg::LinAlg;
1358 let backend = CpuLinAlg::new();
1359
1360 let batch_size = 500;
1361 let n_neurons = 4;
1362 let mut rng = StdRng::seed_from_u64(7);
1363
1364 let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1365 for r in 0..batch_size {
1366 for c in 0..n_neurons {
1367 let val: f64 = rng.gen_range(-1.0..1.0);
1368 backend.mat_set(&mut act_a, r, c, val);
1369 }
1370 }
1371
1372 let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1374 let col_map = [3, 1, 0, 2];
1375 for r in 0..batch_size {
1376 for (j, &src_col) in col_map.iter().enumerate() {
1377 backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1378 }
1379 }
1380
1381 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1382 assert_eq!(perm, vec![3, 1, 0, 2]);
1383 }
1384
1385 #[test]
1388 fn test_cca_a_larger_than_b() {
1389 use crate::linalg::cpu::CpuLinAlg;
1391 use crate::linalg::LinAlg;
1392 let backend = CpuLinAlg::new();
1393
1394 let batch_size = 200;
1395 let mut rng = StdRng::seed_from_u64(42);
1396
1397 let mut act_a = backend.zeros_mat(batch_size, 4);
1398 for r in 0..batch_size {
1399 for c in 0..4 {
1400 let val: f64 = rng.gen_range(-1.0..1.0);
1401 backend.mat_set(&mut act_a, r, c, val);
1402 }
1403 }
1404
1405 let mut act_b = backend.zeros_mat(batch_size, 3);
1407 for r in 0..batch_size {
1408 for c in 0..3 {
1409 backend.mat_set(&mut act_b, r, c, backend.mat_get(&act_a, r, c));
1410 }
1411 }
1412
1413 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1414 assert_eq!(perm.len(), 3);
1415 }
1416
1417 #[test]
1418 fn test_cca_b_larger_than_a() {
1419 use crate::linalg::cpu::CpuLinAlg;
1421 use crate::linalg::LinAlg;
1422 let backend = CpuLinAlg::new();
1423
1424 let batch_size = 200;
1425 let mut rng = StdRng::seed_from_u64(42);
1426
1427 let mut act_a = backend.zeros_mat(batch_size, 3);
1428 for r in 0..batch_size {
1429 for c in 0..3 {
1430 let val: f64 = rng.gen_range(-1.0..1.0);
1431 backend.mat_set(&mut act_a, r, c, val);
1432 }
1433 }
1434
1435 let mut act_b = backend.zeros_mat(batch_size, 5);
1436 for r in 0..batch_size {
1437 for c in 0..5 {
1438 let val: f64 = rng.gen_range(-1.0..1.0);
1439 backend.mat_set(&mut act_b, r, c, val);
1440 }
1441 }
1442
1443 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1444 assert_eq!(perm.len(), 3);
1445 }
1446
1447 #[test]
1448 fn test_cca_dead_neuron_excluded() {
1449 use crate::linalg::cpu::CpuLinAlg;
1451 use crate::linalg::LinAlg;
1452 let backend = CpuLinAlg::new();
1453
1454 let batch_size = 100;
1455 let n_neurons = 3;
1456 let mut rng = StdRng::seed_from_u64(42);
1457
1458 let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1459 for r in 0..batch_size {
1460 for c in 0..n_neurons {
1461 let val: f64 = rng.gen_range(-1.0..1.0);
1462 backend.mat_set(&mut act_a, r, c, val);
1463 }
1464 }
1465
1466 let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1468 for r in 0..batch_size {
1469 backend.mat_set(&mut act_b, r, 0, backend.mat_get(&act_a, r, 0));
1470 backend.mat_set(&mut act_b, r, 1, 0.0); backend.mat_set(&mut act_b, r, 2, backend.mat_get(&act_a, r, 2));
1472 }
1473
1474 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1475 assert_eq!(perm.len(), n_neurons);
1477 for &p in &perm {
1479 assert!(p < n_neurons, "permutation index {p} out of range");
1480 }
1481 let mut sorted = perm.clone();
1483 sorted.sort();
1484 sorted.dedup();
1485 assert_eq!(sorted.len(), n_neurons, "permutation has duplicates");
1486 }
1487
1488 #[test]
1491 fn test_hungarian_assignment_basic() {
1492 let assignment = hungarian_assignment(&[
1499 vec![1.0, 2.0, 3.0],
1500 vec![2.0, 4.0, 6.0],
1501 vec![3.0, 6.0, 9.0],
1502 ]);
1503 let total: f64 = assignment
1505 .iter()
1506 .enumerate()
1507 .map(|(i, &j)| [1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0][i * 3 + j])
1508 .sum();
1509 assert!(
1510 (total - 10.0).abs() < 1e-10,
1511 "Expected total cost 10, got {total}"
1512 );
1513 }
1514
1515 #[test]
1516 fn test_hungarian_assignment_permuted() {
1517 let assignment = hungarian_assignment(&[
1523 vec![5.0, 1.0, 3.0],
1524 vec![2.0, 8.0, 7.0],
1525 vec![6.0, 4.0, 1.0],
1526 ]);
1527 assert_eq!(assignment, vec![1, 0, 2]);
1528 }
1529
1530 #[test]
1531 fn test_hungarian_assignment_4x4() {
1532 let assignment = hungarian_assignment(&[
1539 vec![10.0, 5.0, 13.0, 15.0],
1540 vec![3.0, 9.0, 18.0, 6.0],
1541 vec![10.0, 7.0, 2.0, 12.0],
1542 vec![5.0, 11.0, 9.0, 4.0],
1543 ]);
1544 assert_eq!(assignment, vec![1, 0, 2, 3]);
1545 }
1546
1547 #[test]
1548 fn test_hungarian_assignment_1x1() {
1549 let assignment = hungarian_assignment(&[vec![42.0]]);
1550 assert_eq!(assignment, vec![0]);
1551 }
1552
1553 #[test]
1554 fn test_hungarian_optimal_vs_greedy_on_collision_case() {
1555 use crate::linalg::cpu::CpuLinAlg;
1562 use crate::linalg::LinAlg;
1563 let backend = CpuLinAlg::new();
1564
1565 let batch_size = 200;
1566 let n = 8;
1567 let mut rng = StdRng::seed_from_u64(42);
1568
1569 let mut act_a = backend.zeros_mat(batch_size, n);
1571 for r in 0..batch_size {
1572 let base: f64 = rng.gen_range(-1.0..1.0);
1574 for c in 0..n {
1575 let noise: f64 = rng.gen_range(-0.3..0.3);
1576 let weight = (c as f64 + 1.0) / n as f64;
1578 backend.mat_set(&mut act_a, r, c, base * weight + noise);
1579 }
1580 }
1581
1582 let true_perm = [5, 3, 7, 1, 6, 0, 4, 2];
1584 let mut act_b = backend.zeros_mat(batch_size, n);
1585 for r in 0..batch_size {
1586 for (j, &src_col) in true_perm.iter().enumerate() {
1587 backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1588 }
1589 }
1590
1591 let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1592
1593 assert_eq!(
1595 perm,
1596 true_perm.to_vec(),
1597 "Hungarian should recover exact permutation for correlated neurons"
1598 );
1599 }
1600
1601 #[test]
1602 fn test_sample_from_probs_distribution_roughly_correct() {
1603 let mut rng = StdRng::seed_from_u64(42);
1604 let probs = vec![0.7, 0.3];
1605 let mask = vec![0, 1];
1606 let mut counts = [0usize; 2];
1607 let n = 1000;
1608 for _ in 0..n {
1609 let idx = sample_from_probs(&probs, &mask, &mut rng);
1610 counts[idx] += 1;
1611 }
1612 let ratio = counts[0] as f64 / n as f64;
1613 assert!(
1615 (ratio - 0.7).abs() < 0.1,
1616 "Expected ~0.7 for action 0, got {ratio}"
1617 );
1618 }
1619}