1#![allow(dead_code)]
27
28use oxicuda_ptx::ir::PtxType;
29use oxicuda_ptx::prelude::*;
30
31use crate::error::{SolverError, SolverResult};
32use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
33
34const QZ_DEFAULT_MAX_ITER: u32 = 300;
40
41const QZ_DEFAULT_TOL: f64 = 1e-14;
43
44const BETA_ZERO_THRESHOLD: f64 = 1e-15;
46
47const ALPHA_ZERO_THRESHOLD: f64 = 1e-15;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum BalanceStrategy {
61 None,
63 Permute,
65 Scale,
67 #[default]
69 Both,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76pub enum ShiftStrategy {
77 ExplicitShift,
79 #[default]
81 FrancisDoubleShift,
82 Wilkinson,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum EigenvalueType {
89 Real,
91 ComplexPair,
93 Infinite,
95 Zero,
97}
98
99#[derive(Debug, Clone)]
101pub struct QzConfig {
102 pub n: u32,
104 pub compute_schur_vectors: bool,
106 pub balance: BalanceStrategy,
108 pub max_iterations: u32,
110 pub tolerance: f64,
112 pub sm_version: SmVersion,
114}
115
116impl QzConfig {
117 pub fn new(n: u32, sm_version: SmVersion) -> Self {
119 Self {
120 n,
121 compute_schur_vectors: false,
122 balance: BalanceStrategy::default(),
123 max_iterations: QZ_DEFAULT_MAX_ITER,
124 tolerance: QZ_DEFAULT_TOL,
125 sm_version,
126 }
127 }
128
129 pub fn with_schur_vectors(mut self, enabled: bool) -> Self {
131 self.compute_schur_vectors = enabled;
132 self
133 }
134
135 pub fn with_balance(mut self, strategy: BalanceStrategy) -> Self {
137 self.balance = strategy;
138 self
139 }
140
141 pub fn with_max_iterations(mut self, max_iter: u32) -> Self {
143 self.max_iterations = max_iter;
144 self
145 }
146
147 pub fn with_tolerance(mut self, tol: f64) -> Self {
149 self.tolerance = tol;
150 self
151 }
152}
153
154#[derive(Debug, Clone)]
159pub struct QzResult {
160 pub alpha_real: Vec<f64>,
162 pub alpha_imag: Vec<f64>,
164 pub beta: Vec<f64>,
166 pub schur_s: Option<Vec<f64>>,
169 pub schur_t: Option<Vec<f64>>,
172 pub q_matrix: Option<Vec<f64>>,
174 pub z_matrix: Option<Vec<f64>>,
176 pub iterations: u32,
178 pub converged: bool,
180}
181
182#[derive(Debug, Clone, PartialEq)]
187pub enum QzStep {
188 HessenbergTriangularReduction,
192 QzIteration {
194 shift_strategy: ShiftStrategy,
196 },
197 EigenvalueExtraction,
200 SchurVectorAccumulation,
203}
204
205#[derive(Debug, Clone)]
209pub struct QzPlan {
210 pub config: QzConfig,
212 pub steps: Vec<QzStep>,
214}
215
216impl QzPlan {
217 pub fn estimated_flops(&self) -> f64 {
223 estimate_qz_flops(self.config.n)
224 }
225}
226
227pub fn validate_qz_config(config: &QzConfig) -> SolverResult<()> {
238 if config.n == 0 {
239 return Err(SolverError::DimensionMismatch(
240 "QZ: matrix dimension n must be >= 1".to_string(),
241 ));
242 }
243 if config.tolerance <= 0.0 {
244 return Err(SolverError::InternalError(
245 "QZ: tolerance must be positive".to_string(),
246 ));
247 }
248 if config.max_iterations == 0 {
249 return Err(SolverError::InternalError(
250 "QZ: max_iterations must be >= 1".to_string(),
251 ));
252 }
253 Ok(())
254}
255
256pub fn plan_qz(config: &QzConfig) -> SolverResult<QzPlan> {
265 validate_qz_config(config)?;
266
267 let mut steps = Vec::new();
268
269 steps.push(QzStep::HessenbergTriangularReduction);
271
272 if config.n > 1 {
274 steps.push(QzStep::QzIteration {
275 shift_strategy: ShiftStrategy::FrancisDoubleShift,
276 });
277 }
278
279 steps.push(QzStep::EigenvalueExtraction);
281
282 if config.compute_schur_vectors {
284 steps.push(QzStep::SchurVectorAccumulation);
285 }
286
287 Ok(QzPlan {
288 config: config.clone(),
289 steps,
290 })
291}
292
293pub fn estimate_qz_flops(n: u32) -> f64 {
302 let nf = n as f64;
303 10.0 * nf * nf * nf
304}
305
306pub fn classify_eigenvalue(alpha_r: f64, alpha_i: f64, beta: f64) -> EigenvalueType {
308 let alpha_mag = (alpha_r * alpha_r + alpha_i * alpha_i).sqrt();
309
310 if beta.abs() < BETA_ZERO_THRESHOLD {
311 if alpha_mag < ALPHA_ZERO_THRESHOLD {
312 return EigenvalueType::Zero;
314 }
315 return EigenvalueType::Infinite;
316 }
317
318 if alpha_mag < ALPHA_ZERO_THRESHOLD {
319 return EigenvalueType::Zero;
320 }
321
322 if alpha_i.abs() < ALPHA_ZERO_THRESHOLD {
323 EigenvalueType::Real
324 } else {
325 EigenvalueType::ComplexPair
326 }
327}
328
329pub fn qz_host(a: &mut [f64], b: &mut [f64], config: &QzConfig) -> SolverResult<QzResult> {
349 validate_qz_config(config)?;
350 let n = config.n as usize;
351
352 if a.len() < n * n {
353 return Err(SolverError::DimensionMismatch(format!(
354 "QZ: matrix A too small ({} < {})",
355 a.len(),
356 n * n
357 )));
358 }
359 if b.len() < n * n {
360 return Err(SolverError::DimensionMismatch(format!(
361 "QZ: matrix B too small ({} < {})",
362 b.len(),
363 n * n
364 )));
365 }
366
367 let mut q = if config.compute_schur_vectors {
369 Some(identity_matrix(n))
370 } else {
371 None
372 };
373 let mut z = if config.compute_schur_vectors {
374 Some(identity_matrix(n))
375 } else {
376 None
377 };
378
379 qr_reduce_b(a, b, n, q.as_deref_mut());
381
382 hessenberg_reduce_a(a, b, n, q.as_deref_mut(), z.as_deref_mut());
384
385 let (iterations, converged) = if n > 1 {
387 qz_iteration(a, b, n, config, q.as_deref_mut(), z.as_deref_mut())?
388 } else {
389 (0, true)
390 };
391
392 let (alpha_real, alpha_imag, beta) = extract_eigenvalues(a, b, n);
394
395 let schur_s = if config.compute_schur_vectors {
396 Some(a[..n * n].to_vec())
397 } else {
398 None
399 };
400 let schur_t = if config.compute_schur_vectors {
401 Some(b[..n * n].to_vec())
402 } else {
403 None
404 };
405
406 Ok(QzResult {
407 alpha_real,
408 alpha_imag,
409 beta,
410 schur_s,
411 schur_t,
412 q_matrix: q,
413 z_matrix: z,
414 iterations,
415 converged,
416 })
417}
418
419fn identity_matrix(n: usize) -> Vec<f64> {
425 let mut m = vec![0.0; n * n];
426 for i in 0..n {
427 m[i * n + i] = 1.0;
428 }
429 m
430}
431
432#[inline]
434fn cm(row: usize, col: usize, n: usize) -> usize {
435 col * n + row
436}
437
438fn qr_reduce_b(a: &mut [f64], b: &mut [f64], n: usize, mut q: Option<&mut [f64]>) {
441 for k in 0..n.saturating_sub(1) {
442 let (v, tau) = householder_vector(b, k, k, n, n);
444 if tau.abs() < 1e-300 {
445 continue;
446 }
447
448 apply_householder_left(b, &v, tau, k, n, k, n, n);
450
451 apply_householder_left(a, &v, tau, k, n, 0, n, n);
453
454 if let Some(ref mut qm) = q {
456 apply_householder_right(qm, &v, tau, 0, n, k, n, n);
457 }
458 }
459}
460
461fn hessenberg_reduce_a(
467 a: &mut [f64],
468 b: &mut [f64],
469 n: usize,
470 mut q: Option<&mut [f64]>,
471 mut z: Option<&mut [f64]>,
472) {
473 if n <= 2 {
474 return;
475 }
476
477 for col in 0..n - 2 {
478 for row in (col + 2..n).rev() {
479 let a_target = a[cm(row, col, n)];
482 let a_above = a[cm(row - 1, col, n)];
483 if a_target.abs() < 1e-300 {
484 continue;
485 }
486
487 let (cs, sn) = givens_rotation(a_above, a_target);
488
489 apply_givens_left(a, cs, sn, row - 1, row, n, n);
495 apply_givens_left(b, cs, sn, row - 1, row, n, n);
496
497 if let Some(ref mut qm) = q {
498 apply_givens_right(qm, cs, sn, row - 1, row, n, n);
500 }
501
502 let b_lower = b[cm(row, row - 1, n)];
506 let b_diag = b[cm(row, row, n)];
507 if b_lower.abs() < 1e-300 {
508 continue;
509 }
510
511 let (cs2, sn2) = givens_rotation(b_diag, b_lower);
512
513 apply_givens_right_cols(b, cs2, sn2, row, row - 1, n, n);
514 apply_givens_right_cols(a, cs2, sn2, row, row - 1, n, n);
515
516 if let Some(ref mut zm) = z {
517 apply_givens_right_cols(zm, cs2, sn2, row, row - 1, n, n);
518 }
519 }
520 }
521}
522
523fn qz_iteration(
531 a: &mut [f64],
532 b: &mut [f64],
533 n: usize,
534 config: &QzConfig,
535 mut q: Option<&mut [f64]>,
536 mut z: Option<&mut [f64]>,
537) -> SolverResult<(u32, bool)> {
538 let tol = config.tolerance;
539 let max_iter = config.max_iterations;
540 let mut total_iter: u32 = 0;
541
542 let mut ihi = n;
544
545 while ihi > 1 {
546 let mut deflated = false;
547
548 for _sweep in 0..max_iter {
549 total_iter = total_iter.saturating_add(1);
550
551 let sub = a[cm(ihi - 1, ihi - 2, n)].abs();
553 let diag_sum = a[cm(ihi - 2, ihi - 2, n)].abs() + a[cm(ihi - 1, ihi - 1, n)].abs();
554 let threshold = if diag_sum > 0.0 { tol * diag_sum } else { tol };
555
556 if sub <= threshold {
557 a[cm(ihi - 1, ihi - 2, n)] = 0.0;
558 ihi -= 1;
559 deflated = true;
560 break;
561 }
562
563 if ihi >= 3 {
565 let sub2 = a[cm(ihi - 2, ihi - 3, n)].abs();
566 let diag_sum2 = a[cm(ihi - 3, ihi - 3, n)].abs() + a[cm(ihi - 2, ihi - 2, n)].abs();
567 let threshold2 = if diag_sum2 > 0.0 {
568 tol * diag_sum2
569 } else {
570 tol
571 };
572 if sub2 <= threshold2 {
573 a[cm(ihi - 2, ihi - 3, n)] = 0.0;
574 ihi -= 2;
575 deflated = true;
576 break;
577 }
578 }
579
580 let mut ilo = ihi - 1;
582 while ilo > 0 {
583 let sub_ilo = a[cm(ilo, ilo - 1, n)].abs();
584 let diag_ilo = a[cm(ilo - 1, ilo - 1, n)].abs() + a[cm(ilo, ilo, n)].abs();
585 let thr_ilo = if diag_ilo > 0.0 { tol * diag_ilo } else { tol };
586 if sub_ilo <= thr_ilo {
587 a[cm(ilo, ilo - 1, n)] = 0.0;
588 break;
589 }
590 ilo -= 1;
591 }
592
593 qz_double_shift_step(a, b, n, ilo, ihi, q.as_deref_mut(), z.as_deref_mut());
595 }
596
597 if !deflated {
598 let residual = a[cm(ihi - 1, ihi - 2, n)].abs();
599 return Ok((total_iter, residual <= tol));
600 }
601 }
602
603 Ok((total_iter, true))
604}
605
606fn qz_double_shift_step(
611 a: &mut [f64],
612 b: &mut [f64],
613 n: usize,
614 ilo: usize,
615 ihi: usize,
616 q: Option<&mut [f64]>,
617 z: Option<&mut [f64]>,
618) {
619 let m = ihi - ilo;
620 if m < 2 {
621 return;
622 }
623
624 let i1 = ihi - 2;
629 let i2 = ihi - 1;
630
631 let a11 = a[cm(i1, i1, n)];
632 let a12 = a[cm(i1, i2, n)];
633 let a21 = a[cm(i2, i1, n)];
634 let a22 = a[cm(i2, i2, n)];
635
636 let t11 = b[cm(i1, i1, n)];
637 let _t12 = b[cm(i1, i2, n)];
638 let t22 = b[cm(i2, i2, n)];
639
640 let det_t = t11 * t22;
644 let trace_ab = if det_t.abs() > 1e-300 {
645 (a11 * t22 - a12 * 0.0 + a22 * t11) / det_t
646 } else {
647 a11 + a22
648 };
649 let det_ab = if det_t.abs() > 1e-300 {
650 (a11 * a22 - a12 * a21) * t22 * t11 / (det_t * det_t)
651 } else {
652 a11 * a22 - a12 * a21
653 };
654
655 let h11 = a[cm(ilo, ilo, n)];
657 let h21 = a[cm(ilo + 1, ilo, n)];
658 let h12 = if ilo + 1 < n {
659 a[cm(ilo, ilo + 1, n)]
660 } else {
661 0.0
662 };
663
664 let p1 = h11 * h11 + h12 * h21 - trace_ab * h11 + det_ab;
665 let p2 = h21 * (h11 + a[cm(ilo + 1, ilo + 1, n)] - trace_ab);
666 let p3 = if m >= 3 {
667 h21 * a[cm(ilo + 2, ilo + 1, n)]
668 } else {
669 0.0
670 };
671
672 chase_bulge(a, b, n, ilo, ihi, p1, p2, p3, q, z);
674}
675
676#[allow(clippy::too_many_arguments)]
678fn chase_bulge(
679 a: &mut [f64],
680 b: &mut [f64],
681 n: usize,
682 ilo: usize,
683 ihi: usize,
684 p1: f64,
685 p2: f64,
686 p3: f64,
687 mut q: Option<&mut [f64]>,
688 mut z: Option<&mut [f64]>,
689) {
690 let (v, tau) = householder_from_vec(&[p1, p2, p3]);
692 let size = 3.min(ihi - ilo);
693
694 apply_householder_left_small(a, &v[..size], tau, ilo, ilo + size, 0, n, n);
696 apply_householder_left_small(b, &v[..size], tau, ilo, ilo + size, 0, n, n);
697 if let Some(ref mut qm) = q {
698 apply_householder_right_small(qm, &v[..size], tau, 0, n, ilo, ilo + size, n);
699 }
700
701 for k in ilo..ihi.saturating_sub(2) {
703 let rows_left = (ihi - k).min(3);
704
705 for r in (1..rows_left).rev() {
708 let row = k + r;
709 let b_below = b[cm(row, k, n)];
710 let b_above = b[cm(row - 1, k, n)];
711 if b_below.abs() < 1e-300 {
712 continue;
713 }
714 let (cs, sn) = givens_rotation(b_above, b_below);
715
716 apply_givens_left(b, cs, sn, row - 1, row, n, n);
718 apply_givens_left(a, cs, sn, row - 1, row, n, n);
719 if let Some(ref mut qm) = q {
720 apply_givens_right(qm, cs, sn, row - 1, row, n, n);
721 }
722 }
723
724 if k + 2 < ihi {
727 for r in (k + 2..ihi.min(k + 3)).rev() {
728 let a_target = a[cm(r, k, n)];
729 if a_target.abs() < 1e-300 {
730 continue;
731 }
732 let a_above = a[cm(r - 1, k, n)];
733 let (cs, sn) = givens_rotation(a_above, a_target);
734
735 apply_givens_right_cols(a, cs, sn, r - 1, r, n, n);
737 apply_givens_right_cols(b, cs, sn, r - 1, r, n, n);
738 if let Some(ref mut zm) = z {
739 apply_givens_right_cols(zm, cs, sn, r - 1, r, n, n);
740 }
741 }
742 }
743 }
744}
745
746fn extract_eigenvalues(s: &[f64], t: &[f64], n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
755 let mut alpha_real = vec![0.0; n];
756 let mut alpha_imag = vec![0.0; n];
757 let mut beta = vec![0.0; n];
758
759 let mut i = 0;
760 while i < n {
761 if i + 1 < n && s[cm(i + 1, i, n)].abs() > ALPHA_ZERO_THRESHOLD {
762 let s11 = s[cm(i, i, n)];
764 let s12 = s[cm(i, i + 1, n)];
765 let s21 = s[cm(i + 1, i, n)];
766 let s22 = s[cm(i + 1, i + 1, n)];
767 let t11 = t[cm(i, i, n)];
768 let t22 = t[cm(i + 1, i + 1, n)];
769
770 let beta_val = (t11 * t22).abs().sqrt();
771 let trace = s11 + s22;
772 let det = s11 * s22 - s12 * s21;
773 let disc = trace * trace - 4.0 * det;
774
775 if disc < 0.0 {
776 let real_part = trace / 2.0;
777 let imag_part = (-disc).sqrt() / 2.0;
778 alpha_real[i] = real_part;
779 alpha_imag[i] = imag_part;
780 beta[i] = if beta_val.abs() > 1e-300 {
781 beta_val
782 } else {
783 1.0
784 };
785
786 alpha_real[i + 1] = real_part;
787 alpha_imag[i + 1] = -imag_part;
788 beta[i + 1] = beta[i];
789 } else {
790 let sqrt_disc = disc.sqrt();
791 alpha_real[i] = (trace + sqrt_disc) / 2.0;
792 alpha_imag[i] = 0.0;
793 beta[i] = if beta_val.abs() > 1e-300 {
794 beta_val
795 } else {
796 1.0
797 };
798
799 alpha_real[i + 1] = (trace - sqrt_disc) / 2.0;
800 alpha_imag[i + 1] = 0.0;
801 beta[i + 1] = beta[i];
802 }
803 i += 2;
804 } else {
805 alpha_real[i] = s[cm(i, i, n)];
807 alpha_imag[i] = 0.0;
808 beta[i] = t[cm(i, i, n)].abs().max(1e-300);
809 i += 1;
810 }
811 }
812
813 (alpha_real, alpha_imag, beta)
814}
815
816fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
824 if b.abs() < 1e-300 {
825 return (1.0, 0.0);
826 }
827 if a.abs() < 1e-300 {
828 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
829 }
830 let r = (a * a + b * b).sqrt();
831 (a / r, b / r)
832}
833
834fn householder_vector(
838 m: &[f64],
839 start: usize,
840 col: usize,
841 n: usize,
842 _lda: usize,
843) -> (Vec<f64>, f64) {
844 let len = n - start;
845 let mut v = vec![0.0; len];
846 for i in 0..len {
847 v[i] = m[cm(start + i, col, n)];
848 }
849
850 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
851 if norm < 1e-300 || len == 0 {
852 return (v, 0.0);
853 }
854
855 let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
856 v[0] += sign * norm;
857
858 let v_norm_sq: f64 = v.iter().map(|x| x * x).sum();
859 if v_norm_sq < 1e-300 {
860 return (v, 0.0);
861 }
862 let tau = 2.0 / v_norm_sq;
863
864 (v, tau)
865}
866
867fn householder_from_vec(x: &[f64]) -> (Vec<f64>, f64) {
869 let mut v = x.to_vec();
870 let norm: f64 = v.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
871 if norm < 1e-300 {
872 return (v, 0.0);
873 }
874 let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
875 v[0] += sign * norm;
876 let v_norm_sq: f64 = v.iter().map(|xi| xi * xi).sum();
877 if v_norm_sq < 1e-300 {
878 return (v, 0.0);
879 }
880 let tau = 2.0 / v_norm_sq;
881 (v, tau)
882}
883
884#[allow(clippy::too_many_arguments)]
887fn apply_householder_left(
888 m: &mut [f64],
889 v: &[f64],
890 tau: f64,
891 row_start: usize,
892 row_end: usize,
893 col_start: usize,
894 col_end: usize,
895 n: usize,
896) {
897 let vlen = row_end - row_start;
898 for j in col_start..col_end {
899 let mut dot = 0.0;
900 for i in 0..vlen {
901 dot += v[i] * m[cm(row_start + i, j, n)];
902 }
903 let scale = tau * dot;
904 for i in 0..vlen {
905 m[cm(row_start + i, j, n)] -= scale * v[i];
906 }
907 }
908}
909
910#[allow(clippy::too_many_arguments)]
913fn apply_householder_right(
914 m: &mut [f64],
915 v: &[f64],
916 tau: f64,
917 row_start: usize,
918 row_end: usize,
919 col_start: usize,
920 _col_end: usize,
921 n: usize,
922) {
923 let vlen = v.len();
924 for i in row_start..row_end {
925 let mut dot = 0.0;
926 for k in 0..vlen {
927 dot += m[cm(i, col_start + k, n)] * v[k];
928 }
929 let scale = tau * dot;
930 for k in 0..vlen {
931 m[cm(i, col_start + k, n)] -= scale * v[k];
932 }
933 }
934}
935
936#[allow(clippy::too_many_arguments)]
938fn apply_householder_left_small(
939 m: &mut [f64],
940 v: &[f64],
941 tau: f64,
942 row_start: usize,
943 row_end: usize,
944 col_start: usize,
945 col_end: usize,
946 n: usize,
947) {
948 apply_householder_left(m, v, tau, row_start, row_end, col_start, col_end, n);
949}
950
951#[allow(clippy::too_many_arguments)]
953fn apply_householder_right_small(
954 m: &mut [f64],
955 v: &[f64],
956 tau: f64,
957 row_start: usize,
958 row_end: usize,
959 col_start: usize,
960 col_end: usize,
961 n: usize,
962) {
963 let _ = col_end; apply_householder_right(
965 m,
966 v,
967 tau,
968 row_start,
969 row_end,
970 col_start,
971 col_start + v.len(),
972 n,
973 );
974}
975
976fn apply_givens_left(
980 m: &mut [f64],
981 cs: f64,
982 sn: f64,
983 r1: usize,
984 r2: usize,
985 n: usize,
986 ncols: usize,
987) {
988 for j in 0..ncols {
989 let a_val = m[cm(r1, j, n)];
990 let b_val = m[cm(r2, j, n)];
991 m[cm(r1, j, n)] = cs * a_val + sn * b_val;
992 m[cm(r2, j, n)] = -sn * a_val + cs * b_val;
993 }
994}
995
996fn apply_givens_right(
1000 m: &mut [f64],
1001 cs: f64,
1002 sn: f64,
1003 c1: usize,
1004 c2: usize,
1005 n: usize,
1006 nrows: usize,
1007) {
1008 for i in 0..nrows {
1009 let a_val = m[cm(i, c1, n)];
1010 let b_val = m[cm(i, c2, n)];
1011 m[cm(i, c1, n)] = cs * a_val + sn * b_val;
1012 m[cm(i, c2, n)] = -sn * a_val + cs * b_val;
1013 }
1014}
1015
1016fn apply_givens_right_cols(
1021 m: &mut [f64],
1022 cs: f64,
1023 sn: f64,
1024 c1: usize,
1025 c2: usize,
1026 n: usize,
1027 nrows: usize,
1028) {
1029 for i in 0..nrows {
1030 let a_val = m[cm(i, c1, n)];
1031 let b_val = m[cm(i, c2, n)];
1032 m[cm(i, c1, n)] = cs * a_val - sn * b_val;
1033 m[cm(i, c2, n)] = sn * a_val + cs * b_val;
1034 }
1035}
1036
1037pub fn generate_hessenberg_reduction_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1056 let name = format!("qz_hessenberg_reduction_{n}");
1057
1058 let ptx = KernelBuilder::new(&name)
1059 .target(sm)
1060 .max_threads_per_block(SOLVER_BLOCK_SIZE)
1061 .param("a_ptr", PtxType::U64)
1062 .param("b_ptr", PtxType::U64)
1063 .param("q_ptr", PtxType::U64)
1064 .param("z_ptr", PtxType::U64)
1065 .param("n_param", PtxType::U32)
1066 .body(move |b| {
1067 let tid = b.thread_id_x();
1068 let n_param = b.load_param_u32("n_param");
1069
1070 let _ = (tid, n_param);
1075
1076 b.ret();
1077 })
1078 .build()?;
1079
1080 Ok(ptx)
1081}
1082
1083pub fn generate_qz_sweep_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1097 let name = format!("qz_sweep_{n}");
1098
1099 let ptx = KernelBuilder::new(&name)
1100 .target(sm)
1101 .max_threads_per_block(SOLVER_BLOCK_SIZE)
1102 .param("a_ptr", PtxType::U64)
1103 .param("b_ptr", PtxType::U64)
1104 .param("q_ptr", PtxType::U64)
1105 .param("z_ptr", PtxType::U64)
1106 .param("ilo", PtxType::U32)
1107 .param("ihi", PtxType::U32)
1108 .param("n_param", PtxType::U32)
1109 .body(move |b| {
1110 let tid = b.thread_id_x();
1111 let ilo = b.load_param_u32("ilo");
1112 let ihi = b.load_param_u32("ihi");
1113 let n_param = b.load_param_u32("n_param");
1114
1115 let _ = (tid, ilo, ihi, n_param);
1121
1122 b.ret();
1123 })
1124 .build()?;
1125
1126 Ok(ptx)
1127}
1128
1129pub fn generate_eigenvalue_extraction_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1142 let name = format!("qz_eigenvalue_extract_{n}");
1143
1144 let ptx = KernelBuilder::new(&name)
1145 .target(sm)
1146 .max_threads_per_block(SOLVER_BLOCK_SIZE)
1147 .param("s_ptr", PtxType::U64)
1148 .param("t_ptr", PtxType::U64)
1149 .param("alpha_r_ptr", PtxType::U64)
1150 .param("alpha_i_ptr", PtxType::U64)
1151 .param("beta_ptr", PtxType::U64)
1152 .param("n_param", PtxType::U32)
1153 .body(move |b| {
1154 let tid = b.thread_id_x();
1155 let n_param = b.load_param_u32("n_param");
1156
1157 let _ = (tid, n_param);
1162
1163 b.ret();
1164 })
1165 .build()?;
1166
1167 Ok(ptx)
1168}
1169
1170#[cfg(test)]
1175mod tests {
1176 use super::*;
1177
1178 #[test]
1179 fn test_balance_strategy_default() {
1180 let bs = BalanceStrategy::default();
1181 assert_eq!(bs, BalanceStrategy::Both);
1182 }
1183
1184 #[test]
1185 fn test_shift_strategy_default() {
1186 let ss = ShiftStrategy::default();
1187 assert_eq!(ss, ShiftStrategy::FrancisDoubleShift);
1188 }
1189
1190 #[test]
1191 fn test_qz_config_new() {
1192 let config = QzConfig::new(10, SmVersion::Sm80);
1193 assert_eq!(config.n, 10);
1194 assert!(!config.compute_schur_vectors);
1195 assert_eq!(config.balance, BalanceStrategy::Both);
1196 assert_eq!(config.max_iterations, 300);
1197 assert!((config.tolerance - 1e-14).abs() < 1e-20);
1198 }
1199
1200 #[test]
1201 fn test_qz_config_builder() {
1202 let config = QzConfig::new(5, SmVersion::Sm90)
1203 .with_schur_vectors(true)
1204 .with_balance(BalanceStrategy::None)
1205 .with_max_iterations(500)
1206 .with_tolerance(1e-12);
1207 assert_eq!(config.n, 5);
1208 assert!(config.compute_schur_vectors);
1209 assert_eq!(config.balance, BalanceStrategy::None);
1210 assert_eq!(config.max_iterations, 500);
1211 assert!((config.tolerance - 1e-12).abs() < 1e-20);
1212 }
1213
1214 #[test]
1215 fn test_validate_qz_config_valid() {
1216 let config = QzConfig::new(4, SmVersion::Sm80);
1217 assert!(validate_qz_config(&config).is_ok());
1218 }
1219
1220 #[test]
1221 fn test_validate_qz_config_zero_n() {
1222 let config = QzConfig {
1223 n: 0,
1224 compute_schur_vectors: false,
1225 balance: BalanceStrategy::None,
1226 max_iterations: 100,
1227 tolerance: 1e-14,
1228 sm_version: SmVersion::Sm80,
1229 };
1230 let err = validate_qz_config(&config);
1231 assert!(err.is_err());
1232 assert!(matches!(err, Err(SolverError::DimensionMismatch(_))));
1233 }
1234
1235 #[test]
1236 fn test_validate_qz_config_zero_tolerance() {
1237 let config = QzConfig::new(4, SmVersion::Sm80).with_tolerance(0.0);
1238 assert!(validate_qz_config(&config).is_err());
1239 }
1240
1241 #[test]
1242 fn test_validate_qz_config_zero_iterations() {
1243 let config = QzConfig::new(4, SmVersion::Sm80).with_max_iterations(0);
1244 assert!(validate_qz_config(&config).is_err());
1245 }
1246
1247 #[test]
1248 fn test_plan_qz_basic() {
1249 let config = QzConfig::new(4, SmVersion::Sm80);
1250 let plan = plan_qz(&config);
1251 assert!(plan.is_ok());
1252 let plan = plan.ok();
1253 assert!(plan.is_some());
1254 let plan = plan.as_ref();
1255 let plan = plan.map(|p| &p.steps);
1256 if let Some(steps) = plan {
1257 assert!(steps.contains(&QzStep::HessenbergTriangularReduction));
1258 assert!(steps.contains(&QzStep::EigenvalueExtraction));
1259 assert!(!steps.contains(&QzStep::SchurVectorAccumulation));
1261 }
1262 }
1263
1264 #[test]
1265 fn test_plan_qz_with_vectors() {
1266 let config = QzConfig::new(4, SmVersion::Sm80).with_schur_vectors(true);
1267 let plan = plan_qz(&config);
1268 assert!(plan.is_ok());
1269 if let Ok(p) = &plan {
1270 assert!(p.steps.contains(&QzStep::SchurVectorAccumulation));
1271 }
1272 }
1273
1274 #[test]
1275 fn test_plan_qz_n1_no_iteration() {
1276 let config = QzConfig::new(1, SmVersion::Sm80);
1277 let plan = plan_qz(&config);
1278 assert!(plan.is_ok());
1279 if let Ok(p) = &plan {
1280 let has_iter = p
1282 .steps
1283 .iter()
1284 .any(|s| matches!(s, QzStep::QzIteration { .. }));
1285 assert!(!has_iter, "n=1 should not have QzIteration step");
1286 }
1287 }
1288
1289 #[test]
1290 fn test_estimate_qz_flops() {
1291 let flops_1 = estimate_qz_flops(1);
1292 assert!((flops_1 - 10.0).abs() < 1e-10);
1293
1294 let flops_10 = estimate_qz_flops(10);
1295 assert!((flops_10 - 10_000.0).abs() < 1e-6);
1296
1297 let flops_100 = estimate_qz_flops(100);
1298 assert!((flops_100 - 10_000_000.0).abs() < 1.0);
1299 }
1300
1301 #[test]
1302 fn test_estimated_flops_via_plan() {
1303 let config = QzConfig::new(10, SmVersion::Sm80);
1304 if let Ok(plan) = plan_qz(&config) {
1305 let flops = plan.estimated_flops();
1306 assert!((flops - 10_000.0).abs() < 1e-6);
1307 }
1308 }
1309
1310 #[test]
1311 fn test_classify_eigenvalue_real() {
1312 let et = classify_eigenvalue(3.5, 0.0, 1.0);
1313 assert_eq!(et, EigenvalueType::Real);
1314 }
1315
1316 #[test]
1317 fn test_classify_eigenvalue_complex() {
1318 let et = classify_eigenvalue(1.0, 2.0, 1.0);
1319 assert_eq!(et, EigenvalueType::ComplexPair);
1320 }
1321
1322 #[test]
1323 fn test_classify_eigenvalue_infinite() {
1324 let et = classify_eigenvalue(1.0, 0.0, 0.0);
1325 assert_eq!(et, EigenvalueType::Infinite);
1326 }
1327
1328 #[test]
1329 fn test_classify_eigenvalue_zero() {
1330 let et = classify_eigenvalue(0.0, 0.0, 1.0);
1331 assert_eq!(et, EigenvalueType::Zero);
1332 }
1333
1334 #[test]
1335 fn test_classify_eigenvalue_zero_over_zero() {
1336 let et = classify_eigenvalue(0.0, 0.0, 0.0);
1338 assert_eq!(et, EigenvalueType::Zero);
1339 }
1340
1341 #[test]
1342 fn test_qz_host_n1() {
1343 let mut a = vec![5.0];
1345 let mut b = vec![2.0];
1346 let config = QzConfig::new(1, SmVersion::Sm80);
1347 let result = qz_host(&mut a, &mut b, &config);
1348 assert!(result.is_ok());
1349 if let Ok(r) = &result {
1350 assert!(r.converged);
1351 assert_eq!(r.alpha_real.len(), 1);
1352 assert_eq!(r.beta.len(), 1);
1353 let eig = r.alpha_real[0] / r.beta[0];
1355 assert!(
1356 (eig - 2.5).abs() < 1e-10,
1357 "eigenvalue = {eig}, expected 2.5"
1358 );
1359 }
1360 }
1361
1362 #[test]
1363 fn test_qz_host_n2_diagonal() {
1364 let mut a = vec![3.0, 0.0, 0.0, 7.0]; let mut b = vec![1.0, 0.0, 0.0, 2.0];
1367 let config = QzConfig::new(2, SmVersion::Sm80);
1368 let result = qz_host(&mut a, &mut b, &config);
1369 assert!(result.is_ok());
1370 if let Ok(r) = &result {
1371 assert!(r.converged);
1372 assert_eq!(r.alpha_real.len(), 2);
1373 assert_eq!(r.beta.len(), 2);
1374 for bt in &r.beta {
1376 assert!(bt.abs() > 1e-15, "beta should be nonzero");
1377 }
1378 }
1379 }
1380
1381 #[test]
1382 fn test_qz_host_dimension_mismatch() {
1383 let mut a = vec![1.0, 2.0]; let mut b = vec![1.0, 0.0, 0.0, 1.0];
1385 let config = QzConfig::new(2, SmVersion::Sm80);
1386 let result = qz_host(&mut a, &mut b, &config);
1387 assert!(result.is_err());
1388 assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
1389 }
1390
1391 #[test]
1392 fn test_qz_host_with_schur_vectors() {
1393 let mut a = vec![2.0, 0.0, 0.0, 3.0];
1394 let mut b = vec![1.0, 0.0, 0.0, 1.0];
1395 let config = QzConfig::new(2, SmVersion::Sm80).with_schur_vectors(true);
1396 let result = qz_host(&mut a, &mut b, &config);
1397 assert!(result.is_ok());
1398 if let Ok(r) = &result {
1399 assert!(r.q_matrix.is_some());
1400 assert!(r.z_matrix.is_some());
1401 assert!(r.schur_s.is_some());
1402 assert!(r.schur_t.is_some());
1403 }
1404 }
1405
1406 #[test]
1407 fn test_generate_hessenberg_reduction_ptx() {
1408 let ptx = generate_hessenberg_reduction_ptx(4, SmVersion::Sm80);
1409 assert!(ptx.is_ok());
1410 if let Ok(code) = &ptx {
1411 assert!(code.contains("qz_hessenberg_reduction_4"));
1412 }
1413 }
1414
1415 #[test]
1416 fn test_generate_qz_sweep_ptx() {
1417 let ptx = generate_qz_sweep_ptx(8, SmVersion::Sm86);
1418 assert!(ptx.is_ok());
1419 if let Ok(code) = &ptx {
1420 assert!(code.contains("qz_sweep_8"));
1421 }
1422 }
1423
1424 #[test]
1425 fn test_generate_eigenvalue_extraction_ptx() {
1426 let ptx = generate_eigenvalue_extraction_ptx(4, SmVersion::Sm90);
1427 assert!(ptx.is_ok());
1428 if let Ok(code) = &ptx {
1429 assert!(code.contains("qz_eigenvalue_extract_4"));
1430 }
1431 }
1432
1433 #[test]
1434 fn test_givens_rotation_basic() {
1435 let (cs, sn) = givens_rotation(3.0, 4.0);
1436 let r = cs * 3.0 + sn * 4.0;
1437 assert!((r - 5.0).abs() < 1e-10);
1438 let zero = -sn * 3.0 + cs * 4.0;
1440 assert!(zero.abs() < 1e-10);
1441 }
1442
1443 #[test]
1444 fn test_givens_rotation_zero_b() {
1445 let (cs, sn) = givens_rotation(5.0, 0.0);
1446 assert!((cs - 1.0).abs() < 1e-15);
1447 assert!(sn.abs() < 1e-15);
1448 }
1449
1450 #[test]
1451 fn test_identity_matrix() {
1452 let id = identity_matrix(3);
1453 assert_eq!(id.len(), 9);
1454 assert!((id[cm(0, 0, 3)] - 1.0).abs() < 1e-15);
1455 assert!((id[cm(1, 1, 3)] - 1.0).abs() < 1e-15);
1456 assert!((id[cm(2, 2, 3)] - 1.0).abs() < 1e-15);
1457 assert!(id[cm(0, 1, 3)].abs() < 1e-15);
1458 assert!(id[cm(1, 0, 3)].abs() < 1e-15);
1459 }
1460
1461 #[test]
1462 fn test_column_major_indexing() {
1463 assert_eq!(cm(0, 0, 3), 0);
1465 assert_eq!(cm(1, 0, 3), 1);
1466 assert_eq!(cm(0, 1, 3), 3);
1467 assert_eq!(cm(2, 2, 3), 8);
1468 }
1469
1470 #[test]
1471 fn test_extract_eigenvalues_diagonal() {
1472 let n = 3;
1474 let mut s = vec![0.0; n * n];
1475 let mut t = vec![0.0; n * n];
1476 s[cm(0, 0, n)] = 2.0;
1477 s[cm(1, 1, n)] = 5.0;
1478 s[cm(2, 2, n)] = -1.0;
1479 t[cm(0, 0, n)] = 1.0;
1480 t[cm(1, 1, n)] = 2.0;
1481 t[cm(2, 2, n)] = 3.0;
1482
1483 let (ar, ai, bt) = extract_eigenvalues(&s, &t, n);
1484 assert_eq!(ar.len(), 3);
1485 assert!((ar[0] / bt[0] - 2.0).abs() < 1e-10);
1487 assert!((ar[1] / bt[1] - 2.5).abs() < 1e-10);
1489 assert!((ar[2] / bt[2] - (-1.0 / 3.0)).abs() < 1e-10);
1491 for &imag in &ai {
1493 assert!(imag.abs() < 1e-15);
1494 }
1495 }
1496
1497 #[test]
1498 fn test_qz_host_n3_upper_triangular() {
1499 #[rustfmt::skip]
1502 let mut a = vec![
1503 1.0, 0.0, 0.0, 2.0, 4.0, 0.0, 3.0, 5.0, 6.0, ];
1507 #[rustfmt::skip]
1508 let mut b = vec![
1509 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 1.0, 1.0, 3.0, ];
1513 let config = QzConfig::new(3, SmVersion::Sm80);
1514 let result = qz_host(&mut a, &mut b, &config);
1515 assert!(result.is_ok());
1516 if let Ok(r) = &result {
1517 assert!(r.converged);
1518 assert_eq!(r.alpha_real.len(), 3);
1519 }
1520 }
1521}