mdarray_linalg/testing/solve/
mod.rs1use approx::assert_relative_eq;
2
3use super::common::random_matrix;
4use crate::solve::{Solve, SolveResult};
5use mdarray::DTensor;
6
7fn test_solve_verification<T>(original_a: &DTensor<T, 2>, x: &DTensor<T, 2>, b: &DTensor<T, 2>)
8where
9 T: Default
10 + std::fmt::Debug
11 + Copy
12 + std::ops::Mul<Output = T>
13 + std::ops::Add<Output = T>
14 + std::ops::Sub<Output = T>,
15 f64: From<T>,
16{
17 let (n, nrhs) = *b.shape();
18
19 let mut ax = DTensor::<T, 2>::zeros([n, nrhs]);
20 for i in 0..n {
21 for j in 0..nrhs {
22 let mut sum = T::default();
23 for k in 0..n {
24 sum = sum + original_a[[i, k]] * x[[k, j]];
25 }
26 ax[[i, j]] = sum;
27 }
28 }
29
30 for i in 0..n {
31 for j in 0..nrhs {
32 let diff = f64::from(ax[[i, j]]) - f64::from(b[[i, j]]);
33 assert_relative_eq!(diff, 0.0, epsilon = 1e-10);
34 }
35 }
36}
37
38pub fn test_solve_single_rhs(bd: &impl Solve<f64>) {
39 let n = 4;
40 let a = random_matrix(n, n);
41 let original_a = a.clone();
42 let b = random_matrix(n, 1);
43
44 let SolveResult { x, .. } = bd.solve(&mut a.clone(), &b).expect("");
45
46 test_solve_verification(&original_a, &x, &b);
47}
48
49pub fn test_solve_multiple_rhs(bd: &impl Solve<f64>) {
50 let n = 5;
51 let nrhs = 3;
52 let mut a = random_matrix(n, n);
53 let original_a = a.clone();
54 let b = random_matrix(n, nrhs);
55
56 let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
57
58 test_solve_verification(&original_a, &x, &b);
59}
60
61pub fn test_solve_overwrite(bd: &impl Solve<f64>) {
62 let n = 4;
63 let nrhs = 2;
64 let mut a = random_matrix(n, n);
65 let original_a = a.clone();
66 let mut b = random_matrix(n, nrhs);
67 let original_b = b.clone();
68 let mut p = DTensor::<f64, 2>::zeros([n, n]);
69
70 let _ = bd.solve_overwrite(&mut a, &mut b, &mut p);
71
72 test_solve_verification(&original_a, &b, &original_b);
74}
75
76pub fn test_solve_identity_matrix(bd: &impl Solve<f64>) {
77 let n = 3;
78 let nrhs = 2;
79
80 let mut a = DTensor::<f64, 2>::zeros([n, n]);
81 for i in 0..n {
82 a[[i, i]] = 1.0;
83 }
84 let original_a = a.clone();
85
86 let b = random_matrix(n, nrhs);
87
88 let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
89
90 for i in 0..n {
91 for j in 0..nrhs {
92 let diff = x[[i, j]] - b[[i, j]];
93 assert_relative_eq!(diff, 0.0, epsilon = 1e-14);
94 }
95 }
96
97 test_solve_verification(&original_a, &x, &b);
98}
99
100pub fn test_solve_complex(bd: &impl Solve<num_complex::Complex<f64>>) {
101 use num_complex::Complex;
102
103 let n = 4;
104 let nrhs = 2;
105
106 let re = random_matrix(n, n);
107 let im = random_matrix(n, n);
108
109 let mut a = DTensor::<Complex<f64>, 2>::from_fn([n, n], |i| {
110 Complex::new(re[[i[0], i[1]]], im[[i[0], i[1]]])
111 });
112 println!("a={a:?}");
113 let original_a = a.clone();
114
115 let b = DTensor::<Complex<f64>, 2>::from_fn([n, nrhs], |i| {
117 Complex::new((i[0] + 2 * i[1] + 1) as f64, (2 * i[0] + i[1] + 1) as f64)
118 });
119 println!("b={b:?}");
120
121 let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
122
123 println!("{x:?}");
124
125 for i in 0..n {
127 for j in 0..nrhs {
128 let mut sum = Complex::new(0.0, 0.0);
129 for k in 0..n {
130 sum += original_a[[i, k]] * x[[k, j]];
131 }
132 let diff_real = sum.re - b[[i, j]].re;
133 let diff_imag = sum.im - b[[i, j]].im;
134 assert_relative_eq!(diff_real, 0.0, epsilon = 1e-10);
135 assert_relative_eq!(diff_imag, 0.0, epsilon = 1e-10);
136 }
137 }
138}