1use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14pub const WEIGHT_CLIP: f64 = 5.0;
16
17pub const GRAD_CLIP: f64 = 5.0;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Matrix {
35 pub data: Vec<f64>,
37 pub rows: usize,
39 pub cols: usize,
41}
42
43impl Matrix {
44 pub fn zeros(rows: usize, cols: usize) -> Self {
55 Self {
56 data: vec![0.0; rows * cols],
57 rows,
58 cols,
59 }
60 }
61
62 pub fn xavier(rows: usize, cols: usize, rng: &mut impl Rng) -> Self {
77 let limit = (6.0 / (rows + cols) as f64).sqrt();
78 let data: Vec<f64> = (0..rows * cols)
79 .map(|_| rng.gen_range(-limit..limit))
80 .collect();
81 Self { data, rows, cols }
82 }
83
84 pub fn get(&self, row: usize, col: usize) -> f64 {
93 assert!(
94 row < self.rows && col < self.cols,
95 "Matrix::get out of bounds: ({row}, {col}) for ({}, {})",
96 self.rows,
97 self.cols
98 );
99 self.data[row * self.cols + col]
100 }
101
102 pub fn set(&mut self, row: usize, col: usize, val: f64) {
112 assert!(
113 row < self.rows && col < self.cols,
114 "Matrix::set out of bounds: ({row}, {col}) for ({}, {})",
115 self.rows,
116 self.cols
117 );
118 self.data[row * self.cols + col] = val;
119 }
120
121 pub fn transpose(&self) -> Self {
127 let mut result = Matrix::zeros(self.cols, self.rows);
128 for r in 0..self.rows {
129 for c in 0..self.cols {
130 result.set(c, r, self.get(r, c));
131 }
132 }
133 result
134 }
135
136 pub fn mul_vec(&self, v: &[f64]) -> Vec<f64> {
150 assert_eq!(
151 v.len(),
152 self.cols,
153 "dimension mismatch: vector length {} != matrix cols {}",
154 v.len(),
155 self.cols
156 );
157 (0..self.rows)
158 .map(|r| {
159 let row_start = r * self.cols;
160 self.data[row_start..row_start + self.cols]
161 .iter()
162 .zip(v.iter())
163 .map(|(a, b)| a * b)
164 .sum()
165 })
166 .collect()
167 }
168
169 pub fn outer(a: &[f64], b: &[f64]) -> Self {
181 if a.is_empty() || b.is_empty() {
182 return Matrix::zeros(0, 0);
183 }
184 let rows = a.len();
185 let cols = b.len();
186 let mut data = vec![0.0; rows * cols];
187 for r in 0..rows {
188 for c in 0..cols {
189 data[r * cols + c] = a[r] * b[c];
190 }
191 }
192 Self { data, rows, cols }
193 }
194
195 pub fn scale_add(&mut self, other: &Matrix, scale: f64) {
206 assert!(
207 self.rows == other.rows && self.cols == other.cols,
208 "dimension mismatch in scale_add: ({},{}) vs ({},{})",
209 self.rows,
210 self.cols,
211 other.rows,
212 other.cols
213 );
214 for i in 0..self.data.len() {
215 self.data[i] += scale * other.data[i];
216 self.data[i] = self.data[i].clamp(-WEIGHT_CLIP, WEIGHT_CLIP);
217 }
218 }
219}
220
221pub fn softmax_masked(logits: &[f64], mask: &[usize]) -> Vec<f64> {
235 let mut result = vec![0.0; logits.len()];
236 if mask.is_empty() {
237 return result;
238 }
239 assert!(
240 mask.iter().all(|&i| i < logits.len()),
241 "softmax_masked: mask index out of bounds (max mask={}, logits len={})",
242 mask.iter().max().unwrap_or(&0),
243 logits.len()
244 );
245
246 let max_val = mask
247 .iter()
248 .map(|&i| logits[i])
249 .fold(f64::NEG_INFINITY, f64::max);
250 let mut sum = 0.0;
251 for &i in mask {
252 let exp_val = (logits[i] - max_val).exp();
253 result[i] = exp_val;
254 sum += exp_val;
255 }
256 if sum > 0.0 {
257 for &i in mask {
258 result[i] /= sum;
259 }
260 }
261 result
262}
263
264pub fn argmax_masked(values: &[f64], mask: &[usize]) -> usize {
275 assert!(!mask.is_empty(), "argmax_masked: empty mask");
276 assert!(
277 mask.iter().all(|&i| i < values.len()),
278 "argmax_masked: mask index out of bounds (max mask={}, values len={})",
279 mask.iter().max().unwrap_or(&0),
280 values.len()
281 );
282 let mut best_idx = mask[0];
283 let mut best_val = values[mask[0]];
284 for &i in &mask[1..] {
285 if values[i] > best_val {
286 best_val = values[i];
287 best_idx = i;
288 }
289 }
290 best_idx
291}
292
293pub fn rms_error(error_vecs: &[&[f64]]) -> f64 {
303 let mut sum_sq = 0.0;
304 let mut count = 0usize;
305 for v in error_vecs {
306 for &e in *v {
307 sum_sq += e * e;
308 count += 1;
309 }
310 }
311 if count == 0 {
312 return 0.0;
313 }
314 (sum_sq / count as f64).sqrt()
315}
316
317pub fn sample_from_probs(probs: &[f64], mask: &[usize], rng: &mut impl Rng) -> usize {
332 assert!(!mask.is_empty(), "sample_from_probs: empty mask");
333
334 if mask.len() == 1 {
335 return mask[0];
336 }
337
338 let sum: f64 = mask.iter().map(|&i| probs[i]).sum();
339 if sum <= 0.0 {
340 return mask[rng.gen_range(0..mask.len())];
342 }
343
344 let threshold: f64 = rng.gen_range(0.0..1.0);
345 let mut cumulative = 0.0;
346 for &i in mask {
347 cumulative += probs[i] / sum;
348 if cumulative >= threshold {
349 return i;
350 }
351 }
352
353 *mask.last().unwrap()
355}
356
357pub(crate) fn clip_vec(v: &mut [f64], max_abs: f64) {
364 for x in v.iter_mut() {
365 *x = x.clamp(-max_abs, max_abs);
366 }
367}
368
369pub(crate) fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
380 assert_eq!(
381 a.len(),
382 b.len(),
383 "vec_sub: length mismatch {} vs {}",
384 a.len(),
385 b.len()
386 );
387 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
388}
389
390pub(crate) fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
401 assert_eq!(
402 a.len(),
403 b.len(),
404 "vec_add: length mismatch {} vs {}",
405 a.len(),
406 b.len()
407 );
408 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
409}
410
411pub(crate) fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
422 v.iter().map(|x| x * s).collect()
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use rand::rngs::StdRng;
429 use rand::SeedableRng;
430
431 #[test]
434 fn test_zeros_all_zero_correct_dims() {
435 let m = Matrix::zeros(3, 4);
436 assert_eq!(m.rows, 3);
437 assert_eq!(m.cols, 4);
438 assert_eq!(m.data.len(), 12);
439 assert!(m.data.iter().all(|&v| v == 0.0));
440 }
441
442 #[test]
443 fn test_xavier_variance_approx() {
444 let mut rng = StdRng::seed_from_u64(42);
445 let m = Matrix::xavier(100, 100, &mut rng);
446 let n = m.data.len() as f64;
447 let mean = m.data.iter().sum::<f64>() / n;
448 let variance = m.data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
449 let expected_var = 2.0 / (100.0 + 100.0); assert!(
451 (variance - expected_var).abs() < expected_var * 0.5,
452 "variance {} not within 50% of expected {}",
453 variance,
454 expected_var
455 );
456 }
457
458 #[test]
459 fn test_xavier_all_finite() {
460 let mut rng = StdRng::seed_from_u64(42);
461 let m = Matrix::xavier(50, 50, &mut rng);
462 assert!(m.data.iter().all(|x| x.is_finite()));
463 }
464
465 #[test]
466 fn test_get_set_roundtrip() {
467 let mut m = Matrix::zeros(3, 3);
468 m.set(1, 2, 42.0);
469 assert_eq!(m.get(1, 2), 42.0);
470 }
471
472 #[test]
473 fn test_get_zero_default() {
474 let m = Matrix::zeros(2, 2);
475 assert_eq!(m.get(0, 0), 0.0);
476 }
477
478 #[test]
479 fn test_transpose_swaps_dims() {
480 let m = Matrix::zeros(3, 5);
481 let t = m.transpose();
482 assert_eq!(t.rows, 5);
483 assert_eq!(t.cols, 3);
484 }
485
486 #[test]
487 fn test_transpose_repositions_values() {
488 let mut m = Matrix::zeros(2, 3);
489 m.set(0, 1, 7.0);
490 m.set(1, 2, 3.0);
491 let t = m.transpose();
492 assert_eq!(t.get(1, 0), 7.0);
493 assert_eq!(t.get(2, 1), 3.0);
494 }
495
496 #[test]
497 fn test_transpose_double_is_identity() {
498 let mut rng = StdRng::seed_from_u64(42);
499 let m = Matrix::xavier(3, 5, &mut rng);
500 let tt = m.transpose().transpose();
501 assert_eq!(m.rows, tt.rows);
502 assert_eq!(m.cols, tt.cols);
503 for i in 0..m.data.len() {
504 assert!((m.data[i] - tt.data[i]).abs() < 1e-15);
505 }
506 }
507
508 #[test]
509 fn test_mul_vec_known_result() {
510 let mut m = Matrix::zeros(2, 2);
512 m.set(0, 0, 1.0);
513 m.set(0, 1, 2.0);
514 m.set(1, 0, 3.0);
515 m.set(1, 1, 4.0);
516 let result = m.mul_vec(&[5.0, 6.0]);
517 assert_eq!(result.len(), 2);
518 assert!((result[0] - 17.0).abs() < 1e-10);
519 assert!((result[1] - 39.0).abs() < 1e-10);
520 }
521
522 #[test]
523 fn test_mul_vec_output_length_equals_rows() {
524 let m = Matrix::zeros(4, 3);
525 let result = m.mul_vec(&[1.0, 2.0, 3.0]);
526 assert_eq!(result.len(), 4);
527 }
528
529 #[test]
530 #[should_panic(expected = "dimension")]
531 fn test_mul_vec_panics_wrong_length() {
532 let m = Matrix::zeros(2, 3);
533 m.mul_vec(&[1.0, 2.0]); }
535
536 #[test]
537 fn test_mul_vec_zero_matrix_returns_zeros() {
538 let m = Matrix::zeros(3, 2);
539 let result = m.mul_vec(&[5.0, 10.0]);
540 assert!(result.iter().all(|&v| v == 0.0));
541 }
542
543 #[test]
544 fn test_outer_dims_and_values() {
545 let m = Matrix::outer(&[1.0, 2.0], &[3.0, 4.0, 5.0]);
546 assert_eq!(m.rows, 2);
547 assert_eq!(m.cols, 3);
548 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
549 assert!((m.get(0, 1) - 4.0).abs() < 1e-10);
550 assert!((m.get(0, 2) - 5.0).abs() < 1e-10);
551 assert!((m.get(1, 0) - 6.0).abs() < 1e-10);
552 assert!((m.get(1, 1) - 8.0).abs() < 1e-10);
553 assert!((m.get(1, 2) - 10.0).abs() < 1e-10);
554 }
555
556 #[test]
557 fn test_outer_empty_first_returns_zero_matrix() {
558 let m = Matrix::outer(&[], &[1.0, 2.0]);
559 assert_eq!(m.rows, 0);
560 assert_eq!(m.cols, 0);
561 }
562
563 #[test]
564 fn test_outer_empty_second_returns_zero_matrix() {
565 let m = Matrix::outer(&[1.0, 2.0], &[]);
566 assert_eq!(m.rows, 0);
567 assert_eq!(m.cols, 0);
568 }
569
570 #[test]
571 fn test_scale_add_basic() {
572 let mut m = Matrix::zeros(2, 2);
573 m.set(0, 0, 1.0);
574 m.set(1, 1, 2.0);
575 let mut other = Matrix::zeros(2, 2);
576 other.set(0, 0, 0.5);
577 other.set(1, 1, 0.5);
578 m.scale_add(&other, 2.0);
579 assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
580 assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
581 }
582
583 #[test]
584 fn test_scale_add_clips_to_weight_clip() {
585 let mut m = Matrix::zeros(1, 1);
586 m.set(0, 0, 4.0);
587 let mut other = Matrix::zeros(1, 1);
588 other.set(0, 0, 10.0);
589 m.scale_add(&other, 1.0);
590 assert!((m.get(0, 0) - WEIGHT_CLIP).abs() < 1e-10);
591 }
592
593 #[test]
594 fn test_scale_add_negative_clips_to_neg_weight_clip() {
595 let mut m = Matrix::zeros(1, 1);
596 m.set(0, 0, -4.0);
597 let mut other = Matrix::zeros(1, 1);
598 other.set(0, 0, -10.0);
599 m.scale_add(&other, 1.0);
600 assert!((m.get(0, 0) - (-WEIGHT_CLIP)).abs() < 1e-10);
601 }
602
603 #[test]
604 fn test_scale_add_zero_scale_only_clips() {
605 let mut m = Matrix::zeros(1, 1);
606 m.set(0, 0, 3.0);
607 let other = Matrix::zeros(1, 1);
608 m.scale_add(&other, 0.0);
609 assert!((m.get(0, 0) - 3.0).abs() < 1e-10);
610 }
611
612 #[test]
613 #[should_panic(expected = "dimension")]
614 fn test_scale_add_panics_on_dimension_mismatch() {
615 let mut m = Matrix::zeros(2, 2);
616 let other = Matrix::zeros(3, 3);
617 m.scale_add(&other, 1.0);
618 }
619
620 #[test]
623 fn test_softmax_masked_sums_to_one() {
624 let logits = vec![1.0, 2.0, 3.0, 4.0];
625 let mask = vec![0, 1, 2, 3];
626 let probs = softmax_masked(&logits, &mask);
627 let sum: f64 = probs.iter().sum();
628 assert!((sum - 1.0).abs() < 1e-10);
629 }
630
631 #[test]
632 fn test_softmax_masked_unmasked_are_zero() {
633 let logits = vec![1.0, 2.0, 3.0, 4.0];
634 let mask = vec![1, 3];
635 let probs = softmax_masked(&logits, &mask);
636 assert_eq!(probs[0], 0.0);
637 assert_eq!(probs[2], 0.0);
638 assert!(probs[1] > 0.0);
639 assert!(probs[3] > 0.0);
640 }
641
642 #[test]
643 fn test_softmax_masked_single_index_is_one() {
644 let logits = vec![1.0, 2.0, 3.0];
645 let mask = vec![1];
646 let probs = softmax_masked(&logits, &mask);
647 assert!((probs[1] - 1.0).abs() < 1e-10);
648 }
649
650 #[test]
651 fn test_softmax_masked_empty_mask_returns_all_zeros() {
652 let logits = vec![1.0, 2.0, 3.0];
653 let probs = softmax_masked(&logits, &[]);
654 assert!(probs.iter().all(|&v| v == 0.0));
655 }
656
657 #[test]
658 fn test_softmax_masked_numerically_stable_large_logits() {
659 let logits = vec![1000.0, 1001.0, 1002.0];
660 let mask = vec![0, 1, 2];
661 let probs = softmax_masked(&logits, &mask);
662 assert!(probs.iter().all(|p| p.is_finite()));
663 let sum: f64 = probs.iter().sum();
664 assert!((sum - 1.0).abs() < 1e-10);
665 }
666
667 #[test]
668 fn test_softmax_masked_higher_logit_gets_higher_prob() {
669 let logits = vec![1.0, 5.0, 2.0];
670 let mask = vec![0, 1, 2];
671 let probs = softmax_masked(&logits, &mask);
672 assert!(probs[1] > probs[2]);
673 assert!(probs[2] > probs[0]);
674 }
675
676 #[test]
679 fn test_argmax_masked_returns_highest_in_mask() {
680 let values = vec![1.0, 5.0, 3.0, 4.0];
681 let mask = vec![0, 2, 3];
682 assert_eq!(argmax_masked(&values, &mask), 3);
683 }
684
685 #[test]
686 fn test_argmax_masked_single_element() {
687 let values = vec![1.0, 5.0, 3.0];
688 let mask = vec![2];
689 assert_eq!(argmax_masked(&values, &mask), 2);
690 }
691
692 #[test]
693 fn test_argmax_masked_tie_returns_first() {
694 let values = vec![3.0, 3.0, 3.0];
695 let mask = vec![0, 1, 2];
696 assert_eq!(argmax_masked(&values, &mask), 0);
697 }
698
699 #[test]
700 #[should_panic]
701 fn test_argmax_masked_empty_panics() {
702 let values = vec![1.0, 2.0];
703 argmax_masked(&values, &[]);
704 }
705
706 #[test]
709 fn test_rms_error_empty_returns_zero() {
710 assert_eq!(rms_error(&[]), 0.0);
711 }
712
713 #[test]
714 fn test_rms_error_single_empty_vec_returns_zero() {
715 let empty: &[f64] = &[];
716 assert_eq!(rms_error(&[empty]), 0.0);
717 }
718
719 #[test]
720 fn test_rms_error_known_two_vecs() {
721 let v1: &[f64] = &[1.0, 0.0];
722 let v2: &[f64] = &[0.0, 1.0];
723 let rms = rms_error(&[v1, v2]);
724 let expected = (0.5_f64).sqrt();
726 assert!((rms - expected).abs() < 1e-10);
727 }
728
729 #[test]
730 fn test_rms_error_single_vec() {
731 let v: &[f64] = &[3.0, 4.0];
732 let rms = rms_error(&[v]);
733 let expected = (25.0 / 2.0_f64).sqrt();
735 assert!((rms - expected).abs() < 1e-10);
736 }
737
738 #[test]
739 fn test_rms_error_all_zeros_returns_zero() {
740 let v: &[f64] = &[0.0, 0.0, 0.0];
741 assert_eq!(rms_error(&[v]), 0.0);
742 }
743
744 #[test]
747 fn test_sample_from_probs_always_in_mask() {
748 let mut rng = StdRng::seed_from_u64(42);
749 let probs = vec![0.1, 0.2, 0.3, 0.4];
750 let mask = vec![1, 3];
751 for _ in 0..20 {
752 let idx = sample_from_probs(&probs, &mask, &mut rng);
753 assert!(mask.contains(&idx));
754 }
755 }
756
757 #[test]
758 fn test_sample_from_probs_single_action_always_returns_it() {
759 let mut rng = StdRng::seed_from_u64(42);
760 let probs = vec![0.5, 0.5];
761 let mask = vec![1];
762 for _ in 0..10 {
763 assert_eq!(sample_from_probs(&probs, &mask, &mut rng), 1);
764 }
765 }
766
767 #[test]
768 fn test_sample_from_probs_visits_multiple_actions() {
769 let mut rng = StdRng::seed_from_u64(42);
770 let probs = vec![0.5, 0.5];
771 let mask = vec![0, 1];
772 let mut seen = [false; 2];
773 for _ in 0..100 {
774 let idx = sample_from_probs(&probs, &mask, &mut rng);
775 seen[idx] = true;
776 }
777 assert!(seen[0] && seen[1], "should visit both actions");
778 }
779
780 #[test]
781 fn test_sample_from_probs_zero_probs_fallback_is_in_mask() {
782 let mut rng = StdRng::seed_from_u64(42);
783 let probs = vec![0.0, 0.0, 0.0];
784 let mask = vec![0, 2];
785 for _ in 0..20 {
786 let idx = sample_from_probs(&probs, &mask, &mut rng);
787 assert!(mask.contains(&idx));
788 }
789 }
790
791 #[test]
792 #[should_panic]
793 fn test_sample_from_probs_empty_mask_panics() {
794 let mut rng = StdRng::seed_from_u64(42);
795 let probs = vec![0.5, 0.5];
796 sample_from_probs(&probs, &[], &mut rng);
797 }
798
799 #[test]
802 fn test_vec_sub_known() {
803 let result = vec_sub(&[3.0, 1.0], &[1.0, 2.0]);
804 assert!((result[0] - 2.0).abs() < 1e-10);
805 assert!((result[1] - (-1.0)).abs() < 1e-10);
806 }
807
808 #[test]
809 fn test_vec_add_known() {
810 let result = vec_add(&[1.0, 2.0], &[3.0, 4.0]);
811 assert!((result[0] - 4.0).abs() < 1e-10);
812 assert!((result[1] - 6.0).abs() < 1e-10);
813 }
814
815 #[test]
816 fn test_vec_scale_known() {
817 let result = vec_scale(&[1.0, -2.0], 3.0);
818 assert!((result[0] - 3.0).abs() < 1e-10);
819 assert!((result[1] - (-6.0)).abs() < 1e-10);
820 }
821
822 #[test]
823 fn test_clip_vec_clamps_positive() {
824 let mut v = vec![10.0, -10.0, 0.5];
825 clip_vec(&mut v, 5.0);
826 assert!((v[0] - 5.0).abs() < 1e-10);
827 assert!((v[1] - (-5.0)).abs() < 1e-10);
828 assert!((v[2] - 0.5).abs() < 1e-10);
829 }
830
831 #[test]
832 #[should_panic(expected = "length mismatch")]
833 fn test_vec_sub_panics_on_length_mismatch() {
834 vec_sub(&[1.0, 2.0], &[1.0]);
835 }
836
837 #[test]
838 #[should_panic(expected = "length mismatch")]
839 fn test_vec_add_panics_on_length_mismatch() {
840 vec_add(&[1.0, 2.0], &[1.0]);
841 }
842
843 #[test]
844 fn test_clip_vec_leaves_safe_values() {
845 let mut v = vec![1.0, -1.0, 0.0];
846 clip_vec(&mut v, 5.0);
847 assert!((v[0] - 1.0).abs() < 1e-10);
848 assert!((v[1] - (-1.0)).abs() < 1e-10);
849 assert!((v[2] - 0.0).abs() < 1e-10);
850 }
851
852 #[test]
855 #[should_panic(expected = "out of bounds")]
856 fn test_get_panics_on_oob_row() {
857 let m = Matrix::zeros(2, 2);
858 m.get(5, 0); }
860
861 #[test]
862 #[should_panic(expected = "out of bounds")]
863 fn test_set_panics_on_oob_row() {
864 let mut m = Matrix::zeros(2, 2);
865 m.set(5, 0, 1.0); }
867
868 #[test]
869 #[should_panic(expected = "mask index out of bounds")]
870 fn test_softmax_masked_panics_on_oob_mask() {
871 let logits = vec![1.0, 2.0, 3.0];
872 softmax_masked(&logits, &[0, 5]); }
874
875 #[test]
876 #[should_panic(expected = "mask index out of bounds")]
877 fn test_argmax_masked_panics_on_oob_mask() {
878 let values = vec![1.0, 2.0, 3.0];
879 argmax_masked(&values, &[0, 5]); }
881
882 #[test]
885 fn test_sample_from_probs_distribution_roughly_correct() {
886 let mut rng = StdRng::seed_from_u64(42);
887 let probs = vec![0.7, 0.3];
888 let mask = vec![0, 1];
889 let mut counts = [0usize; 2];
890 let n = 1000;
891 for _ in 0..n {
892 let idx = sample_from_probs(&probs, &mask, &mut rng);
893 counts[idx] += 1;
894 }
895 let ratio = counts[0] as f64 / n as f64;
896 assert!(
898 (ratio - 0.7).abs() < 0.1,
899 "Expected ~0.7 for action 0, got {ratio}"
900 );
901 }
902}