1use oxiblas_core::scalar::{Field, Real, Scalar};
8use oxiblas_matrix::{Mat, MatRef};
9
10use super::hessenberg::Hessenberg;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SchurError {
15 EmptyMatrix,
17 NotSquare,
19 NotConverged,
21}
22
23impl core::fmt::Display for SchurError {
24 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25 match self {
26 Self::EmptyMatrix => write!(f, "Matrix is empty"),
27 Self::NotSquare => write!(f, "Matrix must be square"),
28 Self::NotConverged => write!(f, "Schur decomposition did not converge"),
29 }
30 }
31}
32
33impl std::error::Error for SchurError {}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Eigenvalue<T> {
38 pub real: T,
40 pub imag: T,
42}
43
44impl<T: Scalar> Eigenvalue<T> {
45 pub fn real_only(value: T) -> Self {
47 Self {
48 real: value,
49 imag: T::zero(),
50 }
51 }
52
53 pub fn complex(real: T, imag: T) -> Self {
55 Self { real, imag }
56 }
57
58 pub fn is_real(&self) -> bool {
60 self.imag == T::zero()
61 }
62}
63
64#[derive(Debug, Clone)]
70pub struct Schur<T: Scalar> {
71 q: Mat<T>,
73 t: Mat<T>,
75 eigenvalues: Vec<Eigenvalue<T>>,
77 n: usize,
79}
80
81impl<T: Field + Real + bytemuck::Zeroable> Schur<T> {
82 const MAX_ITERATIONS: usize = 100;
84
85 pub fn compute(a: MatRef<'_, T>) -> Result<Self, SchurError> {
105 let m = a.nrows();
106 let n = a.ncols();
107
108 if m == 0 || n == 0 {
109 return Err(SchurError::EmptyMatrix);
110 }
111
112 if m != n {
113 return Err(SchurError::NotSquare);
114 }
115
116 if n == 1 {
118 let mut t = Mat::zeros(1, 1);
119 t[(0, 0)] = a[(0, 0)];
120 let mut q = Mat::zeros(1, 1);
121 q[(0, 0)] = T::one();
122 let eigenvalues = vec![Eigenvalue::real_only(a[(0, 0)])];
123 return Ok(Self {
124 q,
125 t,
126 eigenvalues,
127 n,
128 });
129 }
130
131 if n == 2 {
133 return Self::compute_2x2(a);
134 }
135
136 let hess = Hessenberg::compute(a).map_err(|_| SchurError::NotSquare)?;
138 let mut t = Mat::zeros(n, n);
139 let h = hess.h();
140 for i in 0..n {
141 for j in 0..n {
142 t[(i, j)] = h[(i, j)];
143 }
144 }
145
146 let mut q = Mat::zeros(n, n);
147 let q_hess = hess.q();
148 for i in 0..n {
149 for j in 0..n {
150 q[(i, j)] = q_hess[(i, j)];
151 }
152 }
153
154 let eps = <T as Scalar>::epsilon();
156 let tol = eps * T::from_f64(100.0).unwrap_or(T::one());
157
158 let mut p = n;
160 let mut iter_count = 0;
161
162 while p > 2 && iter_count < Self::MAX_ITERATIONS * n {
163 iter_count += 1;
164
165 let mut q_idx = p - 1;
167 while q_idx > 0 {
168 let sub = Scalar::abs(t[(q_idx, q_idx - 1)]);
169 let diag_sum =
170 Scalar::abs(t[(q_idx - 1, q_idx - 1)]) + Scalar::abs(t[(q_idx, q_idx)]);
171 if sub <= tol * diag_sum {
172 t[(q_idx, q_idx - 1)] = T::zero();
173 break;
174 }
175 q_idx -= 1;
176 }
177
178 if q_idx == p - 1 {
179 p -= 1;
181 } else if q_idx == p - 2 {
182 let a11 = t[(p - 2, p - 2)];
184 let a12 = t[(p - 2, p - 1)];
185 let a21 = t[(p - 1, p - 2)];
186 let a22 = t[(p - 1, p - 1)];
187 let trace = a11 + a22;
188 let det = a11 * a22 - a12 * a21;
189 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
190
191 if disc < T::zero() {
192 p -= 2;
194 } else {
195 Self::francis_qr_step(&mut t, &mut q, q_idx, p);
197 }
198 } else {
199 Self::francis_qr_step(&mut t, &mut q, q_idx, p);
201 }
202 }
203
204 if p == 2 {
206 let sub = Scalar::abs(t[(1, 0)]);
207 let diag_sum = Scalar::abs(t[(0, 0)]) + Scalar::abs(t[(1, 1)]);
208 if sub <= tol * diag_sum {
209 t[(1, 0)] = T::zero();
210 }
211 }
212
213 let eigenvalues = Self::extract_eigenvalues(&t);
215
216 Ok(Self {
217 q,
218 t,
219 eigenvalues,
220 n,
221 })
222 }
223
224 fn compute_2x2(a: MatRef<'_, T>) -> Result<Self, SchurError> {
226 let mut t = Mat::zeros(2, 2);
227 for i in 0..2 {
228 for j in 0..2 {
229 t[(i, j)] = a[(i, j)];
230 }
231 }
232
233 let a11 = a[(0, 0)];
234 let a12 = a[(0, 1)];
235 let a21 = a[(1, 0)];
236 let a22 = a[(1, 1)];
237
238 let trace = a11 + a22;
239 let det = a11 * a22 - a12 * a21;
240 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
241
242 let mut q = Mat::zeros(2, 2);
243 let eigenvalues: Vec<Eigenvalue<T>>;
244
245 if disc >= T::zero() {
246 let sqrt_disc = Real::sqrt(disc);
248 let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap();
249 let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap();
250
251 if Scalar::abs(a21) > <T as Scalar>::epsilon() {
253 let theta = if Scalar::abs(a11 - lambda1) > <T as Scalar>::epsilon() {
254 Real::atan2(a21, a11 - lambda1)
255 } else {
256 T::from_f64(core::f64::consts::FRAC_PI_2).unwrap()
257 };
258 let c = Real::cos(theta);
259 let s = Real::sin(theta);
260
261 q[(0, 0)] = c;
262 q[(0, 1)] = -s;
263 q[(1, 0)] = s;
264 q[(1, 1)] = c;
265
266 let mut temp = Mat::zeros(2, 2);
268 for i in 0..2 {
270 for j in 0..2 {
271 let mut sum = T::zero();
272 for k in 0..2 {
273 sum = sum + q[(k, i)] * a[(k, j)];
274 }
275 temp[(i, j)] = sum;
276 }
277 }
278 for i in 0..2 {
280 for j in 0..2 {
281 let mut sum = T::zero();
282 for k in 0..2 {
283 sum = sum + temp[(i, k)] * q[(k, j)];
284 }
285 t[(i, j)] = sum;
286 }
287 }
288 } else {
289 q[(0, 0)] = T::one();
291 q[(1, 1)] = T::one();
292 }
293
294 eigenvalues = vec![
295 Eigenvalue::real_only(lambda1),
296 Eigenvalue::real_only(lambda2),
297 ];
298 } else {
299 let sqrt_disc = Real::sqrt(-disc);
301 let real_part = trace / T::from_f64(2.0).unwrap();
302 let imag_part = sqrt_disc / T::from_f64(2.0).unwrap();
303
304 q[(0, 0)] = T::one();
305 q[(1, 1)] = T::one();
306
307 eigenvalues = vec![
308 Eigenvalue::complex(real_part, imag_part),
309 Eigenvalue::complex(real_part, -imag_part),
310 ];
311 }
312
313 Ok(Self {
314 q,
315 t,
316 eigenvalues,
317 n: 2,
318 })
319 }
320
321 fn francis_qr_step(t: &mut Mat<T>, q: &mut Mat<T>, start: usize, end: usize) {
323 let n = t.nrows();
324
325 if end - start < 2 {
326 return;
327 }
328
329 let h11 = t[(end - 2, end - 2)];
331 let h12 = t[(end - 2, end - 1)];
332 let h21 = t[(end - 1, end - 2)];
333 let h22 = t[(end - 1, end - 1)];
334
335 let s = h11 + h22; let p = h11 * h22 - h12 * h21; let h_00 = t[(start, start)];
340 let h_01 = t[(start, start + 1)];
341 let h_10 = t[(start + 1, start)];
342
343 let mut x = h_00 * h_00 + h_01 * h_10 - s * h_00 + p;
344 let mut y = h_10 * (h_00 + t[(start + 1, start + 1)] - s);
345 let mut z = if start + 2 < end {
346 h_10 * t[(start + 2, start + 1)]
347 } else {
348 T::zero()
349 };
350
351 for k in start..end.saturating_sub(1) {
353 let (v, tau) = householder_3(&[x, y, z]);
355
356 if tau != T::zero() {
357 let r = if k > start { k - 1 } else { k };
358
359 let col_start = r;
361 let col_end = n;
362 for j in col_start..col_end {
363 let rows = (k..(k + 3).min(end)).collect::<Vec<_>>();
364 let mut dot = T::zero();
365 for (vi, &row) in rows.iter().enumerate() {
366 dot = dot + v[vi] * t[(row, j)];
367 }
368 let scaled = tau * dot;
369 for (vi, &row) in rows.iter().enumerate() {
370 t[(row, j)] = t[(row, j)] - scaled * v[vi];
371 }
372 }
373
374 let row_end = (k + 4).min(end);
376 for i in 0..row_end {
377 let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
378 let mut dot = T::zero();
379 for (vi, &col) in cols.iter().enumerate() {
380 dot = dot + t[(i, col)] * v[vi];
381 }
382 let scaled = tau * dot;
383 for (vi, &col) in cols.iter().enumerate() {
384 t[(i, col)] = t[(i, col)] - scaled * v[vi];
385 }
386 }
387
388 for i in 0..n {
390 let cols = (k..(k + 3).min(end)).collect::<Vec<_>>();
391 let mut dot = T::zero();
392 for (vi, &col) in cols.iter().enumerate() {
393 dot = dot + q[(i, col)] * v[vi];
394 }
395 let scaled = tau * dot;
396 for (vi, &col) in cols.iter().enumerate() {
397 q[(i, col)] = q[(i, col)] - scaled * v[vi];
398 }
399 }
400 }
401
402 if k + 3 < end {
404 x = t[(k + 1, k)];
405 y = t[(k + 2, k)];
406 z = if k + 3 < end {
407 t[(k + 3, k)]
408 } else {
409 T::zero()
410 };
411 } else if k + 2 < end {
412 x = t[(k + 1, k)];
414 y = t[(k + 2, k)];
415 z = T::zero();
416 }
417 }
418 }
419
420 fn extract_eigenvalues(t: &Mat<T>) -> Vec<Eigenvalue<T>> {
422 let n = t.nrows();
423 let mut eigenvalues = Vec::with_capacity(n);
424 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
425
426 let mut i = 0;
427 while i < n {
428 if i == n - 1 {
429 eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
431 i += 1;
432 } else {
433 let sub = Scalar::abs(t[(i + 1, i)]);
435 let diag_sum = Scalar::abs(t[(i, i)]) + Scalar::abs(t[(i + 1, i + 1)]);
436
437 if sub <= eps * diag_sum {
438 eigenvalues.push(Eigenvalue::real_only(t[(i, i)]));
440 i += 1;
441 } else {
442 let a11 = t[(i, i)];
444 let a12 = t[(i, i + 1)];
445 let a21 = t[(i + 1, i)];
446 let a22 = t[(i + 1, i + 1)];
447
448 let trace = a11 + a22;
449 let det = a11 * a22 - a12 * a21;
450 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
451
452 if disc >= T::zero() {
453 let sqrt_disc = Real::sqrt(disc);
455 let lambda1 = (trace + sqrt_disc) / T::from_f64(2.0).unwrap();
456 let lambda2 = (trace - sqrt_disc) / T::from_f64(2.0).unwrap();
457 eigenvalues.push(Eigenvalue::real_only(lambda1));
458 eigenvalues.push(Eigenvalue::real_only(lambda2));
459 } else {
460 let sqrt_disc = Real::sqrt(-disc);
462 let real_part = trace / T::from_f64(2.0).unwrap();
463 let imag_part = sqrt_disc / T::from_f64(2.0).unwrap();
464 eigenvalues.push(Eigenvalue::complex(real_part, imag_part));
465 eigenvalues.push(Eigenvalue::complex(real_part, -imag_part));
466 }
467 i += 2;
468 }
469 }
470 }
471
472 eigenvalues
473 }
474
475 pub fn q(&self) -> MatRef<'_, T> {
477 self.q.as_ref()
478 }
479
480 pub fn t(&self) -> MatRef<'_, T> {
482 self.t.as_ref()
483 }
484
485 pub fn eigenvalues(&self) -> &[Eigenvalue<T>] {
487 &self.eigenvalues
488 }
489
490 pub fn eigenvalues_real(&self) -> Vec<T> {
492 self.eigenvalues.iter().map(|e| e.real).collect()
493 }
494
495 pub fn reconstruct(&self) -> Mat<T> {
497 let mut qt = Mat::zeros(self.n, self.n);
498 let mut a = Mat::zeros(self.n, self.n);
499
500 for i in 0..self.n {
502 for j in 0..self.n {
503 let mut sum = T::zero();
504 for k in 0..self.n {
505 sum = sum + self.q[(i, k)] * self.t[(k, j)];
506 }
507 qt[(i, j)] = sum;
508 }
509 }
510
511 for i in 0..self.n {
513 for j in 0..self.n {
514 let mut sum = T::zero();
515 for k in 0..self.n {
516 sum = sum + qt[(i, k)] * self.q[(j, k)];
517 }
518 a[(i, j)] = sum;
519 }
520 }
521
522 a
523 }
524
525 #[must_use]
551 pub fn right_eigenvectors(&self) -> Mat<T> {
552 trevc_right(&self.t)
553 }
554
555 #[must_use]
576 pub fn left_eigenvectors(&self) -> Mat<T> {
577 trevc_left(&self.t)
578 }
579
580 #[must_use]
588 pub fn eigenvectors(&self) -> (Mat<T>, Mat<T>) {
589 let vr_t = trevc_right(&self.t);
590 let vl_t = trevc_left(&self.t);
591
592 let mut vr_a = Mat::zeros(self.n, self.n);
594 let mut vl_a = Mat::zeros(self.n, self.n);
595
596 for i in 0..self.n {
597 for j in 0..self.n {
598 let mut sum_r = T::zero();
599 let mut sum_l = T::zero();
600 for k in 0..self.n {
601 sum_r = sum_r + self.q[(i, k)] * vr_t[(k, j)];
602 sum_l = sum_l + self.q[(i, k)] * vl_t[(k, j)];
603 }
604 vr_a[(i, j)] = sum_r;
605 vl_a[(i, j)] = sum_l;
606 }
607 }
608
609 (vr_a, vl_a)
610 }
611
612 #[must_use]
646 pub fn eigenvalue_condition_numbers(&self) -> Vec<T> {
647 trsna_s(&self.t)
648 }
649
650 #[must_use]
660 pub fn eigenvector_separation(&self) -> Vec<T> {
661 trsna_sep(&self.t)
662 }
663}
664
665pub fn trsna_s<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
678 let n = t.nrows();
679 if n == 0 {
680 return Vec::new();
681 }
682
683 let vr = trevc_right(t);
684 let vl = trevc_left(t);
685
686 let mut s = vec![T::zero(); n];
687 let eps = <T as Scalar>::epsilon();
688
689 let mut j = 0;
690 while j < n {
691 let is_2x2 = if j + 1 < n {
693 let sub = Scalar::abs(t[(j + 1, j)]);
694 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
695 sub > eps * T::from_f64(100.0).unwrap() * diag_sum
696 } else {
697 false
698 };
699
700 if is_2x2 {
701 let jp1 = j + 1;
705
706 let mut prod_rr = T::zero(); let mut prod_ii = T::zero(); let mut prod_ri = T::zero(); let mut prod_ir = T::zero(); for k in 0..n {
712 prod_rr = prod_rr + vl[(k, j)] * vr[(k, j)];
713 prod_ii = prod_ii + vl[(k, jp1)] * vr[(k, jp1)];
714 prod_ri = prod_ri + vl[(k, j)] * vr[(k, jp1)];
715 prod_ir = prod_ir + vl[(k, jp1)] * vr[(k, j)];
716 }
717
718 let real_part = prod_rr + prod_ii;
719 let imag_part = prod_ri - prod_ir;
720 let abs_inner = Real::sqrt(real_part * real_part + imag_part * imag_part);
721
722 s[j] = abs_inner;
724 s[jp1] = abs_inner;
725
726 j += 2;
727 } else {
728 let mut inner = T::zero();
730 for k in 0..n {
731 inner = inner + vl[(k, j)] * vr[(k, j)];
732 }
733 s[j] = Scalar::abs(inner);
734 j += 1;
735 }
736 }
737
738 s
739}
740
741pub fn trsna_sep<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Vec<T> {
756 let n = t.nrows();
757 if n == 0 {
758 return Vec::new();
759 }
760
761 let mut sep = vec![T::zero(); n];
762 let eps = <T as Scalar>::epsilon();
763
764 let mut j = 0;
765 while j < n {
766 let is_2x2 = if j + 1 < n {
768 let sub = Scalar::abs(t[(j + 1, j)]);
769 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
770 sub > eps * T::from_f64(100.0).unwrap() * diag_sum
771 } else {
772 false
773 };
774
775 if is_2x2 {
776 let jp1 = j + 1;
778
779 let a11 = t[(j, j)];
781 let a22 = t[(jp1, jp1)];
782 let lambda_real = (a11 + a22) / T::from_f64(2.0).unwrap();
783 let trace = a11 + a22;
784 let det = a11 * a22 - t[(j, jp1)] * t[(jp1, j)];
785 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
786 let lambda_imag = if disc < T::zero() {
787 Real::sqrt(-disc) / T::from_f64(2.0).unwrap()
788 } else {
789 T::zero()
790 };
791
792 let mut min_sep = T::one() / eps;
793
794 let mut k = 0;
796 while k < n {
797 if k == j || k == jp1 {
798 k += 1;
799 continue;
800 }
801
802 let adjacent = j > 0 && k == j - 1;
804 let k_is_2x2 = if k + 1 < n && !adjacent {
805 let sub = Scalar::abs(t[(k + 1, k)]);
806 let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
807 sub > eps * T::from_f64(100.0).unwrap() * diag_sum
808 } else {
809 false
810 };
811
812 let (other_real, other_imag) = if k_is_2x2 {
813 let kp1 = k + 1;
814 let b11 = t[(k, k)];
815 let b22 = t[(kp1, kp1)];
816 let other_trace = b11 + b22;
817 let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
818 let other_disc =
819 other_trace * other_trace - T::from_f64(4.0).unwrap() * other_det;
820 let r = (b11 + b22) / T::from_f64(2.0).unwrap();
821 let i = if other_disc < T::zero() {
822 Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap()
823 } else {
824 T::zero()
825 };
826 (r, i)
827 } else {
828 (t[(k, k)], T::zero())
829 };
830
831 let dr = lambda_real - other_real;
833 let di = lambda_imag - other_imag;
834 let dist = Real::sqrt(dr * dr + di * di);
835 if dist < min_sep && dist > T::zero() {
836 min_sep = dist;
837 }
838
839 if other_imag != T::zero() {
841 let di_conj = lambda_imag + other_imag;
842 let dist_conj = Real::sqrt(dr * dr + di_conj * di_conj);
843 if dist_conj < min_sep && dist_conj > T::zero() {
844 min_sep = dist_conj;
845 }
846 }
847
848 if k_is_2x2 {
849 k += 2;
850 } else {
851 k += 1;
852 }
853 }
854
855 sep[j] = min_sep;
856 sep[jp1] = min_sep;
857 j += 2;
858 } else {
859 let lambda = t[(j, j)];
861 let mut min_sep = T::one() / eps;
862
863 let mut k = 0;
865 while k < n {
866 if k == j {
867 k += 1;
868 continue;
869 }
870
871 let adjacent_to_j = (j > 0 && k == j - 1) || k == j + 1;
873 let k_is_2x2 = if k + 1 < n && !adjacent_to_j {
874 let sub = Scalar::abs(t[(k + 1, k)]);
875 let diag_sum = Scalar::abs(t[(k, k)]) + Scalar::abs(t[(k + 1, k + 1)]);
876 sub > eps * T::from_f64(100.0).unwrap() * diag_sum
877 } else {
878 false
879 };
880
881 let (other_real, other_imag) = if k_is_2x2 {
882 let kp1 = k + 1;
883 let b11 = t[(k, k)];
884 let b22 = t[(kp1, kp1)];
885 let other_trace = b11 + b22;
886 let other_det = b11 * b22 - t[(k, kp1)] * t[(kp1, k)];
887 let other_disc =
888 other_trace * other_trace - T::from_f64(4.0).unwrap() * other_det;
889 let r = (b11 + b22) / T::from_f64(2.0).unwrap();
890 let i = if other_disc < T::zero() {
891 Real::sqrt(-other_disc) / T::from_f64(2.0).unwrap()
892 } else {
893 T::zero()
894 };
895 (r, i)
896 } else {
897 (t[(k, k)], T::zero())
898 };
899
900 let dr = lambda - other_real;
902 let dist = Real::sqrt(dr * dr + other_imag * other_imag);
903 if dist < min_sep && dist > T::zero() {
904 min_sep = dist;
905 }
906
907 if k_is_2x2 {
908 k += 2;
909 } else {
910 k += 1;
911 }
912 }
913
914 sep[j] = min_sep;
915 j += 1;
916 }
917 }
918
919 sep
920}
921
922pub fn trevc_right<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
936 let n = t.nrows();
937 let mut v = Mat::zeros(n, n);
938
939 for i in 0..n {
941 v[(i, i)] = T::one();
942 }
943
944 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
945
946 let mut j = n;
948 while j > 0 {
949 j -= 1;
950
951 let is_2x2 = if j > 0 {
953 let sub = Scalar::abs(t[(j, j - 1)]);
954 let diag_sum = Scalar::abs(t[(j - 1, j - 1)]) + Scalar::abs(t[(j, j)]);
955 sub > eps * diag_sum
956 } else {
957 false
958 };
959
960 if is_2x2 {
961 let jm1 = j - 1;
964
965 let a11 = t[(jm1, jm1)];
967 let a12 = t[(jm1, j)];
968 let a21 = t[(j, jm1)];
969 let a22 = t[(j, j)];
970
971 let trace = a11 + a22;
972 let det = a11 * a22 - a12 * a21;
973 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
974
975 let two = T::from_f64(2.0).unwrap();
977 let real_part = trace / two;
978 let imag_part = Real::sqrt(-disc) / two;
979
980 v[(jm1, jm1)] = T::one();
983 v[(j, jm1)] = T::zero();
984 v[(jm1, j)] = T::zero();
985 v[(j, j)] = T::one();
986
987 let d11 = a11 - real_part;
1012 let d22 = a22 - real_part;
1013
1014 if Scalar::abs(a21) >= Scalar::abs(a12) && Scalar::abs(a21) > eps {
1016 let v1r = -d22 / a21;
1019 let v1i = imag_part / a21;
1020 v[(jm1, jm1)] = v1r;
1021 v[(j, jm1)] = T::one();
1022 v[(jm1, j)] = v1i;
1023 v[(j, j)] = T::zero();
1024 } else if Scalar::abs(a12) > eps {
1025 let v2r = -d11 / a12;
1028 let v2i = -imag_part / a12;
1029 v[(jm1, jm1)] = T::one();
1030 v[(j, jm1)] = v2r;
1031 v[(jm1, j)] = T::zero();
1032 v[(j, j)] = v2i;
1033 } else {
1034 v[(jm1, jm1)] = T::one();
1036 v[(j, jm1)] = T::zero();
1037 v[(jm1, j)] = T::zero();
1038 v[(j, j)] = T::one();
1039 }
1040
1041 for i in (0..jm1).rev() {
1043 let mut sum_r = T::zero();
1048 let mut sum_i = T::zero();
1049 for k in (i + 1)..=j {
1050 sum_r = sum_r + t[(i, k)] * v[(k, jm1)];
1051 sum_i = sum_i + t[(i, k)] * v[(k, j)];
1052 }
1053
1054 let d = t[(i, i)] - real_part;
1055 let det_2x2 = d * d + imag_part * imag_part;
1056
1057 if Scalar::abs(det_2x2) > eps {
1058 v[(i, jm1)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1062 v[(i, j)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1063 }
1064 }
1065
1066 let mut norm_r_sq = T::zero();
1068 let mut norm_i_sq = T::zero();
1069 for i in 0..n {
1070 norm_r_sq = norm_r_sq + v[(i, jm1)] * v[(i, jm1)];
1071 norm_i_sq = norm_i_sq + v[(i, j)] * v[(i, j)];
1072 }
1073 let norm = Real::sqrt(norm_r_sq + norm_i_sq);
1074 if norm > T::zero() {
1075 for i in 0..n {
1076 v[(i, jm1)] = v[(i, jm1)] / norm;
1077 v[(i, j)] = v[(i, j)] / norm;
1078 }
1079 }
1080
1081 j = jm1; } else {
1083 let lambda = t[(j, j)];
1085
1086 v[(j, j)] = T::one();
1088
1089 for i in (0..j).rev() {
1091 let mut sum = T::zero();
1092 for k in (i + 1)..=j {
1093 sum = sum + t[(i, k)] * v[(k, j)];
1094 }
1095
1096 let d = t[(i, i)] - lambda;
1097 if Scalar::abs(d) > eps {
1098 v[(i, j)] = -sum / d;
1099 } else {
1100 v[(i, j)] = -sum / eps;
1102 }
1103 }
1104
1105 let mut norm_sq = T::zero();
1107 for i in 0..n {
1108 norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1109 }
1110 let norm = Real::sqrt(norm_sq);
1111 if norm > T::zero() {
1112 for i in 0..n {
1113 v[(i, j)] = v[(i, j)] / norm;
1114 }
1115 }
1116 }
1117 }
1118
1119 v
1120}
1121
1122pub fn trevc_left<T: Field + Real + bytemuck::Zeroable>(t: &Mat<T>) -> Mat<T> {
1135 let n = t.nrows();
1136 let mut v = Mat::zeros(n, n);
1137
1138 for i in 0..n {
1140 v[(i, i)] = T::one();
1141 }
1142
1143 let eps = <T as Scalar>::epsilon() * T::from_f64(100.0).unwrap_or(T::one());
1144
1145 let mut j = 0;
1147 while j < n {
1148 let is_2x2 = if j + 1 < n {
1150 let sub = Scalar::abs(t[(j + 1, j)]);
1151 let diag_sum = Scalar::abs(t[(j, j)]) + Scalar::abs(t[(j + 1, j + 1)]);
1152 sub > eps * diag_sum
1153 } else {
1154 false
1155 };
1156
1157 if is_2x2 {
1158 let jp1 = j + 1;
1160
1161 let a11 = t[(j, j)];
1162 let a12 = t[(j, jp1)];
1163 let a21 = t[(jp1, j)];
1164 let a22 = t[(jp1, jp1)];
1165
1166 let trace = a11 + a22;
1167 let det = a11 * a22 - a12 * a21;
1168 let disc = trace * trace - T::from_f64(4.0).unwrap() * det;
1169
1170 let two = T::from_f64(2.0).unwrap();
1171 let real_part = trace / two;
1172 let imag_part = Real::sqrt(-disc) / two;
1173
1174 let d11 = a11 - real_part;
1176 let det_factor = d11 * d11 + imag_part * imag_part;
1177
1178 if Scalar::abs(det_factor) > eps {
1179 let vr2 = -a21 * d11 / det_factor;
1180 let vi2 = -imag_part * a21 / det_factor;
1181 v[(j, j)] = T::one();
1182 v[(jp1, j)] = vr2;
1183 v[(j, jp1)] = T::zero();
1184 v[(jp1, jp1)] = vi2;
1185 } else {
1186 v[(j, j)] = T::one();
1187 v[(jp1, j)] = T::zero();
1188 v[(j, jp1)] = T::zero();
1189 v[(jp1, jp1)] = T::one();
1190 }
1191
1192 for i in (jp1 + 1)..n {
1194 let mut sum_r = T::zero();
1195 let mut sum_i = T::zero();
1196 for k in j..i {
1197 sum_r = sum_r + t[(k, i)] * v[(k, j)];
1198 sum_i = sum_i + t[(k, i)] * v[(k, jp1)];
1199 }
1200
1201 let d = t[(i, i)] - real_part;
1202 let det_2x2 = d * d + imag_part * imag_part;
1203
1204 if Scalar::abs(det_2x2) > eps {
1205 v[(i, j)] = (-d * sum_r - imag_part * sum_i) / det_2x2;
1206 v[(i, jp1)] = (imag_part * sum_r - d * sum_i) / det_2x2;
1207 }
1208 }
1209
1210 let mut norm_sq = T::zero();
1212 for i in 0..n {
1213 norm_sq = norm_sq + v[(i, j)] * v[(i, j)] + v[(i, jp1)] * v[(i, jp1)];
1214 }
1215 let norm = Real::sqrt(norm_sq);
1216 if norm > T::zero() {
1217 for i in 0..n {
1218 v[(i, j)] = v[(i, j)] / norm;
1219 v[(i, jp1)] = v[(i, jp1)] / norm;
1220 }
1221 }
1222
1223 j = jp1 + 1;
1224 } else {
1225 let lambda = t[(j, j)];
1227 v[(j, j)] = T::one();
1228
1229 for i in (j + 1)..n {
1231 let mut sum = T::zero();
1232 for k in j..i {
1233 sum = sum + t[(k, i)] * v[(k, j)];
1234 }
1235
1236 let d = t[(i, i)] - lambda;
1237 if Scalar::abs(d) > eps {
1238 v[(i, j)] = -sum / d;
1239 } else {
1240 v[(i, j)] = -sum / eps;
1241 }
1242 }
1243
1244 let mut norm_sq = T::zero();
1246 for i in 0..n {
1247 norm_sq = norm_sq + v[(i, j)] * v[(i, j)];
1248 }
1249 let norm = Real::sqrt(norm_sq);
1250 if norm > T::zero() {
1251 for i in 0..n {
1252 v[(i, j)] = v[(i, j)] / norm;
1253 }
1254 }
1255
1256 j += 1;
1257 }
1258 }
1259
1260 v
1261}
1262
1263fn householder_3<T: Field + Real>(x: &[T]) -> (Vec<T>, T) {
1265 let n = x.len().min(3);
1266 if n == 0 {
1267 return (Vec::new(), T::zero());
1268 }
1269
1270 let mut norm_sq = T::zero();
1271 for i in 0..n {
1272 norm_sq = norm_sq + x[i] * x[i];
1273 }
1274 let norm = Real::sqrt(norm_sq);
1275
1276 if norm == T::zero() {
1277 return (vec![T::zero(); n], T::zero());
1278 }
1279
1280 let mut v = vec![T::zero(); n];
1281 for i in 0..n {
1282 v[i] = x[i];
1283 }
1284
1285 let sign = if x[0] >= T::zero() {
1286 T::one()
1287 } else {
1288 -T::one()
1289 };
1290 v[0] = v[0] + sign * norm;
1291
1292 let mut v_norm_sq = T::zero();
1293 for i in 0..n {
1294 v_norm_sq = v_norm_sq + v[i] * v[i];
1295 }
1296
1297 if v_norm_sq > T::zero() {
1298 let tau = T::from_f64(2.0).unwrap() / v_norm_sq;
1299 (v, tau)
1300 } else {
1301 (v, T::zero())
1302 }
1303}
1304
1305#[cfg(test)]
1306mod tests {
1307 use super::*;
1308
1309 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
1310 (a - b).abs() < tol
1311 }
1312
1313 #[test]
1314 fn test_schur_upper_triangular() {
1315 let a = Mat::from_rows(&[&[1.0f64, 2.0], &[0.0, 3.0]]);
1317
1318 let schur = Schur::compute(a.as_ref()).unwrap();
1319 let eigenvalues = schur.eigenvalues();
1320
1321 let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1323 eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1324 assert!(approx_eq(eigs[0], 1.0, 1e-10));
1325 assert!(approx_eq(eigs[1], 3.0, 1e-10));
1326 }
1327
1328 #[test]
1329 fn test_schur_diagonal() {
1330 let a = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1331
1332 let schur = Schur::compute(a.as_ref()).unwrap();
1333 let eigenvalues = schur.eigenvalues();
1334
1335 let mut eigs: Vec<f64> = eigenvalues.iter().map(|e| e.real).collect();
1336 eigs.sort_by(|a, b| a.partial_cmp(b).unwrap());
1337 assert!(approx_eq(eigs[0], 2.0, 1e-10));
1338 assert!(approx_eq(eigs[1], 3.0, 1e-10));
1339 assert!(approx_eq(eigs[2], 5.0, 1e-10));
1340 }
1341
1342 #[test]
1343 fn test_schur_complex_eigenvalues() {
1344 let theta = core::f64::consts::FRAC_PI_4; let c = theta.cos();
1347 let s = theta.sin();
1348 let a = Mat::from_rows(&[&[c, -s], &[s, c]]);
1349
1350 let schur = Schur::compute(a.as_ref()).unwrap();
1351 let eigenvalues = schur.eigenvalues();
1352
1353 assert_eq!(eigenvalues.len(), 2);
1355 assert!(approx_eq(eigenvalues[0].real, eigenvalues[1].real, 1e-10));
1357 assert!(approx_eq(eigenvalues[0].imag, -eigenvalues[1].imag, 1e-10));
1358 assert!(approx_eq(eigenvalues[0].real, c, 1e-10));
1359 assert!(approx_eq(eigenvalues[0].imag.abs(), s, 1e-10));
1360 }
1361
1362 #[test]
1363 fn test_schur_reconstruction() {
1364 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1365
1366 let schur = Schur::compute(a.as_ref()).unwrap();
1367 let reconstructed = schur.reconstruct();
1368
1369 for i in 0..3 {
1370 for j in 0..3 {
1371 assert!(
1372 approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-10),
1373 "reconstructed[{},{}] = {}, a = {}",
1374 i,
1375 j,
1376 reconstructed[(i, j)],
1377 a[(i, j)]
1378 );
1379 }
1380 }
1381 }
1382
1383 #[test]
1384 fn test_schur_q_orthogonal() {
1385 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1386
1387 let schur = Schur::compute(a.as_ref()).unwrap();
1388 let q = schur.q();
1389
1390 let n = 3;
1392 for i in 0..n {
1393 for j in 0..n {
1394 let mut dot = 0.0;
1395 for k in 0..n {
1396 dot += q[(k, i)] * q[(k, j)];
1397 }
1398 let expected = if i == j { 1.0 } else { 0.0 };
1399 assert!(
1400 approx_eq(dot, expected, 1e-10),
1401 "Q^T*Q[{},{}] = {}, expected {}",
1402 i,
1403 j,
1404 dot,
1405 expected
1406 );
1407 }
1408 }
1409 }
1410
1411 #[test]
1412 fn test_schur_single() {
1413 let a = Mat::from_rows(&[&[5.0f64]]);
1414 let schur = Schur::compute(a.as_ref()).unwrap();
1415
1416 assert_eq!(schur.eigenvalues().len(), 1);
1417 assert!(approx_eq(schur.eigenvalues()[0].real, 5.0, 1e-10));
1418 }
1419
1420 #[test]
1421 fn test_schur_4x4() {
1422 let a = Mat::from_rows(&[
1423 &[4.0f64, 1.0, -2.0, 2.0],
1424 &[1.0, 2.0, 0.0, 1.0],
1425 &[-2.0, 0.0, 3.0, -2.0],
1426 &[2.0, 1.0, -2.0, -1.0],
1427 ]);
1428
1429 let schur = Schur::compute(a.as_ref()).unwrap();
1430 let reconstructed = schur.reconstruct();
1431
1432 for i in 0..4 {
1433 for j in 0..4 {
1434 assert!(
1435 approx_eq(reconstructed[(i, j)], a[(i, j)], 1e-8),
1436 "reconstructed[{},{}] = {}, a = {}",
1437 i,
1438 j,
1439 reconstructed[(i, j)],
1440 a[(i, j)]
1441 );
1442 }
1443 }
1444 }
1445
1446 #[test]
1447 fn test_schur_f32() {
1448 let a = Mat::from_rows(&[&[1.0f32, 2.0], &[3.0, 4.0]]);
1449
1450 let schur = Schur::compute(a.as_ref()).unwrap();
1451 let reconstructed = schur.reconstruct();
1452
1453 for i in 0..2 {
1454 for j in 0..2 {
1455 assert!(
1456 (reconstructed[(i, j)] - a[(i, j)]).abs() < 1e-4,
1457 "reconstructed[{},{}] = {}, a = {}",
1458 i,
1459 j,
1460 reconstructed[(i, j)],
1461 a[(i, j)]
1462 );
1463 }
1464 }
1465 }
1466
1467 #[test]
1468 fn test_trevc_right_upper_triangular() {
1469 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1471
1472 let v = trevc_right(&t);
1473
1474 assert!(v[(1, 0)].abs() < 1e-10);
1476 assert!(v[(2, 0)].abs() < 1e-10);
1477
1478 assert!(v[(2, 2)].abs() > 0.1);
1481 }
1482
1483 #[test]
1484 fn test_trevc_right_diagonal() {
1485 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1487
1488 let v = trevc_right(&t);
1489
1490 for i in 0..3 {
1492 assert!(
1493 approx_eq(v[(i, i)].abs(), 1.0, 1e-10),
1494 "v[{},{}] = {}",
1495 i,
1496 i,
1497 v[(i, i)]
1498 );
1499 }
1500 }
1501
1502 #[test]
1503 fn test_trevc_eigenvector_equation() {
1504 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1506
1507 let v = trevc_right(&t);
1508
1509 let eigenvalues = [1.0, 4.0, 6.0];
1511
1512 for (j, &lambda) in eigenvalues.iter().enumerate() {
1513 let mut tv = [0.0; 3];
1515 for i in 0..3 {
1516 for k in 0..3 {
1517 tv[i] += t[(i, k)] * v[(k, j)];
1518 }
1519 }
1520
1521 for i in 0..3 {
1523 assert!(
1524 approx_eq(tv[i], lambda * v[(i, j)], 1e-10),
1525 "T*v[{}] = {}, λ*v = {}",
1526 i,
1527 tv[i],
1528 lambda * v[(i, j)]
1529 );
1530 }
1531 }
1532 }
1533
1534 #[test]
1535 fn test_trevc_left_diagonal() {
1536 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1537
1538 let u = trevc_left(&t);
1539
1540 for i in 0..3 {
1542 assert!(
1543 approx_eq(u[(i, i)].abs(), 1.0, 1e-10),
1544 "u[{},{}] = {}",
1545 i,
1546 i,
1547 u[(i, i)]
1548 );
1549 }
1550 }
1551
1552 #[test]
1553 fn test_schur_eigenvectors() {
1554 let a = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1556
1557 let schur = Schur::compute(a.as_ref()).unwrap();
1558 let (vr, _vl) = schur.eigenvectors();
1559
1560 for j in 0..3 {
1563 let lambda = schur.eigenvalues()[j].real;
1564
1565 let mut av = [0.0; 3];
1566 for i in 0..3 {
1567 for k in 0..3 {
1568 av[i] += a[(i, k)] * vr[(k, j)];
1569 }
1570 }
1571
1572 for i in 0..3 {
1573 assert!(
1574 approx_eq(av[i], lambda * vr[(i, j)], 1e-8),
1575 "A*v[{}] = {}, λ*v = {}",
1576 i,
1577 av[i],
1578 lambda * vr[(i, j)]
1579 );
1580 }
1581 }
1582 }
1583
1584 #[test]
1585 fn test_trevc_2x2_block() {
1586 let theta = core::f64::consts::FRAC_PI_4;
1588 let c = theta.cos();
1589 let s = theta.sin();
1590 let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1591
1592 let v = trevc_right(&t);
1593
1594 let norm0_sq = v[(0, 0)] * v[(0, 0)] + v[(1, 0)] * v[(1, 0)];
1598 let norm1_sq = v[(0, 1)] * v[(0, 1)] + v[(1, 1)] * v[(1, 1)];
1599 let total_norm = (norm0_sq + norm1_sq).sqrt();
1600
1601 assert!(
1602 approx_eq(total_norm, 1.0, 1e-10),
1603 "eigenvector norm = {}",
1604 total_norm
1605 );
1606 }
1607
1608 #[test]
1609 fn test_trsna_s_diagonal() {
1610 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1613
1614 let s = trsna_s(&t);
1615
1616 assert_eq!(s.len(), 3);
1617 for i in 0..3 {
1618 assert!(
1619 approx_eq(s[i], 1.0, 1e-10),
1620 "s[{}] = {}, expected 1.0",
1621 i,
1622 s[i]
1623 );
1624 }
1625 }
1626
1627 #[test]
1628 fn test_trsna_s_upper_triangular() {
1629 let t = Mat::from_rows(&[&[1.0f64, 2.0, 3.0], &[0.0, 4.0, 5.0], &[0.0, 0.0, 6.0]]);
1631
1632 let s = trsna_s(&t);
1633
1634 assert_eq!(s.len(), 3);
1635 for i in 0..3 {
1637 assert!(s[i] > 0.0, "s[{}] = {} should be positive", i, s[i]);
1638 }
1639 }
1640
1641 #[test]
1642 fn test_trsna_sep_diagonal() {
1643 let t = Mat::from_rows(&[&[2.0f64, 0.0, 0.0], &[0.0, 5.0, 0.0], &[0.0, 0.0, 3.0]]);
1645
1646 let sep = trsna_sep(&t);
1647
1648 assert_eq!(sep.len(), 3);
1649 assert!(
1651 approx_eq(sep[0], 1.0, 1e-10),
1652 "sep[0] = {}, expected 1.0",
1653 sep[0]
1654 );
1655 assert!(
1657 approx_eq(sep[1], 2.0, 1e-10),
1658 "sep[1] = {}, expected 2.0",
1659 sep[1]
1660 );
1661 assert!(
1663 approx_eq(sep[2], 1.0, 1e-10),
1664 "sep[2] = {}, expected 1.0",
1665 sep[2]
1666 );
1667 }
1668
1669 #[test]
1670 fn test_trsna_sep_close_eigenvalues() {
1671 let t = Mat::from_rows(&[&[1.0f64, 0.0], &[0.0, 1.001]]);
1673
1674 let sep = trsna_sep(&t);
1675
1676 assert_eq!(sep.len(), 2);
1677 assert!(
1679 approx_eq(sep[0], 0.001, 1e-10),
1680 "sep[0] = {}, expected 0.001",
1681 sep[0]
1682 );
1683 assert!(
1684 approx_eq(sep[1], 0.001, 1e-10),
1685 "sep[1] = {}, expected 0.001",
1686 sep[1]
1687 );
1688 }
1689
1690 #[test]
1691 fn test_schur_eigenvalue_condition_numbers() {
1692 let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1694
1695 let schur = Schur::compute(a.as_ref()).unwrap();
1696 let cond = schur.eigenvalue_condition_numbers();
1697
1698 assert_eq!(cond.len(), 2);
1699 for i in 0..2 {
1701 assert!(
1702 cond[i] > 0.5,
1703 "cond[{}] = {} should be > 0.5 for diagonal matrix",
1704 i,
1705 cond[i]
1706 );
1707 }
1708 }
1709
1710 #[test]
1711 fn test_schur_eigenvector_separation() {
1712 let a = Mat::from_rows(&[&[4.0f64, 0.0], &[0.0, 2.0]]);
1714
1715 let schur = Schur::compute(a.as_ref()).unwrap();
1716 let sep = schur.eigenvector_separation();
1717
1718 assert_eq!(sep.len(), 2);
1719 for i in 0..2 {
1721 assert!(
1722 approx_eq(sep[i], 2.0, 1e-10),
1723 "sep[{}] = {}, expected 2.0",
1724 i,
1725 sep[i]
1726 );
1727 }
1728 }
1729
1730 #[test]
1731 fn test_trsna_complex_eigenvalues() {
1732 let theta = core::f64::consts::FRAC_PI_4;
1734 let c = theta.cos();
1735 let s = theta.sin();
1736 let t = Mat::from_rows(&[&[c, -s], &[s, c]]);
1737
1738 let cond = trsna_s(&t);
1739 let sep = trsna_sep(&t);
1740
1741 assert_eq!(cond.len(), 2);
1742 assert_eq!(sep.len(), 2);
1743
1744 assert!(
1746 approx_eq(cond[0], cond[1], 1e-10),
1747 "cond[0]={}, cond[1]={} should be equal",
1748 cond[0],
1749 cond[1]
1750 );
1751 assert!(
1752 approx_eq(sep[0], sep[1], 1e-10),
1753 "sep[0]={}, sep[1]={} should be equal",
1754 sep[0],
1755 sep[1]
1756 );
1757 }
1758}