use rand::Rng;
use serde::{Deserialize, Serialize};
pub const WEIGHT_CLIP: f64 = 5.0;
pub const GRAD_CLIP: f64 = 5.0;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Matrix {
pub data: Vec<f64>,
pub rows: usize,
pub cols: usize,
}
impl Matrix {
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
let limit = (6.0 / (rows + cols) as f64).sqrt();
let data: Vec<f64> = (0..rows * cols)
.map(|_| rng.gen_range(-limit..limit))
.collect();
Self { data, rows, cols }
}
pub fn get(&self, row: usize, col: usize) -> f64 {
assert!(
row < self.rows && col < self.cols,
"Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
self.rows,
self.cols
);
self.data[row * self.cols + col]
}
pub fn set(&mut self, row: usize, col: usize, val: f64) {
assert!(
row < self.rows && col < self.cols,
"Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
self.rows,
self.cols
);
self.data[row * self.cols + col] = val;
}
pub fn transpose(&self) -> Self {
let mut result = Matrix::zeros(self.cols, self.rows);
for r in 0..self.rows {
for c in 0..self.cols {
result.set(c, r, self.get(r, c));
}
}
result
}
pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
assert_eq!(
v.len(),
self.cols,
"dimension mismatch: vector length {} != matrix cols {}",
v.len(),
self.cols
);
(0..self.rows)
.map(|r| {
let row_start = r * self.cols;
self.data[row_start..row_start + self.cols]
.iter()
.zip(v.iter())
.map(|(a, b)| a * b)
.sum()
})
.collect()
}
pub fn outer(a: &[f64], b: &[f64]) -> Self {
if a.is_empty() || b.is_empty() {
return Matrix::zeros(0, 0);
}
let rows = a.len();
let cols = b.len();
let mut data = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
data[r * cols + c] = a[r] * b[c];
}
}
Self { data, rows, cols }
}
pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
assert!(
self.rows == other.rows && self.cols == other.cols,
"dimension mismatch in scale_add: ({},{}) vs ({},{})",
self.rows,
self.cols,
other.rows,
other.cols
);
for i in 0..self.data.len() {
self.data[i] += scale * other.data[i];
self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
}
}
}
pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
let mut result = vec![0.0; logits.len()];
if mask.is_empty() {
return result;
}
assert!(
mask.iter().all(|&i| i < logits.len()),
"softmax_masked: mask index out of bounds (max mask={}, logits len={})",
mask.iter().max().unwrap_or(&0),
logits.len()
);
let max_val = mask
.iter()
.map(|&i| logits[i])
.fold(f64::NEG_INFINITY, f64::max);
let mut sum = 0.0;
for &i in mask {
let exp_val = (logits[i] - max_val).exp();
result[i] = exp_val;
sum += exp_val;
}
if sum > 0.0 {
for &i in mask {
result[i] /= sum;
}
}
result
}
pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
assert!(!mask.is_empty(), "argmax_masked: empty mask");
assert!(
mask.iter().all(|&i| i < values.len()),
"argmax_masked: mask index out of bounds (max mask={}, values len={})",
mask.iter().max().unwrap_or(&0),
values.len()
);
let mut best_idx = mask[0];
let mut best_val = values[mask[0]];
for &i in &mask[1..] {
if values[i] > best_val {
best_val = values[i];
best_idx = i;
}
}
best_idx
}
pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
let mut sum_sq = 0.0;
let mut count = 0usize;
for v in error_vecs {
for &e in *v {
sum_sq += e * e;
count += 1;
}
}
if count == 0 {
return 0.0;
}
(sum_sq / count as f64).sqrt()
}
pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
assert!(!mask.is_empty(), "sample_from_probs: empty mask");
if mask.len() == 1 {
return mask[0];
}
let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
if sum <= 0.0 {
return mask[rng.gen_range(0..mask.len())];
}
let threshold: f64 = rng.gen_range(0.0..1.0);
let mut cumulative = 0.0;
for &i in mask {
cumulative += probs[i] / sum;
if cumulative >= threshold {
return i;
}
}
*mask.last().unwrap()
}
pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
for x in v.iter_mut() {
*x = x.clamp(-max_abs, max_abs);
}
}
pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
assert_eq!(
a.len(),
b.len(),
"vec_sub: length mismatch {} vs {}",
a.len(),
b.len()
);
a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
}
pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
assert_eq!(
a.len(),
b.len(),
"vec_add: length mismatch {} vs {}",
a.len(),
b.len()
);
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
v.iter().map(|x| x * s).collect()
}
pub fn cca_neuron_alignment<L: crate::linalg::LinAlg>(
act_a: &L::Matrix,
act_b: &L::Matrix,
) -> Result<Vec<usize>, crate::error::PcError> {
let batch_size = L::mat_rows(act_a);
let n_a = L::mat_cols(act_a);
let n_b = L::mat_cols(act_b);
let k = n_a.min(n_b);
if k == 0 || batch_size < 2 {
return Ok((0..k).collect());
}
let std_a = standardize_columns::<L>(act_a);
let std_b = standardize_columns::<L>(act_b);
let scale = 1.0 / (batch_size as f64 - 1.0);
let std_a_t = L::mat_transpose(&std_a);
let std_b_t = L::mat_transpose(&std_b);
let mut c_a = L::mat_mul(&std_a_t, &std_a); let mut c_b = L::mat_mul(&std_b_t, &std_b); let mut c_ab = L::mat_mul(&std_a_t, &std_b);
scale_matrix::<L>(&mut c_a, n_a, n_a, scale);
scale_matrix::<L>(&mut c_b, n_b, n_b, scale);
scale_matrix::<L>(&mut c_ab, n_a, n_b, scale);
let c_a_inv_sqrt = mat_inv_sqrt::<L>(&c_a)?;
let c_b_inv_sqrt = mat_inv_sqrt::<L>(&c_b)?;
let temp = L::mat_mul(&c_a_inv_sqrt, &c_ab);
let m = L::mat_mul(&temp, &c_b_inv_sqrt);
let (u, s, v) = L::svd(&m)?;
let n_canonical = L::mat_cols(&u).min(L::mat_cols(&v));
let mut cost = vec![vec![0.0; n_a]; n_b];
for (b, cost_row) in cost.iter_mut().enumerate() {
for (a, cost_cell) in cost_row.iter_mut().enumerate() {
let mut sim = 0.0;
for kk in 0..n_canonical {
let sk = L::vec_get(&s, kk);
sim += sk * L::mat_get(&u, a, kk).abs() * L::mat_get(&v, b, kk).abs();
}
*cost_cell = -sim; }
}
let assignment = hungarian_assignment(&cost);
let k = n_a.min(n_b);
let mut perm = vec![0usize; k];
for (b, &a) in assignment.iter().enumerate().take(k) {
perm[b] = a;
}
Ok(perm)
}
fn scale_matrix<L: crate::linalg::LinAlg>(m: &mut L::Matrix, rows: usize, cols: usize, s: f64) {
for r in 0..rows {
for c in 0..cols {
let val = L::mat_get(m, r, c);
L::mat_set(m, r, c, val * s);
}
}
}
fn standardize_columns<L: crate::linalg::LinAlg>(m: &L::Matrix) -> L::Matrix {
let rows = L::mat_rows(m);
let cols = L::mat_cols(m);
let mut result = L::zeros_mat(rows, cols);
let eps = 1e-12;
for c in 0..cols {
let mut sum = 0.0;
for r in 0..rows {
sum += L::mat_get(m, r, c);
}
let mean = sum / rows as f64;
let mut var_sum = 0.0;
for r in 0..rows {
let diff = L::mat_get(m, r, c) - mean;
var_sum += diff * diff;
}
let std = (var_sum / (rows as f64 - 1.0)).sqrt();
if std > eps {
for r in 0..rows {
L::mat_set(&mut result, r, c, (L::mat_get(m, r, c) - mean) / std);
}
}
}
result
}
fn mat_inv_sqrt<L: crate::linalg::LinAlg>(
m: &L::Matrix,
) -> Result<L::Matrix, crate::error::PcError> {
let n = L::mat_rows(m);
let (u, s, _v) = L::svd(m)?;
let eps = 1e-10;
let k = L::vec_len(&s);
let mut diag_inv_sqrt = L::zeros_mat(k, k);
for i in 0..k {
let si = L::vec_get(&s, i);
if si > eps {
L::mat_set(&mut diag_inv_sqrt, i, i, 1.0 / si.sqrt());
}
}
let temp = L::mat_mul(&u, &diag_inv_sqrt);
let ut = L::mat_transpose(&u);
let mut result = L::mat_mul(&temp, &ut);
if L::mat_rows(&result) != n || L::mat_cols(&result) != n {
let mut padded = L::zeros_mat(n, n);
let r_rows = L::mat_rows(&result);
let r_cols = L::mat_cols(&result);
for r in 0..r_rows.min(n) {
for c in 0..r_cols.min(n) {
L::mat_set(&mut padded, r, c, L::mat_get(&result, r, c));
}
}
result = padded;
}
Ok(result)
}
pub(crate) fn hungarian_assignment(cost: &[Vec<f64>]) -> Vec<usize> {
let n_rows = cost.len();
if n_rows == 0 {
return vec![];
}
let n_cols = cost[0].len();
let n = n_rows.max(n_cols);
let mut c = vec![vec![0.0; n + 1]; n + 1]; for (i, row) in cost.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
c[i + 1][j + 1] = val;
}
}
let mut u = vec![0.0; n + 1];
let mut v = vec![0.0; n + 1];
let mut p = vec![0usize; n + 1];
let mut way = vec![0usize; n + 1];
for i in 1..=n {
p[0] = i;
let mut j0 = 0usize; let mut min_v = vec![f64::MAX; n + 1];
let mut used = vec![false; n + 1];
loop {
used[j0] = true;
let i0 = p[j0];
let mut delta = f64::MAX;
let mut j1 = 0usize;
for j in 1..=n {
if !used[j] {
let cur = c[i0][j] - u[i0] - v[j];
if cur < min_v[j] {
min_v[j] = cur;
way[j] = j0;
}
if min_v[j] < delta {
delta = min_v[j];
j1 = j;
}
}
}
for j in 0..=n {
if used[j] {
u[p[j]] += delta;
v[j] -= delta;
} else {
min_v[j] -= delta;
}
}
j0 = j1;
if p[j0] == 0 {
break; }
}
loop {
let j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
if j0 == 0 {
break;
}
}
}
let mut result = vec![0usize; n_rows];
for j in 1..=n {
if p[j] >= 1 && p[j] <= n_rows {
result[p[j] - 1] = j - 1;
}
}
result
}
#[allow(dead_code)]
fn greedy_match<L: crate::linalg::LinAlg>(
u: &L::Matrix,
v: &L::Matrix,
n_a: usize,
n_b: usize,
) -> Vec<usize> {
let k = n_a.min(n_b);
let n_canonical = L::mat_cols(u).min(L::mat_cols(v));
let mut matched_a = vec![false; n_a];
let mut matched_b = vec![false; n_b];
let mut perm = vec![0usize; k];
let mut assigned = vec![false; k];
for col in 0..n_canonical {
let mut best_a = 0;
let mut best_a_val = 0.0_f64;
for (i, &is_matched) in matched_a.iter().enumerate().take(n_a.min(L::mat_rows(u))) {
let val = L::mat_get(u, i, col).abs();
if val > best_a_val && !is_matched {
best_a_val = val;
best_a = i;
}
}
let mut best_b = 0;
let mut best_b_val = 0.0_f64;
for (i, &is_matched) in matched_b.iter().enumerate().take(n_b.min(L::mat_rows(v))) {
let val = L::mat_get(v, i, col).abs();
if val > best_b_val && !is_matched {
best_b_val = val;
best_b = i;
}
}
if !matched_a[best_a] && !matched_b[best_b] && best_b < k {
perm[best_b] = best_a;
assigned[best_b] = true;
matched_a[best_a] = true;
matched_b[best_b] = true;
}
}
let remaining_a: Vec<usize> = (0..n_a).filter(|i| !matched_a[*i]).collect();
let unassigned_b: Vec<usize> = (0..k).filter(|i| !assigned[*i]).collect();
for (idx, &b_idx) in unassigned_b.iter().enumerate() {
if idx < remaining_a.len() {
perm[b_idx] = remaining_a[idx];
}
}
perm
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn test_zeros_all_zero_correct_dims() {
let m = Matrix::zeros(3, 4);
assert_eq!(m.rows, 3);
assert_eq!(m.cols, 4);
assert_eq!(m.data.len(), 12);
assert!(m.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_xavier_variance_approx() {
let mut rng = StdRng::seed_from_u64(42);
let m = Matrix::xavier(100, 100, &mut rng);
let n = m.data.len() as f64;
let mean = m.data.iter().sum::<f64>() / n;
let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let expected_var = 2.0 / (100.0 + 100.0); assert!(
(variance - expected_var).abs() < expected_var * 0.5,
"variance {} not within 50% of expected {}",
variance,
expected_var
);
}
#[test]
fn test_xavier_all_finite() {
let mut rng = StdRng::seed_from_u64(42);
let m = Matrix::xavier(50, 50, &mut rng);
assert!(m.data.iter().all(|x| x.is_finite()));
}
#[test]
fn test_get_set_roundtrip() {
let mut m = Matrix::zeros(3, 3);
m.set(1, 2, 42.0);
assert_eq!(m.get(1, 2), 42.0);
}
#[test]
fn test_get_zero_default() {
let m = Matrix::zeros(2, 2);
assert_eq!(m.get(0, 0), 0.0);
}
#[test]
fn test_transpose_swaps_dims() {
let m = Matrix::zeros(3, 5);
let t = m.transpose();
assert_eq!(t.rows, 5);
assert_eq!(t.cols, 3);
}
#[test]
fn test_transpose_repositions_values() {
let mut m = Matrix::zeros(2, 3);
m.set(0, 1, 7.0);
m.set(1, 2, 3.0);
let t = m.transpose();
assert_eq!(t.get(1, 0), 7.0);
assert_eq!(t.get(2, 1), 3.0);
}
#[test]
fn test_transpose_double_is_identity() {
let mut rng = StdRng::seed_from_u64(42);
let m = Matrix::xavier(3, 5, &mut rng);
let tt = m.transpose().transpose();
assert_eq!(m.rows, tt.rows);
assert_eq!(m.cols, tt.cols);
for i in 0..m.data.len() {
assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
}
}
#[test]
fn test_mul_vec_known_result() {
let mut m = Matrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let result = m.mul_vec(&[5.0, 6.0]);
assert_eq!(result.len(), 2);
assert!((result[0] - 17.0).abs() < 1e-10);
assert!((result[1] - 39.0).abs() < 1e-10);
}
#[test]
fn test_mul_vec_output_length_equals_rows() {
let m = Matrix::zeros(4, 3);
let result = m.mul_vec(&[1.0, 2.0, 3.0]);
assert_eq!(result.len(), 4);
}
#[test]
#[should_panic(expected = "dimension")]
fn test_mul_vec_panics_wrong_length() {
let m = Matrix::zeros(2, 3);
m.mul_vec(&[1.0, 2.0]); }
#[test]
fn test_mul_vec_zero_matrix_returns_zeros() {
let m = Matrix::zeros(3, 2);
let result = m.mul_vec(&[5.0, 10.0]);
assert!(result.iter().all(|&v| v == 0.0));
}
#[test]
fn test_outer_dims_and_values() {
let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
assert_eq!(m.rows, 2);
assert_eq!(m.cols, 3);
assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
}
#[test]
fn test_outer_empty_first_returns_zero_matrix() {
let m = Matrix::outer(&[], &[1.0, 2.0]);
assert_eq!(m.rows, 0);
assert_eq!(m.cols, 0);
}
#[test]
fn test_outer_empty_second_returns_zero_matrix() {
let m = Matrix::outer(&[1.0, 2.0], &[]);
assert_eq!(m.rows, 0);
assert_eq!(m.cols, 0);
}
#[test]
fn test_scale_add_basic() {
let mut m = Matrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(1, 1, 2.0);
let mut other = Matrix::zeros(2, 2);
other.set(0, 0, 0.5);
other.set(1, 1, 0.5);
m.scale_add(&other, 2.0);
assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
}
#[test]
fn test_scale_add_clips_to_weight_clip() {
let mut m = Matrix::zeros(1, 1);
m.set(0, 0, 4.0);
let mut other = Matrix::zeros(1, 1);
other.set(0, 0, 10.0);
m.scale_add(&other, 1.0);
assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
}
#[test]
fn test_scale_add_negative_clips_to_neg_weight_clip() {
let mut m = Matrix::zeros(1, 1);
m.set(0, 0, -4.0);
let mut other = Matrix::zeros(1, 1);
other.set(0, 0, -10.0);
m.scale_add(&other, 1.0);
assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
}
#[test]
fn test_scale_add_zero_scale_only_clips() {
let mut m = Matrix::zeros(1, 1);
m.set(0, 0, 3.0);
let other = Matrix::zeros(1, 1);
m.scale_add(&other, 0.0);
assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "dimension")]
fn test_scale_add_panics_on_dimension_mismatch() {
let mut m = Matrix::zeros(2, 2);
let other = Matrix::zeros(3, 3);
m.scale_add(&other, 1.0);
}
#[test]
fn test_softmax_masked_sums_to_one() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let mask = vec![0, 1, 2, 3];
let probs = softmax_masked(&logits, &mask);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_softmax_masked_unmasked_are_zero() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let mask = vec![1, 3];
let probs = softmax_masked(&logits, &mask);
assert_eq!(probs[0], 0.0);
assert_eq!(probs[2], 0.0);
assert!(probs[1] > 0.0);
assert!(probs[3] > 0.0);
}
#[test]
fn test_softmax_masked_single_index_is_one() {
let logits = vec![1.0, 2.0, 3.0];
let mask = vec![1];
let probs = softmax_masked(&logits, &mask);
assert!((probs[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_softmax_masked_empty_mask_returns_all_zeros() {
let logits = vec![1.0, 2.0, 3.0];
let probs = softmax_masked(&logits, &[]);
assert!(probs.iter().all(|&v| v == 0.0));
}
#[test]
fn test_softmax_masked_numerically_stable_large_logits() {
let logits = vec![1000.0, 1001.0, 1002.0];
let mask = vec![0, 1, 2];
let probs = softmax_masked(&logits, &mask);
assert!(probs.iter().all(|p| p.is_finite()));
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_softmax_masked_higher_logit_gets_higher_prob() {
let logits = vec![1.0, 5.0, 2.0];
let mask = vec![0, 1, 2];
let probs = softmax_masked(&logits, &mask);
assert!(probs[1] > probs[2]);
assert!(probs[2] > probs[0]);
}
#[test]
fn test_argmax_masked_returns_highest_in_mask() {
let values = vec![1.0, 5.0, 3.0, 4.0];
let mask = vec![0, 2, 3];
assert_eq!(argmax_masked(&values, &mask), 3);
}
#[test]
fn test_argmax_masked_single_element() {
let values = vec![1.0, 5.0, 3.0];
let mask = vec![2];
assert_eq!(argmax_masked(&values, &mask), 2);
}
#[test]
fn test_argmax_masked_tie_returns_first() {
let values = vec![3.0, 3.0, 3.0];
let mask = vec![0, 1, 2];
assert_eq!(argmax_masked(&values, &mask), 0);
}
#[test]
#[should_panic]
fn test_argmax_masked_empty_panics() {
let values = vec![1.0, 2.0];
argmax_masked(&values, &[]);
}
#[test]
fn test_rms_error_empty_returns_zero() {
assert_eq!(rms_error(&[]), 0.0);
}
#[test]
fn test_rms_error_single_empty_vec_returns_zero() {
let empty: &[f64] = &[];
assert_eq!(rms_error(&[empty]), 0.0);
}
#[test]
fn test_rms_error_known_two_vecs() {
let v1: &[f64] = &[1.0, 0.0];
let v2: &[f64] = &[0.0, 1.0];
let rms = rms_error(&[v1, v2]);
let expected = (0.5_f64).sqrt();
assert!((rms - expected).abs() < 1e-10);
}
#[test]
fn test_rms_error_single_vec() {
let v: &[f64] = &[3.0, 4.0];
let rms = rms_error(&[v]);
let expected = (25.0 / 2.0_f64).sqrt();
assert!((rms - expected).abs() < 1e-10);
}
#[test]
fn test_rms_error_all_zeros_returns_zero() {
let v: &[f64] = &[0.0, 0.0, 0.0];
assert_eq!(rms_error(&[v]), 0.0);
}
#[test]
fn test_sample_from_probs_always_in_mask() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.1, 0.2, 0.3, 0.4];
let mask = vec![1, 3];
for _ in 0..20 {
let idx = sample_from_probs(&probs, &mask, &mut rng);
assert!(mask.contains(&idx));
}
}
#[test]
fn test_sample_from_probs_single_action_always_returns_it() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.5, 0.5];
let mask = vec![1];
for _ in 0..10 {
assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
}
}
#[test]
fn test_sample_from_probs_visits_multiple_actions() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.5, 0.5];
let mask = vec![0, 1];
let mut seen = [false; 2];
for _ in 0..100 {
let idx = sample_from_probs(&probs, &mask, &mut rng);
seen[idx] = true;
}
assert!(seen[0] && seen[1], "should visit both actions");
}
#[test]
fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.0, 0.0, 0.0];
let mask = vec![0, 2];
for _ in 0..20 {
let idx = sample_from_probs(&probs, &mask, &mut rng);
assert!(mask.contains(&idx));
}
}
#[test]
#[should_panic]
fn test_sample_from_probs_empty_mask_panics() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.5, 0.5];
sample_from_probs(&probs, &[], &mut rng);
}
#[test]
fn test_vec_sub_known() {
let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
assert!((result[0] - 2.0).abs() < 1e-10);
assert!((result[1] - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_vec_add_known() {
let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
assert!((result[0] - 4.0).abs() < 1e-10);
assert!((result[1] - 6.0).abs() < 1e-10);
}
#[test]
fn test_vec_scale_known() {
let result = vec_scale(&[1.0, -2.0], 3.0);
assert!((result[0] - 3.0).abs() < 1e-10);
assert!((result[1] - (-6.0)).abs() < 1e-10);
}
#[test]
fn test_clip_vec_clamps_positive() {
let mut v = vec![10.0, -10.0, 0.5];
clip_vec(&mut v, 5.0);
assert!((v[0] - 5.0).abs() < 1e-10);
assert!((v[1] - (-5.0)).abs() < 1e-10);
assert!((v[2] - 0.5).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "length mismatch")]
fn test_vec_sub_panics_on_length_mismatch() {
vec_sub(&[1.0, 2.0], &[1.0]);
}
#[test]
#[should_panic(expected = "length mismatch")]
fn test_vec_add_panics_on_length_mismatch() {
vec_add(&[1.0, 2.0], &[1.0]);
}
#[test]
fn test_clip_vec_leaves_safe_values() {
let mut v = vec![1.0, -1.0, 0.0];
clip_vec(&mut v, 5.0);
assert!((v[0] - 1.0).abs() < 1e-10);
assert!((v[1] - (-1.0)).abs() < 1e-10);
assert!((v[2] - 0.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "out of bounds")]
fn test_get_panics_on_oob_row() {
let m = Matrix::zeros(2, 2);
m.get(5, 0); }
#[test]
#[should_panic(expected = "out of bounds")]
fn test_set_panics_on_oob_row() {
let mut m = Matrix::zeros(2, 2);
m.set(5, 0, 1.0); }
#[test]
#[should_panic(expected = "mask index out of bounds")]
fn test_softmax_masked_panics_on_oob_mask() {
let logits = vec![1.0, 2.0, 3.0];
softmax_masked(&logits, &[0, 5]); }
#[test]
#[should_panic(expected = "mask index out of bounds")]
fn test_argmax_masked_panics_on_oob_mask() {
let values = vec![1.0, 2.0, 3.0];
argmax_masked(&values, &[0, 5]); }
#[test]
fn test_cca_identical_activations_identity_permutation() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 100;
let n_neurons = 3;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
for c in 0..n_neurons {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let act_b = act_a.clone();
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm.len(), n_neurons);
assert_eq!(perm, vec![0, 1, 2]);
}
#[test]
fn test_cca_permutation_length_is_min() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 100;
let mut rng = StdRng::seed_from_u64(42);
let mut act = CpuLinAlg::zeros_mat(batch_size, 4);
for r in 0..batch_size {
for c in 0..4 {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act, r, c, val);
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act, &act).unwrap();
assert_eq!(perm.len(), 4);
}
#[test]
fn test_cca_permuted_activations_recovers_permutation() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 500;
let n_neurons = 3;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
for c in 0..n_neurons {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
let col_map = [2, 0, 1]; for r in 0..batch_size {
for (j, &src_col) in col_map.iter().enumerate() {
CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm, vec![2, 0, 1]);
}
#[test]
fn test_cca_permuted_with_small_batch() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 50;
let n_neurons = 3;
let mut rng = StdRng::seed_from_u64(99);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
for c in 0..n_neurons {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
let col_map = [1, 2, 0];
for r in 0..batch_size {
for (j, &src_col) in col_map.iter().enumerate() {
CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm, vec![1, 2, 0]);
}
#[test]
fn test_cca_permuted_large_batch() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 500;
let n_neurons = 4;
let mut rng = StdRng::seed_from_u64(7);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
for c in 0..n_neurons {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
let col_map = [3, 1, 0, 2];
for r in 0..batch_size {
for (j, &src_col) in col_map.iter().enumerate() {
CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm, vec![3, 1, 0, 2]);
}
#[test]
fn test_cca_a_larger_than_b() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 200;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, 4);
for r in 0..batch_size {
for c in 0..4 {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, 3);
for r in 0..batch_size {
for c in 0..3 {
CpuLinAlg::mat_set(&mut act_b, r, c, CpuLinAlg::mat_get(&act_a, r, c));
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm.len(), 3);
}
#[test]
fn test_cca_b_larger_than_a() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 200;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, 3);
for r in 0..batch_size {
for c in 0..3 {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, 5);
for r in 0..batch_size {
for c in 0..5 {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_b, r, c, val);
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm.len(), 3);
}
#[test]
fn test_cca_dead_neuron_excluded() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 100;
let n_neurons = 3;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
for c in 0..n_neurons {
let val: f64 = rng.gen_range(-1.0..1.0);
CpuLinAlg::mat_set(&mut act_a, r, c, val);
}
}
let mut act_b = CpuLinAlg::zeros_mat(batch_size, n_neurons);
for r in 0..batch_size {
CpuLinAlg::mat_set(&mut act_b, r, 0, CpuLinAlg::mat_get(&act_a, r, 0));
CpuLinAlg::mat_set(&mut act_b, r, 1, 0.0); CpuLinAlg::mat_set(&mut act_b, r, 2, CpuLinAlg::mat_get(&act_a, r, 2));
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(perm.len(), n_neurons);
for &p in &perm {
assert!(p < n_neurons, "permutation index {p} out of range");
}
let mut sorted = perm.clone();
sorted.sort();
sorted.dedup();
assert_eq!(sorted.len(), n_neurons, "permutation has duplicates");
}
#[test]
fn test_hungarian_assignment_basic() {
let assignment = hungarian_assignment(&[
vec![1.0, 2.0, 3.0],
vec![2.0, 4.0, 6.0],
vec![3.0, 6.0, 9.0],
]);
let total: f64 = assignment
.iter()
.enumerate()
.map(|(i, &j)| [1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0][i * 3 + j])
.sum();
assert!(
(total - 10.0).abs() < 1e-10,
"Expected total cost 10, got {total}"
);
}
#[test]
fn test_hungarian_assignment_permuted() {
let assignment = hungarian_assignment(&[
vec![5.0, 1.0, 3.0],
vec![2.0, 8.0, 7.0],
vec![6.0, 4.0, 1.0],
]);
assert_eq!(assignment, vec![1, 0, 2]);
}
#[test]
fn test_hungarian_assignment_4x4() {
let assignment = hungarian_assignment(&[
vec![10.0, 5.0, 13.0, 15.0],
vec![3.0, 9.0, 18.0, 6.0],
vec![10.0, 7.0, 2.0, 12.0],
vec![5.0, 11.0, 9.0, 4.0],
]);
assert_eq!(assignment, vec![1, 0, 2, 3]);
}
#[test]
fn test_hungarian_assignment_1x1() {
let assignment = hungarian_assignment(&[vec![42.0]]);
assert_eq!(assignment, vec![0]);
}
#[test]
fn test_hungarian_optimal_vs_greedy_on_collision_case() {
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
let batch_size = 200;
let n = 8;
let mut rng = StdRng::seed_from_u64(42);
let mut act_a = CpuLinAlg::zeros_mat(batch_size, n);
for r in 0..batch_size {
let base: f64 = rng.gen_range(-1.0..1.0);
for c in 0..n {
let noise: f64 = rng.gen_range(-0.3..0.3);
let weight = (c as f64 + 1.0) / n as f64;
CpuLinAlg::mat_set(&mut act_a, r, c, base * weight + noise);
}
}
let true_perm = [5, 3, 7, 1, 6, 0, 4, 2];
let mut act_b = CpuLinAlg::zeros_mat(batch_size, n);
for r in 0..batch_size {
for (j, &src_col) in true_perm.iter().enumerate() {
CpuLinAlg::mat_set(&mut act_b, r, j, CpuLinAlg::mat_get(&act_a, r, src_col));
}
}
let perm = cca_neuron_alignment::<CpuLinAlg>(&act_a, &act_b).unwrap();
assert_eq!(
perm,
true_perm.to_vec(),
"Hungarian should recover exact permutation for correlated neurons"
);
}
#[test]
fn test_sample_from_probs_distribution_roughly_correct() {
let mut rng = StdRng::seed_from_u64(42);
let probs = vec![0.7, 0.3];
let mask = vec![0, 1];
let mut counts = [0usize; 2];
let n = 1000;
for _ in 0..n {
let idx = sample_from_probs(&probs, &mask, &mut rng);
counts[idx] += 1;
}
let ratio = counts[0] as f64 / n as f64;
assert!(
(ratio - 0.7).abs() < 0.1,
"Expected ~0.7 for action 0, got {ratio}"
);
}
}