rlst/dense/linalg/lapack/
symmeig.rs

1//! Implementation of the symmetric eigenvalue decomposition using LAPACK.
2
3use crate::UnsafeRandom1DAccessByValue;
4use crate::base_types::{RlstResult, UpLo};
5use crate::dense::array::{Array, DynArray};
6use crate::dense::linalg::lapack::interface::ev::{self, Ev, EvUplo};
7use crate::traits::base_operations::Shape;
8use crate::traits::linalg::decompositions::SymmEig;
9use crate::traits::linalg::lapack::Lapack;
10
11/// Symmetric eigenvalue decomposition mode.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SymmEigMode {
14    /// Compute the eigenvalues only.
15    EigenvaluesOnly,
16    /// Compute the eigenvalues and eigenvectors.
17    EigenvaluesAndEigenvectors,
18}
19
20impl<Item, ArrayImpl> SymmEig for Array<ArrayImpl, 2>
21where
22    ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item> + Shape<2>,
23    Item: Lapack,
24{
25    type Item = Item;
26
27    fn eigh(
28        &self,
29        uplo: UpLo,
30        mode: SymmEigMode,
31    ) -> RlstResult<(DynArray<Item::Real, 1>, Option<DynArray<Item, 2>>)> {
32        let m = self.shape()[0];
33        let n = self.shape()[1];
34        assert_eq!(
35            m, n,
36            "Matrix must be square for symmetric eigenvalue decomposition."
37        );
38
39        let mut a = DynArray::new_from(self);
40
41        let mut w = DynArray::from_shape([n]);
42
43        let uplo = match uplo {
44            UpLo::Upper => EvUplo::Upper,
45            UpLo::Lower => EvUplo::Lower,
46        };
47
48        let jobz = match mode {
49            SymmEigMode::EigenvaluesOnly => ev::JobZEv::None,
50            SymmEigMode::EigenvaluesAndEigenvectors => ev::JobZEv::Compute,
51        };
52
53        <Item as Ev>::ev(
54            jobz,
55            uplo,
56            n,
57            a.data_mut().unwrap(),
58            n,
59            w.data_mut().unwrap(),
60        )?;
61
62        match mode {
63            SymmEigMode::EigenvaluesOnly => Ok((w, None)),
64            SymmEigMode::EigenvaluesAndEigenvectors => Ok((w, Some(a))),
65        }
66    }
67}
68
69#[cfg(test)]
70mod test {
71
72    use super::*;
73    use crate::base_types::{c32, c64};
74    use crate::dense::array::DynArray;
75    use crate::dot;
76
77    use crate::RlstScalar;
78    use itertools::izip;
79    use paste::paste;
80
81    macro_rules! implement_symm_eig_test {
82        ($scalar:ty, $tol:expr) => {
83            paste! {
84
85            #[test]
86            fn [<symm_eig_test_$scalar>]() {
87                let n = 10;
88                let mut a = DynArray::<$scalar, 2>::from_shape([n, n]);
89                a.fill_from_seed_equally_distributed(0);
90
91                let a = DynArray::new_from(&(a.r() + a.r().conj().transpose()));
92
93                let (w1, _) = a
94                    .eigh(UpLo::Upper, SymmEigMode::EigenvaluesOnly)
95                    .unwrap();
96
97                let (w2, v) = a
98                    .eigh(UpLo::Upper, SymmEigMode::EigenvaluesAndEigenvectors)
99                    .unwrap();
100
101                let v = v.unwrap();
102
103                crate::assert_array_relative_eq!(w1, w2, $tol);
104
105                let mut lambda = DynArray::<$scalar, 2>::from_shape([n, n]);
106
107                izip!(lambda.diag_iter_mut(), w1.iter_value()).for_each(|(v_elem, w_elem)| {
108                    *v_elem = RlstScalar::from_real(w_elem);
109                });
110
111                let vt = DynArray::new_from(
112                    &v.r().conj().transpose(),
113                );
114
115                let actual = dot!(v.r(), dot!(lambda.r(), vt.r()));
116
117                crate::assert_array_relative_eq!(actual, a, $tol);
118            }
119
120                    }
121        };
122    }
123
124    implement_symm_eig_test!(f32, 1E-4);
125    implement_symm_eig_test!(f64, 1E-10);
126    implement_symm_eig_test!(c32, 1E-4);
127    implement_symm_eig_test!(c64, 1E-10);
128}