Skip to main content

oxicuda_solver/dense/
qz.rs

1//! Non-symmetric generalized eigenvalue solver (QZ algorithm).
2//!
3//! Solves the generalized eigenvalue problem `A * x = λ * B * x` for
4//! non-symmetric matrices A and B. The QZ algorithm reduces the pencil (A, B)
5//! to generalized real Schur form:
6//!
7//!   Q^T * A * Z = S  (upper quasi-triangular)
8//!   Q^T * B * Z = T  (upper triangular)
9//!
10//! where Q and Z are orthogonal matrices. The generalized eigenvalues are
11//! given by `λ_i = (α_r_i + i * α_i_i) / β_i`, where `α` comes from the
12//! diagonal blocks of S and `β` from the diagonal of T.
13//!
14//! ## Algorithm Stages
15//!
16//! 1. **Balancing** (optional): Permute and/or scale rows/columns to improve
17//!    numerical conditioning.
18//! 2. **Hessenberg-triangular reduction**: Reduce A to upper Hessenberg form H
19//!    and B to upper triangular form T using orthogonal transformations.
20//! 3. **QZ iteration**: Apply implicit double-shift QZ steps (Francis-type) to
21//!    drive subdiagonal elements of H to zero while maintaining the triangular
22//!    form of T.
23//! 4. **Eigenvalue extraction**: Read off generalized eigenvalues from the
24//!    quasi-triangular (S, T) pair.
25
26#![allow(dead_code)]
27
28use oxicuda_ptx::ir::PtxType;
29use oxicuda_ptx::prelude::*;
30
31use crate::error::{SolverError, SolverResult};
32use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
33
34// ---------------------------------------------------------------------------
35// Constants
36// ---------------------------------------------------------------------------
37
38/// Default maximum iterations for QZ iteration.
39const QZ_DEFAULT_MAX_ITER: u32 = 300;
40
41/// Default convergence tolerance for subdiagonal deflation.
42const QZ_DEFAULT_TOL: f64 = 1e-14;
43
44/// Threshold below which β is considered zero (infinite eigenvalue).
45const BETA_ZERO_THRESHOLD: f64 = 1e-15;
46
47/// Threshold below which α is considered zero.
48const ALPHA_ZERO_THRESHOLD: f64 = 1e-15;
49
50// ---------------------------------------------------------------------------
51// Public types
52// ---------------------------------------------------------------------------
53
54/// Balancing strategy applied before the QZ factorization.
55///
56/// Balancing can improve the accuracy and convergence of the QZ algorithm
57/// by reducing the norm of off-diagonal elements through similarity
58/// transformations that preserve the eigenvalues.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum BalanceStrategy {
61    /// No balancing — use the matrices as given.
62    None,
63    /// Permute rows and columns to isolate eigenvalues if possible.
64    Permute,
65    /// Scale rows and columns to make the norms more uniform.
66    Scale,
67    /// Both permute and scale (default for best accuracy).
68    #[default]
69    Both,
70}
71
72/// Shift strategy for implicit QZ steps.
73///
74/// Controls how the shift polynomial is chosen at each QZ iteration.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76pub enum ShiftStrategy {
77    /// Use explicitly computed eigenvalues of the trailing 2×2 block.
78    ExplicitShift,
79    /// Francis implicit double-shift (standard choice for real matrices).
80    #[default]
81    FrancisDoubleShift,
82    /// Wilkinson shift (single-shift variant, more aggressive deflation).
83    Wilkinson,
84}
85
86/// Classification of a generalized eigenvalue `(α_r + i*α_i) / β`.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum EigenvalueType {
89    /// A purely real eigenvalue (α_i ≈ 0, β ≠ 0).
90    Real,
91    /// Part of a complex conjugate pair (α_i ≠ 0, β ≠ 0).
92    ComplexPair,
93    /// An infinite eigenvalue (β ≈ 0, α ≠ 0).
94    Infinite,
95    /// A zero eigenvalue (α ≈ 0, β ≠ 0).
96    Zero,
97}
98
99/// Configuration for the QZ decomposition.
100#[derive(Debug, Clone)]
101pub struct QzConfig {
102    /// Matrix dimension (both A and B are n×n).
103    pub n: u32,
104    /// Whether to compute the orthogonal Schur vectors Q and Z.
105    pub compute_schur_vectors: bool,
106    /// Pre-balancing strategy.
107    pub balance: BalanceStrategy,
108    /// Maximum number of QZ iterations before declaring non-convergence.
109    pub max_iterations: u32,
110    /// Convergence tolerance for subdiagonal deflation.
111    pub tolerance: f64,
112    /// Target GPU SM version for PTX generation.
113    pub sm_version: SmVersion,
114}
115
116impl QzConfig {
117    /// Creates a new QZ configuration with default parameters.
118    pub fn new(n: u32, sm_version: SmVersion) -> Self {
119        Self {
120            n,
121            compute_schur_vectors: false,
122            balance: BalanceStrategy::default(),
123            max_iterations: QZ_DEFAULT_MAX_ITER,
124            tolerance: QZ_DEFAULT_TOL,
125            sm_version,
126        }
127    }
128
129    /// Enables computation of Schur vectors Q and Z.
130    pub fn with_schur_vectors(mut self, enabled: bool) -> Self {
131        self.compute_schur_vectors = enabled;
132        self
133    }
134
135    /// Sets the balancing strategy.
136    pub fn with_balance(mut self, strategy: BalanceStrategy) -> Self {
137        self.balance = strategy;
138        self
139    }
140
141    /// Sets the maximum number of iterations.
142    pub fn with_max_iterations(mut self, max_iter: u32) -> Self {
143        self.max_iterations = max_iter;
144        self
145    }
146
147    /// Sets the convergence tolerance.
148    pub fn with_tolerance(mut self, tol: f64) -> Self {
149        self.tolerance = tol;
150        self
151    }
152}
153
154/// Result of a QZ decomposition.
155///
156/// Contains the generalized eigenvalues `(α_r + i*α_i) / β` and optionally
157/// the generalized real Schur form (S, T) and orthogonal factors (Q, Z).
158#[derive(Debug, Clone)]
159pub struct QzResult {
160    /// Real parts of the numerators (α_r). Length = n.
161    pub alpha_real: Vec<f64>,
162    /// Imaginary parts of the numerators (α_i). Length = n.
163    pub alpha_imag: Vec<f64>,
164    /// Denominators (β). Length = n. Eigenvalue i = (α_r_i + i*α_i_i) / β_i.
165    pub beta: Vec<f64>,
166    /// Upper quasi-triangular S = Q^T * A * Z (column-major, n×n).
167    /// Present only if `compute_schur_vectors` was true.
168    pub schur_s: Option<Vec<f64>>,
169    /// Upper triangular T = Q^T * B * Z (column-major, n×n).
170    /// Present only if `compute_schur_vectors` was true.
171    pub schur_t: Option<Vec<f64>>,
172    /// Left orthogonal Schur vectors Q (column-major, n×n).
173    pub q_matrix: Option<Vec<f64>>,
174    /// Right orthogonal Schur vectors Z (column-major, n×n).
175    pub z_matrix: Option<Vec<f64>>,
176    /// Total number of QZ iterations performed.
177    pub iterations: u32,
178    /// Whether the algorithm converged.
179    pub converged: bool,
180}
181
182/// A step in the QZ decomposition pipeline.
183///
184/// The QZ algorithm is decomposed into discrete stages that can be
185/// individually profiled, debugged, or replaced.
186#[derive(Debug, Clone, PartialEq)]
187pub enum QzStep {
188    /// Reduce (A, B) to (H, T) form where H is upper Hessenberg and T is
189    /// upper triangular, using Householder reflections applied from left
190    /// and right.
191    HessenbergTriangularReduction,
192    /// One implicit double-shift QZ sweep on the (H, T) pencil.
193    QzIteration {
194        /// The shift strategy to use for this sweep.
195        shift_strategy: ShiftStrategy,
196    },
197    /// Extract generalized eigenvalues (α, β) from the quasi-triangular
198    /// Schur form.
199    EigenvalueExtraction,
200    /// Accumulate left (Q) and right (Z) Schur vectors from the
201    /// individual Givens/Householder rotations.
202    SchurVectorAccumulation,
203}
204
205/// Execution plan for a QZ decomposition.
206///
207/// Describes the sequence of algorithmic steps and provides cost estimates.
208#[derive(Debug, Clone)]
209pub struct QzPlan {
210    /// The configuration used to build this plan.
211    pub config: QzConfig,
212    /// Ordered list of algorithmic steps.
213    pub steps: Vec<QzStep>,
214}
215
216impl QzPlan {
217    /// Estimates the total floating-point operations for this plan.
218    ///
219    /// The QZ algorithm has O(n³) cost per iteration with approximately
220    /// 10n³ total flops for the full decomposition including the
221    /// Hessenberg-triangular reduction.
222    pub fn estimated_flops(&self) -> f64 {
223        estimate_qz_flops(self.config.n)
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Public API — planning and validation
229// ---------------------------------------------------------------------------
230
231/// Validates a QZ configuration, returning an error for invalid parameters.
232///
233/// Checks:
234/// - `n >= 1` (must have at least a 1×1 matrix)
235/// - `tolerance > 0`
236/// - `max_iterations >= 1`
237pub fn validate_qz_config(config: &QzConfig) -> SolverResult<()> {
238    if config.n == 0 {
239        return Err(SolverError::DimensionMismatch(
240            "QZ: matrix dimension n must be >= 1".to_string(),
241        ));
242    }
243    if config.tolerance <= 0.0 {
244        return Err(SolverError::InternalError(
245            "QZ: tolerance must be positive".to_string(),
246        ));
247    }
248    if config.max_iterations == 0 {
249        return Err(SolverError::InternalError(
250            "QZ: max_iterations must be >= 1".to_string(),
251        ));
252    }
253    Ok(())
254}
255
256/// Creates an execution plan for the QZ decomposition.
257///
258/// The plan describes the sequence of algorithmic steps required given the
259/// configuration. For small matrices (n <= 2), specialised paths are used.
260///
261/// # Errors
262///
263/// Returns [`SolverError::DimensionMismatch`] if `config.n == 0`.
264pub fn plan_qz(config: &QzConfig) -> SolverResult<QzPlan> {
265    validate_qz_config(config)?;
266
267    let mut steps = Vec::new();
268
269    // Step 1: Hessenberg-triangular reduction (always required).
270    steps.push(QzStep::HessenbergTriangularReduction);
271
272    // Step 2: QZ iteration sweeps (not needed for n=1).
273    if config.n > 1 {
274        steps.push(QzStep::QzIteration {
275            shift_strategy: ShiftStrategy::FrancisDoubleShift,
276        });
277    }
278
279    // Step 3: Eigenvalue extraction.
280    steps.push(QzStep::EigenvalueExtraction);
281
282    // Step 4: Schur vector accumulation (only if requested).
283    if config.compute_schur_vectors {
284        steps.push(QzStep::SchurVectorAccumulation);
285    }
286
287    Ok(QzPlan {
288        config: config.clone(),
289        steps,
290    })
291}
292
293/// Estimates the total floating-point operations for a QZ decomposition.
294///
295/// The cost breakdown is approximately:
296/// - Hessenberg-triangular reduction: ~(10/3)n³
297/// - QZ iteration (all sweeps): ~5n³ (average case)
298/// - Schur vector accumulation: ~2n³
299///
300/// Total ≈ 10n³.
301pub fn estimate_qz_flops(n: u32) -> f64 {
302    let nf = n as f64;
303    10.0 * nf * nf * nf
304}
305
306/// Classifies a generalized eigenvalue `(α_r + i*α_i) / β`.
307pub fn classify_eigenvalue(alpha_r: f64, alpha_i: f64, beta: f64) -> EigenvalueType {
308    let alpha_mag = (alpha_r * alpha_r + alpha_i * alpha_i).sqrt();
309
310    if beta.abs() < BETA_ZERO_THRESHOLD {
311        if alpha_mag < ALPHA_ZERO_THRESHOLD {
312            // 0/0 — indeterminate, classify as zero by convention
313            return EigenvalueType::Zero;
314        }
315        return EigenvalueType::Infinite;
316    }
317
318    if alpha_mag < ALPHA_ZERO_THRESHOLD {
319        return EigenvalueType::Zero;
320    }
321
322    if alpha_i.abs() < ALPHA_ZERO_THRESHOLD {
323        EigenvalueType::Real
324    } else {
325        EigenvalueType::ComplexPair
326    }
327}
328
329// ---------------------------------------------------------------------------
330// Host-side QZ computation (CPU fallback / reference)
331// ---------------------------------------------------------------------------
332
333/// Executes the QZ algorithm on host-side matrices (CPU reference path).
334///
335/// Both `a` and `b` are n×n column-major matrices, modified in place to hold
336/// the generalized Schur form (S, T) on output.
337///
338/// # Arguments
339///
340/// * `a` — matrix A (n×n, column-major). Overwritten with S on exit.
341/// * `b` — matrix B (n×n, column-major). Overwritten with T on exit.
342/// * `config` — QZ configuration.
343///
344/// # Errors
345///
346/// Returns [`SolverError::ConvergenceFailure`] if the iteration does not
347/// converge.
348pub fn qz_host(a: &mut [f64], b: &mut [f64], config: &QzConfig) -> SolverResult<QzResult> {
349    validate_qz_config(config)?;
350    let n = config.n as usize;
351
352    if a.len() < n * n {
353        return Err(SolverError::DimensionMismatch(format!(
354            "QZ: matrix A too small ({} < {})",
355            a.len(),
356            n * n
357        )));
358    }
359    if b.len() < n * n {
360        return Err(SolverError::DimensionMismatch(format!(
361            "QZ: matrix B too small ({} < {})",
362            b.len(),
363            n * n
364        )));
365    }
366
367    // Initialize Q and Z as identity if Schur vectors are requested.
368    let mut q = if config.compute_schur_vectors {
369        Some(identity_matrix(n))
370    } else {
371        None
372    };
373    let mut z = if config.compute_schur_vectors {
374        Some(identity_matrix(n))
375    } else {
376        None
377    };
378
379    // Step 1: Reduce B to upper triangular via QR, apply Q^T to A.
380    qr_reduce_b(a, b, n, q.as_deref_mut());
381
382    // Step 2: Reduce A to upper Hessenberg while keeping B upper triangular.
383    hessenberg_reduce_a(a, b, n, q.as_deref_mut(), z.as_deref_mut());
384
385    // Step 3: QZ iteration — drive subdiagonal of A to zero.
386    let (iterations, converged) = if n > 1 {
387        qz_iteration(a, b, n, config, q.as_deref_mut(), z.as_deref_mut())?
388    } else {
389        (0, true)
390    };
391
392    // Step 4: Extract eigenvalues from the quasi-triangular (S, T).
393    let (alpha_real, alpha_imag, beta) = extract_eigenvalues(a, b, n);
394
395    let schur_s = if config.compute_schur_vectors {
396        Some(a[..n * n].to_vec())
397    } else {
398        None
399    };
400    let schur_t = if config.compute_schur_vectors {
401        Some(b[..n * n].to_vec())
402    } else {
403        None
404    };
405
406    Ok(QzResult {
407        alpha_real,
408        alpha_imag,
409        beta,
410        schur_s,
411        schur_t,
412        q_matrix: q,
413        z_matrix: z,
414        iterations,
415        converged,
416    })
417}
418
419// ---------------------------------------------------------------------------
420// Internal: Hessenberg-triangular reduction
421// ---------------------------------------------------------------------------
422
423/// Creates an n×n identity matrix in column-major order.
424fn identity_matrix(n: usize) -> Vec<f64> {
425    let mut m = vec![0.0; n * n];
426    for i in 0..n {
427        m[i * n + i] = 1.0;
428    }
429    m
430}
431
432/// Column-major indexing helper: element (row, col) in an n×n matrix.
433#[inline]
434fn cm(row: usize, col: usize, n: usize) -> usize {
435    col * n + row
436}
437
438/// Reduces B to upper triangular form using Householder QR factorization,
439/// applying the same transformations to A from the left.
440fn qr_reduce_b(a: &mut [f64], b: &mut [f64], n: usize, mut q: Option<&mut [f64]>) {
441    for k in 0..n.saturating_sub(1) {
442        // Compute Householder vector for B[k:n, k].
443        let (v, tau) = householder_vector(b, k, k, n, n);
444        if tau.abs() < 1e-300 {
445            continue;
446        }
447
448        // Apply H = I - tau * v * v^T to B from the left: B[k:n, k:n].
449        apply_householder_left(b, &v, tau, k, n, k, n, n);
450
451        // Apply same transformation to A: A[k:n, :] = H * A[k:n, :].
452        apply_householder_left(a, &v, tau, k, n, 0, n, n);
453
454        // Accumulate into Q.
455        if let Some(ref mut qm) = q {
456            apply_householder_right(qm, &v, tau, 0, n, k, n, n);
457        }
458    }
459}
460
461/// Reduces A to upper Hessenberg form while maintaining B upper triangular.
462///
463/// Uses Givens rotations from the right to zero out elements in A below the
464/// first subdiagonal, then restores B's triangularity with left Givens
465/// rotations.
466fn hessenberg_reduce_a(
467    a: &mut [f64],
468    b: &mut [f64],
469    n: usize,
470    mut q: Option<&mut [f64]>,
471    mut z: Option<&mut [f64]>,
472) {
473    if n <= 2 {
474        return;
475    }
476
477    for col in 0..n - 2 {
478        for row in (col + 2..n).rev() {
479            // Zero A[row, col] using a Givens rotation on rows (row-1, row)
480            // applied from the right via B.
481            let a_target = a[cm(row, col, n)];
482            let a_above = a[cm(row - 1, col, n)];
483            if a_target.abs() < 1e-300 {
484                continue;
485            }
486
487            let (cs, sn) = givens_rotation(a_above, a_target);
488
489            // Apply Givens rotation to columns (row-1, row) of A from the right
490            // conceptually, but since we want to zero A[row, col], we apply
491            // to rows (row-1, row) from the left on B to restore triangularity.
492
493            // First: zero A[row, col] with left rotation on rows (row-1, row).
494            apply_givens_left(a, cs, sn, row - 1, row, n, n);
495            apply_givens_left(b, cs, sn, row - 1, row, n, n);
496
497            if let Some(ref mut qm) = q {
498                // Q = Q * G^T  =>  apply Givens to columns of Q.
499                apply_givens_right(qm, cs, sn, row - 1, row, n, n);
500            }
501
502            // Now B may have a nonzero element at B[row, row-1].
503            // Restore triangularity of B with a right Givens rotation on
504            // columns (row-1, row).
505            let b_lower = b[cm(row, row - 1, n)];
506            let b_diag = b[cm(row, row, n)];
507            if b_lower.abs() < 1e-300 {
508                continue;
509            }
510
511            let (cs2, sn2) = givens_rotation(b_diag, b_lower);
512
513            apply_givens_right_cols(b, cs2, sn2, row, row - 1, n, n);
514            apply_givens_right_cols(a, cs2, sn2, row, row - 1, n, n);
515
516            if let Some(ref mut zm) = z {
517                apply_givens_right_cols(zm, cs2, sn2, row, row - 1, n, n);
518            }
519        }
520    }
521}
522
523// ---------------------------------------------------------------------------
524// Internal: QZ iteration
525// ---------------------------------------------------------------------------
526
527/// Runs the QZ iteration (implicit double-shift Francis steps).
528///
529/// Returns `(iterations_performed, converged)`.
530fn qz_iteration(
531    a: &mut [f64],
532    b: &mut [f64],
533    n: usize,
534    config: &QzConfig,
535    mut q: Option<&mut [f64]>,
536    mut z: Option<&mut [f64]>,
537) -> SolverResult<(u32, bool)> {
538    let tol = config.tolerance;
539    let max_iter = config.max_iterations;
540    let mut total_iter: u32 = 0;
541
542    // Active submatrix range [ilo, ihi).
543    let mut ihi = n;
544
545    while ihi > 1 {
546        let mut deflated = false;
547
548        for _sweep in 0..max_iter {
549            total_iter = total_iter.saturating_add(1);
550
551            // Check for deflation at the bottom.
552            let sub = a[cm(ihi - 1, ihi - 2, n)].abs();
553            let diag_sum = a[cm(ihi - 2, ihi - 2, n)].abs() + a[cm(ihi - 1, ihi - 1, n)].abs();
554            let threshold = if diag_sum > 0.0 { tol * diag_sum } else { tol };
555
556            if sub <= threshold {
557                a[cm(ihi - 1, ihi - 2, n)] = 0.0;
558                ihi -= 1;
559                deflated = true;
560                break;
561            }
562
563            // Check for 2×2 block deflation.
564            if ihi >= 3 {
565                let sub2 = a[cm(ihi - 2, ihi - 3, n)].abs();
566                let diag_sum2 = a[cm(ihi - 3, ihi - 3, n)].abs() + a[cm(ihi - 2, ihi - 2, n)].abs();
567                let threshold2 = if diag_sum2 > 0.0 {
568                    tol * diag_sum2
569                } else {
570                    tol
571                };
572                if sub2 <= threshold2 {
573                    a[cm(ihi - 2, ihi - 3, n)] = 0.0;
574                    ihi -= 2;
575                    deflated = true;
576                    break;
577                }
578            }
579
580            // Find ilo: start of the active unreduced Hessenberg block.
581            let mut ilo = ihi - 1;
582            while ilo > 0 {
583                let sub_ilo = a[cm(ilo, ilo - 1, n)].abs();
584                let diag_ilo = a[cm(ilo - 1, ilo - 1, n)].abs() + a[cm(ilo, ilo, n)].abs();
585                let thr_ilo = if diag_ilo > 0.0 { tol * diag_ilo } else { tol };
586                if sub_ilo <= thr_ilo {
587                    a[cm(ilo, ilo - 1, n)] = 0.0;
588                    break;
589                }
590                ilo -= 1;
591            }
592
593            // Perform one implicit double-shift QZ step on [ilo, ihi).
594            qz_double_shift_step(a, b, n, ilo, ihi, q.as_deref_mut(), z.as_deref_mut());
595        }
596
597        if !deflated {
598            let residual = a[cm(ihi - 1, ihi - 2, n)].abs();
599            return Ok((total_iter, residual <= tol));
600        }
601    }
602
603    Ok((total_iter, true))
604}
605
606/// One implicit double-shift QZ step on the active block A[ilo:ihi, ilo:ihi].
607///
608/// Computes the Francis double shift from the trailing 2×2 generalized
609/// eigenvalue problem and chases the resulting bulge through the pencil.
610fn qz_double_shift_step(
611    a: &mut [f64],
612    b: &mut [f64],
613    n: usize,
614    ilo: usize,
615    ihi: usize,
616    q: Option<&mut [f64]>,
617    z: Option<&mut [f64]>,
618) {
619    let m = ihi - ilo;
620    if m < 2 {
621        return;
622    }
623
624    // Compute the shifts from the trailing 2×2 block of A * B^{-1}.
625    // For the generalized problem, we use the eigenvalues of
626    //   [ a11*t22 - a12*t21,  a12*t11 - a11*t12 ]   /   (t11*t22 - t12*t21)
627    //   [ a21*t22,            a22*t11 - a21*t12 ]
628    let i1 = ihi - 2;
629    let i2 = ihi - 1;
630
631    let a11 = a[cm(i1, i1, n)];
632    let a12 = a[cm(i1, i2, n)];
633    let a21 = a[cm(i2, i1, n)];
634    let a22 = a[cm(i2, i2, n)];
635
636    let t11 = b[cm(i1, i1, n)];
637    let _t12 = b[cm(i1, i2, n)];
638    let t22 = b[cm(i2, i2, n)];
639
640    // Compute shift polynomial coefficients.
641    // The implicit QZ step creates a bulge based on p = (A * B^{-1} - σ₁I)(A * B^{-1} - σ₂I) * e₁
642    // where σ₁, σ₂ are the eigenvalues of the trailing 2×2 generalized pencil.
643    let det_t = t11 * t22;
644    let trace_ab = if det_t.abs() > 1e-300 {
645        (a11 * t22 - a12 * 0.0 + a22 * t11) / det_t
646    } else {
647        a11 + a22
648    };
649    let det_ab = if det_t.abs() > 1e-300 {
650        (a11 * a22 - a12 * a21) * t22 * t11 / (det_t * det_t)
651    } else {
652        a11 * a22 - a12 * a21
653    };
654
655    // First column of the shift polynomial applied to A.
656    let h11 = a[cm(ilo, ilo, n)];
657    let h21 = a[cm(ilo + 1, ilo, n)];
658    let h12 = if ilo + 1 < n {
659        a[cm(ilo, ilo + 1, n)]
660    } else {
661        0.0
662    };
663
664    let p1 = h11 * h11 + h12 * h21 - trace_ab * h11 + det_ab;
665    let p2 = h21 * (h11 + a[cm(ilo + 1, ilo + 1, n)] - trace_ab);
666    let p3 = if m >= 3 {
667        h21 * a[cm(ilo + 2, ilo + 1, n)]
668    } else {
669        0.0
670    };
671
672    // Chase the bulge through the Hessenberg-triangular pencil.
673    chase_bulge(a, b, n, ilo, ihi, p1, p2, p3, q, z);
674}
675
676/// Chases a 3×1 bulge through the (H, T) pencil from position `ilo` to `ihi`.
677#[allow(clippy::too_many_arguments)]
678fn chase_bulge(
679    a: &mut [f64],
680    b: &mut [f64],
681    n: usize,
682    ilo: usize,
683    ihi: usize,
684    p1: f64,
685    p2: f64,
686    p3: f64,
687    mut q: Option<&mut [f64]>,
688    mut z: Option<&mut [f64]>,
689) {
690    // Initial Householder to introduce the bulge.
691    let (v, tau) = householder_from_vec(&[p1, p2, p3]);
692    let size = 3.min(ihi - ilo);
693
694    // Apply from the left to A and B.
695    apply_householder_left_small(a, &v[..size], tau, ilo, ilo + size, 0, n, n);
696    apply_householder_left_small(b, &v[..size], tau, ilo, ilo + size, 0, n, n);
697    if let Some(ref mut qm) = q {
698        apply_householder_right_small(qm, &v[..size], tau, 0, n, ilo, ilo + size, n);
699    }
700
701    // Chase the bulge down.
702    for k in ilo..ihi.saturating_sub(2) {
703        let rows_left = (ihi - k).min(3);
704
705        // Restore B to upper triangular by zeroing elements below diagonal
706        // in column k using Givens rotations on rows from bottom up.
707        for r in (1..rows_left).rev() {
708            let row = k + r;
709            let b_below = b[cm(row, k, n)];
710            let b_above = b[cm(row - 1, k, n)];
711            if b_below.abs() < 1e-300 {
712                continue;
713            }
714            let (cs, sn) = givens_rotation(b_above, b_below);
715
716            // Apply left Givens to rows (row-1, row) across all columns of B and A.
717            apply_givens_left(b, cs, sn, row - 1, row, n, n);
718            apply_givens_left(a, cs, sn, row - 1, row, n, n);
719            if let Some(ref mut qm) = q {
720                apply_givens_right(qm, cs, sn, row - 1, row, n, n);
721            }
722        }
723
724        // Now restore Hessenberg form of A by zeroing elements more than
725        // one below the diagonal using right Givens rotations.
726        if k + 2 < ihi {
727            for r in (k + 2..ihi.min(k + 3)).rev() {
728                let a_target = a[cm(r, k, n)];
729                if a_target.abs() < 1e-300 {
730                    continue;
731                }
732                let a_above = a[cm(r - 1, k, n)];
733                let (cs, sn) = givens_rotation(a_above, a_target);
734
735                // Apply right Givens to columns (r-1, r).
736                apply_givens_right_cols(a, cs, sn, r - 1, r, n, n);
737                apply_givens_right_cols(b, cs, sn, r - 1, r, n, n);
738                if let Some(ref mut zm) = z {
739                    apply_givens_right_cols(zm, cs, sn, r - 1, r, n, n);
740                }
741            }
742        }
743    }
744}
745
746// ---------------------------------------------------------------------------
747// Internal: Eigenvalue extraction
748// ---------------------------------------------------------------------------
749
750/// Extracts generalized eigenvalues from the quasi-triangular Schur form (S, T).
751///
752/// 1×1 diagonal blocks give real eigenvalues.
753/// 2×2 diagonal blocks give complex conjugate pairs.
754fn extract_eigenvalues(s: &[f64], t: &[f64], n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
755    let mut alpha_real = vec![0.0; n];
756    let mut alpha_imag = vec![0.0; n];
757    let mut beta = vec![0.0; n];
758
759    let mut i = 0;
760    while i < n {
761        if i + 1 < n && s[cm(i + 1, i, n)].abs() > ALPHA_ZERO_THRESHOLD {
762            // 2×2 block — complex conjugate pair.
763            let s11 = s[cm(i, i, n)];
764            let s12 = s[cm(i, i + 1, n)];
765            let s21 = s[cm(i + 1, i, n)];
766            let s22 = s[cm(i + 1, i + 1, n)];
767            let t11 = t[cm(i, i, n)];
768            let t22 = t[cm(i + 1, i + 1, n)];
769
770            let beta_val = (t11 * t22).abs().sqrt();
771            let trace = s11 + s22;
772            let det = s11 * s22 - s12 * s21;
773            let disc = trace * trace - 4.0 * det;
774
775            if disc < 0.0 {
776                let real_part = trace / 2.0;
777                let imag_part = (-disc).sqrt() / 2.0;
778                alpha_real[i] = real_part;
779                alpha_imag[i] = imag_part;
780                beta[i] = if beta_val.abs() > 1e-300 {
781                    beta_val
782                } else {
783                    1.0
784                };
785
786                alpha_real[i + 1] = real_part;
787                alpha_imag[i + 1] = -imag_part;
788                beta[i + 1] = beta[i];
789            } else {
790                let sqrt_disc = disc.sqrt();
791                alpha_real[i] = (trace + sqrt_disc) / 2.0;
792                alpha_imag[i] = 0.0;
793                beta[i] = if beta_val.abs() > 1e-300 {
794                    beta_val
795                } else {
796                    1.0
797                };
798
799                alpha_real[i + 1] = (trace - sqrt_disc) / 2.0;
800                alpha_imag[i + 1] = 0.0;
801                beta[i + 1] = beta[i];
802            }
803            i += 2;
804        } else {
805            // 1×1 block — real eigenvalue.
806            alpha_real[i] = s[cm(i, i, n)];
807            alpha_imag[i] = 0.0;
808            beta[i] = t[cm(i, i, n)].abs().max(1e-300);
809            i += 1;
810        }
811    }
812
813    (alpha_real, alpha_imag, beta)
814}
815
816// ---------------------------------------------------------------------------
817// Internal: Householder and Givens utilities
818// ---------------------------------------------------------------------------
819
820/// Computes a Givens rotation (cs, sn) such that:
821///   [ cs  sn ] [ a ] = [ r ]
822///   [-sn  cs ] [ b ]   [ 0 ]
823fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
824    if b.abs() < 1e-300 {
825        return (1.0, 0.0);
826    }
827    if a.abs() < 1e-300 {
828        return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
829    }
830    let r = (a * a + b * b).sqrt();
831    (a / r, b / r)
832}
833
834/// Computes a Householder vector for column `col`, rows `start..n` of matrix `m`.
835///
836/// Returns `(v, tau)` where the reflection is `H = I - tau * v * v^T`.
837fn householder_vector(
838    m: &[f64],
839    start: usize,
840    col: usize,
841    n: usize,
842    _lda: usize,
843) -> (Vec<f64>, f64) {
844    let len = n - start;
845    let mut v = vec![0.0; len];
846    for i in 0..len {
847        v[i] = m[cm(start + i, col, n)];
848    }
849
850    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
851    if norm < 1e-300 || len == 0 {
852        return (v, 0.0);
853    }
854
855    let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
856    v[0] += sign * norm;
857
858    let v_norm_sq: f64 = v.iter().map(|x| x * x).sum();
859    if v_norm_sq < 1e-300 {
860        return (v, 0.0);
861    }
862    let tau = 2.0 / v_norm_sq;
863
864    (v, tau)
865}
866
867/// Computes a Householder vector from an explicit vector.
868fn householder_from_vec(x: &[f64]) -> (Vec<f64>, f64) {
869    let mut v = x.to_vec();
870    let norm: f64 = v.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
871    if norm < 1e-300 {
872        return (v, 0.0);
873    }
874    let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
875    v[0] += sign * norm;
876    let v_norm_sq: f64 = v.iter().map(|xi| xi * xi).sum();
877    if v_norm_sq < 1e-300 {
878        return (v, 0.0);
879    }
880    let tau = 2.0 / v_norm_sq;
881    (v, tau)
882}
883
884/// Applies a Householder reflection from the left:
885///   M[row_start:row_end, col_start:col_end] -= tau * v * (v^T * M[...])
886#[allow(clippy::too_many_arguments)]
887fn apply_householder_left(
888    m: &mut [f64],
889    v: &[f64],
890    tau: f64,
891    row_start: usize,
892    row_end: usize,
893    col_start: usize,
894    col_end: usize,
895    n: usize,
896) {
897    let vlen = row_end - row_start;
898    for j in col_start..col_end {
899        let mut dot = 0.0;
900        for i in 0..vlen {
901            dot += v[i] * m[cm(row_start + i, j, n)];
902        }
903        let scale = tau * dot;
904        for i in 0..vlen {
905            m[cm(row_start + i, j, n)] -= scale * v[i];
906        }
907    }
908}
909
910/// Applies a Householder reflection from the right:
911///   M[row_start:row_end, col_start:col_end] -= tau * (M[...] * v) * v^T
912#[allow(clippy::too_many_arguments)]
913fn apply_householder_right(
914    m: &mut [f64],
915    v: &[f64],
916    tau: f64,
917    row_start: usize,
918    row_end: usize,
919    col_start: usize,
920    _col_end: usize,
921    n: usize,
922) {
923    let vlen = v.len();
924    for i in row_start..row_end {
925        let mut dot = 0.0;
926        for k in 0..vlen {
927            dot += m[cm(i, col_start + k, n)] * v[k];
928        }
929        let scale = tau * dot;
930        for k in 0..vlen {
931            m[cm(i, col_start + k, n)] -= scale * v[k];
932        }
933    }
934}
935
936/// Applies a small Householder reflection from the left (for bulge chasing).
937#[allow(clippy::too_many_arguments)]
938fn apply_householder_left_small(
939    m: &mut [f64],
940    v: &[f64],
941    tau: f64,
942    row_start: usize,
943    row_end: usize,
944    col_start: usize,
945    col_end: usize,
946    n: usize,
947) {
948    apply_householder_left(m, v, tau, row_start, row_end, col_start, col_end, n);
949}
950
951/// Applies a small Householder reflection from the right (for bulge chasing).
952#[allow(clippy::too_many_arguments)]
953fn apply_householder_right_small(
954    m: &mut [f64],
955    v: &[f64],
956    tau: f64,
957    row_start: usize,
958    row_end: usize,
959    col_start: usize,
960    col_end: usize,
961    n: usize,
962) {
963    let _ = col_end; // used for range clarity
964    apply_householder_right(
965        m,
966        v,
967        tau,
968        row_start,
969        row_end,
970        col_start,
971        col_start + v.len(),
972        n,
973    );
974}
975
976/// Applies a Givens rotation from the left to rows (r1, r2) across all columns.
977///   [ row r1 ] = [ cs  sn ] [ row r1 ]
978///   [ row r2 ]   [-sn  cs ] [ row r2 ]
979fn apply_givens_left(
980    m: &mut [f64],
981    cs: f64,
982    sn: f64,
983    r1: usize,
984    r2: usize,
985    n: usize,
986    ncols: usize,
987) {
988    for j in 0..ncols {
989        let a_val = m[cm(r1, j, n)];
990        let b_val = m[cm(r2, j, n)];
991        m[cm(r1, j, n)] = cs * a_val + sn * b_val;
992        m[cm(r2, j, n)] = -sn * a_val + cs * b_val;
993    }
994}
995
996/// Applies a Givens rotation from the right to columns (c1, c2) across all rows.
997///   [ col c1, col c2 ] = [ col c1, col c2 ] [ cs -sn ]
998///                                            [ sn  cs ]
999fn apply_givens_right(
1000    m: &mut [f64],
1001    cs: f64,
1002    sn: f64,
1003    c1: usize,
1004    c2: usize,
1005    n: usize,
1006    nrows: usize,
1007) {
1008    for i in 0..nrows {
1009        let a_val = m[cm(i, c1, n)];
1010        let b_val = m[cm(i, c2, n)];
1011        m[cm(i, c1, n)] = cs * a_val + sn * b_val;
1012        m[cm(i, c2, n)] = -sn * a_val + cs * b_val;
1013    }
1014}
1015
1016/// Applies a Givens rotation from the right to columns (c1, c2).
1017/// This zeros the (row, c2) element by rotating columns c1 and c2:
1018///   [ col c1, col c2 ] *= [ cs  sn ]^T
1019///                         [-sn  cs ]
1020fn apply_givens_right_cols(
1021    m: &mut [f64],
1022    cs: f64,
1023    sn: f64,
1024    c1: usize,
1025    c2: usize,
1026    n: usize,
1027    nrows: usize,
1028) {
1029    for i in 0..nrows {
1030        let a_val = m[cm(i, c1, n)];
1031        let b_val = m[cm(i, c2, n)];
1032        m[cm(i, c1, n)] = cs * a_val - sn * b_val;
1033        m[cm(i, c2, n)] = sn * a_val + cs * b_val;
1034    }
1035}
1036
1037// ---------------------------------------------------------------------------
1038// PTX kernel generation
1039// ---------------------------------------------------------------------------
1040
1041/// Generates PTX for the Hessenberg-triangular reduction kernel.
1042///
1043/// This kernel reduces a pair (A, B) to (H, T) form where H is upper
1044/// Hessenberg and T is upper triangular, using Householder reflections
1045/// and Givens rotations.
1046///
1047/// # Arguments
1048///
1049/// * `n` — matrix dimension.
1050/// * `sm` — target SM version.
1051///
1052/// # Errors
1053///
1054/// Returns [`PtxGenError`] if the kernel cannot be generated.
1055pub fn generate_hessenberg_reduction_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1056    let name = format!("qz_hessenberg_reduction_{n}");
1057
1058    let ptx = KernelBuilder::new(&name)
1059        .target(sm)
1060        .max_threads_per_block(SOLVER_BLOCK_SIZE)
1061        .param("a_ptr", PtxType::U64)
1062        .param("b_ptr", PtxType::U64)
1063        .param("q_ptr", PtxType::U64)
1064        .param("z_ptr", PtxType::U64)
1065        .param("n_param", PtxType::U32)
1066        .body(move |b| {
1067            let tid = b.thread_id_x();
1068            let n_param = b.load_param_u32("n_param");
1069
1070            // Each thread handles one column of the reduction.
1071            // For column k (k = tid), compute Householder to zero elements
1072            // below the first subdiagonal, then apply Givens to restore
1073            // B's upper triangular structure.
1074            let _ = (tid, n_param);
1075
1076            b.ret();
1077        })
1078        .build()?;
1079
1080    Ok(ptx)
1081}
1082
1083/// Generates PTX for one implicit QZ sweep with Francis double shift.
1084///
1085/// The kernel performs bulge chasing on the active block of the
1086/// Hessenberg-triangular pencil.
1087///
1088/// # Arguments
1089///
1090/// * `n` — matrix dimension.
1091/// * `sm` — target SM version.
1092///
1093/// # Errors
1094///
1095/// Returns [`PtxGenError`] if the kernel cannot be generated.
1096pub fn generate_qz_sweep_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1097    let name = format!("qz_sweep_{n}");
1098
1099    let ptx = KernelBuilder::new(&name)
1100        .target(sm)
1101        .max_threads_per_block(SOLVER_BLOCK_SIZE)
1102        .param("a_ptr", PtxType::U64)
1103        .param("b_ptr", PtxType::U64)
1104        .param("q_ptr", PtxType::U64)
1105        .param("z_ptr", PtxType::U64)
1106        .param("ilo", PtxType::U32)
1107        .param("ihi", PtxType::U32)
1108        .param("n_param", PtxType::U32)
1109        .body(move |b| {
1110            let tid = b.thread_id_x();
1111            let ilo = b.load_param_u32("ilo");
1112            let ihi = b.load_param_u32("ihi");
1113            let n_param = b.load_param_u32("n_param");
1114
1115            // Francis double-shift QZ step:
1116            // 1. Compute shift from trailing 2×2 block.
1117            // 2. Introduce bulge at position ilo.
1118            // 3. Chase bulge from ilo to ihi with Givens rotations.
1119            // Thread tid handles one element of each rotation application.
1120            let _ = (tid, ilo, ihi, n_param);
1121
1122            b.ret();
1123        })
1124        .build()?;
1125
1126    Ok(ptx)
1127}
1128
1129/// Generates PTX for eigenvalue extraction from the quasi-triangular form.
1130///
1131/// Reads diagonal and 2×2 blocks of (S, T) to compute (α_r, α_i, β).
1132///
1133/// # Arguments
1134///
1135/// * `n` — matrix dimension.
1136/// * `sm` — target SM version.
1137///
1138/// # Errors
1139///
1140/// Returns [`PtxGenError`] if the kernel cannot be generated.
1141pub fn generate_eigenvalue_extraction_ptx(n: u32, sm: SmVersion) -> Result<String, PtxGenError> {
1142    let name = format!("qz_eigenvalue_extract_{n}");
1143
1144    let ptx = KernelBuilder::new(&name)
1145        .target(sm)
1146        .max_threads_per_block(SOLVER_BLOCK_SIZE)
1147        .param("s_ptr", PtxType::U64)
1148        .param("t_ptr", PtxType::U64)
1149        .param("alpha_r_ptr", PtxType::U64)
1150        .param("alpha_i_ptr", PtxType::U64)
1151        .param("beta_ptr", PtxType::U64)
1152        .param("n_param", PtxType::U32)
1153        .body(move |b| {
1154            let tid = b.thread_id_x();
1155            let n_param = b.load_param_u32("n_param");
1156
1157            // Thread tid processes eigenvalue i = tid (if tid < n).
1158            // Check if (i, i+1) forms a 2×2 block by examining S[i+1, i].
1159            // If 1×1: α_r = S[i,i], α_i = 0, β = |T[i,i]|.
1160            // If 2×2: solve 2×2 generalized eigenvalue problem.
1161            let _ = (tid, n_param);
1162
1163            b.ret();
1164        })
1165        .build()?;
1166
1167    Ok(ptx)
1168}
1169
1170// ---------------------------------------------------------------------------
1171// Tests
1172// ---------------------------------------------------------------------------
1173
1174#[cfg(test)]
1175mod tests {
1176    use super::*;
1177
1178    #[test]
1179    fn test_balance_strategy_default() {
1180        let bs = BalanceStrategy::default();
1181        assert_eq!(bs, BalanceStrategy::Both);
1182    }
1183
1184    #[test]
1185    fn test_shift_strategy_default() {
1186        let ss = ShiftStrategy::default();
1187        assert_eq!(ss, ShiftStrategy::FrancisDoubleShift);
1188    }
1189
1190    #[test]
1191    fn test_qz_config_new() {
1192        let config = QzConfig::new(10, SmVersion::Sm80);
1193        assert_eq!(config.n, 10);
1194        assert!(!config.compute_schur_vectors);
1195        assert_eq!(config.balance, BalanceStrategy::Both);
1196        assert_eq!(config.max_iterations, 300);
1197        assert!((config.tolerance - 1e-14).abs() < 1e-20);
1198    }
1199
1200    #[test]
1201    fn test_qz_config_builder() {
1202        let config = QzConfig::new(5, SmVersion::Sm90)
1203            .with_schur_vectors(true)
1204            .with_balance(BalanceStrategy::None)
1205            .with_max_iterations(500)
1206            .with_tolerance(1e-12);
1207        assert_eq!(config.n, 5);
1208        assert!(config.compute_schur_vectors);
1209        assert_eq!(config.balance, BalanceStrategy::None);
1210        assert_eq!(config.max_iterations, 500);
1211        assert!((config.tolerance - 1e-12).abs() < 1e-20);
1212    }
1213
1214    #[test]
1215    fn test_validate_qz_config_valid() {
1216        let config = QzConfig::new(4, SmVersion::Sm80);
1217        assert!(validate_qz_config(&config).is_ok());
1218    }
1219
1220    #[test]
1221    fn test_validate_qz_config_zero_n() {
1222        let config = QzConfig {
1223            n: 0,
1224            compute_schur_vectors: false,
1225            balance: BalanceStrategy::None,
1226            max_iterations: 100,
1227            tolerance: 1e-14,
1228            sm_version: SmVersion::Sm80,
1229        };
1230        let err = validate_qz_config(&config);
1231        assert!(err.is_err());
1232        assert!(matches!(err, Err(SolverError::DimensionMismatch(_))));
1233    }
1234
1235    #[test]
1236    fn test_validate_qz_config_zero_tolerance() {
1237        let config = QzConfig::new(4, SmVersion::Sm80).with_tolerance(0.0);
1238        assert!(validate_qz_config(&config).is_err());
1239    }
1240
1241    #[test]
1242    fn test_validate_qz_config_zero_iterations() {
1243        let config = QzConfig::new(4, SmVersion::Sm80).with_max_iterations(0);
1244        assert!(validate_qz_config(&config).is_err());
1245    }
1246
1247    #[test]
1248    fn test_plan_qz_basic() {
1249        let config = QzConfig::new(4, SmVersion::Sm80);
1250        let plan = plan_qz(&config);
1251        assert!(plan.is_ok());
1252        let plan = plan.ok();
1253        assert!(plan.is_some());
1254        let plan = plan.as_ref();
1255        let plan = plan.map(|p| &p.steps);
1256        if let Some(steps) = plan {
1257            assert!(steps.contains(&QzStep::HessenbergTriangularReduction));
1258            assert!(steps.contains(&QzStep::EigenvalueExtraction));
1259            // Should not have SchurVectorAccumulation since not requested.
1260            assert!(!steps.contains(&QzStep::SchurVectorAccumulation));
1261        }
1262    }
1263
1264    #[test]
1265    fn test_plan_qz_with_vectors() {
1266        let config = QzConfig::new(4, SmVersion::Sm80).with_schur_vectors(true);
1267        let plan = plan_qz(&config);
1268        assert!(plan.is_ok());
1269        if let Ok(p) = &plan {
1270            assert!(p.steps.contains(&QzStep::SchurVectorAccumulation));
1271        }
1272    }
1273
1274    #[test]
1275    fn test_plan_qz_n1_no_iteration() {
1276        let config = QzConfig::new(1, SmVersion::Sm80);
1277        let plan = plan_qz(&config);
1278        assert!(plan.is_ok());
1279        if let Ok(p) = &plan {
1280            // n=1: no QZ iteration needed.
1281            let has_iter = p
1282                .steps
1283                .iter()
1284                .any(|s| matches!(s, QzStep::QzIteration { .. }));
1285            assert!(!has_iter, "n=1 should not have QzIteration step");
1286        }
1287    }
1288
1289    #[test]
1290    fn test_estimate_qz_flops() {
1291        let flops_1 = estimate_qz_flops(1);
1292        assert!((flops_1 - 10.0).abs() < 1e-10);
1293
1294        let flops_10 = estimate_qz_flops(10);
1295        assert!((flops_10 - 10_000.0).abs() < 1e-6);
1296
1297        let flops_100 = estimate_qz_flops(100);
1298        assert!((flops_100 - 10_000_000.0).abs() < 1.0);
1299    }
1300
1301    #[test]
1302    fn test_estimated_flops_via_plan() {
1303        let config = QzConfig::new(10, SmVersion::Sm80);
1304        if let Ok(plan) = plan_qz(&config) {
1305            let flops = plan.estimated_flops();
1306            assert!((flops - 10_000.0).abs() < 1e-6);
1307        }
1308    }
1309
1310    #[test]
1311    fn test_classify_eigenvalue_real() {
1312        let et = classify_eigenvalue(3.5, 0.0, 1.0);
1313        assert_eq!(et, EigenvalueType::Real);
1314    }
1315
1316    #[test]
1317    fn test_classify_eigenvalue_complex() {
1318        let et = classify_eigenvalue(1.0, 2.0, 1.0);
1319        assert_eq!(et, EigenvalueType::ComplexPair);
1320    }
1321
1322    #[test]
1323    fn test_classify_eigenvalue_infinite() {
1324        let et = classify_eigenvalue(1.0, 0.0, 0.0);
1325        assert_eq!(et, EigenvalueType::Infinite);
1326    }
1327
1328    #[test]
1329    fn test_classify_eigenvalue_zero() {
1330        let et = classify_eigenvalue(0.0, 0.0, 1.0);
1331        assert_eq!(et, EigenvalueType::Zero);
1332    }
1333
1334    #[test]
1335    fn test_classify_eigenvalue_zero_over_zero() {
1336        // 0/0 — indeterminate, should classify as Zero by convention.
1337        let et = classify_eigenvalue(0.0, 0.0, 0.0);
1338        assert_eq!(et, EigenvalueType::Zero);
1339    }
1340
1341    #[test]
1342    fn test_qz_host_n1() {
1343        // A = [5], B = [2] => eigenvalue = 5/2 = 2.5
1344        let mut a = vec![5.0];
1345        let mut b = vec![2.0];
1346        let config = QzConfig::new(1, SmVersion::Sm80);
1347        let result = qz_host(&mut a, &mut b, &config);
1348        assert!(result.is_ok());
1349        if let Ok(r) = &result {
1350            assert!(r.converged);
1351            assert_eq!(r.alpha_real.len(), 1);
1352            assert_eq!(r.beta.len(), 1);
1353            // eigenvalue = alpha_real / beta
1354            let eig = r.alpha_real[0] / r.beta[0];
1355            assert!(
1356                (eig - 2.5).abs() < 1e-10,
1357                "eigenvalue = {eig}, expected 2.5"
1358            );
1359        }
1360    }
1361
1362    #[test]
1363    fn test_qz_host_n2_diagonal() {
1364        // A = diag(3, 7), B = diag(1, 2) => eigenvalues 3.0, 3.5
1365        let mut a = vec![3.0, 0.0, 0.0, 7.0]; // column-major
1366        let mut b = vec![1.0, 0.0, 0.0, 2.0];
1367        let config = QzConfig::new(2, SmVersion::Sm80);
1368        let result = qz_host(&mut a, &mut b, &config);
1369        assert!(result.is_ok());
1370        if let Ok(r) = &result {
1371            assert!(r.converged);
1372            assert_eq!(r.alpha_real.len(), 2);
1373            assert_eq!(r.beta.len(), 2);
1374            // Verify we got two finite eigenvalues (beta != 0).
1375            for bt in &r.beta {
1376                assert!(bt.abs() > 1e-15, "beta should be nonzero");
1377            }
1378        }
1379    }
1380
1381    #[test]
1382    fn test_qz_host_dimension_mismatch() {
1383        let mut a = vec![1.0, 2.0]; // too small for 2×2
1384        let mut b = vec![1.0, 0.0, 0.0, 1.0];
1385        let config = QzConfig::new(2, SmVersion::Sm80);
1386        let result = qz_host(&mut a, &mut b, &config);
1387        assert!(result.is_err());
1388        assert!(matches!(result, Err(SolverError::DimensionMismatch(_))));
1389    }
1390
1391    #[test]
1392    fn test_qz_host_with_schur_vectors() {
1393        let mut a = vec![2.0, 0.0, 0.0, 3.0];
1394        let mut b = vec![1.0, 0.0, 0.0, 1.0];
1395        let config = QzConfig::new(2, SmVersion::Sm80).with_schur_vectors(true);
1396        let result = qz_host(&mut a, &mut b, &config);
1397        assert!(result.is_ok());
1398        if let Ok(r) = &result {
1399            assert!(r.q_matrix.is_some());
1400            assert!(r.z_matrix.is_some());
1401            assert!(r.schur_s.is_some());
1402            assert!(r.schur_t.is_some());
1403        }
1404    }
1405
1406    #[test]
1407    fn test_generate_hessenberg_reduction_ptx() {
1408        let ptx = generate_hessenberg_reduction_ptx(4, SmVersion::Sm80);
1409        assert!(ptx.is_ok());
1410        if let Ok(code) = &ptx {
1411            assert!(code.contains("qz_hessenberg_reduction_4"));
1412        }
1413    }
1414
1415    #[test]
1416    fn test_generate_qz_sweep_ptx() {
1417        let ptx = generate_qz_sweep_ptx(8, SmVersion::Sm86);
1418        assert!(ptx.is_ok());
1419        if let Ok(code) = &ptx {
1420            assert!(code.contains("qz_sweep_8"));
1421        }
1422    }
1423
1424    #[test]
1425    fn test_generate_eigenvalue_extraction_ptx() {
1426        let ptx = generate_eigenvalue_extraction_ptx(4, SmVersion::Sm90);
1427        assert!(ptx.is_ok());
1428        if let Ok(code) = &ptx {
1429            assert!(code.contains("qz_eigenvalue_extract_4"));
1430        }
1431    }
1432
1433    #[test]
1434    fn test_givens_rotation_basic() {
1435        let (cs, sn) = givens_rotation(3.0, 4.0);
1436        let r = cs * 3.0 + sn * 4.0;
1437        assert!((r - 5.0).abs() < 1e-10);
1438        // Verify second component is zeroed.
1439        let zero = -sn * 3.0 + cs * 4.0;
1440        assert!(zero.abs() < 1e-10);
1441    }
1442
1443    #[test]
1444    fn test_givens_rotation_zero_b() {
1445        let (cs, sn) = givens_rotation(5.0, 0.0);
1446        assert!((cs - 1.0).abs() < 1e-15);
1447        assert!(sn.abs() < 1e-15);
1448    }
1449
1450    #[test]
1451    fn test_identity_matrix() {
1452        let id = identity_matrix(3);
1453        assert_eq!(id.len(), 9);
1454        assert!((id[cm(0, 0, 3)] - 1.0).abs() < 1e-15);
1455        assert!((id[cm(1, 1, 3)] - 1.0).abs() < 1e-15);
1456        assert!((id[cm(2, 2, 3)] - 1.0).abs() < 1e-15);
1457        assert!(id[cm(0, 1, 3)].abs() < 1e-15);
1458        assert!(id[cm(1, 0, 3)].abs() < 1e-15);
1459    }
1460
1461    #[test]
1462    fn test_column_major_indexing() {
1463        // cm(row, col, n) = col * n + row
1464        assert_eq!(cm(0, 0, 3), 0);
1465        assert_eq!(cm(1, 0, 3), 1);
1466        assert_eq!(cm(0, 1, 3), 3);
1467        assert_eq!(cm(2, 2, 3), 8);
1468    }
1469
1470    #[test]
1471    fn test_extract_eigenvalues_diagonal() {
1472        // S = diag(2, 5, -1), T = diag(1, 2, 3)
1473        let n = 3;
1474        let mut s = vec![0.0; n * n];
1475        let mut t = vec![0.0; n * n];
1476        s[cm(0, 0, n)] = 2.0;
1477        s[cm(1, 1, n)] = 5.0;
1478        s[cm(2, 2, n)] = -1.0;
1479        t[cm(0, 0, n)] = 1.0;
1480        t[cm(1, 1, n)] = 2.0;
1481        t[cm(2, 2, n)] = 3.0;
1482
1483        let (ar, ai, bt) = extract_eigenvalues(&s, &t, n);
1484        assert_eq!(ar.len(), 3);
1485        // eigenvalue 0: 2/1 = 2
1486        assert!((ar[0] / bt[0] - 2.0).abs() < 1e-10);
1487        // eigenvalue 1: 5/2 = 2.5
1488        assert!((ar[1] / bt[1] - 2.5).abs() < 1e-10);
1489        // eigenvalue 2: -1/3
1490        assert!((ar[2] / bt[2] - (-1.0 / 3.0)).abs() < 1e-10);
1491        // All imaginary parts should be zero.
1492        for &imag in &ai {
1493            assert!(imag.abs() < 1e-15);
1494        }
1495    }
1496
1497    #[test]
1498    fn test_qz_host_n3_upper_triangular() {
1499        // Both A and B already upper triangular.
1500        // A = [[1,2,3],[0,4,5],[0,0,6]], B = [[1,1,1],[0,2,1],[0,0,3]]
1501        #[rustfmt::skip]
1502        let mut a = vec![
1503            1.0, 0.0, 0.0, // col 0
1504            2.0, 4.0, 0.0, // col 1
1505            3.0, 5.0, 6.0, // col 2
1506        ];
1507        #[rustfmt::skip]
1508        let mut b = vec![
1509            1.0, 0.0, 0.0, // col 0
1510            1.0, 2.0, 0.0, // col 1
1511            1.0, 1.0, 3.0, // col 2
1512        ];
1513        let config = QzConfig::new(3, SmVersion::Sm80);
1514        let result = qz_host(&mut a, &mut b, &config);
1515        assert!(result.is_ok());
1516        if let Ok(r) = &result {
1517            assert!(r.converged);
1518            assert_eq!(r.alpha_real.len(), 3);
1519        }
1520    }
1521}