1use crate::activation::Activation;
13use crate::linalg::LinAlg;
14use crate::matrix::Matrix;
15use rand::Rng;
16
17#[derive(Debug, Clone)]
23pub struct CpuLinAlg;
24
25impl LinAlg for CpuLinAlg {
26 type Vector = Vec<f64>;
27 type Matrix = Matrix;
28
29 fn zeros_vec(size: usize) -> Self::Vector {
30 vec![0.0; size]
31 }
32
33 fn zeros_mat(rows: usize, cols: usize) -> Self::Matrix {
34 Matrix::zeros(rows, cols)
35 }
36
37 fn xavier_mat(rows: usize, cols: usize, rng: &mut impl Rng) -> Self::Matrix {
38 Matrix::xavier(rows, cols, rng)
39 }
40
41 fn mat_vec_mul(m: &Self::Matrix, v: &Self::Vector) -> Self::Vector {
42 m.mul_vec(v)
43 }
44
45 fn mat_transpose(m: &Self::Matrix) -> Self::Matrix {
46 m.transpose()
47 }
48
49 fn outer_product(a: &Self::Vector, b: &Self::Vector) -> Self::Matrix {
50 Matrix::outer(a, b)
51 }
52
53 fn mat_mul(a: &Self::Matrix, b: &Self::Matrix) -> Self::Matrix {
54 assert_eq!(a.cols, b.rows, "mat_mul: inner dimensions mismatch");
55 let mut result = Matrix::zeros(a.rows, b.cols);
56 for i in 0..a.rows {
57 for j in 0..b.cols {
58 let mut sum = 0.0;
59 for k in 0..a.cols {
60 sum += a.get(i, k) * b.get(k, j);
61 }
62 result.set(i, j, sum);
63 }
64 }
65 result
66 }
67
68 fn svd(m: &Self::Matrix) -> crate::linalg::SvdResult<Self> {
69 Ok(crate::linalg::golub_kahan::GolubKahanSvd::new().compute(m)?)
70 }
71
72 fn mat_scale_add(m: &mut Self::Matrix, other: &Self::Matrix, scale: f64) {
73 m.scale_add(other, scale);
74 }
75
76 fn mat_rows(m: &Self::Matrix) -> usize {
77 m.rows
78 }
79
80 fn mat_cols(m: &Self::Matrix) -> usize {
81 m.cols
82 }
83
84 fn mat_get(m: &Self::Matrix, row: usize, col: usize) -> f64 {
85 m.get(row, col)
86 }
87
88 fn mat_set(m: &mut Self::Matrix, row: usize, col: usize, val: f64) {
89 m.set(row, col, val);
90 }
91
92 fn vec_add(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
93 crate::matrix::vec_add(a, b)
94 }
95
96 fn vec_sub(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
97 crate::matrix::vec_sub(a, b)
98 }
99
100 fn vec_scale(v: &Self::Vector, s: f64) -> Self::Vector {
101 crate::matrix::vec_scale(v, s)
102 }
103
104 fn vec_hadamard(a: &Self::Vector, b: &Self::Vector) -> Self::Vector {
105 assert_eq!(a.len(), b.len(), "vec_hadamard: length mismatch");
106 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
107 }
108
109 fn vec_dot(a: &Self::Vector, b: &Self::Vector) -> f64 {
110 assert_eq!(a.len(), b.len(), "vec_dot: length mismatch");
111 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
112 }
113
114 fn vec_len(v: &Self::Vector) -> usize {
115 v.len()
116 }
117
118 fn vec_get(v: &Self::Vector, i: usize) -> f64 {
119 v[i]
120 }
121
122 fn vec_set(v: &mut Self::Vector, i: usize, val: f64) {
123 v[i] = val;
124 }
125
126 fn vec_from_slice(s: &[f64]) -> Self::Vector {
127 s.to_vec()
128 }
129
130 fn vec_to_vec(v: &Self::Vector) -> Vec<f64> {
131 v.clone()
132 }
133
134 fn vec_as_slice(v: &Self::Vector) -> &[f64] {
135 v.as_slice()
136 }
137
138 fn clip_vec(v: &mut Self::Vector, max_abs: f64) {
139 crate::matrix::clip_vec(v, max_abs);
140 }
141
142 fn clip_mat(m: &mut Self::Matrix, max_abs: f64) {
143 for x in m.data.iter_mut() {
144 *x = x.clamp(-max_abs, max_abs);
145 }
146 }
147
148 fn apply_activation(v: &Self::Vector, act: Activation) -> Self::Vector {
149 v.iter().map(|&x| act.apply(x)).collect()
150 }
151
152 fn apply_derivative(v: &Self::Vector, act: Activation) -> Self::Vector {
153 v.iter().map(|&fx| act.derivative(fx)).collect()
154 }
155
156 fn softmax_masked(logits: &Self::Vector, mask: &[usize]) -> Self::Vector {
157 crate::matrix::softmax_masked(logits, mask)
158 }
159
160 fn argmax_masked(values: &Self::Vector, mask: &[usize]) -> usize {
161 crate::matrix::argmax_masked(values, mask)
162 }
163
164 fn sample_from_probs(probs: &Self::Vector, mask: &[usize], rng: &mut impl Rng) -> usize {
165 crate::matrix::sample_from_probs(probs, mask, rng)
166 }
167
168 fn rms_error(error_vecs: &[&Self::Vector]) -> f64 {
169 let slices: Vec<&[f64]> = error_vecs.iter().map(|v| v.as_slice()).collect();
170 crate::matrix::rms_error(&slices)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use rand::SeedableRng;
178
179 #[test]
182 fn test_zeros_vec_correct_length() {
183 let v = CpuLinAlg::zeros_vec(5);
184 assert_eq!(CpuLinAlg::vec_len(&v), 5);
185 }
186
187 #[test]
188 fn test_zeros_vec_all_zeros() {
189 let v = CpuLinAlg::zeros_vec(3);
190 for i in 0..3 {
191 assert_eq!(CpuLinAlg::vec_get(&v, i), 0.0);
192 }
193 }
194
195 #[test]
196 fn test_zeros_vec_empty() {
197 let v = CpuLinAlg::zeros_vec(0);
198 assert_eq!(CpuLinAlg::vec_len(&v), 0);
199 }
200
201 #[test]
202 fn test_vec_get_returns_element() {
203 let v = CpuLinAlg::vec_from_slice(&[10.0, 20.0, 30.0]);
204 assert_eq!(CpuLinAlg::vec_get(&v, 0), 10.0);
205 assert_eq!(CpuLinAlg::vec_get(&v, 1), 20.0);
206 assert_eq!(CpuLinAlg::vec_get(&v, 2), 30.0);
207 }
208
209 #[test]
210 fn test_vec_set_modifies_element() {
211 let mut v = CpuLinAlg::zeros_vec(3);
212 CpuLinAlg::vec_set(&mut v, 1, 42.0);
213 assert_eq!(CpuLinAlg::vec_get(&v, 1), 42.0);
214 assert_eq!(CpuLinAlg::vec_get(&v, 0), 0.0);
215 }
216
217 #[test]
218 fn test_vec_from_slice_creates_vector() {
219 let v = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
220 assert_eq!(CpuLinAlg::vec_len(&v), 2);
221 assert_eq!(CpuLinAlg::vec_get(&v, 0), 1.0);
222 assert_eq!(CpuLinAlg::vec_get(&v, 1), 2.0);
223 }
224
225 #[test]
226 fn test_vec_from_slice_empty() {
227 let v = CpuLinAlg::vec_from_slice(&[]);
228 assert_eq!(CpuLinAlg::vec_len(&v), 0);
229 }
230
231 #[test]
232 fn test_vec_to_vec_returns_owned() {
233 let v = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0]);
234 let owned = CpuLinAlg::vec_to_vec(&v);
235 assert_eq!(owned, vec![1.0, 2.0, 3.0]);
236 }
237
238 #[test]
239 fn test_vec_as_slice_returns_slice() {
240 let v = CpuLinAlg::vec_from_slice(&[4.0, 5.0]);
241 let s = CpuLinAlg::vec_as_slice(&v);
242 assert_eq!(s, &[4.0, 5.0]);
243 }
244
245 #[test]
246 fn test_vec_len_matches_creation_size() {
247 let v = CpuLinAlg::zeros_vec(7);
248 assert_eq!(CpuLinAlg::vec_len(&v), 7);
249 }
250
251 #[test]
254 fn test_vec_add_known() {
255 let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
256 let b = CpuLinAlg::vec_from_slice(&[3.0, 4.0]);
257 let r = CpuLinAlg::vec_add(&a, &b);
258 assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![4.0, 6.0]);
259 }
260
261 #[test]
262 fn test_vec_sub_known() {
263 let a = CpuLinAlg::vec_from_slice(&[5.0, 3.0]);
264 let b = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
265 let r = CpuLinAlg::vec_sub(&a, &b);
266 assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![4.0, 1.0]);
267 }
268
269 #[test]
270 fn test_vec_scale_known() {
271 let v = CpuLinAlg::vec_from_slice(&[2.0, 4.0]);
272 let r = CpuLinAlg::vec_scale(&v, 0.5);
273 assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, 2.0]);
274 }
275
276 #[test]
277 fn test_vec_hadamard_known() {
278 let a = CpuLinAlg::vec_from_slice(&[2.0, 3.0, 4.0]);
279 let b = CpuLinAlg::vec_from_slice(&[0.5, -1.0, 2.0]);
280 let r = CpuLinAlg::vec_hadamard(&a, &b);
281 assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, -3.0, 8.0]);
282 }
283
284 #[test]
285 fn test_clip_vec_clamps() {
286 let mut v = CpuLinAlg::vec_from_slice(&[10.0, -10.0, 0.5]);
287 CpuLinAlg::clip_vec(&mut v, 5.0);
288 assert!((CpuLinAlg::vec_get(&v, 0) - 5.0).abs() < 1e-10);
289 assert!((CpuLinAlg::vec_get(&v, 1) - (-5.0)).abs() < 1e-10);
290 assert!((CpuLinAlg::vec_get(&v, 2) - 0.5).abs() < 1e-10);
291 }
292
293 #[test]
294 fn test_clip_vec_leaves_safe_values() {
295 let mut v = CpuLinAlg::vec_from_slice(&[1.0, -1.0, 0.0]);
296 CpuLinAlg::clip_vec(&mut v, 5.0);
297 assert_eq!(CpuLinAlg::vec_to_vec(&v), vec![1.0, -1.0, 0.0]);
298 }
299
300 #[test]
303 fn test_zeros_mat_correct_dims() {
304 let m = CpuLinAlg::zeros_mat(3, 4);
305 assert_eq!(CpuLinAlg::mat_rows(&m), 3);
306 assert_eq!(CpuLinAlg::mat_cols(&m), 4);
307 }
308
309 #[test]
310 fn test_zeros_mat_all_zeros() {
311 let m = CpuLinAlg::zeros_mat(2, 3);
312 for r in 0..2 {
313 for c in 0..3 {
314 assert_eq!(CpuLinAlg::mat_get(&m, r, c), 0.0);
315 }
316 }
317 }
318
319 #[test]
320 fn test_xavier_mat_correct_dims() {
321 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
322 let m = CpuLinAlg::xavier_mat(3, 4, &mut rng);
323 assert_eq!(CpuLinAlg::mat_rows(&m), 3);
324 assert_eq!(CpuLinAlg::mat_cols(&m), 4);
325 }
326
327 #[test]
328 fn test_xavier_mat_all_finite() {
329 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
330 let m = CpuLinAlg::xavier_mat(10, 10, &mut rng);
331 for r in 0..10 {
332 for c in 0..10 {
333 assert!(CpuLinAlg::mat_get(&m, r, c).is_finite());
334 }
335 }
336 }
337
338 #[test]
339 fn test_mat_get_set_roundtrip() {
340 let mut m = CpuLinAlg::zeros_mat(3, 3);
341 CpuLinAlg::mat_set(&mut m, 1, 2, 42.0);
342 assert_eq!(CpuLinAlg::mat_get(&m, 1, 2), 42.0);
343 assert_eq!(CpuLinAlg::mat_get(&m, 0, 0), 0.0);
344 }
345
346 #[test]
347 fn test_mat_transpose_swaps_dims() {
348 let m = CpuLinAlg::zeros_mat(3, 5);
349 let t = CpuLinAlg::mat_transpose(&m);
350 assert_eq!(CpuLinAlg::mat_rows(&t), 5);
351 assert_eq!(CpuLinAlg::mat_cols(&t), 3);
352 }
353
354 #[test]
355 fn test_mat_transpose_repositions_values() {
356 let mut m = CpuLinAlg::zeros_mat(2, 3);
357 CpuLinAlg::mat_set(&mut m, 0, 1, 7.0);
358 CpuLinAlg::mat_set(&mut m, 1, 2, 3.0);
359 let t = CpuLinAlg::mat_transpose(&m);
360 assert_eq!(CpuLinAlg::mat_get(&t, 1, 0), 7.0);
361 assert_eq!(CpuLinAlg::mat_get(&t, 2, 1), 3.0);
362 }
363
364 #[test]
367 fn test_mat_vec_mul_known() {
368 let mut m = CpuLinAlg::zeros_mat(2, 2);
370 CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
371 CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
372 CpuLinAlg::mat_set(&mut m, 1, 0, 3.0);
373 CpuLinAlg::mat_set(&mut m, 1, 1, 4.0);
374 let v = CpuLinAlg::vec_from_slice(&[5.0, 6.0]);
375 let r = CpuLinAlg::mat_vec_mul(&m, &v);
376 assert_eq!(CpuLinAlg::vec_len(&r), 2);
377 assert!((CpuLinAlg::vec_get(&r, 0) - 17.0).abs() < 1e-10);
378 assert!((CpuLinAlg::vec_get(&r, 1) - 39.0).abs() < 1e-10);
379 }
380
381 #[test]
382 fn test_outer_product_known() {
383 let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0]);
384 let b = CpuLinAlg::vec_from_slice(&[3.0, 4.0, 5.0]);
385 let m = CpuLinAlg::outer_product(&a, &b);
386 assert_eq!(CpuLinAlg::mat_rows(&m), 2);
387 assert_eq!(CpuLinAlg::mat_cols(&m), 3);
388 assert!((CpuLinAlg::mat_get(&m, 0, 0) - 3.0).abs() < 1e-10);
389 assert!((CpuLinAlg::mat_get(&m, 1, 2) - 10.0).abs() < 1e-10);
390 }
391
392 #[test]
393 fn test_mat_scale_add_basic() {
394 let mut m = CpuLinAlg::zeros_mat(2, 2);
395 CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
396 CpuLinAlg::mat_set(&mut m, 1, 1, 2.0);
397 let mut other = CpuLinAlg::zeros_mat(2, 2);
398 CpuLinAlg::mat_set(&mut other, 0, 0, 0.5);
399 CpuLinAlg::mat_set(&mut other, 1, 1, 0.5);
400 CpuLinAlg::mat_scale_add(&mut m, &other, 2.0);
401 assert!((CpuLinAlg::mat_get(&m, 0, 0) - 2.0).abs() < 1e-10);
402 assert!((CpuLinAlg::mat_get(&m, 1, 1) - 3.0).abs() < 1e-10);
403 }
404
405 #[test]
406 fn test_clip_mat_clamps() {
407 let mut m = CpuLinAlg::zeros_mat(1, 2);
408 CpuLinAlg::mat_set(&mut m, 0, 0, 10.0);
409 CpuLinAlg::mat_set(&mut m, 0, 1, -10.0);
410 CpuLinAlg::clip_mat(&mut m, 5.0);
411 assert!((CpuLinAlg::mat_get(&m, 0, 0) - 5.0).abs() < 1e-10);
412 assert!((CpuLinAlg::mat_get(&m, 0, 1) - (-5.0)).abs() < 1e-10);
413 }
414
415 #[test]
418 fn test_apply_activation_tanh() {
419 let v = CpuLinAlg::vec_from_slice(&[0.5, -0.5]);
420 let r = CpuLinAlg::apply_activation(&v, Activation::Tanh);
421 assert!((CpuLinAlg::vec_get(&r, 0) - 0.5_f64.tanh()).abs() < 1e-12);
422 assert!((CpuLinAlg::vec_get(&r, 1) - (-0.5_f64).tanh()).abs() < 1e-12);
423 }
424
425 #[test]
426 fn test_apply_activation_relu() {
427 let v = CpuLinAlg::vec_from_slice(&[1.0, -1.0, 0.0]);
428 let r = CpuLinAlg::apply_activation(&v, Activation::Relu);
429 assert_eq!(CpuLinAlg::vec_to_vec(&r), vec![1.0, 0.0, 0.0]);
430 }
431
432 #[test]
433 fn test_apply_derivative_tanh() {
434 let v = CpuLinAlg::vec_from_slice(&[0.5]);
435 let r = CpuLinAlg::apply_derivative(&v, Activation::Tanh);
436 assert!((CpuLinAlg::vec_get(&r, 0) - 0.75).abs() < 1e-12);
438 }
439
440 #[test]
441 fn test_softmax_masked_sums_to_one() {
442 let logits = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0, 4.0]);
443 let mask = vec![0, 1, 2, 3];
444 let probs = CpuLinAlg::softmax_masked(&logits, &mask);
445 let sum: f64 = CpuLinAlg::vec_to_vec(&probs).iter().sum();
446 assert!((sum - 1.0).abs() < 1e-10);
447 }
448
449 #[test]
450 fn test_softmax_masked_unmasked_are_zero() {
451 let logits = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0, 4.0]);
452 let mask = vec![1, 3];
453 let probs = CpuLinAlg::softmax_masked(&logits, &mask);
454 assert_eq!(CpuLinAlg::vec_get(&probs, 0), 0.0);
455 assert_eq!(CpuLinAlg::vec_get(&probs, 2), 0.0);
456 assert!(CpuLinAlg::vec_get(&probs, 1) > 0.0);
457 assert!(CpuLinAlg::vec_get(&probs, 3) > 0.0);
458 }
459
460 #[test]
461 fn test_argmax_masked_returns_highest() {
462 let values = CpuLinAlg::vec_from_slice(&[1.0, 5.0, 3.0, 4.0]);
463 let mask = vec![0, 2, 3];
464 assert_eq!(CpuLinAlg::argmax_masked(&values, &mask), 3);
465 }
466
467 #[test]
468 fn test_sample_from_probs_in_mask() {
469 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
470 let probs = CpuLinAlg::vec_from_slice(&[0.1, 0.2, 0.3, 0.4]);
471 let mask = vec![1, 3];
472 for _ in 0..20 {
473 let idx = CpuLinAlg::sample_from_probs(&probs, &mask, &mut rng);
474 assert!(mask.contains(&idx));
475 }
476 }
477
478 #[test]
479 fn test_rms_error_known() {
480 let v1 = CpuLinAlg::vec_from_slice(&[1.0, 0.0]);
481 let v2 = CpuLinAlg::vec_from_slice(&[0.0, 1.0]);
482 let rms = CpuLinAlg::rms_error(&[&v1, &v2]);
483 let expected = (0.5_f64).sqrt();
484 assert!((rms - expected).abs() < 1e-10);
485 }
486
487 #[test]
488 fn test_rms_error_empty() {
489 let rms = CpuLinAlg::rms_error(&[]);
490 assert_eq!(rms, 0.0);
491 }
492
493 #[test]
494 fn test_vec_dot_known() {
495 let a = CpuLinAlg::vec_from_slice(&[1.0, 2.0, 3.0]);
496 let b = CpuLinAlg::vec_from_slice(&[4.0, 5.0, 6.0]);
497 let dot = CpuLinAlg::vec_dot(&a, &b);
499 assert!((dot - 32.0).abs() < 1e-12);
500 }
501
502 #[test]
503 fn test_vec_dot_orthogonal() {
504 let a = CpuLinAlg::vec_from_slice(&[1.0, 0.0]);
505 let b = CpuLinAlg::vec_from_slice(&[0.0, 1.0]);
506 assert!((CpuLinAlg::vec_dot(&a, &b)).abs() < 1e-12);
507 }
508
509 #[test]
512 fn test_mat_mul_2x3_by_3x2() {
513 let mut a = CpuLinAlg::zeros_mat(2, 3);
517 CpuLinAlg::mat_set(&mut a, 0, 0, 1.0);
518 CpuLinAlg::mat_set(&mut a, 0, 1, 2.0);
519 CpuLinAlg::mat_set(&mut a, 0, 2, 3.0);
520 CpuLinAlg::mat_set(&mut a, 1, 0, 4.0);
521 CpuLinAlg::mat_set(&mut a, 1, 1, 5.0);
522 CpuLinAlg::mat_set(&mut a, 1, 2, 6.0);
523
524 let mut b = CpuLinAlg::zeros_mat(3, 2);
525 CpuLinAlg::mat_set(&mut b, 0, 0, 7.0);
526 CpuLinAlg::mat_set(&mut b, 0, 1, 8.0);
527 CpuLinAlg::mat_set(&mut b, 1, 0, 9.0);
528 CpuLinAlg::mat_set(&mut b, 1, 1, 10.0);
529 CpuLinAlg::mat_set(&mut b, 2, 0, 11.0);
530 CpuLinAlg::mat_set(&mut b, 2, 1, 12.0);
531
532 let c = CpuLinAlg::mat_mul(&a, &b);
533 assert_eq!(CpuLinAlg::mat_rows(&c), 2);
534 assert_eq!(CpuLinAlg::mat_cols(&c), 2);
535 assert!((CpuLinAlg::mat_get(&c, 0, 0) - 58.0).abs() < 1e-10);
536 assert!((CpuLinAlg::mat_get(&c, 0, 1) - 64.0).abs() < 1e-10);
537 assert!((CpuLinAlg::mat_get(&c, 1, 0) - 139.0).abs() < 1e-10);
538 assert!((CpuLinAlg::mat_get(&c, 1, 1) - 154.0).abs() < 1e-10);
539 }
540
541 #[test]
542 fn test_mat_mul_identity_left() {
543 let mut identity = CpuLinAlg::zeros_mat(3, 3);
545 CpuLinAlg::mat_set(&mut identity, 0, 0, 1.0);
546 CpuLinAlg::mat_set(&mut identity, 1, 1, 1.0);
547 CpuLinAlg::mat_set(&mut identity, 2, 2, 1.0);
548
549 let mut m = CpuLinAlg::zeros_mat(3, 2);
550 CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
551 CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
552 CpuLinAlg::mat_set(&mut m, 1, 0, 3.0);
553 CpuLinAlg::mat_set(&mut m, 1, 1, 4.0);
554 CpuLinAlg::mat_set(&mut m, 2, 0, 5.0);
555 CpuLinAlg::mat_set(&mut m, 2, 1, 6.0);
556
557 let result = CpuLinAlg::mat_mul(&identity, &m);
558 assert_eq!(CpuLinAlg::mat_rows(&result), 3);
559 assert_eq!(CpuLinAlg::mat_cols(&result), 2);
560 for r in 0..3 {
561 for c in 0..2 {
562 assert!(
563 (CpuLinAlg::mat_get(&result, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs()
564 < 1e-10
565 );
566 }
567 }
568 }
569
570 #[test]
571 fn test_mat_mul_identity_right() {
572 let mut m = CpuLinAlg::zeros_mat(2, 3);
574 CpuLinAlg::mat_set(&mut m, 0, 0, 1.0);
575 CpuLinAlg::mat_set(&mut m, 0, 1, 2.0);
576 CpuLinAlg::mat_set(&mut m, 0, 2, 3.0);
577 CpuLinAlg::mat_set(&mut m, 1, 0, 4.0);
578 CpuLinAlg::mat_set(&mut m, 1, 1, 5.0);
579 CpuLinAlg::mat_set(&mut m, 1, 2, 6.0);
580
581 let mut identity = CpuLinAlg::zeros_mat(3, 3);
582 CpuLinAlg::mat_set(&mut identity, 0, 0, 1.0);
583 CpuLinAlg::mat_set(&mut identity, 1, 1, 1.0);
584 CpuLinAlg::mat_set(&mut identity, 2, 2, 1.0);
585
586 let result = CpuLinAlg::mat_mul(&m, &identity);
587 assert_eq!(CpuLinAlg::mat_rows(&result), 2);
588 assert_eq!(CpuLinAlg::mat_cols(&result), 3);
589 for r in 0..2 {
590 for c in 0..3 {
591 assert!(
592 (CpuLinAlg::mat_get(&result, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs()
593 < 1e-10
594 );
595 }
596 }
597 }
598
599 #[test]
600 fn test_mat_mul_result_dimensions() {
601 let a = CpuLinAlg::zeros_mat(4, 3);
603 let b = CpuLinAlg::zeros_mat(3, 5);
604 let c = CpuLinAlg::mat_mul(&a, &b);
605 assert_eq!(CpuLinAlg::mat_rows(&c), 4);
606 assert_eq!(CpuLinAlg::mat_cols(&c), 5);
607 }
608
609 fn mat_from_rows(rows: usize, cols: usize, data: &[f64]) -> Matrix {
613 assert_eq!(data.len(), rows * cols);
614 let mut m = CpuLinAlg::zeros_mat(rows, cols);
615 for r in 0..rows {
616 for c in 0..cols {
617 CpuLinAlg::mat_set(&mut m, r, c, data[r * cols + c]);
618 }
619 }
620 m
621 }
622
623 fn reconstruct_usv(u: &Matrix, s: &Vec<f64>, v: &Matrix) -> Matrix {
625 let rows = CpuLinAlg::mat_rows(u);
626 let cols = CpuLinAlg::mat_cols(v);
627 let k = CpuLinAlg::vec_len(s);
628 let vt = CpuLinAlg::mat_transpose(v);
630 let mut sv = CpuLinAlg::zeros_mat(k, cols);
631 for i in 0..k {
632 for j in 0..cols {
633 CpuLinAlg::mat_set(
634 &mut sv,
635 i,
636 j,
637 CpuLinAlg::vec_get(s, i) * CpuLinAlg::mat_get(&vt, i, j),
638 );
639 }
640 }
641 let mut result = CpuLinAlg::zeros_mat(rows, cols);
643 for i in 0..rows {
644 for j in 0..cols {
645 let mut sum = 0.0;
646 for l in 0..k {
647 sum += CpuLinAlg::mat_get(u, i, l) * CpuLinAlg::mat_get(&sv, l, j);
648 }
649 CpuLinAlg::mat_set(&mut result, i, j, sum);
650 }
651 }
652 result
653 }
654
655 fn assert_approx_identity(m: &Matrix, tol: f64) {
657 let n = CpuLinAlg::mat_rows(m);
658 assert_eq!(n, CpuLinAlg::mat_cols(m), "not square");
659 for r in 0..n {
660 for c in 0..n {
661 let expected = if r == c { 1.0 } else { 0.0 };
662 assert!(
663 (CpuLinAlg::mat_get(m, r, c) - expected).abs() < tol,
664 "at ({r},{c}): got {} expected {expected}",
665 CpuLinAlg::mat_get(m, r, c)
666 );
667 }
668 }
669 }
670
671 #[test]
672 fn test_svd_2x2_diagonal() {
673 let m = mat_from_rows(2, 2, &[5.0, 0.0, 0.0, 3.0]);
675 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
676
677 assert!((CpuLinAlg::vec_get(&s, 0) - 5.0).abs() < 1e-10);
679 assert!((CpuLinAlg::vec_get(&s, 1) - 3.0).abs() < 1e-10);
680
681 let recon = reconstruct_usv(&u, &s, &v);
683 for r in 0..2 {
684 for c in 0..2 {
685 assert!(
686 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
687 "reconstruction mismatch at ({r},{c})"
688 );
689 }
690 }
691 }
692
693 #[test]
694 fn test_svd_3x3_reconstruction() {
695 let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
697 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
698
699 let recon = reconstruct_usv(&u, &s, &v);
701 for r in 0..3 {
702 for c in 0..3 {
703 assert!(
704 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
705 "reconstruction mismatch at ({r},{c}): got {} expected {}",
706 CpuLinAlg::mat_get(&recon, r, c),
707 CpuLinAlg::mat_get(&m, r, c)
708 );
709 }
710 }
711 }
712
713 #[test]
714 fn test_svd_rectangular_3x2_reconstruction() {
715 let m = mat_from_rows(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
716 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
717
718 assert_eq!(CpuLinAlg::mat_rows(&u), 3);
720 assert_eq!(CpuLinAlg::mat_cols(&u), 2);
721 assert_eq!(CpuLinAlg::vec_len(&s), 2);
722 assert_eq!(CpuLinAlg::mat_rows(&v), 2);
723 assert_eq!(CpuLinAlg::mat_cols(&v), 2);
724
725 let recon = reconstruct_usv(&u, &s, &v);
726 for r in 0..3 {
727 for c in 0..2 {
728 assert!(
729 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
730 "reconstruction mismatch at ({r},{c})"
731 );
732 }
733 }
734 }
735
736 #[test]
737 fn test_svd_singular_values_non_negative_descending() {
738 let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
739 let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
740
741 for i in 0..CpuLinAlg::vec_len(&s) {
742 assert!(
743 CpuLinAlg::vec_get(&s, i) >= 0.0,
744 "singular value {i} is negative: {}",
745 CpuLinAlg::vec_get(&s, i)
746 );
747 }
748 for i in 1..CpuLinAlg::vec_len(&s) {
749 assert!(
750 CpuLinAlg::vec_get(&s, i - 1) >= CpuLinAlg::vec_get(&s, i) - 1e-12,
751 "singular values not descending: s[{}]={} < s[{}]={}",
752 i - 1,
753 CpuLinAlg::vec_get(&s, i - 1),
754 i,
755 CpuLinAlg::vec_get(&s, i)
756 );
757 }
758 }
759
760 #[test]
761 fn test_svd_orthonormal_columns() {
762 let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
763 let (u, _s, v) = CpuLinAlg::svd(&m).unwrap();
764
765 let utu = CpuLinAlg::mat_mul(&CpuLinAlg::mat_transpose(&u), &u);
767 assert_approx_identity(&utu, 1e-10);
768
769 let vtv = CpuLinAlg::mat_mul(&CpuLinAlg::mat_transpose(&v), &v);
771 assert_approx_identity(&vtv, 1e-10);
772 }
773
774 #[test]
777 fn test_svd_1x1_matrix() {
778 let m = mat_from_rows(1, 1, &[7.0]);
779 let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
780 assert_eq!(CpuLinAlg::vec_len(&s), 1);
781 assert!((CpuLinAlg::vec_get(&s, 0) - 7.0).abs() < 1e-10);
782 }
783
784 #[test]
785 fn test_svd_1x1_negative() {
786 let m = mat_from_rows(1, 1, &[-3.0]);
787 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
788 assert!(CpuLinAlg::vec_get(&s, 0) >= 0.0);
790 assert!((CpuLinAlg::vec_get(&s, 0) - 3.0).abs() < 1e-10);
791 let recon = reconstruct_usv(&u, &s, &v);
793 assert!((CpuLinAlg::mat_get(&recon, 0, 0) - (-3.0)).abs() < 1e-10);
794 }
795
796 #[test]
797 fn test_svd_zero_matrix() {
798 let m = CpuLinAlg::zeros_mat(3, 3);
799 let (_u, s, _v) = CpuLinAlg::svd(&m).unwrap();
800 for i in 0..CpuLinAlg::vec_len(&s) {
801 assert!(
802 CpuLinAlg::vec_get(&s, i).abs() < 1e-12,
803 "expected zero singular value, got {}",
804 CpuLinAlg::vec_get(&s, i)
805 );
806 }
807 }
808
809 #[test]
810 fn test_svd_repeated_singular_values() {
811 let m = mat_from_rows(3, 3, &[4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0]);
813 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
814 assert!((CpuLinAlg::vec_get(&s, 0) - 4.0).abs() < 1e-10);
815 assert!((CpuLinAlg::vec_get(&s, 1) - 4.0).abs() < 1e-10);
816 assert!((CpuLinAlg::vec_get(&s, 2) - 2.0).abs() < 1e-10);
817
818 let recon = reconstruct_usv(&u, &s, &v);
819 for r in 0..3 {
820 for c in 0..3 {
821 assert!(
822 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10,
823 "reconstruction mismatch at ({r},{c})"
824 );
825 }
826 }
827 }
828
829 #[test]
830 fn test_svd_16x16_reconstruction() {
831 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
833 let m = CpuLinAlg::xavier_mat(16, 16, &mut rng);
834 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
835
836 let recon = reconstruct_usv(&u, &s, &v);
837 for r in 0..16 {
838 for c in 0..16 {
839 assert!(
840 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-8,
841 "reconstruction mismatch at ({r},{c}): got {} expected {}",
842 CpuLinAlg::mat_get(&recon, r, c),
843 CpuLinAlg::mat_get(&m, r, c)
844 );
845 }
846 }
847 }
848
849 #[test]
852 fn test_svd_returns_ok_for_valid_matrix() {
853 let m = mat_from_rows(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0]);
854 let result = CpuLinAlg::svd(&m);
855 assert!(result.is_ok(), "SVD of valid matrix should return Ok");
856 let (u, s, v) = result.unwrap();
857 assert_eq!(CpuLinAlg::vec_len(&s), 3);
858 assert_eq!(CpuLinAlg::mat_rows(&u), 3);
859 assert_eq!(CpuLinAlg::mat_rows(&v), 3);
860 }
861
862 #[test]
863 fn test_svd_result_reconstruction() {
864 let m = mat_from_rows(2, 2, &[5.0, 0.0, 0.0, 3.0]);
866 let (u, s, v) = CpuLinAlg::svd(&m).unwrap();
867 let recon = reconstruct_usv(&u, &s, &v);
868 for r in 0..2 {
869 for c in 0..2 {
870 assert!(
871 (CpuLinAlg::mat_get(&recon, r, c) - CpuLinAlg::mat_get(&m, r, c)).abs() < 1e-10
872 );
873 }
874 }
875 }
876}