Skip to main content

gam_solve/gpu/
arrow_schur_gpu.rs

1//! Caller-facing thin wrapper around `crate::gpu_kernels::arrow_schur`.
2//!
3//! The entire dense per-row factor + Schur reduce + back-sub pipeline lives
4//! device-side; this module only translates the device failure enum into the
5//! `ArrowSchurError` variant the PIRLS outer loop already understands, so
6//! call-sites do not need to learn the device-specific reason codes.
7//!
8//! ## Dispatch logic for matrix-free systems
9//!
10//! When `solve_arrow_newton_step` returns `GpuRequiresDenseSystem`, the GPU
11//! dense-Schur path is structurally incompatible with the supplied operators.
12//! This wrapper routes such systems to CPU `InexactPCG` — the mode that was
13//! designed precisely for SAE-manifold scale callers that cannot materialise
14//! a dense `K × K` block. No information is lost: `GpuRequiresDenseSystem`
15//! is not a numerical failure, just a capability mismatch, so the CPU solver
16//! receives the full system without escalating any ridge.
17
18use crate::gpu_kernels::arrow_schur::{
19    ArrowSchurGpuFailure, gpu_schur_matvec_backend, solve_arrow_newton_step,
20};
21use crate::arrow_schur::{
22    ArrowSchurError, ArrowSchurSystem, ArrowSolveOptions, ArrowSolverMode,
23};
24use ndarray::Array1;
25
26pub fn solve_arrow_newton_step_gpu(
27    sys: &ArrowSchurSystem,
28    ridge_t: f64,
29    ridge_beta: f64,
30) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
31    match solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
32        Ok(solution) => Ok((solution.delta_t, solution.delta_beta)),
33        Err(ArrowSchurGpuFailure::Unavailable) => {
34            // Mirror the CPU path's failure variant so the outer loop falls
35            // through to its existing recovery logic.
36            sys.solve_with_options(ridge_t, ridge_beta, &ArrowSolveOptions::automatic(sys.k))
37                .map(|(dt, db, _diag)| (dt, db))
38        }
39        Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem { .. }) => {
40            // Matrix-free H_ββ or H_tβ operators present — the dense GPU Schur
41            // path cannot consume them, but the reduced K-system PCG can.
42            // Build the GPU-backed reduced Schur matvec (row-procedural sparse
43            // Kronecker apply over active atoms; per-row latent eliminated via
44            // cached factors) and run `InexactPCG` against it. Only when the
45            // device matvec is genuinely `Unavailable` do we fall back to the
46            // pure-CPU `InexactPCG` matvec.
47            let mut opts = ArrowSolveOptions::automatic(sys.k);
48            opts.mode = ArrowSolverMode::InexactPCG;
49            match gpu_schur_matvec_backend(sys, ridge_t, ridge_beta) {
50                Ok(gpu_matvec) => {
51                    opts.gpu_matvec = Some(gpu_matvec);
52                }
53                Err(ArrowSchurGpuFailure::Unavailable) => {
54                    // No device matvec available; CPU InexactPCG owns the solve.
55                }
56                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
57                    return Err(ArrowSchurError::PerRowFactorFailed {
58                        row,
59                        reason: format!(
60                            "GPU row-procedural factor failed; suggested ridge bump {bump:.3e}"
61                        ),
62                    });
63                }
64                Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem { .. }) => {
65                    // The matvec builder cannot lift this system either; CPU
66                    // InexactPCG matvec handles the reduction.
67                }
68                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
69                    return Err(ArrowSchurError::SchurFactorFailed { reason });
70                }
71            }
72            sys.solve_with_options(ridge_t, ridge_beta, &opts)
73                .map(|(dt, db, _diag)| (dt, db))
74        }
75        Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
76            Err(ArrowSchurError::PerRowFactorFailed {
77                row,
78                reason: format!("GPU Cholesky factor failed; suggested ridge bump {bump:.3e}"),
79            })
80        }
81        Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
82            Err(ArrowSchurError::SchurFactorFailed { reason })
83        }
84    }
85}