1use oxiblas_core::scalar::{Field, Real, Scalar};
8use oxiblas_matrix::{Mat, MatRef};
9
10use super::hessenberg::Hessenberg;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchurError {
15 EmptyMatrix,
17 NotSquare,
19 NotConverged,
21}
22
23impl core::fmt::Display for SchurError {
24 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25 match self {
26 Self::EmptyMatrix => write!(f, "Matrix is empty"),
27 Self::NotSquare => write!(f, "Matrix must be square"),
28 Self::NotConverged => write!(f, "Schur decomposition did not converge"),
29 }
30 }
31}
32
33impl std::error::Error for SchurError {}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Eigenvalue<T> {
38 pub real: T,
40 pub imag: T,
42}
43
44impl<T: Scalar> Eigenvalue<T> {
45 pub fn real_only(value: T) -> Self {
47 Self {
48 real: value,
49 imag: T::zero(),
50 }
51 }
52
53 pub fn complex(real: T, imag: T) -> Self {
55 Self { real, imag }
56 }
57
58 pub fn is_real(&self) -> bool {
60 self.imag == T::zero()
61 }
62}
63
64#[derive(Debug, Clone)]
70pub struct Schur<T: Scalar> {
71 q: Mat<T>,
73 t: Mat<T>,
75 eigenvalues: Vec<Eigenvalue<T>>,
77 n: usize,
79}
80
81impl<T: Field + Real + bytemuck::Zeroable> Schur<T> {
82 const MAX_ITERATIONS: usize = 100;
84
85 pub fn compute(a: MatRef<'_, T>) -> Result<Self, SchurError> {
105 let m = a.nrows();
106 let n = a.ncols();
107
108 if m == 0 || n == 0 {
109 return Err(SchurError::EmptyMatrix);
110 }
111
112 if m != n {
113 return Err(SchurError::NotSquare);
114 }
115
116 if n == 1 {
118 let mut t = Mat::zeros(1, 1);
119 t[(0, 0)] = a[(0, 0)];
120 let mut q = Mat::zeros(1, 1);
121 q[(0, 0)] = T::one();
122 let eigenvalues = vec![Eigenvalue::real_only(a[(0, 0)])];
123 return Ok(Self {
124 q,
125 t,
126 eigenvalues,
127 n,
128 });
129 }
130
131 if n == 2 {
133 return Self::compute_2x2(a);
134 }
135
136 let hess = Hessenberg::compute(a).map_err(|_| SchurError::NotSquare)?;
138 let mut t = Mat::zeros(n, n);
139 let h = hess.h();
140 for i in 0..n {
141 for j in 0..n {
142 t[(i, j)] = h[(i, j)];
143 }
144 }
145
146 let mut q = Mat::zeros(n, n);
147 let q_hess = hess.q();
148 for i in 0..n {
149 for j in 0..n {
150 q[(i, j)] = q_hess[(i, j)];
151 }
152 }
153
154 let eps = <T as Scalar>::epsilon();
156 let tol = eps * T::from_f64(100.0).unwrap_or(T::one());
157
158 let mut p = n;
160 let mut iter_count = 0;
161
162 while p > 2 && iter_count < Self::MAX_ITERATIONS * n {
163 iter_count += 1;
164
165 let mut q_idx = p - 1;
167 while q_idx > 0 {
168 let sub = Scalar::abs(t[(q_idx, q_idx - 1)]);
169 let diag_sum =
170 Scalar::abs(t[(q_idx - 1, q_idx - 1)]) + Scalar::abs(t[(q_idx, q_idx)]);
171 if sub <= tol * diag_sum {
172 t[(q_idx, q_idx - 1)] = T::zero();
173 break;
174 }
175 q_idx -= 1;
176 }
177
178 if q_idx == p - 1 {
179 p -= 1;
181 } else if q_idx == p - 2 {
182 let a11 = t[(p - 2, p - 2)];
184 let a12 = t[(p - 2, p - 1)];
185 let a21 = t[(p - 1, p - 2)];
186 let a22 = t[(p - 1, p - 1)];
187 let trace = a11 + a22;
188 let det = a11 * a22 - a12 * a21;
189 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
190
191 if disc < T::zero() {
192 p -= 2;
194 } else {
195 Self::francis_qr_step(&mut t, &mut q, q_idx, p);
197 }
198 } else {
199 Self::francis_qr_step(&mut t, &mut q, q_idx, p);
201 }
202 }
203
204 if p == 2 {
206 let sub = Scalar::abs(t[(1, 0)]);
207 let diag_sum = Scalar::abs(t[(0, 0)]) + Scalar::abs(t[(1, 1)]);
208 if sub <= tol * diag_sum {
209 t[(1, 0)] = T::zero();
210 }
211 }
212
213 let eigenvalues = Self::extract_eigenvalues(&t);
215
216 Ok(Self {
217 q,
218 t,
219 eigenvalues,
220 n,
221 })
222 }
223
224 fn compute_2x2(a: MatRef<'_, T>) -> Result<Self, SchurError> {
226 let mut t = Mat::zeros(2, 2);
227 for i in 0..2 {
228 for j in 0..2 {
229 t[(i, j)] = a[(i, j)];
230 }
231 }
232
233 let a11 = a[(0, 0)];
234 let a12 = a[(0, 1)];
235 let a21 = a[(1, 0)];
236 let a22 = a[(1, 1)];
237
238 let trace = a11 + a22;
239 let det = a11 * a22 - a12 * a21;
240 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
241
242 let mut q = Mat::zeros(2, 2);
243 let eigenvalues: Vec<Eigenvalue<T>>;
244
245 if disc >= T::zero() {
246 let sqrt_disc = Real::sqrt(disc);
248 let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
249 let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
250
251 if Scalar::abs(a21) > <T as Scalar>::epsilon() {
253 let theta = if Scalar::abs(a11 - lambda1) > <T as Scalar>::epsilon() {
254 Real::atan2(a21, a11 - lambda1)
255 } else {
256 T::from_f64(core::f64::consts::FRAC_PI_2).unwrap_or_else(T::zero)
257 };
258 let c = Real::cos(theta);
259 let s = Real::sin(theta);
260
261 q[(0, 0)] = c;
262 q[(0, 1)] = -s;
263 q[(1, 0)] = s;
264 q[(1, 1)] = c;
265
266 let mut temp = Mat::zeros(2, 2);
268 for i in 0..2 {
270 for j in 0..2 {
271 let mut sum = T::zero();
272 for k in 0..2 {
273 sum = sum + q[(k, i)] * a[(k, j)];
274 }
275 temp[(i, j)] = sum;
276 }
277 }
278 for i in 0..2 {
280 for j in 0..2 {
281 let mut sum = T::zero();
282 for k in 0..2 {
283 sum = sum + temp[(i, k)] * q[(k, j)];
284 }
285 t[(i, j)] = sum;
286 }
287 }
288 } else {
289 q[(0, 0)] = T::one();
291 q[(1, 1)] = T::one();
292 }
293
294 eigenvalues = vec![
295 Eigenvalue::real_only(lambda1),
296 Eigenvalue::real_only(lambda2),
297 ];
298 } else {
299 let sqrt_disc = Real::sqrt(-disc);
301 let real_part = trace / T::from_f64(2.0).unwrap_or_else(T::zero);
302 let imag_part = sqrt_disc / T::from_f64(2.0).unwrap_or_else(T::zero);
303
304 q[(0, 0)] = T::one();
305 q[(1, 1)] = T::one();
306
307 eigenvalues = vec![
308 Eigenvalue::complex(real_part, imag_part),
309 Eigenvalue::complex(real_part, -imag_part),
310 ];
311 }
312
313 Ok(Self {
314 q,
315 t,
316 eigenvalues,
317 n: 2,
318 })
319 }
320
321 fn francis_qr_step(t: &mut Mat<T>, q: &mut Mat<T>, start: usize, end: usize) {
323 let n = t.nrows();
324
325 if end - start < 2 {
326 return;
327 }
328
329 let h11 = t[(end - 2, end - 2)];
331 let h12 = t[(end - 2, end - 1)];
332 let h21 = t[(end - 1, end - 2)];
333 let h22 = t[(end - 1, end - 1)];
334
335 let s = h11 + h22; let p = h11 * h22 - h12 * h21; let h_00 = t[(start, start)];
340 let h_01 = t[(start, start + 1)];
341 let h_10 = t[(start + 1, start)];
342
343 let mut x = h_00 * h_00 + h_01 * h_10 - s * h_00 + p;
344 let mut y = h_10 * (h_00 + t[(start + 1, start + 1)] - s);
345 let mut z = if start + 2 < end {
346 h_10 * t[(start + 2, start + 1)]
347 } else {
348 T::zero()
349 };
350
351 for k in start..end.saturating_sub(1) {
353 let (v, tau) = householder_3(&[x, y, z]);
355
356 if tau != T::zero() {
357 let r = if k > start { k - 1 } else { k };
358
359 let col_start = r;
361 let col_end = n;
362 for j in col_start..col_end {
363 let rows = (k..(k + 3).min(end)).collect::<Vec<_>>();
364 let mut dot = T::zero();
365 for (vi, &row) in rows.iter().enumerate() {
366 dot = dot + v[vi] * t[(row, j)];
367 }
368 let scaled = tau * dot;
369 for (vi, &row) in rows.iter().enumerate() {
370 t[(row, j)] = t[(row, j)] - scaled * v[vi];
371 }
372 }
373
374 let row_end = (k + 4).min(end);
376 for i in 0..row_end {
377 let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
378 let mut dot = T::zero();
379 for (vi, &col) in cols.iter().enumerate() {
380 dot = dot + t[(i, col)] * v[vi];
381 }
382 let scaled = tau * dot;
383 for (vi, &col) in cols.iter().enumerate() {
384 t[(i, col)] = t[(i, col)] - scaled * v[vi];
385 }
386 }
387
388 for i in 0..n {
390 let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
391 let mut dot = T::zero();
392 for (vi, &col) in cols.iter().enumerate() {
393 dot = dot + q[(i, col)] * v[vi];
394 }
395 let scaled = tau * dot;
396 for (vi, &col) in cols.iter().enumerate() {
397 q[(i, col)] = q[(i, col)] - scaled * v[vi];
398 }
399 }
400 }
401
402 if k + 3 < end {
404 x = t[(k + 1, k)];
405 y = t[(k + 2, k)];
406 z = if k + 3 < end {
407 t[(k + 3, k)]
408 } else {
409 T::zero()
410 };
411 } else if k + 2 < end {
412 x = t[(k + 1, k)];
414 y = t[(k + 2, k)];
415 z = T::zero();
416 }
417 }
418 }
419
420 fn extract_eigenvalues(t: &Mat<T>) -> Vec<Eigenvalue<T>> {
422 let n = t.nrows();
423 let mut eigenvalues = Vec::with_capacity(n);
424 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
425
426 let mut i = 0;
427 while i < n {
428 if i == n - 1 {
429 eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
431 i += 1;
432 } else {
433 let sub = Scalar::abs(t[(i + 1, i)]);
435 let diag_sum = Scalar::abs(t[(i, i)]) + Scalar::abs(t[(i + 1, i + 1)]);
436
437 if sub <= eps * diag_sum {
438 eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
440 i += 1;
441 } else {
442 let a11 = t[(i, i)];
444 let a12 = t[(i, i + 1)];
445 let a21 = t[(i + 1, i)];
446 let a22 = t[(i + 1, i + 1)];
447
448 let trace = a11 + a22;
449 let det = a11 * a22 - a12 * a21;
450 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
451
452 if disc >= T::zero() {
453 let sqrt_disc = Real::sqrt(disc);
455 let lambda1 =
456 (trace + sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
457 let lambda2 =
458 (trace - sqrt_disc) / T::from_f64(2.0).unwrap_or_else(T::zero);
459 eigenvalues.push(Eigenvalue::real_only(lambda1));
460 eigenvalues.push(Eigenvalue::real_only(lambda2));
461 } else {
462 let sqrt_disc = Real::sqrt(-disc);
464 let real_part = trace / T::from_f64(2.0).unwrap_or_else(T::zero);
465 let imag_part = sqrt_disc / T::from_f64(2.0).unwrap_or_else(T::zero);
466 eigenvalues.push(Eigenvalue::complex(real_part, imag_part));
467 eigenvalues.push(Eigenvalue::complex(real_part, -imag_part));
468 }
469 i += 2;
470 }
471 }
472 }
473
474 eigenvalues
475 }
476
477 pub fn q(&self) -> MatRef<'_, T> {
479 self.q.as_ref()
480 }
481
482 pub fn t(&self) -> MatRef<'_, T> {
484 self.t.as_ref()
485 }
486
487 pub fn eigenvalues(&self) -> &[Eigenvalue<T>] {
489 &self.eigenvalues
490 }
491
492 pub fn eigenvalues_real(&self) -> Vec<T> {
494 self.eigenvalues.iter().map(|e| e.real).collect()
495 }
496
497 pub fn reconstruct(&self) -> Mat<T> {
499 let mut qt = Mat::zeros(self.n, self.n);
500 let mut a = Mat::zeros(self.n, self.n);
501
502 for i in 0..self.n {
504 for j in 0..self.n {
505 let mut sum = T::zero();
506 for k in 0..self.n {
507 sum = sum + self.q[(i, k)] * self.t[(k, j)];
508 }
509 qt[(i, j)] = sum;
510 }
511 }
512
513 for i in 0..self.n {
515 for j in 0..self.n {
516 let mut sum = T::zero();
517 for k in 0..self.n {
518 sum = sum + qt[(i, k)] * self.q[(j, k)];
519 }
520 a[(i, j)] = sum;
521 }
522 }
523
524 a
525 }
526
527 #[must_use]
553 pub fn right_eigenvectors(&self) -> Mat<T> {
554 trevc_right(&self.t)
555 }
556
557 #[must_use]
578 pub fn left_eigenvectors(&self) -> Mat<T> {
579 trevc_left(&self.t)
580 }
581
582 #[must_use]
590 pub fn eigenvectors(&self) -> (Mat<T>, Mat<T>) {
591 let vr_t = trevc_right(&self.t);
592 let vl_t = trevc_left(&self.t);
593
594 let mut vr_a = Mat::zeros(self.n, self.n);
596 let mut vl_a = Mat::zeros(self.n, self.n);
597
598 for i in 0..self.n {
599 for j in 0..self.n {
600 let mut sum_r = T::zero();
601 let mut sum_l = T::zero();
602 for k in 0..self.n {
603 sum_r = sum_r + self.q[(i, k)] * vr_t[(k, j)];
604 sum_l = sum_l + self.q[(i, k)] * vl_t[(k, j)];
605 }
606 vr_a[(i, j)] = sum_r;
607 vl_a[(i, j)] = sum_l;
608 }
609 }
610
611 (vr_a, vl_a)
612 }
613
614 #[must_use]
648 pub fn eigenvalue_condition_numbers(&self) -> Vec<T> {
649 trsna_s(&self.t)
650 }
651
652 #[must_use]
662 pub fn eigenvector_separation(&self) -> Vec<T> {
663 trsna_sep(&self.t)
664 }
665}
666
667pub fn trsna_s<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
680 let n = t.nrows();
681 if n == 0 {
682 return Vec::new();
683 }
684
685 let vr = trevc_right(t);
686 let vl = trevc_left(t);
687
688 let mut s = vec![T::zero(); n];
689 let eps = <T as Scalar>::epsilon();
690
691 let mut j = 0;
692 while j < n {
693 let is_2x2 = if j + 1 < n {
695 let sub = Scalar::abs(t[(j + 1, j)]);
696 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
697 sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
698 } else {
699 false
700 };
701
702 if is_2x2 {
703 let jp1 = j + 1;
707
708 let mut prod_rr = T::zero(); let mut prod_ii = T::zero(); let mut prod_ri = T::zero(); let mut prod_ir = T::zero(); for k in 0..n {
714 prod_rr = prod_rr + vl[(k, j)] * vr[(k, j)];
715 prod_ii = prod_ii + vl[(k, jp1)] * vr[(k, jp1)];
716 prod_ri = prod_ri + vl[(k, j)] * vr[(k, jp1)];
717 prod_ir = prod_ir + vl[(k, jp1)] * vr[(k, j)];
718 }
719
720 let real_part = prod_rr + prod_ii;
721 let imag_part = prod_ri - prod_ir;
722 let abs_inner = Real::sqrt(real_part * real_part + imag_part * imag_part);
723
724 s[j] = abs_inner;
726 s[jp1] = abs_inner;
727
728 j += 2;
729 } else {
730 let mut inner = T::zero();
732 for k in 0..n {
733 inner = inner + vl[(k, j)] * vr[(k, j)];
734 }
735 s[j] = Scalar::abs(inner);
736 j += 1;
737 }
738 }
739
740 s
741}
742
743pub fn trsna_sep<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
758 let n = t.nrows();
759 if n == 0 {
760 return Vec::new();
761 }
762
763 let mut sep = vec![T::zero(); n];
764 let eps = <T as Scalar>::epsilon();
765
766 let mut j = 0;
767 while j < n {
768 let is_2x2 = if j + 1 < n {
770 let sub = Scalar::abs(t[(j + 1, j)]);
771 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
772 sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
773 } else {
774 false
775 };
776
777 if is_2x2 {
778 let jp1 = j + 1;
780
781 let a11 = t[(j, j)];
783 let a22 = t[(jp1, jp1)];
784 let lambda_real = (a11 + a22) / T::from_f64(2.0).unwrap_or_else(T::zero);
785 let trace = a11 + a22;
786 let det = a11 * a22 - t[(j, jp1)] * t[(jp1, j)];
787 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
788 let lambda_imag = if disc < T::zero() {
789 Real::sqrt(-disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
790 } else {
791 T::zero()
792 };
793
794 let mut min_sep = T::one() / eps;
795
796 let mut k = 0;
798 while k < n {
799 if k == j || k == jp1 {
800 k += 1;
801 continue;
802 }
803
804 let adjacent = j > 0 && k == j - 1;
806 let k_is_2x2 = if k + 1 < n && !adjacent {
807 let sub = Scalar::abs(t[(k + 1, k)]);
808 let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
809 sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
810 } else {
811 false
812 };
813
814 let (other_real, other_imag) = if k_is_2x2 {
815 let kp1 = k + 1;
816 let b11 = t[(k, k)];
817 let b22 = t[(kp1, kp1)];
818 let other_trace = b11 + b22;
819 let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
820 let other_disc = other_trace * other_trace
821 - T::from_f64(4.0).unwrap_or_else(T::zero) * other_det;
822 let r = (b11 + b22) / T::from_f64(2.0).unwrap_or_else(T::zero);
823 let i = if other_disc < T::zero() {
824 Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
825 } else {
826 T::zero()
827 };
828 (r, i)
829 } else {
830 (t[(k, k)], T::zero())
831 };
832
833 let dr = lambda_real - other_real;
835 let di = lambda_imag - other_imag;
836 let dist = Real::sqrt(dr * dr + di * di);
837 if dist < min_sep && dist > T::zero() {
838 min_sep = dist;
839 }
840
841 if other_imag != T::zero() {
843 let di_conj = lambda_imag + other_imag;
844 let dist_conj = Real::sqrt(dr * dr + di_conj * di_conj);
845 if dist_conj < min_sep && dist_conj > T::zero() {
846 min_sep = dist_conj;
847 }
848 }
849
850 if k_is_2x2 {
851 k += 2;
852 } else {
853 k += 1;
854 }
855 }
856
857 sep[j] = min_sep;
858 sep[jp1] = min_sep;
859 j += 2;
860 } else {
861 let lambda = t[(j, j)];
863 let mut min_sep = T::one() / eps;
864
865 let mut k = 0;
867 while k < n {
868 if k == j {
869 k += 1;
870 continue;
871 }
872
873 let adjacent_to_j = (j > 0 && k == j - 1) || k == j + 1;
875 let k_is_2x2 = if k + 1 < n && !adjacent_to_j {
876 let sub = Scalar::abs(t[(k + 1, k)]);
877 let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
878 sub > eps * T::from_f64(100.0).unwrap_or_else(T::zero) * diag_sum
879 } else {
880 false
881 };
882
883 let (other_real, other_imag) = if k_is_2x2 {
884 let kp1 = k + 1;
885 let b11 = t[(k, k)];
886 let b22 = t[(kp1, kp1)];
887 let other_trace = b11 + b22;
888 let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
889 let other_disc = other_trace * other_trace
890 - T::from_f64(4.0).unwrap_or_else(T::zero) * other_det;
891 let r = (b11 + b22) / T::from_f64(2.0).unwrap_or_else(T::zero);
892 let i = if other_disc < T::zero() {
893 Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap_or_else(T::zero)
894 } else {
895 T::zero()
896 };
897 (r, i)
898 } else {
899 (t[(k, k)], T::zero())
900 };
901
902 let dr = lambda - other_real;
904 let dist = Real::sqrt(dr * dr + other_imag * other_imag);
905 if dist < min_sep && dist > T::zero() {
906 min_sep = dist;
907 }
908
909 if k_is_2x2 {
910 k += 2;
911 } else {
912 k += 1;
913 }
914 }
915
916 sep[j] = min_sep;
917 j += 1;
918 }
919 }
920
921 sep
922}
923
924pub fn trevc_right<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
938 let n = t.nrows();
939 let mut v = Mat::zeros(n, n);
940
941 for i in 0..n {
943 v[(i, i)] = T::one();
944 }
945
946 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
947
948 let mut j = n;
950 while j > 0 {
951 j -= 1;
952
953 let is_2x2 = if j > 0 {
955 let sub = Scalar::abs(t[(j, j - 1)]);
956 let diag_sum = Scalar::abs(t[(j - 1, j - 1)]) + Scalar::abs(t[(j, j)]);
957 sub > eps * diag_sum
958 } else {
959 false
960 };
961
962 if is_2x2 {
963 let jm1 = j - 1;
966
967 let a11 = t[(jm1, jm1)];
969 let a12 = t[(jm1, j)];
970 let a21 = t[(j, jm1)];
971 let a22 = t[(j, j)];
972
973 let trace = a11 + a22;
974 let det = a11 * a22 - a12 * a21;
975 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
976
977 let two = T::from_f64(2.0).unwrap_or_else(T::zero);
979 let real_part = trace / two;
980 let imag_part = Real::sqrt(-disc) / two;
981
982 v[(jm1, jm1)] = T::one();
985 v[(j, jm1)] = T::zero();
986 v[(jm1, j)] = T::zero();
987 v[(j, j)] = T::one();
988
989 let d11 = a11 - real_part;
1014 let d22 = a22 - real_part;
1015
1016 if Scalar::abs(a21) >= Scalar::abs(a12) && Scalar::abs(a21) > eps {
1018 let v1r = -d22 / a21;
1021 let v1i = imag_part / a21;
1022 v[(jm1, jm1)] = v1r;
1023 v[(j, jm1)] = T::one();
1024 v[(jm1, j)] = v1i;
1025 v[(j, j)] = T::zero();
1026 } else if Scalar::abs(a12) > eps {
1027 let v2r = -d11 / a12;
1030 let v2i = -imag_part / a12;
1031 v[(jm1, jm1)] = T::one();
1032 v[(j, jm1)] = v2r;
1033 v[(jm1, j)] = T::zero();
1034 v[(j, j)] = v2i;
1035 } else {
1036 v[(jm1, jm1)] = T::one();
1038 v[(j, jm1)] = T::zero();
1039 v[(jm1, j)] = T::zero();
1040 v[(j, j)] = T::one();
1041 }
1042
1043 for i in (0..jm1).rev() {
1045 let mut sum_r = T::zero();
1050 let mut sum_i = T::zero();
1051 for k in (i + 1)..=j {
1052 sum_r = sum_r + t[(i, k)] * v[(k, jm1)];
1053 sum_i = sum_i + t[(i, k)] * v[(k, j)];
1054 }
1055
1056 let d = t[(i, i)] - real_part;
1057 let det_2x2 = d * d + imag_part * imag_part;
1058
1059 if Scalar::abs(det_2x2) > eps {
1060 v[(i, jm1)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1064 v[(i, j)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1065 }
1066 }
1067
1068 let mut norm_r_sq = T::zero();
1070 let mut norm_i_sq = T::zero();
1071 for i in 0..n {
1072 norm_r_sq = norm_r_sq + v[(i, jm1)] * v[(i, jm1)];
1073 norm_i_sq = norm_i_sq + v[(i, j)] * v[(i, j)];
1074 }
1075 let norm = Real::sqrt(norm_r_sq + norm_i_sq);
1076 if norm > T::zero() {
1077 for i in 0..n {
1078 v[(i, jm1)] = v[(i, jm1)] / norm;
1079 v[(i, j)] = v[(i, j)] / norm;
1080 }
1081 }
1082
1083 j = jm1; } else {
1085 let lambda = t[(j, j)];
1087
1088 v[(j, j)] = T::one();
1090
1091 for i in (0..j).rev() {
1093 let mut sum = T::zero();
1094 for k in (i + 1)..=j {
1095 sum = sum + t[(i, k)] * v[(k, j)];
1096 }
1097
1098 let d = t[(i, i)] - lambda;
1099 if Scalar::abs(d) > eps {
1100 v[(i, j)] = -sum / d;
1101 } else {
1102 v[(i, j)] = -sum / eps;
1104 }
1105 }
1106
1107 let mut norm_sq = T::zero();
1109 for i in 0..n {
1110 norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1111 }
1112 let norm = Real::sqrt(norm_sq);
1113 if norm > T::zero() {
1114 for i in 0..n {
1115 v[(i, j)] = v[(i, j)] / norm;
1116 }
1117 }
1118 }
1119 }
1120
1121 v
1122}
1123
1124pub fn trevc_left<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
1137 let n = t.nrows();
1138 let mut v = Mat::zeros(n, n);
1139
1140 for i in 0..n {
1142 v[(i, i)] = T::one();
1143 }
1144
1145 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
1146
1147 let mut j = 0;
1149 while j < n {
1150 let is_2x2 = if j + 1 < n {
1152 let sub = Scalar::abs(t[(j + 1, j)]);
1153 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
1154 sub > eps * diag_sum
1155 } else {
1156 false
1157 };
1158
1159 if is_2x2 {
1160 let jp1 = j + 1;
1162
1163 let a11 = t[(j, j)];
1164 let a12 = t[(j, jp1)];
1165 let a21 = t[(jp1, j)];
1166 let a22 = t[(jp1, jp1)];
1167
1168 let trace = a11 + a22;
1169 let det = a11 * a22 - a12 * a21;
1170 let disc = trace * trace - T::from_f64(4.0).unwrap_or_else(T::zero) * det;
1171
1172 let two = T::from_f64(2.0).unwrap_or_else(T::zero);
1173 let real_part = trace / two;
1174 let imag_part = Real::sqrt(-disc) / two;
1175
1176 let d11 = a11 - real_part;
1178 let det_factor = d11 * d11 + imag_part * imag_part;
1179
1180 if Scalar::abs(det_factor) > eps {
1181 let vr2 = -a21 * d11 / det_factor;
1182 let vi2 = -imag_part * a21 / det_factor;
1183 v[(j, j)] = T::one();
1184 v[(jp1, j)] = vr2;
1185 v[(j, jp1)] = T::zero();
1186 v[(jp1, jp1)] = vi2;
1187 } else {
1188 v[(j, j)] = T::one();
1189 v[(jp1, j)] = T::zero();
1190 v[(j, jp1)] = T::zero();
1191 v[(jp1, jp1)] = T::one();
1192 }
1193
1194 for i in (jp1 + 1)..n {
1196 let mut sum_r = T::zero();
1197 let mut sum_i = T::zero();
1198 for k in j..i {
1199 sum_r = sum_r + t[(k, i)] * v[(k, j)];
1200 sum_i = sum_i + t[(k, i)] * v[(k, jp1)];
1201 }
1202
1203 let d = t[(i, i)] - real_part;
1204 let det_2x2 = d * d + imag_part * imag_part;
1205
1206 if Scalar::abs(det_2x2) > eps {
1207 v[(i, j)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1208 v[(i, jp1)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1209 }
1210 }
1211
1212 let mut norm_sq = T::zero();
1214 for i in 0..n {
1215 norm_sq = norm_sq + v[(i, j)] * v[(i, j)] + v[(i, jp1)] * v[(i, jp1)];
1216 }
1217 let norm = Real::sqrt(norm_sq);
1218 if norm > T::zero() {
1219 for i in 0..n {
1220 v[(i, j)] = v[(i, j)] / norm;
1221 v[(i, jp1)] = v[(i, jp1)] / norm;
1222 }
1223 }
1224
1225 j = jp1 + 1;
1226 } else {
1227 let lambda = t[(j, j)];
1229 v[(j, j)] = T::one();
1230
1231 for i in (j + 1)..n {
1233 let mut sum = T::zero();
1234 for k in j..i {
1235 sum = sum + t[(k, i)] * v[(k, j)];
1236 }
1237
1238 let d = t[(i, i)] - lambda;
1239 if Scalar::abs(d) > eps {
1240 v[(i, j)] = -sum / d;
1241 } else {
1242 v[(i, j)] = -sum / eps;
1243 }
1244 }
1245
1246 let mut norm_sq = T::zero();
1248 for i in 0..n {
1249 norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1250 }
1251 let norm = Real::sqrt(norm_sq);
1252 if norm > T::zero() {
1253 for i in 0..n {
1254 v[(i, j)] = v[(i, j)] / norm;
1255 }
1256 }
1257
1258 j += 1;
1259 }
1260 }
1261
1262 v
1263}
1264
1265fn householder_3<T: Field + Real>(x: &[T]) -> (Vec<T>, T) {
1267 let n = x.len().min(3);
1268 if n == 0 {
1269 return (Vec::new(), T::zero());
1270 }
1271
1272 let mut norm_sq = T::zero();
1273 for i in 0..n {
1274 norm_sq = norm_sq + x[i] * x[i];
1275 }
1276 let norm = Real::sqrt(norm_sq);
1277
1278 if norm == T::zero() {
1279 return (vec![T::zero(); n], T::zero());
1280 }
1281
1282 let mut v = vec![T::zero(); n];
1283 for i in 0..n {
1284 v[i] = x[i];
1285 }
1286
1287 let sign = if x[0] >= T::zero() {
1288 T::one()
1289 } else {
1290 -T::one()
1291 };
1292 v[0] = v[0] + sign * norm;
1293
1294 let mut v_norm_sq = T::zero();
1295 for i in 0..n {
1296 v_norm_sq = v_norm_sq + v[i] * v[i];
1297 }
1298
1299 if v_norm_sq > T::zero() {
1300 let tau = T::from_f64(2.0).unwrap_or_else(T::zero) / v_norm_sq;
1301 (v, tau)
1302 } else {
1303 (v, T::zero())
1304 }
1305}
1306
1307#[cfg(test)]
1308mod tests {
1309 use super::*;
1310
1311 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
1312 (a - b).abs() < tol
1313 }
1314
1315 #[test]
1316 fn test_schur_upper_triangular() {
1317 let a = Mat::from_rows(&[&[1.0f64, 2.0], &[0.0, 3.0]]);
1319
1320 let schur = Schur::compute(a.as_ref()).unwrap();
1321 let eigenvalues = schur.eigenvalues();
1322
1323 let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1325 eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1326 assert!(approx_eq(eigs[0], 1.0, 1e-10));
1327 assert!(approx_eq(eigs[1], 3.0, 1e-10));
1328 }
1329
1330 #[test]
1331 fn test_schur_diagonal() {
1332 let a = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1333
1334 let schur = Schur::compute(a.as_ref()).unwrap();
1335 let eigenvalues = schur.eigenvalues();
1336
1337 let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1338 eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1339 assert!(approx_eq(eigs[0], 2.0, 1e-10));
1340 assert!(approx_eq(eigs[1], 3.0, 1e-10));
1341 assert!(approx_eq(eigs[2], 5.0, 1e-10));
1342 }
1343
1344 #[test]
1345 fn test_schur_complex_eigenvalues() {
1346 let theta = core::f64::consts::FRAC_PI_4; let c = theta.cos();
1349 let s = theta.sin();
1350 let a = Mat::from_rows(&[&[c, -s], &[s, c]]);
1351
1352 let schur = Schur::compute(a.as_ref()).unwrap();
1353 let eigenvalues = schur.eigenvalues();
1354
1355 assert_eq!(eigenvalues.len(), 2);
1357 assert!(approx_eq(eigenvalues[0].real, eigenvalues[1].real, 1e-10));
1359 assert!(approx_eq(eigenvalues[0].imag, -eigenvalues[1].imag, 1e-10));
1360 assert!(approx_eq(eigenvalues[0].real, c, 1e-10));
1361 assert!(approx_eq(eigenvalues[0].imag.abs(), s, 1e-10));
1362 }
1363
1364 #[test]
1365 fn test_schur_reconstruction() {
1366 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1367
1368 let schur = Schur::compute(a.as_ref()).unwrap();
1369 let reconstructed = schur.reconstruct();
1370
1371 for i in 0..3 {
1372 for j in 0..3 {
1373 assert!(
1374 approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-10),
1375 "reconstructed[{},{}] = {}, a = {}",
1376 i,
1377 j,
1378 reconstructed[(i, j)],
1379 a[(i, j)]
1380 );
1381 }
1382 }
1383 }
1384
1385 #[test]
1386 fn test_schur_q_orthogonal() {
1387 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1388
1389 let schur = Schur::compute(a.as_ref()).unwrap();
1390 let q = schur.q();
1391
1392 let n = 3;
1394 for i in 0..n {
1395 for j in 0..n {
1396 let mut dot = 0.0;
1397 for k in 0..n {
1398 dot += q[(k, i)] * q[(k, j)];
1399 }
1400 let expected = if i == j { 1.0 } else { 0.0 };
1401 assert!(
1402 approx_eq(dot, expected, 1e-10),
1403 "Q^T*Q[{},{}] = {}, expected {}",
1404 i,
1405 j,
1406 dot,
1407 expected
1408 );
1409 }
1410 }
1411 }
1412
1413 #[test]
1414 fn test_schur_single() {
1415 let a = Mat::from_rows(&[&[5.0f64]]);
1416 let schur = Schur::compute(a.as_ref()).unwrap();
1417
1418 assert_eq!(schur.eigenvalues().len(), 1);
1419 assert!(approx_eq(schur.eigenvalues()[0].real, 5.0, 1e-10));
1420 }
1421
1422 #[test]
1423 fn test_schur_4x4() {
1424 let a = Mat::from_rows(&[
1425 &[4.0f64, 1.0, -2.0, 2.0],
1426 &[1.0, 2.0, 0.0, 1.0],
1427 &[-2.0, 0.0, 3.0, -2.0],
1428 &[2.0, 1.0, -2.0, -1.0],
1429 ]);
1430
1431 let schur = Schur::compute(a.as_ref()).unwrap();
1432 let reconstructed = schur.reconstruct();
1433
1434 for i in 0..4 {
1435 for j in 0..4 {
1436 assert!(
1437 approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-8),
1438 "reconstructed[{},{}] = {}, a = {}",
1439 i,
1440 j,
1441 reconstructed[(i, j)],
1442 a[(i, j)]
1443 );
1444 }
1445 }
1446 }
1447
1448 #[test]
1449 fn test_schur_f32() {
1450 let a = Mat::from_rows(&[&[1.0f32, 2.0], &[3.0, 4.0]]);
1451
1452 let schur = Schur::compute(a.as_ref()).unwrap();
1453 let reconstructed = schur.reconstruct();
1454
1455 for i in 0..2 {
1456 for j in 0..2 {
1457 assert!(
1458 (reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-4,
1459 "reconstructed[{},{}] = {}, a = {}",
1460 i,
1461 j,
1462 reconstructed[(i, j)],
1463 a[(i, j)]
1464 );
1465 }
1466 }
1467 }
1468
1469 #[test]
1470 fn test_trevc_right_upper_triangular() {
1471 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1473
1474 let v = trevc_right(&t);
1475
1476 assert!(v[(1, 0)].abs() < 1e-10);
1478 assert!(v[(2, 0)].abs() < 1e-10);
1479
1480 assert!(v[(2, 2)].abs() > 0.1);
1483 }
1484
1485 #[test]
1486 fn test_trevc_right_diagonal() {
1487 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1489
1490 let v = trevc_right(&t);
1491
1492 for i in 0..3 {
1494 assert!(
1495 approx_eq(v[(i, i)].abs(), 1.0, 1e-10),
1496 "v[{},{}] = {}",
1497 i,
1498 i,
1499 v[(i, i)]
1500 );
1501 }
1502 }
1503
1504 #[test]
1505 fn test_trevc_eigenvector_equation() {
1506 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1508
1509 let v = trevc_right(&t);
1510
1511 let eigenvalues = [1.0, 4.0, 6.0];
1513
1514 for (j, &lambda) in eigenvalues.iter().enumerate() {
1515 let mut tv = [0.0; 3];
1517 for i in 0..3 {
1518 for k in 0..3 {
1519 tv[i] += t[(i, k)] * v[(k, j)];
1520 }
1521 }
1522
1523 for i in 0..3 {
1525 assert!(
1526 approx_eq(tv[i], lambda * v[(i, j)], 1e-10),
1527 "T*v[{}] = {}, λ*v = {}",
1528 i,
1529 tv[i],
1530 lambda * v[(i, j)]
1531 );
1532 }
1533 }
1534 }
1535
1536 #[test]
1537 fn test_trevc_left_diagonal() {
1538 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1539
1540 let u = trevc_left(&t);
1541
1542 for i in 0..3 {
1544 assert!(
1545 approx_eq(u[(i, i)].abs(), 1.0, 1e-10),
1546 "u[{},{}] = {}",
1547 i,
1548 i,
1549 u[(i, i)]
1550 );
1551 }
1552 }
1553
1554 #[test]
1555 fn test_schur_eigenvectors() {
1556 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1558
1559 let schur = Schur::compute(a.as_ref()).unwrap();
1560 let (vr, _vl) = schur.eigenvectors();
1561
1562 for j in 0..3 {
1565 let lambda = schur.eigenvalues()[j].real;
1566
1567 let mut av = [0.0; 3];
1568 for i in 0..3 {
1569 for k in 0..3 {
1570 av[i] += a[(i, k)] * vr[(k, j)];
1571 }
1572 }
1573
1574 for i in 0..3 {
1575 assert!(
1576 approx_eq(av[i], lambda * vr[(i, j)], 1e-8),
1577 "A*v[{}] = {}, λ*v = {}",
1578 i,
1579 av[i],
1580 lambda * vr[(i, j)]
1581 );
1582 }
1583 }
1584 }
1585
1586 #[test]
1587 fn test_trevc_2x2_block() {
1588 let theta = core::f64::consts::FRAC_PI_4;
1590 let c = theta.cos();
1591 let s = theta.sin();
1592 let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1593
1594 let v = trevc_right(&t);
1595
1596 let norm0_sq = v[(0, 0)] * v[(0, 0)] + v[(1, 0)] * v[(1, 0)];
1600 let norm1_sq = v[(0, 1)] * v[(0, 1)] + v[(1, 1)] * v[(1, 1)];
1601 let total_norm = (norm0_sq + norm1_sq).sqrt();
1602
1603 assert!(
1604 approx_eq(total_norm, 1.0, 1e-10),
1605 "eigenvector norm = {}",
1606 total_norm
1607 );
1608 }
1609
1610 #[test]
1611 fn test_trsna_s_diagonal() {
1612 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1615
1616 let s = trsna_s(&t);
1617
1618 assert_eq!(s.len(), 3);
1619 for i in 0..3 {
1620 assert!(
1621 approx_eq(s[i], 1.0, 1e-10),
1622 "s[{}] = {}, expected 1.0",
1623 i,
1624 s[i]
1625 );
1626 }
1627 }
1628
1629 #[test]
1630 fn test_trsna_s_upper_triangular() {
1631 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1633
1634 let s = trsna_s(&t);
1635
1636 assert_eq!(s.len(), 3);
1637 for i in 0..3 {
1639 assert!(s[i] > 0.0, "s[{}] = {} should be positive", i, s[i]);
1640 }
1641 }
1642
1643 #[test]
1644 fn test_trsna_sep_diagonal() {
1645 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1647
1648 let sep = trsna_sep(&t);
1649
1650 assert_eq!(sep.len(), 3);
1651 assert!(
1653 approx_eq(sep[0], 1.0, 1e-10),
1654 "sep[0] = {}, expected 1.0",
1655 sep[0]
1656 );
1657 assert!(
1659 approx_eq(sep[1], 2.0, 1e-10),
1660 "sep[1] = {}, expected 2.0",
1661 sep[1]
1662 );
1663 assert!(
1665 approx_eq(sep[2], 1.0, 1e-10),
1666 "sep[2] = {}, expected 1.0",
1667 sep[2]
1668 );
1669 }
1670
1671 #[test]
1672 fn test_trsna_sep_close_eigenvalues() {
1673 let t = Mat::from_rows(&[&[1.0f64, 0.0], &[0.0, 1.001]]);
1675
1676 let sep = trsna_sep(&t);
1677
1678 assert_eq!(sep.len(), 2);
1679 assert!(
1681 approx_eq(sep[0], 0.001, 1e-10),
1682 "sep[0] = {}, expected 0.001",
1683 sep[0]
1684 );
1685 assert!(
1686 approx_eq(sep[1], 0.001, 1e-10),
1687 "sep[1] = {}, expected 0.001",
1688 sep[1]
1689 );
1690 }
1691
1692 #[test]
1693 fn test_schur_eigenvalue_condition_numbers() {
1694 let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1696
1697 let schur = Schur::compute(a.as_ref()).unwrap();
1698 let cond = schur.eigenvalue_condition_numbers();
1699
1700 assert_eq!(cond.len(), 2);
1701 for i in 0..2 {
1703 assert!(
1704 cond[i] > 0.5,
1705 "cond[{}] = {} should be > 0.5 for diagonal matrix",
1706 i,
1707 cond[i]
1708 );
1709 }
1710 }
1711
1712 #[test]
1713 fn test_schur_eigenvector_separation() {
1714 let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1716
1717 let schur = Schur::compute(a.as_ref()).unwrap();
1718 let sep = schur.eigenvector_separation();
1719
1720 assert_eq!(sep.len(), 2);
1721 for i in 0..2 {
1723 assert!(
1724 approx_eq(sep[i], 2.0, 1e-10),
1725 "sep[{}] = {}, expected 2.0",
1726 i,
1727 sep[i]
1728 );
1729 }
1730 }
1731
1732 #[test]
1733 fn test_trsna_complex_eigenvalues() {
1734 let theta = core::f64::consts::FRAC_PI_4;
1736 let c = theta.cos();
1737 let s = theta.sin();
1738 let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1739
1740 let cond = trsna_s(&t);
1741 let sep = trsna_sep(&t);
1742
1743 assert_eq!(cond.len(), 2);
1744 assert_eq!(sep.len(), 2);
1745
1746 assert!(
1748 approx_eq(cond[0], cond[1], 1e-10),
1749 "cond[0]={}, cond[1]={} should be equal",
1750 cond[0],
1751 cond[1]
1752 );
1753 assert!(
1754 approx_eq(sep[0], sep[1], 1e-10),
1755 "sep[0]={}, sep[1]={} should be equal",
1756 sep[0],
1757 sep[1]
1758 );
1759 }
1760}