mdarray_linalg/testing/svd/
mod.rs

1use approx::assert_relative_eq;
2use num_complex::Complex;
3
4use crate::{assert_complex_matrix_eq, assert_matrix_eq};
5use mdarray::DTensor;
6
7use crate::pretty_print;
8use crate::svd::{SVD, SVDDecomp};
9
10use super::common::naive_matmul;
11
12use num_complex::ComplexFloat;
13use rand::Rng;
14
15fn test_svd_reconstruction<T>(bd: &impl SVD<T>, a: &DTensor<T, 2>, debug_print: bool)
16where
17    T: ComplexFloat<Real = f64>
18        + Default
19        + Copy
20        + std::fmt::Debug
21        + approx::AbsDiffEq<Epsilon = T::Real>
22        + std::fmt::Display
23        + approx::RelativeEq,
24    T::Real: std::fmt::Display,
25{
26    let (m, n) = (a.shape().0, a.shape().1);
27    let min_dim = m.min(n);
28
29    let SVDDecomp { s, u, vt } = bd.svd(&mut a.clone()).expect("SVD failed");
30
31    assert_eq!(*s.shape(), (n, n));
32    assert_eq!(*u.shape(), (m, m));
33    assert_eq!(*vt.shape(), (n, n));
34
35    let mut sigma = DTensor::<T, 2>::zeros([m, n]);
36    for i in 0..min_dim {
37        sigma[[i, i]] = s[[0, i]];
38    }
39
40    if debug_print {
41        println!("=== Σ (Sigma) ===");
42        pretty_print(&sigma);
43        println!("=== U ===");
44        pretty_print(&u);
45        println!("=== Vᵀ ===");
46        pretty_print(&vt);
47    }
48
49    let us = naive_matmul(&u, &sigma);
50    if debug_print {
51        println!("=== U × Σ ===");
52        pretty_print(&us);
53    }
54
55    let usvt = naive_matmul(&us, &vt);
56    if debug_print {
57        println!("=== U × Σ × Vᵀ  ===");
58        pretty_print(&usvt);
59        println!("=== A original ===");
60        pretty_print(a);
61    }
62
63    assert_matrix_eq!(*a, usvt);
64}
65
66pub fn test_svd_square_matrix(bd: &impl SVD<f64>) {
67    let n = 3;
68    let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * i[1]) as f64);
69    test_svd_reconstruction(bd, &a, true);
70}
71
72pub fn test_svd_rectangular_m_gt_n(bd: &impl SVD<f64>) {
73    let (m, n) = (4, 3);
74    let a = DTensor::<f64, 2>::from_fn([m, n], |i| (i[0] * i[1]) as f64);
75    test_svd_reconstruction(bd, &a, true);
76}
77
78pub fn test_svd_big_square_matrix(bd: &impl SVD<f64>) {
79    let n = 200;
80    let a = DTensor::<f64, 2>::from_fn([n, n], |i| (i[0] * i[1]) as f64);
81    test_svd_reconstruction(bd, &a, false);
82}
83
84pub fn test_svd_random_matrix(bd: &impl SVD<f64>) {
85    let mut rng = rand::rng();
86    let n = 10;
87    let a = DTensor::<f64, 2>::from_fn([n, n], |_| rng.random::<f64>());
88    test_svd_reconstruction(bd, &a, false);
89}
90
91pub fn test_svd_cplx_square_matrix(bd: &impl SVD<Complex<f64>>) {
92    let n = 3;
93    let a = DTensor::<Complex<f64>, 2>::from_fn([n, n], |i| {
94        Complex::new((i[0] * i[1]) as f64, i[1] as f64)
95    });
96
97    let SVDDecomp { s, u, vt } = bd.svd(&mut a.clone()).expect("SVD failed");
98
99    assert_eq!(*s.shape(), (n, n));
100    assert_eq!(*u.shape(), (n, n));
101    assert_eq!(*vt.shape(), (n, n));
102
103    let mut sigma = DTensor::<Complex<f64>, 2>::zeros([n, n]);
104    for i in 0..n {
105        sigma[[i, i]] = s[[0, i]];
106    }
107
108    println!("=== Σ (Sigma) ===");
109    pretty_print(&sigma);
110    println!("=== U ===");
111    pretty_print(&u);
112    println!("=== Vᵀ ===");
113    pretty_print(&vt);
114
115    let us = naive_matmul(&u, &sigma);
116    println!("=== U × Σ ===");
117    pretty_print(&us);
118    let usvt = naive_matmul(&us, &vt);
119    println!("=== U × Σ × Vᵀ  ===");
120    pretty_print(&usvt);
121    println!("=== A original ===");
122    pretty_print(&a);
123
124    assert_complex_matrix_eq!(a, usvt);
125}