bem/core/solver/
direct.rs1use ndarray::{Array1, Array2};
6use num_complex::Complex64;
7
8#[derive(Debug)]
10pub struct DirectSolution {
11 pub x: Array1<Complex64>,
13 pub success: bool,
15}
16
17pub fn direct_solve(a: &Array2<Complex64>, b: &Array1<Complex64>) -> DirectSolution {
34 use ndarray_linalg::Solve;
35
36 match a.solve(b) {
37 Ok(x) => DirectSolution { x, success: true },
38 Err(_) => DirectSolution {
39 x: Array1::zeros(b.len()),
40 success: false,
41 },
42 }
43}
44
45pub fn direct_solve_lu(a: &Array2<Complex64>, b: &Array1<Complex64>) -> DirectSolution {
49 direct_solve(a, b)
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn test_direct_solve_simple() {
59 let a = Array2::from_shape_vec(
61 (2, 2),
62 vec![
63 Complex64::new(2.0, 0.0),
64 Complex64::new(1.0, 0.0),
65 Complex64::new(1.0, 0.0),
66 Complex64::new(3.0, 0.0),
67 ],
68 )
69 .unwrap();
70
71 let b = Array1::from_vec(vec![Complex64::new(4.0, 0.0), Complex64::new(5.0, 0.0)]);
72
73 let solution = direct_solve(&a, &b);
74
75 assert!(solution.success);
76
77 let ax = a.dot(&solution.x);
79 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
80 assert!(error < 1e-10);
81 }
82
83 #[test]
84 fn test_direct_solve_complex() {
85 let a = Array2::from_shape_vec(
87 (2, 2),
88 vec![
89 Complex64::new(1.0, 1.0),
90 Complex64::new(0.0, 1.0),
91 Complex64::new(1.0, 0.0),
92 Complex64::new(1.0, -1.0),
93 ],
94 )
95 .unwrap();
96
97 let b = Array1::from_vec(vec![Complex64::new(2.0, 1.0), Complex64::new(1.0, 0.0)]);
98
99 let solution = direct_solve(&a, &b);
100
101 assert!(solution.success);
102
103 let ax = a.dot(&solution.x);
105 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
106 assert!(error < 1e-10);
107 }
108
109 #[test]
110 fn test_direct_solve_identity() {
111 let n = 5;
112 let mut a: Array2<Complex64> = Array2::zeros((n, n));
113 for i in 0..n {
114 a[[i, i]] = Complex64::new(1.0, 0.0);
115 }
116
117 let b = Array1::from_vec(
118 (1..=n)
119 .map(|i| Complex64::new(i as f64, 0.0))
120 .collect::<Vec<_>>(),
121 );
122
123 let solution = direct_solve(&a, &b);
124
125 assert!(solution.success);
126
127 let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
129 assert!(error < 1e-10);
130 }
131}