rlst/dense/linalg/lapack/
symmeig.rs1use crate::UnsafeRandom1DAccessByValue;
4use crate::base_types::{RlstResult, UpLo};
5use crate::dense::array::{Array, DynArray};
6use crate::dense::linalg::lapack::interface::ev::{self, Ev, EvUplo};
7use crate::traits::base_operations::Shape;
8use crate::traits::linalg::decompositions::SymmEig;
9use crate::traits::linalg::lapack::Lapack;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SymmEigMode {
14 EigenvaluesOnly,
16 EigenvaluesAndEigenvectors,
18}
19
20impl<Item, ArrayImpl> SymmEig for Array<ArrayImpl, 2>
21where
22 ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item> + Shape<2>,
23 Item: Lapack,
24{
25 type Item = Item;
26
27 fn eigh(
28 &self,
29 uplo: UpLo,
30 mode: SymmEigMode,
31 ) -> RlstResult<(DynArray<Item::Real, 1>, Option<DynArray<Item, 2>>)> {
32 let m = self.shape()[0];
33 let n = self.shape()[1];
34 assert_eq!(
35 m, n,
36 "Matrix must be square for symmetric eigenvalue decomposition."
37 );
38
39 let mut a = DynArray::new_from(self);
40
41 let mut w = DynArray::from_shape([n]);
42
43 let uplo = match uplo {
44 UpLo::Upper => EvUplo::Upper,
45 UpLo::Lower => EvUplo::Lower,
46 };
47
48 let jobz = match mode {
49 SymmEigMode::EigenvaluesOnly => ev::JobZEv::None,
50 SymmEigMode::EigenvaluesAndEigenvectors => ev::JobZEv::Compute,
51 };
52
53 <Item as Ev>::ev(
54 jobz,
55 uplo,
56 n,
57 a.data_mut().unwrap(),
58 n,
59 w.data_mut().unwrap(),
60 )?;
61
62 match mode {
63 SymmEigMode::EigenvaluesOnly => Ok((w, None)),
64 SymmEigMode::EigenvaluesAndEigenvectors => Ok((w, Some(a))),
65 }
66 }
67}
68
69#[cfg(test)]
70mod test {
71
72 use super::*;
73 use crate::base_types::{c32, c64};
74 use crate::dense::array::DynArray;
75 use crate::dot;
76
77 use crate::RlstScalar;
78 use itertools::izip;
79 use paste::paste;
80
81 macro_rules! implement_symm_eig_test {
82 ($scalar:ty, $tol:expr) => {
83 paste! {
84
85 #[test]
86 fn [<symm_eig_test_$scalar>]() {
87 let n = 10;
88 let mut a = DynArray::<$scalar, 2>::from_shape([n, n]);
89 a.fill_from_seed_equally_distributed(0);
90
91 let a = DynArray::new_from(&(a.r() + a.r().conj().transpose()));
92
93 let (w1, _) = a
94 .eigh(UpLo::Upper, SymmEigMode::EigenvaluesOnly)
95 .unwrap();
96
97 let (w2, v) = a
98 .eigh(UpLo::Upper, SymmEigMode::EigenvaluesAndEigenvectors)
99 .unwrap();
100
101 let v = v.unwrap();
102
103 crate::assert_array_relative_eq!(w1, w2, $tol);
104
105 let mut lambda = DynArray::<$scalar, 2>::from_shape([n, n]);
106
107 izip!(lambda.diag_iter_mut(), w1.iter_value()).for_each(|(v_elem, w_elem)| {
108 *v_elem = RlstScalar::from_real(w_elem);
109 });
110
111 let vt = DynArray::new_from(
112 &v.r().conj().transpose(),
113 );
114
115 let actual = dot!(v.r(), dot!(lambda.r(), vt.r()));
116
117 crate::assert_array_relative_eq!(actual, a, $tol);
118 }
119
120 }
121 };
122 }
123
124 implement_symm_eig_test!(f32, 1E-4);
125 implement_symm_eig_test!(f64, 1E-10);
126 implement_symm_eig_test!(c32, 1E-4);
127 implement_symm_eig_test!(c64, 1E-10);
128}