mdarray_linalg_faer/eig/
context.rs

1// Eigenvalue Decomposition:
2//     A * V = V * Λ  (right eigenvectors)
3//     W^H * A = Λ * W^H  (left eigenvectors)
4// where:
5//     - A is n × n         (input square matrix)
6//     - V is n × n         (right eigenvectors as columns)
7//     - W is n × n         (left eigenvectors as columns)
8//     - Λ is n × n         (diagonal matrix with eigenvalues)
9//
10// For Hermitian/Symmetric matrices:
11//     A = Q * Λ * Q^H
12// where:
13//     - Q is n × n         (orthogonal/unitary eigenvectors)
14//     - Λ is n × n         (diagonal matrix with real eigenvalues)
15//
16// Schur Decomposition:
17//     A = Z * T * Z^H
18// where:
19//     - Z is n × n         (unitary Schur vectors)
20//     - T is n × n         (upper triangular for complex, quasi-upper triangular for real)
21
22use faer_traits::ComplexField;
23use mdarray::{DSlice, Dense, Layout, tensor};
24use mdarray_linalg::eig::{Eig, EigDecomp, EigError, EigResult, SchurError, SchurResult};
25use num_complex::{Complex, ComplexFloat};
26
27use crate::{Faer, into_faer, into_faer_mut};
28
29macro_rules! complex_from_faer {
30    ($val:expr, $t:ty) => {{
31        let re: <$t as ComplexFloat>::Real = unsafe { std::mem::transmute_copy(&($val.re)) };
32        let im: <$t as ComplexFloat>::Real = unsafe { std::mem::transmute_copy(&($val.im)) };
33        Complex::new(re, im)
34    }};
35}
36
37impl<T> Eig<T> for Faer
38where
39    T: ComplexFloat
40        + ComplexField
41        + Default
42        + std::convert::From<<T as num_complex::ComplexFloat>::Real>
43        + 'static,
44{
45    /// Compute eigenvalues and right eigenvectors with new allocated matrices
46    /// The matrix `A` satisfies: `A * v = λ * v` where v are the right eigenvectors
47    fn eig<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
48        let (m, n) = *a.shape();
49
50        if m != n {
51            return Err(EigError::NotSquareMatrix);
52        }
53
54        let a_faer = into_faer(a);
55        let eig_result = a_faer.eigen();
56
57        match eig_result {
58            Ok(eig) => {
59                let eigenvalues = eig.S();
60                let right_vecs = eig.U();
61
62                let x = T::default();
63                let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
64                let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
65
66                for i in 0..n {
67                    eigenvalues_mda[[0, i]] = complex_from_faer!(&eigenvalues[i], T);
68                }
69
70                for i in 0..n {
71                    for j in 0..n {
72                        right_vecs_mda[[i, j]] = complex_from_faer!(&right_vecs[(i, j)], T);
73                    }
74                }
75
76                Ok(EigDecomp {
77                    eigenvalues: eigenvalues_mda,
78                    left_eigenvectors: None,
79                    right_eigenvectors: Some(right_vecs_mda),
80                })
81            }
82            Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
83        }
84    }
85
86    // /// Compute eigenvalues and both left/right eigenvectors with new allocated matrices
87    // /// The matrix A satisfies: `A * vr = λ * vr` and `vl^H * A = λ * vl^H`
88    // /// where `vr` are right eigenvectors and `vl` are left eigenvectors
89    // fn eig_full<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
90    //     let (m, n) = *a.shape();
91
92    //     if m != n {
93    //         return Err(EigError::NotSquareMatrix);
94    //     }
95
96    //     let a_faer = into_faer(a);
97
98    //     let eig_result = a_faer.eigen();
99
100    //     match eig_result {
101    //         Ok(eig) => {
102    //             let eigenvalues = eig.S();
103    //             let right_vecs = eig.U();
104
105    //             let x = T::default();
106    //             let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
107    //             let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
108
109    //             for i in 0..n {
110    //                 eigenvalues_mda[[0, i]] = complex_from_faer!(&eigenvalues[i], T);
111    //             }
112
113    //             let mut right_vecs_faer = into_faer_mut(&mut right_vecs_mda);
114    //             for i in 0..n {
115    //                 for j in 0..n {
116    //                     right_vecs_faer[(i, j)] = complex_from_faer!(right_vecs[(i, j)], T);
117    //                 }
118    //             }
119
120    //             let mut left_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
121
122    //             let mut left_vecs_faer = into_faer_mut(&mut left_vecs_mda);
123    //             for i in 0..n {
124    //                 for j in 0..n {
125    //                     left_vecs_faer[(i, j)] = complex_from_faer!(right_vecs[(i, j)].conj(), T);
126    //                 }
127    //             }
128
129    //             Ok(EigDecomp {
130    //                 eigenvalues: eigenvalues_mda,
131    //                 left_eigenvectors: Some(left_vecs_mda),
132    //                 right_eigenvectors: Some(right_vecs_mda),
133    //             })
134    //         }
135    //         Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
136    //     }
137    // }
138
139    fn eig_full<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> Result<EigDecomp<T>, EigError> {
140        todo!();
141        // let (m, n) = *a.shape();
142        // if m != n {
143        //     return Err(EigError::NotSquareMatrix);
144        // }
145
146        // let par = faer::get_global_parallelism();
147
148        // let x = T::default();
149
150        // let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
151        // let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re()); n]; n];
152        // let mut left_vecs_mda = tensor![[Complex::new(x.re(), x.re()); n]; n];
153
154        // let a_faer = into_faer_mut(a);
155
156        // let params = <faer::linalg::evd::EvdParams as faer::Auto<T>>::auto();
157
158        // // let eig_result = if TypeId::of::<T>() == TypeId::of::<Complex<f32>>()
159        // //     || TypeId::of::<T>() == TypeId::of::<Complex<f64>>()
160        // // {
161        // let eig_result = if true {
162        //     let mut eigenvalues_faer = into_faer_mut(&mut eigenvalues_mda);
163        //     let mut right_vecs_faer = into_faer_mut(&mut right_vecs_mda);
164        //     let mut left_vecs_faer = into_faer_mut(&mut left_vecs_mda);
165
166        //     let a_faer_complex: MatRef<'_, Complex<<T as faer::traits::ComplexField>::Real>> = unsafe {
167        //         faer::hacks::coerce::<_, MatRef<'_, Complex<<T as faer::traits::ComplexField>::Real>>>(
168        //             a_faer,
169        //         )
170        //     };
171
172        //     let mut left_vecs_complex: MatMut<
173        //         '_,
174        //         Complex<<T as faer::traits::ComplexField>::Real>,
175        //     > = unsafe {
176        //         faer::hacks::coerce::<_, MatMut<'_, Complex<<T as faer::traits::ComplexField>::Real>>>(
177        //             left_vecs_faer,
178        //         )
179        //     };
180
181        //     let mut right_vecs_complex: MatMut<
182        //         '_,
183        //         Complex<<T as faer::traits::ComplexField>::Real>,
184        //     > = unsafe {
185        //         faer::hacks::coerce::<_, MatMut<'_, Complex<<T as faer::traits::ComplexField>::Real>>>(
186        //             right_vecs_faer,
187        //         )
188        //     };
189
190        //     let mut stack_buf = MemBuffer::new(faer::linalg::evd::evd_scratch::<T>(
191        //         n,
192        //         ComputeEigenvectors::Yes,
193        //         ComputeEigenvectors::Yes,
194        //         par,
195        //         params.into(),
196        //     ));
197        //     let stack = MemStack::new(&mut stack_buf);
198
199        //     let col0 = eigenvalues_faer.col_mut(0);
200
201        //     let col0_as_matmut: MatMut<'_, Complex<<T as faer::traits::ComplexField>::Real>> = unsafe {
202        //         faer::hacks::coerce::<_, MatMut<'_, Complex<<T as faer::traits::ComplexField>::Real>>>(
203        //             col0,
204        //         )
205        //     };
206
207        //     let diag_mut = col0_as_matmut.diagonal_mut();
208
209        //     faer::linalg::evd::evd_cplx::<<T as faer::traits::ComplexField>::Real>(
210        //         a_faer_complex,
211        //         diag_mut,
212        //         Some(left_vecs_complex.as_mut()),
213        //         Some(right_vecs_complex.as_mut()),
214        //         par,
215        //         stack,
216        //         params.into(),
217        //     )
218        // } else {
219        //     let mut s_re_mda = tensor![[x.re(); n]; 1];
220        //     let mut s_im_mda = tensor![[x.re(); n]; 1];
221
222        //     let mut right_vecs_faer = into_faer_mut(&mut right_vecs_mda);
223        //     let mut left_vecs_faer = into_faer_mut(&mut left_vecs_mda);
224        //     let mut s_re_faer = into_faer_mut(&mut s_re_mda);
225        //     let mut s_im_faer = into_faer_mut(&mut s_im_mda);
226
227        //     let a_faer_real: MatRef<'_, <T as faer::traits::ComplexField>::Real> = unsafe {
228        //         faer::hacks::coerce::<_, MatRef<'_, <T as faer::traits::ComplexField>::Real>>(
229        //             a_faer,
230        //         )
231        //     };
232
233        //     println!("ici");
234
235        //     let mut left_vecs_real: MatMut<'_, <T as faer::traits::ComplexField>::Real> = unsafe {
236        //         faer::hacks::coerce::<_, MatMut<'_, <T as faer::traits::ComplexField>::Real>>(
237        //             left_vecs_faer,
238        //         )
239        //     };
240
241        //     let mut right_vecs_real: MatMut<'_, <T as faer::traits::ComplexField>::Real> = unsafe {
242        //         faer::hacks::coerce::<_, MatMut<'_, <T as faer::traits::ComplexField>::Real>>(
243        //             right_vecs_faer,
244        //         )
245        //     };
246
247        //     let mut stack_buf = MemBuffer::new(faer::linalg::evd::evd_scratch::<T>(
248        //         n,
249        //         ComputeEigenvectors::Yes,
250        //         ComputeEigenvectors::Yes,
251        //         par,
252        //         params.into(),
253        //     ));
254        //     let stack = MemStack::new(&mut stack_buf);
255
256        //     let s_re_col0 = s_re_faer.col_mut(0);
257        //     let s_re_as_matmut: MatMut<'_, <T as faer::traits::ComplexField>::Real> = unsafe {
258        //         faer::hacks::coerce::<_, MatMut<'_, <T as faer::traits::ComplexField>::Real>>(
259        //             s_re_col0,
260        //         )
261        //     };
262        //     let s_re_diag = s_re_as_matmut.diagonal_mut();
263
264        //     let s_im_col0 = s_im_faer.col_mut(0);
265        //     let s_im_as_matmut: MatMut<'_, <T as faer::traits::ComplexField>::Real> = unsafe {
266        //         faer::hacks::coerce::<_, MatMut<'_, <T as faer::traits::ComplexField>::Real>>(
267        //             s_im_col0,
268        //         )
269        //     };
270        //     let s_im_diag = s_im_as_matmut.diagonal_mut();
271
272        //     let result = faer::linalg::evd::evd_real::<<T as faer::traits::ComplexField>::Real>(
273        //         a_faer_real,
274        //         s_re_diag,
275        //         s_im_diag,
276        //         Some(left_vecs_real.as_mut()),
277        //         Some(right_vecs_real.as_mut()),
278        //         par,
279        //         stack,
280        //         params.into(),
281        //     );
282
283        //     if result.is_ok() {
284        //         for i in 0..n {
285        //             eigenvalues_mda[[0, i]] = Complex::new(s_re_mda[[0, i]], s_im_mda[[0, i]]);
286        //         }
287        //     }
288
289        //     result
290        // };
291
292        // match eig_result {
293        //     Ok(_) => Ok(EigDecomp {
294        //         eigenvalues: eigenvalues_mda,
295        //         left_eigenvectors: Some(left_vecs_mda),
296        //         right_eigenvectors: Some(right_vecs_mda),
297        //     }),
298        //     Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
299        // }
300    }
301    /// Compute only eigenvalues with new allocated vectors
302    fn eig_values<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
303        let (m, n) = *a.shape();
304
305        if m != n {
306            return Err(EigError::NotSquareMatrix);
307        }
308
309        let a_faer = into_faer(a);
310
311        let eigenvalues_result = a_faer.eigenvalues();
312
313        match eigenvalues_result {
314            Ok(eigenvalues) => {
315                let x = T::default();
316                let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
317
318                for i in 0..n {
319                    eigenvalues_mda[[0, i]] = complex_from_faer!(&eigenvalues[i], T);
320                }
321
322                Ok(EigDecomp {
323                    eigenvalues: eigenvalues_mda,
324                    left_eigenvectors: None,
325                    right_eigenvectors: None,
326                })
327            }
328            Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
329        }
330    }
331
332    /// Compute eigenvalues and eigenvectors of a Hermitian matrix (input should be complex)
333    fn eigh<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
334        let (m, n) = *a.shape();
335
336        if m != n {
337            return Err(EigError::NotSquareMatrix);
338        }
339
340        let a_faer = into_faer(a);
341
342        let eig_result = a_faer.self_adjoint_eigen(faer::Side::Lower);
343
344        match eig_result {
345            Ok(eig) => {
346                let eigenvalues = eig.S();
347                let eigenvectors = eig.U();
348
349                let x = T::default();
350                let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
351
352                let mut eigenvalues_faer = into_faer_mut(&mut eigenvalues_mda);
353                for i in 0..n {
354                    eigenvalues_faer[(0, i)] = Complex::new(eigenvalues[i].re(), x.re());
355                }
356
357                let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
358
359                let mut eigenvectors_faer = into_faer_mut(&mut right_vecs_mda);
360                for i in 0..n {
361                    for j in 0..n {
362                        let val = eigenvectors[(i, j)];
363                        eigenvectors_faer[(i, j)] = Complex::new(val.re(), val.im());
364                    }
365                }
366
367                Ok(EigDecomp {
368                    eigenvalues: eigenvalues_mda,
369                    left_eigenvectors: None,
370                    right_eigenvectors: Some(right_vecs_mda),
371                })
372            }
373            Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
374        }
375    }
376
377    /// Compute eigenvalues and eigenvectors of a symmetric matrix (input should be real)
378    fn eigs<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
379        self.eigh(a)
380    }
381
382    /// Compute Schur decomposition with new allocated matrices
383    fn schur<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> SchurResult<T> {
384        todo!();
385    }
386
387    /// Compute Schur decomposition overwriting existing matrices
388    fn schur_overwrite<L: Layout>(
389        &self,
390        _a: &mut DSlice<T, 2, L>,
391        _t: &mut DSlice<T, 2, Dense>,
392        _z: &mut DSlice<T, 2, Dense>,
393    ) -> Result<(), SchurError> {
394        todo!();
395    }
396
397    /// Compute Schur (complex) decomposition with new allocated matrices
398    fn schur_complex<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> SchurResult<T> {
399        todo!();
400    }
401
402    /// Compute Schur (complex) decomposition overwriting existing matrices
403    fn schur_complex_overwrite<L: Layout>(
404        &self,
405        _a: &mut DSlice<T, 2, L>,
406        _t: &mut DSlice<T, 2, Dense>,
407        _z: &mut DSlice<T, 2, Dense>,
408    ) -> Result<(), SchurError> {
409        todo!();
410    }
411}