faer_evd/
lib.rs

1//! The eigenvalue decomposition of a square matrix $M$ of shape $(n, n)$ is a decomposition into
2//! two components $U$, $S$:
3//!
4//! - $U$ has shape $(n, n)$ and is invertible,
5//! - $S$ has shape $(n, n)$ and is a diagonal matrix,
6//! - and finally:
7//!
8//! $$M = U S U^{-1}.$$
9//!
10//! If $M$ is hermitian, then $U$ can be made unitary ($U^{-1} = U^H$), and $S$ is real valued.
11
12#![allow(clippy::type_complexity)]
13#![allow(clippy::too_many_arguments)]
14#![cfg_attr(not(feature = "std"), no_std)]
15
16use coe::Coerce;
17use dyn_stack::{PodStack, SizeOverflow, StackReq};
18use faer_core::{
19    assert,
20    householder::{
21        apply_block_householder_sequence_on_the_right_in_place_req,
22        apply_block_householder_sequence_on_the_right_in_place_with_conj,
23        upgrade_householder_factor,
24    },
25    mul::{
26        inner_prod::inner_prod_with_conj,
27        triangular::{self, BlockStructure},
28    },
29    temp_mat_req, temp_mat_uninit, temp_mat_zeroed, unzipped, zipped, ComplexField, Conj, MatMut,
30    MatRef, Parallelism, RealField,
31};
32use faer_qr::no_pivoting::compute::recommended_blocksize;
33pub use hessenberg_cplx_evd::EvdParams;
34use reborrow::*;
35
36#[doc(hidden)]
37pub mod tridiag_qr_algorithm;
38
39#[doc(hidden)]
40pub mod tridiag_real_evd;
41
42#[doc(hidden)]
43pub mod tridiag;
44
45#[doc(hidden)]
46pub mod hessenberg;
47
48#[doc(hidden)]
49pub mod hessenberg_cplx_evd;
50#[doc(hidden)]
51pub mod hessenberg_real_evd;
52
53/// Indicates whether the eigenvectors are fully computed, partially computed, or skipped.
54#[derive(Copy, Clone, Debug, PartialEq, Eq)]
55pub enum ComputeVectors {
56    No,
57    Yes,
58}
59
60#[derive(Default, Copy, Clone)]
61#[non_exhaustive]
62pub struct SymmetricEvdParams {}
63
64/// Computes the size and alignment of required workspace for performing a hermitian eigenvalue
65/// decomposition. The eigenvectors may be optionally computed.
66pub fn compute_hermitian_evd_req<E: ComplexField>(
67    n: usize,
68    compute_eigenvectors: ComputeVectors,
69    parallelism: Parallelism,
70    params: SymmetricEvdParams,
71) -> Result<StackReq, SizeOverflow> {
72    let _ = params;
73    let _ = compute_eigenvectors;
74    let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(n, n);
75
76    let cplx_storage = if coe::is_same::<E::Real, E>() {
77        StackReq::empty()
78    } else {
79        StackReq::try_all_of([
80            temp_mat_req::<E::Real>(n, n)?,
81            StackReq::try_new::<E::Real>(n)?,
82        ])?
83    };
84
85    StackReq::try_all_of([
86        temp_mat_req::<E>(n, n)?,
87        temp_mat_req::<E>(householder_blocksize, n - 1)?,
88        StackReq::try_any_of([
89            tridiag::tridiagonalize_in_place_req::<E>(n, parallelism)?,
90            StackReq::try_all_of([
91                StackReq::try_new::<E::Real>(n)?,
92                StackReq::try_new::<E::Real>(n - 1)?,
93                tridiag_real_evd::compute_tridiag_real_evd_req::<E>(n, parallelism)?,
94                cplx_storage,
95            ])?,
96            faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_req::<E>(
97                n - 1,
98                householder_blocksize,
99                n,
100            )?,
101        ])?,
102    ])
103}
104
105/// Computes the eigenvalue decomposition of a square hermitian `matrix`. Only the lower triangular
106/// half of the matrix is accessed.
107///
108/// `s` represents the diagonal of the matrix $S$, and must have size equal to the dimension of the
109/// matrix.
110///
111/// If `u` is `None`, then only the eigenvalues are computed. Otherwise, the eigenvectors are
112/// computed and stored in `u`.
113///
114/// # Panics
115/// Panics if any of the conditions described above is violated, or if the type `E` does not have a
116/// fixed precision at compile time, e.g. a dynamic multiprecision floating point type.
117///
118/// This can also panic if the provided memory in `stack` is insufficient (see
119/// [`compute_hermitian_evd_req`]).
120pub fn compute_hermitian_evd<E: ComplexField>(
121    matrix: MatRef<'_, E>,
122    s: MatMut<'_, E>,
123    u: Option<MatMut<'_, E>>,
124    parallelism: Parallelism,
125    stack: PodStack<'_>,
126    params: SymmetricEvdParams,
127) {
128    compute_hermitian_evd_custom_epsilon(
129        matrix,
130        s,
131        u,
132        E::Real::faer_epsilon().unwrap(),
133        E::Real::faer_zero_threshold().unwrap(),
134        parallelism,
135        stack,
136        params,
137    );
138}
139
140/// See [`compute_hermitian_evd`].
141///
142/// This function takes an additional `epsilon` and `zero_threshold` parameters. `epsilon`
143/// represents the precision of the values in the matrix, and `zero_threshold` is the value below
144/// which the precision starts to deteriorate, e.g. due to denormalized numbers.
145///
146/// These values need to be provided manually for types that do not have a known precision at
147/// compile time, e.g. a dynamic multiprecision floating point type.
148pub fn compute_hermitian_evd_custom_epsilon<E: ComplexField>(
149    matrix: MatRef<'_, E>,
150    s: MatMut<'_, E>,
151    u: Option<MatMut<'_, E>>,
152    epsilon: E::Real,
153    zero_threshold: E::Real,
154    parallelism: Parallelism,
155    stack: PodStack<'_>,
156    params: SymmetricEvdParams,
157) {
158    let _ = params;
159    let n = matrix.nrows();
160
161    assert!(all(
162        matrix.nrows() == matrix.ncols(),
163        s.nrows() == n,
164        s.ncols() == 1
165    ));
166    if let Some(u) = u.rb() {
167        assert!(all(u.nrows() == n, u.ncols() == n));
168    }
169
170    if n == 0 {
171        return;
172    }
173
174    #[cfg(feature = "perf-warn")]
175    if let Some(matrix) = u.rb() {
176        if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) {
177            if matrix.col_stride().unsigned_abs() == 1 {
178                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix.");
179            } else {
180                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found matrix with generic strides.");
181            }
182        }
183    }
184
185    let mut all_finite = true;
186    zipped!(matrix).for_each_triangular_lower(faer_core::zip::Diag::Include, |unzipped!(x)| {
187        all_finite &= x.read().faer_is_finite();
188    });
189
190    if !all_finite {
191        { s }.fill(E::faer_nan());
192        if let Some(mut u) = u {
193            u.fill(E::faer_nan());
194        }
195        return;
196    }
197
198    let (mut trid, stack) = temp_mat_uninit::<E>(n, n, stack);
199    let householder_blocksize =
200        faer_qr::no_pivoting::compute::recommended_blocksize::<E>(n - 1, n - 1);
201
202    let (mut householder, mut stack) = temp_mat_uninit::<E>(householder_blocksize, n - 1, stack);
203    let mut householder = householder.as_mut();
204
205    let mut trid = trid.as_mut();
206
207    zipped!(trid.rb_mut(), matrix)
208        .for_each_triangular_lower(faer_core::zip::Diag::Include, |unzipped!(mut dst, src)| {
209            dst.write(src.read())
210        });
211
212    tridiag::tridiagonalize_in_place(
213        trid.rb_mut(),
214        householder.rb_mut().transpose_mut(),
215        parallelism,
216        stack.rb_mut(),
217    );
218
219    let trid = trid.into_const();
220    let mut s = s;
221
222    let mut u = match u {
223        Some(u) => u,
224        None => {
225            let (diag, stack) = stack.rb_mut().make_with(n, |i| trid.read(i, i).faer_real());
226            let (offdiag, _) = stack.make_with(n - 1, |i| trid.read(i + 1, i).faer_abs());
227            tridiag_qr_algorithm::compute_tridiag_real_evd_qr_algorithm(
228                diag,
229                offdiag,
230                None,
231                epsilon,
232                zero_threshold,
233            );
234            for (i, &diag) in diag.iter().enumerate() {
235                s.write(i, 0, E::faer_from_real(diag));
236            }
237
238            return;
239        }
240    };
241
242    let mut j_base = 0;
243    while j_base < n - 1 {
244        let bs = Ord::min(householder_blocksize, n - 1 - j_base);
245        let mut householder = householder.rb_mut().submatrix_mut(0, j_base, bs, bs);
246        let full_essentials = trid.submatrix(1, 0, n - 1, n);
247        let essentials = full_essentials.submatrix(j_base, j_base, n - 1 - j_base, bs);
248        for j in 0..bs {
249            householder.write(j, j, householder.read(0, j));
250        }
251        upgrade_householder_factor(householder, essentials, bs, 1, parallelism);
252        j_base += bs;
253    }
254
255    {
256        let (diag, stack) = stack.rb_mut().make_with(n, |i| trid.read(i, i).faer_real());
257
258        if coe::is_same::<E::Real, E>() {
259            let (offdiag, stack) = stack.make_with(n - 1, |i| trid.read(i + 1, i).faer_real());
260
261            tridiag_real_evd::compute_tridiag_real_evd::<E::Real>(
262                diag,
263                offdiag,
264                u.rb_mut().coerce(),
265                epsilon,
266                zero_threshold,
267                parallelism,
268                stack,
269            );
270        } else {
271            let (offdiag, stack) = stack.make_with(n - 1, |i| trid.read(i + 1, i).faer_abs());
272
273            let (mut u_real, stack) = temp_mat_uninit::<E::Real>(n, n, stack);
274            let (mul, stack) = stack.make_with(n, |_| E::faer_zero());
275
276            let normalized = |x: E| {
277                if x == E::faer_zero() {
278                    E::faer_one()
279                } else {
280                    x.faer_scale_real(x.faer_abs().faer_inv())
281                }
282            };
283
284            mul[0] = E::faer_one();
285
286            let mut x = E::faer_one();
287            for (i, mul) in mul.iter_mut().enumerate().skip(1) {
288                x = normalized(trid.read(i, i - 1).faer_mul(x.faer_conj())).faer_conj();
289                *mul = x.faer_conj();
290            }
291
292            tridiag_real_evd::compute_tridiag_real_evd::<E::Real>(
293                diag,
294                offdiag,
295                u_real.rb_mut(),
296                epsilon,
297                zero_threshold,
298                parallelism,
299                stack,
300            );
301
302            for j in 0..n {
303                for (i, &mul) in mul.iter().enumerate() {
304                    unsafe {
305                        u.write_unchecked(i, j, mul.faer_scale_real(u_real.read_unchecked(i, j)))
306                    };
307                }
308            }
309        }
310
311        for (i, &diag) in diag.iter().enumerate() {
312            s.write(i, 0, E::faer_from_real(diag));
313        }
314    }
315
316    let mut m = faer_core::Mat::<E>::zeros(n, n);
317    for i in 0..n {
318        m.write(i, i, s.read(i, 0));
319    }
320
321    faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
322        trid.submatrix(1, 0, n - 1, n - 1),
323        householder.rb(),
324        Conj::No,
325        u.rb_mut().subrows_mut(1, n - 1),
326        parallelism,
327        stack.rb_mut(),
328    );
329}
330
331/// Computes the eigenvalue decomposition of a square real `matrix`.
332///
333/// `s_re` and `s_im` respectively represent the real and imaginary parts of the diagonal of the
334/// matrix $S$, and must have size equal to the dimension of the matrix.
335///
336/// If `u` is `None`, then only the eigenvalues are computed. Otherwise, the eigenvectors are
337/// computed and stored in `u`.
338///
339/// The eigenvectors are stored as follows, for each real eigenvalue, the corresponding column of
340/// the eigenvector matrix is the corresponding eigenvector.
341///
342/// For each complex eigenvalue pair $a + ib$ and $a - ib$ at indices `k` and `k + 1`, the
343/// eigenvalues are stored consecutively. And the real and imaginary parts of the eigenvector
344/// corresponding to the eigenvalue $a + ib$ are stored at indices `k` and `k+1`. The eigenvector
345/// corresponding to $a - ib$ can be computed as the conjugate of that vector.
346///
347/// # Panics
348/// Panics if any of the conditions described above is violated, or if the type `E` does not have a
349/// fixed precision at compile time, e.g. a dynamic multiprecision floating point type.
350///
351/// This can also panic if the provided memory in `stack` is insufficient (see [`compute_evd_req`]).
352pub fn compute_evd_real<E: RealField>(
353    matrix: MatRef<'_, E>,
354    s_re: MatMut<'_, E>,
355    s_im: MatMut<'_, E>,
356    u: Option<MatMut<'_, E>>,
357    parallelism: Parallelism,
358    stack: PodStack<'_>,
359    params: EvdParams,
360) {
361    compute_evd_real_custom_epsilon(
362        matrix,
363        s_re,
364        s_im,
365        u,
366        E::faer_epsilon().unwrap(),
367        E::faer_zero_threshold().unwrap(),
368        parallelism,
369        stack,
370        params,
371    );
372}
373
374fn dot2<E: RealField>(lhs0: MatRef<'_, E>, lhs1: MatRef<'_, E>, rhs: MatRef<'_, E>) -> (E, E) {
375    assert!(lhs0.ncols() == 1);
376    assert!(lhs1.ncols() == 1);
377    assert!(rhs.ncols() == 1);
378    let n = rhs.nrows();
379    assert!(lhs0.nrows() == n);
380    assert!(lhs1.nrows() == n);
381
382    let mut acc00 = E::faer_zero();
383    let mut acc01 = E::faer_zero();
384    let mut acc02 = E::faer_zero();
385    let mut acc03 = E::faer_zero();
386
387    let mut acc10 = E::faer_zero();
388    let mut acc11 = E::faer_zero();
389    let mut acc12 = E::faer_zero();
390    let mut acc13 = E::faer_zero();
391
392    let n4 = n / 4 * 4;
393
394    let mut i = 0;
395    unsafe {
396        while i < n4 {
397            acc00 = acc00.faer_add(E::faer_mul(
398                lhs0.read_unchecked(i, 0),
399                rhs.read_unchecked(i, 0),
400            ));
401            acc01 = acc01.faer_add(E::faer_mul(
402                lhs0.read_unchecked(i + 1, 0),
403                rhs.read_unchecked(i + 1, 0),
404            ));
405            acc02 = acc02.faer_add(E::faer_mul(
406                lhs0.read_unchecked(i + 2, 0),
407                rhs.read_unchecked(i + 2, 0),
408            ));
409            acc03 = acc03.faer_add(E::faer_mul(
410                lhs0.read_unchecked(i + 3, 0),
411                rhs.read_unchecked(i + 3, 0),
412            ));
413
414            acc10 = acc10.faer_add(E::faer_mul(
415                lhs1.read_unchecked(i, 0),
416                rhs.read_unchecked(i, 0),
417            ));
418            acc11 = acc11.faer_add(E::faer_mul(
419                lhs1.read_unchecked(i + 1, 0),
420                rhs.read_unchecked(i + 1, 0),
421            ));
422            acc12 = acc12.faer_add(E::faer_mul(
423                lhs1.read_unchecked(i + 2, 0),
424                rhs.read_unchecked(i + 2, 0),
425            ));
426            acc13 = acc13.faer_add(E::faer_mul(
427                lhs1.read_unchecked(i + 3, 0),
428                rhs.read_unchecked(i + 3, 0),
429            ));
430
431            i += 4;
432        }
433        while i < n {
434            acc00 = acc00.faer_add(E::faer_mul(
435                lhs0.read_unchecked(i, 0),
436                rhs.read_unchecked(i, 0),
437            ));
438            acc10 = acc10.faer_add(E::faer_mul(
439                lhs1.read_unchecked(i, 0),
440                rhs.read_unchecked(i, 0),
441            ));
442
443            i += 1;
444        }
445    }
446
447    (
448        E::faer_add(acc00.faer_add(acc01), acc02.faer_add(acc03)),
449        E::faer_add(acc10.faer_add(acc11), acc12.faer_add(acc13)),
450    )
451}
452
453fn dot4<E: RealField>(
454    lhs0: MatRef<'_, E>,
455    lhs1: MatRef<'_, E>,
456    rhs0: MatRef<'_, E>,
457    rhs1: MatRef<'_, E>,
458) -> (E, E, E, E) {
459    assert!(lhs0.ncols() == 1);
460    assert!(lhs1.ncols() == 1);
461    let n = rhs0.nrows();
462    assert!(lhs0.nrows() == n);
463    assert!(lhs1.nrows() == n);
464    assert!(rhs0.nrows() == n);
465    assert!(rhs1.nrows() == n);
466
467    let mut acc00 = E::faer_zero();
468    let mut acc01 = E::faer_zero();
469
470    let mut acc10 = E::faer_zero();
471    let mut acc11 = E::faer_zero();
472
473    let mut acc20 = E::faer_zero();
474    let mut acc21 = E::faer_zero();
475
476    let mut acc30 = E::faer_zero();
477    let mut acc31 = E::faer_zero();
478
479    let n2 = n / 2 * 2;
480
481    let mut i = 0;
482    unsafe {
483        while i < n2 {
484            acc00 = acc00.faer_add(E::faer_mul(
485                lhs0.read_unchecked(i, 0),
486                rhs0.read_unchecked(i, 0),
487            ));
488            acc01 = acc01.faer_add(E::faer_mul(
489                lhs0.read_unchecked(i + 1, 0),
490                rhs0.read_unchecked(i + 1, 0),
491            ));
492
493            acc10 = acc10.faer_add(E::faer_mul(
494                lhs1.read_unchecked(i, 0),
495                rhs0.read_unchecked(i, 0),
496            ));
497            acc11 = acc11.faer_add(E::faer_mul(
498                lhs1.read_unchecked(i + 1, 0),
499                rhs0.read_unchecked(i + 1, 0),
500            ));
501
502            acc20 = acc20.faer_add(E::faer_mul(
503                lhs0.read_unchecked(i, 0),
504                rhs1.read_unchecked(i, 0),
505            ));
506            acc21 = acc21.faer_add(E::faer_mul(
507                lhs0.read_unchecked(i + 1, 0),
508                rhs1.read_unchecked(i + 1, 0),
509            ));
510
511            acc30 = acc30.faer_add(E::faer_mul(
512                lhs1.read_unchecked(i, 0),
513                rhs1.read_unchecked(i, 0),
514            ));
515            acc31 = acc31.faer_add(E::faer_mul(
516                lhs1.read_unchecked(i + 1, 0),
517                rhs1.read_unchecked(i + 1, 0),
518            ));
519
520            i += 2;
521        }
522        while i < n {
523            acc00 = acc00.faer_add(E::faer_mul(
524                lhs0.read_unchecked(i, 0),
525                rhs0.read_unchecked(i, 0),
526            ));
527            acc10 = acc10.faer_add(E::faer_mul(
528                lhs1.read_unchecked(i, 0),
529                rhs0.read_unchecked(i, 0),
530            ));
531            acc20 = acc20.faer_add(E::faer_mul(
532                lhs0.read_unchecked(i, 0),
533                rhs1.read_unchecked(i, 0),
534            ));
535            acc30 = acc30.faer_add(E::faer_mul(
536                lhs1.read_unchecked(i, 0),
537                rhs1.read_unchecked(i, 0),
538            ));
539
540            i += 1;
541        }
542    }
543
544    (
545        acc00.faer_add(acc01),
546        acc10.faer_add(acc11),
547        acc20.faer_add(acc21),
548        acc30.faer_add(acc31),
549    )
550}
551
552/// See [`compute_evd_real`].
553///
554/// This function takes an additional `epsilon` and `zero_threshold` parameters. `epsilon`
555/// represents the precision of the values in the matrix, and `zero_threshold` is the value below
556/// which the precision starts to deteriorate, e.g. due to denormalized numbers.
557///
558/// These values need to be provided manually for types that do not have a known precision at
559/// compile time, e.g. a dynamic multiprecision floating point type.
560pub fn compute_evd_real_custom_epsilon<E: RealField>(
561    matrix: MatRef<'_, E>,
562    s_re: MatMut<'_, E>,
563    s_im: MatMut<'_, E>,
564    u: Option<MatMut<'_, E>>,
565    epsilon: E,
566    zero_threshold: E,
567    parallelism: Parallelism,
568    stack: PodStack<'_>,
569    params: EvdParams,
570) {
571    let n = matrix.nrows();
572
573    assert!(all(
574        matrix.nrows() == matrix.ncols(),
575        s_re.nrows() == n,
576        s_re.ncols() == 1,
577        s_im.nrows() == n,
578        s_im.ncols() == 1,
579    ));
580    if let Some(u) = u.rb() {
581        assert!(all(u.nrows() == n, u.ncols() == n));
582    }
583
584    if n == 0 {
585        return;
586    }
587
588    #[cfg(feature = "perf-warn")]
589    if let Some(matrix) = u.rb() {
590        if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) {
591            if matrix.col_stride().unsigned_abs() == 1 {
592                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix.");
593            } else {
594                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found matrix with generic strides.");
595            }
596        }
597    }
598
599    if !matrix.is_all_finite() {
600        { s_re }.fill(E::faer_nan());
601        { s_im }.fill(E::faer_nan());
602        if let Some(mut u) = u {
603            u.fill(E::faer_nan());
604        }
605        return;
606    }
607
608    let householder_blocksize = recommended_blocksize::<E>(n - 1, n - 1);
609
610    let mut u = u;
611    let mut s_re = s_re;
612    let mut s_im = s_im;
613
614    let (mut h, stack) = temp_mat_uninit(n, n, stack);
615
616    h.copy_from(matrix);
617
618    let (mut z, mut stack) = temp_mat_zeroed::<E>(n, if u.is_some() { n } else { 0 }, stack);
619    let mut z = z.as_mut();
620    z.rb_mut()
621        .diagonal_mut()
622        .column_vector_mut()
623        .fill(E::faer_one());
624
625    {
626        let (mut householder, mut stack) =
627            temp_mat_uninit(householder_blocksize, n - 1, stack.rb_mut());
628        let mut householder = householder.as_mut();
629
630        hessenberg::make_hessenberg_in_place(
631            h.rb_mut(),
632            householder.rb_mut().transpose_mut(),
633            parallelism,
634            stack.rb_mut(),
635        );
636        if u.is_some() {
637            apply_block_householder_sequence_on_the_right_in_place_with_conj(
638                h.rb().submatrix(1, 0, n - 1, n - 1),
639                householder.rb(),
640                Conj::No,
641                z.rb_mut().submatrix_mut(1, 1, n - 1, n - 1),
642                parallelism,
643                stack,
644            );
645        }
646
647        for j in 0..n {
648            for i in j + 2..n {
649                h.write(i, j, E::faer_zero());
650            }
651        }
652    }
653
654    if let Some(mut u) = u.rb_mut() {
655        hessenberg_real_evd::multishift_qr(
656            true,
657            h.rb_mut(),
658            Some(z.rb_mut()),
659            s_re.rb_mut(),
660            s_im.rb_mut(),
661            0,
662            n,
663            epsilon,
664            zero_threshold,
665            parallelism,
666            stack.rb_mut(),
667            params,
668        );
669
670        let (mut x, _) = temp_mat_zeroed::<E>(n, n, stack);
671        let mut x = x.as_mut();
672
673        let mut norm = zero_threshold;
674        zipped!(h.rb()).for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(x)| {
675            norm = norm.faer_add(x.read().faer_abs());
676        });
677        // subdiagonal
678        zipped!(h
679            .rb()
680            .submatrix(1, 0, n - 1, n - 1)
681            .diagonal()
682            .column_vector()
683            .as_2d())
684        .for_each(|unzipped!(x)| {
685            norm = norm.faer_add(x.read().faer_abs());
686        });
687
688        let mut h = h.transpose_mut();
689
690        for j in 1..n {
691            let upper = h.read(j - 1, j);
692            let lower = h.read(j, j - 1);
693
694            h.write(j - 1, j, lower);
695            h.write(j, j - 1, upper);
696        }
697
698        for j in 2..n {
699            for i in 0..j - 1 {
700                h.write(i, j, h.read(j, i));
701            }
702        }
703
704        {
705            let mut k = n;
706            loop {
707                if k == 0 {
708                    break;
709                }
710                k -= 1;
711
712                if k == 0 || h.read(k, k - 1) == E::faer_zero() {
713                    // real eigenvalue
714                    let p = h.read(k, k);
715
716                    x.write(k, k, E::faer_one());
717
718                    // solve (h[:k, :k] - p I) X = -h[:i, i]
719                    // form RHS
720                    for i in 0..k {
721                        x.write(i, k, h.read(i, k).faer_neg());
722                    }
723
724                    // solve in place
725                    let mut i = k;
726                    loop {
727                        if i == 0 {
728                            break;
729                        }
730                        i -= 1;
731
732                        if i == 0 || h.read(i, i - 1) == E::faer_zero() {
733                            // 1x1 block
734                            let dot = inner_prod_with_conj(
735                                h.rb().row(i).subcols(i + 1, k - i - 1).transpose().as_2d(),
736                                Conj::No,
737                                x.rb().col(k).subrows(i + 1, k - i - 1).as_2d(),
738                                Conj::No,
739                            );
740
741                            x.write(i, k, x.read(i, k).faer_sub(dot));
742                            let mut z = h.read(i, i).faer_sub(p);
743                            if z == E::faer_zero() {
744                                z = epsilon.faer_mul(norm);
745                            }
746                            let z_inv = z.faer_inv();
747                            let x_ = x.read(i, k);
748                            if x_ != E::faer_zero() {
749                                x.write(i, k, x.read(i, k).faer_mul(z_inv));
750                            }
751                        } else {
752                            // 2x2 block
753                            let dot0 = inner_prod_with_conj(
754                                h.rb()
755                                    .row(i - 1)
756                                    .subcols(i + 1, k - i - 1)
757                                    .transpose()
758                                    .as_2d(),
759                                Conj::No,
760                                x.rb().col(k).subrows(i + 1, k - i - 1).as_2d(),
761                                Conj::No,
762                            );
763                            let dot1 = inner_prod_with_conj(
764                                h.rb().row(i).subcols(i + 1, k - i - 1).transpose().as_2d(),
765                                Conj::No,
766                                x.rb().col(k).subrows(i + 1, k - i - 1).as_2d(),
767                                Conj::No,
768                            );
769
770                            x.write(i - 1, k, x.read(i - 1, k).faer_sub(dot0));
771                            x.write(i, k, x.read(i, k).faer_sub(dot1));
772
773                            // solve
774                            // [a b  [x0    [r0
775                            //  c a]× x1] =  r1]
776                            //
777                            //  [x0    [a  -b  [r0
778                            //   x1] =  -c  a]× r1] / det
779                            let a = h.read(i, i).faer_sub(p);
780                            let b = h.read(i - 1, i);
781                            let c = h.read(i, i - 1);
782
783                            let r0 = x.read(i - 1, k);
784                            let r1 = x.read(i, k);
785
786                            let inv_det = (a.faer_mul(a).faer_sub(b.faer_mul(c))).faer_inv();
787
788                            let x0 = a.faer_mul(r0).faer_sub(b.faer_mul(r1)).faer_mul(inv_det);
789                            let x1 = a.faer_mul(r1).faer_sub(c.faer_mul(r0)).faer_mul(inv_det);
790
791                            x.write(i - 1, k, x0);
792                            x.write(i, k, x1);
793
794                            i -= 1;
795                        }
796                    }
797                } else {
798                    // complex eigenvalue pair
799                    let p = h.read(k, k);
800                    let q = h
801                        .read(k, k - 1)
802                        .faer_abs()
803                        .faer_sqrt()
804                        .faer_mul(h.read(k - 1, k).faer_abs().faer_sqrt());
805
806                    if h.read(k - 1, k).faer_abs() >= h.read(k, k - 1) {
807                        x.write(k - 1, k - 1, E::faer_one());
808                        x.write(k, k, q.faer_div(h.read(k - 1, k)));
809                    } else {
810                        x.write(k - 1, k - 1, q.faer_neg().faer_div(h.read(k, k - 1)));
811                        x.write(k, k, E::faer_one());
812                    }
813                    x.write(k - 1, k, E::faer_zero());
814                    x.write(k, k - 1, E::faer_zero());
815
816                    // solve (h[:k-1, :k-1] - (p + iq) I) X = RHS
817                    // form RHS
818                    for i in 0..k - 1 {
819                        x.write(
820                            i,
821                            k - 1,
822                            x.read(k - 1, k - 1).faer_neg().faer_mul(h.read(i, k - 1)),
823                        );
824                        x.write(i, k, x.read(k, k).faer_neg().faer_mul(h.read(i, k)));
825                    }
826
827                    // solve in place
828                    let mut i = k - 1;
829                    loop {
830                        use num_complex::Complex;
831
832                        if i == 0 {
833                            break;
834                        }
835                        i -= 1;
836
837                        if i == 0 || h.read(i, i - 1) == E::faer_zero() {
838                            // 1x1 block
839                            let start = i + 1;
840                            let len = k - 1 - (i + 1);
841                            let (dot_re, dot_im) = dot2(
842                                x.rb().col(k - 1).subrows(start, len).as_2d(),
843                                x.rb().col(k).subrows(start, len).as_2d(),
844                                h.rb().transpose().col(i).subrows(start, len).as_2d(),
845                            );
846
847                            x.write(i, k - 1, x.read(i, k - 1).faer_sub(dot_re));
848                            x.write(i, k, x.read(i, k).faer_sub(dot_im));
849
850                            let z = Complex {
851                                re: h.read(i, i).faer_sub(p),
852                                im: q.faer_neg(),
853                            };
854                            let z_inv = z.faer_inv();
855                            let x_ = Complex {
856                                re: x.read(i, k - 1),
857                                im: x.read(i, k),
858                            };
859                            if x_ != Complex::<E>::faer_zero() {
860                                let x_ = z_inv.faer_mul(x_);
861                                x.write(i, k - 1, x_.re);
862                                x.write(i, k, x_.im);
863                            }
864                        } else {
865                            // 2x2 block
866                            let start = i + 1;
867                            let len = k - 1 - (i + 1);
868                            let (dot0_re, dot0_im, dot1_re, dot1_im) = dot4(
869                                x.rb().col(k - 1).subrows(start, len).as_2d(),
870                                x.rb().col(k).subrows(start, len).as_2d(),
871                                h.rb().transpose().col(i - 1).subrows(start, len).as_2d(),
872                                h.rb().transpose().col(i).subrows(start, len).as_2d(),
873                            );
874                            let mut dot0 = Complex::<E>::faer_zero();
875                            let mut dot1 = Complex::<E>::faer_zero();
876                            for j in i + 1..k - 1 {
877                                dot0 = dot0.faer_add(
878                                    Complex {
879                                        re: x.read(j, k - 1),
880                                        im: x.read(j, k),
881                                    }
882                                    .faer_scale_real(h.read(i - 1, j)),
883                                );
884                                dot1 = dot1.faer_add(
885                                    Complex {
886                                        re: x.read(j, k - 1),
887                                        im: x.read(j, k),
888                                    }
889                                    .faer_scale_real(h.read(i, j)),
890                                );
891                            }
892
893                            x.write(i - 1, k - 1, x.read(i - 1, k - 1).faer_sub(dot0_re));
894                            x.write(i - 1, k, x.read(i - 1, k).faer_sub(dot0_im));
895                            x.write(i, k - 1, x.read(i, k - 1).faer_sub(dot1_re));
896                            x.write(i, k, x.read(i, k).faer_sub(dot1_im));
897
898                            let a = Complex {
899                                re: h.read(i, i).faer_sub(p),
900                                im: q.faer_neg(),
901                            };
902                            let b = h.read(i - 1, i);
903                            let c = h.read(i, i - 1);
904
905                            let r0 = Complex {
906                                re: x.read(i - 1, k - 1),
907                                im: x.read(i - 1, k),
908                            };
909                            let r1 = Complex {
910                                re: x.read(i, k - 1),
911                                im: x.read(i, k),
912                            };
913
914                            let inv_det = (a
915                                .faer_mul(a)
916                                .faer_sub(Complex::<E>::faer_from_real(b.faer_mul(c))))
917                            .faer_inv();
918
919                            let x0 = a
920                                .faer_mul(r0)
921                                .faer_sub(r1.faer_scale_real(b))
922                                .faer_mul(inv_det);
923                            let x1 = a
924                                .faer_mul(r1)
925                                .faer_sub(r0.faer_scale_real(c))
926                                .faer_mul(inv_det);
927
928                            x.write(i - 1, k - 1, x0.re);
929                            x.write(i - 1, k, x0.im);
930                            x.write(i, k - 1, x1.re);
931                            x.write(i, k, x1.im);
932
933                            i -= 1;
934                        }
935                    }
936
937                    k -= 1;
938                }
939            }
940        }
941
942        triangular::matmul(
943            u.rb_mut(),
944            BlockStructure::Rectangular,
945            z.rb(),
946            BlockStructure::Rectangular,
947            x.rb(),
948            BlockStructure::TriangularUpper,
949            None,
950            E::faer_one(),
951            parallelism,
952        );
953    } else {
954        hessenberg_real_evd::multishift_qr(
955            false,
956            h.rb_mut(),
957            None,
958            s_re.rb_mut(),
959            s_im.rb_mut(),
960            0,
961            n,
962            epsilon,
963            zero_threshold,
964            parallelism,
965            stack.rb_mut(),
966            params,
967        );
968    }
969}
970
971/// Computes the size and alignment of required workspace for performing an eigenvalue
972/// decomposition. The eigenvectors may be optionally computed.
973pub fn compute_evd_req<E: ComplexField>(
974    n: usize,
975    compute_eigenvectors: ComputeVectors,
976    parallelism: Parallelism,
977    params: EvdParams,
978) -> Result<StackReq, SizeOverflow> {
979    if n == 0 {
980        return Ok(StackReq::empty());
981    }
982    let householder_blocksize = recommended_blocksize::<E>(n - 1, n - 1);
983    let compute_vecs = matches!(compute_eigenvectors, ComputeVectors::Yes);
984    StackReq::try_all_of([
985        // h
986        temp_mat_req::<E>(n, n)?,
987        // z
988        temp_mat_req::<E>(n, if compute_vecs { n } else { 0 })?,
989        StackReq::try_any_of([
990            StackReq::try_all_of([
991                temp_mat_req::<E>(householder_blocksize, n - 1)?,
992                StackReq::try_any_of([
993                    hessenberg::make_hessenberg_in_place_req::<E>(
994                        n,
995                        householder_blocksize,
996                        parallelism,
997                    )?,
998                    apply_block_householder_sequence_on_the_right_in_place_req::<E>(
999                        n - 1,
1000                        householder_blocksize,
1001                        n,
1002                    )?,
1003                ])?,
1004            ])?,
1005            StackReq::try_any_of([
1006                hessenberg_cplx_evd::multishift_qr_req::<E>(
1007                    n,
1008                    n,
1009                    compute_vecs,
1010                    compute_vecs,
1011                    parallelism,
1012                    params,
1013                )?,
1014                temp_mat_req::<E>(n, n)?,
1015            ])?,
1016        ])?,
1017    ])
1018}
1019
1020/// Computes the eigenvalue decomposition of a square complex `matrix`.
1021///
1022/// `s` represents the diagonal of the matrix $S$, and must have size equal to the dimension of the
1023/// matrix.
1024///
1025/// If `u` is `None`, then only the eigenvalues are computed. Otherwise, the eigenvectors are
1026/// computed and stored in `u`.
1027///
1028/// # Panics
1029/// Panics if any of the conditions described above is violated, or if the type `E` does not have a
1030/// fixed precision at compile time, e.g. a dynamic multiprecision floating point type.
1031///
1032/// This can also panic if the provided memory in `stack` is insufficient (see [`compute_evd_req`]).
1033pub fn compute_evd_complex<E: ComplexField>(
1034    matrix: MatRef<'_, E>,
1035    s: MatMut<'_, E>,
1036    u: Option<MatMut<'_, E>>,
1037    parallelism: Parallelism,
1038    stack: PodStack<'_>,
1039    params: EvdParams,
1040) {
1041    compute_evd_complex_custom_epsilon(
1042        matrix,
1043        s,
1044        u,
1045        E::Real::faer_epsilon().unwrap(),
1046        E::Real::faer_zero_threshold().unwrap(),
1047        parallelism,
1048        stack,
1049        params,
1050    );
1051}
1052
1053/// See [`compute_evd_complex`].
1054///
1055/// This function takes an additional `epsilon` and `zero_threshold` parameters. `epsilon`
1056/// represents the precision of the values in the matrix, and `zero_threshold` is the value below
1057/// which the precision starts to deteriorate, e.g. due to denormalized numbers.
1058///
1059/// These values need to be provided manually for types that do not have a known precision at
1060/// compile time, e.g. a dynamic multiprecision floating point type.
1061pub fn compute_evd_complex_custom_epsilon<E: ComplexField>(
1062    matrix: MatRef<'_, E>,
1063    s: MatMut<'_, E>,
1064    u: Option<MatMut<'_, E>>,
1065    epsilon: E::Real,
1066    zero_threshold: E::Real,
1067    parallelism: Parallelism,
1068    stack: PodStack<'_>,
1069    params: EvdParams,
1070) {
1071    assert!(!coe::is_same::<E, E::Real>());
1072    let n = matrix.nrows();
1073
1074    assert!(all(
1075        matrix.nrows() == matrix.ncols(),
1076        s.nrows() == n,
1077        s.ncols() == 1,
1078    ));
1079    if let Some(u) = u.rb() {
1080        assert!(all(u.nrows() == n, u.ncols() == n));
1081    }
1082
1083    if n == 0 {
1084        return;
1085    }
1086
1087    #[cfg(feature = "perf-warn")]
1088    if let Some(matrix) = u.rb() {
1089        if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) {
1090            if matrix.col_stride().unsigned_abs() == 1 {
1091                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix.");
1092            } else {
1093                log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found matrix with generic strides.");
1094            }
1095        }
1096    }
1097
1098    if !matrix.is_all_finite() {
1099        { s }.fill(E::faer_nan());
1100        if let Some(mut u) = u {
1101            u.fill(E::faer_nan());
1102        }
1103        return;
1104    }
1105
1106    let householder_blocksize = recommended_blocksize::<E>(n - 1, n - 1);
1107
1108    let mut u = u;
1109    let mut s = s;
1110
1111    let (mut h, stack) = temp_mat_uninit(n, n, stack);
1112    let mut h = h.as_mut();
1113
1114    h.copy_from(matrix);
1115
1116    let (mut z, mut stack) = temp_mat_zeroed::<E>(n, if u.is_some() { n } else { 0 }, stack);
1117    let mut z = z.as_mut();
1118    z.rb_mut()
1119        .diagonal_mut()
1120        .column_vector_mut()
1121        .fill(E::faer_one());
1122
1123    {
1124        let (mut householder, mut stack) =
1125            temp_mat_uninit(n - 1, householder_blocksize, stack.rb_mut());
1126        let mut householder = householder.as_mut();
1127
1128        hessenberg::make_hessenberg_in_place(
1129            h.rb_mut(),
1130            householder.rb_mut(),
1131            parallelism,
1132            stack.rb_mut(),
1133        );
1134        if u.is_some() {
1135            apply_block_householder_sequence_on_the_right_in_place_with_conj(
1136                h.rb().submatrix(1, 0, n - 1, n - 1),
1137                householder.rb().transpose(),
1138                Conj::No,
1139                z.rb_mut().submatrix_mut(1, 1, n - 1, n - 1),
1140                parallelism,
1141                stack,
1142            );
1143        }
1144
1145        for j in 0..n {
1146            for i in j + 2..n {
1147                h.write(i, j, E::faer_zero());
1148            }
1149        }
1150    }
1151
1152    if let Some(mut u) = u.rb_mut() {
1153        hessenberg_cplx_evd::multishift_qr(
1154            true,
1155            h.rb_mut(),
1156            Some(z.rb_mut()),
1157            s.rb_mut(),
1158            0,
1159            n,
1160            epsilon,
1161            zero_threshold,
1162            parallelism,
1163            stack.rb_mut(),
1164            params,
1165        );
1166
1167        let (mut x, _) = temp_mat_zeroed::<E>(n, n, stack);
1168        let mut x = x.as_mut();
1169
1170        let mut norm = zero_threshold;
1171        zipped!(h.rb()).for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(x)| {
1172            norm = norm.faer_add(x.read().faer_abs2());
1173        });
1174        let norm = norm.faer_sqrt();
1175
1176        let mut h = h.transpose_mut();
1177
1178        for j in 1..n {
1179            for i in 0..j {
1180                h.write(i, j, h.read(j, i));
1181            }
1182        }
1183
1184        for k in (0..n).rev() {
1185            x.write(k, k, E::faer_zero());
1186            for i in (0..k).rev() {
1187                x.write(i, k, h.read(i, k).faer_neg());
1188                if k > i + 1 {
1189                    let dot = inner_prod_with_conj(
1190                        h.rb().row(i).subcols(i + 1, k - i - 1).transpose().as_2d(),
1191                        Conj::No,
1192                        x.rb().col(k).subrows(i + 1, k - i - 1).as_2d(),
1193                        Conj::No,
1194                    );
1195                    x.write(i, k, x.read(i, k).faer_sub(dot));
1196                }
1197
1198                let mut z = h.read(i, i).faer_sub(h.read(k, k));
1199                if z == E::faer_zero() {
1200                    z = E::faer_from_real(epsilon.faer_mul(norm));
1201                }
1202                let z_inv = z.faer_inv();
1203                let x_ = x.read(i, k);
1204                if x_ != E::faer_zero() {
1205                    x.write(i, k, x.read(i, k).faer_mul(z_inv));
1206                }
1207            }
1208        }
1209
1210        triangular::matmul(
1211            u.rb_mut(),
1212            BlockStructure::Rectangular,
1213            z.rb(),
1214            BlockStructure::Rectangular,
1215            x.rb(),
1216            BlockStructure::UnitTriangularUpper,
1217            None,
1218            E::faer_one(),
1219            parallelism,
1220        );
1221    } else {
1222        hessenberg_cplx_evd::multishift_qr(
1223            false,
1224            h.rb_mut(),
1225            None,
1226            s.rb_mut(),
1227            0,
1228            n,
1229            epsilon,
1230            zero_threshold,
1231            parallelism,
1232            stack.rb_mut(),
1233            params,
1234        );
1235    }
1236}
1237
1238#[cfg(test)]
1239mod herm_tests {
1240    use super::*;
1241    use assert_approx_eq::assert_approx_eq;
1242    use faer_core::{assert, c64, Mat};
1243
1244    macro_rules! make_stack {
1245        ($req: expr) => {
1246            ::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
1247        };
1248    }
1249
1250    #[test]
1251    fn test_real() {
1252        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1253            let mat = Mat::from_fn(n, n, |_, _| rand::random::<f64>());
1254
1255            let mut s = Mat::zeros(n, n);
1256            let mut u = Mat::zeros(n, n);
1257
1258            compute_hermitian_evd(
1259                mat.as_ref(),
1260                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1261                Some(u.as_mut()),
1262                Parallelism::None,
1263                make_stack!(compute_hermitian_evd_req::<f64>(
1264                    n,
1265                    ComputeVectors::Yes,
1266                    Parallelism::None,
1267                    Default::default(),
1268                )),
1269                Default::default(),
1270            );
1271
1272            let reconstructed = &u * &s * u.transpose();
1273
1274            for j in 0..n {
1275                for i in j..n {
1276                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1277                }
1278            }
1279        }
1280    }
1281
1282    #[test]
1283    fn test_cplx() {
1284        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1285            let mat = Mat::from_fn(n, n, |i, j| {
1286                c64::new(rand::random(), if i == j { 0.0 } else { rand::random() })
1287            });
1288
1289            let mut s = Mat::zeros(n, n);
1290            let mut u = Mat::zeros(n, n);
1291
1292            compute_hermitian_evd(
1293                mat.as_ref(),
1294                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1295                Some(u.as_mut()),
1296                Parallelism::None,
1297                make_stack!(compute_hermitian_evd_req::<c64>(
1298                    n,
1299                    ComputeVectors::Yes,
1300                    Parallelism::None,
1301                    Default::default(),
1302                )),
1303                Default::default(),
1304            );
1305
1306            let reconstructed = &u * &s * u.adjoint();
1307            dbgf::dbgf!("6.2?", &u, &reconstructed, &mat);
1308
1309            for j in 0..n {
1310                for i in j..n {
1311                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1312                }
1313            }
1314        }
1315    }
1316
1317    #[test]
1318    fn test_real_identity() {
1319        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1320            let mat = Mat::from_fn(n, n, |i, j| {
1321                if i == j {
1322                    f64::faer_one()
1323                } else {
1324                    f64::faer_zero()
1325                }
1326            });
1327
1328            let mut s = Mat::zeros(n, n);
1329            let mut u = Mat::zeros(n, n);
1330
1331            compute_hermitian_evd(
1332                mat.as_ref(),
1333                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1334                Some(u.as_mut()),
1335                Parallelism::None,
1336                make_stack!(compute_hermitian_evd_req::<f64>(
1337                    n,
1338                    ComputeVectors::Yes,
1339                    Parallelism::None,
1340                    Default::default(),
1341                )),
1342                Default::default(),
1343            );
1344
1345            let reconstructed = &u * &s * u.transpose();
1346
1347            for j in 0..n {
1348                for i in j..n {
1349                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1350                }
1351            }
1352        }
1353    }
1354
1355    #[test]
1356    fn test_cplx_identity() {
1357        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1358            let mat = Mat::from_fn(n, n, |i, j| {
1359                if i == j {
1360                    c64::faer_one()
1361                } else {
1362                    c64::faer_zero()
1363                }
1364            });
1365
1366            let mut s = Mat::zeros(n, n);
1367            let mut u = Mat::zeros(n, n);
1368
1369            compute_hermitian_evd(
1370                mat.as_ref(),
1371                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1372                Some(u.as_mut()),
1373                Parallelism::None,
1374                make_stack!(compute_hermitian_evd_req::<c64>(
1375                    n,
1376                    ComputeVectors::Yes,
1377                    Parallelism::None,
1378                    Default::default(),
1379                )),
1380                Default::default(),
1381            );
1382
1383            let reconstructed = &u * &s * u.adjoint();
1384            dbgf::dbgf!("6.2?", &u, &reconstructed, &mat);
1385
1386            for j in 0..n {
1387                for i in j..n {
1388                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1389                }
1390            }
1391        }
1392    }
1393
1394    #[test]
1395    fn test_real_zero() {
1396        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1397            let mat = Mat::from_fn(n, n, |_, _| f64::faer_zero());
1398
1399            let mut s = Mat::zeros(n, n);
1400            let mut u = Mat::zeros(n, n);
1401
1402            compute_hermitian_evd(
1403                mat.as_ref(),
1404                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1405                Some(u.as_mut()),
1406                Parallelism::None,
1407                make_stack!(compute_hermitian_evd_req::<f64>(
1408                    n,
1409                    ComputeVectors::Yes,
1410                    Parallelism::None,
1411                    Default::default(),
1412                )),
1413                Default::default(),
1414            );
1415
1416            let reconstructed = &u * &s * u.transpose();
1417
1418            for j in 0..n {
1419                for i in j..n {
1420                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1421                }
1422            }
1423        }
1424    }
1425
1426    #[test]
1427    fn test_cplx_zero() {
1428        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1429            let mat = Mat::from_fn(n, n, |_, _| c64::faer_zero());
1430
1431            let mut s = Mat::zeros(n, n);
1432            let mut u = Mat::zeros(n, n);
1433
1434            compute_hermitian_evd(
1435                mat.as_ref(),
1436                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1437                Some(u.as_mut()),
1438                Parallelism::None,
1439                make_stack!(compute_hermitian_evd_req::<c64>(
1440                    n,
1441                    ComputeVectors::Yes,
1442                    Parallelism::None,
1443                    Default::default(),
1444                )),
1445                Default::default(),
1446            );
1447
1448            let reconstructed = &u * &s * u.adjoint();
1449            dbgf::dbgf!("6.2?", &u, &reconstructed, &mat);
1450
1451            for j in 0..n {
1452                for i in j..n {
1453                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1454                }
1455            }
1456        }
1457    }
1458}
1459
1460#[cfg(test)]
1461mod tests {
1462    use super::*;
1463    use assert_approx_eq::assert_approx_eq;
1464    use faer_core::{assert, c64, Mat};
1465    use num_complex::Complex;
1466
1467    macro_rules! make_stack {
1468        ($req: expr) => {
1469            ::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
1470        };
1471    }
1472
1473    #[test]
1474    fn test_real_3() {
1475        let mat = faer_core::mat![
1476            [0.03498524449256035, 0.5246466104879548, 0.20804192188707582,],
1477            [0.007467248113335545, 0.1723793560841066, 0.2677423170633869,],
1478            [
1479                0.5907508388039022,
1480                0.11540612644030279,
1481                0.2624452803216497f64,
1482            ],
1483        ];
1484
1485        let n = mat.nrows();
1486
1487        let mut s_re = Mat::zeros(n, n);
1488        let mut s_im = Mat::zeros(n, n);
1489        let mut u_re = Mat::zeros(n, n);
1490        let mut u_im = Mat::zeros(n, n);
1491
1492        compute_evd_real(
1493            mat.as_ref(),
1494            s_re.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1495            s_im.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1496            Some(u_re.as_mut()),
1497            Parallelism::None,
1498            make_stack!(compute_evd_req::<c64>(
1499                n,
1500                ComputeVectors::Yes,
1501                Parallelism::None,
1502                Default::default(),
1503            )),
1504            Default::default(),
1505        );
1506
1507        let mut j = 0;
1508        loop {
1509            if j == n {
1510                break;
1511            }
1512
1513            if s_im.read(j, j) != 0.0 {
1514                for i in 0..n {
1515                    u_im.write(i, j, u_re.read(i, j + 1));
1516                    u_im.write(i, j + 1, -u_re.read(i, j + 1));
1517                    u_re.write(i, j + 1, u_re.read(i, j));
1518                }
1519
1520                j += 1;
1521            }
1522
1523            j += 1;
1524        }
1525
1526        let u = Mat::from_fn(n, n, |i, j| Complex::new(u_re.read(i, j), u_im.read(i, j)));
1527        let s = Mat::from_fn(n, n, |i, j| Complex::new(s_re.read(i, j), s_im.read(i, j)));
1528        let mat = Mat::from_fn(n, n, |i, j| Complex::new(mat.read(i, j), 0.0));
1529
1530        let left = &mat * &u;
1531        let right = &u * &s;
1532
1533        for j in 0..n {
1534            for i in 0..n {
1535                assert_approx_eq!(left.read(i, j).re, right.read(i, j).re, 1e-10);
1536                assert_approx_eq!(left.read(i, j).im, right.read(i, j).im, 1e-10);
1537            }
1538        }
1539    }
1540
1541    #[test]
1542    fn test_real() {
1543        for n in [3, 2, 4, 5, 6, 7, 10, 15, 25] {
1544            for _ in 0..10 {
1545                let mat = Mat::from_fn(n, n, |_, _| rand::random::<f64>());
1546                dbg!(&mat);
1547
1548                let n = mat.nrows();
1549
1550                let mut s_re = Mat::zeros(n, n);
1551                let mut s_im = Mat::zeros(n, n);
1552                let mut u_re = Mat::zeros(n, n);
1553                let mut u_im = Mat::zeros(n, n);
1554
1555                compute_evd_real(
1556                    mat.as_ref(),
1557                    s_re.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1558                    s_im.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1559                    Some(u_re.as_mut()),
1560                    Parallelism::None,
1561                    make_stack!(compute_evd_req::<c64>(
1562                        n,
1563                        ComputeVectors::Yes,
1564                        Parallelism::None,
1565                        Default::default(),
1566                    )),
1567                    Default::default(),
1568                );
1569
1570                let mut j = 0;
1571                loop {
1572                    if j == n {
1573                        break;
1574                    }
1575
1576                    if s_im.read(j, j) != 0.0 {
1577                        for i in 0..n {
1578                            u_im.write(i, j, u_re.read(i, j + 1));
1579                            u_im.write(i, j + 1, -u_re.read(i, j + 1));
1580                            u_re.write(i, j + 1, u_re.read(i, j));
1581                        }
1582
1583                        j += 1;
1584                    }
1585
1586                    j += 1;
1587                }
1588
1589                let u = Mat::from_fn(n, n, |i, j| Complex::new(u_re.read(i, j), u_im.read(i, j)));
1590                let s = Mat::from_fn(n, n, |i, j| Complex::new(s_re.read(i, j), s_im.read(i, j)));
1591                let mat = Mat::from_fn(n, n, |i, j| Complex::new(mat.read(i, j), 0.0));
1592
1593                let left = &mat * &u;
1594                let right = &u * &s;
1595
1596                for j in 0..n {
1597                    for i in 0..n {
1598                        assert_approx_eq!(left.read(i, j).re, right.read(i, j).re, 1e-10);
1599                        assert_approx_eq!(left.read(i, j).im, right.read(i, j).im, 1e-10);
1600                    }
1601                }
1602            }
1603        }
1604    }
1605
1606    #[test]
1607    fn test_real_identity() {
1608        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1609            let mat = Mat::from_fn(n, n, |i, j| {
1610                if i == j {
1611                    f64::faer_one()
1612                } else {
1613                    f64::faer_zero()
1614                }
1615            });
1616
1617            let n = mat.nrows();
1618
1619            let mut s_re = Mat::zeros(n, n);
1620            let mut s_im = Mat::zeros(n, n);
1621            let mut u_re = Mat::zeros(n, n);
1622            let mut u_im = Mat::zeros(n, n);
1623
1624            compute_evd_real(
1625                mat.as_ref(),
1626                s_re.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1627                s_im.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1628                Some(u_re.as_mut()),
1629                Parallelism::None,
1630                make_stack!(compute_evd_req::<c64>(
1631                    n,
1632                    ComputeVectors::Yes,
1633                    Parallelism::None,
1634                    Default::default(),
1635                )),
1636                Default::default(),
1637            );
1638
1639            let mut j = 0;
1640            loop {
1641                if j == n {
1642                    break;
1643                }
1644
1645                if s_im.read(j, j) != 0.0 {
1646                    for i in 0..n {
1647                        u_im.write(i, j, u_re.read(i, j + 1));
1648                        u_im.write(i, j + 1, -u_re.read(i, j + 1));
1649                        u_re.write(i, j + 1, u_re.read(i, j));
1650                    }
1651
1652                    j += 1;
1653                }
1654
1655                j += 1;
1656            }
1657
1658            let u = Mat::from_fn(n, n, |i, j| Complex::new(u_re.read(i, j), u_im.read(i, j)));
1659            let s = Mat::from_fn(n, n, |i, j| Complex::new(s_re.read(i, j), s_im.read(i, j)));
1660            let mat = Mat::from_fn(n, n, |i, j| Complex::new(mat.read(i, j), 0.0));
1661
1662            let left = &mat * &u;
1663            let right = &u * &s;
1664
1665            for j in 0..n {
1666                for i in 0..n {
1667                    assert_approx_eq!(left.read(i, j).re, right.read(i, j).re, 1e-10);
1668                    assert_approx_eq!(left.read(i, j).im, right.read(i, j).im, 1e-10);
1669                }
1670            }
1671        }
1672    }
1673
1674    #[test]
1675    fn test_real_zero() {
1676        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1677            let mat = Mat::<f64>::zeros(n, n);
1678
1679            let n = mat.nrows();
1680
1681            let mut s_re = Mat::zeros(n, n);
1682            let mut s_im = Mat::zeros(n, n);
1683            let mut u_re = Mat::zeros(n, n);
1684            let mut u_im = Mat::zeros(n, n);
1685
1686            compute_evd_real(
1687                mat.as_ref(),
1688                s_re.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1689                s_im.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1690                Some(u_re.as_mut()),
1691                Parallelism::None,
1692                make_stack!(compute_evd_req::<c64>(
1693                    n,
1694                    ComputeVectors::Yes,
1695                    Parallelism::None,
1696                    Default::default(),
1697                )),
1698                Default::default(),
1699            );
1700
1701            let mut j = 0;
1702            loop {
1703                if j == n {
1704                    break;
1705                }
1706
1707                if s_im.read(j, j) != 0.0 {
1708                    for i in 0..n {
1709                        u_im.write(i, j, u_re.read(i, j + 1));
1710                        u_im.write(i, j + 1, -u_re.read(i, j + 1));
1711                        u_re.write(i, j + 1, u_re.read(i, j));
1712                    }
1713
1714                    j += 1;
1715                }
1716
1717                j += 1;
1718            }
1719
1720            let u = Mat::from_fn(n, n, |i, j| Complex::new(u_re.read(i, j), u_im.read(i, j)));
1721            let s = Mat::from_fn(n, n, |i, j| Complex::new(s_re.read(i, j), s_im.read(i, j)));
1722            let mat = Mat::from_fn(n, n, |i, j| Complex::new(mat.read(i, j), 0.0));
1723
1724            let left = &mat * &u;
1725            let right = &u * &s;
1726
1727            for j in 0..n {
1728                for i in 0..n {
1729                    assert_approx_eq!(left.read(i, j).re, right.read(i, j).re, 1e-10);
1730                    assert_approx_eq!(left.read(i, j).im, right.read(i, j).im, 1e-10);
1731                }
1732            }
1733        }
1734    }
1735
1736    #[test]
1737    fn test_cplx() {
1738        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1739            let mat = Mat::from_fn(n, n, |_, _| c64::new(rand::random(), rand::random()));
1740
1741            let mut s = Mat::zeros(n, n);
1742            let mut u = Mat::zeros(n, n);
1743
1744            compute_evd_complex(
1745                mat.as_ref(),
1746                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1747                Some(u.as_mut()),
1748                Parallelism::None,
1749                make_stack!(compute_evd_req::<c64>(
1750                    n,
1751                    ComputeVectors::Yes,
1752                    Parallelism::None,
1753                    Default::default(),
1754                )),
1755                Default::default(),
1756            );
1757
1758            let left = &mat * &u;
1759            let right = &u * &s;
1760
1761            dbgf::dbgf!("6.2?", &mat, &left, &right);
1762
1763            for j in 0..n {
1764                for i in 0..n {
1765                    assert_approx_eq!(left.read(i, j), right.read(i, j), 1e-10);
1766                }
1767            }
1768        }
1769    }
1770
1771    #[test]
1772    fn test_cplx_identity() {
1773        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1774            let mat = Mat::from_fn(n, n, |i, j| {
1775                if i == j {
1776                    c64::faer_one()
1777                } else {
1778                    c64::faer_zero()
1779                }
1780            });
1781
1782            let mut s = Mat::zeros(n, n);
1783            let mut u = Mat::zeros(n, n);
1784
1785            compute_evd_complex(
1786                mat.as_ref(),
1787                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1788                Some(u.as_mut()),
1789                Parallelism::None,
1790                make_stack!(compute_evd_req::<c64>(
1791                    n,
1792                    ComputeVectors::Yes,
1793                    Parallelism::None,
1794                    Default::default(),
1795                )),
1796                Default::default(),
1797            );
1798
1799            let left = &mat * &u;
1800            let right = &u * &s;
1801
1802            for j in 0..n {
1803                for i in 0..n {
1804                    assert_approx_eq!(left.read(i, j), right.read(i, j), 1e-10);
1805                }
1806            }
1807        }
1808    }
1809
1810    #[test]
1811    fn test_cplx_zero() {
1812        for n in [2, 3, 4, 5, 6, 7, 10, 15, 25] {
1813            let mat = Mat::from_fn(n, n, |_, _| c64::faer_zero());
1814
1815            let mut s = Mat::zeros(n, n);
1816            let mut u = Mat::zeros(n, n);
1817
1818            compute_evd_complex(
1819                mat.as_ref(),
1820                s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1821                Some(u.as_mut()),
1822                Parallelism::None,
1823                make_stack!(compute_evd_req::<c64>(
1824                    n,
1825                    ComputeVectors::Yes,
1826                    Parallelism::None,
1827                    Default::default(),
1828                )),
1829                Default::default(),
1830            );
1831
1832            let left = &mat * &u;
1833            let right = &u * &s;
1834
1835            for j in 0..n {
1836                for i in 0..n {
1837                    assert_approx_eq!(left.read(i, j), right.read(i, j), 1e-10);
1838                }
1839            }
1840        }
1841    }
1842
1843    // https://github.com/sarah-ek/faer-rs/issues/78
1844    #[test]
1845    fn test_cplx_gh78() {
1846        let i = c64::new(0.0, 1.0);
1847
1848        let mat = faer_core::mat![
1849            [
1850                0.0 + 0.0 * i,
1851                0.0 + 0.0 * i,
1852                0.0 + 0.0 * i,
1853                2.220446049250313e-16 + -1.0000000000000002 * i
1854            ],
1855            [
1856                0.0 + 0.0 * i,
1857                0.0 + 0.0 * i,
1858                2.220446049250313e-16 + -1.0000000000000002 * i,
1859                0.0 + 0.0 * i
1860            ],
1861            [
1862                0.0 + 0.0 * i,
1863                2.220446049250313e-16 + -1.0000000000000002 * i,
1864                0.0 + 0.0 * i,
1865                0.0 + 0.0 * i
1866            ],
1867            [
1868                2.220446049250313e-16 + -1.0000000000000002 * i,
1869                0.0 + 0.0 * i,
1870                0.0 + 0.0 * i,
1871                0.0 + 0.0 * i
1872            ],
1873        ];
1874        let n = mat.nrows();
1875
1876        let mut s = Mat::zeros(n, n);
1877        let mut u = Mat::zeros(n, n);
1878
1879        compute_evd_complex(
1880            mat.as_ref(),
1881            s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1882            Some(u.as_mut()),
1883            Parallelism::None,
1884            make_stack!(compute_evd_req::<c64>(
1885                n,
1886                ComputeVectors::Yes,
1887                Parallelism::None,
1888                Default::default(),
1889            )),
1890            Default::default(),
1891        );
1892
1893        let left = &mat * &u;
1894        let right = &u * &s;
1895
1896        for j in 0..n {
1897            for i in 0..n {
1898                assert_approx_eq!(left.read(i, j), right.read(i, j), 1e-10);
1899            }
1900        }
1901    }
1902}