Skip to main content

pc_rl_core/
matrix.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Dense matrix operations and vector utilities for neural networks.
6//!
7//! Provides a custom [`Matrix`] struct and free functions for softmax,
8//! argmax, RMS error, categorical sampling, and element-wise vector ops.
9//! Pure Rust with no external linear-algebra dependencies.
10
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14/// Maximum absolute value for weight clamping after updates.
15pub const WEIGHT_CLIP: f64 = 5.0;
16
17/// Maximum absolute value for gradient clamping.
18pub const GRAD_CLIP: f64 = 5.0;
19
20/// A dense row-major matrix of `f64` values.
21///
22/// Data is stored in a flat `Vec<f64>` of length `rows * cols`.
23///
24/// # Examples
25///
26/// ```
27/// use pc_rl_core::matrix::Matrix;
28///
29/// let m = Matrix::zeros(2, 3);
30/// assert_eq!(m.rows, 2);
31/// assert_eq!(m.cols, 3);
32/// ```
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Matrix {
35    /// Flat row-major storage.
36    pub data: Vec<f64>,
37    /// Number of rows.
38    pub rows: usize,
39    /// Number of columns.
40    pub cols: usize,
41}
42
43impl Matrix {
44    /// Creates a matrix filled with zeros.
45    ///
46    /// # Arguments
47    ///
48    /// * `rows` - Number of rows.
49    /// * `cols` - Number of columns.
50    ///
51    /// # Returns
52    ///
53    /// A `Matrix` with all elements set to `0.0`.
54    pub fn zeros(rows: usize, cols: usize) -> Self {
55        Self {
56            data: vec![0.0; rows * cols],
57            rows,
58            cols,
59        }
60    }
61
62    /// Creates a matrix with Xavier-uniform initialization.
63    ///
64    /// Elements are drawn uniformly from `[-limit, limit]` where
65    /// `limit = sqrt(6.0 / (rows + cols))`.
66    ///
67    /// # Arguments
68    ///
69    /// * `rows` - Number of rows.
70    /// * `cols` - Number of columns.
71    /// * `rng` - Mutable reference to a random number generator.
72    ///
73    /// # Returns
74    ///
75    /// A `Matrix` with Xavier-initialized values.
76    pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
77        let limit = (6.0 / (rows + cols) as f64).sqrt();
78        let data: Vec<f64> = (0..rows * cols)
79            .map(|_| rng.gen_range(-limit..limit))
80            .collect();
81        Self { data, rows, cols }
82    }
83
84    /// Returns the element at `(row, col)`.
85    ///
86    /// Defaults to `0.0` if indices are out of bounds.
87    ///
88    /// # Arguments
89    ///
90    /// * `row` - Row index.
91    /// * `col` - Column index.
92    pub fn get(&self, row: usize, col: usize) -> f64 {
93        assert!(
94            row < self.rows && col < self.cols,
95            "Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
96            self.rows,
97            self.cols
98        );
99        self.data[row * self.cols + col]
100    }
101
102    /// Sets the element at `(row, col)` to `val`.
103    ///
104    /// Does nothing if indices are out of bounds.
105    ///
106    /// # Arguments
107    ///
108    /// * `row` - Row index.
109    /// * `col` - Column index.
110    /// * `val` - Value to set.
111    pub fn set(&mut self, row: usize, col: usize, val: f64) {
112        assert!(
113            row < self.rows && col < self.cols,
114            "Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
115            self.rows,
116            self.cols
117        );
118        self.data[row * self.cols + col] = val;
119    }
120
121    /// Returns the transpose of this matrix.
122    ///
123    /// # Returns
124    ///
125    /// A new `Matrix` with rows and columns swapped.
126    pub fn transpose(&self) -> Self {
127        let mut result = Matrix::zeros(self.cols, self.rows);
128        for r in 0..self.rows {
129            for c in 0..self.cols {
130                result.set(c, r, self.get(r, c));
131            }
132        }
133        result
134    }
135
136    /// Multiplies this matrix by a column vector.
137    ///
138    /// # Arguments
139    ///
140    /// * `v` - Input vector of length `self.cols`.
141    ///
142    /// # Returns
143    ///
144    /// A vector of length `self.rows`.
145    ///
146    /// # Panics
147    ///
148    /// Panics with "dimension" if `v.len() != self.cols`.
149    pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
150        assert_eq!(
151            v.len(),
152            self.cols,
153            "dimension mismatch: vector length {} != matrix cols {}",
154            v.len(),
155            self.cols
156        );
157        (0..self.rows)
158            .map(|r| {
159                let row_start = r * self.cols;
160                self.data[row_start..row_start + self.cols]
161                    .iter()
162                    .zip(v.iter())
163                    .map(|(a, b)| a * b)
164                    .sum()
165            })
166            .collect()
167    }
168
169    /// Computes the outer product of two vectors.
170    ///
171    /// # Arguments
172    ///
173    /// * `a` - First vector (determines rows).
174    /// * `b` - Second vector (determines cols).
175    ///
176    /// # Returns
177    ///
178    /// A `Matrix` of shape `(a.len(), b.len())`. Returns a 0x0 matrix
179    /// if either vector is empty.
180    pub fn outer(a: &[f64], b: &[f64]) -> Self {
181        if a.is_empty() || b.is_empty() {
182            return Matrix::zeros(0, 0);
183        }
184        let rows = a.len();
185        let cols = b.len();
186        let mut data = vec![0.0; rows * cols];
187        for r in 0..rows {
188            for c in 0..cols {
189                data[r * cols + c] = a[r] * b[c];
190            }
191        }
192        Self { data, rows, cols }
193    }
194
195    /// Adds `scale * other` element-wise and clamps to `[-WEIGHT_CLIP, WEIGHT_CLIP]`.
196    ///
197    /// # Arguments
198    ///
199    /// * `other` - Matrix to add (must have same dimensions).
200    /// * `scale` - Scalar multiplier for `other`.
201    ///
202    /// # Panics
203    ///
204    /// Panics if dimensions do not match.
205    pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
206        assert!(
207            self.rows == other.rows && self.cols == other.cols,
208            "dimension mismatch in scale_add: ({},{}) vs ({},{})",
209            self.rows,
210            self.cols,
211            other.rows,
212            other.cols
213        );
214        for i in 0..self.data.len() {
215            self.data[i] += scale * other.data[i];
216            self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
217        }
218    }
219}
220
221/// Numerically stable masked softmax.
222///
223/// Computes softmax only over indices in `mask`. Non-mask indices are set to zero.
224/// Uses max-subtraction trick for numerical stability.
225///
226/// # Arguments
227///
228/// * `logits` - Raw scores.
229/// * `mask` - Indices to include in the softmax.
230///
231/// # Returns
232///
233/// A probability vector of the same length as `logits`. Empty mask returns all zeros.
234pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
235    let mut result = vec![0.0; logits.len()];
236    if mask.is_empty() {
237        return result;
238    }
239    assert!(
240        mask.iter().all(|&i| i < logits.len()),
241        "softmax_masked: mask index out of bounds (max mask={}, logits len={})",
242        mask.iter().max().unwrap_or(&0),
243        logits.len()
244    );
245
246    let max_val = mask
247        .iter()
248        .map(|&i| logits[i])
249        .fold(f64::NEG_INFINITY, f64::max);
250    let mut sum = 0.0;
251    for &i in mask {
252        let exp_val = (logits[i] - max_val).exp();
253        result[i] = exp_val;
254        sum += exp_val;
255    }
256    if sum > 0.0 {
257        for &i in mask {
258            result[i] /= sum;
259        }
260    }
261    result
262}
263
264/// Returns the index of the maximum value among masked indices.
265///
266/// # Arguments
267///
268/// * `values` - Slice of values.
269/// * `mask` - Indices to consider.
270///
271/// # Panics
272///
273/// Panics if `mask` is empty.
274pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
275    assert!(!mask.is_empty(), "argmax_masked: empty mask");
276    assert!(
277        mask.iter().all(|&i| i < values.len()),
278        "argmax_masked: mask index out of bounds (max mask={}, values len={})",
279        mask.iter().max().unwrap_or(&0),
280        values.len()
281    );
282    let mut best_idx = mask[0];
283    let mut best_val = values[mask[0]];
284    for &i in &mask[1..] {
285        if values[i] > best_val {
286            best_val = values[i];
287            best_idx = i;
288        }
289    }
290    best_idx
291}
292
293/// Combined RMS error across multiple error vectors.
294///
295/// # Arguments
296///
297/// * `error_vecs` - Slice of error vector references.
298///
299/// # Returns
300///
301/// The root-mean-square of all elements. Returns `0.0` if empty.
302pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
303    let mut sum_sq = 0.0;
304    let mut count = 0usize;
305    for v in error_vecs {
306        for &e in *v {
307            sum_sq += e * e;
308            count += 1;
309        }
310    }
311    if count == 0 {
312        return 0.0;
313    }
314    (sum_sq / count as f64).sqrt()
315}
316
317/// Samples an action index from a probability distribution over masked indices.
318///
319/// If only one action is valid, returns it directly. If all probabilities among
320/// mask indices are zero, falls back to uniform sampling over the mask.
321///
322/// # Arguments
323///
324/// * `probs` - Probability vector.
325/// * `mask` - Valid action indices.
326/// * `rng` - Mutable reference to a random number generator.
327///
328/// # Panics
329///
330/// Panics if `mask` is empty.
331pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
332    assert!(!mask.is_empty(), "sample_from_probs: empty mask");
333
334    if mask.len() == 1 {
335        return mask[0];
336    }
337
338    let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
339    if sum <= 0.0 {
340        // Uniform fallback
341        return mask[rng.gen_range(0..mask.len())];
342    }
343
344    let threshold: f64 = rng.gen_range(0.0..1.0);
345    let mut cumulative = 0.0;
346    for &i in mask {
347        cumulative += probs[i] / sum;
348        if cumulative >= threshold {
349            return i;
350        }
351    }
352
353    // Fallback to last mask element (rounding)
354    *mask.last().unwrap()
355}
356
357/// Clamps each element of `v` to `[-max_abs, max_abs]` in place.
358///
359/// # Arguments
360///
361/// * `v` - Mutable slice to clamp.
362/// * `max_abs` - Maximum absolute value.
363pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
364    for x in v.iter_mut() {
365        *x = x.clamp(-max_abs, max_abs);
366    }
367}
368
369/// Element-wise subtraction: `a - b`.
370///
371/// # Arguments
372///
373/// * `a` - First vector.
374/// * `b` - Second vector.
375///
376/// # Returns
377///
378/// A new vector where each element is `a[i] - b[i]`.
379pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
380    assert_eq!(
381        a.len(),
382        b.len(),
383        "vec_sub: length mismatch {} vs {}",
384        a.len(),
385        b.len()
386    );
387    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
388}
389
390/// Element-wise addition: `a + b`.
391///
392/// # Arguments
393///
394/// * `a` - First vector.
395/// * `b` - Second vector.
396///
397/// # Returns
398///
399/// A new vector where each element is `a[i] + b[i]`.
400pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
401    assert_eq!(
402        a.len(),
403        b.len(),
404        "vec_add: length mismatch {} vs {}",
405        a.len(),
406        b.len()
407    );
408    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411/// Scales every element of `v` by `s`.
412///
413/// # Arguments
414///
415/// * `v` - Input vector.
416/// * `s` - Scalar multiplier.
417///
418/// # Returns
419///
420/// A new vector where each element is `v[i] * s`.
421pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
422    v.iter().map(|x| x * s).collect()
423}
424
425/// CCA-based neuron alignment between two activation matrices.
426///
427/// Given activation matrices from parent A and parent B (rows = batch samples,
428/// columns = neurons), computes a permutation that aligns B's neurons to A's
429/// neurons based on functional similarity via Canonical Correlation Analysis.
430///
431/// # Arguments
432///
433/// * `act_a` - Activation matrix for parent A `[batch_size × n_a]`.
434/// * `act_b` - Activation matrix for parent B `[batch_size × n_b]`.
435///
436/// # Returns
437///
438/// A permutation vector of length `min(n_a, n_b)` where `perm[i]` is the
439/// index of the neuron in A that B's neuron `i` maps to.
440pub fn cca_neuron_alignment<L: crate::linalg::LinAlg>(
441    backend: &L,
442    act_a: &L::Matrix,
443    act_b: &L::Matrix,
444) -> Result<Vec<usize>, crate::error::PcError> {
445    let batch_size = backend.mat_rows(act_a);
446    let n_a = backend.mat_cols(act_a);
447    let n_b = backend.mat_cols(act_b);
448    let k = n_a.min(n_b);
449
450    if k == 0 || batch_size < 2 {
451        return Ok((0..k).collect());
452    }
453
454    // Phase 1: Standardize columns (mean=0, std=1)
455    let std_a = standardize_columns(backend, act_a);
456    let std_b = standardize_columns(backend, act_b);
457
458    let scale = 1.0 / (batch_size as f64 - 1.0);
459
460    // Phase 2: Compute covariance matrices
461    let std_a_t = backend.mat_transpose(&std_a);
462    let std_b_t = backend.mat_transpose(&std_b);
463
464    let mut c_a = backend.mat_mul(&std_a_t, &std_a); // n_a × n_a
465    let mut c_b = backend.mat_mul(&std_b_t, &std_b); // n_b × n_b
466    let mut c_ab = backend.mat_mul(&std_a_t, &std_b); // n_a × n_b
467
468    // Scale by 1/(batch_size - 1)
469    scale_matrix(backend, &mut c_a, n_a, n_a, scale);
470    scale_matrix(backend, &mut c_b, n_b, n_b, scale);
471    scale_matrix(backend, &mut c_ab, n_a, n_b, scale);
472
473    // Phase 3: Compute C_a^(-1/2) and C_b^(-1/2) via SVD
474    let c_a_inv_sqrt = mat_inv_sqrt(backend, &c_a)?;
475    let c_b_inv_sqrt = mat_inv_sqrt(backend, &c_b)?;
476
477    // M = C_a^(-1/2) × C_ab × C_b^(-1/2)
478    let temp = backend.mat_mul(&c_a_inv_sqrt, &c_ab);
479    let m = backend.mat_mul(&temp, &c_b_inv_sqrt);
480
481    // SVD(M) → U, S, V
482    let (u, s, v) = backend.svd(&m)?;
483
484    // Phase 4: Build cost matrix and solve optimal assignment via Hungarian
485    let n_canonical = backend.mat_cols(&u).min(backend.mat_cols(&v));
486
487    // Cost matrix: cost[b][a] = -similarity(a, b)
488    // similarity = sum_k S[k] * |U[a,k]| * |V[b,k]|
489    let mut cost = vec![vec![0.0; n_a]; n_b];
490    for (b, cost_row) in cost.iter_mut().enumerate() {
491        for (a, cost_cell) in cost_row.iter_mut().enumerate() {
492            let mut sim = 0.0;
493            for kk in 0..n_canonical {
494                let sk = backend.vec_get(&s, kk);
495                sim += sk * backend.mat_get(&u, a, kk).abs() * backend.mat_get(&v, b, kk).abs();
496            }
497            *cost_cell = -sim; // Negate: Hungarian minimizes
498        }
499    }
500
501    // Solve assignment
502    let assignment = hungarian_assignment(&cost);
503
504    // Build permutation: perm[b] = assigned a for each b
505    let k = n_a.min(n_b);
506    let mut perm = vec![0usize; k];
507    for (b, &a) in assignment.iter().enumerate().take(k) {
508        perm[b] = a;
509    }
510
511    Ok(perm)
512}
513
514/// Scale all elements of a matrix by a scalar.
515fn scale_matrix<L: crate::linalg::LinAlg>(
516    backend: &L,
517    m: &mut L::Matrix,
518    rows: usize,
519    cols: usize,
520    s: f64,
521) {
522    for r in 0..rows {
523        for c in 0..cols {
524            let val = backend.mat_get(m, r, c);
525            backend.mat_set(m, r, c, val * s);
526        }
527    }
528}
529
530/// Standardize columns of a matrix to mean=0, std=1.
531/// Dead neurons (std < epsilon) get zeroed columns.
532fn standardize_columns<L: crate::linalg::LinAlg>(backend: &L, m: &L::Matrix) -> L::Matrix {
533    let rows = backend.mat_rows(m);
534    let cols = backend.mat_cols(m);
535    let mut result = backend.zeros_mat(rows, cols);
536    let eps = 1e-12;
537
538    for c in 0..cols {
539        // Compute mean
540        let mut sum = 0.0;
541        for r in 0..rows {
542            sum += backend.mat_get(m, r, c);
543        }
544        let mean = sum / rows as f64;
545
546        // Compute std
547        let mut var_sum = 0.0;
548        for r in 0..rows {
549            let diff = backend.mat_get(m, r, c) - mean;
550            var_sum += diff * diff;
551        }
552        let std = (var_sum / (rows as f64 - 1.0)).sqrt();
553
554        if std > eps {
555            for r in 0..rows {
556                backend.mat_set(&mut result, r, c, (backend.mat_get(m, r, c) - mean) / std);
557            }
558        }
559        // Dead neuron: column stays zero
560    }
561    result
562}
563
564/// Compute M^(-1/2) for a symmetric positive semi-definite matrix via SVD.
565/// Eigenvalues below epsilon are treated as zero.
566fn mat_inv_sqrt<L: crate::linalg::LinAlg>(
567    backend: &L,
568    m: &L::Matrix,
569) -> Result<L::Matrix, crate::error::PcError> {
570    let n = backend.mat_rows(m);
571    let (u, s, _v) = backend.svd(m)?;
572    let eps = 1e-10;
573
574    // Build diag(1/sqrt(s_i)) for non-zero singular values
575    let k = backend.vec_len(&s);
576    let mut diag_inv_sqrt = backend.zeros_mat(k, k);
577    for i in 0..k {
578        let si = backend.vec_get(&s, i);
579        if si > eps {
580            backend.mat_set(&mut diag_inv_sqrt, i, i, 1.0 / si.sqrt());
581        }
582    }
583
584    // M^(-1/2) = V × diag(1/sqrt(S)) × U^T
585    // For symmetric M: U ≈ V, so M^(-1/2) = U × diag(1/sqrt(S)) × U^T
586    let temp = backend.mat_mul(&u, &diag_inv_sqrt);
587    let ut = backend.mat_transpose(&u);
588    let mut result = backend.mat_mul(&temp, &ut);
589
590    // Ensure result is n×n (SVD may truncate)
591    if backend.mat_rows(&result) != n || backend.mat_cols(&result) != n {
592        let mut padded = backend.zeros_mat(n, n);
593        let r_rows = backend.mat_rows(&result);
594        let r_cols = backend.mat_cols(&result);
595        for r in 0..r_rows.min(n) {
596            for c in 0..r_cols.min(n) {
597                backend.mat_set(&mut padded, r, c, backend.mat_get(&result, r, c));
598            }
599        }
600        result = padded;
601    }
602
603    Ok(result)
604}
605
606/// Greedy matching from CCA canonical directions.
607/// Returns permutation[i] = index in A that B's neuron i maps to.
608/// Hungarian algorithm (Kuhn-Munkres) for optimal assignment.
609///
610/// Given an n×m cost matrix (rows = workers, cols = jobs), finds the
611/// assignment that minimizes total cost. Returns a vector where
612/// `result[i]` is the column assigned to row i.
613///
614/// Handles rectangular matrices by padding to square with zeros.
615/// Time complexity: O(n^3) where n = max(rows, cols).
616pub(crate) fn hungarian_assignment(cost: &[Vec<f64>]) -> Vec<usize> {
617    let n_rows = cost.len();
618    if n_rows == 0 {
619        return vec![];
620    }
621    let n_cols = cost[0].len();
622    let n = n_rows.max(n_cols);
623
624    // Pad to square matrix
625    let mut c = vec![vec![0.0; n + 1]; n + 1]; // 1-indexed
626    for (i, row) in cost.iter().enumerate() {
627        for (j, &val) in row.iter().enumerate() {
628            c[i + 1][j + 1] = val;
629        }
630    }
631
632    // u[i] = potential for row i, v[j] = potential for col j
633    let mut u = vec![0.0; n + 1];
634    let mut v = vec![0.0; n + 1];
635    // p[j] = row assigned to column j (0 = unassigned)
636    let mut p = vec![0usize; n + 1];
637    // way[j] = column that leads to j in the augmenting path
638    let mut way = vec![0usize; n + 1];
639
640    for i in 1..=n {
641        // Start augmenting path from row i
642        p[0] = i;
643        let mut j0 = 0usize; // virtual column
644        let mut min_v = vec![f64::MAX; n + 1];
645        let mut used = vec![false; n + 1];
646
647        loop {
648            used[j0] = true;
649            let i0 = p[j0];
650            let mut delta = f64::MAX;
651            let mut j1 = 0usize;
652
653            for j in 1..=n {
654                if !used[j] {
655                    let cur = c[i0][j] - u[i0] - v[j];
656                    if cur < min_v[j] {
657                        min_v[j] = cur;
658                        way[j] = j0;
659                    }
660                    if min_v[j] < delta {
661                        delta = min_v[j];
662                        j1 = j;
663                    }
664                }
665            }
666
667            // Update potentials
668            for j in 0..=n {
669                if used[j] {
670                    u[p[j]] += delta;
671                    v[j] -= delta;
672                } else {
673                    min_v[j] -= delta;
674                }
675            }
676
677            j0 = j1;
678
679            if p[j0] == 0 {
680                break; // Found augmenting path
681            }
682        }
683
684        // Trace back augmenting path
685        loop {
686            let j1 = way[j0];
687            p[j0] = p[j1];
688            j0 = j1;
689            if j0 == 0 {
690                break;
691            }
692        }
693    }
694
695    // Extract assignment: for each row i, find its assigned column
696    let mut result = vec![0usize; n_rows];
697    for j in 1..=n {
698        if p[j] >= 1 && p[j] <= n_rows {
699            result[p[j] - 1] = j - 1;
700        }
701    }
702    result
703}
704
705/// Greedy matching from CCA canonical directions (deprecated, kept for reference).
706/// Use `hungarian_assignment` instead for optimal matching.
707#[allow(dead_code)]
708fn greedy_match<L: crate::linalg::LinAlg>(
709    backend: &L,
710    u: &L::Matrix,
711    v: &L::Matrix,
712    n_a: usize,
713    n_b: usize,
714) -> Vec<usize> {
715    let k = n_a.min(n_b);
716    let n_canonical = backend.mat_cols(u).min(backend.mat_cols(v));
717
718    let mut matched_a = vec![false; n_a];
719    let mut matched_b = vec![false; n_b];
720    let mut perm = vec![0usize; k];
721    let mut assigned = vec![false; k];
722
723    // Match by strongest canonical correlation first
724    for col in 0..n_canonical {
725        // Find neuron in A with largest |u_k| coefficient
726        let mut best_a = 0;
727        let mut best_a_val = 0.0_f64;
728        for (i, &is_matched) in matched_a
729            .iter()
730            .enumerate()
731            .take(n_a.min(backend.mat_rows(u)))
732        {
733            let val = backend.mat_get(u, i, col).abs();
734            if val > best_a_val && !is_matched {
735                best_a_val = val;
736                best_a = i;
737            }
738        }
739
740        // Find neuron in B with largest |v_k| coefficient
741        let mut best_b = 0;
742        let mut best_b_val = 0.0_f64;
743        for (i, &is_matched) in matched_b
744            .iter()
745            .enumerate()
746            .take(n_b.min(backend.mat_rows(v)))
747        {
748            let val = backend.mat_get(v, i, col).abs();
749            if val > best_b_val && !is_matched {
750                best_b_val = val;
751                best_b = i;
752            }
753        }
754
755        if !matched_a[best_a] && !matched_b[best_b] && best_b < k {
756            perm[best_b] = best_a;
757            assigned[best_b] = true;
758            matched_a[best_a] = true;
759            matched_b[best_b] = true;
760        }
761    }
762
763    // Assign remaining unmatched B neurons to remaining A positions
764    let remaining_a: Vec<usize> = (0..n_a).filter(|i| !matched_a[*i]).collect();
765    let unassigned_b: Vec<usize> = (0..k).filter(|i| !assigned[*i]).collect();
766    for (idx, &b_idx) in unassigned_b.iter().enumerate() {
767        if idx < remaining_a.len() {
768            perm[b_idx] = remaining_a[idx];
769        }
770    }
771
772    perm
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778    use rand::rngs::StdRng;
779    use rand::SeedableRng;
780
781    // ── Matrix Tests ──────────────────────────────────────────────────
782
783    #[test]
784    fn test_zeros_all_zero_correct_dims() {
785        let m = Matrix::zeros(3, 4);
786        assert_eq!(m.rows, 3);
787        assert_eq!(m.cols, 4);
788        assert_eq!(m.data.len(), 12);
789        assert!(m.data.iter().all(|&v| v == 0.0));
790    }
791
792    #[test]
793    fn test_xavier_variance_approx() {
794        let mut rng = StdRng::seed_from_u64(42);
795        let m = Matrix::xavier(100, 100, &mut rng);
796        let n = m.data.len() as f64;
797        let mean = m.data.iter().sum::<f64>() / n;
798        let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
799        let expected_var = 2.0 / (100.0 + 100.0); // 0.01
800        assert!(
801            (variance - expected_var).abs() < expected_var * 0.5,
802            "variance {} not within 50% of expected {}",
803            variance,
804            expected_var
805        );
806    }
807
808    #[test]
809    fn test_xavier_all_finite() {
810        let mut rng = StdRng::seed_from_u64(42);
811        let m = Matrix::xavier(50, 50, &mut rng);
812        assert!(m.data.iter().all(|x| x.is_finite()));
813    }
814
815    #[test]
816    fn test_get_set_roundtrip() {
817        let mut m = Matrix::zeros(3, 3);
818        m.set(1, 2, 42.0);
819        assert_eq!(m.get(1, 2), 42.0);
820    }
821
822    #[test]
823    fn test_get_zero_default() {
824        let m = Matrix::zeros(2, 2);
825        assert_eq!(m.get(0, 0), 0.0);
826    }
827
828    #[test]
829    fn test_transpose_swaps_dims() {
830        let m = Matrix::zeros(3, 5);
831        let t = m.transpose();
832        assert_eq!(t.rows, 5);
833        assert_eq!(t.cols, 3);
834    }
835
836    #[test]
837    fn test_transpose_repositions_values() {
838        let mut m = Matrix::zeros(2, 3);
839        m.set(0, 1, 7.0);
840        m.set(1, 2, 3.0);
841        let t = m.transpose();
842        assert_eq!(t.get(1, 0), 7.0);
843        assert_eq!(t.get(2, 1), 3.0);
844    }
845
846    #[test]
847    fn test_transpose_double_is_identity() {
848        let mut rng = StdRng::seed_from_u64(42);
849        let m = Matrix::xavier(3, 5, &mut rng);
850        let tt = m.transpose().transpose();
851        assert_eq!(m.rows, tt.rows);
852        assert_eq!(m.cols, tt.cols);
853        for i in 0..m.data.len() {
854            assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
855        }
856    }
857
858    #[test]
859    fn test_mul_vec_known_result() {
860        // [[1,2],[3,4]] * [5,6] = [17, 39]
861        let mut m = Matrix::zeros(2, 2);
862        m.set(0, 0, 1.0);
863        m.set(0, 1, 2.0);
864        m.set(1, 0, 3.0);
865        m.set(1, 1, 4.0);
866        let result = m.mul_vec(&[5.0, 6.0]);
867        assert_eq!(result.len(), 2);
868        assert!((result[0] - 17.0).abs() < 1e-10);
869        assert!((result[1] - 39.0).abs() < 1e-10);
870    }
871
872    #[test]
873    fn test_mul_vec_output_length_equals_rows() {
874        let m = Matrix::zeros(4, 3);
875        let result = m.mul_vec(&[1.0, 2.0, 3.0]);
876        assert_eq!(result.len(), 4);
877    }
878
879    #[test]
880    #[should_panic(expected = "dimension")]
881    fn test_mul_vec_panics_wrong_length() {
882        let m = Matrix::zeros(2, 3);
883        m.mul_vec(&[1.0, 2.0]); // wrong length
884    }
885
886    #[test]
887    fn test_mul_vec_zero_matrix_returns_zeros() {
888        let m = Matrix::zeros(3, 2);
889        let result = m.mul_vec(&[5.0, 10.0]);
890        assert!(result.iter().all(|&v| v == 0.0));
891    }
892
893    #[test]
894    fn test_outer_dims_and_values() {
895        let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
896        assert_eq!(m.rows, 2);
897        assert_eq!(m.cols, 3);
898        assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
899        assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
900        assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
901        assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
902        assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
903        assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
904    }
905
906    #[test]
907    fn test_outer_empty_first_returns_zero_matrix() {
908        let m = Matrix::outer(&[], &[1.0, 2.0]);
909        assert_eq!(m.rows, 0);
910        assert_eq!(m.cols, 0);
911    }
912
913    #[test]
914    fn test_outer_empty_second_returns_zero_matrix() {
915        let m = Matrix::outer(&[1.0, 2.0], &[]);
916        assert_eq!(m.rows, 0);
917        assert_eq!(m.cols, 0);
918    }
919
920    #[test]
921    fn test_scale_add_basic() {
922        let mut m = Matrix::zeros(2, 2);
923        m.set(0, 0, 1.0);
924        m.set(1, 1, 2.0);
925        let mut other = Matrix::zeros(2, 2);
926        other.set(0, 0, 0.5);
927        other.set(1, 1, 0.5);
928        m.scale_add(&other, 2.0);
929        assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
930        assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
931    }
932
933    #[test]
934    fn test_scale_add_clips_to_weight_clip() {
935        let mut m = Matrix::zeros(1, 1);
936        m.set(0, 0, 4.0);
937        let mut other = Matrix::zeros(1, 1);
938        other.set(0, 0, 10.0);
939        m.scale_add(&other, 1.0);
940        assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
941    }
942
943    #[test]
944    fn test_scale_add_negative_clips_to_neg_weight_clip() {
945        let mut m = Matrix::zeros(1, 1);
946        m.set(0, 0, -4.0);
947        let mut other = Matrix::zeros(1, 1);
948        other.set(0, 0, -10.0);
949        m.scale_add(&other, 1.0);
950        assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
951    }
952
953    #[test]
954    fn test_scale_add_zero_scale_only_clips() {
955        let mut m = Matrix::zeros(1, 1);
956        m.set(0, 0, 3.0);
957        let other = Matrix::zeros(1, 1);
958        m.scale_add(&other, 0.0);
959        assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
960    }
961
962    #[test]
963    #[should_panic(expected = "dimension")]
964    fn test_scale_add_panics_on_dimension_mismatch() {
965        let mut m = Matrix::zeros(2, 2);
966        let other = Matrix::zeros(3, 3);
967        m.scale_add(&other, 1.0);
968    }
969
970    // ── Softmax Tests ─────────────────────────────────────────────────
971
972    #[test]
973    fn test_softmax_masked_sums_to_one() {
974        let logits = vec![1.0, 2.0, 3.0, 4.0];
975        let mask = vec![0, 1, 2, 3];
976        let probs = softmax_masked(&logits, &mask);
977        let sum: f64 = probs.iter().sum();
978        assert!((sum - 1.0).abs() < 1e-10);
979    }
980
981    #[test]
982    fn test_softmax_masked_unmasked_are_zero() {
983        let logits = vec![1.0, 2.0, 3.0, 4.0];
984        let mask = vec![1, 3];
985        let probs = softmax_masked(&logits, &mask);
986        assert_eq!(probs[0], 0.0);
987        assert_eq!(probs[2], 0.0);
988        assert!(probs[1] > 0.0);
989        assert!(probs[3] > 0.0);
990    }
991
992    #[test]
993    fn test_softmax_masked_single_index_is_one() {
994        let logits = vec![1.0, 2.0, 3.0];
995        let mask = vec![1];
996        let probs = softmax_masked(&logits, &mask);
997        assert!((probs[1] - 1.0).abs() < 1e-10);
998    }
999
1000    #[test]
1001    fn test_softmax_masked_empty_mask_returns_all_zeros() {
1002        let logits = vec![1.0, 2.0, 3.0];
1003        let probs = softmax_masked(&logits, &[]);
1004        assert!(probs.iter().all(|&v| v == 0.0));
1005    }
1006
1007    #[test]
1008    fn test_softmax_masked_numerically_stable_large_logits() {
1009        let logits = vec![1000.0, 1001.0, 1002.0];
1010        let mask = vec![0, 1, 2];
1011        let probs = softmax_masked(&logits, &mask);
1012        assert!(probs.iter().all(|p| p.is_finite()));
1013        let sum: f64 = probs.iter().sum();
1014        assert!((sum - 1.0).abs() < 1e-10);
1015    }
1016
1017    #[test]
1018    fn test_softmax_masked_higher_logit_gets_higher_prob() {
1019        let logits = vec![1.0, 5.0, 2.0];
1020        let mask = vec![0, 1, 2];
1021        let probs = softmax_masked(&logits, &mask);
1022        assert!(probs[1] > probs[2]);
1023        assert!(probs[2] > probs[0]);
1024    }
1025
1026    // ── Argmax Tests ──────────────────────────────────────────────────
1027
1028    #[test]
1029    fn test_argmax_masked_returns_highest_in_mask() {
1030        let values = vec![1.0, 5.0, 3.0, 4.0];
1031        let mask = vec![0, 2, 3];
1032        assert_eq!(argmax_masked(&values, &mask), 3);
1033    }
1034
1035    #[test]
1036    fn test_argmax_masked_single_element() {
1037        let values = vec![1.0, 5.0, 3.0];
1038        let mask = vec![2];
1039        assert_eq!(argmax_masked(&values, &mask), 2);
1040    }
1041
1042    #[test]
1043    fn test_argmax_masked_tie_returns_first() {
1044        let values = vec![3.0, 3.0, 3.0];
1045        let mask = vec![0, 1, 2];
1046        assert_eq!(argmax_masked(&values, &mask), 0);
1047    }
1048
1049    #[test]
1050    #[should_panic]
1051    fn test_argmax_masked_empty_panics() {
1052        let values = vec![1.0, 2.0];
1053        argmax_masked(&values, &[]);
1054    }
1055
1056    // ── RMS Error Tests ───────────────────────────────────────────────
1057
1058    #[test]
1059    fn test_rms_error_empty_returns_zero() {
1060        assert_eq!(rms_error(&[]), 0.0);
1061    }
1062
1063    #[test]
1064    fn test_rms_error_single_empty_vec_returns_zero() {
1065        let empty: &[f64] = &[];
1066        assert_eq!(rms_error(&[empty]), 0.0);
1067    }
1068
1069    #[test]
1070    fn test_rms_error_known_two_vecs() {
1071        let v1: &[f64] = &[1.0, 0.0];
1072        let v2: &[f64] = &[0.0, 1.0];
1073        let rms = rms_error(&[v1, v2]);
1074        // sum_sq = 1+0+0+1 = 2, count = 4, rms = sqrt(2/4) = sqrt(0.5)
1075        let expected = (0.5_f64).sqrt();
1076        assert!((rms - expected).abs() < 1e-10);
1077    }
1078
1079    #[test]
1080    fn test_rms_error_single_vec() {
1081        let v: &[f64] = &[3.0, 4.0];
1082        let rms = rms_error(&[v]);
1083        // sum_sq = 9+16 = 25, count = 2, rms = sqrt(12.5) = 3.5355...
1084        let expected = (25.0 / 2.0_f64).sqrt();
1085        assert!((rms - expected).abs() < 1e-10);
1086    }
1087
1088    #[test]
1089    fn test_rms_error_all_zeros_returns_zero() {
1090        let v: &[f64] = &[0.0, 0.0, 0.0];
1091        assert_eq!(rms_error(&[v]), 0.0);
1092    }
1093
1094    // ── Sample Tests ──────────────────────────────────────────────────
1095
1096    #[test]
1097    fn test_sample_from_probs_always_in_mask() {
1098        let mut rng = StdRng::seed_from_u64(42);
1099        let probs = vec![0.1, 0.2, 0.3, 0.4];
1100        let mask = vec![1, 3];
1101        for _ in 0..20 {
1102            let idx = sample_from_probs(&probs, &mask, &mut rng);
1103            assert!(mask.contains(&idx));
1104        }
1105    }
1106
1107    #[test]
1108    fn test_sample_from_probs_single_action_always_returns_it() {
1109        let mut rng = StdRng::seed_from_u64(42);
1110        let probs = vec![0.5, 0.5];
1111        let mask = vec![1];
1112        for _ in 0..10 {
1113            assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
1114        }
1115    }
1116
1117    #[test]
1118    fn test_sample_from_probs_visits_multiple_actions() {
1119        let mut rng = StdRng::seed_from_u64(42);
1120        let probs = vec![0.5, 0.5];
1121        let mask = vec![0, 1];
1122        let mut seen = [false; 2];
1123        for _ in 0..100 {
1124            let idx = sample_from_probs(&probs, &mask, &mut rng);
1125            seen[idx] = true;
1126        }
1127        assert!(seen[0] && seen[1], "should visit both actions");
1128    }
1129
1130    #[test]
1131    fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
1132        let mut rng = StdRng::seed_from_u64(42);
1133        let probs = vec![0.0, 0.0, 0.0];
1134        let mask = vec![0, 2];
1135        for _ in 0..20 {
1136            let idx = sample_from_probs(&probs, &mask, &mut rng);
1137            assert!(mask.contains(&idx));
1138        }
1139    }
1140
1141    #[test]
1142    #[should_panic]
1143    fn test_sample_from_probs_empty_mask_panics() {
1144        let mut rng = StdRng::seed_from_u64(42);
1145        let probs = vec![0.5, 0.5];
1146        sample_from_probs(&probs, &[], &mut rng);
1147    }
1148
1149    // ── Vec Utility Tests ─────────────────────────────────────────────
1150
1151    #[test]
1152    fn test_vec_sub_known() {
1153        let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
1154        assert!((result[0] - 2.0).abs() < 1e-10);
1155        assert!((result[1] - (-1.0)).abs() < 1e-10);
1156    }
1157
1158    #[test]
1159    fn test_vec_add_known() {
1160        let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
1161        assert!((result[0] - 4.0).abs() < 1e-10);
1162        assert!((result[1] - 6.0).abs() < 1e-10);
1163    }
1164
1165    #[test]
1166    fn test_vec_scale_known() {
1167        let result = vec_scale(&[1.0, -2.0], 3.0);
1168        assert!((result[0] - 3.0).abs() < 1e-10);
1169        assert!((result[1] - (-6.0)).abs() < 1e-10);
1170    }
1171
1172    #[test]
1173    fn test_clip_vec_clamps_positive() {
1174        let mut v = vec![10.0, -10.0, 0.5];
1175        clip_vec(&mut v, 5.0);
1176        assert!((v[0] - 5.0).abs() < 1e-10);
1177        assert!((v[1] - (-5.0)).abs() < 1e-10);
1178        assert!((v[2] - 0.5).abs() < 1e-10);
1179    }
1180
1181    #[test]
1182    #[should_panic(expected = "length mismatch")]
1183    fn test_vec_sub_panics_on_length_mismatch() {
1184        vec_sub(&[1.0, 2.0], &[1.0]);
1185    }
1186
1187    #[test]
1188    #[should_panic(expected = "length mismatch")]
1189    fn test_vec_add_panics_on_length_mismatch() {
1190        vec_add(&[1.0, 2.0], &[1.0]);
1191    }
1192
1193    #[test]
1194    fn test_clip_vec_leaves_safe_values() {
1195        let mut v = vec![1.0, -1.0, 0.0];
1196        clip_vec(&mut v, 5.0);
1197        assert!((v[0] - 1.0).abs() < 1e-10);
1198        assert!((v[1] - (-1.0)).abs() < 1e-10);
1199        assert!((v[2] - 0.0).abs() < 1e-10);
1200    }
1201
1202    // ── Defensive: OOB assertions ────────────────────────────────
1203
1204    #[test]
1205    #[should_panic(expected = "out of bounds")]
1206    fn test_get_panics_on_oob_row() {
1207        let m = Matrix::zeros(2, 2);
1208        m.get(5, 0); // should panic, not return 0.0
1209    }
1210
1211    #[test]
1212    #[should_panic(expected = "out of bounds")]
1213    fn test_set_panics_on_oob_row() {
1214        let mut m = Matrix::zeros(2, 2);
1215        m.set(5, 0, 1.0); // should panic, not silently do nothing
1216    }
1217
1218    #[test]
1219    #[should_panic(expected = "mask index out of bounds")]
1220    fn test_softmax_masked_panics_on_oob_mask() {
1221        let logits = vec![1.0, 2.0, 3.0];
1222        softmax_masked(&logits, &[0, 5]); // 5 >= logits.len()
1223    }
1224
1225    #[test]
1226    #[should_panic(expected = "mask index out of bounds")]
1227    fn test_argmax_masked_panics_on_oob_mask() {
1228        let values = vec![1.0, 2.0, 3.0];
1229        argmax_masked(&values, &[0, 5]); // 5 >= values.len()
1230    }
1231
1232    // ── sample_from_probs distribution ───────────────────────────
1233
1234    // ── Phase 3 Cycle 3.1: CCA identical activations → identity ────
1235
1236    #[test]
1237    fn test_cca_identical_activations_identity_permutation() {
1238        // Same activations for A and B → identity permutation [0, 1, 2]
1239        use crate::linalg::cpu::CpuLinAlg;
1240        use crate::linalg::LinAlg;
1241        let backend = CpuLinAlg::new();
1242
1243        let batch_size = 100;
1244        let n_neurons = 3;
1245        let mut rng = StdRng::seed_from_u64(42);
1246
1247        // Generate random activations (batch_size × n_neurons)
1248        let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1249        for r in 0..batch_size {
1250            for c in 0..n_neurons {
1251                let val: f64 = rng.gen_range(-1.0..1.0);
1252                backend.mat_set(&mut act_a, r, c, val);
1253            }
1254        }
1255        let act_b = act_a.clone();
1256
1257        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1258        assert_eq!(perm.len(), n_neurons);
1259        assert_eq!(perm, vec![0, 1, 2]);
1260    }
1261
1262    #[test]
1263    fn test_cca_permutation_length_is_min() {
1264        // A has 4 neurons, B has 4 neurons → perm length = 4
1265        use crate::linalg::cpu::CpuLinAlg;
1266        use crate::linalg::LinAlg;
1267        let backend = CpuLinAlg::new();
1268
1269        let batch_size = 100;
1270        let mut rng = StdRng::seed_from_u64(42);
1271
1272        let mut act = backend.zeros_mat(batch_size, 4);
1273        for r in 0..batch_size {
1274            for c in 0..4 {
1275                let val: f64 = rng.gen_range(-1.0..1.0);
1276                backend.mat_set(&mut act, r, c, val);
1277            }
1278        }
1279
1280        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act, &act).unwrap();
1281        assert_eq!(perm.len(), 4);
1282    }
1283
1284    // ── Phase 3 Cycle 3.2: CCA recovers permutation ────────────
1285
1286    #[test]
1287    fn test_cca_permuted_activations_recovers_permutation() {
1288        // B = permuted A with columns [2, 0, 1] → perm should map back
1289        use crate::linalg::cpu::CpuLinAlg;
1290        use crate::linalg::LinAlg;
1291        let backend = CpuLinAlg::new();
1292
1293        let batch_size = 500;
1294        let n_neurons = 3;
1295        let mut rng = StdRng::seed_from_u64(42);
1296
1297        let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1298        for r in 0..batch_size {
1299            for c in 0..n_neurons {
1300                let val: f64 = rng.gen_range(-1.0..1.0);
1301                backend.mat_set(&mut act_a, r, c, val);
1302            }
1303        }
1304
1305        // B columns = [A_col2, A_col0, A_col1]
1306        // So B neuron 0 = A neuron 2, B neuron 1 = A neuron 0, B neuron 2 = A neuron 1
1307        // permutation[i] = which A neuron maps to B neuron i
1308        // Expected: [2, 0, 1]
1309        let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1310        let col_map = [2, 0, 1]; // B_col_j = A_col_{col_map[j]}
1311        for r in 0..batch_size {
1312            for (j, &src_col) in col_map.iter().enumerate() {
1313                backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1314            }
1315        }
1316
1317        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1318        assert_eq!(perm, vec![2, 0, 1]);
1319    }
1320
1321    #[test]
1322    fn test_cca_permuted_with_small_batch() {
1323        // Same permutation test but with batch_size=50
1324        use crate::linalg::cpu::CpuLinAlg;
1325        use crate::linalg::LinAlg;
1326        let backend = CpuLinAlg::new();
1327
1328        let batch_size = 50;
1329        let n_neurons = 3;
1330        let mut rng = StdRng::seed_from_u64(99);
1331
1332        let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1333        for r in 0..batch_size {
1334            for c in 0..n_neurons {
1335                let val: f64 = rng.gen_range(-1.0..1.0);
1336                backend.mat_set(&mut act_a, r, c, val);
1337            }
1338        }
1339
1340        // Permutation [1, 2, 0]
1341        let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1342        let col_map = [1, 2, 0];
1343        for r in 0..batch_size {
1344            for (j, &src_col) in col_map.iter().enumerate() {
1345                backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1346            }
1347        }
1348
1349        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1350        assert_eq!(perm, vec![1, 2, 0]);
1351    }
1352
1353    #[test]
1354    fn test_cca_permuted_large_batch() {
1355        // batch_size=500, verifying robustness
1356        use crate::linalg::cpu::CpuLinAlg;
1357        use crate::linalg::LinAlg;
1358        let backend = CpuLinAlg::new();
1359
1360        let batch_size = 500;
1361        let n_neurons = 4;
1362        let mut rng = StdRng::seed_from_u64(7);
1363
1364        let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1365        for r in 0..batch_size {
1366            for c in 0..n_neurons {
1367                let val: f64 = rng.gen_range(-1.0..1.0);
1368                backend.mat_set(&mut act_a, r, c, val);
1369            }
1370        }
1371
1372        // Permutation [3, 1, 0, 2]
1373        let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1374        let col_map = [3, 1, 0, 2];
1375        for r in 0..batch_size {
1376            for (j, &src_col) in col_map.iter().enumerate() {
1377                backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1378            }
1379        }
1380
1381        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1382        assert_eq!(perm, vec![3, 1, 0, 2]);
1383    }
1384
1385    // ── Phase 3 Cycle 3.3: CCA with different dimensions ────────
1386
1387    #[test]
1388    fn test_cca_a_larger_than_b() {
1389        // A has 4 neurons, B has 3 neurons → permutation length = 3
1390        use crate::linalg::cpu::CpuLinAlg;
1391        use crate::linalg::LinAlg;
1392        let backend = CpuLinAlg::new();
1393
1394        let batch_size = 200;
1395        let mut rng = StdRng::seed_from_u64(42);
1396
1397        let mut act_a = backend.zeros_mat(batch_size, 4);
1398        for r in 0..batch_size {
1399            for c in 0..4 {
1400                let val: f64 = rng.gen_range(-1.0..1.0);
1401                backend.mat_set(&mut act_a, r, c, val);
1402            }
1403        }
1404
1405        // B has 3 neurons: B_col_j = A_col_j (first 3 columns)
1406        let mut act_b = backend.zeros_mat(batch_size, 3);
1407        for r in 0..batch_size {
1408            for c in 0..3 {
1409                backend.mat_set(&mut act_b, r, c, backend.mat_get(&act_a, r, c));
1410            }
1411        }
1412
1413        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1414        assert_eq!(perm.len(), 3);
1415    }
1416
1417    #[test]
1418    fn test_cca_b_larger_than_a() {
1419        // A has 3 neurons, B has 5 neurons → permutation length = 3 (min)
1420        use crate::linalg::cpu::CpuLinAlg;
1421        use crate::linalg::LinAlg;
1422        let backend = CpuLinAlg::new();
1423
1424        let batch_size = 200;
1425        let mut rng = StdRng::seed_from_u64(42);
1426
1427        let mut act_a = backend.zeros_mat(batch_size, 3);
1428        for r in 0..batch_size {
1429            for c in 0..3 {
1430                let val: f64 = rng.gen_range(-1.0..1.0);
1431                backend.mat_set(&mut act_a, r, c, val);
1432            }
1433        }
1434
1435        let mut act_b = backend.zeros_mat(batch_size, 5);
1436        for r in 0..batch_size {
1437            for c in 0..5 {
1438                let val: f64 = rng.gen_range(-1.0..1.0);
1439                backend.mat_set(&mut act_b, r, c, val);
1440            }
1441        }
1442
1443        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1444        assert_eq!(perm.len(), 3);
1445    }
1446
1447    #[test]
1448    fn test_cca_dead_neuron_excluded() {
1449        // One neuron in B has zero variance (dead) → still produces valid permutation
1450        use crate::linalg::cpu::CpuLinAlg;
1451        use crate::linalg::LinAlg;
1452        let backend = CpuLinAlg::new();
1453
1454        let batch_size = 100;
1455        let n_neurons = 3;
1456        let mut rng = StdRng::seed_from_u64(42);
1457
1458        let mut act_a = backend.zeros_mat(batch_size, n_neurons);
1459        for r in 0..batch_size {
1460            for c in 0..n_neurons {
1461                let val: f64 = rng.gen_range(-1.0..1.0);
1462                backend.mat_set(&mut act_a, r, c, val);
1463            }
1464        }
1465
1466        // B: neuron 1 is dead (constant 0), neurons 0 and 2 copy from A
1467        let mut act_b = backend.zeros_mat(batch_size, n_neurons);
1468        for r in 0..batch_size {
1469            backend.mat_set(&mut act_b, r, 0, backend.mat_get(&act_a, r, 0));
1470            backend.mat_set(&mut act_b, r, 1, 0.0); // dead neuron
1471            backend.mat_set(&mut act_b, r, 2, backend.mat_get(&act_a, r, 2));
1472        }
1473
1474        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1475        // Should produce a valid permutation of length 3, no panic
1476        assert_eq!(perm.len(), n_neurons);
1477        // All indices in range [0, n_neurons)
1478        for &p in &perm {
1479            assert!(p < n_neurons, "permutation index {p} out of range");
1480        }
1481        // All indices unique
1482        let mut sorted = perm.clone();
1483        sorted.sort();
1484        sorted.dedup();
1485        assert_eq!(sorted.len(), n_neurons, "permutation has duplicates");
1486    }
1487
1488    // ── Fix #9: Hungarian matching ─────────────────────────────
1489
1490    #[test]
1491    fn test_hungarian_assignment_basic() {
1492        // Verify hungarian_assignment solves a known 3x3 cost matrix optimally
1493        // Cost matrix (minimize):
1494        //   [[1, 2, 3],
1495        //    [2, 4, 6],
1496        //    [3, 6, 9]]
1497        // Optimal: (0,2)=3, (1,1)=4, (2,0)=3 → total=10
1498        let assignment = hungarian_assignment(&[
1499            vec![1.0, 2.0, 3.0],
1500            vec![2.0, 4.0, 6.0],
1501            vec![3.0, 6.0, 9.0],
1502        ]);
1503        // Verify total cost is optimal (10)
1504        let total: f64 = assignment
1505            .iter()
1506            .enumerate()
1507            .map(|(i, &j)| [1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0][i * 3 + j])
1508            .sum();
1509        assert!(
1510            (total - 10.0).abs() < 1e-10,
1511            "Expected total cost 10, got {total}"
1512        );
1513    }
1514
1515    #[test]
1516    fn test_hungarian_assignment_permuted() {
1517        // Cost matrix where optimal is NOT the diagonal:
1518        //   [[5, 1, 3],
1519        //    [2, 8, 7],
1520        //    [6, 4, 1]]
1521        // Optimal: (0,1)=1, (1,0)=2, (2,2)=1 → total=4
1522        let assignment = hungarian_assignment(&[
1523            vec![5.0, 1.0, 3.0],
1524            vec![2.0, 8.0, 7.0],
1525            vec![6.0, 4.0, 1.0],
1526        ]);
1527        assert_eq!(assignment, vec![1, 0, 2]);
1528    }
1529
1530    #[test]
1531    fn test_hungarian_assignment_4x4() {
1532        // 4×4 cost matrix:
1533        //   [[10, 5, 13, 15],
1534        //    [3, 9, 18,  6],
1535        //    [10, 7, 2, 12],
1536        //    [5, 11, 9,  4]]
1537        // Optimal: (0,1)=5, (1,0)=3, (2,2)=2, (3,3)=4 → total=14
1538        let assignment = hungarian_assignment(&[
1539            vec![10.0, 5.0, 13.0, 15.0],
1540            vec![3.0, 9.0, 18.0, 6.0],
1541            vec![10.0, 7.0, 2.0, 12.0],
1542            vec![5.0, 11.0, 9.0, 4.0],
1543        ]);
1544        assert_eq!(assignment, vec![1, 0, 2, 3]);
1545    }
1546
1547    #[test]
1548    fn test_hungarian_assignment_1x1() {
1549        let assignment = hungarian_assignment(&[vec![42.0]]);
1550        assert_eq!(assignment, vec![0]);
1551    }
1552
1553    #[test]
1554    fn test_hungarian_optimal_vs_greedy_on_collision_case() {
1555        // Construct activations where greedy matching fails due to argmax collision
1556        // but Hungarian finds the optimal alignment.
1557        //
1558        // Create two networks where neurons in B are a known permutation of A,
1559        // but with correlated noise that causes argmax collisions in canonical
1560        // directions. Use enough neurons that collisions are likely.
1561        use crate::linalg::cpu::CpuLinAlg;
1562        use crate::linalg::LinAlg;
1563        let backend = CpuLinAlg::new();
1564
1565        let batch_size = 200;
1566        let n = 8;
1567        let mut rng = StdRng::seed_from_u64(42);
1568
1569        // Generate A with distinct but correlated neuron activations
1570        let mut act_a = backend.zeros_mat(batch_size, n);
1571        for r in 0..batch_size {
1572            // Base signal
1573            let base: f64 = rng.gen_range(-1.0..1.0);
1574            for c in 0..n {
1575                let noise: f64 = rng.gen_range(-0.3..0.3);
1576                // Each neuron = base * weight_c + noise
1577                let weight = (c as f64 + 1.0) / n as f64;
1578                backend.mat_set(&mut act_a, r, c, base * weight + noise);
1579            }
1580        }
1581
1582        // B is a known permutation of A: [5, 3, 7, 1, 6, 0, 4, 2]
1583        let true_perm = [5, 3, 7, 1, 6, 0, 4, 2];
1584        let mut act_b = backend.zeros_mat(batch_size, n);
1585        for r in 0..batch_size {
1586            for (j, &src_col) in true_perm.iter().enumerate() {
1587                backend.mat_set(&mut act_b, r, j, backend.mat_get(&act_a, r, src_col));
1588            }
1589        }
1590
1591        let perm = cca_neuron_alignment::<CpuLinAlg>(&backend, &act_a, &act_b).unwrap();
1592
1593        // With Hungarian, the optimal assignment should recover the true permutation
1594        assert_eq!(
1595            perm,
1596            true_perm.to_vec(),
1597            "Hungarian should recover exact permutation for correlated neurons"
1598        );
1599    }
1600
1601    #[test]
1602    fn test_sample_from_probs_distribution_roughly_correct() {
1603        let mut rng = StdRng::seed_from_u64(42);
1604        let probs = vec![0.7, 0.3];
1605        let mask = vec![0, 1];
1606        let mut counts = [0usize; 2];
1607        let n = 1000;
1608        for _ in 0..n {
1609            let idx = sample_from_probs(&probs, &mask, &mut rng);
1610            counts[idx] += 1;
1611        }
1612        let ratio = counts[0] as f64 / n as f64;
1613        // Should be roughly 0.7, allow 10% tolerance
1614        assert!(
1615            (ratio - 0.7).abs() < 0.1,
1616            "Expected ~0.7 for action 0, got {ratio}"
1617        );
1618    }
1619}