mdarray_linalg_faer/solve/
context.rs1use faer_traits::ComplexField;
17use mdarray::{DSlice, Layout, tensor};
18use mdarray_linalg::identity;
19use mdarray_linalg::solve::{Solve, SolveError, SolveResult, SolveResultType};
20use num_complex::ComplexFloat;
21
22use faer::linalg::solvers::Solve as FaerSolve;
23
24use crate::{Faer, into_faer, into_faer_mut};
25
26impl<T> Solve<T> for Faer
27where
28 T: ComplexFloat
29 + ComplexField
30 + Default
31 + std::convert::From<<T as num_complex::ComplexFloat>::Real>
32 + 'static,
33{
34 fn solve<La: Layout, Lb: Layout>(
38 &self,
39 a: &mut DSlice<T, 2, La>,
40 b: &DSlice<T, 2, Lb>,
41 ) -> SolveResultType<T> {
42 let (m, n) = *a.shape();
43 let (b_m, b_n) = *b.shape();
44
45 if m != n {
46 return Err(SolveError::InvalidDimensions);
47 }
48
49 if b_m != m {
50 return Err(SolveError::InvalidDimensions);
51 }
52
53 let a_faer = into_faer_mut(a);
54
55 let solver = a_faer.partial_piv_lu();
56
57 let b_faer = into_faer(b);
58 let x_faer = solver.solve(b_faer);
59
60 let mut x_mda = tensor![[T::default(); b_n]; m];
61 let mut x_faer_mut = into_faer_mut(&mut x_mda);
62 for i in 0..m {
63 for j in 0..b_n {
64 x_faer_mut[(i, j)] = x_faer[(i, j)];
65 }
66 }
67
68 let p_mda = identity(m); Ok(SolveResult { x: x_mda, p: p_mda })
71 }
72
73 fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
79 &self,
80 a: &mut DSlice<T, 2, La>,
81 b: &mut DSlice<T, 2, Lb>,
82 p: &mut DSlice<T, 2, Lp>,
83 ) -> Result<(), SolveError> {
84 let (m, n) = *a.shape();
85 let (b_m, b_n) = *b.shape();
86
87 if m != n {
88 return Err(SolveError::InvalidDimensions);
89 }
90
91 if b_m != m {
92 return Err(SolveError::InvalidDimensions);
93 }
94
95 let _par = faer::get_global_parallelism();
96 let a_faer = into_faer(a);
97
98 let solver = a_faer.partial_piv_lu();
99
100 let b_faer = into_faer(b).to_owned();
101 let x_faer = solver.solve(b_faer);
102
103 let mut b_faer_mut = into_faer_mut(b);
104 for i in 0..m {
105 for j in 0..b_n {
106 b_faer_mut[(i, j)] = x_faer[(i, j)];
107 }
108 }
109
110 let mut p_faer = into_faer_mut(p);
111 for i in 0..m {
112 for j in 0..m {
113 if i != j {
114 p_faer[(i, j)] = T::zero();
115 } else {
116 p_faer[(i, j)] = T::one();
117 }
118 }
119 }
120
121 Ok(())
122 }
123}