mdarray_linalg/testing/qr/
mod.rs1use mdarray::DTensor;
2
3use crate::qr::QR;
4
5use approx::assert_relative_eq;
6use num_complex::Complex;
7use rand::prelude::*;
8
9use super::common::naive_matmul;
10use crate::{assert_complex_matrix_eq, assert_matrix_eq};
11use crate::pretty_print;
12
13pub fn test_qr_random_matrix(bd: &impl QR<f64>) {
14 let (m, n) = (5, 5);
15 let mut rng = rand::rng();
16
17 let a = DTensor::<f64, 2>::from_fn([m, n], |_| rng.random::<f64>());
18 test_qr_reconstruction(bd, &a);
19}
20
21pub fn test_qr_structured_matrix(bd: &impl QR<f64>) {
22 let (m, n) = (3, 3);
23
24 let a = DTensor::<f64, 2>::from_fn([m, n], |i| (i[0] * i[1] + 1) as f64);
25 test_qr_reconstruction(bd, &a);
26}
27
28pub fn test_qr_complex_matrix(bd: &impl QR<Complex<f64>>) {
29 let (m, n) = (3, 3);
30
31 let mut a = DTensor::<Complex<f64>, 2>::from_fn([m, n], |i| {
32 Complex::new((i[0] + 1) as f64, (i[1] + 1) as f64)
33 });
34
35 a[[1, 2]] = Complex::new(1., 5.); let mut q = DTensor::<Complex<f64>, 2>::zeros([m, m]);
38 let mut r = DTensor::<Complex<f64>, 2>::zeros([m, n]);
39
40 bd.qr_overwrite(&mut a.clone(), &mut q, &mut r);
41 let reconstructed = naive_matmul(&q, &r);
42 assert_complex_matrix_eq!(a, reconstructed);
43
44 let (q, r) = bd.qr(&mut a.clone());
45 let reconstructed = naive_matmul(&q, &r);
46 assert_complex_matrix_eq!(a, reconstructed);
47
48 pretty_print(&a);
49 pretty_print(&reconstructed);
50}
51
52pub fn test_qr_reconstruction<T>(bd: &impl QR<T>, a: &DTensor<T, 2>)
53where
54 T: num_traits::float::FloatConst
55 + Default
56 + Copy
57 + std::fmt::Debug
58 + approx::AbsDiffEq<Epsilon = f64>
59 + std::fmt::Display
60 + approx::RelativeEq
61 + num_traits::Float
62 + std::convert::From<i8>,
63{
64 let (m, n) = *a.shape();
65 let mut q = DTensor::<T, 2>::zeros([m, m]);
66 let mut r = DTensor::<T, 2>::zeros([m, n]);
67
68 bd.qr_overwrite(&mut a.clone(), &mut q, &mut r);
69 let reconstructed = naive_matmul(&q, &r);
70
71 pretty_print(&q);
72 pretty_print(&r);
73
74 pretty_print(a);
75 pretty_print(&reconstructed);
76
77 assert_matrix_eq!(a, reconstructed);
78
79 let (q, r) = bd.qr(&mut a.clone());
80 let reconstructed = naive_matmul(&q, &r);
81
82 pretty_print(a);
83 pretty_print(&reconstructed);
84
85 assert_matrix_eq!(a, reconstructed);
86}