Skip to main content

oxiblas_lapack/evd/
schur.rs

1//! Schur decomposition.
2//!
3//! Computes the Schur decomposition A = Q T Q^T where Q is orthogonal
4//! and T is quasi-upper triangular (upper triangular with possible 2×2 blocks
5//! on the diagonal for complex eigenvalue pairs).
6
7use oxiblas_core::scalar::{Field, Real, Scalar};
8use oxiblas_matrix::{Mat, MatRef};
9
10use super::hessenberg::Hessenberg;
11
12/// Error type for Schur decomposition.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchurError {
15    /// Matrix is empty.
16    EmptyMatrix,
17    /// Matrix is not square.
18    NotSquare,
19    /// Algorithm did not converge.
20    NotConverged,
21}
22
23impl core::fmt::Display for SchurError {
24    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25        match self {
26            Self::EmptyMatrix => write!(f, "Matrix is empty"),
27            Self::NotSquare => write!(f, "Matrix must be square"),
28            Self::NotConverged => write!(f, "Schur decomposition did not converge"),
29        }
30    }
31}
32
33impl std::error::Error for SchurError {}
34
35/// Represents a real or complex eigenvalue.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Eigenvalue<T> {
38    /// Real part of the eigenvalue.
39    pub real: T,
40    /// Imaginary part of the eigenvalue (zero for real eigenvalues).
41    pub imag: T,
42}
43
44impl<T: Scalar> Eigenvalue<T> {
45    /// Creates a real eigenvalue.
46    pub fn real_only(value: T) -> Self {
47        Self {
48            real: value,
49            imag: T::zero(),
50        }
51    }
52
53    /// Creates a complex eigenvalue.
54    pub fn complex(real: T, imag: T) -> Self {
55        Self { real, imag }
56    }
57
58    /// Returns true if this is a real eigenvalue.
59    pub fn is_real(&self) -> bool {
60        self.imag == T::zero()
61    }
62}
63
64/// Schur decomposition of a matrix.
65///
66/// For a matrix A, computes A = Q T Q^T where:
67/// - Q is orthogonal (Q^T Q = I)
68/// - T is quasi-upper triangular (real Schur form)
69#[derive(Debug, Clone)]
70pub struct Schur<T: Scalar> {
71    /// The orthogonal matrix Q (Schur vectors).
72    q: Mat<T>,
73    /// The quasi-upper triangular matrix T.
74    t: Mat<T>,
75    /// Eigenvalues (real and complex pairs).
76    eigenvalues: Vec<Eigenvalue<T>>,
77    /// Matrix dimension.
78    n: usize,
79}
80
81impl<T: Field + Real + bytemuck::Zeroable> Schur<T> {
82    /// Maximum iterations for QR iteration.
83    const MAX_ITERATIONS: usize = 100;
84
85    /// Computes the Schur decomposition of a square matrix.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use oxiblas_lapack::evd::Schur;
91    /// use oxiblas_matrix::Mat;
92    ///
93    /// let a = Mat::from_rows(&[
94    ///     &[1.0f64, 2.0],
95    ///     &[0.0, 3.0],
96    /// ]);
97    ///
98    /// let schur = Schur::compute(a.as_ref()).unwrap();
99    /// let eigenvalues = schur.eigenvalues();
100    ///
101    /// // Eigenvalues are 1 and 3
102    /// assert!((eigenvalues[0].real - 1.0).abs() < 1e-10 || (eigenvalues[0].real - 3.0).abs() < 1e-10);
103    /// ```
104    pub fn compute(a: MatRef<'_, T>) -> Result<Self, SchurError> {
105        let m = a.nrows();
106        let n = a.ncols();
107
108        if m == 0 || n == 0 {
109            return Err(SchurError::EmptyMatrix);
110        }
111
112        if m != n {
113            return Err(SchurError::NotSquare);
114        }
115
116        // Handle 1×1 case
117        if n == 1 {
118            let mut t = Mat::zeros(1, 1);
119            t[(0, 0)] = a[(0, 0)];
120            let mut q = Mat::zeros(1, 1);
121            q[(0, 0)] = T::one();
122            let eigenvalues = vec![Eigenvalue::real_only(a[(0, 0)])];
123            return Ok(Self {
124                q,
125                t,
126                eigenvalues,
127                n,
128            });
129        }
130
131        // Handle 2×2 case
132        if n == 2 {
133            return Self::compute_2x2(a);
134        }
135
136        // Step 1: Reduce to upper Hessenberg form
137        let hess = Hessenberg::compute(a).map_err(|_| SchurError::NotSquare)?;
138        let mut t = Mat::zeros(n, n);
139        let h = hess.h();
140        for i in 0..n {
141            for j in 0..n {
142                t[(i, j)] = h[(i, j)];
143            }
144        }
145
146        let mut q = Mat::zeros(n, n);
147        let q_hess = hess.q();
148        for i in 0..n {
149            for j in 0..n {
150                q[(i, j)] = q_hess[(i, j)];
151            }
152        }
153
154        // Step 2: Apply QR iteration with implicit shifts
155        let eps = <T as Scalar>::epsilon();
156        let tol = eps * T::from_f64(100.0).unwrap_or(T::one());
157
158        // Process from bottom to top
159        let mut p = n;
160        let mut iter_count = 0;
161
162        while p > 2 && iter_count < Self::MAX_ITERATIONS * n {
163            iter_count += 1;
164
165            // Find the active block
166            let mut q_idx = p - 1;
167            while q_idx > 0 {
168                let sub = Scalar::abs(t[(q_idx, q_idx - 1)]);
169                let diag_sum =
170                    Scalar::abs(t[(q_idx - 1, q_idx - 1)]) + Scalar::abs(t[(q_idx, q_idx)]);
171                if sub <= tol * diag_sum {
172                    t[(q_idx, q_idx - 1)] = T::zero();
173                    break;
174                }
175                q_idx -= 1;
176            }
177
178            if q_idx == p - 1 {
179                // 1×1 block converged
180                p -= 1;
181            } else if q_idx == p - 2 {
182                // Check if 2×2 block has converged (complex eigenvalues)
183                let a11 = t[(p - 2, p - 2)];
184                let a12 = t[(p - 2, p - 1)];
185                let a21 = t[(p - 1, p - 2)];
186                let a22 = t[(p - 1, p - 1)];
187                let trace = a11 + a22;
188                let det = a11 * a22 - a12 * a21;
189                let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
190
191                if disc < T::zero() {
192                    // Complex eigenvalues, keep as 2×2 block
193                    p -= 2;
194                } else {
195                    // Real eigenvalues, continue iteration
196                    Self::francis_qr_step(&mut t, &mut q, q_idx, p);
197                }
198            } else {
199                // Apply Francis QR step
200                Self::francis_qr_step(&mut t, &mut q, q_idx, p);
201            }
202        }
203
204        // Handle remaining 2×2 block if needed
205        if p == 2 {
206            let sub = Scalar::abs(t[(1, 0)]);
207            let diag_sum = Scalar::abs(t[(0, 0)]) + Scalar::abs(t[(1, 1)]);
208            if sub <= tol * diag_sum {
209                t[(1, 0)] = T::zero();
210            }
211        }
212
213        // Extract eigenvalues
214        let eigenvalues = Self::extract_eigenvalues(&t);
215
216        Ok(Self {
217            q,
218            t,
219            eigenvalues,
220            n,
221        })
222    }
223
224    /// Computes Schur decomposition for 2×2 matrix.
225    fn compute_2x2(a: MatRef<'_, T>) -> Result<Self, SchurError> {
226        let mut t = Mat::zeros(2, 2);
227        for i in 0..2 {
228            for j in 0..2 {
229                t[(i, j)] = a[(i, j)];
230            }
231        }
232
233        let a11 = a[(0, 0)];
234        let a12 = a[(0, 1)];
235        let a21 = a[(1, 0)];
236        let a22 = a[(1, 1)];
237
238        let trace = a11 + a22;
239        let det = a11 * a22 - a12 * a21;
240        let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
241
242        let mut q = Mat::zeros(2, 2);
243        let eigenvalues: Vec<Eigenvalue<T>>;
244
245        if disc >= T::zero() {
246            // Real eigenvalues - triangularize
247            let sqrt_disc = Real::sqrt(disc);
248            let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
249            let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
250
251            // Find rotation to triangularize
252            if Scalar::abs(a21) > <T as Scalar>::epsilon() {
253                let theta = if Scalar::abs(a11 - lambda1) > <T as Scalar>::epsilon() {
254                    Real::atan2(a21, a11 - lambda1)
255                } else {
256                    T::from_f64(core::f64::consts::FRAC_PI_2).unwrap_or_else(T::zero)
257                };
258                let c = Real::cos(theta);
259                let s = Real::sin(theta);
260
261                q[(0, 0)] = c;
262                q[(0, 1)] = -s;
263                q[(1, 0)] = s;
264                q[(1, 1)] = c;
265
266                // T = Q^T * A * Q
267                let mut temp = Mat::zeros(2, 2);
268                // temp = Q^T * A
269                for i in 0..2 {
270                    for j in 0..2 {
271                        let mut sum = T::zero();
272                        for k in 0..2 {
273                            sum = sum + q[(k, i)] * a[(k, j)];
274                        }
275                        temp[(i, j)] = sum;
276                    }
277                }
278                // t = temp * Q
279                for i in 0..2 {
280                    for j in 0..2 {
281                        let mut sum = T::zero();
282                        for k in 0..2 {
283                            sum = sum + temp[(i, k)] * q[(k, j)];
284                        }
285                        t[(i, j)] = sum;
286                    }
287                }
288            } else {
289                // Already triangular
290                q[(0, 0)] = T::one();
291                q[(1, 1)] = T::one();
292            }
293
294            eigenvalues = vec![
295                Eigenvalue::real_only(lambda1),
296                Eigenvalue::real_only(lambda2),
297            ];
298        } else {
299            // Complex eigenvalues - keep as 2×2 block
300            let sqrt_disc = Real::sqrt(-disc);
301            let real_part = trace / T::from_f64(2.0).unwrap_or_else(T::zero);
302            let imag_part = sqrt_disc / T::from_f64(2.0).unwrap_or_else(T::zero);
303
304            q[(0, 0)] = T::one();
305            q[(1, 1)] = T::one();
306
307            eigenvalues = vec![
308                Eigenvalue::complex(real_part, imag_part),
309                Eigenvalue::complex(real_part, -imag_part),
310            ];
311        }
312
313        Ok(Self {
314            q,
315            t,
316            eigenvalues,
317            n: 2,
318        })
319    }
320
321    /// Applies one step of Francis double-shift QR iteration.
322    fn francis_qr_step(t: &mut Mat<T>, q: &mut Mat<T>, start: usize, end: usize) {
323        let n = t.nrows();
324
325        if end - start < 2 {
326            return;
327        }
328
329        // Compute shift from bottom 2×2 block
330        let h11 = t[(end - 2, end - 2)];
331        let h12 = t[(end - 2, end - 1)];
332        let h21 = t[(end - 1, end - 2)];
333        let h22 = t[(end - 1, end - 1)];
334
335        let s = h11 + h22; // trace
336        let p = h11 * h22 - h12 * h21; // determinant
337
338        // First column of (H - s1*I)(H - s2*I) = H² - s*H + p*I
339        let h_00 = t[(start, start)];
340        let h_01 = t[(start, start + 1)];
341        let h_10 = t[(start + 1, start)];
342
343        let mut x = h_00 * h_00 + h_01 * h_10 - s * h_00 + p;
344        let mut y = h_10 * (h_00 + t[(start + 1, start + 1)] - s);
345        let mut z = if start + 2 < end {
346            h_10 * t[(start + 2, start + 1)]
347        } else {
348            T::zero()
349        };
350
351        // Chase the bulge
352        for k in start..end.saturating_sub(1) {
353            // Compute Householder to zero out y, z
354            let (v, tau) = householder_3(&[x, y, z]);
355
356            if tau != T::zero() {
357                let r = if k > start { k - 1 } else { k };
358
359                // Apply from left: T := (I - tau * v * v^T) * T
360                let col_start = r;
361                let col_end = n;
362                for j in col_start..col_end {
363                    let rows = (k..(k + 3).min(end)).collect::<Vec<_>>();
364                    let mut dot = T::zero();
365                    for (vi, &row) in rows.iter().enumerate() {
366                        dot = dot + v[vi] * t[(row, j)];
367                    }
368                    let scaled = tau * dot;
369                    for (vi, &row) in rows.iter().enumerate() {
370                        t[(row, j)] = t[(row, j)] - scaled * v[vi];
371                    }
372                }
373
374                // Apply from right: T := T * (I - tau * v * v^T)
375                let row_end = (k + 4).min(end);
376                for i in 0..row_end {
377                    let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
378                    let mut dot = T::zero();
379                    for (vi, &col) in cols.iter().enumerate() {
380                        dot = dot + t[(i, col)] * v[vi];
381                    }
382                    let scaled = tau * dot;
383                    for (vi, &col) in cols.iter().enumerate() {
384                        t[(i, col)] = t[(i, col)] - scaled * v[vi];
385                    }
386                }
387
388                // Accumulate Q
389                for i in 0..n {
390                    let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
391                    let mut dot = T::zero();
392                    for (vi, &col) in cols.iter().enumerate() {
393                        dot = dot + q[(i, col)] * v[vi];
394                    }
395                    let scaled = tau * dot;
396                    for (vi, &col) in cols.iter().enumerate() {
397                        q[(i, col)] = q[(i, col)] - scaled * v[vi];
398                    }
399                }
400            }
401
402            // Prepare for next iteration
403            if k + 3 < end {
404                x = t[(k + 1, k)];
405                y = t[(k + 2, k)];
406                z = if k + 3 < end {
407                    t[(k + 3, k)]
408                } else {
409                    T::zero()
410                };
411            } else if k + 2 < end {
412                // 2×2 Householder for the last step
413                x = t[(k + 1, k)];
414                y = t[(k + 2, k)];
415                z = T::zero();
416            }
417        }
418    }
419
420    /// Extracts eigenvalues from the quasi-upper triangular Schur form.
421    fn extract_eigenvalues(t: &Mat<T>) -> Vec<Eigenvalue<T>> {
422        let n = t.nrows();
423        let mut eigenvalues = Vec::with_capacity(n);
424        let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
425
426        let mut i = 0;
427        while i < n {
428            if i == n - 1 {
429                // Last element is a 1×1 block
430                eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
431                i += 1;
432            } else {
433                // Check if this is a 2×2 block
434                let sub = Scalar::abs(t[(i + 1, i)]);
435                let diag_sum = Scalar::abs(t[(i, i)]) + Scalar::abs(t[(i + 1, i + 1)]);
436
437                if sub <= eps * diag_sum {
438                    // 1×1 block
439                    eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
440                    i += 1;
441                } else {
442                    // 2×2 block - compute eigenvalues
443                    let a11 = t[(i, i)];
444                    let a12 = t[(i, i + 1)];
445                    let a21 = t[(i + 1, i)];
446                    let a22 = t[(i + 1, i + 1)];
447
448                    let trace = a11 + a22;
449                    let det = a11 * a22 - a12 * a21;
450                    let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
451
452                    if disc >= T::zero() {
453                        // Real eigenvalues
454                        let sqrt_disc = Real::sqrt(disc);
455                        let lambda1 =
456                            (trace + sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
457                        let lambda2 =
458                            (trace - sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
459                        eigenvalues.push(Eigenvalue::real_only(lambda1));
460                        eigenvalues.push(Eigenvalue::real_only(lambda2));
461                    } else {
462                        // Complex conjugate pair
463                        let sqrt_disc = Real::sqrt(-disc);
464                        let real_part = trace / T::from_f64(2.0).unwrap_or_else(T::zero);
465                        let imag_part = sqrt_disc / T::from_f64(2.0).unwrap_or_else(T::zero);
466                        eigenvalues.push(Eigenvalue::complex(real_part, imag_part));
467                        eigenvalues.push(Eigenvalue::complex(real_part, -imag_part));
468                    }
469                    i += 2;
470                }
471            }
472        }
473
474        eigenvalues
475    }
476
477    /// Returns the orthogonal matrix Q (Schur vectors).
478    pub fn q(&self) -> MatRef<'_, T> {
479        self.q.as_ref()
480    }
481
482    /// Returns the quasi-upper triangular matrix T (Schur form).
483    pub fn t(&self) -> MatRef<'_, T> {
484        self.t.as_ref()
485    }
486
487    /// Returns the eigenvalues.
488    pub fn eigenvalues(&self) -> &[Eigenvalue<T>] {
489        &self.eigenvalues
490    }
491
492    /// Returns only the real parts of eigenvalues.
493    pub fn eigenvalues_real(&self) -> Vec<T> {
494        self.eigenvalues.iter().map(|e| e.real).collect()
495    }
496
497    /// Reconstructs the original matrix: A = Q T Q^T.
498    pub fn reconstruct(&self) -> Mat<T> {
499        let mut qt = Mat::zeros(self.n, self.n);
500        let mut a = Mat::zeros(self.n, self.n);
501
502        // QT = Q * T
503        for i in 0..self.n {
504            for j in 0..self.n {
505                let mut sum = T::zero();
506                for k in 0..self.n {
507                    sum = sum + self.q[(i, k)] * self.t[(k, j)];
508                }
509                qt[(i, j)] = sum;
510            }
511        }
512
513        // A = QT * Q^T
514        for i in 0..self.n {
515            for j in 0..self.n {
516                let mut sum = T::zero();
517                for k in 0..self.n {
518                    sum = sum + qt[(i, k)] * self.q[(j, k)];
519                }
520                a[(i, j)] = sum;
521            }
522        }
523
524        a
525    }
526
527    /// Computes right eigenvectors from the Schur form (LAPACK TREVC).
528    ///
529    /// Returns the right eigenvectors V such that T V = V D where D is the
530    /// diagonal matrix of eigenvalues. The eigenvectors are normalized.
531    ///
532    /// For real eigenvalues, returns real eigenvectors.
533    /// For complex conjugate pairs, returns two columns: the real and imaginary
534    /// parts of the eigenvector.
535    ///
536    /// # Example
537    ///
538    /// ```
539    /// use oxiblas_lapack::evd::Schur;
540    /// use oxiblas_matrix::Mat;
541    ///
542    /// let a = Mat::from_rows(&[
543    ///     &[1.0f64, 2.0, 3.0],
544    ///     &[0.0, 4.0, 5.0],
545    ///     &[0.0, 0.0, 6.0],
546    /// ]);
547    ///
548    /// let schur = Schur::compute(a.as_ref()).unwrap();
549    /// let vr = schur.right_eigenvectors();
550    /// // Each column of vr is a right eigenvector
551    /// ```
552    #[must_use]
553    pub fn right_eigenvectors(&self) -> Mat<T> {
554        trevc_right(&self.t)
555    }
556
557    /// Computes left eigenvectors from the Schur form (LAPACK TREVC).
558    ///
559    /// Returns the left eigenvectors U such that U^T T = D U^T where D is the
560    /// diagonal matrix of eigenvalues. The eigenvectors are normalized.
561    ///
562    /// # Example
563    ///
564    /// ```
565    /// use oxiblas_lapack::evd::Schur;
566    /// use oxiblas_matrix::Mat;
567    ///
568    /// let a = Mat::from_rows(&[
569    ///     &[1.0f64, 2.0, 3.0],
570    ///     &[0.0, 4.0, 5.0],
571    ///     &[0.0, 0.0, 6.0],
572    /// ]);
573    ///
574    /// let schur = Schur::compute(a.as_ref()).unwrap();
575    /// let vl = schur.left_eigenvectors();
576    /// ```
577    #[must_use]
578    pub fn left_eigenvectors(&self) -> Mat<T> {
579        trevc_left(&self.t)
580    }
581
582    /// Computes eigenvectors of the original matrix A = Q T Q^T.
583    ///
584    /// The eigenvectors of A are Q * V where V are the eigenvectors of T.
585    ///
586    /// # Returns
587    ///
588    /// (right_eigenvectors, left_eigenvectors) of A.
589    #[must_use]
590    pub fn eigenvectors(&self) -> (Mat<T>, Mat<T>) {
591        let vr_t = trevc_right(&self.t);
592        let vl_t = trevc_left(&self.t);
593
594        // Transform to eigenvectors of A: A_vr = Q * T_vr, A_vl = Q * T_vl
595        let mut vr_a = Mat::zeros(self.n, self.n);
596        let mut vl_a = Mat::zeros(self.n, self.n);
597
598        for i in 0..self.n {
599            for j in 0..self.n {
600                let mut sum_r = T::zero();
601                let mut sum_l = T::zero();
602                for k in 0..self.n {
603                    sum_r = sum_r + self.q[(i, k)] * vr_t[(k, j)];
604                    sum_l = sum_l + self.q[(i, k)] * vl_t[(k, j)];
605                }
606                vr_a[(i, j)] = sum_r;
607                vl_a[(i, j)] = sum_l;
608            }
609        }
610
611        (vr_a, vl_a)
612    }
613
614    /// Computes reciprocal condition numbers for eigenvalues (LAPACK DTRSNA).
615    ///
616    /// For each eigenvalue λ, the reciprocal condition number s measures how
617    /// sensitive λ is to perturbations in the matrix. A small value indicates
618    /// a poorly conditioned eigenvalue.
619    ///
620    /// The condition number is computed as:
621    /// - For simple eigenvalues: s = 1 / |y^H x| where x is the right eigenvector
622    ///   and y is the left eigenvector, both normalized to unit length.
623    /// - For complex conjugate pairs: uses the average of the pair.
624    ///
625    /// # Returns
626    ///
627    /// Vector of reciprocal condition numbers, one per eigenvalue.
628    /// Smaller values indicate more sensitive eigenvalues.
629    ///
630    /// # Example
631    ///
632    /// ```
633    /// use oxiblas_lapack::evd::Schur;
634    /// use oxiblas_matrix::Mat;
635    ///
636    /// let a = Mat::from_rows(&[
637    ///     &[1.0f64, 0.0],
638    ///     &[0.0, 1000.0],
639    /// ]);
640    ///
641    /// let schur = Schur::compute(a.as_ref()).unwrap();
642    /// let cond = schur.eigenvalue_condition_numbers();
643    /// // Both eigenvalues of a diagonal matrix are well-conditioned
644    /// assert!(cond[0] > 0.9);
645    /// assert!(cond[1] > 0.9);
646    /// ```
647    #[must_use]
648    pub fn eigenvalue_condition_numbers(&self) -> Vec<T> {
649        trsna_s(&self.t)
650    }
651
652    /// Computes reciprocal condition numbers for eigenvectors (LAPACK DTRSNA).
653    ///
654    /// For each right eigenvector x_j, computes the separation sep_j which
655    /// measures how close the eigenvalue is to the rest of the spectrum.
656    ///
657    /// # Returns
658    ///
659    /// Vector of separation values, one per eigenvalue.
660    /// Smaller values indicate more sensitive eigenvectors.
661    #[must_use]
662    pub fn eigenvector_separation(&self) -> Vec<T> {
663        trsna_sep(&self.t)
664    }
665}
666
667/// Computes reciprocal condition numbers for eigenvalues (s values from LAPACK DTRSNA).
668///
669/// For each simple eigenvalue λ_j, s_j = |y_j^H * x_j| where x_j is the right
670/// eigenvector and y_j is the left eigenvector, both normalized.
671///
672/// # Arguments
673///
674/// * `t` - The quasi-upper triangular Schur matrix
675///
676/// # Returns
677///
678/// Vector of reciprocal condition numbers for each eigenvalue.
679pub fn trsna_s<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
680    let n = t.nrows();
681    if n == 0 {
682        return Vec::new();
683    }
684
685    let vr = trevc_right(t);
686    let vl = trevc_left(t);
687
688    let mut s = vec![T::zero(); n];
689    let eps = <T as Scalar>::epsilon();
690
691    let mut j = 0;
692    while j < n {
693        // Check for 2×2 block
694        let is_2x2 = if j + 1 < n {
695            let sub = Scalar::abs(t[(j + 1, j)]);
696            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
697            sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
698        } else {
699            false
700        };
701
702        if is_2x2 {
703            // Complex conjugate pair: compute s = |y^H * x| using both columns
704            // For complex eigenvector stored as (real, imag) in consecutive columns:
705            // y^H * x = (yr - i*yi)^T * (xr + i*xi) = (yr^T*xr + yi^T*xi) + i*(yr^T*xi - yi^T*xr)
706            let jp1 = j + 1;
707
708            let mut prod_rr = T::zero(); // yr^T * xr
709            let mut prod_ii = T::zero(); // yi^T * xi
710            let mut prod_ri = T::zero(); // yr^T * xi
711            let mut prod_ir = T::zero(); // yi^T * xr
712
713            for k in 0..n {
714                prod_rr = prod_rr + vl[(k, j)] * vr[(k, j)];
715                prod_ii = prod_ii + vl[(k, jp1)] * vr[(k, jp1)];
716                prod_ri = prod_ri + vl[(k, j)] * vr[(k, jp1)];
717                prod_ir = prod_ir + vl[(k, jp1)] * vr[(k, j)];
718            }
719
720            let real_part = prod_rr + prod_ii;
721            let imag_part = prod_ri - prod_ir;
722            let abs_inner = Real::sqrt(real_part * real_part + imag_part * imag_part);
723
724            // Both eigenvalues in the pair have the same condition number
725            s[j] = abs_inner;
726            s[jp1] = abs_inner;
727
728            j += 2;
729        } else {
730            // Simple real eigenvalue: s = |y^T * x|
731            let mut inner = T::zero();
732            for k in 0..n {
733                inner = inner + vl[(k, j)] * vr[(k, j)];
734            }
735            s[j] = Scalar::abs(inner);
736            j += 1;
737        }
738    }
739
740    s
741}
742
743/// Computes separation (sep) for eigenvectors (from LAPACK DTRSNA).
744///
745/// For each eigenvalue λ_j, sep_j = σ_min(T_22 - λ_j * I) where T_22 is the
746/// (n-1)×(n-1) trailing principal submatrix with λ_j removed.
747///
748/// This measures how separated λ_j is from the rest of the spectrum.
749///
750/// # Arguments
751///
752/// * `t` - The quasi-upper triangular Schur matrix
753///
754/// # Returns
755///
756/// Vector of separation values for each eigenvalue.
757pub fn trsna_sep<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
758    let n = t.nrows();
759    if n == 0 {
760        return Vec::new();
761    }
762
763    let mut sep = vec![T::zero(); n];
764    let eps = <T as Scalar>::epsilon();
765
766    let mut j = 0;
767    while j < n {
768        // Check for 2×2 block
769        let is_2x2 = if j + 1 < n {
770            let sub = Scalar::abs(t[(j + 1, j)]);
771            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
772            sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
773        } else {
774            false
775        };
776
777        if is_2x2 {
778            // Complex conjugate pair
779            let jp1 = j + 1;
780
781            // Compute approximate separation as minimum distance to other eigenvalues
782            let a11 = t[(j, j)];
783            let a22 = t[(jp1, jp1)];
784            let lambda_real = (a11 + a22) / T::from_f64(2.0).unwrap_or_else(T::zero);
785            let trace = a11 + a22;
786            let det = a11 * a22 - t[(j, jp1)] * t[(jp1, j)];
787            let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
788            let lambda_imag = if disc < T::zero() {
789                Real::sqrt(-disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
790            } else {
791                T::zero()
792            };
793
794            let mut min_sep = T::one() / eps;
795
796            // Check distance to all other eigenvalues
797            let mut k = 0;
798            while k < n {
799                if k == j || k == jp1 {
800                    k += 1;
801                    continue;
802                }
803
804                // Check if k is part of a 2×2 block
805                let adjacent = j > 0 && k == j - 1;
806                let k_is_2x2 = if k + 1 < n && !adjacent {
807                    let sub = Scalar::abs(t[(k + 1, k)]);
808                    let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
809                    sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
810                } else {
811                    false
812                };
813
814                let (other_real, other_imag) = if k_is_2x2 {
815                    let kp1 = k + 1;
816                    let b11 = t[(k, k)];
817                    let b22 = t[(kp1, kp1)];
818                    let other_trace = b11 + b22;
819                    let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
820                    let other_disc = other_trace * other_trace
821                        - T::from_f64(4.0).unwrap_or_else(T::zero) * other_det;
822                    let r = (b11 + b22) / T::from_f64(2.0).unwrap_or_else(T::zero);
823                    let i = if other_disc < T::zero() {
824                        Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
825                    } else {
826                        T::zero()
827                    };
828                    (r, i)
829                } else {
830                    (t[(k, k)], T::zero())
831                };
832
833                // Distance between eigenvalues
834                let dr = lambda_real - other_real;
835                let di = lambda_imag - other_imag;
836                let dist = Real::sqrt(dr * dr + di * di);
837                if dist < min_sep && dist > T::zero() {
838                    min_sep = dist;
839                }
840
841                // Also check conjugate
842                if other_imag != T::zero() {
843                    let di_conj = lambda_imag + other_imag;
844                    let dist_conj = Real::sqrt(dr * dr + di_conj * di_conj);
845                    if dist_conj < min_sep && dist_conj > T::zero() {
846                        min_sep = dist_conj;
847                    }
848                }
849
850                if k_is_2x2 {
851                    k += 2;
852                } else {
853                    k += 1;
854                }
855            }
856
857            sep[j] = min_sep;
858            sep[jp1] = min_sep;
859            j += 2;
860        } else {
861            // Simple real eigenvalue
862            let lambda = t[(j, j)];
863            let mut min_sep = T::one() / eps;
864
865            // Check distance to all other eigenvalues
866            let mut k = 0;
867            while k < n {
868                if k == j {
869                    k += 1;
870                    continue;
871                }
872
873                // Check if k is part of a 2×2 block
874                let adjacent_to_j = (j > 0 && k == j - 1) || k == j + 1;
875                let k_is_2x2 = if k + 1 < n && !adjacent_to_j {
876                    let sub = Scalar::abs(t[(k + 1, k)]);
877                    let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
878                    sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
879                } else {
880                    false
881                };
882
883                let (other_real, other_imag) = if k_is_2x2 {
884                    let kp1 = k + 1;
885                    let b11 = t[(k, k)];
886                    let b22 = t[(kp1, kp1)];
887                    let other_trace = b11 + b22;
888                    let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
889                    let other_disc = other_trace * other_trace
890                        - T::from_f64(4.0).unwrap_or_else(T::zero) * other_det;
891                    let r = (b11 + b22) / T::from_f64(2.0).unwrap_or_else(T::zero);
892                    let i = if other_disc < T::zero() {
893                        Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
894                    } else {
895                        T::zero()
896                    };
897                    (r, i)
898                } else {
899                    (t[(k, k)], T::zero())
900                };
901
902                // Distance between eigenvalues
903                let dr = lambda - other_real;
904                let dist = Real::sqrt(dr * dr + other_imag * other_imag);
905                if dist < min_sep && dist > T::zero() {
906                    min_sep = dist;
907                }
908
909                if k_is_2x2 {
910                    k += 2;
911                } else {
912                    k += 1;
913                }
914            }
915
916            sep[j] = min_sep;
917            j += 1;
918        }
919    }
920
921    sep
922}
923
924/// Computes right eigenvectors of a quasi-upper triangular matrix (LAPACK DTREVC).
925///
926/// The input T must be in real Schur form (quasi-upper triangular with 1×1 and
927/// 2×2 diagonal blocks).
928///
929/// # Arguments
930///
931/// * `t` - The quasi-upper triangular Schur matrix
932///
933/// # Returns
934///
935/// Matrix V of right eigenvectors (column-wise). For complex conjugate pairs,
936/// consecutive columns contain the real and imaginary parts.
937pub fn trevc_right<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
938    let n = t.nrows();
939    let mut v = Mat::zeros(n, n);
940
941    // Initialize to identity for back-substitution starting point
942    for i in 0..n {
943        v[(i, i)] = T::one();
944    }
945
946    let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
947
948    // Process eigenvalues from last to first
949    let mut j = n;
950    while j > 0 {
951        j -= 1;
952
953        // Check if this is part of a 2×2 block
954        let is_2x2 = if j > 0 {
955            let sub = Scalar::abs(t[(j, j - 1)]);
956            let diag_sum = Scalar::abs(t[(j - 1, j - 1)]) + Scalar::abs(t[(j, j)]);
957            sub > eps * diag_sum
958        } else {
959            false
960        };
961
962        if is_2x2 {
963            // 2×2 block: complex conjugate eigenvalues
964            // Process columns j-1 and j together
965            let jm1 = j - 1;
966
967            // Eigenvalues of 2×2 block
968            let a11 = t[(jm1, jm1)];
969            let a12 = t[(jm1, j)];
970            let a21 = t[(j, jm1)];
971            let a22 = t[(j, j)];
972
973            let trace = a11 + a22;
974            let det = a11 * a22 - a12 * a21;
975            let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
976
977            // Complex eigenvalues: λ = (trace ± i*sqrt(-disc)) / 2
978            let two = T::from_f64(2.0).unwrap_or_else(T::zero);
979            let real_part = trace / two;
980            let imag_part = Real::sqrt(-disc) / two;
981
982            // For the 2×2 block, set up eigenvector components
983            // v[jm1] = real part, v[j] = imag part for first eigenvector
984            v[(jm1, jm1)] = T::one();
985            v[(j, jm1)] = T::zero();
986            v[(jm1, j)] = T::zero();
987            v[(j, j)] = T::one();
988
989            // Compute actual eigenvector of 2×2 block
990            // (a11 - λ) v1 + a12 v2 = 0
991            // a21 v1 + (a22 - λ) v2 = 0
992            // For λ = real_part + i*imag_part:
993            // Let v = vr + i*vi
994            // Then: (T - λI)(vr + i*vi) = 0
995            // Real: (T - real_part*I)vr + imag_part*vi = 0
996            // Imag: (T - real_part*I)vi - imag_part*vr = 0
997
998            // For the 2×2 block itself, find normalized eigenvector
999            // The eigenvector satisfies (T - λI)v = 0 where λ = real_part + i*imag_part
1000            // For complex eigenvector v = vr + i*vi:
1001            // (T - real_part*I)vr + imag_part*vi = 0  (real part)
1002            // (T - real_part*I)vi - imag_part*vr = 0  (imag part)
1003            //
1004            // For the 2×2 case, we can set v2 = 1 + 0i and solve for v1
1005            // From row 2: a21*v1r + (a22-real)*v2r + imag*v2i = 0 (real)
1006            //             a21*v1i + (a22-real)*v2i - imag*v2r = 0 (imag)
1007            // With v2r=1, v2i=0:
1008            //   a21*v1r + (a22-real) = 0  =>  v1r = -(a22-real)/a21 = -d22/a21
1009            //   a21*v1i - imag = 0        =>  v1i = imag/a21
1010
1011            // Use the more stable formulation based on LAPACK
1012            // For standardized eigenvector, use row with larger coefficient
1013            let d11 = a11 - real_part;
1014            let d22 = a22 - real_part;
1015
1016            // Choose the row with larger off-diagonal to avoid division by small number
1017            if Scalar::abs(a21) >= Scalar::abs(a12) && Scalar::abs(a21) > eps {
1018                // Use row 2: a21*v1 + d22*v2 = 0 (with v2=1)
1019                // v1r = -d22/a21, v1i = imag/a21
1020                let v1r = -d22 / a21;
1021                let v1i = imag_part / a21;
1022                v[(jm1, jm1)] = v1r;
1023                v[(j, jm1)] = T::one();
1024                v[(jm1, j)] = v1i;
1025                v[(j, j)] = T::zero();
1026            } else if Scalar::abs(a12) > eps {
1027                // Use row 1: d11*v1 + a12*v2 = 0 (with v1=1)
1028                // v2r = -d11/a12, v2i = -imag/a12
1029                let v2r = -d11 / a12;
1030                let v2i = -imag_part / a12;
1031                v[(jm1, jm1)] = T::one();
1032                v[(j, jm1)] = v2r;
1033                v[(jm1, j)] = T::zero();
1034                v[(j, j)] = v2i;
1035            } else {
1036                // Fallback to standard basis
1037                v[(jm1, jm1)] = T::one();
1038                v[(j, jm1)] = T::zero();
1039                v[(jm1, j)] = T::zero();
1040                v[(j, j)] = T::one();
1041            }
1042
1043            // Back-substitute for rows above the 2×2 block
1044            for i in (0..jm1).rev() {
1045                // Solve for v[i,jm1] and v[i,j] (real and imag parts)
1046                // (T[i,i] - real_part) * vr[i] + imag_part * vi[i] = -sum of upper terms (real)
1047                // (T[i,i] - real_part) * vi[i] - imag_part * vr[i] = -sum of upper terms (imag)
1048
1049                let mut sum_r = T::zero();
1050                let mut sum_i = T::zero();
1051                for k in (i + 1)..=j {
1052                    sum_r = sum_r + t[(i, k)] * v[(k, jm1)];
1053                    sum_i = sum_i + t[(i, k)] * v[(k, j)];
1054                }
1055
1056                let d = t[(i, i)] - real_part;
1057                let det_2x2 = d * d + imag_part * imag_part;
1058
1059                if Scalar::abs(det_2x2) > eps {
1060                    // Solve 2×2 system:
1061                    // [d, imag] [vr]   [-sum_r]
1062                    // [-imag, d] [vi] = [-sum_i]
1063                    v[(i, jm1)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1064                    v[(i, j)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1065                }
1066            }
1067
1068            // Normalize the two columns
1069            let mut norm_r_sq = T::zero();
1070            let mut norm_i_sq = T::zero();
1071            for i in 0..n {
1072                norm_r_sq = norm_r_sq + v[(i, jm1)] * v[(i, jm1)];
1073                norm_i_sq = norm_i_sq + v[(i, j)] * v[(i, j)];
1074            }
1075            let norm = Real::sqrt(norm_r_sq + norm_i_sq);
1076            if norm > T::zero() {
1077                for i in 0..n {
1078                    v[(i, jm1)] = v[(i, jm1)] / norm;
1079                    v[(i, j)] = v[(i, j)] / norm;
1080                }
1081            }
1082
1083            j = jm1; // Skip the already-processed column
1084        } else {
1085            // 1×1 block: real eigenvalue
1086            let lambda = t[(j, j)];
1087
1088            // Initialize: v[j,j] = 1, others computed by back-substitution
1089            v[(j, j)] = T::one();
1090
1091            // Back-substitute: (T[i,i] - λ) v[i] = -sum_{k>i} T[i,k] v[k]
1092            for i in (0..j).rev() {
1093                let mut sum = T::zero();
1094                for k in (i + 1)..=j {
1095                    sum = sum + t[(i, k)] * v[(k, j)];
1096                }
1097
1098                let d = t[(i, i)] - lambda;
1099                if Scalar::abs(d) > eps {
1100                    v[(i, j)] = -sum / d;
1101                } else {
1102                    // Near-singular: use small perturbation
1103                    v[(i, j)] = -sum / eps;
1104                }
1105            }
1106
1107            // Normalize
1108            let mut norm_sq = T::zero();
1109            for i in 0..n {
1110                norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1111            }
1112            let norm = Real::sqrt(norm_sq);
1113            if norm > T::zero() {
1114                for i in 0..n {
1115                    v[(i, j)] = v[(i, j)] / norm;
1116                }
1117            }
1118        }
1119    }
1120
1121    v
1122}
1123
1124/// Computes left eigenvectors of a quasi-upper triangular matrix (LAPACK DTREVC).
1125///
1126/// The input T must be in real Schur form. Returns left eigenvectors U such that
1127/// U^T T = D U^T where D is diagonal.
1128///
1129/// # Arguments
1130///
1131/// * `t` - The quasi-upper triangular Schur matrix
1132///
1133/// # Returns
1134///
1135/// Matrix U of left eigenvectors (column-wise).
1136pub fn trevc_left<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
1137    let n = t.nrows();
1138    let mut v = Mat::zeros(n, n);
1139
1140    // Initialize
1141    for i in 0..n {
1142        v[(i, i)] = T::one();
1143    }
1144
1145    let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
1146
1147    // Process eigenvalues from first to last (forward substitution for left eigenvectors)
1148    let mut j = 0;
1149    while j < n {
1150        // Check if this is part of a 2×2 block
1151        let is_2x2 = if j + 1 < n {
1152            let sub = Scalar::abs(t[(j + 1, j)]);
1153            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
1154            sub > eps * diag_sum
1155        } else {
1156            false
1157        };
1158
1159        if is_2x2 {
1160            // 2×2 block: complex conjugate eigenvalues
1161            let jp1 = j + 1;
1162
1163            let a11 = t[(j, j)];
1164            let a12 = t[(j, jp1)];
1165            let a21 = t[(jp1, j)];
1166            let a22 = t[(jp1, jp1)];
1167
1168            let trace = a11 + a22;
1169            let det = a11 * a22 - a12 * a21;
1170            let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
1171
1172            let two = T::from_f64(2.0).unwrap_or_else(T::zero);
1173            let real_part = trace / two;
1174            let imag_part = Real::sqrt(-disc) / two;
1175
1176            // Initialize 2×2 block eigenvector
1177            let d11 = a11 - real_part;
1178            let det_factor = d11 * d11 + imag_part * imag_part;
1179
1180            if Scalar::abs(det_factor) > eps {
1181                let vr2 = -a21 * d11 / det_factor;
1182                let vi2 = -imag_part * a21 / det_factor;
1183                v[(j, j)] = T::one();
1184                v[(jp1, j)] = vr2;
1185                v[(j, jp1)] = T::zero();
1186                v[(jp1, jp1)] = vi2;
1187            } else {
1188                v[(j, j)] = T::one();
1189                v[(jp1, j)] = T::zero();
1190                v[(j, jp1)] = T::zero();
1191                v[(jp1, jp1)] = T::one();
1192            }
1193
1194            // Forward substitute for rows after the 2×2 block
1195            for i in (jp1 + 1)..n {
1196                let mut sum_r = T::zero();
1197                let mut sum_i = T::zero();
1198                for k in j..i {
1199                    sum_r = sum_r + t[(k, i)] * v[(k, j)];
1200                    sum_i = sum_i + t[(k, i)] * v[(k, jp1)];
1201                }
1202
1203                let d = t[(i, i)] - real_part;
1204                let det_2x2 = d * d + imag_part * imag_part;
1205
1206                if Scalar::abs(det_2x2) > eps {
1207                    v[(i, j)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1208                    v[(i, jp1)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1209                }
1210            }
1211
1212            // Normalize
1213            let mut norm_sq = T::zero();
1214            for i in 0..n {
1215                norm_sq = norm_sq + v[(i, j)] * v[(i, j)] + v[(i, jp1)] * v[(i, jp1)];
1216            }
1217            let norm = Real::sqrt(norm_sq);
1218            if norm > T::zero() {
1219                for i in 0..n {
1220                    v[(i, j)] = v[(i, j)] / norm;
1221                    v[(i, jp1)] = v[(i, jp1)] / norm;
1222                }
1223            }
1224
1225            j = jp1 + 1;
1226        } else {
1227            // 1×1 block: real eigenvalue
1228            let lambda = t[(j, j)];
1229            v[(j, j)] = T::one();
1230
1231            // Forward substitute: (T[i,i] - λ) v[i] = -sum_{k<i} T[k,i] v[k]
1232            for i in (j + 1)..n {
1233                let mut sum = T::zero();
1234                for k in j..i {
1235                    sum = sum + t[(k, i)] * v[(k, j)];
1236                }
1237
1238                let d = t[(i, i)] - lambda;
1239                if Scalar::abs(d) > eps {
1240                    v[(i, j)] = -sum / d;
1241                } else {
1242                    v[(i, j)] = -sum / eps;
1243                }
1244            }
1245
1246            // Normalize
1247            let mut norm_sq = T::zero();
1248            for i in 0..n {
1249                norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1250            }
1251            let norm = Real::sqrt(norm_sq);
1252            if norm > T::zero() {
1253                for i in 0..n {
1254                    v[(i, j)] = v[(i, j)] / norm;
1255                }
1256            }
1257
1258            j += 1;
1259        }
1260    }
1261
1262    v
1263}
1264
1265/// Computes a Householder vector for a 3-element (or smaller) vector.
1266fn householder_3<T: Field + Real>(x: &[T]) -> (Vec<T>, T) {
1267    let n = x.len().min(3);
1268    if n == 0 {
1269        return (Vec::new(), T::zero());
1270    }
1271
1272    let mut norm_sq = T::zero();
1273    for i in 0..n {
1274        norm_sq = norm_sq + x[i] * x[i];
1275    }
1276    let norm = Real::sqrt(norm_sq);
1277
1278    if norm == T::zero() {
1279        return (vec![T::zero(); n], T::zero());
1280    }
1281
1282    let mut v = vec![T::zero(); n];
1283    for i in 0..n {
1284        v[i] = x[i];
1285    }
1286
1287    let sign = if x[0] >= T::zero() {
1288        T::one()
1289    } else {
1290        -T::one()
1291    };
1292    v[0] = v[0] + sign * norm;
1293
1294    let mut v_norm_sq = T::zero();
1295    for i in 0..n {
1296        v_norm_sq = v_norm_sq + v[i] * v[i];
1297    }
1298
1299    if v_norm_sq > T::zero() {
1300        let tau = T::from_f64(2.0).unwrap_or_else(T::zero) / v_norm_sq;
1301        (v, tau)
1302    } else {
1303        (v, T::zero())
1304    }
1305}
1306
1307#[cfg(test)]
1308mod tests {
1309    use super::*;
1310
1311    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
1312        (a - b).abs() < tol
1313    }
1314
1315    #[test]
1316    fn test_schur_upper_triangular() {
1317        // Already upper triangular
1318        let a = Mat::from_rows(&[&[1.0f64, 2.0], &[0.0, 3.0]]);
1319
1320        let schur = Schur::compute(a.as_ref()).unwrap();
1321        let eigenvalues = schur.eigenvalues();
1322
1323        // Eigenvalues should be 1 and 3
1324        let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1325        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1326        assert!(approx_eq(eigs[0], 1.0, 1e-10));
1327        assert!(approx_eq(eigs[1], 3.0, 1e-10));
1328    }
1329
1330    #[test]
1331    fn test_schur_diagonal() {
1332        let a = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1333
1334        let schur = Schur::compute(a.as_ref()).unwrap();
1335        let eigenvalues = schur.eigenvalues();
1336
1337        let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1338        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1339        assert!(approx_eq(eigs[0], 2.0, 1e-10));
1340        assert!(approx_eq(eigs[1], 3.0, 1e-10));
1341        assert!(approx_eq(eigs[2], 5.0, 1e-10));
1342    }
1343
1344    #[test]
1345    fn test_schur_complex_eigenvalues() {
1346        // Rotation matrix - has complex eigenvalues
1347        let theta = core::f64::consts::FRAC_PI_4; // 45 degrees
1348        let c = theta.cos();
1349        let s = theta.sin();
1350        let a = Mat::from_rows(&[&[c, -s], &[s, c]]);
1351
1352        let schur = Schur::compute(a.as_ref()).unwrap();
1353        let eigenvalues = schur.eigenvalues();
1354
1355        // Eigenvalues should be cos(θ) ± i*sin(θ)
1356        assert_eq!(eigenvalues.len(), 2);
1357        // They should be complex conjugates
1358        assert!(approx_eq(eigenvalues[0].real, eigenvalues[1].real, 1e-10));
1359        assert!(approx_eq(eigenvalues[0].imag, -eigenvalues[1].imag, 1e-10));
1360        assert!(approx_eq(eigenvalues[0].real, c, 1e-10));
1361        assert!(approx_eq(eigenvalues[0].imag.abs(), s, 1e-10));
1362    }
1363
1364    #[test]
1365    fn test_schur_reconstruction() {
1366        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1367
1368        let schur = Schur::compute(a.as_ref()).unwrap();
1369        let reconstructed = schur.reconstruct();
1370
1371        for i in 0..3 {
1372            for j in 0..3 {
1373                assert!(
1374                    approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-10),
1375                    "reconstructed[{},{}] = {}, a = {}",
1376                    i,
1377                    j,
1378                    reconstructed[(i, j)],
1379                    a[(i, j)]
1380                );
1381            }
1382        }
1383    }
1384
1385    #[test]
1386    fn test_schur_q_orthogonal() {
1387        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1388
1389        let schur = Schur::compute(a.as_ref()).unwrap();
1390        let q = schur.q();
1391
1392        // Check Q^T * Q = I
1393        let n = 3;
1394        for i in 0..n {
1395            for j in 0..n {
1396                let mut dot = 0.0;
1397                for k in 0..n {
1398                    dot += q[(k, i)] * q[(k, j)];
1399                }
1400                let expected = if i == j { 1.0 } else { 0.0 };
1401                assert!(
1402                    approx_eq(dot, expected, 1e-10),
1403                    "Q^T*Q[{},{}] = {}, expected {}",
1404                    i,
1405                    j,
1406                    dot,
1407                    expected
1408                );
1409            }
1410        }
1411    }
1412
1413    #[test]
1414    fn test_schur_single() {
1415        let a = Mat::from_rows(&[&[5.0f64]]);
1416        let schur = Schur::compute(a.as_ref()).unwrap();
1417
1418        assert_eq!(schur.eigenvalues().len(), 1);
1419        assert!(approx_eq(schur.eigenvalues()[0].real, 5.0, 1e-10));
1420    }
1421
1422    #[test]
1423    fn test_schur_4x4() {
1424        let a = Mat::from_rows(&[
1425            &[4.0f64, 1.0, -2.0, 2.0],
1426            &[1.0, 2.0, 0.0, 1.0],
1427            &[-2.0, 0.0, 3.0, -2.0],
1428            &[2.0, 1.0, -2.0, -1.0],
1429        ]);
1430
1431        let schur = Schur::compute(a.as_ref()).unwrap();
1432        let reconstructed = schur.reconstruct();
1433
1434        for i in 0..4 {
1435            for j in 0..4 {
1436                assert!(
1437                    approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-8),
1438                    "reconstructed[{},{}] = {}, a = {}",
1439                    i,
1440                    j,
1441                    reconstructed[(i, j)],
1442                    a[(i, j)]
1443                );
1444            }
1445        }
1446    }
1447
1448    #[test]
1449    fn test_schur_f32() {
1450        let a = Mat::from_rows(&[&[1.0f32, 2.0], &[3.0, 4.0]]);
1451
1452        let schur = Schur::compute(a.as_ref()).unwrap();
1453        let reconstructed = schur.reconstruct();
1454
1455        for i in 0..2 {
1456            for j in 0..2 {
1457                assert!(
1458                    (reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-4,
1459                    "reconstructed[{},{}] = {}, a = {}",
1460                    i,
1461                    j,
1462                    reconstructed[(i, j)],
1463                    a[(i, j)]
1464                );
1465            }
1466        }
1467    }
1468
1469    #[test]
1470    fn test_trevc_right_upper_triangular() {
1471        // Upper triangular matrix - eigenvectors should be standard basis vectors
1472        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1473
1474        let v = trevc_right(&t);
1475
1476        // First eigenvector for λ=1 should be proportional to [1, 0, 0]
1477        assert!(v[(1, 0)].abs() < 1e-10);
1478        assert!(v[(2, 0)].abs() < 1e-10);
1479
1480        // Third eigenvector for λ=6 should be proportional to [*, *, 1]
1481        // (normalized, so third component is non-zero)
1482        assert!(v[(2, 2)].abs() > 0.1);
1483    }
1484
1485    #[test]
1486    fn test_trevc_right_diagonal() {
1487        // Diagonal matrix - eigenvectors are standard basis
1488        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1489
1490        let v = trevc_right(&t);
1491
1492        // Should be identity (or close to it with normalization)
1493        for i in 0..3 {
1494            assert!(
1495                approx_eq(v[(i, i)].abs(), 1.0, 1e-10),
1496                "v[{},{}] = {}",
1497                i,
1498                i,
1499                v[(i, i)]
1500            );
1501        }
1502    }
1503
1504    #[test]
1505    fn test_trevc_eigenvector_equation() {
1506        // Test that T * v = λ * v for computed eigenvectors
1507        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1508
1509        let v = trevc_right(&t);
1510
1511        // Eigenvalues are 1, 4, 6 (diagonal elements)
1512        let eigenvalues = [1.0, 4.0, 6.0];
1513
1514        for (j, &lambda) in eigenvalues.iter().enumerate() {
1515            // Compute T * v[:,j]
1516            let mut tv = [0.0; 3];
1517            for i in 0..3 {
1518                for k in 0..3 {
1519                    tv[i] += t[(i, k)] * v[(k, j)];
1520                }
1521            }
1522
1523            // Check T * v = λ * v
1524            for i in 0..3 {
1525                assert!(
1526                    approx_eq(tv[i], lambda * v[(i, j)], 1e-10),
1527                    "T*v[{}] = {}, λ*v = {}",
1528                    i,
1529                    tv[i],
1530                    lambda * v[(i, j)]
1531                );
1532            }
1533        }
1534    }
1535
1536    #[test]
1537    fn test_trevc_left_diagonal() {
1538        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1539
1540        let u = trevc_left(&t);
1541
1542        // Should be identity
1543        for i in 0..3 {
1544            assert!(
1545                approx_eq(u[(i, i)].abs(), 1.0, 1e-10),
1546                "u[{},{}] = {}",
1547                i,
1548                i,
1549                u[(i, i)]
1550            );
1551        }
1552    }
1553
1554    #[test]
1555    fn test_schur_eigenvectors() {
1556        // Test eigenvectors through the Schur decomposition
1557        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1558
1559        let schur = Schur::compute(a.as_ref()).unwrap();
1560        let (vr, _vl) = schur.eigenvectors();
1561
1562        // For upper triangular matrix, A = T, so eigenvectors are the same
1563        // Check that A * v = λ * v
1564        for j in 0..3 {
1565            let lambda = schur.eigenvalues()[j].real;
1566
1567            let mut av = [0.0; 3];
1568            for i in 0..3 {
1569                for k in 0..3 {
1570                    av[i] += a[(i, k)] * vr[(k, j)];
1571                }
1572            }
1573
1574            for i in 0..3 {
1575                assert!(
1576                    approx_eq(av[i], lambda * vr[(i, j)], 1e-8),
1577                    "A*v[{}] = {}, λ*v = {}",
1578                    i,
1579                    av[i],
1580                    lambda * vr[(i, j)]
1581                );
1582            }
1583        }
1584    }
1585
1586    #[test]
1587    fn test_trevc_2x2_block() {
1588        // Create a 2x2 block with complex eigenvalues: rotation matrix
1589        let theta = core::f64::consts::FRAC_PI_4;
1590        let c = theta.cos();
1591        let s = theta.sin();
1592        let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1593
1594        let v = trevc_right(&t);
1595
1596        // For complex eigenvalues, columns should contain real and imag parts
1597        // The eigenvector equation (T - λI)v = 0 where λ = c + i*s
1598        // Check that both columns are normalized
1599        let norm0_sq = v[(0, 0)] * v[(0, 0)] + v[(1, 0)] * v[(1, 0)];
1600        let norm1_sq = v[(0, 1)] * v[(0, 1)] + v[(1, 1)] * v[(1, 1)];
1601        let total_norm = (norm0_sq + norm1_sq).sqrt();
1602
1603        assert!(
1604            approx_eq(total_norm, 1.0, 1e-10),
1605            "eigenvector norm = {}",
1606            total_norm
1607        );
1608    }
1609
1610    #[test]
1611    fn test_trsna_s_diagonal() {
1612        // Diagonal matrix: left and right eigenvectors are standard basis
1613        // so s = |e_i^T e_i| = 1 for all eigenvalues
1614        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1615
1616        let s = trsna_s(&t);
1617
1618        assert_eq!(s.len(), 3);
1619        for i in 0..3 {
1620            assert!(
1621                approx_eq(s[i], 1.0, 1e-10),
1622                "s[{}] = {}, expected 1.0",
1623                i,
1624                s[i]
1625            );
1626        }
1627    }
1628
1629    #[test]
1630    fn test_trsna_s_upper_triangular() {
1631        // Upper triangular: eigenvalues are diagonal elements
1632        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1633
1634        let s = trsna_s(&t);
1635
1636        assert_eq!(s.len(), 3);
1637        // All condition numbers should be positive
1638        for i in 0..3 {
1639            assert!(s[i] > 0.0, "s[{}] = {} should be positive", i, s[i]);
1640        }
1641    }
1642
1643    #[test]
1644    fn test_trsna_sep_diagonal() {
1645        // Diagonal matrix: separation is the minimum distance to other eigenvalues
1646        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1647
1648        let sep = trsna_sep(&t);
1649
1650        assert_eq!(sep.len(), 3);
1651        // λ=2: min dist to 3,5 is 1
1652        assert!(
1653            approx_eq(sep[0], 1.0, 1e-10),
1654            "sep[0] = {}, expected 1.0",
1655            sep[0]
1656        );
1657        // λ=5: min dist to 2,3 is 2
1658        assert!(
1659            approx_eq(sep[1], 2.0, 1e-10),
1660            "sep[1] = {}, expected 2.0",
1661            sep[1]
1662        );
1663        // λ=3: min dist to 2,5 is 1
1664        assert!(
1665            approx_eq(sep[2], 1.0, 1e-10),
1666            "sep[2] = {}, expected 1.0",
1667            sep[2]
1668        );
1669    }
1670
1671    #[test]
1672    fn test_trsna_sep_close_eigenvalues() {
1673        // Eigenvalues very close together should have small separation
1674        let t = Mat::from_rows(&[&[1.0f64, 0.0], &[0.0, 1.001]]);
1675
1676        let sep = trsna_sep(&t);
1677
1678        assert_eq!(sep.len(), 2);
1679        // Both should have separation close to 0.001
1680        assert!(
1681            approx_eq(sep[0], 0.001, 1e-10),
1682            "sep[0] = {}, expected 0.001",
1683            sep[0]
1684        );
1685        assert!(
1686            approx_eq(sep[1], 0.001, 1e-10),
1687            "sep[1] = {}, expected 0.001",
1688            sep[1]
1689        );
1690    }
1691
1692    #[test]
1693    fn test_schur_eigenvalue_condition_numbers() {
1694        // Test through Schur decomposition
1695        let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1696
1697        let schur = Schur::compute(a.as_ref()).unwrap();
1698        let cond = schur.eigenvalue_condition_numbers();
1699
1700        assert_eq!(cond.len(), 2);
1701        // Diagonal matrix should have well-conditioned eigenvalues
1702        for i in 0..2 {
1703            assert!(
1704                cond[i] > 0.5,
1705                "cond[{}] = {} should be > 0.5 for diagonal matrix",
1706                i,
1707                cond[i]
1708            );
1709        }
1710    }
1711
1712    #[test]
1713    fn test_schur_eigenvector_separation() {
1714        // Test through Schur decomposition
1715        let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1716
1717        let schur = Schur::compute(a.as_ref()).unwrap();
1718        let sep = schur.eigenvector_separation();
1719
1720        assert_eq!(sep.len(), 2);
1721        // Eigenvalues are 2 and 4, so separation should be 2
1722        for i in 0..2 {
1723            assert!(
1724                approx_eq(sep[i], 2.0, 1e-10),
1725                "sep[{}] = {}, expected 2.0",
1726                i,
1727                sep[i]
1728            );
1729        }
1730    }
1731
1732    #[test]
1733    fn test_trsna_complex_eigenvalues() {
1734        // Rotation matrix with complex eigenvalues
1735        let theta = core::f64::consts::FRAC_PI_4;
1736        let c = theta.cos();
1737        let s = theta.sin();
1738        let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1739
1740        let cond = trsna_s(&t);
1741        let sep = trsna_sep(&t);
1742
1743        assert_eq!(cond.len(), 2);
1744        assert_eq!(sep.len(), 2);
1745
1746        // Both eigenvalues in a complex pair should have the same condition
1747        assert!(
1748            approx_eq(cond[0], cond[1], 1e-10),
1749            "cond[0]={}, cond[1]={} should be equal",
1750            cond[0],
1751            cond[1]
1752        );
1753        assert!(
1754            approx_eq(sep[0], sep[1], 1e-10),
1755            "sep[0]={}, sep[1]={} should be equal",
1756            sep[0],
1757            sep[1]
1758        );
1759    }
1760}