mdarray_linalg/testing/svd/
mod.rs1use approx::assert_relative_eq;
2use num_complex::Complex;
3
4use crate::{assert_complex_matrix_eq, assert_matrix_eq};
5use mdarray::DTensor;
6
7use crate::pretty_print;
8use crate::svd::{SVD, SVDDecomp};
9
10use super::common::naive_matmul;
11
12use num_complex::ComplexFloat;
13use rand::Rng;
14
15fn test_svd_reconstruction<T>(bd: &impl SVD<T>, a: &DTensor<T, 2>, debug_print: bool)
16where
17 T: ComplexFloat<Real = f64>
18 + Default
19 + Copy
20 + std::fmt::Debug
21 + approx::AbsDiffEq<Epsilon = T::Real>
22 + std::fmt::Display
23 + approx::RelativeEq,
24 T::Real: std::fmt::Display,
25{
26 let (m, n) = (a.shape().0, a.shape().1);
27 let min_dim = m.min(n);
28
29 let SVDDecomp { s, u, vt } = bd.svd(&mut a.clone()).expect("SVD failed");
30
31 assert_eq!(*s.shape(), (n, n));
32 assert_eq!(*u.shape(), (m, m));
33 assert_eq!(*vt.shape(), (n, n));
34
35 let mut sigma = DTensor::<T, 2>::zeros([m, n]);
36 for i in 0..min_dim {
37 sigma[[i, i]] = s[[0, i]];
38 }
39
40 if debug_print {
41 println!("=== Σ (Sigma) ===");
42 pretty_print(&sigma);
43 println!("=== U ===");
44 pretty_print(&u);
45 println!("=== Vᵀ ===");
46 pretty_print(&vt);
47 }
48
49 let us = naive_matmul(&u, &sigma);
50 if debug_print {
51 println!("=== U × Σ ===");
52 pretty_print(&us);
53 }
54
55 let usvt = naive_matmul(&us, &vt);
56 if debug_print {
57 println!("=== U × Σ × Vᵀ ===");
58 pretty_print(&usvt);
59 println!("=== A original ===");
60 pretty_print(a);
61 }
62
63 assert_matrix_eq!(*a, usvt);
64}
65
66pub fn test_svd_square_matrix(bd: &impl SVD<f64>) {
67 let n = 3;
68 let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * i[1]) as f64);
69 test_svd_reconstruction(bd, &a, true);
70}
71
72pub fn test_svd_rectangular_m_gt_n(bd: &impl SVD<f64>) {
73 let (m, n) = (4, 3);
74 let a = DTensor::<f64, 2>::from_fn([m, n], |i| (i[0] * i[1]) as f64);
75 test_svd_reconstruction(bd, &a, true);
76}
77
78pub fn test_svd_big_square_matrix(bd: &impl SVD<f64>) {
79 let n = 200;
80 let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * i[1]) as f64);
81 test_svd_reconstruction(bd, &a, false);
82}
83
84pub fn test_svd_random_matrix(bd: &impl SVD<f64>) {
85 let mut rng = rand::rng();
86 let n = 10;
87 let a = DTensor::<f64, 2>::from_fn([n, n], |_| rng.random::<f64>());
88 test_svd_reconstruction(bd, &a, false);
89}
90
91pub fn test_svd_cplx_square_matrix(bd: &impl SVD<Complex<f64>>) {
92 let n = 3;
93 let a = DTensor::<Complex<f64>, 2>::from_fn([n, n], |i| {
94 Complex::new((i[0] * i[1]) as f64, i[1] as f64)
95 });
96
97 let SVDDecomp { s, u, vt } = bd.svd(&mut a.clone()).expect("SVD failed");
98
99 assert_eq!(*s.shape(), (n, n));
100 assert_eq!(*u.shape(), (n, n));
101 assert_eq!(*vt.shape(), (n, n));
102
103 let mut sigma = DTensor::<Complex<f64>, 2>::zeros([n, n]);
104 for i in 0..n {
105 sigma[[i, i]] = s[[0, i]];
106 }
107
108 println!("=== Σ (Sigma) ===");
109 pretty_print(&sigma);
110 println!("=== U ===");
111 pretty_print(&u);
112 println!("=== Vᵀ ===");
113 pretty_print(&vt);
114
115 let us = naive_matmul(&u, &sigma);
116 println!("=== U × Σ ===");
117 pretty_print(&us);
118 let usvt = naive_matmul(&us, &vt);
119 println!("=== U × Σ × Vᵀ ===");
120 pretty_print(&usvt);
121 println!("=== A original ===");
122 pretty_print(&a);
123
124 assert_complex_matrix_eq!(a, usvt);
125}