mdarray_linalg/testing/qr/
mod.rs

1use mdarray::DTensor;
2
3use crate::qr::QR;
4
5use approx::assert_relative_eq;
6use num_complex::Complex;
7use rand::prelude::*;
8
9use super::common::naive_matmul;
10use crate::{assert_complex_matrix_eq, assert_matrix_eq};
11use crate::pretty_print;
12
13pub fn test_qr_random_matrix(bd: &impl QR<f64>) {
14    let (m, n) = (5, 5);
15    let mut rng = rand::rng();
16
17    let a = DTensor::<f64, 2>::from_fn([m, n], |_| rng.random::<f64>());
18    test_qr_reconstruction(bd, &a);
19}
20
21pub fn test_qr_structured_matrix(bd: &impl QR<f64>) {
22    let (m, n) = (3, 3);
23
24    let a = DTensor::<f64, 2>::from_fn([m, n], |i| (i[0] * i[1] + 1) as f64);
25    test_qr_reconstruction(bd, &a);
26}
27
28pub fn test_qr_complex_matrix(bd: &impl QR<Complex<f64>>) {
29    let (m, n) = (3, 3);
30
31    let mut a = DTensor::<Complex<f64>, 2>::from_fn([m, n], |i| {
32        Complex::new((i[0] + 1) as f64, (i[1] + 1) as f64)
33    });
34
35    a[[1, 2]] = Complex::new(1., 5.); // destroy symmetry
36
37    let mut q = DTensor::<Complex<f64>, 2>::zeros([m, m]);
38    let mut r = DTensor::<Complex<f64>, 2>::zeros([m, n]);
39
40    bd.qr_overwrite(&mut a.clone(), &mut q, &mut r);
41    let reconstructed = naive_matmul(&q, &r);
42    assert_complex_matrix_eq!(a, reconstructed);
43
44    let (q, r) = bd.qr(&mut a.clone());
45    let reconstructed = naive_matmul(&q, &r);
46    assert_complex_matrix_eq!(a, reconstructed);
47
48    pretty_print(&a);
49    pretty_print(&reconstructed);
50}
51
52pub fn test_qr_reconstruction<T>(bd: &impl QR<T>, a: &DTensor<T, 2>)
53where
54    T: num_traits::float::FloatConst
55        + Default
56        + Copy
57        + std::fmt::Debug
58        + approx::AbsDiffEq<Epsilon = f64>
59        + std::fmt::Display
60        + approx::RelativeEq
61        + num_traits::Float
62        + std::convert::From<i8>,
63{
64    let (m, n) = *a.shape();
65    let mut q = DTensor::<T, 2>::zeros([m, m]);
66    let mut r = DTensor::<T, 2>::zeros([m, n]);
67
68    bd.qr_overwrite(&mut a.clone(), &mut q, &mut r);
69    let reconstructed = naive_matmul(&q, &r);
70
71    pretty_print(&q);
72    pretty_print(&r);
73
74    pretty_print(a);
75    pretty_print(&reconstructed);
76
77    assert_matrix_eq!(a, reconstructed);
78
79    let (q, r) = bd.qr(&mut a.clone());
80    let reconstructed = naive_matmul(&q, &r);
81
82    pretty_print(a);
83    pretty_print(&reconstructed);
84
85    assert_matrix_eq!(a, reconstructed);
86}