Skip to main content

pc_rl_core/
matrix.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Dense matrix operations and vector utilities for neural networks.
6//!
7//! Provides a custom [`Matrix`] struct and free functions for softmax,
8//! argmax, RMS error, categorical sampling, and element-wise vector ops.
9//! Pure Rust with no external linear-algebra dependencies.
10
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14/// Maximum absolute value for weight clamping after updates.
15pub const WEIGHT_CLIP: f64 = 5.0;
16
17/// Maximum absolute value for gradient clamping.
18pub const GRAD_CLIP: f64 = 5.0;
19
20/// A dense row-major matrix of `f64` values.
21///
22/// Data is stored in a flat `Vec<f64>` of length `rows * cols`.
23///
24/// # Examples
25///
26/// ```
27/// use pc_rl_core::matrix::Matrix;
28///
29/// let m = Matrix::zeros(2, 3);
30/// assert_eq!(m.rows, 2);
31/// assert_eq!(m.cols, 3);
32/// ```
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Matrix {
35    /// Flat row-major storage.
36    pub data: Vec<f64>,
37    /// Number of rows.
38    pub rows: usize,
39    /// Number of columns.
40    pub cols: usize,
41}
42
43impl Matrix {
44    /// Creates a matrix filled with zeros.
45    ///
46    /// # Arguments
47    ///
48    /// * `rows` - Number of rows.
49    /// * `cols` - Number of columns.
50    ///
51    /// # Returns
52    ///
53    /// A `Matrix` with all elements set to `0.0`.
54    pub fn zeros(rows: usize, cols: usize) -> Self {
55        Self {
56            data: vec![0.0; rows * cols],
57            rows,
58            cols,
59        }
60    }
61
62    /// Creates a matrix with Xavier-uniform initialization.
63    ///
64    /// Elements are drawn uniformly from `[-limit, limit]` where
65    /// `limit = sqrt(6.0 / (rows + cols))`.
66    ///
67    /// # Arguments
68    ///
69    /// * `rows` - Number of rows.
70    /// * `cols` - Number of columns.
71    /// * `rng` - Mutable reference to a random number generator.
72    ///
73    /// # Returns
74    ///
75    /// A `Matrix` with Xavier-initialized values.
76    pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
77        let limit = (6.0 / (rows + cols) as f64).sqrt();
78        let data: Vec<f64> = (0..rows * cols)
79            .map(|_| rng.gen_range(-limit..limit))
80            .collect();
81        Self { data, rows, cols }
82    }
83
84    /// Returns the element at `(row, col)`.
85    ///
86    /// Defaults to `0.0` if indices are out of bounds.
87    ///
88    /// # Arguments
89    ///
90    /// * `row` - Row index.
91    /// * `col` - Column index.
92    pub fn get(&self, row: usize, col: usize) -> f64 {
93        assert!(
94            row < self.rows && col < self.cols,
95            "Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
96            self.rows,
97            self.cols
98        );
99        self.data[row * self.cols + col]
100    }
101
102    /// Sets the element at `(row, col)` to `val`.
103    ///
104    /// Does nothing if indices are out of bounds.
105    ///
106    /// # Arguments
107    ///
108    /// * `row` - Row index.
109    /// * `col` - Column index.
110    /// * `val` - Value to set.
111    pub fn set(&mut self, row: usize, col: usize, val: f64) {
112        assert!(
113            row < self.rows && col < self.cols,
114            "Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
115            self.rows,
116            self.cols
117        );
118        self.data[row * self.cols + col] = val;
119    }
120
121    /// Returns the transpose of this matrix.
122    ///
123    /// # Returns
124    ///
125    /// A new `Matrix` with rows and columns swapped.
126    pub fn transpose(&self) -> Self {
127        let mut result = Matrix::zeros(self.cols, self.rows);
128        for r in 0..self.rows {
129            for c in 0..self.cols {
130                result.set(c, r, self.get(r, c));
131            }
132        }
133        result
134    }
135
136    /// Multiplies this matrix by a column vector.
137    ///
138    /// # Arguments
139    ///
140    /// * `v` - Input vector of length `self.cols`.
141    ///
142    /// # Returns
143    ///
144    /// A vector of length `self.rows`.
145    ///
146    /// # Panics
147    ///
148    /// Panics with "dimension" if `v.len() != self.cols`.
149    pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
150        assert_eq!(
151            v.len(),
152            self.cols,
153            "dimension mismatch: vector length {} != matrix cols {}",
154            v.len(),
155            self.cols
156        );
157        (0..self.rows)
158            .map(|r| {
159                let row_start = r * self.cols;
160                self.data[row_start..row_start + self.cols]
161                    .iter()
162                    .zip(v.iter())
163                    .map(|(a, b)| a * b)
164                    .sum()
165            })
166            .collect()
167    }
168
169    /// Computes the outer product of two vectors.
170    ///
171    /// # Arguments
172    ///
173    /// * `a` - First vector (determines rows).
174    /// * `b` - Second vector (determines cols).
175    ///
176    /// # Returns
177    ///
178    /// A `Matrix` of shape `(a.len(), b.len())`. Returns a 0x0 matrix
179    /// if either vector is empty.
180    pub fn outer(a: &[f64], b: &[f64]) -> Self {
181        if a.is_empty() || b.is_empty() {
182            return Matrix::zeros(0, 0);
183        }
184        let rows = a.len();
185        let cols = b.len();
186        let mut data = vec![0.0; rows * cols];
187        for r in 0..rows {
188            for c in 0..cols {
189                data[r * cols + c] = a[r] * b[c];
190            }
191        }
192        Self { data, rows, cols }
193    }
194
195    /// Adds `scale * other` element-wise and clamps to `[-WEIGHT_CLIP, WEIGHT_CLIP]`.
196    ///
197    /// # Arguments
198    ///
199    /// * `other` - Matrix to add (must have same dimensions).
200    /// * `scale` - Scalar multiplier for `other`.
201    ///
202    /// # Panics
203    ///
204    /// Panics if dimensions do not match.
205    pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
206        assert!(
207            self.rows == other.rows && self.cols == other.cols,
208            "dimension mismatch in scale_add: ({},{}) vs ({},{})",
209            self.rows,
210            self.cols,
211            other.rows,
212            other.cols
213        );
214        for i in 0..self.data.len() {
215            self.data[i] += scale * other.data[i];
216            self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
217        }
218    }
219}
220
221/// Numerically stable masked softmax.
222///
223/// Computes softmax only over indices in `mask`. Non-mask indices are set to zero.
224/// Uses max-subtraction trick for numerical stability.
225///
226/// # Arguments
227///
228/// * `logits` - Raw scores.
229/// * `mask` - Indices to include in the softmax.
230///
231/// # Returns
232///
233/// A probability vector of the same length as `logits`. Empty mask returns all zeros.
234pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
235    let mut result = vec![0.0; logits.len()];
236    if mask.is_empty() {
237        return result;
238    }
239    assert!(
240        mask.iter().all(|&i| i < logits.len()),
241        "softmax_masked: mask index out of bounds (max mask={}, logits len={})",
242        mask.iter().max().unwrap_or(&0),
243        logits.len()
244    );
245
246    let max_val = mask
247        .iter()
248        .map(|&i| logits[i])
249        .fold(f64::NEG_INFINITY, f64::max);
250    let mut sum = 0.0;
251    for &i in mask {
252        let exp_val = (logits[i] - max_val).exp();
253        result[i] = exp_val;
254        sum += exp_val;
255    }
256    if sum > 0.0 {
257        for &i in mask {
258            result[i] /= sum;
259        }
260    }
261    result
262}
263
264/// Returns the index of the maximum value among masked indices.
265///
266/// # Arguments
267///
268/// * `values` - Slice of values.
269/// * `mask` - Indices to consider.
270///
271/// # Panics
272///
273/// Panics if `mask` is empty.
274pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
275    assert!(!mask.is_empty(), "argmax_masked: empty mask");
276    assert!(
277        mask.iter().all(|&i| i < values.len()),
278        "argmax_masked: mask index out of bounds (max mask={}, values len={})",
279        mask.iter().max().unwrap_or(&0),
280        values.len()
281    );
282    let mut best_idx = mask[0];
283    let mut best_val = values[mask[0]];
284    for &i in &mask[1..] {
285        if values[i] > best_val {
286            best_val = values[i];
287            best_idx = i;
288        }
289    }
290    best_idx
291}
292
293/// Combined RMS error across multiple error vectors.
294///
295/// # Arguments
296///
297/// * `error_vecs` - Slice of error vector references.
298///
299/// # Returns
300///
301/// The root-mean-square of all elements. Returns `0.0` if empty.
302pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
303    let mut sum_sq = 0.0;
304    let mut count = 0usize;
305    for v in error_vecs {
306        for &e in *v {
307            sum_sq += e * e;
308            count += 1;
309        }
310    }
311    if count == 0 {
312        return 0.0;
313    }
314    (sum_sq / count as f64).sqrt()
315}
316
317/// Samples an action index from a probability distribution over masked indices.
318///
319/// If only one action is valid, returns it directly. If all probabilities among
320/// mask indices are zero, falls back to uniform sampling over the mask.
321///
322/// # Arguments
323///
324/// * `probs` - Probability vector.
325/// * `mask` - Valid action indices.
326/// * `rng` - Mutable reference to a random number generator.
327///
328/// # Panics
329///
330/// Panics if `mask` is empty.
331pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
332    assert!(!mask.is_empty(), "sample_from_probs: empty mask");
333
334    if mask.len() == 1 {
335        return mask[0];
336    }
337
338    let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
339    if sum <= 0.0 {
340        // Uniform fallback
341        return mask[rng.gen_range(0..mask.len())];
342    }
343
344    let threshold: f64 = rng.gen_range(0.0..1.0);
345    let mut cumulative = 0.0;
346    for &i in mask {
347        cumulative += probs[i] / sum;
348        if cumulative >= threshold {
349            return i;
350        }
351    }
352
353    // Fallback to last mask element (rounding)
354    *mask.last().unwrap()
355}
356
357/// Clamps each element of `v` to `[-max_abs, max_abs]` in place.
358///
359/// # Arguments
360///
361/// * `v` - Mutable slice to clamp.
362/// * `max_abs` - Maximum absolute value.
363pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
364    for x in v.iter_mut() {
365        *x = x.clamp(-max_abs, max_abs);
366    }
367}
368
369/// Element-wise subtraction: `a - b`.
370///
371/// # Arguments
372///
373/// * `a` - First vector.
374/// * `b` - Second vector.
375///
376/// # Returns
377///
378/// A new vector where each element is `a[i] - b[i]`.
379pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
380    assert_eq!(
381        a.len(),
382        b.len(),
383        "vec_sub: length mismatch {} vs {}",
384        a.len(),
385        b.len()
386    );
387    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
388}
389
390/// Element-wise addition: `a + b`.
391///
392/// # Arguments
393///
394/// * `a` - First vector.
395/// * `b` - Second vector.
396///
397/// # Returns
398///
399/// A new vector where each element is `a[i] + b[i]`.
400pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
401    assert_eq!(
402        a.len(),
403        b.len(),
404        "vec_add: length mismatch {} vs {}",
405        a.len(),
406        b.len()
407    );
408    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411/// Scales every element of `v` by `s`.
412///
413/// # Arguments
414///
415/// * `v` - Input vector.
416/// * `s` - Scalar multiplier.
417///
418/// # Returns
419///
420/// A new vector where each element is `v[i] * s`.
421pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
422    v.iter().map(|x| x * s).collect()
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use rand::rngs::StdRng;
429    use rand::SeedableRng;
430
431    // ── Matrix Tests ──────────────────────────────────────────────────
432
433    #[test]
434    fn test_zeros_all_zero_correct_dims() {
435        let m = Matrix::zeros(3, 4);
436        assert_eq!(m.rows, 3);
437        assert_eq!(m.cols, 4);
438        assert_eq!(m.data.len(), 12);
439        assert!(m.data.iter().all(|&v| v == 0.0));
440    }
441
442    #[test]
443    fn test_xavier_variance_approx() {
444        let mut rng = StdRng::seed_from_u64(42);
445        let m = Matrix::xavier(100, 100, &mut rng);
446        let n = m.data.len() as f64;
447        let mean = m.data.iter().sum::<f64>() / n;
448        let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
449        let expected_var = 2.0 / (100.0 + 100.0); // 0.01
450        assert!(
451            (variance - expected_var).abs() < expected_var * 0.5,
452            "variance {} not within 50% of expected {}",
453            variance,
454            expected_var
455        );
456    }
457
458    #[test]
459    fn test_xavier_all_finite() {
460        let mut rng = StdRng::seed_from_u64(42);
461        let m = Matrix::xavier(50, 50, &mut rng);
462        assert!(m.data.iter().all(|x| x.is_finite()));
463    }
464
465    #[test]
466    fn test_get_set_roundtrip() {
467        let mut m = Matrix::zeros(3, 3);
468        m.set(1, 2, 42.0);
469        assert_eq!(m.get(1, 2), 42.0);
470    }
471
472    #[test]
473    fn test_get_zero_default() {
474        let m = Matrix::zeros(2, 2);
475        assert_eq!(m.get(0, 0), 0.0);
476    }
477
478    #[test]
479    fn test_transpose_swaps_dims() {
480        let m = Matrix::zeros(3, 5);
481        let t = m.transpose();
482        assert_eq!(t.rows, 5);
483        assert_eq!(t.cols, 3);
484    }
485
486    #[test]
487    fn test_transpose_repositions_values() {
488        let mut m = Matrix::zeros(2, 3);
489        m.set(0, 1, 7.0);
490        m.set(1, 2, 3.0);
491        let t = m.transpose();
492        assert_eq!(t.get(1, 0), 7.0);
493        assert_eq!(t.get(2, 1), 3.0);
494    }
495
496    #[test]
497    fn test_transpose_double_is_identity() {
498        let mut rng = StdRng::seed_from_u64(42);
499        let m = Matrix::xavier(3, 5, &mut rng);
500        let tt = m.transpose().transpose();
501        assert_eq!(m.rows, tt.rows);
502        assert_eq!(m.cols, tt.cols);
503        for i in 0..m.data.len() {
504            assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
505        }
506    }
507
508    #[test]
509    fn test_mul_vec_known_result() {
510        // [[1,2],[3,4]] * [5,6] = [17, 39]
511        let mut m = Matrix::zeros(2, 2);
512        m.set(0, 0, 1.0);
513        m.set(0, 1, 2.0);
514        m.set(1, 0, 3.0);
515        m.set(1, 1, 4.0);
516        let result = m.mul_vec(&[5.0, 6.0]);
517        assert_eq!(result.len(), 2);
518        assert!((result[0] - 17.0).abs() < 1e-10);
519        assert!((result[1] - 39.0).abs() < 1e-10);
520    }
521
522    #[test]
523    fn test_mul_vec_output_length_equals_rows() {
524        let m = Matrix::zeros(4, 3);
525        let result = m.mul_vec(&[1.0, 2.0, 3.0]);
526        assert_eq!(result.len(), 4);
527    }
528
529    #[test]
530    #[should_panic(expected = "dimension")]
531    fn test_mul_vec_panics_wrong_length() {
532        let m = Matrix::zeros(2, 3);
533        m.mul_vec(&[1.0, 2.0]); // wrong length
534    }
535
536    #[test]
537    fn test_mul_vec_zero_matrix_returns_zeros() {
538        let m = Matrix::zeros(3, 2);
539        let result = m.mul_vec(&[5.0, 10.0]);
540        assert!(result.iter().all(|&v| v == 0.0));
541    }
542
543    #[test]
544    fn test_outer_dims_and_values() {
545        let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
546        assert_eq!(m.rows, 2);
547        assert_eq!(m.cols, 3);
548        assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
549        assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
550        assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
551        assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
552        assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
553        assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
554    }
555
556    #[test]
557    fn test_outer_empty_first_returns_zero_matrix() {
558        let m = Matrix::outer(&[], &[1.0, 2.0]);
559        assert_eq!(m.rows, 0);
560        assert_eq!(m.cols, 0);
561    }
562
563    #[test]
564    fn test_outer_empty_second_returns_zero_matrix() {
565        let m = Matrix::outer(&[1.0, 2.0], &[]);
566        assert_eq!(m.rows, 0);
567        assert_eq!(m.cols, 0);
568    }
569
570    #[test]
571    fn test_scale_add_basic() {
572        let mut m = Matrix::zeros(2, 2);
573        m.set(0, 0, 1.0);
574        m.set(1, 1, 2.0);
575        let mut other = Matrix::zeros(2, 2);
576        other.set(0, 0, 0.5);
577        other.set(1, 1, 0.5);
578        m.scale_add(&other, 2.0);
579        assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
580        assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
581    }
582
583    #[test]
584    fn test_scale_add_clips_to_weight_clip() {
585        let mut m = Matrix::zeros(1, 1);
586        m.set(0, 0, 4.0);
587        let mut other = Matrix::zeros(1, 1);
588        other.set(0, 0, 10.0);
589        m.scale_add(&other, 1.0);
590        assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
591    }
592
593    #[test]
594    fn test_scale_add_negative_clips_to_neg_weight_clip() {
595        let mut m = Matrix::zeros(1, 1);
596        m.set(0, 0, -4.0);
597        let mut other = Matrix::zeros(1, 1);
598        other.set(0, 0, -10.0);
599        m.scale_add(&other, 1.0);
600        assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
601    }
602
603    #[test]
604    fn test_scale_add_zero_scale_only_clips() {
605        let mut m = Matrix::zeros(1, 1);
606        m.set(0, 0, 3.0);
607        let other = Matrix::zeros(1, 1);
608        m.scale_add(&other, 0.0);
609        assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
610    }
611
612    #[test]
613    #[should_panic(expected = "dimension")]
614    fn test_scale_add_panics_on_dimension_mismatch() {
615        let mut m = Matrix::zeros(2, 2);
616        let other = Matrix::zeros(3, 3);
617        m.scale_add(&other, 1.0);
618    }
619
620    // ── Softmax Tests ─────────────────────────────────────────────────
621
622    #[test]
623    fn test_softmax_masked_sums_to_one() {
624        let logits = vec![1.0, 2.0, 3.0, 4.0];
625        let mask = vec![0, 1, 2, 3];
626        let probs = softmax_masked(&logits, &mask);
627        let sum: f64 = probs.iter().sum();
628        assert!((sum - 1.0).abs() < 1e-10);
629    }
630
631    #[test]
632    fn test_softmax_masked_unmasked_are_zero() {
633        let logits = vec![1.0, 2.0, 3.0, 4.0];
634        let mask = vec![1, 3];
635        let probs = softmax_masked(&logits, &mask);
636        assert_eq!(probs[0], 0.0);
637        assert_eq!(probs[2], 0.0);
638        assert!(probs[1] > 0.0);
639        assert!(probs[3] > 0.0);
640    }
641
642    #[test]
643    fn test_softmax_masked_single_index_is_one() {
644        let logits = vec![1.0, 2.0, 3.0];
645        let mask = vec![1];
646        let probs = softmax_masked(&logits, &mask);
647        assert!((probs[1] - 1.0).abs() < 1e-10);
648    }
649
650    #[test]
651    fn test_softmax_masked_empty_mask_returns_all_zeros() {
652        let logits = vec![1.0, 2.0, 3.0];
653        let probs = softmax_masked(&logits, &[]);
654        assert!(probs.iter().all(|&v| v == 0.0));
655    }
656
657    #[test]
658    fn test_softmax_masked_numerically_stable_large_logits() {
659        let logits = vec![1000.0, 1001.0, 1002.0];
660        let mask = vec![0, 1, 2];
661        let probs = softmax_masked(&logits, &mask);
662        assert!(probs.iter().all(|p| p.is_finite()));
663        let sum: f64 = probs.iter().sum();
664        assert!((sum - 1.0).abs() < 1e-10);
665    }
666
667    #[test]
668    fn test_softmax_masked_higher_logit_gets_higher_prob() {
669        let logits = vec![1.0, 5.0, 2.0];
670        let mask = vec![0, 1, 2];
671        let probs = softmax_masked(&logits, &mask);
672        assert!(probs[1] > probs[2]);
673        assert!(probs[2] > probs[0]);
674    }
675
676    // ── Argmax Tests ──────────────────────────────────────────────────
677
678    #[test]
679    fn test_argmax_masked_returns_highest_in_mask() {
680        let values = vec![1.0, 5.0, 3.0, 4.0];
681        let mask = vec![0, 2, 3];
682        assert_eq!(argmax_masked(&values, &mask), 3);
683    }
684
685    #[test]
686    fn test_argmax_masked_single_element() {
687        let values = vec![1.0, 5.0, 3.0];
688        let mask = vec![2];
689        assert_eq!(argmax_masked(&values, &mask), 2);
690    }
691
692    #[test]
693    fn test_argmax_masked_tie_returns_first() {
694        let values = vec![3.0, 3.0, 3.0];
695        let mask = vec![0, 1, 2];
696        assert_eq!(argmax_masked(&values, &mask), 0);
697    }
698
699    #[test]
700    #[should_panic]
701    fn test_argmax_masked_empty_panics() {
702        let values = vec![1.0, 2.0];
703        argmax_masked(&values, &[]);
704    }
705
706    // ── RMS Error Tests ───────────────────────────────────────────────
707
708    #[test]
709    fn test_rms_error_empty_returns_zero() {
710        assert_eq!(rms_error(&[]), 0.0);
711    }
712
713    #[test]
714    fn test_rms_error_single_empty_vec_returns_zero() {
715        let empty: &[f64] = &[];
716        assert_eq!(rms_error(&[empty]), 0.0);
717    }
718
719    #[test]
720    fn test_rms_error_known_two_vecs() {
721        let v1: &[f64] = &[1.0, 0.0];
722        let v2: &[f64] = &[0.0, 1.0];
723        let rms = rms_error(&[v1, v2]);
724        // sum_sq = 1+0+0+1 = 2, count = 4, rms = sqrt(2/4) = sqrt(0.5)
725        let expected = (0.5_f64).sqrt();
726        assert!((rms - expected).abs() < 1e-10);
727    }
728
729    #[test]
730    fn test_rms_error_single_vec() {
731        let v: &[f64] = &[3.0, 4.0];
732        let rms = rms_error(&[v]);
733        // sum_sq = 9+16 = 25, count = 2, rms = sqrt(12.5) = 3.5355...
734        let expected = (25.0 / 2.0_f64).sqrt();
735        assert!((rms - expected).abs() < 1e-10);
736    }
737
738    #[test]
739    fn test_rms_error_all_zeros_returns_zero() {
740        let v: &[f64] = &[0.0, 0.0, 0.0];
741        assert_eq!(rms_error(&[v]), 0.0);
742    }
743
744    // ── Sample Tests ──────────────────────────────────────────────────
745
746    #[test]
747    fn test_sample_from_probs_always_in_mask() {
748        let mut rng = StdRng::seed_from_u64(42);
749        let probs = vec![0.1, 0.2, 0.3, 0.4];
750        let mask = vec![1, 3];
751        for _ in 0..20 {
752            let idx = sample_from_probs(&probs, &mask, &mut rng);
753            assert!(mask.contains(&idx));
754        }
755    }
756
757    #[test]
758    fn test_sample_from_probs_single_action_always_returns_it() {
759        let mut rng = StdRng::seed_from_u64(42);
760        let probs = vec![0.5, 0.5];
761        let mask = vec![1];
762        for _ in 0..10 {
763            assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
764        }
765    }
766
767    #[test]
768    fn test_sample_from_probs_visits_multiple_actions() {
769        let mut rng = StdRng::seed_from_u64(42);
770        let probs = vec![0.5, 0.5];
771        let mask = vec![0, 1];
772        let mut seen = [false; 2];
773        for _ in 0..100 {
774            let idx = sample_from_probs(&probs, &mask, &mut rng);
775            seen[idx] = true;
776        }
777        assert!(seen[0] && seen[1], "should visit both actions");
778    }
779
780    #[test]
781    fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
782        let mut rng = StdRng::seed_from_u64(42);
783        let probs = vec![0.0, 0.0, 0.0];
784        let mask = vec![0, 2];
785        for _ in 0..20 {
786            let idx = sample_from_probs(&probs, &mask, &mut rng);
787            assert!(mask.contains(&idx));
788        }
789    }
790
791    #[test]
792    #[should_panic]
793    fn test_sample_from_probs_empty_mask_panics() {
794        let mut rng = StdRng::seed_from_u64(42);
795        let probs = vec![0.5, 0.5];
796        sample_from_probs(&probs, &[], &mut rng);
797    }
798
799    // ── Vec Utility Tests ─────────────────────────────────────────────
800
801    #[test]
802    fn test_vec_sub_known() {
803        let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
804        assert!((result[0] - 2.0).abs() < 1e-10);
805        assert!((result[1] - (-1.0)).abs() < 1e-10);
806    }
807
808    #[test]
809    fn test_vec_add_known() {
810        let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
811        assert!((result[0] - 4.0).abs() < 1e-10);
812        assert!((result[1] - 6.0).abs() < 1e-10);
813    }
814
815    #[test]
816    fn test_vec_scale_known() {
817        let result = vec_scale(&[1.0, -2.0], 3.0);
818        assert!((result[0] - 3.0).abs() < 1e-10);
819        assert!((result[1] - (-6.0)).abs() < 1e-10);
820    }
821
822    #[test]
823    fn test_clip_vec_clamps_positive() {
824        let mut v = vec![10.0, -10.0, 0.5];
825        clip_vec(&mut v, 5.0);
826        assert!((v[0] - 5.0).abs() < 1e-10);
827        assert!((v[1] - (-5.0)).abs() < 1e-10);
828        assert!((v[2] - 0.5).abs() < 1e-10);
829    }
830
831    #[test]
832    #[should_panic(expected = "length mismatch")]
833    fn test_vec_sub_panics_on_length_mismatch() {
834        vec_sub(&[1.0, 2.0], &[1.0]);
835    }
836
837    #[test]
838    #[should_panic(expected = "length mismatch")]
839    fn test_vec_add_panics_on_length_mismatch() {
840        vec_add(&[1.0, 2.0], &[1.0]);
841    }
842
843    #[test]
844    fn test_clip_vec_leaves_safe_values() {
845        let mut v = vec![1.0, -1.0, 0.0];
846        clip_vec(&mut v, 5.0);
847        assert!((v[0] - 1.0).abs() < 1e-10);
848        assert!((v[1] - (-1.0)).abs() < 1e-10);
849        assert!((v[2] - 0.0).abs() < 1e-10);
850    }
851
852    // ── Defensive: OOB assertions ────────────────────────────────
853
854    #[test]
855    #[should_panic(expected = "out of bounds")]
856    fn test_get_panics_on_oob_row() {
857        let m = Matrix::zeros(2, 2);
858        m.get(5, 0); // should panic, not return 0.0
859    }
860
861    #[test]
862    #[should_panic(expected = "out of bounds")]
863    fn test_set_panics_on_oob_row() {
864        let mut m = Matrix::zeros(2, 2);
865        m.set(5, 0, 1.0); // should panic, not silently do nothing
866    }
867
868    #[test]
869    #[should_panic(expected = "mask index out of bounds")]
870    fn test_softmax_masked_panics_on_oob_mask() {
871        let logits = vec![1.0, 2.0, 3.0];
872        softmax_masked(&logits, &[0, 5]); // 5 >= logits.len()
873    }
874
875    #[test]
876    #[should_panic(expected = "mask index out of bounds")]
877    fn test_argmax_masked_panics_on_oob_mask() {
878        let values = vec![1.0, 2.0, 3.0];
879        argmax_masked(&values, &[0, 5]); // 5 >= values.len()
880    }
881
882    // ── sample_from_probs distribution ───────────────────────────
883
884    #[test]
885    fn test_sample_from_probs_distribution_roughly_correct() {
886        let mut rng = StdRng::seed_from_u64(42);
887        let probs = vec![0.7, 0.3];
888        let mask = vec![0, 1];
889        let mut counts = [0usize; 2];
890        let n = 1000;
891        for _ in 0..n {
892            let idx = sample_from_probs(&probs, &mask, &mut rng);
893            counts[idx] += 1;
894        }
895        let ratio = counts[0] as f64 / n as f64;
896        // Should be roughly 0.7, allow 10% tolerance
897        assert!(
898            (ratio - 0.7).abs() < 0.1,
899            "Expected ~0.7 for action 0, got {ratio}"
900        );
901    }
902}