1use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14pub const WEIGHT_CLIP: f64 = 5.0;
16
17pub const GRAD_CLIP: f64 = 5.0;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Matrix {
35 pub data: Vec<f64>,
37 pub rows: usize,
39 pub cols: usize,
41}
42
43impl Matrix {
44 pub fn zeros(rows: usize, cols: usize) -> Self {
55 Self {
56 data: vec![0.0; rows * cols],
57 rows,
58 cols,
59 }
60 }
61
62 pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
77 let limit = (6.0 / (rows + cols) as f64).sqrt();
78 let data: Vec<f64> = (0..rows * cols)
79 .map(|_| rng.gen_range(-limit..limit))
80 .collect();
81 Self { data, rows, cols }
82 }
83
84 pub fn get(&self, row: usize, col: usize) -> f64 {
93 assert!(
94 row < self.rows && col < self.cols,
95 "Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
96 self.rows,
97 self.cols
98 );
99 self.data[row * self.cols + col]
100 }
101
102 pub fn set(&mut self, row: usize, col: usize, val: f64) {
112 assert!(
113 row < self.rows && col < self.cols,
114 "Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
115 self.rows,
116 self.cols
117 );
118 self.data[row * self.cols + col] = val;
119 }
120
121 pub fn transpose(&self) -> Self {
127 let mut result = Matrix::zeros(self.cols, self.rows);
128 for r in 0..self.rows {
129 for c in 0..self.cols {
130 result.set(c, r, self.get(r, c));
131 }
132 }
133 result
134 }
135
136 pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
150 assert_eq!(
151 v.len(),
152 self.cols,
153 "dimension mismatch: vector length {} != matrix cols {}",
154 v.len(),
155 self.cols
156 );
157 (0..self.rows)
158 .map(|r| {
159 let row_start = r * self.cols;
160 self.data[row_start..row_start + self.cols]
161 .iter()
162 .zip(v.iter())
163 .map(|(a, b)| a * b)
164 .sum()
165 })
166 .collect()
167 }
168
169 pub fn outer(a: &[f64], b: &[f64]) -> Self {
181 if a.is_empty() || b.is_empty() {
182 return Matrix::zeros(0, 0);
183 }
184 let rows = a.len();
185 let cols = b.len();
186 let mut data = vec![0.0; rows * cols];
187 for r in 0..rows {
188 for c in 0..cols {
189 data[r * cols + c] = a[r] * b[c];
190 }
191 }
192 Self { data, rows, cols }
193 }
194
195 pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
206 assert!(
207 self.rows == other.rows && self.cols == other.cols,
208 "dimension mismatch in scale_add: ({},{}) vs ({},{})",
209 self.rows,
210 self.cols,
211 other.rows,
212 other.cols
213 );
214 for i in 0..self.data.len() {
215 self.data[i] += scale * other.data[i];
216 self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
217 }
218 }
219}
220
221pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
235 let mut result = vec![0.0; logits.len()];
236 if mask.is_empty() {
237 return result;
238 }
239 assert!(
240 mask.iter().all(|&i| i < logits.len()),
241 "softmax_masked: mask index out of bounds (max mask={}, logits len={})",
242 mask.iter().max().unwrap_or(&0),
243 logits.len()
244 );
245
246 let max_val = mask
247 .iter()
248 .map(|&i| logits[i])
249 .fold(f64::NEG_INFINITY, f64::max);
250 let mut sum = 0.0;
251 for &i in mask {
252 let exp_val = (logits[i] - max_val).exp();
253 result[i] = exp_val;
254 sum += exp_val;
255 }
256 if sum > 0.0 {
257 for &i in mask {
258 result[i] /= sum;
259 }
260 }
261 result
262}
263
264pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
275 assert!(!mask.is_empty(), "argmax_masked: empty mask");
276 assert!(
277 mask.iter().all(|&i| i < values.len()),
278 "argmax_masked: mask index out of bounds (max mask={}, values len={})",
279 mask.iter().max().unwrap_or(&0),
280 values.len()
281 );
282 let mut best_idx = mask[0];
283 let mut best_val = values[mask[0]];
284 for &i in &mask[1..] {
285 if values[i] > best_val {
286 best_val = values[i];
287 best_idx = i;
288 }
289 }
290 best_idx
291}
292
293pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
303 let mut sum_sq = 0.0;
304 let mut count = 0usize;
305 for v in error_vecs {
306 for &e in *v {
307 sum_sq += e * e;
308 count += 1;
309 }
310 }
311 if count == 0 {
312 return 0.0;
313 }
314 (sum_sq / count as f64).sqrt()
315}
316
317pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
332 assert!(!mask.is_empty(), "sample_from_probs: empty mask");
333
334 if mask.len() == 1 {
335 return mask[0];
336 }
337
338 let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
339 if sum <= 0.0 {
340 return mask[rng.gen_range(0..mask.len())];
342 }
343
344 let threshold: f64 = rng.gen_range(0.0..1.0);
345 let mut cumulative = 0.0;
346 for &i in mask {
347 cumulative += probs[i] / sum;
348 if cumulative >= threshold {
349 return i;
350 }
351 }
352
353 *mask.last().unwrap()
355}
356
357pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
364 for x in v.iter_mut() {
365 *x = x.clamp(-max_abs, max_abs);
366 }
367}
368
369pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
380 assert_eq!(
381 a.len(),
382 b.len(),
383 "vec_sub: length mismatch {} vs {}",
384 a.len(),
385 b.len()
386 );
387 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
388}
389
390pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
401 assert_eq!(
402 a.len(),
403 b.len(),
404 "vec_add: length mismatch {} vs {}",
405 a.len(),
406 b.len()
407 );
408 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
422 v.iter().map(|x| x * s).collect()
423}
424
425pub fn cca_neuron_alignment<L: crate::linalg::LinAlg>(
441 act_a: &L::Matrix,
442 act_b: &L::Matrix,
443) -> Result<Vec<usize>, crate::error::PcError> {
444 let batch_size = L::mat_rows(act_a);
445 let n_a = L::mat_cols(act_a);
446 let n_b = L::mat_cols(act_b);
447 let k = n_a.min(n_b);
448
449 if k == 0 || batch_size < 2 {
450 return Ok((0..k).collect());
451 }
452
453 let std_a = standardize_columns::<L>(act_a);
455 let std_b = standardize_columns::<L>(act_b);
456
457 let scale = 1.0 / (batch_size as f64 - 1.0);
458
459 let std_a_t = L::mat_transpose(&std_a);
461 let std_b_t = L::mat_transpose(&std_b);
462
463 let mut c_a = L::mat_mul(&std_a_t, &std_a); let mut c_b = L::mat_mul(&std_b_t, &std_b); let mut c_ab = L::mat_mul(&std_a_t, &std_b); scale_matrix::<L>(&mut c_a, n_a, n_a, scale);
469 scale_matrix::<L>(&mut c_b, n_b, n_b, scale);
470 scale_matrix::<L>(&mut c_ab, n_a, n_b, scale);
471
472 let c_a_inv_sqrt = mat_inv_sqrt::<L>(&c_a)?;
474 let c_b_inv_sqrt = mat_inv_sqrt::<L>(&c_b)?;
475
476 let temp = L::mat_mul(&c_a_inv_sqrt, &c_ab);
478 let m = L::mat_mul(&temp, &c_b_inv_sqrt);
479
480 let (u, s, v) = L::svd(&m)?;
482
483 let n_canonical = L::mat_cols(&u).min(L::mat_cols(&v));
485
486 let mut cost = vec![vec![0.0; n_a]; n_b];
489 for (b, cost_row) in cost.iter_mut().enumerate() {
490 for (a, cost_cell) in cost_row.iter_mut().enumerate() {
491 let mut sim = 0.0;
492 for kk in 0..n_canonical {
493 let sk = L::vec_get(&s, kk);
494 sim += sk * L::mat_get(&u, a, kk).abs() * L::mat_get(&v, b, kk).abs();
495 }
496 *cost_cell = -sim; }
498 }
499
500 let assignment = hungarian_assignment(&cost);
502
503 let k = n_a.min(n_b);
505 let mut perm = vec![0usize; k];
506 for (b, &a) in assignment.iter().enumerate().take(k) {
507 perm[b] = a;
508 }
509
510 Ok(perm)
511}
512
513fn scale_matrix<L: crate::linalg::LinAlg>(m: &mut L::Matrix, rows: usize, cols: usize, s: f64) {
515 for r in 0..rows {
516 for c in 0..cols {
517 let val = L::mat_get(m, r, c);
518 L::mat_set(m, r, c, val * s);
519 }
520 }
521}
522
523fn standardize_columns<L: crate::linalg::LinAlg>(m: &L::Matrix) -> L::Matrix {
526 let rows = L::mat_rows(m);
527 let cols = L::mat_cols(m);
528 let mut result = L::zeros_mat(rows, cols);
529 let eps = 1e-12;
530
531 for c in 0..cols {
532 let mut sum = 0.0;
534 for r in 0..rows {
535 sum += L::mat_get(m, r, c);
536 }
537 let mean = sum / rows as f64;
538
539 let mut var_sum = 0.0;
541 for r in 0..rows {
542 let diff = L::mat_get(m, r, c) - mean;
543 var_sum += diff * diff;
544 }
545 let std = (var_sum / (rows as f64 - 1.0)).sqrt();
546
547 if std > eps {
548 for r in 0..rows {
549 L::mat_set(&mut result, r, c, (L::mat_get(m, r, c) - mean) / std);
550 }
551 }
552 }
554 result
555}
556
557fn mat_inv_sqrt<L: crate::linalg::LinAlg>(
560 m: &L::Matrix,
561) -> Result<L::Matrix, crate::error::PcError> {
562 let n = L::mat_rows(m);
563 let (u, s, _v) = L::svd(m)?;
564 let eps = 1e-10;
565
566 let k = L::vec_len(&s);
568 let mut diag_inv_sqrt = L::zeros_mat(k, k);
569 for i in 0..k {
570 let si = L::vec_get(&s, i);
571 if si > eps {
572 L::mat_set(&mut diag_inv_sqrt, i, i, 1.0 / si.sqrt());
573 }
574 }
575
576 let temp = L::mat_mul(&u, &diag_inv_sqrt);
579 let ut = L::mat_transpose(&u);
580 let mut result = L::mat_mul(&temp, &ut);
581
582 if L::mat_rows(&result) != n || L::mat_cols(&result) != n {
584 let mut padded = L::zeros_mat(n, n);
585 let r_rows = L::mat_rows(&result);
586 let r_cols = L::mat_cols(&result);
587 for r in 0..r_rows.min(n) {
588 for c in 0..r_cols.min(n) {
589 L::mat_set(&mut padded, r, c, L::mat_get(&result, r, c));
590 }
591 }
592 result = padded;
593 }
594
595 Ok(result)
596}
597
598pub(crate) fn hungarian_assignment(cost: &[Vec<f64>]) -> Vec<usize> {
609 let n_rows = cost.len();
610 if n_rows == 0 {
611 return vec![];
612 }
613 let n_cols = cost[0].len();
614 let n = n_rows.max(n_cols);
615
616 let mut c = vec![vec![0.0; n + 1]; n + 1]; for (i, row) in cost.iter().enumerate() {
619 for (j, &val) in row.iter().enumerate() {
620 c[i + 1][j + 1] = val;
621 }
622 }
623
624 let mut u = vec![0.0; n + 1];
626 let mut v = vec![0.0; n + 1];
627 let mut p = vec![0usize; n + 1];
629 let mut way = vec![0usize; n + 1];
631
632 for i in 1..=n {
633 p[0] = i;
635 let mut j0 = 0usize; let mut min_v = vec![f64::MAX; n + 1];
637 let mut used = vec![false; n + 1];
638
639 loop {
640 used[j0] = true;
641 let i0 = p[j0];
642 let mut delta = f64::MAX;
643 let mut j1 = 0usize;
644
645 for j in 1..=n {
646 if !used[j] {
647 let cur = c[i0][j] - u[i0] - v[j];
648 if cur < min_v[j] {
649 min_v[j] = cur;
650 way[j] = j0;
651 }
652 if min_v[j] < delta {
653 delta = min_v[j];
654 j1 = j;
655 }
656 }
657 }
658
659 for j in 0..=n {
661 if used[j] {
662 u[p[j]] += delta;
663 v[j] -= delta;
664 } else {
665 min_v[j] -= delta;
666 }
667 }
668
669 j0 = j1;
670
671 if p[j0] == 0 {
672 break; }
674 }
675
676 loop {
678 let j1 = way[j0];
679 p[j0] = p[j1];
680 j0 = j1;
681 if j0 == 0 {
682 break;
683 }
684 }
685 }
686
687 let mut result = vec![0usize; n_rows];
689 for j in 1..=n {
690 if p[j] >= 1 && p[j] <= n_rows {
691 result[p[j] - 1] = j - 1;
692 }
693 }
694 result
695}
696
697#[allow(dead_code)]
700fn greedy_match<L: crate::linalg::LinAlg>(
701 u: &L::Matrix,
702 v: &L::Matrix,
703 n_a: usize,
704 n_b: usize,
705) -> Vec<usize> {
706 let k = n_a.min(n_b);
707 let n_canonical = L::mat_cols(u).min(L::mat_cols(v));
708
709 let mut matched_a = vec![false; n_a];
710 let mut matched_b = vec![false; n_b];
711 let mut perm = vec![0usize; k];
712 let mut assigned = vec![false; k];
713
714 for col in 0..n_canonical {
716 let mut best_a = 0;
718 let mut best_a_val = 0.0_f64;
719 for (i, &is_matched) in matched_a.iter().enumerate().take(n_a.min(L::mat_rows(u))) {
720 let val = L::mat_get(u, i, col).abs();
721 if val > best_a_val && !is_matched {
722 best_a_val = val;
723 best_a = i;
724 }
725 }
726
727 let mut best_b = 0;
729 let mut best_b_val = 0.0_f64;
730 for (i, &is_matched) in matched_b.iter().enumerate().take(n_b.min(L::mat_rows(v))) {
731 let val = L::mat_get(v, i, col).abs();
732 if val > best_b_val && !is_matched {
733 best_b_val = val;
734 best_b = i;
735 }
736 }
737
738 if !matched_a[best_a] && !matched_b[best_b] && best_b < k {
739 perm[best_b] = best_a;
740 assigned[best_b] = true;
741 matched_a[best_a] = true;
742 matched_b[best_b] = true;
743 }
744 }
745
746 let remaining_a: Vec<usize> = (0..n_a).filter(|i| !matched_a[*i]).collect();
748 let unassigned_b: Vec<usize> = (0..k).filter(|i| !assigned[*i]).collect();
749 for (idx, &b_idx) in unassigned_b.iter().enumerate() {
750 if idx < remaining_a.len() {
751 perm[b_idx] = remaining_a[idx];
752 }
753 }
754
755 perm
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use rand::rngs::StdRng;
762 use rand::SeedableRng;
763
764 #[test]
767 fn test_zeros_all_zero_correct_dims() {
768 let m = Matrix::zeros(3, 4);
769 assert_eq!(m.rows, 3);
770 assert_eq!(m.cols, 4);
771 assert_eq!(m.data.len(), 12);
772 assert!(m.data.iter().all(|&v| v == 0.0));
773 }
774
775 #[test]
776 fn test_xavier_variance_approx() {
777 let mut rng = StdRng::seed_from_u64(42);
778 let m = Matrix::xavier(100, 100, &mut rng);
779 let n = m.data.len() as f64;
780 let mean = m.data.iter().sum::<f64>() / n;
781 let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
782 let expected_var = 2.0 / (100.0 + 100.0); assert!(
784 (variance - expected_var).abs() < expected_var * 0.5,
785 "variance {} not within 50% of expected {}",
786 variance,
787 expected_var
788 );
789 }
790
791 #[test]
792 fn test_xavier_all_finite() {
793 let mut rng = StdRng::seed_from_u64(42);
794 let m = Matrix::xavier(50, 50, &mut rng);
795 assert!(m.data.iter().all(|x| x.is_finite()));
796 }
797
798 #[test]
799 fn test_get_set_roundtrip() {
800 let mut m = Matrix::zeros(3, 3);
801 m.set(1, 2, 42.0);
802 assert_eq!(m.get(1, 2), 42.0);
803 }
804
805 #[test]
806 fn test_get_zero_default() {
807 let m = Matrix::zeros(2, 2);
808 assert_eq!(m.get(0, 0), 0.0);
809 }
810
811 #[test]
812 fn test_transpose_swaps_dims() {
813 let m = Matrix::zeros(3, 5);
814 let t = m.transpose();
815 assert_eq!(t.rows, 5);
816 assert_eq!(t.cols, 3);
817 }
818
819 #[test]
820 fn test_transpose_repositions_values() {
821 let mut m = Matrix::zeros(2, 3);
822 m.set(0, 1, 7.0);
823 m.set(1, 2, 3.0);
824 let t = m.transpose();
825 assert_eq!(t.get(1, 0), 7.0);
826 assert_eq!(t.get(2, 1), 3.0);
827 }
828
829 #[test]
830 fn test_transpose_double_is_identity() {
831 let mut rng = StdRng::seed_from_u64(42);
832 let m = Matrix::xavier(3, 5, &mut rng);
833 let tt = m.transpose().transpose();
834 assert_eq!(m.rows, tt.rows);
835 assert_eq!(m.cols, tt.cols);
836 for i in 0..m.data.len() {
837 assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
838 }
839 }
840
841 #[test]
842 fn test_mul_vec_known_result() {
843 let mut m = Matrix::zeros(2, 2);
845 m.set(0, 0, 1.0);
846 m.set(0, 1, 2.0);
847 m.set(1, 0, 3.0);
848 m.set(1, 1, 4.0);
849 let result = m.mul_vec(&[5.0, 6.0]);
850 assert_eq!(result.len(), 2);
851 assert!((result[0] - 17.0).abs() < 1e-10);
852 assert!((result[1] - 39.0).abs() < 1e-10);
853 }
854
855 #[test]
856 fn test_mul_vec_output_length_equals_rows() {
857 let m = Matrix::zeros(4, 3);
858 let result = m.mul_vec(&[1.0, 2.0, 3.0]);
859 assert_eq!(result.len(), 4);
860 }
861
862 #[test]
863 #[should_panic(expected = "dimension")]
864 fn test_mul_vec_panics_wrong_length() {
865 let m = Matrix::zeros(2, 3);
866 m.mul_vec(&[1.0, 2.0]); }
868
869 #[test]
870 fn test_mul_vec_zero_matrix_returns_zeros() {
871 let m = Matrix::zeros(3, 2);
872 let result = m.mul_vec(&[5.0, 10.0]);
873 assert!(result.iter().all(|&v| v == 0.0));
874 }
875
876 #[test]
877 fn test_outer_dims_and_values() {
878 let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
879 assert_eq!(m.rows, 2);
880 assert_eq!(m.cols, 3);
881 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
882 assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
883 assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
884 assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
885 assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
886 assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
887 }
888
889 #[test]
890 fn test_outer_empty_first_returns_zero_matrix() {
891 let m = Matrix::outer(&[], &[1.0, 2.0]);
892 assert_eq!(m.rows, 0);
893 assert_eq!(m.cols, 0);
894 }
895
896 #[test]
897 fn test_outer_empty_second_returns_zero_matrix() {
898 let m = Matrix::outer(&[1.0, 2.0], &[]);
899 assert_eq!(m.rows, 0);
900 assert_eq!(m.cols, 0);
901 }
902
903 #[test]
904 fn test_scale_add_basic() {
905 let mut m = Matrix::zeros(2, 2);
906 m.set(0, 0, 1.0);
907 m.set(1, 1, 2.0);
908 let mut other = Matrix::zeros(2, 2);
909 other.set(0, 0, 0.5);
910 other.set(1, 1, 0.5);
911 m.scale_add(&other, 2.0);
912 assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
913 assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
914 }
915
916 #[test]
917 fn test_scale_add_clips_to_weight_clip() {
918 let mut m = Matrix::zeros(1, 1);
919 m.set(0, 0, 4.0);
920 let mut other = Matrix::zeros(1, 1);
921 other.set(0, 0, 10.0);
922 m.scale_add(&other, 1.0);
923 assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
924 }
925
926 #[test]
927 fn test_scale_add_negative_clips_to_neg_weight_clip() {
928 let mut m = Matrix::zeros(1, 1);
929 m.set(0, 0, -4.0);
930 let mut other = Matrix::zeros(1, 1);
931 other.set(0, 0, -10.0);
932 m.scale_add(&other, 1.0);
933 assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
934 }
935
936 #[test]
937 fn test_scale_add_zero_scale_only_clips() {
938 let mut m = Matrix::zeros(1, 1);
939 m.set(0, 0, 3.0);
940 let other = Matrix::zeros(1, 1);
941 m.scale_add(&other, 0.0);
942 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
943 }
944
945 #[test]
946 #[should_panic(expected = "dimension")]
947 fn test_scale_add_panics_on_dimension_mismatch() {
948 let mut m = Matrix::zeros(2, 2);
949 let other = Matrix::zeros(3, 3);
950 m.scale_add(&other, 1.0);
951 }
952
953 #[test]
956 fn test_softmax_masked_sums_to_one() {
957 let logits = vec![1.0, 2.0, 3.0, 4.0];
958 let mask = vec![0, 1, 2, 3];
959 let probs = softmax_masked(&logits, &mask);
960 let sum: f64 = probs.iter().sum();
961 assert!((sum - 1.0).abs() < 1e-10);
962 }
963
964 #[test]
965 fn test_softmax_masked_unmasked_are_zero() {
966 let logits = vec![1.0, 2.0, 3.0, 4.0];
967 let mask = vec![1, 3];
968 let probs = softmax_masked(&logits, &mask);
969 assert_eq!(probs[0], 0.0);
970 assert_eq!(probs[2], 0.0);
971 assert!(probs[1] > 0.0);
972 assert!(probs[3] > 0.0);
973 }
974
975 #[test]
976 fn test_softmax_masked_single_index_is_one() {
977 let logits = vec![1.0, 2.0, 3.0];
978 let mask = vec![1];
979 let probs = softmax_masked(&logits, &mask);
980 assert!((probs[1] - 1.0).abs() < 1e-10);
981 }
982
983 #[test]
984 fn test_softmax_masked_empty_mask_returns_all_zeros() {
985 let logits = vec![1.0, 2.0, 3.0];
986 let probs = softmax_masked(&logits, &[]);
987 assert!(probs.iter().all(|&v| v == 0.0));
988 }
989
990 #[test]
991 fn test_softmax_masked_numerically_stable_large_logits() {
992 let logits = vec![1000.0, 1001.0, 1002.0];
993 let mask = vec![0, 1, 2];
994 let probs = softmax_masked(&logits, &mask);
995 assert!(probs.iter().all(|p| p.is_finite()));
996 let sum: f64 = probs.iter().sum();
997 assert!((sum - 1.0).abs() < 1e-10);
998 }
999
1000 #[test]
1001 fn test_softmax_masked_higher_logit_gets_higher_prob() {
1002 let logits = vec![1.0, 5.0, 2.0];
1003 let mask = vec![0, 1, 2];
1004 let probs = softmax_masked(&logits, &mask);
1005 assert!(probs[1] > probs[2]);
1006 assert!(probs[2] > probs[0]);
1007 }
1008
1009 #[test]
1012 fn test_argmax_masked_returns_highest_in_mask() {
1013 let values = vec![1.0, 5.0, 3.0, 4.0];
1014 let mask = vec![0, 2, 3];
1015 assert_eq!(argmax_masked(&values, &mask), 3);
1016 }
1017
1018 #[test]
1019 fn test_argmax_masked_single_element() {
1020 let values = vec![1.0, 5.0, 3.0];
1021 let mask = vec![2];
1022 assert_eq!(argmax_masked(&values, &mask), 2);
1023 }
1024
1025 #[test]
1026 fn test_argmax_masked_tie_returns_first() {
1027 let values = vec![3.0, 3.0, 3.0];
1028 let mask = vec![0, 1, 2];
1029 assert_eq!(argmax_masked(&values, &mask), 0);
1030 }
1031
1032 #[test]
1033 #[should_panic]
1034 fn test_argmax_masked_empty_panics() {
1035 let values = vec![1.0, 2.0];
1036 argmax_masked(&values, &[]);
1037 }
1038
1039 #[test]
1042 fn test_rms_error_empty_returns_zero() {
1043 assert_eq!(rms_error(&[]), 0.0);
1044 }
1045
1046 #[test]
1047 fn test_rms_error_single_empty_vec_returns_zero() {
1048 let empty: &[f64] = &[];
1049 assert_eq!(rms_error(&[empty]), 0.0);
1050 }
1051
1052 #[test]
1053 fn test_rms_error_known_two_vecs() {
1054 let v1: &[f64] = &[1.0, 0.0];
1055 let v2: &[f64] = &[0.0, 1.0];
1056 let rms = rms_error(&[v1, v2]);
1057 let expected = (0.5_f64).sqrt();
1059 assert!((rms - expected).abs() < 1e-10);
1060 }
1061
1062 #[test]
1063 fn test_rms_error_single_vec() {
1064 let v: &[f64] = &[3.0, 4.0];
1065 let rms = rms_error(&[v]);
1066 let expected = (25.0 / 2.0_f64).sqrt();
1068 assert!((rms - expected).abs() < 1e-10);
1069 }
1070
1071 #[test]
1072 fn test_rms_error_all_zeros_returns_zero() {
1073 let v: &[f64] = &[0.0, 0.0, 0.0];
1074 assert_eq!(rms_error(&[v]), 0.0);
1075 }
1076
1077 #[test]
1080 fn test_sample_from_probs_always_in_mask() {
1081 let mut rng = StdRng::seed_from_u64(42);
1082 let probs = vec![0.1, 0.2, 0.3, 0.4];
1083 let mask = vec![1, 3];
1084 for _ in 0..20 {
1085 let idx = sample_from_probs(&probs, &mask, &mut rng);
1086 assert!(mask.contains(&idx));
1087 }
1088 }
1089
1090 #[test]
1091 fn test_sample_from_probs_single_action_always_returns_it() {
1092 let mut rng = StdRng::seed_from_u64(42);
1093 let probs = vec![0.5, 0.5];
1094 let mask = vec![1];
1095 for _ in 0..10 {
1096 assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
1097 }
1098 }
1099
1100 #[test]
1101 fn test_sample_from_probs_visits_multiple_actions() {
1102 let mut rng = StdRng::seed_from_u64(42);
1103 let probs = vec![0.5, 0.5];
1104 let mask = vec![0, 1];
1105 let mut seen = [false; 2];
1106 for _ in 0..100 {
1107 let idx = sample_from_probs(&probs, &mask, &mut rng);
1108 seen[idx] = true;
1109 }
1110 assert!(seen[0] && seen[1], "should visit both actions");
1111 }
1112
1113 #[test]
1114 fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
1115 let mut rng = StdRng::seed_from_u64(42);
1116 let probs = vec![0.0, 0.0, 0.0];
1117 let mask = vec![0, 2];
1118 for _ in 0..20 {
1119 let idx = sample_from_probs(&probs, &mask, &mut rng);
1120 assert!(mask.contains(&idx));
1121 }
1122 }
1123
1124 #[test]
1125 #[should_panic]
1126 fn test_sample_from_probs_empty_mask_panics() {
1127 let mut rng = StdRng::seed_from_u64(42);
1128 let probs = vec![0.5, 0.5];
1129 sample_from_probs(&probs, &[], &mut rng);
1130 }
1131
1132 #[test]
1135 fn test_vec_sub_known() {
1136 let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
1137 assert!((result[0] - 2.0).abs() < 1e-10);
1138 assert!((result[1] - (-1.0)).abs() < 1e-10);
1139 }
1140
1141 #[test]
1142 fn test_vec_add_known() {
1143 let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
1144 assert!((result[0] - 4.0).abs() < 1e-10);
1145 assert!((result[1] - 6.0).abs() < 1e-10);
1146 }
1147
1148 #[test]
1149 fn test_vec_scale_known() {
1150 let result = vec_scale(&[1.0, -2.0], 3.0);
1151 assert!((result[0] - 3.0).abs() < 1e-10);
1152 assert!((result[1] - (-6.0)).abs() < 1e-10);
1153 }
1154
1155 #[test]
1156 fn test_clip_vec_clamps_positive() {
1157 let mut v = vec![10.0, -10.0, 0.5];
1158 clip_vec(&mut v, 5.0);
1159 assert!((v[0] - 5.0).abs() < 1e-10);
1160 assert!((v[1] - (-5.0)).abs() < 1e-10);
1161 assert!((v[2] - 0.5).abs() < 1e-10);
1162 }
1163
1164 #[test]
1165 #[should_panic(expected = "length mismatch")]
1166 fn test_vec_sub_panics_on_length_mismatch() {
1167 vec_sub(&[1.0, 2.0], &[1.0]);
1168 }
1169
1170 #[test]
1171 #[should_panic(expected = "length mismatch")]
1172 fn test_vec_add_panics_on_length_mismatch() {
1173 vec_add(&[1.0, 2.0], &[1.0]);
1174 }
1175
1176 #[test]
1177 fn test_clip_vec_leaves_safe_values() {
1178 let mut v = vec![1.0, -1.0, 0.0];
1179 clip_vec(&mut v, 5.0);
1180 assert!((v[0] - 1.0).abs() < 1e-10);
1181 assert!((v[1] - (-1.0)).abs() < 1e-10);
1182 assert!((v[2] - 0.0).abs() < 1e-10);
1183 }
1184
1185 #[test]
1188 #[should_panic(expected = "out of bounds")]
1189 fn test_get_panics_on_oob_row() {
1190 let m = Matrix::zeros(2, 2);
1191 m.get(5, 0); }
1193
1194 #[test]
1195 #[should_panic(expected = "out of bounds")]
1196 fn test_set_panics_on_oob_row() {
1197 let mut m = Matrix::zeros(2, 2);
1198 m.set(5, 0, 1.0); }
1200
1201 #[test]
1202 #[should_panic(expected = "mask index out of bounds")]
1203 fn test_softmax_masked_panics_on_oob_mask() {
1204 let logits = vec![1.0, 2.0, 3.0];
1205 softmax_masked(&logits, &[0, 5]); }
1207
1208 #[test]
1209 #[should_panic(expected = "mask index out of bounds")]
1210 fn test_argmax_masked_panics_on_oob_mask() {
1211 let values = vec![1.0, 2.0, 3.0];
1212 argmax_masked(&values, &[0, 5]); }
1214
1215 #[test]
1220 fn test_cca_identical_activations_identity_permutation() {
1221 use crate::linalg::cpu::CpuLinAlg;
1223 use crate::linalg::LinAlg;
1224
1225 let batch_size = 100;
1226 let n_neurons = 3;
1227 let mut rng = StdRng::seed_from_u64(42);
1228
1229 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1231 for r in 0..batch_size {
1232 for c in 0..n_neurons {
1233 let val: f64 = rng.gen_range(-1.0..1.0);
1234 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1235 }
1236 }
1237 let act_b = act_a.clone();
1238
1239 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1240 assert_eq!(perm.len(), n_neurons);
1241 assert_eq!(perm, vec![0, 1, 2]);
1242 }
1243
1244 #[test]
1245 fn test_cca_permutation_length_is_min() {
1246 use crate::linalg::cpu::CpuLinAlg;
1248 use crate::linalg::LinAlg;
1249
1250 let batch_size = 100;
1251 let mut rng = StdRng::seed_from_u64(42);
1252
1253 let mut act = CpuLinAlg::zeros_mat(batch_size, 4);
1254 for r in 0..batch_size {
1255 for c in 0..4 {
1256 let val: f64 = rng.gen_range(-1.0..1.0);
1257 CpuLinAlg::mat_set(&mut act, r, c, val);
1258 }
1259 }
1260
1261 let perm = cca_neuron_alignment::<CpuLinAlg>(&act, &act).unwrap();
1262 assert_eq!(perm.len(), 4);
1263 }
1264
1265 #[test]
1268 fn test_cca_permuted_activations_recovers_permutation() {
1269 use crate::linalg::cpu::CpuLinAlg;
1271 use crate::linalg::LinAlg;
1272
1273 let batch_size = 500;
1274 let n_neurons = 3;
1275 let mut rng = StdRng::seed_from_u64(42);
1276
1277 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1278 for r in 0..batch_size {
1279 for c in 0..n_neurons {
1280 let val: f64 = rng.gen_range(-1.0..1.0);
1281 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1282 }
1283 }
1284
1285 let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1290 let col_map = [2, 0, 1]; for r in 0..batch_size {
1292 for (j, &src_col) in col_map.iter().enumerate() {
1293 CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
1294 }
1295 }
1296
1297 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1298 assert_eq!(perm, vec![2, 0, 1]);
1299 }
1300
1301 #[test]
1302 fn test_cca_permuted_with_small_batch() {
1303 use crate::linalg::cpu::CpuLinAlg;
1305 use crate::linalg::LinAlg;
1306
1307 let batch_size = 50;
1308 let n_neurons = 3;
1309 let mut rng = StdRng::seed_from_u64(99);
1310
1311 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1312 for r in 0..batch_size {
1313 for c in 0..n_neurons {
1314 let val: f64 = rng.gen_range(-1.0..1.0);
1315 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1316 }
1317 }
1318
1319 let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1321 let col_map = [1, 2, 0];
1322 for r in 0..batch_size {
1323 for (j, &src_col) in col_map.iter().enumerate() {
1324 CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
1325 }
1326 }
1327
1328 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1329 assert_eq!(perm, vec![1, 2, 0]);
1330 }
1331
1332 #[test]
1333 fn test_cca_permuted_large_batch() {
1334 use crate::linalg::cpu::CpuLinAlg;
1336 use crate::linalg::LinAlg;
1337
1338 let batch_size = 500;
1339 let n_neurons = 4;
1340 let mut rng = StdRng::seed_from_u64(7);
1341
1342 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1343 for r in 0..batch_size {
1344 for c in 0..n_neurons {
1345 let val: f64 = rng.gen_range(-1.0..1.0);
1346 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1347 }
1348 }
1349
1350 let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1352 let col_map = [3, 1, 0, 2];
1353 for r in 0..batch_size {
1354 for (j, &src_col) in col_map.iter().enumerate() {
1355 CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
1356 }
1357 }
1358
1359 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1360 assert_eq!(perm, vec![3, 1, 0, 2]);
1361 }
1362
1363 #[test]
1366 fn test_cca_a_larger_than_b() {
1367 use crate::linalg::cpu::CpuLinAlg;
1369 use crate::linalg::LinAlg;
1370
1371 let batch_size = 200;
1372 let mut rng = StdRng::seed_from_u64(42);
1373
1374 let mut act_a = CpuLinAlg::zeros_mat(batch_size, 4);
1375 for r in 0..batch_size {
1376 for c in 0..4 {
1377 let val: f64 = rng.gen_range(-1.0..1.0);
1378 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1379 }
1380 }
1381
1382 let mut act_b = CpuLinAlg::zeros_mat(batch_size, 3);
1384 for r in 0..batch_size {
1385 for c in 0..3 {
1386 CpuLinAlg::mat_set(&mut act_b, r, c, CpuLinAlg::mat_get(&act_a, r, c));
1387 }
1388 }
1389
1390 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1391 assert_eq!(perm.len(), 3);
1392 }
1393
1394 #[test]
1395 fn test_cca_b_larger_than_a() {
1396 use crate::linalg::cpu::CpuLinAlg;
1398 use crate::linalg::LinAlg;
1399
1400 let batch_size = 200;
1401 let mut rng = StdRng::seed_from_u64(42);
1402
1403 let mut act_a = CpuLinAlg::zeros_mat(batch_size, 3);
1404 for r in 0..batch_size {
1405 for c in 0..3 {
1406 let val: f64 = rng.gen_range(-1.0..1.0);
1407 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1408 }
1409 }
1410
1411 let mut act_b = CpuLinAlg::zeros_mat(batch_size, 5);
1412 for r in 0..batch_size {
1413 for c in 0..5 {
1414 let val: f64 = rng.gen_range(-1.0..1.0);
1415 CpuLinAlg::mat_set(&mut act_b, r, c, val);
1416 }
1417 }
1418
1419 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1420 assert_eq!(perm.len(), 3);
1421 }
1422
1423 #[test]
1424 fn test_cca_dead_neuron_excluded() {
1425 use crate::linalg::cpu::CpuLinAlg;
1427 use crate::linalg::LinAlg;
1428
1429 let batch_size = 100;
1430 let n_neurons = 3;
1431 let mut rng = StdRng::seed_from_u64(42);
1432
1433 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1434 for r in 0..batch_size {
1435 for c in 0..n_neurons {
1436 let val: f64 = rng.gen_range(-1.0..1.0);
1437 CpuLinAlg::mat_set(&mut act_a, r, c, val);
1438 }
1439 }
1440
1441 let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
1443 for r in 0..batch_size {
1444 CpuLinAlg::mat_set(&mut act_b, r, 0, CpuLinAlg::mat_get(&act_a, r, 0));
1445 CpuLinAlg::mat_set(&mut act_b, r, 1, 0.0); CpuLinAlg::mat_set(&mut act_b, r, 2, CpuLinAlg::mat_get(&act_a, r, 2));
1447 }
1448
1449 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1450 assert_eq!(perm.len(), n_neurons);
1452 for &p in &perm {
1454 assert!(p < n_neurons, "permutation index {p} out of range");
1455 }
1456 let mut sorted = perm.clone();
1458 sorted.sort();
1459 sorted.dedup();
1460 assert_eq!(sorted.len(), n_neurons, "permutation has duplicates");
1461 }
1462
1463 #[test]
1466 fn test_hungarian_assignment_basic() {
1467 let assignment = hungarian_assignment(&[
1474 vec![1.0, 2.0, 3.0],
1475 vec![2.0, 4.0, 6.0],
1476 vec![3.0, 6.0, 9.0],
1477 ]);
1478 let total: f64 = assignment
1480 .iter()
1481 .enumerate()
1482 .map(|(i, &j)| [1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0][i * 3 + j])
1483 .sum();
1484 assert!(
1485 (total - 10.0).abs() < 1e-10,
1486 "Expected total cost 10, got {total}"
1487 );
1488 }
1489
1490 #[test]
1491 fn test_hungarian_assignment_permuted() {
1492 let assignment = hungarian_assignment(&[
1498 vec![5.0, 1.0, 3.0],
1499 vec![2.0, 8.0, 7.0],
1500 vec![6.0, 4.0, 1.0],
1501 ]);
1502 assert_eq!(assignment, vec![1, 0, 2]);
1503 }
1504
1505 #[test]
1506 fn test_hungarian_assignment_4x4() {
1507 let assignment = hungarian_assignment(&[
1514 vec![10.0, 5.0, 13.0, 15.0],
1515 vec![3.0, 9.0, 18.0, 6.0],
1516 vec![10.0, 7.0, 2.0, 12.0],
1517 vec![5.0, 11.0, 9.0, 4.0],
1518 ]);
1519 assert_eq!(assignment, vec![1, 0, 2, 3]);
1520 }
1521
1522 #[test]
1523 fn test_hungarian_assignment_1x1() {
1524 let assignment = hungarian_assignment(&[vec![42.0]]);
1525 assert_eq!(assignment, vec![0]);
1526 }
1527
1528 #[test]
1529 fn test_hungarian_optimal_vs_greedy_on_collision_case() {
1530 use crate::linalg::cpu::CpuLinAlg;
1537 use crate::linalg::LinAlg;
1538
1539 let batch_size = 200;
1540 let n = 8;
1541 let mut rng = StdRng::seed_from_u64(42);
1542
1543 let mut act_a = CpuLinAlg::zeros_mat(batch_size, n);
1545 for r in 0..batch_size {
1546 let base: f64 = rng.gen_range(-1.0..1.0);
1548 for c in 0..n {
1549 let noise: f64 = rng.gen_range(-0.3..0.3);
1550 let weight = (c as f64 + 1.0) / n as f64;
1552 CpuLinAlg::mat_set(&mut act_a, r, c, base * weight + noise);
1553 }
1554 }
1555
1556 let true_perm = [5, 3, 7, 1, 6, 0, 4, 2];
1558 let mut act_b = CpuLinAlg::zeros_mat(batch_size, n);
1559 for r in 0..batch_size {
1560 for (j, &src_col) in true_perm.iter().enumerate() {
1561 CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
1562 }
1563 }
1564
1565 let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
1566
1567 assert_eq!(
1569 perm,
1570 true_perm.to_vec(),
1571 "Hungarian should recover exact permutation for correlated neurons"
1572 );
1573 }
1574
1575 #[test]
1576 fn test_sample_from_probs_distribution_roughly_correct() {
1577 let mut rng = StdRng::seed_from_u64(42);
1578 let probs = vec![0.7, 0.3];
1579 let mask = vec![0, 1];
1580 let mut counts = [0usize; 2];
1581 let n = 1000;
1582 for _ in 0..n {
1583 let idx = sample_from_probs(&probs, &mask, &mut rng);
1584 counts[idx] += 1;
1585 }
1586 let ratio = counts[0] as f64 / n as f64;
1587 assert!(
1589 (ratio - 0.7).abs() < 0.1,
1590 "Expected ~0.7 for action 0, got {ratio}"
1591 );
1592 }
1593}