1use dyn_stack::{MemBuffer, MemStack};
10
11use super::simple::lu_faer;
12use faer_traits::ComplexField;
13use mdarray::{DSlice, DTensor, Layout, tensor};
14use mdarray_linalg::lu::{InvError, InvResult, LU};
15use num_complex::ComplexFloat;
16
17use crate::{Faer, into_faer_mut, into_mdarray};
18
19impl<T> LU<T> for Faer
20where
21 T: ComplexFloat
22 + ComplexField
23 + Default
24 + std::convert::From<<T as num_complex::ComplexFloat>::Real>
25 + 'static,
26{
27 fn lu<L: Layout>(
29 &self,
30 a: &mut DSlice<T, 2, L>,
31 ) -> (DTensor<T, 2>, DTensor<T, 2>, DTensor<T, 2>) {
32 let (m, n) = *a.shape();
33 let min_mn = m.min(n);
34 let mut l_mda = tensor![[T::default(); min_mn]; m ];
35 let mut u_mda = tensor![[T::default(); n ]; min_mn];
36 let mut p_mda = tensor![[T::default(); m]; m];
37
38 lu_faer(a, &mut l_mda, &mut u_mda, &mut p_mda);
39
40 (l_mda, u_mda, p_mda)
41 }
42
43 fn lu_overwrite<L: Layout, Ll: Layout, Lu: Layout, Lp: Layout>(
45 &self,
46 a: &mut DSlice<T, 2, L>,
47 l: &mut DSlice<T, 2, Ll>,
48 u: &mut DSlice<T, 2, Lu>,
49 p: &mut DSlice<T, 2, Lp>,
50 ) {
51 lu_faer::<T, L, Ll, Lu, Lp>(a, l, u, p);
52 }
53
54 fn inv<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
56 let (m, n) = *a.shape();
57
58 if m != n {
59 return Err(InvError::NotSquare {
60 rows: m as i32,
61 cols: n as i32,
62 });
63 }
64
65 let par = faer::get_global_parallelism();
66 let mut a_faer = into_faer_mut(a);
67
68 let mut row_perm_fwd = vec![0usize; m];
69 let mut row_perm_bwd = vec![0usize; m];
70
71 faer::linalg::lu::partial_pivoting::factor::lu_in_place(
72 a_faer.as_mut(),
73 &mut row_perm_fwd,
74 &mut row_perm_bwd,
75 par,
76 MemStack::new(&mut MemBuffer::new(
77 faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(
78 m,
79 n,
80 par,
81 faer::prelude::default(),
82 ),
83 )),
84 faer::prelude::default(),
85 );
86
87 let l_mat = a_faer.as_ref();
88 let u_mat = a_faer.as_ref();
89
90 let perm = unsafe {
91 faer::perm::Perm::new_unchecked(
92 row_perm_fwd.into_boxed_slice(),
93 row_perm_bwd.into_boxed_slice(),
94 )
95 };
96
97 let mut inv_mat = faer::Mat::<T>::zeros(m, n);
98
99 faer::linalg::lu::partial_pivoting::inverse::inverse(
100 inv_mat.as_mut(),
101 l_mat,
102 u_mat,
103 perm.as_ref(),
104 par,
105 MemStack::new(&mut MemBuffer::new(
106 faer::linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(m, par),
107 )),
108 );
109 Ok(into_mdarray(inv_mat))
110 }
111
112 fn inv_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
114 let (m, n) = *a.shape();
115
116 if m != n {
117 return Err(InvError::NotSquare {
118 rows: m as i32,
119 cols: n as i32,
120 });
121 }
122
123 let par = faer::get_global_parallelism();
124 let mut a_faer = into_faer_mut(a);
125
126 let mut row_perm_fwd = vec![0usize; m];
127 let mut row_perm_bwd = vec![0usize; m];
128
129 faer::linalg::lu::partial_pivoting::factor::lu_in_place(
130 a_faer.as_mut(),
131 &mut row_perm_fwd,
132 &mut row_perm_bwd,
133 par,
134 MemStack::new(&mut MemBuffer::new(
135 faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(
136 m,
137 n,
138 par,
139 faer::prelude::default(),
140 ),
141 )),
142 faer::prelude::default(),
143 );
144
145 let l_mat = a_faer.as_ref();
146 let u_mat = a_faer.as_ref();
147
148 let perm = unsafe {
149 faer::perm::Perm::new_unchecked(
150 row_perm_fwd.into_boxed_slice(),
151 row_perm_bwd.into_boxed_slice(),
152 )
153 };
154
155 let mut inv_mat = faer::Mat::<T>::zeros(m, n);
156
157 faer::linalg::lu::partial_pivoting::inverse::inverse(
158 inv_mat.as_mut(),
159 l_mat,
160 u_mat,
161 perm.as_ref(),
162 par,
163 MemStack::new(&mut MemBuffer::new(
164 faer::linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(m, par),
165 )),
166 );
167
168 for i in 0..m {
169 for j in 0..n {
170 a_faer[(i, j)] = inv_mat[(i, j)];
171 }
172 }
173
174 Ok(())
175 }
176
177 fn det<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> T {
179 let (m, n) = *a.shape();
180 assert_eq!(m, n, "determinant is only defined for square matrices");
181 let a_faer = into_faer_mut(a);
182 a_faer.determinant()
183 }
184
185 fn choleski<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> InvResult<T> {
187 todo!("choleski will be implemented later")
188 }
189
190 fn choleski_overwrite<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
192 todo!("choleski_overwrite will be implemented later")
193 }
194}