1use crate::matrix::DenseMatrix;
8use crate::Scalar;
9use faer::linalg::solvers::SelfAdjointEigendecomposition;
10use faer::{ComplexField, Conjugate, Entity, SimpleEntity};
11use numra_core::LinalgError;
12
13pub struct SymEigenDecomposition<S: Scalar + Entity> {
17 evd: SelfAdjointEigendecomposition<S>,
18 n: usize,
19}
20
21impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SymEigenDecomposition<S> {
22 pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
26 let nrows = matrix.rows();
27 let ncols = matrix.cols();
28 if nrows != ncols {
29 return Err(LinalgError::NotSquare { nrows, ncols });
30 }
31 let n = nrows;
32 let evd = SelfAdjointEigendecomposition::new(matrix.as_faer(), faer::Side::Lower);
33 Ok(Self { evd, n })
34 }
35
36 pub fn eigenvalues(&self) -> Vec<S> {
38 let s = self.evd.s();
39 (0..self.n).map(|i| s.column_vector().read(i)).collect()
40 }
41
42 pub fn eigenvectors(&self) -> DenseMatrix<S> {
44 DenseMatrix::from_faer(self.evd.u().to_owned())
45 }
46
47 pub fn dim(&self) -> usize {
49 self.n
50 }
51}
52
53pub struct EigenDecomposition<S: Scalar> {
58 eigenvalues_re: Vec<S>,
59 eigenvalues_im: Vec<S>,
60 n: usize,
61}
62
63impl EigenDecomposition<f64> {
64 pub fn new(matrix: &DenseMatrix<f64>) -> Result<Self, LinalgError> {
69 let nrows = matrix.rows();
70 let ncols = matrix.cols();
71 if nrows != ncols {
72 return Err(LinalgError::NotSquare { nrows, ncols });
73 }
74 let n = nrows;
75
76 let evals: Vec<faer::complex_native::c64> =
77 matrix.as_faer().eigenvalues::<faer::complex_native::c64>();
78
79 let eigenvalues_re: Vec<f64> = evals.iter().map(|c| c.re).collect();
80 let eigenvalues_im: Vec<f64> = evals.iter().map(|c| c.im).collect();
81
82 Ok(Self {
83 eigenvalues_re,
84 eigenvalues_im,
85 n,
86 })
87 }
88}
89
90impl EigenDecomposition<f32> {
91 pub fn new(matrix: &DenseMatrix<f32>) -> Result<Self, LinalgError> {
93 let nrows = matrix.rows();
94 let ncols = matrix.cols();
95 if nrows != ncols {
96 return Err(LinalgError::NotSquare { nrows, ncols });
97 }
98 let n = nrows;
99
100 let evals: Vec<faer::complex_native::c32> =
101 matrix.as_faer().eigenvalues::<faer::complex_native::c32>();
102
103 let eigenvalues_re: Vec<f32> = evals.iter().map(|c| c.re).collect();
104 let eigenvalues_im: Vec<f32> = evals.iter().map(|c| c.im).collect();
105
106 Ok(Self {
107 eigenvalues_re,
108 eigenvalues_im,
109 n,
110 })
111 }
112}
113
114impl<S: Scalar> EigenDecomposition<S> {
115 pub fn eigenvalues_real(&self) -> &[S] {
117 &self.eigenvalues_re
118 }
119
120 pub fn eigenvalues_imag(&self) -> &[S] {
122 &self.eigenvalues_im
123 }
124
125 pub fn dim(&self) -> usize {
127 self.n
128 }
129}
130
131impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> DenseMatrix<S> {
133 pub fn eigh(&self) -> Result<SymEigenDecomposition<S>, LinalgError> {
135 SymEigenDecomposition::new(self)
136 }
137
138 pub fn eigvalsh(&self) -> Result<Vec<S>, LinalgError> {
140 let evd = SymEigenDecomposition::new(self)?;
141 Ok(evd.eigenvalues())
142 }
143}
144
145impl DenseMatrix<f64> {
146 pub fn eigvals(&self) -> Result<(Vec<f64>, Vec<f64>), LinalgError> {
148 let evd = EigenDecomposition::<f64>::new(self)?;
149 Ok((evd.eigenvalues_re, evd.eigenvalues_im))
150 }
151}
152
153impl DenseMatrix<f32> {
154 pub fn eigvals(&self) -> Result<(Vec<f32>, Vec<f32>), LinalgError> {
156 let evd = EigenDecomposition::<f32>::new(self)?;
157 Ok((evd.eigenvalues_re, evd.eigenvalues_im))
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::Matrix;
165
166 #[test]
167 fn test_sym_eigen_2x2() {
168 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
170 a.set(0, 0, 2.0);
171 a.set(0, 1, 1.0);
172 a.set(1, 0, 1.0);
173 a.set(1, 1, 2.0);
174
175 let evd = SymEigenDecomposition::<f64>::new(&a).unwrap();
176 let eigenvalues = evd.eigenvalues();
177
178 assert!((eigenvalues[0] - 1.0).abs() < 1e-10);
180 assert!((eigenvalues[1] - 3.0).abs() < 1e-10);
181 }
182
183 #[test]
184 fn test_sym_eigen_3x3() {
185 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
187 a.set(0, 0, 5.0);
188 a.set(1, 1, 2.0);
189 a.set(2, 2, 8.0);
190
191 let evd = SymEigenDecomposition::<f64>::new(&a).unwrap();
192 let eigenvalues = evd.eigenvalues();
193
194 assert!((eigenvalues[0] - 2.0).abs() < 1e-10);
196 assert!((eigenvalues[1] - 5.0).abs() < 1e-10);
197 assert!((eigenvalues[2] - 8.0).abs() < 1e-10);
198 }
199
200 #[test]
201 fn test_eigenvectors_orthogonal() {
202 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
204 a.set(0, 0, 2.0);
205 a.set(0, 1, 1.0);
206 a.set(1, 0, 1.0);
207 a.set(1, 1, 2.0);
208
209 let evd = SymEigenDecomposition::<f64>::new(&a).unwrap();
210 let v = evd.eigenvectors();
211
212 let n = 2;
214 for i in 0..n {
215 for j in 0..n {
216 let mut dot = 0.0;
217 for k in 0..n {
218 dot += v.get(k, i) * v.get(k, j);
219 }
220 let expected = if i == j { 1.0 } else { 0.0 };
221 assert!(
222 (dot - expected).abs() < 1e-10,
223 "V^T V not identity at ({}, {}): {} vs {}",
224 i,
225 j,
226 dot,
227 expected
228 );
229 }
230 }
231 }
232
233 #[test]
234 fn test_sym_eigen_reconstruction() {
235 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
237 a.set(0, 0, 4.0);
238 a.set(0, 1, 1.0);
239 a.set(0, 2, 0.0);
240 a.set(1, 0, 1.0);
241 a.set(1, 1, 3.0);
242 a.set(1, 2, 1.0);
243 a.set(2, 0, 0.0);
244 a.set(2, 1, 1.0);
245 a.set(2, 2, 2.0);
246
247 let evd = SymEigenDecomposition::<f64>::new(&a).unwrap();
248 let eigenvalues = evd.eigenvalues();
249 let v = evd.eigenvectors();
250
251 let n = 3;
252 for i in 0..n {
253 for j in 0..n {
254 let mut val = 0.0;
255 for k in 0..n {
256 val += v.get(i, k) * eigenvalues[k] * v.get(j, k);
257 }
258 assert!(
259 (val - a.get(i, j)).abs() < 1e-10,
260 "Reconstruction failed at ({}, {}): {} vs {}",
261 i,
262 j,
263 val,
264 a.get(i, j)
265 );
266 }
267 }
268 }
269
270 #[test]
271 fn test_general_eigenvalues_rotation() {
272 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
275 a.set(0, 0, 0.0);
276 a.set(0, 1, -1.0);
277 a.set(1, 0, 1.0);
278 a.set(1, 1, 0.0);
279
280 let evd = EigenDecomposition::<f64>::new(&a).unwrap();
281 let re = evd.eigenvalues_real();
282 let im = evd.eigenvalues_imag();
283
284 assert!(re[0].abs() < 1e-10);
286 assert!(re[1].abs() < 1e-10);
287
288 let mut im_sorted = vec![im[0], im[1]];
290 im_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
291 assert!((im_sorted[0] - (-1.0)).abs() < 1e-10);
292 assert!((im_sorted[1] - 1.0).abs() < 1e-10);
293 }
294
295 #[test]
296 fn test_general_eigenvalues_diagonal() {
297 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
299 a.set(0, 0, 2.0);
300 a.set(1, 1, 5.0);
301 a.set(2, 2, -1.0);
302
303 let evd = EigenDecomposition::<f64>::new(&a).unwrap();
304 let re = evd.eigenvalues_real();
305 let im = evd.eigenvalues_imag();
306
307 for &v in im.iter() {
309 assert!(v.abs() < 1e-10);
310 }
311
312 let mut re_sorted = re.to_vec();
314 re_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
315 assert!((re_sorted[0] - (-1.0)).abs() < 1e-10);
316 assert!((re_sorted[1] - 2.0).abs() < 1e-10);
317 assert!((re_sorted[2] - 5.0).abs() < 1e-10);
318 }
319
320 #[test]
321 fn test_convenience_eigh() {
322 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
323 a.set(0, 0, 3.0);
324 a.set(1, 1, 7.0);
325
326 let eigenvalues = a.eigvalsh().unwrap();
327 assert!((eigenvalues[0] - 3.0).abs() < 1e-10);
328 assert!((eigenvalues[1] - 7.0).abs() < 1e-10);
329 }
330
331 #[test]
332 fn test_convenience_eigvals() {
333 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
334 a.set(0, 0, 1.0);
335 a.set(1, 1, 2.0);
336
337 let (re, im) = a.eigvals().unwrap();
338 for &v in im.iter() {
339 assert!(v.abs() < 1e-10);
340 }
341 let mut re_sorted = re.clone();
342 re_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
343 assert!((re_sorted[0] - 1.0).abs() < 1e-10);
344 assert!((re_sorted[1] - 2.0).abs() < 1e-10);
345 }
346
347 #[test]
348 fn test_not_square_error() {
349 let a: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
350 assert!(SymEigenDecomposition::new(&a).is_err());
351 assert!(EigenDecomposition::<f64>::new(&a).is_err());
352 }
353
354 #[test]
355 fn test_sym_eigen_f32() {
356 let mut a: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
357 a.set(0, 0, 4.0);
358 a.set(0, 1, 0.0);
359 a.set(1, 0, 0.0);
360 a.set(1, 1, 9.0);
361
362 let evd = SymEigenDecomposition::new(&a).unwrap();
363 let eigenvalues = evd.eigenvalues();
364
365 assert!((eigenvalues[0] - 4.0f32).abs() < 1e-5);
366 assert!((eigenvalues[1] - 9.0f32).abs() < 1e-5);
367 }
368
369 #[test]
370 fn test_general_eigen_f32() {
371 let mut a: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
372 a.set(0, 0, 0.0);
373 a.set(0, 1, -1.0);
374 a.set(1, 0, 1.0);
375 a.set(1, 1, 0.0);
376
377 let evd = EigenDecomposition::<f32>::new(&a).unwrap();
378 let re = evd.eigenvalues_real();
379 let im = evd.eigenvalues_imag();
380
381 assert!(re[0].abs() < 1e-5);
382 assert!(re[1].abs() < 1e-5);
383
384 let mut im_sorted: Vec<f32> = vec![im[0], im[1]];
385 im_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
386 assert!((im_sorted[0] - (-1.0f32)).abs() < 1e-5);
387 assert!((im_sorted[1] - 1.0f32).abs() < 1e-5);
388 }
389}