Skip to main content

pc_rl_core/linalg/
cpu.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-29
4
5//! CPU backend for the [`LinAlg`] trait.
6//!
7//! [`CpuLinAlg`] implements all linear algebra operations using the
8//! existing [`Matrix`] struct and `Vec<f64>`.
9//! All operations are zero-cost wrappers that delegate to the current
10//! pure-Rust implementations.
11
12use crate::activation::Activation;
13use crate::linalg::LinAlg;
14use crate::matrix::Matrix;
15use rand::Rng;
16
17/// CPU linear algebra backend.
18///
19/// Uses `Vec<f64>` for vectors and [`Matrix`] for matrices.
20/// All trait methods delegate to existing implementations in
21/// [`crate::matrix`] and [`crate::activation`].
22#[derive(Debug, Clone)]
23pub struct CpuLinAlg;
24
25impl LinAlg for CpuLinAlg {
26    type Vector = Vec<f64>;
27    type Matrix = Matrix;
28
29    fn zeros_vec(size: usize) -> Self::Vector {
30        vec![0.0; size]
31    }
32
33    fn zeros_mat(rows: usize, cols: usize) -> Self::Matrix {
34        Matrix::zeros(rows, cols)
35    }
36
37    fn xavier_mat(rows: usize, cols: usize, rng: &mut impl Rng) -> Self::Matrix {
38        Matrix::xavier(rows, cols, rng)
39    }
40
41    fn mat_vec_mul(m: &Self::Matrix, v: &Self::Vector) -> Self::Vector {
42        m.mul_vec(v)
43    }
44
45    fn mat_transpose(m: &Self::Matrix) -> Self::Matrix {
46        m.transpose()
47    }
48
49    fn outer_product(a: &Self::Vector, b: &Self::Vector) -> Self::Matrix {
50        Matrix::outer(a, b)
51    }
52
53    fn mat_mul(a: &Self::Matrix, b: &Self::Matrix) -> Self::Matrix {
54        assert_eq!(a.cols, b.rows, "mat_mul: inner dimensions mismatch");
55        let mut result = Matrix::zeros(a.rows, b.cols);
56        for i in 0..a.rows {
57            for j in 0..b.cols {
58                let mut sum = 0.0;
59                for k in 0..a.cols {
60                    sum += a.get(i, k) * b.get(k, j);
61                }
62                result.set(i, j, sum);
63            }
64        }
65        result
66    }
67
68    fn svd(m: &Self::Matrix) -> crate::linalg::SvdResult<Self> {
69        Ok(crate::linalg::golub_kahan::GolubKahanSvd::new().compute(m)?)
70    }
71
72    fn mat_scale_add(m: &mut Self::Matrix, other: &Self::Matrix, scale: f64) {
73        m.scale_add(other, scale);
74    }
75
76    fn mat_rows(m: &Self::Matrix) -> usize {
77        m.rows
78    }
79
80    fn mat_cols(m: &Self::Matrix) -> usize {
81        m.cols
82    }
83
84    fn mat_get(m: &Self::Matrix, row: usize, col: usize) -> f64 {
85        m.get(row, col)
86    }
87
88    fn mat_set(m: &mut Self::Matrix, row: usize, col: usize, val: f64) {
89        m.set(row, col, val);
90    }
91
92    fn vec_add(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
93        crate::matrix::vec_add(a, b)
94    }
95
96    fn vec_sub(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
97        crate::matrix::vec_sub(a, b)
98    }
99
100    fn vec_scale(v: &Self::Vector, s: f64) -> Self::Vector {
101        crate::matrix::vec_scale(v, s)
102    }
103
104    fn vec_hadamard(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
105        assert_eq!(a.len(), b.len(), "vec_hadamard: length mismatch");
106        a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
107    }
108
109    fn vec_dot(a: &Self::Vector, b: &Self::Vector) -> f64 {
110        assert_eq!(a.len(), b.len(), "vec_dot: length mismatch");
111        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
112    }
113
114    fn vec_len(v: &Self::Vector) -> usize {
115        v.len()
116    }
117
118    fn vec_get(v: &Self::Vector, i: usize) -> f64 {
119        v[i]
120    }
121
122    fn vec_set(v: &mut Self::Vector, i: usize, val: f64) {
123        v[i] = val;
124    }
125
126    fn vec_from_slice(s: &[f64]) -> Self::Vector {
127        s.to_vec()
128    }
129
130    fn vec_to_vec(v: &Self::Vector) -> Vec<f64> {
131        v.clone()
132    }
133
134    fn vec_as_slice(v: &Self::Vector) -> &[f64] {
135        v.as_slice()
136    }
137
138    fn clip_vec(v: &mut Self::Vector, max_abs: f64) {
139        crate::matrix::clip_vec(v, max_abs);
140    }
141
142    fn clip_mat(m: &mut Self::Matrix, max_abs: f64) {
143        for x in m.data.iter_mut() {
144            *x = x.clamp(-max_abs, max_abs);
145        }
146    }
147
148    fn apply_activation(v: &Self::Vector, act: Activation) -> Self::Vector {
149        v.iter().map(|&x| act.apply(x)).collect()
150    }
151
152    fn apply_derivative(v: &Self::Vector, act: Activation) -> Self::Vector {
153        v.iter().map(|&fx| act.derivative(fx)).collect()
154    }
155
156    fn softmax_masked(logits: &Self::Vector, mask: &[usize]) -> Self::Vector {
157        crate::matrix::softmax_masked(logits, mask)
158    }
159
160    fn argmax_masked(values: &Self::Vector, mask: &[usize]) -> usize {
161        crate::matrix::argmax_masked(values, mask)
162    }
163
164    fn sample_from_probs(probs: &Self::Vector, mask: &[usize], rng: &mut impl Rng) -> usize {
165        crate::matrix::sample_from_probs(probs, mask, rng)
166    }
167
168    fn rms_error(error_vecs: &[&Self::Vector]) -> f64 {
169        let slices: Vec<&[f64]> = error_vecs.iter().map(|v| v.as_slice()).collect();
170        crate::matrix::rms_error(&slices)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use rand::SeedableRng;
178
179    // ── Cycle 1.1: Vector basics ─────────────────────────────────
180
181    #[test]
182    fn test_zeros_vec_correct_length() {
183        let v = CpuLinAlg::zeros_vec(5);
184        assert_eq!(CpuLinAlg::vec_len(&v), 5);
185    }
186
187    #[test]
188    fn test_zeros_vec_all_zeros() {
189        let v = CpuLinAlg::zeros_vec(3);
190        for i in 0..3 {
191            assert_eq!(CpuLinAlg::vec_get(&v, i), 0.0);
192        }
193    }
194
195    #[test]
196    fn test_zeros_vec_empty() {
197        let v = CpuLinAlg::zeros_vec(0);
198        assert_eq!(CpuLinAlg::vec_len(&v), 0);
199    }
200
201    #[test]
202    fn test_vec_get_returns_element() {
203        let v = CpuLinAlg::vec_from_slice(&[10.0, 20.0, 30.0]);
204        assert_eq!(CpuLinAlg::vec_get(&v, 0), 10.0);
205        assert_eq!(CpuLinAlg::vec_get(&v, 1), 20.0);
206        assert_eq!(CpuLinAlg::vec_get(&v, 2), 30.0);
207    }
208
209    #[test]
210    fn test_vec_set_modifies_element() {
211        let mut v = CpuLinAlg::zeros_vec(3);
212        CpuLinAlg::vec_set(&mut v, 1, 42.0);
213        assert_eq!(CpuLinAlg::vec_get(&v, 1), 42.0);
214        assert_eq!(CpuLinAlg::vec_get(&v, 0), 0.0);
215    }
216
217    #[test]
218    fn test_vec_from_slice_creates_vector() {
219        let v = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
220        assert_eq!(CpuLinAlg::vec_len(&v), 2);
221        assert_eq!(CpuLinAlg::vec_get(&v, 0), 1.0);
222        assert_eq!(CpuLinAlg::vec_get(&v, 1), 2.0);
223    }
224
225    #[test]
226    fn test_vec_from_slice_empty() {
227        let v = CpuLinAlg::vec_from_slice(&[]);
228        assert_eq!(CpuLinAlg::vec_len(&v), 0);
229    }
230
231    #[test]
232    fn test_vec_to_vec_returns_owned() {
233        let v = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0]);
234        let owned = CpuLinAlg::vec_to_vec(&v);
235        assert_eq!(owned, vec![1.0, 2.0, 3.0]);
236    }
237
238    #[test]
239    fn test_vec_as_slice_returns_slice() {
240        let v = CpuLinAlg::vec_from_slice(&[4.0, 5.0]);
241        let s = CpuLinAlg::vec_as_slice(&v);
242        assert_eq!(s, &[4.0, 5.0]);
243    }
244
245    #[test]
246    fn test_vec_len_matches_creation_size() {
247        let v = CpuLinAlg::zeros_vec(7);
248        assert_eq!(CpuLinAlg::vec_len(&v), 7);
249    }
250
251    // ── Cycle 1.2: Vector arithmetic ─────────────────────────────
252
253    #[test]
254    fn test_vec_add_known() {
255        let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
256        let b = CpuLinAlg::vec_from_slice(&[3.0, 4.0]);
257        let r = CpuLinAlg::vec_add(&a, &b);
258        assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![4.0, 6.0]);
259    }
260
261    #[test]
262    fn test_vec_sub_known() {
263        let a = CpuLinAlg::vec_from_slice(&[5.0, 3.0]);
264        let b = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
265        let r = CpuLinAlg::vec_sub(&a, &b);
266        assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![4.0, 1.0]);
267    }
268
269    #[test]
270    fn test_vec_scale_known() {
271        let v = CpuLinAlg::vec_from_slice(&[2.0, 4.0]);
272        let r = CpuLinAlg::vec_scale(&v, 0.5);
273        assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, 2.0]);
274    }
275
276    #[test]
277    fn test_vec_hadamard_known() {
278        let a = CpuLinAlg::vec_from_slice(&[2.0, 3.0, 4.0]);
279        let b = CpuLinAlg::vec_from_slice(&[0.5, -1.0, 2.0]);
280        let r = CpuLinAlg::vec_hadamard(&a, &b);
281        assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, -3.0, 8.0]);
282    }
283
284    #[test]
285    fn test_clip_vec_clamps() {
286        let mut v = CpuLinAlg::vec_from_slice(&[10.0, -10.0, 0.5]);
287        CpuLinAlg::clip_vec(&mut v, 5.0);
288        assert!((CpuLinAlg::vec_get(&v, 0) - 5.0).abs() < 1e-10);
289        assert!((CpuLinAlg::vec_get(&v, 1) - (-5.0)).abs() < 1e-10);
290        assert!((CpuLinAlg::vec_get(&v, 2) - 0.5).abs() < 1e-10);
291    }
292
293    #[test]
294    fn test_clip_vec_leaves_safe_values() {
295        let mut v = CpuLinAlg::vec_from_slice(&[1.0, -1.0, 0.0]);
296        CpuLinAlg::clip_vec(&mut v, 5.0);
297        assert_eq!(CpuLinAlg::vec_to_vec(&v), vec![1.0, -1.0, 0.0]);
298    }
299
300    // ── Cycle 1.3: Matrix basics ─────────────────────────────────
301
302    #[test]
303    fn test_zeros_mat_correct_dims() {
304        let m = CpuLinAlg::zeros_mat(3, 4);
305        assert_eq!(CpuLinAlg::mat_rows(&m), 3);
306        assert_eq!(CpuLinAlg::mat_cols(&m), 4);
307    }
308
309    #[test]
310    fn test_zeros_mat_all_zeros() {
311        let m = CpuLinAlg::zeros_mat(2, 3);
312        for r in 0..2 {
313            for c in 0..3 {
314                assert_eq!(CpuLinAlg::mat_get(&m, r, c), 0.0);
315            }
316        }
317    }
318
319    #[test]
320    fn test_xavier_mat_correct_dims() {
321        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
322        let m = CpuLinAlg::xavier_mat(3, 4, &mut rng);
323        assert_eq!(CpuLinAlg::mat_rows(&m), 3);
324        assert_eq!(CpuLinAlg::mat_cols(&m), 4);
325    }
326
327    #[test]
328    fn test_xavier_mat_all_finite() {
329        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
330        let m = CpuLinAlg::xavier_mat(10, 10, &mut rng);
331        for r in 0..10 {
332            for c in 0..10 {
333                assert!(CpuLinAlg::mat_get(&m, r, c).is_finite());
334            }
335        }
336    }
337
338    #[test]
339    fn test_mat_get_set_roundtrip() {
340        let mut m = CpuLinAlg::zeros_mat(3, 3);
341        CpuLinAlg::mat_set(&mut m, 1, 2, 42.0);
342        assert_eq!(CpuLinAlg::mat_get(&m, 1, 2), 42.0);
343        assert_eq!(CpuLinAlg::mat_get(&m, 0, 0), 0.0);
344    }
345
346    #[test]
347    fn test_mat_transpose_swaps_dims() {
348        let m = CpuLinAlg::zeros_mat(3, 5);
349        let t = CpuLinAlg::mat_transpose(&m);
350        assert_eq!(CpuLinAlg::mat_rows(&t), 5);
351        assert_eq!(CpuLinAlg::mat_cols(&t), 3);
352    }
353
354    #[test]
355    fn test_mat_transpose_repositions_values() {
356        let mut m = CpuLinAlg::zeros_mat(2, 3);
357        CpuLinAlg::mat_set(&mut m, 0, 1, 7.0);
358        CpuLinAlg::mat_set(&mut m, 1, 2, 3.0);
359        let t = CpuLinAlg::mat_transpose(&m);
360        assert_eq!(CpuLinAlg::mat_get(&t, 1, 0), 7.0);
361        assert_eq!(CpuLinAlg::mat_get(&t, 2, 1), 3.0);
362    }
363
364    // ── Cycle 1.4: Matrix-vector operations ──────────────────────
365
366    #[test]
367    fn test_mat_vec_mul_known() {
368        // [[1,2],[3,4]] * [5,6] = [17, 39]
369        let mut m = CpuLinAlg::zeros_mat(2, 2);
370        CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
371        CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
372        CpuLinAlg::mat_set(&mut m, 1, 0, 3.0);
373        CpuLinAlg::mat_set(&mut m, 1, 1, 4.0);
374        let v = CpuLinAlg::vec_from_slice(&[5.0, 6.0]);
375        let r = CpuLinAlg::mat_vec_mul(&m, &v);
376        assert_eq!(CpuLinAlg::vec_len(&r), 2);
377        assert!((CpuLinAlg::vec_get(&r, 0) - 17.0).abs() < 1e-10);
378        assert!((CpuLinAlg::vec_get(&r, 1) - 39.0).abs() < 1e-10);
379    }
380
381    #[test]
382    fn test_outer_product_known() {
383        let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
384        let b = CpuLinAlg::vec_from_slice(&[3.0, 4.0, 5.0]);
385        let m = CpuLinAlg::outer_product(&a, &b);
386        assert_eq!(CpuLinAlg::mat_rows(&m), 2);
387        assert_eq!(CpuLinAlg::mat_cols(&m), 3);
388        assert!((CpuLinAlg::mat_get(&m, 0, 0) - 3.0).abs() < 1e-10);
389        assert!((CpuLinAlg::mat_get(&m, 1, 2) - 10.0).abs() < 1e-10);
390    }
391
392    #[test]
393    fn test_mat_scale_add_basic() {
394        let mut m = CpuLinAlg::zeros_mat(2, 2);
395        CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
396        CpuLinAlg::mat_set(&mut m, 1, 1, 2.0);
397        let mut other = CpuLinAlg::zeros_mat(2, 2);
398        CpuLinAlg::mat_set(&mut other, 0, 0, 0.5);
399        CpuLinAlg::mat_set(&mut other, 1, 1, 0.5);
400        CpuLinAlg::mat_scale_add(&mut m, &other, 2.0);
401        assert!((CpuLinAlg::mat_get(&m, 0, 0) - 2.0).abs() < 1e-10);
402        assert!((CpuLinAlg::mat_get(&m, 1, 1) - 3.0).abs() < 1e-10);
403    }
404
405    #[test]
406    fn test_clip_mat_clamps() {
407        let mut m = CpuLinAlg::zeros_mat(1, 2);
408        CpuLinAlg::mat_set(&mut m, 0, 0, 10.0);
409        CpuLinAlg::mat_set(&mut m, 0, 1, -10.0);
410        CpuLinAlg::clip_mat(&mut m, 5.0);
411        assert!((CpuLinAlg::mat_get(&m, 0, 0) - 5.0).abs() < 1e-10);
412        assert!((CpuLinAlg::mat_get(&m, 0, 1) - (-5.0)).abs() < 1e-10);
413    }
414
415    // ── Cycle 1.5: Activation + softmax + sampling ───────────────
416
417    #[test]
418    fn test_apply_activation_tanh() {
419        let v = CpuLinAlg::vec_from_slice(&[0.5, -0.5]);
420        let r = CpuLinAlg::apply_activation(&v, Activation::Tanh);
421        assert!((CpuLinAlg::vec_get(&r, 0) - 0.5_f64.tanh()).abs() < 1e-12);
422        assert!((CpuLinAlg::vec_get(&r, 1) - (-0.5_f64).tanh()).abs() < 1e-12);
423    }
424
425    #[test]
426    fn test_apply_activation_relu() {
427        let v = CpuLinAlg::vec_from_slice(&[1.0, -1.0, 0.0]);
428        let r = CpuLinAlg::apply_activation(&v, Activation::Relu);
429        assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, 0.0, 0.0]);
430    }
431
432    #[test]
433    fn test_apply_derivative_tanh() {
434        let v = CpuLinAlg::vec_from_slice(&[0.5]);
435        let r = CpuLinAlg::apply_derivative(&v, Activation::Tanh);
436        // derivative(0.5) = 1 - 0.25 = 0.75
437        assert!((CpuLinAlg::vec_get(&r, 0) - 0.75).abs() < 1e-12);
438    }
439
440    #[test]
441    fn test_softmax_masked_sums_to_one() {
442        let logits = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0, 4.0]);
443        let mask = vec![0, 1, 2, 3];
444        let probs = CpuLinAlg::softmax_masked(&logits, &mask);
445        let sum: f64 = CpuLinAlg::vec_to_vec(&probs).iter().sum();
446        assert!((sum - 1.0).abs() < 1e-10);
447    }
448
449    #[test]
450    fn test_softmax_masked_unmasked_are_zero() {
451        let logits = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0, 4.0]);
452        let mask = vec![1, 3];
453        let probs = CpuLinAlg::softmax_masked(&logits, &mask);
454        assert_eq!(CpuLinAlg::vec_get(&probs, 0), 0.0);
455        assert_eq!(CpuLinAlg::vec_get(&probs, 2), 0.0);
456        assert!(CpuLinAlg::vec_get(&probs, 1) > 0.0);
457        assert!(CpuLinAlg::vec_get(&probs, 3) > 0.0);
458    }
459
460    #[test]
461    fn test_argmax_masked_returns_highest() {
462        let values = CpuLinAlg::vec_from_slice(&[1.0, 5.0, 3.0, 4.0]);
463        let mask = vec![0, 2, 3];
464        assert_eq!(CpuLinAlg::argmax_masked(&values, &mask), 3);
465    }
466
467    #[test]
468    fn test_sample_from_probs_in_mask() {
469        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
470        let probs = CpuLinAlg::vec_from_slice(&[0.1, 0.2, 0.3, 0.4]);
471        let mask = vec![1, 3];
472        for _ in 0..20 {
473            let idx = CpuLinAlg::sample_from_probs(&probs, &mask, &mut rng);
474            assert!(mask.contains(&idx));
475        }
476    }
477
478    #[test]
479    fn test_rms_error_known() {
480        let v1 = CpuLinAlg::vec_from_slice(&[1.0, 0.0]);
481        let v2 = CpuLinAlg::vec_from_slice(&[0.0, 1.0]);
482        let rms = CpuLinAlg::rms_error(&[&v1, &v2]);
483        let expected = (0.5_f64).sqrt();
484        assert!((rms - expected).abs() < 1e-10);
485    }
486
487    #[test]
488    fn test_rms_error_empty() {
489        let rms = CpuLinAlg::rms_error(&[]);
490        assert_eq!(rms, 0.0);
491    }
492
493    #[test]
494    fn test_vec_dot_known() {
495        let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0]);
496        let b = CpuLinAlg::vec_from_slice(&[4.0, 5.0, 6.0]);
497        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
498        let dot = CpuLinAlg::vec_dot(&a, &b);
499        assert!((dot - 32.0).abs() < 1e-12);
500    }
501
502    #[test]
503    fn test_vec_dot_orthogonal() {
504        let a = CpuLinAlg::vec_from_slice(&[1.0, 0.0]);
505        let b = CpuLinAlg::vec_from_slice(&[0.0, 1.0]);
506        assert!((CpuLinAlg::vec_dot(&a, &b)).abs() < 1e-12);
507    }
508
509    // ── Phase 1 Cycle 1.1: mat_mul (matrix × matrix) ────────────
510
511    #[test]
512    fn test_mat_mul_2x3_by_3x2() {
513        // A = [[1,2,3],[4,5,6]] (2×3)
514        // B = [[7,8],[9,10],[11,12]] (3×2)
515        // C = A*B = [[58,64],[139,154]] (2×2)
516        let mut a = CpuLinAlg::zeros_mat(2, 3);
517        CpuLinAlg::mat_set(&mut a, 0, 0, 1.0);
518        CpuLinAlg::mat_set(&mut a, 0, 1, 2.0);
519        CpuLinAlg::mat_set(&mut a, 0, 2, 3.0);
520        CpuLinAlg::mat_set(&mut a, 1, 0, 4.0);
521        CpuLinAlg::mat_set(&mut a, 1, 1, 5.0);
522        CpuLinAlg::mat_set(&mut a, 1, 2, 6.0);
523
524        let mut b = CpuLinAlg::zeros_mat(3, 2);
525        CpuLinAlg::mat_set(&mut b, 0, 0, 7.0);
526        CpuLinAlg::mat_set(&mut b, 0, 1, 8.0);
527        CpuLinAlg::mat_set(&mut b, 1, 0, 9.0);
528        CpuLinAlg::mat_set(&mut b, 1, 1, 10.0);
529        CpuLinAlg::mat_set(&mut b, 2, 0, 11.0);
530        CpuLinAlg::mat_set(&mut b, 2, 1, 12.0);
531
532        let c = CpuLinAlg::mat_mul(&a, &b);
533        assert_eq!(CpuLinAlg::mat_rows(&c), 2);
534        assert_eq!(CpuLinAlg::mat_cols(&c), 2);
535        assert!((CpuLinAlg::mat_get(&c, 0, 0) - 58.0).abs() < 1e-10);
536        assert!((CpuLinAlg::mat_get(&c, 0, 1) - 64.0).abs() < 1e-10);
537        assert!((CpuLinAlg::mat_get(&c, 1, 0) - 139.0).abs() < 1e-10);
538        assert!((CpuLinAlg::mat_get(&c, 1, 1) - 154.0).abs() < 1e-10);
539    }
540
541    #[test]
542    fn test_mat_mul_identity_left() {
543        // I × M = M
544        let mut identity = CpuLinAlg::zeros_mat(3, 3);
545        CpuLinAlg::mat_set(&mut identity, 0, 0, 1.0);
546        CpuLinAlg::mat_set(&mut identity, 1, 1, 1.0);
547        CpuLinAlg::mat_set(&mut identity, 2, 2, 1.0);
548
549        let mut m = CpuLinAlg::zeros_mat(3, 2);
550        CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
551        CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
552        CpuLinAlg::mat_set(&mut m, 1, 0, 3.0);
553        CpuLinAlg::mat_set(&mut m, 1, 1, 4.0);
554        CpuLinAlg::mat_set(&mut m, 2, 0, 5.0);
555        CpuLinAlg::mat_set(&mut m, 2, 1, 6.0);
556
557        let result = CpuLinAlg::mat_mul(&identity, &m);
558        assert_eq!(CpuLinAlg::mat_rows(&result), 3);
559        assert_eq!(CpuLinAlg::mat_cols(&result), 2);
560        for r in 0..3 {
561            for c in 0..2 {
562                assert!(
563                    (CpuLinAlg::mat_get(&result, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs()
564                        < 1e-10
565                );
566            }
567        }
568    }
569
570    #[test]
571    fn test_mat_mul_identity_right() {
572        // M × I = M
573        let mut m = CpuLinAlg::zeros_mat(2, 3);
574        CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
575        CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
576        CpuLinAlg::mat_set(&mut m, 0, 2, 3.0);
577        CpuLinAlg::mat_set(&mut m, 1, 0, 4.0);
578        CpuLinAlg::mat_set(&mut m, 1, 1, 5.0);
579        CpuLinAlg::mat_set(&mut m, 1, 2, 6.0);
580
581        let mut identity = CpuLinAlg::zeros_mat(3, 3);
582        CpuLinAlg::mat_set(&mut identity, 0, 0, 1.0);
583        CpuLinAlg::mat_set(&mut identity, 1, 1, 1.0);
584        CpuLinAlg::mat_set(&mut identity, 2, 2, 1.0);
585
586        let result = CpuLinAlg::mat_mul(&m, &identity);
587        assert_eq!(CpuLinAlg::mat_rows(&result), 2);
588        assert_eq!(CpuLinAlg::mat_cols(&result), 3);
589        for r in 0..2 {
590            for c in 0..3 {
591                assert!(
592                    (CpuLinAlg::mat_get(&result, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs()
593                        < 1e-10
594                );
595            }
596        }
597    }
598
599    #[test]
600    fn test_mat_mul_result_dimensions() {
601        // (4×3) × (3×5) = (4×5)
602        let a = CpuLinAlg::zeros_mat(4, 3);
603        let b = CpuLinAlg::zeros_mat(3, 5);
604        let c = CpuLinAlg::mat_mul(&a, &b);
605        assert_eq!(CpuLinAlg::mat_rows(&c), 4);
606        assert_eq!(CpuLinAlg::mat_cols(&c), 5);
607    }
608
609    // ── Phase 2 Cycle 2.1: SVD of known matrices ────────────────
610
611    /// Helper: build matrix from row-major slice.
612    fn mat_from_rows(rows: usize, cols: usize, data: &[f64]) -> Matrix {
613        assert_eq!(data.len(), rows * cols);
614        let mut m = CpuLinAlg::zeros_mat(rows, cols);
615        for r in 0..rows {
616            for c in 0..cols {
617                CpuLinAlg::mat_set(&mut m, r, c, data[r * cols + c]);
618            }
619        }
620        m
621    }
622
623    /// Helper: reconstruct M from U, S, V^T → U × diag(S) × V^T.
624    fn reconstruct_usv(u: &Matrix, s: &Vec<f64>, v: &Matrix) -> Matrix {
625        let rows = CpuLinAlg::mat_rows(u);
626        let cols = CpuLinAlg::mat_cols(v);
627        let k = CpuLinAlg::vec_len(s);
628        // diag(S) × V^T
629        let vt = CpuLinAlg::mat_transpose(v);
630        let mut sv = CpuLinAlg::zeros_mat(k, cols);
631        for i in 0..k {
632            for j in 0..cols {
633                CpuLinAlg::mat_set(
634                    &mut sv,
635                    i,
636                    j,
637                    CpuLinAlg::vec_get(s, i) * CpuLinAlg::mat_get(&vt, i, j),
638                );
639            }
640        }
641        // U × (diag(S) × V^T)
642        let mut result = CpuLinAlg::zeros_mat(rows, cols);
643        for i in 0..rows {
644            for j in 0..cols {
645                let mut sum = 0.0;
646                for l in 0..k {
647                    sum += CpuLinAlg::mat_get(u, i, l) * CpuLinAlg::mat_get(&sv, l, j);
648                }
649                CpuLinAlg::mat_set(&mut result, i, j, sum);
650            }
651        }
652        result
653    }
654
655    /// Helper: check if matrix is approximately identity.
656    fn assert_approx_identity(m: &Matrix, tol: f64) {
657        let n = CpuLinAlg::mat_rows(m);
658        assert_eq!(n, CpuLinAlg::mat_cols(m), "not square");
659        for r in 0..n {
660            for c in 0..n {
661                let expected = if r == c { 1.0 } else { 0.0 };
662                assert!(
663                    (CpuLinAlg::mat_get(m, r, c) - expected).abs() < tol,
664                    "at ({r},{c}): got {} expected {expected}",
665                    CpuLinAlg::mat_get(m, r, c)
666                );
667            }
668        }
669    }
670
671    #[test]
672    fn test_svd_2x2_diagonal() {
673        // diag(5, 3) → U≈I, S=[5,3], V≈I (up to sign)
674        let m = mat_from_rows(2, 2, &[5.0, 0.0, 0.0, 3.0]);
675        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
676
677        // S values = [5, 3] sorted descending
678        assert!((CpuLinAlg::vec_get(&s, 0) - 5.0).abs() < 1e-10);
679        assert!((CpuLinAlg::vec_get(&s, 1) - 3.0).abs() < 1e-10);
680
681        // Reconstruction: U × diag(S) × V^T ≈ M
682        let recon = reconstruct_usv(&u, &s, &v);
683        for r in 0..2 {
684            for c in 0..2 {
685                assert!(
686                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
687                    "reconstruction mismatch at ({r},{c})"
688                );
689            }
690        }
691    }
692
693    #[test]
694    fn test_svd_3x3_reconstruction() {
695        // Known 3×3 matrix
696        let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
697        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
698
699        // Reconstruction within tolerance
700        let recon = reconstruct_usv(&u, &s, &v);
701        for r in 0..3 {
702            for c in 0..3 {
703                assert!(
704                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
705                    "reconstruction mismatch at ({r},{c}): got {} expected {}",
706                    CpuLinAlg::mat_get(&recon, r, c),
707                    CpuLinAlg::mat_get(&m, r, c)
708                );
709            }
710        }
711    }
712
713    #[test]
714    fn test_svd_rectangular_3x2_reconstruction() {
715        let m = mat_from_rows(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
716        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
717
718        // U is 3×2, S has 2 elements, V is 2×2
719        assert_eq!(CpuLinAlg::mat_rows(&u), 3);
720        assert_eq!(CpuLinAlg::mat_cols(&u), 2);
721        assert_eq!(CpuLinAlg::vec_len(&s), 2);
722        assert_eq!(CpuLinAlg::mat_rows(&v), 2);
723        assert_eq!(CpuLinAlg::mat_cols(&v), 2);
724
725        let recon = reconstruct_usv(&u, &s, &v);
726        for r in 0..3 {
727            for c in 0..2 {
728                assert!(
729                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
730                    "reconstruction mismatch at ({r},{c})"
731                );
732            }
733        }
734    }
735
736    #[test]
737    fn test_svd_singular_values_non_negative_descending() {
738        let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
739        let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
740
741        for i in 0..CpuLinAlg::vec_len(&s) {
742            assert!(
743                CpuLinAlg::vec_get(&s, i) >= 0.0,
744                "singular value {i} is negative: {}",
745                CpuLinAlg::vec_get(&s, i)
746            );
747        }
748        for i in 1..CpuLinAlg::vec_len(&s) {
749            assert!(
750                CpuLinAlg::vec_get(&s, i - 1) >= CpuLinAlg::vec_get(&s, i) - 1e-12,
751                "singular values not descending: s[{}]={} < s[{}]={}",
752                i - 1,
753                CpuLinAlg::vec_get(&s, i - 1),
754                i,
755                CpuLinAlg::vec_get(&s, i)
756            );
757        }
758    }
759
760    #[test]
761    fn test_svd_orthonormal_columns() {
762        let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
763        let (u, _s, v) = CpuLinAlg::svd(&m).unwrap();
764
765        // U^T × U ≈ I
766        let utu = CpuLinAlg::mat_mul(&CpuLinAlg::mat_transpose(&u), &u);
767        assert_approx_identity(&utu, 1e-10);
768
769        // V^T × V ≈ I
770        let vtv = CpuLinAlg::mat_mul(&CpuLinAlg::mat_transpose(&v), &v);
771        assert_approx_identity(&vtv, 1e-10);
772    }
773
774    // ── Phase 2 Cycle 2.2: SVD edge cases ───────────────────────
775
776    #[test]
777    fn test_svd_1x1_matrix() {
778        let m = mat_from_rows(1, 1, &[7.0]);
779        let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
780        assert_eq!(CpuLinAlg::vec_len(&s), 1);
781        assert!((CpuLinAlg::vec_get(&s, 0) - 7.0).abs() < 1e-10);
782    }
783
784    #[test]
785    fn test_svd_1x1_negative() {
786        let m = mat_from_rows(1, 1, &[-3.0]);
787        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
788        // S must be non-negative
789        assert!(CpuLinAlg::vec_get(&s, 0) >= 0.0);
790        assert!((CpuLinAlg::vec_get(&s, 0) - 3.0).abs() < 1e-10);
791        // Reconstruction
792        let recon = reconstruct_usv(&u, &s, &v);
793        assert!((CpuLinAlg::mat_get(&recon, 0, 0) - (-3.0)).abs() < 1e-10);
794    }
795
796    #[test]
797    fn test_svd_zero_matrix() {
798        let m = CpuLinAlg::zeros_mat(3, 3);
799        let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
800        for i in 0..CpuLinAlg::vec_len(&s) {
801            assert!(
802                CpuLinAlg::vec_get(&s, i).abs() < 1e-12,
803                "expected zero singular value, got {}",
804                CpuLinAlg::vec_get(&s, i)
805            );
806        }
807    }
808
809    #[test]
810    fn test_svd_repeated_singular_values() {
811        // diag(4, 4, 2) → S = [4, 4, 2]
812        let m = mat_from_rows(3, 3, &[4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0]);
813        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
814        assert!((CpuLinAlg::vec_get(&s, 0) - 4.0).abs() < 1e-10);
815        assert!((CpuLinAlg::vec_get(&s, 1) - 4.0).abs() < 1e-10);
816        assert!((CpuLinAlg::vec_get(&s, 2) - 2.0).abs() < 1e-10);
817
818        let recon = reconstruct_usv(&u, &s, &v);
819        for r in 0..3 {
820            for c in 0..3 {
821                assert!(
822                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
823                    "reconstruction mismatch at ({r},{c})"
824                );
825            }
826        }
827    }
828
829    #[test]
830    fn test_svd_16x16_reconstruction() {
831        // Deterministic 16×16 matrix
832        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
833        let m = CpuLinAlg::xavier_mat(16, 16, &mut rng);
834        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
835
836        let recon = reconstruct_usv(&u, &s, &v);
837        for r in 0..16 {
838            for c in 0..16 {
839                assert!(
840                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-8,
841                    "reconstruction mismatch at ({r},{c}): got {} expected {}",
842                    CpuLinAlg::mat_get(&recon, r, c),
843                    CpuLinAlg::mat_get(&m, r, c)
844                );
845            }
846        }
847    }
848
849    // ── Fix #4: SVD returns Result ──────────────────────────────
850
851    #[test]
852    fn test_svd_returns_ok_for_valid_matrix() {
853        let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
854        let result = CpuLinAlg::svd(&m);
855        assert!(result.is_ok(), "SVD of valid matrix should return Ok");
856        let (u, s, v) = result.unwrap();
857        assert_eq!(CpuLinAlg::vec_len(&s), 3);
858        assert_eq!(CpuLinAlg::mat_rows(&u), 3);
859        assert_eq!(CpuLinAlg::mat_rows(&v), 3);
860    }
861
862    #[test]
863    fn test_svd_result_reconstruction() {
864        // Verify reconstruction works through Result unwrap
865        let m = mat_from_rows(2, 2, &[5.0, 0.0, 0.0, 3.0]);
866        let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
867        let recon = reconstruct_usv(&u, &s, &v);
868        for r in 0..2 {
869            for c in 0..2 {
870                assert!(
871                    (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10
872                );
873            }
874        }
875    }
876}