gam_solve/gpu/arrow_schur_gpu.rs
1//! Caller-facing thin wrapper around `crate::gpu_kernels::arrow_schur`.
2//!
3//! The entire dense per-row factor + Schur reduce + back-sub pipeline lives
4//! device-side; this module only translates the device failure enum into the
5//! `ArrowSchurError` variant the PIRLS outer loop already understands, so
6//! call-sites do not need to learn the device-specific reason codes.
7//!
8//! ## Dispatch logic for matrix-free systems
9//!
10//! When `solve_arrow_newton_step` returns `GpuRequiresDenseSystem`, the GPU
11//! dense-Schur path is structurally incompatible with the supplied operators.
12//! This wrapper routes such systems to CPU `InexactPCG` — the mode that was
13//! designed precisely for SAE-manifold scale callers that cannot materialise
14//! a dense `K × K` block. No information is lost: `GpuRequiresDenseSystem`
15//! is not a numerical failure, just a capability mismatch, so the CPU solver
16//! receives the full system without escalating any ridge.
17
18use crate::gpu_kernels::arrow_schur::{
19 ArrowSchurGpuFailure, gpu_schur_matvec_backend, solve_arrow_newton_step,
20};
21use crate::arrow_schur::{
22 ArrowSchurError, ArrowSchurSystem, ArrowSolveOptions, ArrowSolverMode,
23};
24use ndarray::Array1;
25
26pub fn solve_arrow_newton_step_gpu(
27 sys: &ArrowSchurSystem,
28 ridge_t: f64,
29 ridge_beta: f64,
30) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
31 match solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
32 Ok(solution) => Ok((solution.delta_t, solution.delta_beta)),
33 Err(ArrowSchurGpuFailure::Unavailable) => {
34 // Mirror the CPU path's failure variant so the outer loop falls
35 // through to its existing recovery logic.
36 sys.solve_with_options(ridge_t, ridge_beta, &ArrowSolveOptions::automatic(sys.k))
37 .map(|(dt, db, _diag)| (dt, db))
38 }
39 Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem { .. }) => {
40 // Matrix-free H_ββ or H_tβ operators present — the dense GPU Schur
41 // path cannot consume them, but the reduced K-system PCG can.
42 // Build the GPU-backed reduced Schur matvec (row-procedural sparse
43 // Kronecker apply over active atoms; per-row latent eliminated via
44 // cached factors) and run `InexactPCG` against it. Only when the
45 // device matvec is genuinely `Unavailable` do we fall back to the
46 // pure-CPU `InexactPCG` matvec.
47 let mut opts = ArrowSolveOptions::automatic(sys.k);
48 opts.mode = ArrowSolverMode::InexactPCG;
49 match gpu_schur_matvec_backend(sys, ridge_t, ridge_beta) {
50 Ok(gpu_matvec) => {
51 opts.gpu_matvec = Some(gpu_matvec);
52 }
53 Err(ArrowSchurGpuFailure::Unavailable) => {
54 // No device matvec available; CPU InexactPCG owns the solve.
55 }
56 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
57 return Err(ArrowSchurError::PerRowFactorFailed {
58 row,
59 reason: format!(
60 "GPU row-procedural factor failed; suggested ridge bump {bump:.3e}"
61 ),
62 });
63 }
64 Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem { .. }) => {
65 // The matvec builder cannot lift this system either; CPU
66 // InexactPCG matvec handles the reduction.
67 }
68 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
69 return Err(ArrowSchurError::SchurFactorFailed { reason });
70 }
71 }
72 sys.solve_with_options(ridge_t, ridge_beta, &opts)
73 .map(|(dt, db, _diag)| (dt, db))
74 }
75 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
76 Err(ArrowSchurError::PerRowFactorFailed {
77 row,
78 reason: format!("GPU Cholesky factor failed; suggested ridge bump {bump:.3e}"),
79 })
80 }
81 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
82 Err(ArrowSchurError::SchurFactorFailed { reason })
83 }
84 }
85}