oxiblas_lapack/evd/
schur.rs

1//! Schur decomposition.
2//!
3//! Computes the Schur decomposition A = Q T Q^T where Q is orthogonal
4//! and T is quasi-upper triangular (upper triangular with possible 2×2 blocks
5//! on the diagonal for complex eigenvalue pairs).
6
7use oxiblas_core::scalar::{Field, Real, Scalar};
8use oxiblas_matrix::{Mat, MatRef};
9
10use super::hessenberg::Hessenberg;
11
12/// Error type for Schur decomposition.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchurError {
15    /// Matrix is empty.
16    EmptyMatrix,
17    /// Matrix is not square.
18    NotSquare,
19    /// Algorithm did not converge.
20    NotConverged,
21}
22
23impl core::fmt::Display for SchurError {
24    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25        match self {
26            Self::EmptyMatrix => write!(f, "Matrix is empty"),
27            Self::NotSquare => write!(f, "Matrix must be square"),
28            Self::NotConverged => write!(f, "Schur decomposition did not converge"),
29        }
30    }
31}
32
33impl std::error::Error for SchurError {}
34
35/// Represents a real or complex eigenvalue.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Eigenvalue<T> {
38    /// Real part of the eigenvalue.
39    pub real: T,
40    /// Imaginary part of the eigenvalue (zero for real eigenvalues).
41    pub imag: T,
42}
43
44impl<T: Scalar> Eigenvalue<T> {
45    /// Creates a real eigenvalue.
46    pub fn real_only(value: T) -> Self {
47        Self {
48            real: value,
49            imag: T::zero(),
50        }
51    }
52
53    /// Creates a complex eigenvalue.
54    pub fn complex(real: T, imag: T) -> Self {
55        Self { real, imag }
56    }
57
58    /// Returns true if this is a real eigenvalue.
59    pub fn is_real(&self) -> bool {
60        self.imag == T::zero()
61    }
62}
63
64/// Schur decomposition of a matrix.
65///
66/// For a matrix A, computes A = Q T Q^T where:
67/// - Q is orthogonal (Q^T Q = I)
68/// - T is quasi-upper triangular (real Schur form)
69#[derive(Debug, Clone)]
70pub struct Schur<T: Scalar> {
71    /// The orthogonal matrix Q (Schur vectors).
72    q: Mat<T>,
73    /// The quasi-upper triangular matrix T.
74    t: Mat<T>,
75    /// Eigenvalues (real and complex pairs).
76    eigenvalues: Vec<Eigenvalue<T>>,
77    /// Matrix dimension.
78    n: usize,
79}
80
81impl<T: Field + Real + bytemuck::Zeroable> Schur<T> {
82    /// Maximum iterations for QR iteration.
83    const MAX_ITERATIONS: usize = 100;
84
85    /// Computes the Schur decomposition of a square matrix.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use oxiblas_lapack::evd::Schur;
91    /// use oxiblas_matrix::Mat;
92    ///
93    /// let a = Mat::from_rows(&[
94    ///     &[1.0f64, 2.0],
95    ///     &[0.0, 3.0],
96    /// ]);
97    ///
98    /// let schur = Schur::compute(a.as_ref()).unwrap();
99    /// let eigenvalues = schur.eigenvalues();
100    ///
101    /// // Eigenvalues are 1 and 3
102    /// assert!((eigenvalues[0].real - 1.0).abs() < 1e-10 || (eigenvalues[0].real - 3.0).abs() < 1e-10);
103    /// ```
104    pub fn compute(a: MatRef<'_, T>) -> Result<Self, SchurError> {
105        let m = a.nrows();
106        let n = a.ncols();
107
108        if m == 0 || n == 0 {
109            return Err(SchurError::EmptyMatrix);
110        }
111
112        if m != n {
113            return Err(SchurError::NotSquare);
114        }
115
116        // Handle 1×1 case
117        if n == 1 {
118            let mut t = Mat::zeros(1, 1);
119            t[(0, 0)] = a[(0, 0)];
120            let mut q = Mat::zeros(1, 1);
121            q[(0, 0)] = T::one();
122            let eigenvalues = vec![Eigenvalue::real_only(a[(0, 0)])];
123            return Ok(Self {
124                q,
125                t,
126                eigenvalues,
127                n,
128            });
129        }
130
131        // Handle 2×2 case
132        if n == 2 {
133            return Self::compute_2x2(a);
134        }
135
136        // Step 1: Reduce to upper Hessenberg form
137        let hess = Hessenberg::compute(a).map_err(|_| SchurError::NotSquare)?;
138        let mut t = Mat::zeros(n, n);
139        let h = hess.h();
140        for i in 0..n {
141            for j in 0..n {
142                t[(i, j)] = h[(i, j)];
143            }
144        }
145
146        let mut q = Mat::zeros(n, n);
147        let q_hess = hess.q();
148        for i in 0..n {
149            for j in 0..n {
150                q[(i, j)] = q_hess[(i, j)];
151            }
152        }
153
154        // Step 2: Apply QR iteration with implicit shifts
155        let eps = <T as Scalar>::epsilon();
156        let tol = eps * T::from_f64(100.0).unwrap_or(T::one());
157
158        // Process from bottom to top
159        let mut p = n;
160        let mut iter_count = 0;
161
162        while p > 2 && iter_count < Self::MAX_ITERATIONS * n {
163            iter_count += 1;
164
165            // Find the active block
166            let mut q_idx = p - 1;
167            while q_idx > 0 {
168                let sub = Scalar::abs(t[(q_idx, q_idx - 1)]);
169                let diag_sum =
170                    Scalar::abs(t[(q_idx - 1, q_idx - 1)]) + Scalar::abs(t[(q_idx, q_idx)]);
171                if sub <= tol * diag_sum {
172                    t[(q_idx, q_idx - 1)] = T::zero();
173                    break;
174                }
175                q_idx -= 1;
176            }
177
178            if q_idx == p - 1 {
179                // 1×1 block converged
180                p -= 1;
181            } else if q_idx == p - 2 {
182                // Check if 2×2 block has converged (complex eigenvalues)
183                let a11 = t[(p - 2, p - 2)];
184                let a12 = t[(p - 2, p - 1)];
185                let a21 = t[(p - 1, p - 2)];
186                let a22 = t[(p - 1, p - 1)];
187                let trace = a11 + a22;
188                let det = a11 * a22 - a12 * a21;
189                let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
190
191                if disc < T::zero() {
192                    // Complex eigenvalues, keep as 2×2 block
193                    p -= 2;
194                } else {
195                    // Real eigenvalues, continue iteration
196                    Self::francis_qr_step(&mut t, &mut q, q_idx, p);
197                }
198            } else {
199                // Apply Francis QR step
200                Self::francis_qr_step(&mut t, &mut q, q_idx, p);
201            }
202        }
203
204        // Handle remaining 2×2 block if needed
205        if p == 2 {
206            let sub = Scalar::abs(t[(1, 0)]);
207            let diag_sum = Scalar::abs(t[(0, 0)]) + Scalar::abs(t[(1, 1)]);
208            if sub <= tol * diag_sum {
209                t[(1, 0)] = T::zero();
210            }
211        }
212
213        // Extract eigenvalues
214        let eigenvalues = Self::extract_eigenvalues(&t);
215
216        Ok(Self {
217            q,
218            t,
219            eigenvalues,
220            n,
221        })
222    }
223
224    /// Computes Schur decomposition for 2×2 matrix.
225    fn compute_2x2(a: MatRef<'_, T>) -> Result<Self, SchurError> {
226        let mut t = Mat::zeros(2, 2);
227        for i in 0..2 {
228            for j in 0..2 {
229                t[(i, j)] = a[(i, j)];
230            }
231        }
232
233        let a11 = a[(0, 0)];
234        let a12 = a[(0, 1)];
235        let a21 = a[(1, 0)];
236        let a22 = a[(1, 1)];
237
238        let trace = a11 + a22;
239        let det = a11 * a22 - a12 * a21;
240        let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
241
242        let mut q = Mat::zeros(2, 2);
243        let eigenvalues: Vec<Eigenvalue<T>>;
244
245        if disc >= T::zero() {
246            // Real eigenvalues - triangularize
247            let sqrt_disc = Real::sqrt(disc);
248            let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap();
249            let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap();
250
251            // Find rotation to triangularize
252            if Scalar::abs(a21) > <T as Scalar>::epsilon() {
253                let theta = if Scalar::abs(a11 - lambda1) > <T as Scalar>::epsilon() {
254                    Real::atan2(a21, a11 - lambda1)
255                } else {
256                    T::from_f64(core::f64::consts::FRAC_PI_2).unwrap()
257                };
258                let c = Real::cos(theta);
259                let s = Real::sin(theta);
260
261                q[(0, 0)] = c;
262                q[(0, 1)] = -s;
263                q[(1, 0)] = s;
264                q[(1, 1)] = c;
265
266                // T = Q^T * A * Q
267                let mut temp = Mat::zeros(2, 2);
268                // temp = Q^T * A
269                for i in 0..2 {
270                    for j in 0..2 {
271                        let mut sum = T::zero();
272                        for k in 0..2 {
273                            sum = sum + q[(k, i)] * a[(k, j)];
274                        }
275                        temp[(i, j)] = sum;
276                    }
277                }
278                // t = temp * Q
279                for i in 0..2 {
280                    for j in 0..2 {
281                        let mut sum = T::zero();
282                        for k in 0..2 {
283                            sum = sum + temp[(i, k)] * q[(k, j)];
284                        }
285                        t[(i, j)] = sum;
286                    }
287                }
288            } else {
289                // Already triangular
290                q[(0, 0)] = T::one();
291                q[(1, 1)] = T::one();
292            }
293
294            eigenvalues = vec![
295                Eigenvalue::real_only(lambda1),
296                Eigenvalue::real_only(lambda2),
297            ];
298        } else {
299            // Complex eigenvalues - keep as 2×2 block
300            let sqrt_disc = Real::sqrt(-disc);
301            let real_part = trace / T::from_f64(2.0).unwrap();
302            let imag_part = sqrt_disc / T::from_f64(2.0).unwrap();
303
304            q[(0, 0)] = T::one();
305            q[(1, 1)] = T::one();
306
307            eigenvalues = vec![
308                Eigenvalue::complex(real_part, imag_part),
309                Eigenvalue::complex(real_part, -imag_part),
310            ];
311        }
312
313        Ok(Self {
314            q,
315            t,
316            eigenvalues,
317            n: 2,
318        })
319    }
320
321    /// Applies one step of Francis double-shift QR iteration.
322    fn francis_qr_step(t: &mut Mat<T>, q: &mut Mat<T>, start: usize, end: usize) {
323        let n = t.nrows();
324
325        if end - start < 2 {
326            return;
327        }
328
329        // Compute shift from bottom 2×2 block
330        let h11 = t[(end - 2, end - 2)];
331        let h12 = t[(end - 2, end - 1)];
332        let h21 = t[(end - 1, end - 2)];
333        let h22 = t[(end - 1, end - 1)];
334
335        let s = h11 + h22; // trace
336        let p = h11 * h22 - h12 * h21; // determinant
337
338        // First column of (H - s1*I)(H - s2*I) = H² - s*H + p*I
339        let h_00 = t[(start, start)];
340        let h_01 = t[(start, start + 1)];
341        let h_10 = t[(start + 1, start)];
342
343        let mut x = h_00 * h_00 + h_01 * h_10 - s * h_00 + p;
344        let mut y = h_10 * (h_00 + t[(start + 1, start + 1)] - s);
345        let mut z = if start + 2 < end {
346            h_10 * t[(start + 2, start + 1)]
347        } else {
348            T::zero()
349        };
350
351        // Chase the bulge
352        for k in start..end.saturating_sub(1) {
353            // Compute Householder to zero out y, z
354            let (v, tau) = householder_3(&[x, y, z]);
355
356            if tau != T::zero() {
357                let r = if k > start { k - 1 } else { k };
358
359                // Apply from left: T := (I - tau * v * v^T) * T
360                let col_start = r;
361                let col_end = n;
362                for j in col_start..col_end {
363                    let rows = (k..(k + 3).min(end)).collect::<Vec<_>>();
364                    let mut dot = T::zero();
365                    for (vi, &row) in rows.iter().enumerate() {
366                        dot = dot + v[vi] * t[(row, j)];
367                    }
368                    let scaled = tau * dot;
369                    for (vi, &row) in rows.iter().enumerate() {
370                        t[(row, j)] = t[(row, j)] - scaled * v[vi];
371                    }
372                }
373
374                // Apply from right: T := T * (I - tau * v * v^T)
375                let row_end = (k + 4).min(end);
376                for i in 0..row_end {
377                    let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
378                    let mut dot = T::zero();
379                    for (vi, &col) in cols.iter().enumerate() {
380                        dot = dot + t[(i, col)] * v[vi];
381                    }
382                    let scaled = tau * dot;
383                    for (vi, &col) in cols.iter().enumerate() {
384                        t[(i, col)] = t[(i, col)] - scaled * v[vi];
385                    }
386                }
387
388                // Accumulate Q
389                for i in 0..n {
390                    let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
391                    let mut dot = T::zero();
392                    for (vi, &col) in cols.iter().enumerate() {
393                        dot = dot + q[(i, col)] * v[vi];
394                    }
395                    let scaled = tau * dot;
396                    for (vi, &col) in cols.iter().enumerate() {
397                        q[(i, col)] = q[(i, col)] - scaled * v[vi];
398                    }
399                }
400            }
401
402            // Prepare for next iteration
403            if k + 3 < end {
404                x = t[(k + 1, k)];
405                y = t[(k + 2, k)];
406                z = if k + 3 < end {
407                    t[(k + 3, k)]
408                } else {
409                    T::zero()
410                };
411            } else if k + 2 < end {
412                // 2×2 Householder for the last step
413                x = t[(k + 1, k)];
414                y = t[(k + 2, k)];
415                z = T::zero();
416            }
417        }
418    }
419
420    /// Extracts eigenvalues from the quasi-upper triangular Schur form.
421    fn extract_eigenvalues(t: &Mat<T>) -> Vec<Eigenvalue<T>> {
422        let n = t.nrows();
423        let mut eigenvalues = Vec::with_capacity(n);
424        let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
425
426        let mut i = 0;
427        while i < n {
428            if i == n - 1 {
429                // Last element is a 1×1 block
430                eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
431                i += 1;
432            } else {
433                // Check if this is a 2×2 block
434                let sub = Scalar::abs(t[(i + 1, i)]);
435                let diag_sum = Scalar::abs(t[(i, i)]) + Scalar::abs(t[(i + 1, i + 1)]);
436
437                if sub <= eps * diag_sum {
438                    // 1×1 block
439                    eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
440                    i += 1;
441                } else {
442                    // 2×2 block - compute eigenvalues
443                    let a11 = t[(i, i)];
444                    let a12 = t[(i, i + 1)];
445                    let a21 = t[(i + 1, i)];
446                    let a22 = t[(i + 1, i + 1)];
447
448                    let trace = a11 + a22;
449                    let det = a11 * a22 - a12 * a21;
450                    let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
451
452                    if disc >= T::zero() {
453                        // Real eigenvalues
454                        let sqrt_disc = Real::sqrt(disc);
455                        let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap();
456                        let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap();
457                        eigenvalues.push(Eigenvalue::real_only(lambda1));
458                        eigenvalues.push(Eigenvalue::real_only(lambda2));
459                    } else {
460                        // Complex conjugate pair
461                        let sqrt_disc = Real::sqrt(-disc);
462                        let real_part = trace / T::from_f64(2.0).unwrap();
463                        let imag_part = sqrt_disc / T::from_f64(2.0).unwrap();
464                        eigenvalues.push(Eigenvalue::complex(real_part, imag_part));
465                        eigenvalues.push(Eigenvalue::complex(real_part, -imag_part));
466                    }
467                    i += 2;
468                }
469            }
470        }
471
472        eigenvalues
473    }
474
475    /// Returns the orthogonal matrix Q (Schur vectors).
476    pub fn q(&self) -> MatRef<'_, T> {
477        self.q.as_ref()
478    }
479
480    /// Returns the quasi-upper triangular matrix T (Schur form).
481    pub fn t(&self) -> MatRef<'_, T> {
482        self.t.as_ref()
483    }
484
485    /// Returns the eigenvalues.
486    pub fn eigenvalues(&self) -> &[Eigenvalue<T>] {
487        &self.eigenvalues
488    }
489
490    /// Returns only the real parts of eigenvalues.
491    pub fn eigenvalues_real(&self) -> Vec<T> {
492        self.eigenvalues.iter().map(|e| e.real).collect()
493    }
494
495    /// Reconstructs the original matrix: A = Q T Q^T.
496    pub fn reconstruct(&self) -> Mat<T> {
497        let mut qt = Mat::zeros(self.n, self.n);
498        let mut a = Mat::zeros(self.n, self.n);
499
500        // QT = Q * T
501        for i in 0..self.n {
502            for j in 0..self.n {
503                let mut sum = T::zero();
504                for k in 0..self.n {
505                    sum = sum + self.q[(i, k)] * self.t[(k, j)];
506                }
507                qt[(i, j)] = sum;
508            }
509        }
510
511        // A = QT * Q^T
512        for i in 0..self.n {
513            for j in 0..self.n {
514                let mut sum = T::zero();
515                for k in 0..self.n {
516                    sum = sum + qt[(i, k)] * self.q[(j, k)];
517                }
518                a[(i, j)] = sum;
519            }
520        }
521
522        a
523    }
524
525    /// Computes right eigenvectors from the Schur form (LAPACK TREVC).
526    ///
527    /// Returns the right eigenvectors V such that T V = V D where D is the
528    /// diagonal matrix of eigenvalues. The eigenvectors are normalized.
529    ///
530    /// For real eigenvalues, returns real eigenvectors.
531    /// For complex conjugate pairs, returns two columns: the real and imaginary
532    /// parts of the eigenvector.
533    ///
534    /// # Example
535    ///
536    /// ```
537    /// use oxiblas_lapack::evd::Schur;
538    /// use oxiblas_matrix::Mat;
539    ///
540    /// let a = Mat::from_rows(&[
541    ///     &[1.0f64, 2.0, 3.0],
542    ///     &[0.0, 4.0, 5.0],
543    ///     &[0.0, 0.0, 6.0],
544    /// ]);
545    ///
546    /// let schur = Schur::compute(a.as_ref()).unwrap();
547    /// let vr = schur.right_eigenvectors();
548    /// // Each column of vr is a right eigenvector
549    /// ```
550    #[must_use]
551    pub fn right_eigenvectors(&self) -> Mat<T> {
552        trevc_right(&self.t)
553    }
554
555    /// Computes left eigenvectors from the Schur form (LAPACK TREVC).
556    ///
557    /// Returns the left eigenvectors U such that U^T T = D U^T where D is the
558    /// diagonal matrix of eigenvalues. The eigenvectors are normalized.
559    ///
560    /// # Example
561    ///
562    /// ```
563    /// use oxiblas_lapack::evd::Schur;
564    /// use oxiblas_matrix::Mat;
565    ///
566    /// let a = Mat::from_rows(&[
567    ///     &[1.0f64, 2.0, 3.0],
568    ///     &[0.0, 4.0, 5.0],
569    ///     &[0.0, 0.0, 6.0],
570    /// ]);
571    ///
572    /// let schur = Schur::compute(a.as_ref()).unwrap();
573    /// let vl = schur.left_eigenvectors();
574    /// ```
575    #[must_use]
576    pub fn left_eigenvectors(&self) -> Mat<T> {
577        trevc_left(&self.t)
578    }
579
580    /// Computes eigenvectors of the original matrix A = Q T Q^T.
581    ///
582    /// The eigenvectors of A are Q * V where V are the eigenvectors of T.
583    ///
584    /// # Returns
585    ///
586    /// (right_eigenvectors, left_eigenvectors) of A.
587    #[must_use]
588    pub fn eigenvectors(&self) -> (Mat<T>, Mat<T>) {
589        let vr_t = trevc_right(&self.t);
590        let vl_t = trevc_left(&self.t);
591
592        // Transform to eigenvectors of A: A_vr = Q * T_vr, A_vl = Q * T_vl
593        let mut vr_a = Mat::zeros(self.n, self.n);
594        let mut vl_a = Mat::zeros(self.n, self.n);
595
596        for i in 0..self.n {
597            for j in 0..self.n {
598                let mut sum_r = T::zero();
599                let mut sum_l = T::zero();
600                for k in 0..self.n {
601                    sum_r = sum_r + self.q[(i, k)] * vr_t[(k, j)];
602                    sum_l = sum_l + self.q[(i, k)] * vl_t[(k, j)];
603                }
604                vr_a[(i, j)] = sum_r;
605                vl_a[(i, j)] = sum_l;
606            }
607        }
608
609        (vr_a, vl_a)
610    }
611
612    /// Computes reciprocal condition numbers for eigenvalues (LAPACK DTRSNA).
613    ///
614    /// For each eigenvalue λ, the reciprocal condition number s measures how
615    /// sensitive λ is to perturbations in the matrix. A small value indicates
616    /// a poorly conditioned eigenvalue.
617    ///
618    /// The condition number is computed as:
619    /// - For simple eigenvalues: s = 1 / |y^H x| where x is the right eigenvector
620    ///   and y is the left eigenvector, both normalized to unit length.
621    /// - For complex conjugate pairs: uses the average of the pair.
622    ///
623    /// # Returns
624    ///
625    /// Vector of reciprocal condition numbers, one per eigenvalue.
626    /// Smaller values indicate more sensitive eigenvalues.
627    ///
628    /// # Example
629    ///
630    /// ```
631    /// use oxiblas_lapack::evd::Schur;
632    /// use oxiblas_matrix::Mat;
633    ///
634    /// let a = Mat::from_rows(&[
635    ///     &[1.0f64, 0.0],
636    ///     &[0.0, 1000.0],
637    /// ]);
638    ///
639    /// let schur = Schur::compute(a.as_ref()).unwrap();
640    /// let cond = schur.eigenvalue_condition_numbers();
641    /// // Both eigenvalues of a diagonal matrix are well-conditioned
642    /// assert!(cond[0] > 0.9);
643    /// assert!(cond[1] > 0.9);
644    /// ```
645    #[must_use]
646    pub fn eigenvalue_condition_numbers(&self) -> Vec<T> {
647        trsna_s(&self.t)
648    }
649
650    /// Computes reciprocal condition numbers for eigenvectors (LAPACK DTRSNA).
651    ///
652    /// For each right eigenvector x_j, computes the separation sep_j which
653    /// measures how close the eigenvalue is to the rest of the spectrum.
654    ///
655    /// # Returns
656    ///
657    /// Vector of separation values, one per eigenvalue.
658    /// Smaller values indicate more sensitive eigenvectors.
659    #[must_use]
660    pub fn eigenvector_separation(&self) -> Vec<T> {
661        trsna_sep(&self.t)
662    }
663}
664
665/// Computes reciprocal condition numbers for eigenvalues (s values from LAPACK DTRSNA).
666///
667/// For each simple eigenvalue λ_j, s_j = |y_j^H * x_j| where x_j is the right
668/// eigenvector and y_j is the left eigenvector, both normalized.
669///
670/// # Arguments
671///
672/// * `t` - The quasi-upper triangular Schur matrix
673///
674/// # Returns
675///
676/// Vector of reciprocal condition numbers for each eigenvalue.
677pub fn trsna_s<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
678    let n = t.nrows();
679    if n == 0 {
680        return Vec::new();
681    }
682
683    let vr = trevc_right(t);
684    let vl = trevc_left(t);
685
686    let mut s = vec![T::zero(); n];
687    let eps = <T as Scalar>::epsilon();
688
689    let mut j = 0;
690    while j < n {
691        // Check for 2×2 block
692        let is_2x2 = if j + 1 < n {
693            let sub = Scalar::abs(t[(j + 1, j)]);
694            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
695            sub > eps * T::from_f64(100.0).unwrap() * diag_sum
696        } else {
697            false
698        };
699
700        if is_2x2 {
701            // Complex conjugate pair: compute s = |y^H * x| using both columns
702            // For complex eigenvector stored as (real, imag) in consecutive columns:
703            // y^H * x = (yr - i*yi)^T * (xr + i*xi) = (yr^T*xr + yi^T*xi) + i*(yr^T*xi - yi^T*xr)
704            let jp1 = j + 1;
705
706            let mut prod_rr = T::zero(); // yr^T * xr
707            let mut prod_ii = T::zero(); // yi^T * xi
708            let mut prod_ri = T::zero(); // yr^T * xi
709            let mut prod_ir = T::zero(); // yi^T * xr
710
711            for k in 0..n {
712                prod_rr = prod_rr + vl[(k, j)] * vr[(k, j)];
713                prod_ii = prod_ii + vl[(k, jp1)] * vr[(k, jp1)];
714                prod_ri = prod_ri + vl[(k, j)] * vr[(k, jp1)];
715                prod_ir = prod_ir + vl[(k, jp1)] * vr[(k, j)];
716            }
717
718            let real_part = prod_rr + prod_ii;
719            let imag_part = prod_ri - prod_ir;
720            let abs_inner = Real::sqrt(real_part * real_part + imag_part * imag_part);
721
722            // Both eigenvalues in the pair have the same condition number
723            s[j] = abs_inner;
724            s[jp1] = abs_inner;
725
726            j += 2;
727        } else {
728            // Simple real eigenvalue: s = |y^T * x|
729            let mut inner = T::zero();
730            for k in 0..n {
731                inner = inner + vl[(k, j)] * vr[(k, j)];
732            }
733            s[j] = Scalar::abs(inner);
734            j += 1;
735        }
736    }
737
738    s
739}
740
741/// Computes separation (sep) for eigenvectors (from LAPACK DTRSNA).
742///
743/// For each eigenvalue λ_j, sep_j = σ_min(T_22 - λ_j * I) where T_22 is the
744/// (n-1)×(n-1) trailing principal submatrix with λ_j removed.
745///
746/// This measures how separated λ_j is from the rest of the spectrum.
747///
748/// # Arguments
749///
750/// * `t` - The quasi-upper triangular Schur matrix
751///
752/// # Returns
753///
754/// Vector of separation values for each eigenvalue.
755pub fn trsna_sep<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
756    let n = t.nrows();
757    if n == 0 {
758        return Vec::new();
759    }
760
761    let mut sep = vec![T::zero(); n];
762    let eps = <T as Scalar>::epsilon();
763
764    let mut j = 0;
765    while j < n {
766        // Check for 2×2 block
767        let is_2x2 = if j + 1 < n {
768            let sub = Scalar::abs(t[(j + 1, j)]);
769            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
770            sub > eps * T::from_f64(100.0).unwrap() * diag_sum
771        } else {
772            false
773        };
774
775        if is_2x2 {
776            // Complex conjugate pair
777            let jp1 = j + 1;
778
779            // Compute approximate separation as minimum distance to other eigenvalues
780            let a11 = t[(j, j)];
781            let a22 = t[(jp1, jp1)];
782            let lambda_real = (a11 + a22) / T::from_f64(2.0).unwrap();
783            let trace = a11 + a22;
784            let det = a11 * a22 - t[(j, jp1)] * t[(jp1, j)];
785            let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
786            let lambda_imag = if disc < T::zero() {
787                Real::sqrt(-disc) / T::from_f64(2.0).unwrap()
788            } else {
789                T::zero()
790            };
791
792            let mut min_sep = T::one() / eps;
793
794            // Check distance to all other eigenvalues
795            let mut k = 0;
796            while k < n {
797                if k == j || k == jp1 {
798                    k += 1;
799                    continue;
800                }
801
802                // Check if k is part of a 2×2 block
803                let adjacent = j > 0 && k == j - 1;
804                let k_is_2x2 = if k + 1 < n && !adjacent {
805                    let sub = Scalar::abs(t[(k + 1, k)]);
806                    let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
807                    sub > eps * T::from_f64(100.0).unwrap() * diag_sum
808                } else {
809                    false
810                };
811
812                let (other_real, other_imag) = if k_is_2x2 {
813                    let kp1 = k + 1;
814                    let b11 = t[(k, k)];
815                    let b22 = t[(kp1, kp1)];
816                    let other_trace = b11 + b22;
817                    let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
818                    let other_disc =
819                        other_trace * other_trace - T::from_f64(4.0).unwrap() * other_det;
820                    let r = (b11 + b22) / T::from_f64(2.0).unwrap();
821                    let i = if other_disc < T::zero() {
822                        Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap()
823                    } else {
824                        T::zero()
825                    };
826                    (r, i)
827                } else {
828                    (t[(k, k)], T::zero())
829                };
830
831                // Distance between eigenvalues
832                let dr = lambda_real - other_real;
833                let di = lambda_imag - other_imag;
834                let dist = Real::sqrt(dr * dr + di * di);
835                if dist < min_sep && dist > T::zero() {
836                    min_sep = dist;
837                }
838
839                // Also check conjugate
840                if other_imag != T::zero() {
841                    let di_conj = lambda_imag + other_imag;
842                    let dist_conj = Real::sqrt(dr * dr + di_conj * di_conj);
843                    if dist_conj < min_sep && dist_conj > T::zero() {
844                        min_sep = dist_conj;
845                    }
846                }
847
848                if k_is_2x2 {
849                    k += 2;
850                } else {
851                    k += 1;
852                }
853            }
854
855            sep[j] = min_sep;
856            sep[jp1] = min_sep;
857            j += 2;
858        } else {
859            // Simple real eigenvalue
860            let lambda = t[(j, j)];
861            let mut min_sep = T::one() / eps;
862
863            // Check distance to all other eigenvalues
864            let mut k = 0;
865            while k < n {
866                if k == j {
867                    k += 1;
868                    continue;
869                }
870
871                // Check if k is part of a 2×2 block
872                let adjacent_to_j = (j > 0 && k == j - 1) || k == j + 1;
873                let k_is_2x2 = if k + 1 < n && !adjacent_to_j {
874                    let sub = Scalar::abs(t[(k + 1, k)]);
875                    let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
876                    sub > eps * T::from_f64(100.0).unwrap() * diag_sum
877                } else {
878                    false
879                };
880
881                let (other_real, other_imag) = if k_is_2x2 {
882                    let kp1 = k + 1;
883                    let b11 = t[(k, k)];
884                    let b22 = t[(kp1, kp1)];
885                    let other_trace = b11 + b22;
886                    let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
887                    let other_disc =
888                        other_trace * other_trace - T::from_f64(4.0).unwrap() * other_det;
889                    let r = (b11 + b22) / T::from_f64(2.0).unwrap();
890                    let i = if other_disc < T::zero() {
891                        Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap()
892                    } else {
893                        T::zero()
894                    };
895                    (r, i)
896                } else {
897                    (t[(k, k)], T::zero())
898                };
899
900                // Distance between eigenvalues
901                let dr = lambda - other_real;
902                let dist = Real::sqrt(dr * dr + other_imag * other_imag);
903                if dist < min_sep && dist > T::zero() {
904                    min_sep = dist;
905                }
906
907                if k_is_2x2 {
908                    k += 2;
909                } else {
910                    k += 1;
911                }
912            }
913
914            sep[j] = min_sep;
915            j += 1;
916        }
917    }
918
919    sep
920}
921
922/// Computes right eigenvectors of a quasi-upper triangular matrix (LAPACK DTREVC).
923///
924/// The input T must be in real Schur form (quasi-upper triangular with 1×1 and
925/// 2×2 diagonal blocks).
926///
927/// # Arguments
928///
929/// * `t` - The quasi-upper triangular Schur matrix
930///
931/// # Returns
932///
933/// Matrix V of right eigenvectors (column-wise). For complex conjugate pairs,
934/// consecutive columns contain the real and imaginary parts.
935pub fn trevc_right<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
936    let n = t.nrows();
937    let mut v = Mat::zeros(n, n);
938
939    // Initialize to identity for back-substitution starting point
940    for i in 0..n {
941        v[(i, i)] = T::one();
942    }
943
944    let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
945
946    // Process eigenvalues from last to first
947    let mut j = n;
948    while j > 0 {
949        j -= 1;
950
951        // Check if this is part of a 2×2 block
952        let is_2x2 = if j > 0 {
953            let sub = Scalar::abs(t[(j, j - 1)]);
954            let diag_sum = Scalar::abs(t[(j - 1, j - 1)]) + Scalar::abs(t[(j, j)]);
955            sub > eps * diag_sum
956        } else {
957            false
958        };
959
960        if is_2x2 {
961            // 2×2 block: complex conjugate eigenvalues
962            // Process columns j-1 and j together
963            let jm1 = j - 1;
964
965            // Eigenvalues of 2×2 block
966            let a11 = t[(jm1, jm1)];
967            let a12 = t[(jm1, j)];
968            let a21 = t[(j, jm1)];
969            let a22 = t[(j, j)];
970
971            let trace = a11 + a22;
972            let det = a11 * a22 - a12 * a21;
973            let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
974
975            // Complex eigenvalues: λ = (trace ± i*sqrt(-disc)) / 2
976            let two = T::from_f64(2.0).unwrap();
977            let real_part = trace / two;
978            let imag_part = Real::sqrt(-disc) / two;
979
980            // For the 2×2 block, set up eigenvector components
981            // v[jm1] = real part, v[j] = imag part for first eigenvector
982            v[(jm1, jm1)] = T::one();
983            v[(j, jm1)] = T::zero();
984            v[(jm1, j)] = T::zero();
985            v[(j, j)] = T::one();
986
987            // Compute actual eigenvector of 2×2 block
988            // (a11 - λ) v1 + a12 v2 = 0
989            // a21 v1 + (a22 - λ) v2 = 0
990            // For λ = real_part + i*imag_part:
991            // Let v = vr + i*vi
992            // Then: (T - λI)(vr + i*vi) = 0
993            // Real: (T - real_part*I)vr + imag_part*vi = 0
994            // Imag: (T - real_part*I)vi - imag_part*vr = 0
995
996            // For the 2×2 block itself, find normalized eigenvector
997            // The eigenvector satisfies (T - λI)v = 0 where λ = real_part + i*imag_part
998            // For complex eigenvector v = vr + i*vi:
999            // (T - real_part*I)vr + imag_part*vi = 0  (real part)
1000            // (T - real_part*I)vi - imag_part*vr = 0  (imag part)
1001            //
1002            // For the 2×2 case, we can set v2 = 1 + 0i and solve for v1
1003            // From row 2: a21*v1r + (a22-real)*v2r + imag*v2i = 0 (real)
1004            //             a21*v1i + (a22-real)*v2i - imag*v2r = 0 (imag)
1005            // With v2r=1, v2i=0:
1006            //   a21*v1r + (a22-real) = 0  =>  v1r = -(a22-real)/a21 = -d22/a21
1007            //   a21*v1i - imag = 0        =>  v1i = imag/a21
1008
1009            // Use the more stable formulation based on LAPACK
1010            // For standardized eigenvector, use row with larger coefficient
1011            let d11 = a11 - real_part;
1012            let d22 = a22 - real_part;
1013
1014            // Choose the row with larger off-diagonal to avoid division by small number
1015            if Scalar::abs(a21) >= Scalar::abs(a12) && Scalar::abs(a21) > eps {
1016                // Use row 2: a21*v1 + d22*v2 = 0 (with v2=1)
1017                // v1r = -d22/a21, v1i = imag/a21
1018                let v1r = -d22 / a21;
1019                let v1i = imag_part / a21;
1020                v[(jm1, jm1)] = v1r;
1021                v[(j, jm1)] = T::one();
1022                v[(jm1, j)] = v1i;
1023                v[(j, j)] = T::zero();
1024            } else if Scalar::abs(a12) > eps {
1025                // Use row 1: d11*v1 + a12*v2 = 0 (with v1=1)
1026                // v2r = -d11/a12, v2i = -imag/a12
1027                let v2r = -d11 / a12;
1028                let v2i = -imag_part / a12;
1029                v[(jm1, jm1)] = T::one();
1030                v[(j, jm1)] = v2r;
1031                v[(jm1, j)] = T::zero();
1032                v[(j, j)] = v2i;
1033            } else {
1034                // Fallback to standard basis
1035                v[(jm1, jm1)] = T::one();
1036                v[(j, jm1)] = T::zero();
1037                v[(jm1, j)] = T::zero();
1038                v[(j, j)] = T::one();
1039            }
1040
1041            // Back-substitute for rows above the 2×2 block
1042            for i in (0..jm1).rev() {
1043                // Solve for v[i,jm1] and v[i,j] (real and imag parts)
1044                // (T[i,i] - real_part) * vr[i] + imag_part * vi[i] = -sum of upper terms (real)
1045                // (T[i,i] - real_part) * vi[i] - imag_part * vr[i] = -sum of upper terms (imag)
1046
1047                let mut sum_r = T::zero();
1048                let mut sum_i = T::zero();
1049                for k in (i + 1)..=j {
1050                    sum_r = sum_r + t[(i, k)] * v[(k, jm1)];
1051                    sum_i = sum_i + t[(i, k)] * v[(k, j)];
1052                }
1053
1054                let d = t[(i, i)] - real_part;
1055                let det_2x2 = d * d + imag_part * imag_part;
1056
1057                if Scalar::abs(det_2x2) > eps {
1058                    // Solve 2×2 system:
1059                    // [d, imag] [vr]   [-sum_r]
1060                    // [-imag, d] [vi] = [-sum_i]
1061                    v[(i, jm1)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1062                    v[(i, j)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1063                }
1064            }
1065
1066            // Normalize the two columns
1067            let mut norm_r_sq = T::zero();
1068            let mut norm_i_sq = T::zero();
1069            for i in 0..n {
1070                norm_r_sq = norm_r_sq + v[(i, jm1)] * v[(i, jm1)];
1071                norm_i_sq = norm_i_sq + v[(i, j)] * v[(i, j)];
1072            }
1073            let norm = Real::sqrt(norm_r_sq + norm_i_sq);
1074            if norm > T::zero() {
1075                for i in 0..n {
1076                    v[(i, jm1)] = v[(i, jm1)] / norm;
1077                    v[(i, j)] = v[(i, j)] / norm;
1078                }
1079            }
1080
1081            j = jm1; // Skip the already-processed column
1082        } else {
1083            // 1×1 block: real eigenvalue
1084            let lambda = t[(j, j)];
1085
1086            // Initialize: v[j,j] = 1, others computed by back-substitution
1087            v[(j, j)] = T::one();
1088
1089            // Back-substitute: (T[i,i] - λ) v[i] = -sum_{k>i} T[i,k] v[k]
1090            for i in (0..j).rev() {
1091                let mut sum = T::zero();
1092                for k in (i + 1)..=j {
1093                    sum = sum + t[(i, k)] * v[(k, j)];
1094                }
1095
1096                let d = t[(i, i)] - lambda;
1097                if Scalar::abs(d) > eps {
1098                    v[(i, j)] = -sum / d;
1099                } else {
1100                    // Near-singular: use small perturbation
1101                    v[(i, j)] = -sum / eps;
1102                }
1103            }
1104
1105            // Normalize
1106            let mut norm_sq = T::zero();
1107            for i in 0..n {
1108                norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1109            }
1110            let norm = Real::sqrt(norm_sq);
1111            if norm > T::zero() {
1112                for i in 0..n {
1113                    v[(i, j)] = v[(i, j)] / norm;
1114                }
1115            }
1116        }
1117    }
1118
1119    v
1120}
1121
1122/// Computes left eigenvectors of a quasi-upper triangular matrix (LAPACK DTREVC).
1123///
1124/// The input T must be in real Schur form. Returns left eigenvectors U such that
1125/// U^T T = D U^T where D is diagonal.
1126///
1127/// # Arguments
1128///
1129/// * `t` - The quasi-upper triangular Schur matrix
1130///
1131/// # Returns
1132///
1133/// Matrix U of left eigenvectors (column-wise).
1134pub fn trevc_left<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
1135    let n = t.nrows();
1136    let mut v = Mat::zeros(n, n);
1137
1138    // Initialize
1139    for i in 0..n {
1140        v[(i, i)] = T::one();
1141    }
1142
1143    let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
1144
1145    // Process eigenvalues from first to last (forward substitution for left eigenvectors)
1146    let mut j = 0;
1147    while j < n {
1148        // Check if this is part of a 2×2 block
1149        let is_2x2 = if j + 1 < n {
1150            let sub = Scalar::abs(t[(j + 1, j)]);
1151            let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
1152            sub > eps * diag_sum
1153        } else {
1154            false
1155        };
1156
1157        if is_2x2 {
1158            // 2×2 block: complex conjugate eigenvalues
1159            let jp1 = j + 1;
1160
1161            let a11 = t[(j, j)];
1162            let a12 = t[(j, jp1)];
1163            let a21 = t[(jp1, j)];
1164            let a22 = t[(jp1, jp1)];
1165
1166            let trace = a11 + a22;
1167            let det = a11 * a22 - a12 * a21;
1168            let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
1169
1170            let two = T::from_f64(2.0).unwrap();
1171            let real_part = trace / two;
1172            let imag_part = Real::sqrt(-disc) / two;
1173
1174            // Initialize 2×2 block eigenvector
1175            let d11 = a11 - real_part;
1176            let det_factor = d11 * d11 + imag_part * imag_part;
1177
1178            if Scalar::abs(det_factor) > eps {
1179                let vr2 = -a21 * d11 / det_factor;
1180                let vi2 = -imag_part * a21 / det_factor;
1181                v[(j, j)] = T::one();
1182                v[(jp1, j)] = vr2;
1183                v[(j, jp1)] = T::zero();
1184                v[(jp1, jp1)] = vi2;
1185            } else {
1186                v[(j, j)] = T::one();
1187                v[(jp1, j)] = T::zero();
1188                v[(j, jp1)] = T::zero();
1189                v[(jp1, jp1)] = T::one();
1190            }
1191
1192            // Forward substitute for rows after the 2×2 block
1193            for i in (jp1 + 1)..n {
1194                let mut sum_r = T::zero();
1195                let mut sum_i = T::zero();
1196                for k in j..i {
1197                    sum_r = sum_r + t[(k, i)] * v[(k, j)];
1198                    sum_i = sum_i + t[(k, i)] * v[(k, jp1)];
1199                }
1200
1201                let d = t[(i, i)] - real_part;
1202                let det_2x2 = d * d + imag_part * imag_part;
1203
1204                if Scalar::abs(det_2x2) > eps {
1205                    v[(i, j)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1206                    v[(i, jp1)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1207                }
1208            }
1209
1210            // Normalize
1211            let mut norm_sq = T::zero();
1212            for i in 0..n {
1213                norm_sq = norm_sq + v[(i, j)] * v[(i, j)] + v[(i, jp1)] * v[(i, jp1)];
1214            }
1215            let norm = Real::sqrt(norm_sq);
1216            if norm > T::zero() {
1217                for i in 0..n {
1218                    v[(i, j)] = v[(i, j)] / norm;
1219                    v[(i, jp1)] = v[(i, jp1)] / norm;
1220                }
1221            }
1222
1223            j = jp1 + 1;
1224        } else {
1225            // 1×1 block: real eigenvalue
1226            let lambda = t[(j, j)];
1227            v[(j, j)] = T::one();
1228
1229            // Forward substitute: (T[i,i] - λ) v[i] = -sum_{k<i} T[k,i] v[k]
1230            for i in (j + 1)..n {
1231                let mut sum = T::zero();
1232                for k in j..i {
1233                    sum = sum + t[(k, i)] * v[(k, j)];
1234                }
1235
1236                let d = t[(i, i)] - lambda;
1237                if Scalar::abs(d) > eps {
1238                    v[(i, j)] = -sum / d;
1239                } else {
1240                    v[(i, j)] = -sum / eps;
1241                }
1242            }
1243
1244            // Normalize
1245            let mut norm_sq = T::zero();
1246            for i in 0..n {
1247                norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1248            }
1249            let norm = Real::sqrt(norm_sq);
1250            if norm > T::zero() {
1251                for i in 0..n {
1252                    v[(i, j)] = v[(i, j)] / norm;
1253                }
1254            }
1255
1256            j += 1;
1257        }
1258    }
1259
1260    v
1261}
1262
1263/// Computes a Householder vector for a 3-element (or smaller) vector.
1264fn householder_3<T: Field + Real>(x: &[T]) -> (Vec<T>, T) {
1265    let n = x.len().min(3);
1266    if n == 0 {
1267        return (Vec::new(), T::zero());
1268    }
1269
1270    let mut norm_sq = T::zero();
1271    for i in 0..n {
1272        norm_sq = norm_sq + x[i] * x[i];
1273    }
1274    let norm = Real::sqrt(norm_sq);
1275
1276    if norm == T::zero() {
1277        return (vec![T::zero(); n], T::zero());
1278    }
1279
1280    let mut v = vec![T::zero(); n];
1281    for i in 0..n {
1282        v[i] = x[i];
1283    }
1284
1285    let sign = if x[0] >= T::zero() {
1286        T::one()
1287    } else {
1288        -T::one()
1289    };
1290    v[0] = v[0] + sign * norm;
1291
1292    let mut v_norm_sq = T::zero();
1293    for i in 0..n {
1294        v_norm_sq = v_norm_sq + v[i] * v[i];
1295    }
1296
1297    if v_norm_sq > T::zero() {
1298        let tau = T::from_f64(2.0).unwrap() / v_norm_sq;
1299        (v, tau)
1300    } else {
1301        (v, T::zero())
1302    }
1303}
1304
1305#[cfg(test)]
1306mod tests {
1307    use super::*;
1308
1309    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
1310        (a - b).abs() < tol
1311    }
1312
1313    #[test]
1314    fn test_schur_upper_triangular() {
1315        // Already upper triangular
1316        let a = Mat::from_rows(&[&[1.0f64, 2.0], &[0.0, 3.0]]);
1317
1318        let schur = Schur::compute(a.as_ref()).unwrap();
1319        let eigenvalues = schur.eigenvalues();
1320
1321        // Eigenvalues should be 1 and 3
1322        let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1323        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1324        assert!(approx_eq(eigs[0], 1.0, 1e-10));
1325        assert!(approx_eq(eigs[1], 3.0, 1e-10));
1326    }
1327
1328    #[test]
1329    fn test_schur_diagonal() {
1330        let a = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1331
1332        let schur = Schur::compute(a.as_ref()).unwrap();
1333        let eigenvalues = schur.eigenvalues();
1334
1335        let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1336        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1337        assert!(approx_eq(eigs[0], 2.0, 1e-10));
1338        assert!(approx_eq(eigs[1], 3.0, 1e-10));
1339        assert!(approx_eq(eigs[2], 5.0, 1e-10));
1340    }
1341
1342    #[test]
1343    fn test_schur_complex_eigenvalues() {
1344        // Rotation matrix - has complex eigenvalues
1345        let theta = core::f64::consts::FRAC_PI_4; // 45 degrees
1346        let c = theta.cos();
1347        let s = theta.sin();
1348        let a = Mat::from_rows(&[&[c, -s], &[s, c]]);
1349
1350        let schur = Schur::compute(a.as_ref()).unwrap();
1351        let eigenvalues = schur.eigenvalues();
1352
1353        // Eigenvalues should be cos(θ) ± i*sin(θ)
1354        assert_eq!(eigenvalues.len(), 2);
1355        // They should be complex conjugates
1356        assert!(approx_eq(eigenvalues[0].real, eigenvalues[1].real, 1e-10));
1357        assert!(approx_eq(eigenvalues[0].imag, -eigenvalues[1].imag, 1e-10));
1358        assert!(approx_eq(eigenvalues[0].real, c, 1e-10));
1359        assert!(approx_eq(eigenvalues[0].imag.abs(), s, 1e-10));
1360    }
1361
1362    #[test]
1363    fn test_schur_reconstruction() {
1364        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1365
1366        let schur = Schur::compute(a.as_ref()).unwrap();
1367        let reconstructed = schur.reconstruct();
1368
1369        for i in 0..3 {
1370            for j in 0..3 {
1371                assert!(
1372                    approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-10),
1373                    "reconstructed[{},{}] = {}, a = {}",
1374                    i,
1375                    j,
1376                    reconstructed[(i, j)],
1377                    a[(i, j)]
1378                );
1379            }
1380        }
1381    }
1382
1383    #[test]
1384    fn test_schur_q_orthogonal() {
1385        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1386
1387        let schur = Schur::compute(a.as_ref()).unwrap();
1388        let q = schur.q();
1389
1390        // Check Q^T * Q = I
1391        let n = 3;
1392        for i in 0..n {
1393            for j in 0..n {
1394                let mut dot = 0.0;
1395                for k in 0..n {
1396                    dot += q[(k, i)] * q[(k, j)];
1397                }
1398                let expected = if i == j { 1.0 } else { 0.0 };
1399                assert!(
1400                    approx_eq(dot, expected, 1e-10),
1401                    "Q^T*Q[{},{}] = {}, expected {}",
1402                    i,
1403                    j,
1404                    dot,
1405                    expected
1406                );
1407            }
1408        }
1409    }
1410
1411    #[test]
1412    fn test_schur_single() {
1413        let a = Mat::from_rows(&[&[5.0f64]]);
1414        let schur = Schur::compute(a.as_ref()).unwrap();
1415
1416        assert_eq!(schur.eigenvalues().len(), 1);
1417        assert!(approx_eq(schur.eigenvalues()[0].real, 5.0, 1e-10));
1418    }
1419
1420    #[test]
1421    fn test_schur_4x4() {
1422        let a = Mat::from_rows(&[
1423            &[4.0f64, 1.0, -2.0, 2.0],
1424            &[1.0, 2.0, 0.0, 1.0],
1425            &[-2.0, 0.0, 3.0, -2.0],
1426            &[2.0, 1.0, -2.0, -1.0],
1427        ]);
1428
1429        let schur = Schur::compute(a.as_ref()).unwrap();
1430        let reconstructed = schur.reconstruct();
1431
1432        for i in 0..4 {
1433            for j in 0..4 {
1434                assert!(
1435                    approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-8),
1436                    "reconstructed[{},{}] = {}, a = {}",
1437                    i,
1438                    j,
1439                    reconstructed[(i, j)],
1440                    a[(i, j)]
1441                );
1442            }
1443        }
1444    }
1445
1446    #[test]
1447    fn test_schur_f32() {
1448        let a = Mat::from_rows(&[&[1.0f32, 2.0], &[3.0, 4.0]]);
1449
1450        let schur = Schur::compute(a.as_ref()).unwrap();
1451        let reconstructed = schur.reconstruct();
1452
1453        for i in 0..2 {
1454            for j in 0..2 {
1455                assert!(
1456                    (reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-4,
1457                    "reconstructed[{},{}] = {}, a = {}",
1458                    i,
1459                    j,
1460                    reconstructed[(i, j)],
1461                    a[(i, j)]
1462                );
1463            }
1464        }
1465    }
1466
1467    #[test]
1468    fn test_trevc_right_upper_triangular() {
1469        // Upper triangular matrix - eigenvectors should be standard basis vectors
1470        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1471
1472        let v = trevc_right(&t);
1473
1474        // First eigenvector for λ=1 should be proportional to [1, 0, 0]
1475        assert!(v[(1, 0)].abs() < 1e-10);
1476        assert!(v[(2, 0)].abs() < 1e-10);
1477
1478        // Third eigenvector for λ=6 should be proportional to [*, *, 1]
1479        // (normalized, so third component is non-zero)
1480        assert!(v[(2, 2)].abs() > 0.1);
1481    }
1482
1483    #[test]
1484    fn test_trevc_right_diagonal() {
1485        // Diagonal matrix - eigenvectors are standard basis
1486        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1487
1488        let v = trevc_right(&t);
1489
1490        // Should be identity (or close to it with normalization)
1491        for i in 0..3 {
1492            assert!(
1493                approx_eq(v[(i, i)].abs(), 1.0, 1e-10),
1494                "v[{},{}] = {}",
1495                i,
1496                i,
1497                v[(i, i)]
1498            );
1499        }
1500    }
1501
1502    #[test]
1503    fn test_trevc_eigenvector_equation() {
1504        // Test that T * v = λ * v for computed eigenvectors
1505        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1506
1507        let v = trevc_right(&t);
1508
1509        // Eigenvalues are 1, 4, 6 (diagonal elements)
1510        let eigenvalues = [1.0, 4.0, 6.0];
1511
1512        for (j, &lambda) in eigenvalues.iter().enumerate() {
1513            // Compute T * v[:,j]
1514            let mut tv = [0.0; 3];
1515            for i in 0..3 {
1516                for k in 0..3 {
1517                    tv[i] += t[(i, k)] * v[(k, j)];
1518                }
1519            }
1520
1521            // Check T * v = λ * v
1522            for i in 0..3 {
1523                assert!(
1524                    approx_eq(tv[i], lambda * v[(i, j)], 1e-10),
1525                    "T*v[{}] = {}, λ*v = {}",
1526                    i,
1527                    tv[i],
1528                    lambda * v[(i, j)]
1529                );
1530            }
1531        }
1532    }
1533
1534    #[test]
1535    fn test_trevc_left_diagonal() {
1536        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1537
1538        let u = trevc_left(&t);
1539
1540        // Should be identity
1541        for i in 0..3 {
1542            assert!(
1543                approx_eq(u[(i, i)].abs(), 1.0, 1e-10),
1544                "u[{},{}] = {}",
1545                i,
1546                i,
1547                u[(i, i)]
1548            );
1549        }
1550    }
1551
1552    #[test]
1553    fn test_schur_eigenvectors() {
1554        // Test eigenvectors through the Schur decomposition
1555        let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1556
1557        let schur = Schur::compute(a.as_ref()).unwrap();
1558        let (vr, _vl) = schur.eigenvectors();
1559
1560        // For upper triangular matrix, A = T, so eigenvectors are the same
1561        // Check that A * v = λ * v
1562        for j in 0..3 {
1563            let lambda = schur.eigenvalues()[j].real;
1564
1565            let mut av = [0.0; 3];
1566            for i in 0..3 {
1567                for k in 0..3 {
1568                    av[i] += a[(i, k)] * vr[(k, j)];
1569                }
1570            }
1571
1572            for i in 0..3 {
1573                assert!(
1574                    approx_eq(av[i], lambda * vr[(i, j)], 1e-8),
1575                    "A*v[{}] = {}, λ*v = {}",
1576                    i,
1577                    av[i],
1578                    lambda * vr[(i, j)]
1579                );
1580            }
1581        }
1582    }
1583
1584    #[test]
1585    fn test_trevc_2x2_block() {
1586        // Create a 2x2 block with complex eigenvalues: rotation matrix
1587        let theta = core::f64::consts::FRAC_PI_4;
1588        let c = theta.cos();
1589        let s = theta.sin();
1590        let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1591
1592        let v = trevc_right(&t);
1593
1594        // For complex eigenvalues, columns should contain real and imag parts
1595        // The eigenvector equation (T - λI)v = 0 where λ = c + i*s
1596        // Check that both columns are normalized
1597        let norm0_sq = v[(0, 0)] * v[(0, 0)] + v[(1, 0)] * v[(1, 0)];
1598        let norm1_sq = v[(0, 1)] * v[(0, 1)] + v[(1, 1)] * v[(1, 1)];
1599        let total_norm = (norm0_sq + norm1_sq).sqrt();
1600
1601        assert!(
1602            approx_eq(total_norm, 1.0, 1e-10),
1603            "eigenvector norm = {}",
1604            total_norm
1605        );
1606    }
1607
1608    #[test]
1609    fn test_trsna_s_diagonal() {
1610        // Diagonal matrix: left and right eigenvectors are standard basis
1611        // so s = |e_i^T e_i| = 1 for all eigenvalues
1612        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1613
1614        let s = trsna_s(&t);
1615
1616        assert_eq!(s.len(), 3);
1617        for i in 0..3 {
1618            assert!(
1619                approx_eq(s[i], 1.0, 1e-10),
1620                "s[{}] = {}, expected 1.0",
1621                i,
1622                s[i]
1623            );
1624        }
1625    }
1626
1627    #[test]
1628    fn test_trsna_s_upper_triangular() {
1629        // Upper triangular: eigenvalues are diagonal elements
1630        let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1631
1632        let s = trsna_s(&t);
1633
1634        assert_eq!(s.len(), 3);
1635        // All condition numbers should be positive
1636        for i in 0..3 {
1637            assert!(s[i] > 0.0, "s[{}] = {} should be positive", i, s[i]);
1638        }
1639    }
1640
1641    #[test]
1642    fn test_trsna_sep_diagonal() {
1643        // Diagonal matrix: separation is the minimum distance to other eigenvalues
1644        let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1645
1646        let sep = trsna_sep(&t);
1647
1648        assert_eq!(sep.len(), 3);
1649        // λ=2: min dist to 3,5 is 1
1650        assert!(
1651            approx_eq(sep[0], 1.0, 1e-10),
1652            "sep[0] = {}, expected 1.0",
1653            sep[0]
1654        );
1655        // λ=5: min dist to 2,3 is 2
1656        assert!(
1657            approx_eq(sep[1], 2.0, 1e-10),
1658            "sep[1] = {}, expected 2.0",
1659            sep[1]
1660        );
1661        // λ=3: min dist to 2,5 is 1
1662        assert!(
1663            approx_eq(sep[2], 1.0, 1e-10),
1664            "sep[2] = {}, expected 1.0",
1665            sep[2]
1666        );
1667    }
1668
1669    #[test]
1670    fn test_trsna_sep_close_eigenvalues() {
1671        // Eigenvalues very close together should have small separation
1672        let t = Mat::from_rows(&[&[1.0f64, 0.0], &[0.0, 1.001]]);
1673
1674        let sep = trsna_sep(&t);
1675
1676        assert_eq!(sep.len(), 2);
1677        // Both should have separation close to 0.001
1678        assert!(
1679            approx_eq(sep[0], 0.001, 1e-10),
1680            "sep[0] = {}, expected 0.001",
1681            sep[0]
1682        );
1683        assert!(
1684            approx_eq(sep[1], 0.001, 1e-10),
1685            "sep[1] = {}, expected 0.001",
1686            sep[1]
1687        );
1688    }
1689
1690    #[test]
1691    fn test_schur_eigenvalue_condition_numbers() {
1692        // Test through Schur decomposition
1693        let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1694
1695        let schur = Schur::compute(a.as_ref()).unwrap();
1696        let cond = schur.eigenvalue_condition_numbers();
1697
1698        assert_eq!(cond.len(), 2);
1699        // Diagonal matrix should have well-conditioned eigenvalues
1700        for i in 0..2 {
1701            assert!(
1702                cond[i] > 0.5,
1703                "cond[{}] = {} should be > 0.5 for diagonal matrix",
1704                i,
1705                cond[i]
1706            );
1707        }
1708    }
1709
1710    #[test]
1711    fn test_schur_eigenvector_separation() {
1712        // Test through Schur decomposition
1713        let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1714
1715        let schur = Schur::compute(a.as_ref()).unwrap();
1716        let sep = schur.eigenvector_separation();
1717
1718        assert_eq!(sep.len(), 2);
1719        // Eigenvalues are 2 and 4, so separation should be 2
1720        for i in 0..2 {
1721            assert!(
1722                approx_eq(sep[i], 2.0, 1e-10),
1723                "sep[{}] = {}, expected 2.0",
1724                i,
1725                sep[i]
1726            );
1727        }
1728    }
1729
1730    #[test]
1731    fn test_trsna_complex_eigenvalues() {
1732        // Rotation matrix with complex eigenvalues
1733        let theta = core::f64::consts::FRAC_PI_4;
1734        let c = theta.cos();
1735        let s = theta.sin();
1736        let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1737
1738        let cond = trsna_s(&t);
1739        let sep = trsna_sep(&t);
1740
1741        assert_eq!(cond.len(), 2);
1742        assert_eq!(sep.len(), 2);
1743
1744        // Both eigenvalues in a complex pair should have the same condition
1745        assert!(
1746            approx_eq(cond[0], cond[1], 1e-10),
1747            "cond[0]={}, cond[1]={} should be equal",
1748            cond[0],
1749            cond[1]
1750        );
1751        assert!(
1752            approx_eq(sep[0], sep[1], 1e-10),
1753            "sep[0]={}, sep[1]={} should be equal",
1754            sep[0],
1755            sep[1]
1756        );
1757    }
1758}