mdarray_linalg/testing/solve/
mod.rs

1use approx::assert_relative_eq;
2
3use super::common::random_matrix;
4use crate::solve::{Solve, SolveResult};
5use mdarray::DTensor;
6
7fn test_solve_verification<T>(original_a: &DTensor<T, 2>, x: &DTensor<T, 2>, b: &DTensor<T, 2>)
8where
9    T: Default
10        + std::fmt::Debug
11        + Copy
12        + std::ops::Mul<Output = T>
13        + std::ops::Add<Output = T>
14        + std::ops::Sub<Output = T>,
15    f64: From<T>,
16{
17    let (n, nrhs) = *b.shape();
18
19    let mut ax = DTensor::<T, 2>::zeros([n, nrhs]);
20    for i in 0..n {
21        for j in 0..nrhs {
22            let mut sum = T::default();
23            for k in 0..n {
24                sum = sum + original_a[[i, k]] * x[[k, j]];
25            }
26            ax[[i, j]] = sum;
27        }
28    }
29
30    for i in 0..n {
31        for j in 0..nrhs {
32            let diff = f64::from(ax[[i, j]]) - f64::from(b[[i, j]]);
33            assert_relative_eq!(diff, 0.0, epsilon = 1e-10);
34        }
35    }
36}
37
38pub fn test_solve_single_rhs(bd: &impl Solve<f64>) {
39    let n = 4;
40    let a = random_matrix(n, n);
41    let original_a = a.clone();
42    let b = random_matrix(n, 1);
43
44    let SolveResult { x, .. } = bd.solve(&mut a.clone(), &b).expect("");
45
46    test_solve_verification(&original_a, &x, &b);
47}
48
49pub fn test_solve_multiple_rhs(bd: &impl Solve<f64>) {
50    let n = 5;
51    let nrhs = 3;
52    let mut a = random_matrix(n, n);
53    let original_a = a.clone();
54    let b = random_matrix(n, nrhs);
55
56    let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
57
58    test_solve_verification(&original_a, &x, &b);
59}
60
61pub fn test_solve_overwrite(bd: &impl Solve<f64>) {
62    let n = 4;
63    let nrhs = 2;
64    let mut a = random_matrix(n, n);
65    let original_a = a.clone();
66    let mut b = random_matrix(n, nrhs);
67    let original_b = b.clone();
68    let mut p = DTensor::<f64, 2>::zeros([n, n]);
69
70    let _ = bd.solve_overwrite(&mut a, &mut b, &mut p);
71
72    // b now contains the solution x
73    test_solve_verification(&original_a, &b, &original_b);
74}
75
76pub fn test_solve_identity_matrix(bd: &impl Solve<f64>) {
77    let n = 3;
78    let nrhs = 2;
79
80    let mut a = DTensor::<f64, 2>::zeros([n, n]);
81    for i in 0..n {
82        a[[i, i]] = 1.0;
83    }
84    let original_a = a.clone();
85
86    let b = random_matrix(n, nrhs);
87
88    let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
89
90    for i in 0..n {
91        for j in 0..nrhs {
92            let diff = x[[i, j]] - b[[i, j]];
93            assert_relative_eq!(diff, 0.0, epsilon = 1e-14);
94        }
95    }
96
97    test_solve_verification(&original_a, &x, &b);
98}
99
100pub fn test_solve_complex(bd: &impl Solve<num_complex::Complex<f64>>) {
101    use num_complex::Complex;
102
103    let n = 4;
104    let nrhs = 2;
105
106    let re = random_matrix(n, n);
107    let im = random_matrix(n, n);
108
109    let mut a = DTensor::<Complex<f64>, 2>::from_fn([n, n], |i| {
110        Complex::new(re[[i[0], i[1]]], im[[i[0], i[1]]])
111    });
112    println!("a={a:?}");
113    let original_a = a.clone();
114
115    // Create random complex right-hand side
116    let b = DTensor::<Complex<f64>, 2>::from_fn([n, nrhs], |i| {
117        Complex::new((i[0] + 2 * i[1] + 1) as f64, (2 * i[0] + i[1] + 1) as f64)
118    });
119    println!("b={b:?}");
120
121    let SolveResult { x, .. } = bd.solve(&mut a, &b).expect("");
122
123    println!("{x:?}");
124
125    // Verify A * X = B for complex matrices
126    for i in 0..n {
127        for j in 0..nrhs {
128            let mut sum = Complex::new(0.0, 0.0);
129            for k in 0..n {
130                sum += original_a[[i, k]] * x[[k, j]];
131            }
132            let diff_real = sum.re - b[[i, j]].re;
133            let diff_imag = sum.im - b[[i, j]].im;
134            assert_relative_eq!(diff_real, 0.0, epsilon = 1e-10);
135            assert_relative_eq!(diff_imag, 0.0, epsilon = 1e-10);
136        }
137    }
138}