Skip to main content

gam_solve/gpu_kernels/
arrow_schur.rs

1//! Fully GPU-resident batched Arrow-Schur dense Cholesky solver.
2//!
3//! Implements the square-root Schur form: each local block `D_i = L_i L_i^T`
4//! is factored on device, `u_i = L_i^{-1} g_i` and `Y_i = L_i^{-1} B_i` are
5//! formed by triangular solves, the reduced shared system
6//!     `S_β = C + ρ_β I − Σ_i Y_i^T Y_i,  r_β = -g_β + Σ_i Y_i^T u_i`
7//! is assembled on device, factored once, and the back-substitution
8//!     `w_i = u_i + Y_i · δβ,  L_i^T x_i = w_i,  δt_i = -x_i`
9//! is run on device. Only the final `(δt, δβ, log|H|)` triple is downloaded.
10//!
11//! The current caller (Arrow-Schur Newton step inside PIRLS) feeds uniform
12//! local block size `d` and uniform shared width `k`, so the entire pipeline
13//! is dispatched as a single p-group; per-p grouping for heterogenous blocks
14//! is Layer D's NVRTC fused-kernel concern and lives in this module's
15//! follow-up implementation rather than in policy plumbing.
16//!
17//! On non-Linux builds the entire module degrades to a CPU-fallback shim.
18
19use ndarray::{Array1, Array2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24/// Outcome of a single Arrow-Schur Newton solve.
25pub struct ArrowSchurGpuSolution {
26    pub delta_t: Array1<f64>,
27    pub delta_beta: Array1<f64>,
28    /// Natural log of the determinant of the full bordered Hessian, computed
29    /// from the local Cholesky factors and the Schur factor on device.
30    pub log_det_hessian: f64,
31}
32
33/// Reason a device path declined to run; lets the host caller decide between
34/// CPU fallback and per-row escalation. `RidgeBumpRequired` carries the
35/// estimated diagonal bump needed to clear the failed pivot.
36#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38    /// CUDA runtime unavailable, allocation failed, or workload below policy.
39    Unavailable,
40    /// A row block was not positive definite even after the requested ridge.
41    /// Caller may retry with `ridge_t + bump`.
42    RidgeBumpRequired { row: usize, bump: f64 },
43    /// Shared Schur factor failed; bordered system is rank-deficient at the
44    /// requested ridges and the CPU path should handle escalation.
45    SchurFactorFailed { reason: String },
46    /// The system carries matrix-free `H_ββ` or per-row `H_tβ` operators that
47    /// the dense GPU Schur path cannot consume. The caller should route to CPU
48    /// `InexactPCG` (or supply dense buffers) rather than treating this as a
49    /// numerical failure. See `gpu/arrow_schur.rs` Part B for the planned GPU
50    /// PCG path that will lift this restriction at K ≥ 5000.
51    GpuRequiresDenseSystem {
52        had_hbb_matvec: bool,
53        had_htbeta_matvec: bool,
54    },
55}
56
57/// Safety-margin multiplier on `√(machine ε)` for the diagonal ridge bump
58/// suggested when a local block fails Cholesky.
59///
60/// The estimated bump is `diag_scale · |pivot| · √ε · RIDGE_BUMP_EPS_MARGIN`.
61/// A bare `diag_scale · √ε` ridge is the smallest perturbation that makes a
62/// marginally-indefinite block PD in exact arithmetic, but a single retry at
63/// that magnitude is routinely re-rejected by the next POTRF because the
64/// rounding error of forming `D + ridge·I` and re-factoring is itself O(√ε).
65/// The 1024× headroom (≈ 2¹⁰, i.e. ten extra bits below the f64 mantissa's
66/// 52) clears the pivot on the first retry without materially perturbing the
67/// curvature the Newton step sees. Shared by the per-row scalar path and the
68/// batched-tile path so both suggest an identical bump.
69const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
70
71/// Entry point: attempt the fully device-resident Arrow-Schur Newton solve.
72/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` to indicate "device path
73/// declined, fall back to CPU" — never panics.
74pub fn solve_arrow_newton_step(
75    sys: &ArrowSchurSystem,
76    ridge_t: f64,
77    ridge_beta: f64,
78) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
79    let n = sys.rows.len();
80    let d = sys.d;
81    let k = sys.k;
82
83    // Detect matrix-free operators before any dim() checks so callers get a
84    // clear, actionable error instead of a generic SchurFactorFailed. The GPU
85    // dense-Schur path requires materialised H_ββ and per-row H_tβ slabs;
86    // CPU InexactPCG is the correct fallback when either operator is abstract.
87    let had_hbb_matvec = sys.hbb_matvec.is_some();
88    let had_htbeta_matvec = sys.htbeta_matvec.is_some();
89    if had_hbb_matvec || had_htbeta_matvec {
90        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
91            had_hbb_matvec,
92            had_htbeta_matvec,
93        });
94    }
95
96    if sys.hbb.dim() != (k, k) {
97        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
98            reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
99        });
100    }
101    if n == 0 || d == 0 {
102        return Err(ArrowSchurGpuFailure::Unavailable);
103    }
104    if sys
105        .rows
106        .iter()
107        .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
108    {
109        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
110            reason: "row block dimension mismatch".to_string(),
111        });
112    }
113
114    #[cfg(not(target_os = "linux"))]
115    {
116        if ridge_t.is_nan() || ridge_beta.is_nan() {
117            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
118                reason: "ridge is NaN".to_string(),
119            });
120        }
121        Err(ArrowSchurGpuFailure::Unavailable)
122    }
123
124    #[cfg(target_os = "linux")]
125    {
126        // Multi-GPU: the arrow-Schur solve is row-block separable in its forward
127        // (per-row factor / whiten / partial-Schur) and backward (per-row
128        // back-sub) phases — only the small shared K×K reduce+factor+δβ is
129        // central. When more than one device is usable, split the WHOLE solve at
130        // row-block granularity across all GPUs. The POTRF stays fused with its
131        // dependent TRSM+GEMM on each tile's own stream, so no on-stream solve is
132        // orphaned. On `Unavailable` (one device, shape below policy, transient)
133        // fall through to the single-device fused / Layer-A paths below.
134        if gam_gpu::device_runtime::GpuRuntime::global()
135            .map(gam_gpu::device_runtime::GpuRuntime::device_count)
136            .unwrap_or(0)
137            > 1
138        {
139            match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
140                Ok(sol) => return Ok(sol),
141                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
142                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
143                }
144                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
145                    return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
146                }
147                // Unavailable / GpuRequiresDenseSystem: fall through to the
148                // single-device paths (already shape-validated above).
149                Err(_) => {}
150            }
151        }
152        // Layer D admission: when the system shape passes the
153        // (Σ p³ ≥ 1e5 OR R ≥ 16) heuristic and `p ≤ MAX_FUSED_P`, the fused
154        // NVRTC kernel replaces the cuSOLVER/cuBLAS Layer A+B+C path with a
155        // single per-row block. Layer C↔D parity (math block 3 §16 test 6)
156        // requires both paths to agree to 1e-10 on identical inputs.
157        if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
158            match cuda::solve_fused(sys, ridge_t, ridge_beta) {
159                Ok(sol) => return Ok(sol),
160                // RidgeBumpRequired must surface to the outer escalation loop —
161                // the fused path's pivot diagnostic is identical in semantics
162                // to the cuSOLVER batched POTRF info code.
163                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
164                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
165                }
166                // Any other failure (Unavailable, SchurFactorFailed) falls
167                // through to the unfused path so a flaky NVRTC compile or
168                // shared-mem allocation does not abort the outer Newton step.
169                Err(_) => {}
170            }
171        }
172        cuda::solve(sys, ridge_t, ridge_beta)
173    }
174}
175
176/// Build the stacked column-major D buffer (n local d×d blocks), the stacked
177/// stacked B buffer (n local d×k blocks), and the stacked g buffer
178/// (n local d-vectors) consumed by the device pipeline. Each block is laid
179/// out column-major so a single allocation + `cuMemcpyHtoD` reaches the
180/// device without per-row dispatch overhead.
181#[cfg(target_os = "linux")]
182fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
183    let n = sys.rows.len();
184    let d = sys.d;
185    let k = sys.k;
186    let mut d_buf = Vec::with_capacity(n * d * d);
187    let mut b_buf = Vec::with_capacity(n * d * k);
188    let mut g_buf = Vec::with_capacity(n * d);
189    for row in &sys.rows {
190        pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
191    }
192    (d_buf, b_buf, g_buf)
193}
194
195#[cfg(target_os = "linux")]
196#[inline]
197fn pack_block(
198    row: &crate::arrow_schur::ArrowRowBlock,
199    ridge_t: f64,
200    d: usize,
201    k: usize,
202    d_buf: &mut Vec<f64>,
203    b_buf: &mut Vec<f64>,
204    g_buf: &mut Vec<f64>,
205) {
206    for col in 0..d {
207        for r in 0..d {
208            let mut value = row.htt[[r, col]];
209            if r == col {
210                value += ridge_t;
211            }
212            d_buf.push(value);
213        }
214    }
215    for col in 0..k {
216        for r in 0..d {
217            b_buf.push(row.htbeta[[r, col]]);
218        }
219    }
220    for r in 0..d {
221        g_buf.push(row.gt[r]);
222    }
223}
224
225/// Test-only entry that forces the Layer D + E fused NVRTC path regardless
226/// of the admission heuristic. Used by the V100 Layer C↔D parity test to
227/// drive the fused kernel at small shapes the heuristic would otherwise
228/// route through the cuSOLVER/cuBLAS Layer A+B+C path.
229#[doc(hidden)]
230#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] // `sys` is consumed only by the linux branch
231pub fn solve_arrow_newton_step_fused_force(
232    sys: &ArrowSchurSystem,
233    ridge_t: f64,
234    ridge_beta: f64,
235) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
236    if ridge_t.is_nan() || ridge_beta.is_nan() {
237        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
238            reason: "ridge is NaN".to_string(),
239        });
240    }
241    #[cfg(not(target_os = "linux"))]
242    {
243        // No NVRTC toolchain off linux: the fused path is unconditionally
244        // unavailable. `sys` is consumed only by the linux branch below; the
245        // fn-level cfg_attr allows it to read as unused here without a banned
246        // `let _` binding or a no-op `drop` of the reference.
247        Err(ArrowSchurGpuFailure::Unavailable)
248    }
249    #[cfg(target_os = "linux")]
250    {
251        if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
252            .is_none()
253        {
254            return Err(ArrowSchurGpuFailure::Unavailable);
255        }
256        cuda::solve_fused(sys, ridge_t, ridge_beta)
257    }
258}
259
260/// #1017 Phase 3: a device-resident Arrow-Schur frame whose constant Hessian
261/// blocks (`D = H_tt`, `B = H_tβ`, border `H_ββ`) and their factors stay on the
262/// device across the inner Newton loop. Construct once per frozen gate/basis
263/// frame, then call [`ResidentArrowFrameHandle::solve_gradient`] once per
264/// iterate with the fresh residual gradient — only the `O(n·d + p)` gradient
265/// crosses to the device and only `δ` crosses back, in contrast to
266/// [`solve_arrow_newton_step`] which re-uploads and re-factors the full system
267/// every call. On a non-CUDA host construction returns
268/// `ArrowSchurGpuFailure::Unavailable`.
269pub struct ResidentArrowFrameHandle {
270    #[cfg(target_os = "linux")]
271    inner: cuda::ResidentArrowFrame,
272    #[cfg(not(target_os = "linux"))]
273    _never: std::convert::Infallible,
274}
275
276impl ResidentArrowFrameHandle {
277    /// Upload the constant Hessian blocks and perform the one-time factor work.
278    pub fn new(
279        sys: &ArrowSchurSystem,
280        ridge_t: f64,
281        ridge_beta: f64,
282    ) -> Result<Self, ArrowSchurGpuFailure> {
283        // The dense device path requires materialised blocks, same admission as
284        // `solve_arrow_newton_step`.
285        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
286            return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
287                had_hbb_matvec: sys.hbb_matvec.is_some(),
288                had_htbeta_matvec: sys.htbeta_matvec.is_some(),
289            });
290        }
291        #[cfg(not(target_os = "linux"))]
292        {
293            if ridge_t.is_nan() || ridge_beta.is_nan() {
294                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
295                    reason: "ridge is NaN".to_string(),
296                });
297            }
298            Err(ArrowSchurGpuFailure::Unavailable)
299        }
300        #[cfg(target_os = "linux")]
301        {
302            Ok(Self {
303                inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
304            })
305        }
306    }
307
308    /// Solve `H δ = −gradient` for a fresh gradient reusing the resident factors.
309    pub fn solve_gradient(
310        &self,
311        g_t: &[f64],
312        g_beta: &[f64],
313    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
314        #[cfg(not(target_os = "linux"))]
315        {
316            if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
317                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
318                    reason: "non-finite gradient entry".to_string(),
319                });
320            }
321            Err(ArrowSchurGpuFailure::Unavailable)
322        }
323        #[cfg(target_os = "linux")]
324        {
325            self.inner.solve_gradient(g_t, g_beta)
326        }
327    }
328
329    /// `log|H|` for the frame (constant; depends only on the factored Hessian).
330    #[must_use]
331    pub fn log_det_hessian(&self) -> f64 {
332        #[cfg(not(target_os = "linux"))]
333        {
334            // SAFETY: off-CUDA, `ResidentArrowFrameHandle::new` always returns
335            // `Err(Unavailable)`, so no handle of this type is ever constructed and
336            // this method is statically unreachable on non-Linux targets. A NaN
337            // sentinel would silently corrupt any consumer of the log-determinant,
338            // so fail loudly on the impossible path instead.
339            panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
340        }
341        #[cfg(target_os = "linux")]
342        {
343            self.inner.log_det_hessian()
344        }
345    }
346}
347
348/// Build a GPU-backed Schur matvec closure for CPU-driven PCG at K ≥ 5000.
349///
350/// Runs the fused NVRTC forward kernel once on the dense per-row `H_tβ` slabs
351/// to compute `Y_i = L_i^{-1} H_tβ^(i)` for all rows, persists the `Y_i`
352/// factors in a host-side buffer, and returns an `Arc<dyn Fn(...)>` closure
353/// that computes the full Schur matvec
354///
355/// ```text
356/// S·x = (H_ββ + ridge_beta·I)·x  −  Σ_i Y_i^T (Y_i·x)
357/// ```
358///
359/// each time it is called. At K ≥ 5000 the `Σ_i Y_i^T (Y_i·x)` term
360/// dominates over the host↔device transfer of the K-vector `x`, so the GPU
361/// path is a clear win even with per-iteration transfer.
362///
363/// `H_ββ·x` is evaluated on the CPU using `sys.hbb_matvec` when present (the
364/// matrix-free hook for SAE-manifold scale callers) or the dense `sys.hbb`
365/// block otherwise. The `Y_i` term uses cuBLAS batched GEMV device-side; only
366/// `x` (K doubles) and `out` (K doubles) cross the host↔device boundary per
367/// PCG iteration.
368///
369/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` if CUDA is unavailable or
370/// the system shape is outside the fused kernel's admission range (e.g.
371/// `d > MAX_FUSED_P = 32` or no CUDA context). Callers should fall back to CPU
372/// `InexactPCG` on `Unavailable`.
373///
374/// Returns `Err(ArrowSchurGpuFailure::RidgeBumpRequired)` if a per-row Cholesky
375/// factor failed at the requested `ridge_t`; the outer LM escalation should
376/// bump `ridge_t` and retry.
377///
378/// # Composition with the matrix-free SAE Kronecker operator
379///
380/// When `sys.htbeta_matvec` is set (matrix-free `H_tβ` Kronecker operator),
381/// the dense `H_tβ` slabs are absent — the dense forward kernel above cannot
382/// run, and at `K = 100K` the dense `Y_i = L_i^{-1} H_tβ^(i)` (`d × K` per row)
383/// could not be materialised anyway. Instead, `build_row_procedural_matvec`
384/// returns a row-procedural Schur matvec: per row it gathers
385/// `v_i = H_tβ^(i)·x` through the forward operator (sparse `O(m_i · p)`),
386/// solves `(H_tt^(i) + ρ_t·I)^{-1} v_i` through a pre-computed per-row Cholesky
387/// factor, and scatters `H_βt^(i)·w_i` through the sparse transpose operator
388/// (`O(m_i · p)`, replacing the old `O(K)` column-probe). This is the
389/// row-procedural `a_ik · Φ_k[i,m]` Kronecker apply over the active atoms only.
390pub fn gpu_schur_matvec_backend(
391    sys: &ArrowSchurSystem,
392    ridge_t: f64,
393    ridge_beta: f64,
394) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
395    // Matrix-free H_tβ operator present: drive the row-procedural sparse
396    // Kronecker apply (active atoms only) instead of the dense forward kernel.
397    if sys.htbeta_matvec.is_some() {
398        return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
399    }
400
401    #[cfg(not(target_os = "linux"))]
402    {
403        // No CUDA runtime on non-Linux. NaN ridges are validated to ensure the
404        // same contract as the Linux path.
405        if ridge_t.is_nan() || ridge_beta.is_nan() {
406            return Err(ArrowSchurGpuFailure::Unavailable);
407        }
408        Err(ArrowSchurGpuFailure::Unavailable)
409    }
410
411    #[cfg(target_os = "linux")]
412    {
413        cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
414    }
415}
416
417/// Build a row-procedural reduced-Schur matvec for matrix-free SAE Kronecker
418/// systems, eliminating the per-row latent block via cached per-row Cholesky
419/// factors and applying the cross-block through the sparse forward/transpose
420/// Kronecker operators (active atoms only).
421///
422/// The returned closure evaluates
423/// `S·x = (H_ββ + ρ_β·I)·x − Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x`,
424/// the same reduced Schur complement the dense path forms, but never
425/// materialises the `d × K` cross-block `H_tβ^(i)`: the forward operator
426/// (`out = H_tβ^(i)·x`) and transpose operator (`out += H_βt^(i)·v`) are the
427/// sparse Kronecker gather/scatter from `SaeKroneckerRows`. The per-row factor
428/// of `H_tt^(i) + ρ_t·I` is computed once when the closure is built and reused
429/// across every CG iteration.
430///
431/// Returns `RidgeBumpRequired` if a per-row block is not positive definite at
432/// the requested `ridge_t`; the outer LM escalation bumps `ridge_t` and retries.
433fn build_row_procedural_matvec(
434    sys: &ArrowSchurSystem,
435    ridge_t: f64,
436    ridge_beta: f64,
437) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
438    use std::sync::Arc;
439    let n = sys.rows.len();
440    let k = sys.k;
441    let forward = sys
442        .htbeta_matvec
443        .clone()
444        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
445    let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
446        // A forward operator without its sparse adjoint cannot be applied
447        // row-procedurally; this is a wiring error, surfaced as a Schur failure
448        // so the caller routes to the dense CPU path rather than misreporting a
449        // numerical bump.
450        ArrowSchurGpuFailure::SchurFactorFailed {
451            reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
452                     forward operator installed without its sparse adjoint"
453                .to_string(),
454        }
455    })?;
456
457    // Pre-factor each per-row block H_tt^(i) + ρ_t·I = L_i L_iᵀ on the host.
458    // The blocks are tiny (d_i ≲ 32) and the dense cross-block slabs are
459    // absent, so there is no device forward-kernel work to amortise here; the
460    // GPU win is the reduced K-system solve in `solve_reduced_beta_pcg`.
461    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
462    for (i, row) in sys.rows.iter().enumerate() {
463        let di = row.htt.nrows();
464        if row.htt.ncols() != di || row.gt.len() != di {
465            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
466                reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
467            });
468        }
469        let mut block = row.htt.clone();
470        for r in 0..di {
471            block[[r, r]] += ridge_t;
472        }
473        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
474            .ok_or_else(|| {
475                let scale = row
476                    .htt
477                    .diag()
478                    .iter()
479                    .map(|v| v.abs())
480                    .fold(0.0_f64, f64::max)
481                    .max(1.0);
482                ArrowSchurGpuFailure::RidgeBumpRequired {
483                    row: i,
484                    bump: scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN,
485                }
486            })?;
487        factors.push(factor);
488    }
489
490    // The SAE-manifold β-Hessian lives in the structured penalty operator
491    // (data-fit Gauss-Newton `G ⊗ I_p` + smoothness Kronecker blocks + any
492    // dense analytic-β residual), NOT in the dense `hbb` accumulator — for
493    // matrix-free systems `hbb` is zero/absent. Capture the effective penalty
494    // operator so `H_ββ·x` matches the CPU `schur_matvec` path exactly. The
495    // operator's `matvec` adds (`y += P x`), so seed `out` from the ridge term.
496    let penalty_op = sys.effective_penalty_op();
497    let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
498
499    let closure: crate::arrow_schur::GpuSchurMatvec =
500        Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
501            assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
502            assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
503
504            // (H_ββ + ρ_β·I)·x into out. Seed with the ridge term, then add the
505            // structured penalty-side product (penalty_op.matvec is additive).
506            {
507                let x_slice = x.as_slice().expect("x must be contiguous");
508                let out_slice = out.as_slice_mut().expect("out must be contiguous");
509                for a in 0..k {
510                    out_slice[a] = ridge_beta * x_slice[a];
511                }
512                penalty_op.matvec(x_slice, out_slice);
513            }
514
515            // out -= Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x.
516            //
517            // #1017: this row-procedural reduced-Schur term is the matrix-free
518            // SAE path's matvec hot loop (`build_row_procedural_matvec` is the
519            // host backend `gpu_schur_matvec_backend` returns when the dense
520            // `H_tβ` slabs are absent — the production Qwen shape). At
521            // (n≈2000 rows) it ran SERIALLY on one core and allocated a fresh
522            // length-`K` `neg` plus per-row `v_i`/`w_i` on EVERY CG iteration —
523            // tens of thousands of tiny heap allocations across a solve. Each
524            // row contributes an independent length-`K` scatter, so the sum is
525            // embarrassingly parallel; fan it across rayon over fixed row chunks
526            // and fold the per-chunk length-`K` partials in chunk order so the
527            // f64 reduction is deterministic (bit-identical run-to-run)
528            // regardless of thread scheduling — it agrees with the serial sum up
529            // to ULP-scale chunk reassociation (the #1017 verification gate).
530            // Because that reassociation is a real (if tiny) departure from
531            // serial, the criterion ranking across topology candidates is stable
532            // except for candidates separated by less than the reassociation
533            // margin, where the near-tie winner can flip — not an exact no-move
534            // guarantee (#1211). Stay
535            // sequential below
536            // `SCHUR_MATVEC_PARALLEL_ROW_MIN` rows and when already inside a
537            // rayon worker (the topology race fans candidates with
538            // `run_topology_race_parallel`) — the same nested-rayon guard the
539            // CPU `schur_matvec` uses. Buffers (`v_i`, `neg`) are reused across
540            // rows within a chunk, so the per-row allocation churn is gone.
541            let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
542                && rayon::current_thread_index().is_none();
543            if parallel {
544                use rayon::prelude::*;
545                const CHUNK: usize = 64;
546                let partials: Vec<Array1<f64>> = (0..n)
547                    .into_par_iter()
548                    .chunks(CHUNK)
549                    .map(|idxs| {
550                        // One length-`K` scatter accumulator per chunk; the
551                        // per-row latent vector `v_i` (length `d_i ≲ 32`) is the
552                        // only per-row buffer, sized to the row's own `d_i`.
553                        let mut neg = Array1::<f64>::zeros(k);
554                        for i in idxs {
555                            let di = row_dims[i];
556                            // v_i = H_tβ^(i)·x (sparse Kronecker gather).
557                            let mut v_i = Array1::<f64>::zeros(di);
558                            forward(i, x.view(), &mut v_i);
559                            // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
560                            let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
561                            // neg += H_βt^(i)·w_i (sparse scatter).
562                            transpose(i, w_i.view(), &mut neg);
563                        }
564                        neg
565                    })
566                    .collect();
567                // #1017/#1175 floating-point parity contract: each chunk's row
568                // sum is formed locally, then chunk partials are folded
569                // left-to-right. That makes the parallel row-procedural Schur
570                // term deterministic for a fixed input and chunking, but it is
571                // not required to be bit-identical to the serial path because
572                // the additions are reassociated at chunk boundaries. CPU/GPU
573                // validation should therefore allow ULP-scale drift while
574                // expecting stable run-to-run results.
575                let mut neg = Array1::<f64>::zeros(k);
576                for part in &partials {
577                    for a in 0..k {
578                        neg[a] += part[a];
579                    }
580                }
581                for a in 0..k {
582                    out[a] -= neg[a];
583                }
584            } else {
585                // Serial path: reuse one `neg` and one `v_i` across rows.
586                let mut neg = Array1::<f64>::zeros(k);
587                for i in 0..n {
588                    let di = row_dims[i];
589                    // v_i = H_tβ^(i)·x (sparse Kronecker gather, length d_i).
590                    let mut v_i = Array1::<f64>::zeros(di);
591                    forward(i, x.view(), &mut v_i);
592                    // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
593                    let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
594                    // neg += H_βt^(i)·w_i (sparse scatter); subtract once at end.
595                    transpose(i, w_i.view(), &mut neg);
596                }
597                for a in 0..k {
598                    out[a] -= neg[a];
599                }
600            }
601        });
602
603    Ok(closure)
604}
605
606/// Solve the reduced shared β-system `S·δβ = r` fully on device with a
607/// Jacobi-preconditioned conjugate-gradient (Steihaug truncated-CG) loop.
608///
609/// `S` is the already-reduced symmetric positive-definite `K × K` Schur
610/// complement the streaming SAE joint fit accumulates across minibatches
611/// (`StreamingArrowSchur::take_accumulators` summed over chunks, with the
612/// global β ridge folded in). The per-row latent blocks have already been
613/// eliminated into `S` on the host streaming path; the device's job is the
614/// dense `K`-dimensional solve, which is the dominant cost at `K = 100K`.
615///
616/// The dense `S·p` matvec runs on device via cuBLAS `Dgemv`, and the PCG state
617/// vectors (`x`, `r`, `z`, `p`, `S·p`) remain device-resident for the solve.
618/// Jacobi preconditioning is an elementwise CUDA kernel; only convergence
619/// scalars (`pᵀSp`, `rᵀz`, `‖r‖`) cross the host boundary per iteration, plus the
620/// final solution vector.
621///
622/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` when CUDA is unavailable
623/// or the workload is below the dispatch policy; the caller then runs the CPU
624/// reduced-β solve. Returns `Err(ArrowSchurGpuFailure::SchurFactorFailed)`
625/// when `S` carries a non-positive Jacobi diagonal (caller escalates the
626/// proximal ridge).
627pub fn solve_reduced_beta_pcg(
628    s_acc: &Array2<f64>,
629    rhs_beta: &Array1<f64>,
630    max_iterations: usize,
631    relative_tolerance: f64,
632) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
633    solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
634        .map(|(x, _)| x)
635}
636
637#[doc(hidden)]
638pub fn solve_reduced_beta_pcg_with_diagnostics(
639    s_acc: &Array2<f64>,
640    rhs_beta: &Array1<f64>,
641    max_iterations: usize,
642    relative_tolerance: f64,
643) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
644    let k = rhs_beta.len();
645    if s_acc.dim() != (k, k) {
646        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
647            reason: format!(
648                "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
649                s_acc.dim()
650            ),
651        });
652    }
653    if k == 0 {
654        return Err(ArrowSchurGpuFailure::Unavailable);
655    }
656
657    #[cfg(not(target_os = "linux"))]
658    {
659        if relative_tolerance.is_nan() || max_iterations == 0 {
660            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
661                reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
662            });
663        }
664        Err(ArrowSchurGpuFailure::Unavailable)
665    }
666
667    #[cfg(target_os = "linux")]
668    {
669        cuda::solve_reduced_beta_pcg_with_diagnostics(
670            s_acc,
671            rhs_beta,
672            max_iterations,
673            relative_tolerance,
674        )
675    }
676}
677
678pub fn solve_sae_matrix_free_pcg(
679    sys: &ArrowSchurSystem,
680    data: &DeviceSaePcgData,
681    ridge_t: f64,
682    ridge_beta: f64,
683    rhs_beta: &Array1<f64>,
684    max_iterations: usize,
685    relative_tolerance: f64,
686) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
687    if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
688        return Err(ArrowSchurGpuFailure::Unavailable);
689    }
690    #[cfg(not(target_os = "linux"))]
691    {
692        if ridge_t.is_nan()
693            || ridge_beta.is_nan()
694            || relative_tolerance.is_nan()
695            || max_iterations == 0
696        {
697            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
698                reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
699            });
700        }
701        Err(ArrowSchurGpuFailure::Unavailable)
702    }
703    #[cfg(target_os = "linux")]
704    {
705        // #1017/#1026 dispatch GUARD: framed data (frame metadata present) carries
706        // a factored β border `G ⊗ W_{ij}` data Hessian and dense per-row cross
707        // blocks the legacy `⊗ I_p` kernel CANNOT represent — feeding it framed
708        // data would silently return a WRONG Newton step (it returns Ok with no
709        // fallback). Route framed systems to the dedicated framed kernel and
710        // legacy full-`B` systems to the legacy kernel; the two never cross.
711        if data.frame.is_some() {
712            cuda::solve_sae_matrix_free_pcg_framed(
713                sys,
714                data,
715                ridge_t,
716                ridge_beta,
717                rhs_beta,
718                max_iterations,
719                relative_tolerance,
720            )
721        } else {
722            cuda::solve_sae_matrix_free_pcg(
723                sys,
724                data,
725                ridge_t,
726                ridge_beta,
727                rhs_beta,
728                max_iterations,
729                relative_tolerance,
730            )
731        }
732    }
733}
734
735/// Reference dense back-end used by tests and as the fallback when the
736/// GPU declines. Kept here (not in `arrow_schur_gpu.rs`) so the validation
737/// suite has one canonical baseline.
738#[doc(hidden)]
739pub fn solve_arrow_newton_step_dense_reference(
740    sys: &ArrowSchurSystem,
741    ridge_t: f64,
742    ridge_beta: f64,
743) -> Result<ArrowSchurGpuSolution, String> {
744    let n = sys.rows.len();
745    let d = sys.d;
746    let k = sys.k;
747    let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
748    let mut h = Array2::<f64>::zeros((total, total));
749    let mut rhs = Array1::<f64>::zeros(total);
750    for (i, row) in sys.rows.iter().enumerate() {
751        let base = i * d;
752        for c in 0..d {
753            for r in 0..d {
754                h[[base + r, base + c]] = row.htt[[r, c]];
755            }
756            h[[base + c, base + c]] += ridge_t;
757        }
758        for c in 0..k {
759            for r in 0..d {
760                let value = row.htbeta[[r, c]];
761                h[[base + r, n * d + c]] = value;
762                h[[n * d + c, base + r]] = value;
763            }
764        }
765        for r in 0..d {
766            rhs[base + r] = -row.gt[r];
767        }
768    }
769    for c in 0..k {
770        for r in 0..k {
771            h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
772        }
773        h[[n * d + c, n * d + c]] += ridge_beta;
774        rhs[n * d + c] = -sys.gb[c];
775    }
776    let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
777        .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
778    let mut log_det = 0.0_f64;
779    for i in 0..total {
780        log_det += factor[[i, i]].ln();
781    }
782    log_det *= 2.0;
783    let solved = cholesky_solve_vector(factor.view(), rhs.view());
784    let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
785    let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
786    Ok(ArrowSchurGpuSolution {
787        delta_t,
788        delta_beta,
789        log_det_hessian: log_det,
790    })
791}
792
793/// Frames-engaged reduced-Schur penalty-side matvec `out = (P_ββ + ρ_β I)·x`,
794/// computed purely from the factored device data (issue #1017/#1026). This is
795/// the CPU bit-parity ORACLE for the GPU `arrow_sae_*` penalty kernels on the
796/// frames path: smooth `λ S_k ⊗ I_{r_k}` (each `smooth_blocks[i]` at its
797/// `global_offset` with right-width `frame.smooth_ranks[i]`) plus data-fit
798/// `G_{ij} ⊗ W_{ij}` (each `frame.frame_blocks` entry, with the `μ`-major /
799/// frame-minor index `border_offset[atom] + basis·r + frame_coord`). The
800/// accumulation order matches the device kernels exactly.
801///
802/// `out` is OVERWRITTEN: first set to `ρ_β·x`, then the penalty blocks add in.
803#[doc(hidden)]
804pub fn sae_framed_penalty_matvec_cpu(
805    data: &DeviceSaePcgData,
806    ridge_beta: f64,
807    x: &[f64],
808    out: &mut [f64],
809) {
810    let frame = data
811        .frame
812        .as_ref()
813        .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
814    let k = data.beta_dim;
815    for a in 0..k {
816        out[a] = ridge_beta * x[a];
817    }
818    // Smooth penalty `λ S_k ⊗ I_{r_k}`: y[off + ia·r + ib] += Σ_ja S[ia,ja]·x[off + ja·r + ib].
819    for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
820        let off = blk.global_offset;
821        let m = blk.factor_a.nrows();
822        for i_a in 0..m {
823            for i_b in 0..r {
824                let mut acc = 0.0_f64;
825                for j_a in 0..m {
826                    let s = blk.factor_a[[i_a, j_a]];
827                    if s == 0.0 {
828                        continue;
829                    }
830                    acc += s * x[off + j_a * r + i_b];
831                }
832                out[off + i_a * r + i_b] += acc;
833            }
834        }
835    }
836    // Data-fit penalty `G_{ij} ⊗ W_{ij}`.
837    for blk in &frame.frame_blocks {
838        let r_i = frame.ranks[blk.atom_i];
839        let r_j = frame.ranks[blk.atom_j];
840        let off_i = frame.border_offsets[blk.atom_i];
841        let off_j = frame.border_offsets[blk.atom_j];
842        let (m_i, m_j) = blk.g.dim();
843        for li in 0..m_i {
844            let yi_base = off_i + li * r_i;
845            for lj in 0..m_j {
846                let g = blk.g[[li, lj]];
847                if g == 0.0 {
848                    continue;
849                }
850                let xj_base = off_j + lj * r_j;
851                for a in 0..r_i {
852                    let mut acc = 0.0_f64;
853                    for b in 0..r_j {
854                        acc += blk.w[[a, b]] * x[xj_base + b];
855                    }
856                    out[yi_base + a] += g * acc;
857                }
858            }
859        }
860    }
861}
862
863/// Frames-engaged FULL reduced-Schur matvec `out = S·x` purely from the device
864/// data, where `S = (P_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)`
865/// (issue #1017/#1026). The penalty side is [`sae_framed_penalty_matvec_cpu`];
866/// the per-row reduced term reads the dense `frame.row_htbeta[i]`
867/// (`q_i × border_dim`, row-major), solves against the row's
868/// `H_tt^(i)+ρ_t I` Cholesky factor, and scatters the transpose back. This is
869/// the size-independent bit-parity oracle the device kernel mirrors; it is also
870/// the matvec the GPU PCG iterates.
871#[doc(hidden)]
872pub fn sae_framed_schur_matvec_cpu(
873    sys: &ArrowSchurSystem,
874    data: &DeviceSaePcgData,
875    ridge_t: f64,
876    ridge_beta: f64,
877    x: &[f64],
878    out: &mut [f64],
879) -> Result<(), String> {
880    let frame = data
881        .frame
882        .as_ref()
883        .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
884    let k = data.beta_dim;
885    sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
886    if frame.row_htbeta.len() != sys.rows.len() {
887        return Err(format!(
888            "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
889            frame.row_htbeta.len(),
890            sys.rows.len()
891        ));
892    }
893    for (i, row) in sys.rows.iter().enumerate() {
894        let slab = &frame.row_htbeta[i];
895        if slab.is_empty() {
896            continue;
897        }
898        let qi = sys.row_dims[i];
899        if qi == 0 || slab.len() != qi * k {
900            continue;
901        }
902        // h = H_tβ^(i) · x  (length q_i).
903        let mut h = vec![0.0_f64; qi];
904        for c in 0..qi {
905            let base = c * k;
906            let mut acc = 0.0_f64;
907            for a in 0..k {
908                acc += slab[base + a] * x[a];
909            }
910            h[c] = acc;
911        }
912        // solve (H_tt^(i)+ρ_t I) s = h.
913        let mut block = row.htt.clone();
914        for d in 0..qi {
915            block[[d, d]] += ridge_t;
916        }
917        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
918            .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
919        let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
920        // out -= H_βt^(i) · s = (H_tβ^(i))ᵀ · s.
921        for c in 0..qi {
922            let sc = s[c];
923            if sc == 0.0 {
924                continue;
925            }
926            let base = c * k;
927            for a in 0..k {
928                out[a] -= slab[base + a] * sc;
929            }
930        }
931    }
932    Ok(())
933}
934
935#[cfg(target_os = "linux")]
936mod cuda {
937    use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
938    use gam_gpu::driver::to_i32;
939    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
940    use crate::arrow_schur::{
941        ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
942    };
943    use cudarc::cublas::sys::{
944        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
945    };
946    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
947    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
948    use cudarc::driver::{
949        CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
950        PushKernelArg,
951    };
952    use ndarray::Array1;
953    use std::sync::{Arc, OnceLock};
954
955    /// Per-row work slot for the row-block-granular multi-GPU solve. Inputs are
956    /// the packed single-row buffers (`d×d` D block + ρ_t ridge, `d×k` B block,
957    /// `d` g vector); the forward pass fills the whitened factors `l/u/y` and the
958    /// per-tile reduction lands in the tile's leading slot.
959    struct RowSlot {
960        // Inputs (packed once on the host, column-major).
961        d_block: Vec<f64>, // d*d
962        b_block: Vec<f64>, // d*k
963        g_vec: Vec<f64>,   // d
964        diag_scale: f64,   // |diag(H_tt)| scale for the ridge-bump diagnostic
965        // Forward outputs, kept on the host for the back-sub pass.
966        l_block: Vec<f64>, // d*d lower factor, column-major
967        u_vec: Vec<f64>,   // d   (= L^{-1} g)
968        y_block: Vec<f64>, // d*k (= L^{-1} B), column-major
969        log_det_local: f64,
970        // Set on a non-PD pivot so the orchestrator can raise RidgeBumpRequired
971        // for the offending global row instead of silently falling back.
972        bump: Option<f64>,
973        // Tile-level reduction, written into the tile's first slot only.
974        tile_partial_schur: Option<Vec<f64>>, // k*k col-major, = Σ Y_iᵀY_i
975        tile_partial_rhs: Option<Vec<f64>>,   // k, = Σ Y_iᵀu_i
976        // Back-sub output for this row.
977        delta_t_block: Vec<f64>, // d
978    }
979
980    /// Row-block-granular multi-GPU Arrow-Schur Newton solve.
981    ///
982    /// The solve is separable across row blocks in both phases:
983    ///   * forward — each row's local Cholesky `L_i`, whitening
984    ///     `u_i = L_i⁻¹g_i`, `Y_i = L_i⁻¹B_i`, and partial Schur
985    ///     `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` are independent;
986    ///   * backward — `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)` is independent.
987    /// Only the small shared `K×K` reduce + factor + `δβ` solve is central.
988    ///
989    /// `gam_gpu::pool::scatter_batched` hands each device a contiguous row
990    /// tile on its own bound context/stream; the per-tile forward keeps the
991    /// POTRF fused with its dependent TRSM + Schur GEMM on that one stream, so no
992    /// on-stream solve is orphaned. Tile partials and per-tile `log|L|` are
993    /// reduced on the host (in tile/row order), `S_β` is factored on the primary
994    /// device, and the back-sub is scattered back across the same tiles.
995    ///
996    /// Returns `Unavailable` (caller uses a single-device path) when the system
997    /// carries matrix-free operators, the shared block is not dense `K×K`, the
998    /// pool is single-device, or any tile's device work declines. A non-PD tip
999    /// block surfaces as `RidgeBumpRequired` for the precise global row.
1000    pub(super) fn solve_multi_gpu(
1001        sys: &ArrowSchurSystem,
1002        ridge_t: f64,
1003        ridge_beta: f64,
1004    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1005        let n = sys.rows.len();
1006        let d = sys.d;
1007        let k = sys.k;
1008        if n == 0 || d == 0 || k == 0 {
1009            return Err(ArrowSchurGpuFailure::Unavailable);
1010        }
1011        // Dense shared block + materialised per-row slabs are required; the
1012        // public entry already rejected matrix-free operators, but re-check so
1013        // this routine is safe in isolation.
1014        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1015            return Err(ArrowSchurGpuFailure::Unavailable);
1016        }
1017
1018        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1019            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1020        if runtime.device_count() < 2 {
1021            return Err(ArrowSchurGpuFailure::Unavailable);
1022        }
1023
1024        // Pack one slot per row (column-major), folding ρ_t into each D block.
1025        let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1026        for row in &sys.rows {
1027            if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1028                return Err(ArrowSchurGpuFailure::Unavailable);
1029            }
1030            let mut d_block = Vec::with_capacity(d * d);
1031            let mut b_block = Vec::with_capacity(d * k);
1032            let mut g_vec = Vec::with_capacity(d);
1033            pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1034            let diag_scale = row
1035                .htt
1036                .diag()
1037                .iter()
1038                .map(|v| v.abs())
1039                .fold(0.0_f64, f64::max)
1040                .max(1.0);
1041            slots.push(RowSlot {
1042                d_block,
1043                b_block,
1044                g_vec,
1045                diag_scale,
1046                l_block: Vec::new(),
1047                u_vec: Vec::new(),
1048                y_block: Vec::new(),
1049                log_det_local: 0.0,
1050                bump: None,
1051                tile_partial_schur: None,
1052                tile_partial_rhs: None,
1053                delta_t_block: vec![0.0; d],
1054            });
1055        }
1056
1057        // ---- Forward pass: per-device row tile, fused on its own stream ----
1058        let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1059            forward_tile(ordinal, d, k, tile)
1060        });
1061        if forward_ok.is_none() {
1062            return Err(ArrowSchurGpuFailure::Unavailable);
1063        }
1064
1065        // Surface a non-PD tip block as a precise per-row ridge bump.
1066        let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1067        if let Some((row, bump)) = slots
1068            .iter()
1069            .enumerate()
1070            .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1071        {
1072            return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1073        }
1074
1075        // ---- Central: reduce tile partials → S_β, r_β; factor; solve δβ ----
1076        // Seed S_β with H_ββ + ρ_β I (column-major) and r_β with -g_β, then fold
1077        // in the per-tile partials in tile order so the reduction order tracks
1078        // the single-device accumulation (up to inter-tile reassociation).
1079        let mut schur_host = vec![0.0_f64; k * k];
1080        for col in 0..k {
1081            for row in 0..k {
1082                let mut v = sys.hbb[[row, col]];
1083                if row == col {
1084                    v += ridge_beta;
1085                }
1086                schur_host[col * k + row] = v;
1087            }
1088        }
1089        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1090        let mut log_det = 0.0_f64;
1091        for start in tile_starts(&row_base_of_tile) {
1092            let slot = &slots[start];
1093            let partial_schur = slot
1094                .tile_partial_schur
1095                .as_ref()
1096                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1097            let partial_rhs = slot
1098                .tile_partial_rhs
1099                .as_ref()
1100                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1101            // `accumulate_schur` writes `partial_schur = -Σ_tile Y_iᵀY_i` (GEMM
1102            // α=-1, β=1 into a zero seed) and `partial_rhs = +Σ_tile Y_iᵀu_i`.
1103            // The reduced Schur is `S = (H_ββ+ρI) − Σ_all Y_iᵀY_i`, so adding the
1104            // (already-negated) partials reproduces the single-device sign.
1105            for idx in 0..k * k {
1106                schur_host[idx] += partial_schur[idx];
1107            }
1108            for a in 0..k {
1109                rhs_host[a] += partial_rhs[a];
1110            }
1111        }
1112        for slot in &slots {
1113            log_det += slot.log_det_local;
1114        }
1115
1116        // Factor S_β and solve δβ on the primary device (small K×K leaf). The
1117        // stream carries the primary context (same pattern as `solve()`); no
1118        // thread bind is needed for the cuSOLVER/cuBLAS handles created from it.
1119        let primary = runtime.selected_device().ordinal;
1120        let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1121            .and_then(|ctx| ctx.new_stream().ok())
1122            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1123        let solver =
1124            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1125        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1126        let mut schur_dev = stream
1127            .clone_htod(&schur_host)
1128            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1129        let mut rhs_dev = stream
1130            .clone_htod(&rhs_host)
1131            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1132        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1133        if info != 0 {
1134            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1135                reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1136            });
1137        }
1138        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1139        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1140        let delta_beta_host = stream
1141            .clone_dtoh(&rhs_dev)
1142            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1143        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1144        let l_schur_host = stream
1145            .clone_dtoh(&schur_dev)
1146            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1147        for j in 0..k {
1148            log_det += l_schur_host[j * k + j].ln();
1149        }
1150        log_det *= 2.0;
1151
1152        // ---- Backward pass: δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ), per-device tile ----
1153        let delta_beta_ref = &delta_beta_host;
1154        let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1155            back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1156        });
1157        if back_ok.is_none() {
1158            return Err(ArrowSchurGpuFailure::Unavailable);
1159        }
1160
1161        // Stitch per-row δt into the stacked (n*d) result.
1162        let mut delta_t = Array1::<f64>::zeros(n * d);
1163        for (i, slot) in slots.iter().enumerate() {
1164            let base = i * d;
1165            for r in 0..d {
1166                delta_t[base + r] = slot.delta_t_block[r];
1167            }
1168        }
1169
1170        Ok(ArrowSchurGpuSolution {
1171            delta_t,
1172            delta_beta,
1173            log_det_hessian: log_det,
1174        })
1175    }
1176
1177    /// Tile starts: the leading global row index of each device tile (where the
1178    /// tile-level partial reduction was written by the forward pass).
1179    fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1180        tiles.iter().map(|(_, range)| range.start)
1181    }
1182
1183    /// Forward pass for one device row tile, running on `ordinal`'s bound stream.
1184    /// Factors each row block, whitens `u`/`Y`, accumulates the tile's partial
1185    /// Schur `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` into the tile's leading slot, keeps the
1186    /// per-row `L`/`u`/`Y` on the host for back-sub, and records the per-row
1187    /// `Σ_j log L_jj`. A non-PD pivot is recorded in `slot.bump` (the tile still
1188    /// returns `Some(())` so the orchestrator raises a precise `RidgeBumpRequired`
1189    /// rather than collapsing the whole batch to CPU).
1190    fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1191        if tile.is_empty() {
1192            return Some(());
1193        }
1194        // `scatter_batched` has already bound this ordinal's context on this
1195        // worker thread; the stream below targets that same device.
1196        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1197            .and_then(|ctx| ctx.new_stream().ok())?;
1198        let solver = DnHandle::new(stream.clone()).ok()?;
1199        let blas = CudaBlas::new(stream.clone()).ok()?;
1200        let m = tile.len();
1201
1202        // Stack the tile's D, B, g into contiguous device buffers (same layout
1203        // the single-device path packs for `m` rows).
1204        let mut d_host = Vec::with_capacity(m * d * d);
1205        let mut b_host = Vec::with_capacity(m * d * k);
1206        let mut g_host = Vec::with_capacity(m * d);
1207        for slot in tile.iter() {
1208            d_host.extend_from_slice(&slot.d_block);
1209            b_host.extend_from_slice(&slot.b_block);
1210            g_host.extend_from_slice(&slot.g_vec);
1211        }
1212        let mut d_dev = stream.clone_htod(&d_host).ok()?;
1213        let mut b_dev = stream.clone_htod(&b_host).ok()?;
1214        let mut g_dev = stream.clone_htod(&g_host).ok()?;
1215
1216        // Batched POTRF; a non-PD block records its bump and stops the tile.
1217        let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1218        if let Some(local) = info_host.iter().position(|info| *info != 0) {
1219            let pivot = info_host[local];
1220            tile[local].bump = Some(
1221                tile[local].diag_scale
1222                    * (f64::from(pivot).abs()).max(1.0)
1223                    * f64::EPSILON.sqrt()
1224                    * super::RIDGE_BUMP_EPS_MARGIN,
1225            );
1226            return Some(());
1227        }
1228
1229        // Whiten: u = L⁻¹ g, Y = L⁻¹ B.
1230        trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1231        trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1232
1233        // Tile partial Schur: zero-seeded so the host adds the H_ββ seed once.
1234        let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1235        let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1236        accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1237
1238        // Download L, u, Y, and the tile partials.
1239        let l_host = stream.clone_dtoh(&d_dev).ok()?;
1240        let u_host = stream.clone_dtoh(&g_dev).ok()?;
1241        let y_host = stream.clone_dtoh(&b_dev).ok()?;
1242        let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1243        let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1244
1245        for (local, slot) in tile.iter_mut().enumerate() {
1246            let l_base = local * d * d;
1247            let u_base = local * d;
1248            let y_base = local * d * k;
1249            slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1250            slot.u_vec = u_host[u_base..u_base + d].to_vec();
1251            slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1252            let mut log_det_local = 0.0_f64;
1253            for j in 0..d {
1254                log_det_local += l_host[l_base + j * d + j].ln();
1255            }
1256            slot.log_det_local = log_det_local;
1257        }
1258        tile[0].tile_partial_schur = Some(partial_schur);
1259        tile[0].tile_partial_rhs = Some(partial_rhs);
1260        Some(())
1261    }
1262
1263    /// Back-substitution for one device row tile: `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)`.
1264    /// Re-uploads the tile's kept `L`/`u`/`Y` to `ordinal`, applies the GEMV
1265    /// accumulate + transposed TRSM, and writes each row's `δt` into its slot.
1266    fn back_sub_tile(
1267        ordinal: usize,
1268        d: usize,
1269        k: usize,
1270        delta_beta: &[f64],
1271        tile: &mut [RowSlot],
1272    ) -> Option<()> {
1273        if tile.is_empty() {
1274            return Some(());
1275        }
1276        // `scatter_batched` has already bound this ordinal's context on this
1277        // worker thread; the stream below targets that same device.
1278        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1279            .and_then(|ctx| ctx.new_stream().ok())?;
1280        let blas = CudaBlas::new(stream.clone()).ok()?;
1281        let m = tile.len();
1282
1283        let mut l_host = Vec::with_capacity(m * d * d);
1284        let mut u_host = Vec::with_capacity(m * d);
1285        let mut y_host = Vec::with_capacity(m * d * k);
1286        for slot in tile.iter() {
1287            l_host.extend_from_slice(&slot.l_block);
1288            u_host.extend_from_slice(&slot.u_vec);
1289            y_host.extend_from_slice(&slot.y_block);
1290        }
1291        let d_dev = stream.clone_htod(&l_host).ok()?;
1292        let mut g_dev = stream.clone_htod(&u_host).ok()?;
1293        let b_dev = stream.clone_htod(&y_host).ok()?;
1294        let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1295
1296        // g ← u + Y·δβ, then x = L⁻ᵀ g; δt = -x.
1297        accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1298        trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1299        let x_host = stream.clone_dtoh(&g_dev).ok()?;
1300        for (local, slot) in tile.iter_mut().enumerate() {
1301            let base = local * d;
1302            for r in 0..d {
1303                slot.delta_t_block[r] = -x_host[base + r];
1304            }
1305        }
1306        Some(())
1307    }
1308
1309    pub(super) fn solve(
1310        sys: &ArrowSchurSystem,
1311        ridge_t: f64,
1312        ridge_beta: f64,
1313    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1314        let n = sys.rows.len();
1315        let d = sys.d;
1316        let k = sys.k;
1317        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1318            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1319
1320        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1321            .and_then(|ctx| ctx.new_stream().ok())
1322            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1323        let solver =
1324            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1325        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1326
1327        // ----- Pack + upload D, B, g -----
1328        let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1329        let mut d_dev = stream
1330            .clone_htod(&d_host)
1331            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1332        let mut b_dev = stream
1333            .clone_htod(&b_host)
1334            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1335        let mut g_dev = stream
1336            .clone_htod(&g_host)
1337            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1338
1339        // ----- Layer A: batched lower Cholesky of D in place -----
1340        // This POTRF is fused with the downstream TRSM + Schur GEMM + back-sub
1341        // on this one stream, so splitting only the POTRF across devices would
1342        // orphan the dependent on-stream solves. Multi-GPU here is the
1343        // whole-solve row-block split in `solve_arrow_newton_step` (see
1344        // `solve_multi_gpu`), not a per-layer split — this device-resident path
1345        // is the single-device leaf the split dispatches per tile.
1346        let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1347        if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1348            let pivot = info_host[idx];
1349            let scale = sys.rows[idx]
1350                .htt
1351                .diag()
1352                .iter()
1353                .map(|v| v.abs())
1354                .fold(0.0_f64, f64::max)
1355                .max(1.0);
1356            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1357                row: idx,
1358                bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
1359            });
1360        }
1361
1362        // ----- Layer B (1/2): in-place triangular solves -----
1363        // u_i = L_i^{-1} g_i, packed as a stacked (n*d) column-vector.
1364        trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1365        // Y_i = L_i^{-1} B_i, in place over the (n*d) × k buffer (laid out as
1366        // n stacked column-major d×k tiles).
1367        trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1368
1369        // ----- Layer B (2/2): Schur reduction via single big GEMM / GEMV -----
1370        // Y_all is (n*d) × k column-major: viewing all n stacked d×k tiles as
1371        // one big matrix is bit-exact because each tile is column-major with
1372        // leading dim d and the tiles are contiguous in memory, so the
1373        // combined leading dim is n*d only for the *outer* matrix view. To
1374        // make the single-GEMM equivalence hold we must treat the stacked
1375        // buffer as (n*d) × k column-major with leading dim = n*d, which
1376        // means columns of Y_all are interleaved by row across blocks.
1377        // That is NOT what we packed. So we use the cuBLAS stride pattern
1378        // instead: stride-by-block, transpose-A, and *accumulate* into one
1379        // S_β buffer via beta=1 across batches. Equivalent flop count, no
1380        // extra reduction kernel, and correct layout.
1381        //
1382        // Concretely: schur ← C + ρ_β I; rhs ← -g_β; then for each block
1383        //   schur -= Y_i^T Y_i      (k×k)
1384        //   rhs   += Y_i^T u_i      (k)
1385        // We launch this as `n` sequential GEMMs/GEMVs with beta=1 on the
1386        // accumulator. Layer D fuses these into one NVRTC launch.
1387        let schur_init: Vec<f64> = {
1388            let mut tmp = Vec::with_capacity(k * k);
1389            for col in 0..k {
1390                for row in 0..k {
1391                    let mut v = sys.hbb[[row, col]];
1392                    if row == col {
1393                        v += ridge_beta;
1394                    }
1395                    tmp.push(v);
1396                }
1397            }
1398            tmp
1399        };
1400        let mut schur_dev = stream
1401            .clone_htod(&schur_init)
1402            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1403        let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1404        let mut rhs_dev = stream
1405            .clone_htod(&rhs_init)
1406            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1407
1408        accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1409
1410        // ----- Layer C (1/2): factor S_β and solve for δβ -----
1411        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1412        if info != 0 {
1413            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1414                reason: format!("Schur Cholesky failed at pivot {info}"),
1415            });
1416        }
1417        // δβ ← L_S^{-T} L_S^{-1} rhs
1418        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1419        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1420        let delta_beta_host = stream
1421            .clone_dtoh(&rhs_dev)
1422            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1423        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1424
1425        // ----- Layer C (2/2): back-sub δt_i = -L_i^{-T} (u_i + Y_i δβ) -----
1426        // Already on device:
1427        //   g_dev holds u_i stacked (n*d).
1428        //   b_dev holds Y_i stacked column-major n×(d×k) tiles.
1429        // Compute g_dev ← g_dev + Y_block · δβ per block (cuBLAS gemv with beta=1),
1430        // then in-place trsm with L_i^T (CUBLAS_OP_T) to obtain x_i, and finally
1431        // δt_i = -x_i on host after download.
1432        accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1433        trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1434
1435        let x_host = stream
1436            .clone_dtoh(&g_dev)
1437            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1438        let mut delta_t = Array1::<f64>::zeros(n * d);
1439        for (i, v) in x_host.iter().enumerate() {
1440            delta_t[i] = -*v;
1441        }
1442
1443        // ----- log|H| = 2 Σ log L_{i,jj} + 2 Σ log R_{β,aa} -----
1444        let l_local_host = stream
1445            .clone_dtoh(&d_dev)
1446            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1447        let l_schur_host = stream
1448            .clone_dtoh(&schur_dev)
1449            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1450        let mut log_det = 0.0_f64;
1451        for i in 0..n {
1452            let base = i * d * d;
1453            for j in 0..d {
1454                log_det += l_local_host[base + j * d + j].ln();
1455            }
1456        }
1457        for j in 0..k {
1458            log_det += l_schur_host[j * k + j].ln();
1459        }
1460        log_det *= 2.0;
1461
1462        Ok(ArrowSchurGpuSolution {
1463            delta_t,
1464            delta_beta,
1465            log_det_hessian: log_det,
1466        })
1467    }
1468
1469    fn potrf_batched(
1470        solver: &DnHandle,
1471        stream: &Arc<CudaStream>,
1472        p: usize,
1473        batch: usize,
1474        matrices: &mut CudaSlice<f64>,
1475    ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1476        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1477        let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1478        let matrix_len = p * p;
1479        let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1480        let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1481        let mut ptrs = Vec::with_capacity(batch);
1482        for idx in 0..batch {
1483            ptrs.push(base_ptr + (idx as u64) * bytes_per);
1484        }
1485        let mut ptrs_dev = stream
1486            .clone_htod(&ptrs)
1487            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1488        let mut info_dev = stream
1489            .alloc_zeros::<i32>(batch)
1490            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1491        let status = {
1492            let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1493            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1494            // SAFETY: pointer array and info buffer live on the device,
1495            // matrices_dev holds `batch` contiguous p×p column-major blocks.
1496            unsafe {
1497                cusolver_sys::cusolverDnDpotrfBatched(
1498                    solver.cu(),
1499                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1500                    p_i,
1501                    ptrs_ptr as *mut *mut f64,
1502                    p_i,
1503                    info_ptr as *mut i32,
1504                    batch_i,
1505                )
1506            }
1507        };
1508        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1509            return Err(ArrowSchurGpuFailure::Unavailable);
1510        }
1511        stream
1512            .clone_dtoh(&info_dev)
1513            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1514    }
1515
1516    fn potrf_single(
1517        solver: &DnHandle,
1518        stream: &Arc<CudaStream>,
1519        p: usize,
1520        matrix: &mut CudaSlice<f64>,
1521    ) -> Result<i32, ArrowSchurGpuFailure> {
1522        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1523        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1524        let mut lwork = 0_i32;
1525        {
1526            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1527            // SAFETY: buffer query against a live p-by-p column-major device matrix.
1528            let status = unsafe {
1529                cusolver_sys::cusolverDnDpotrf_bufferSize(
1530                    solver.cu(),
1531                    uplo,
1532                    p_i,
1533                    mat_ptr as *mut f64,
1534                    p_i,
1535                    &mut lwork,
1536                )
1537            };
1538            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1539                return Err(ArrowSchurGpuFailure::Unavailable);
1540            }
1541        }
1542        let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1543        let mut workspace = stream
1544            .alloc_zeros::<f64>(lwork_usize.max(1))
1545            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1546        let mut info_dev = stream
1547            .alloc_zeros::<i32>(1)
1548            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1549        {
1550            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1551            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1552            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1553            // SAFETY: all three pointers refer to live, correctly sized device buffers.
1554            let status = unsafe {
1555                cusolver_sys::cusolverDnDpotrf(
1556                    solver.cu(),
1557                    uplo,
1558                    p_i,
1559                    mat_ptr as *mut f64,
1560                    p_i,
1561                    work_ptr as *mut f64,
1562                    lwork,
1563                    info_ptr as *mut i32,
1564                )
1565            };
1566            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1567                return Err(ArrowSchurGpuFailure::Unavailable);
1568            }
1569        }
1570        let info_host = stream
1571            .clone_dtoh(&info_dev)
1572            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1573        Ok(info_host[0])
1574    }
1575
1576    /// In-place lower-triangular solves `X_i ← L_i^{-1} X_i` over the n stacked
1577    /// d×nrhs RHS tiles in `rhs`. Uses `cublasDtrsmBatched` so all n solves
1578    /// hit the device in one launch.
1579    fn trsm_batched_lower_inplace(
1580        blas: &CudaBlas,
1581        stream: &Arc<CudaStream>,
1582        d: usize,
1583        n: usize,
1584        nrhs: usize,
1585        l_stack: &CudaSlice<f64>,
1586        rhs_stack: &mut CudaSlice<f64>,
1587    ) -> Result<(), ArrowSchurGpuFailure> {
1588        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1589    }
1590
1591    /// As above but with `L_i^T` instead of `L_i`.
1592    fn trsm_batched_lower_inplace_transposed(
1593        blas: &CudaBlas,
1594        stream: &Arc<CudaStream>,
1595        d: usize,
1596        n: usize,
1597        nrhs: usize,
1598        l_stack: &CudaSlice<f64>,
1599        rhs_stack: &mut CudaSlice<f64>,
1600    ) -> Result<(), ArrowSchurGpuFailure> {
1601        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1602    }
1603
1604    fn trsm_batched_inplace_inner(
1605        blas: &CudaBlas,
1606        stream: &Arc<CudaStream>,
1607        d: usize,
1608        n: usize,
1609        nrhs: usize,
1610        l_stack: &CudaSlice<f64>,
1611        rhs_stack: &mut CudaSlice<f64>,
1612        transposed: bool,
1613    ) -> Result<(), ArrowSchurGpuFailure> {
1614        let alpha = 1.0_f64;
1615        let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1616        let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1617        let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1618        let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1619        let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1620        let (l_base, _l_record) = l_stack.device_ptr(stream);
1621        let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1622        let mut l_ptrs = Vec::with_capacity(n);
1623        let mut rhs_ptrs = Vec::with_capacity(n);
1624        for i in 0..n {
1625            l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1626            rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1627        }
1628        let mut l_ptrs_dev = stream
1629            .clone_htod(&l_ptrs)
1630            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1631        let mut rhs_ptrs_dev = stream
1632            .clone_htod(&rhs_ptrs)
1633            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1634        let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1635        let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1636        let op = if transposed {
1637            cublasOperation_t::CUBLAS_OP_T
1638        } else {
1639            cublasOperation_t::CUBLAS_OP_N
1640        };
1641        let handle = *blas.handle();
1642        // SAFETY: pointer arrays and base buffers were just constructed from
1643        // live device allocations covering the entire batch.
1644        let status = unsafe {
1645            cudarc::cublas::sys::cublasDtrsmBatched(
1646                handle,
1647                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1648                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1649                op,
1650                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1651                d_i,
1652                nrhs_i,
1653                &alpha,
1654                l_ptrs_ptr as *const *const f64,
1655                d_i,
1656                rhs_ptrs_ptr as *const *mut f64,
1657                d_i,
1658                batch_i,
1659            )
1660        };
1661        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1662            return Err(ArrowSchurGpuFailure::Unavailable);
1663        }
1664        Ok(())
1665    }
1666
1667    /// Single-matrix lower-triangular solve: `rhs ← L^{-1} rhs` (or
1668    /// `L^{-T} rhs` if `transposed`). For the Schur Cholesky back-sub.
1669    fn trsm_single(
1670        blas: &CudaBlas,
1671        stream: &Arc<CudaStream>,
1672        n: usize,
1673        l: &CudaSlice<f64>,
1674        rhs: &mut CudaSlice<f64>,
1675        upper: bool,
1676        transposed: bool,
1677    ) -> Result<(), ArrowSchurGpuFailure> {
1678        let alpha = 1.0_f64;
1679        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1680        let handle = *blas.handle();
1681        let (l_ptr, _l_rec) = l.device_ptr(stream);
1682        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1683        // SAFETY: single n×n lower factor and n-vector RHS on device.
1684        let status = unsafe {
1685            cudarc::cublas::sys::cublasDtrsm_v2(
1686                handle,
1687                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1688                if upper {
1689                    cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1690                } else {
1691                    cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1692                },
1693                if transposed {
1694                    cublasOperation_t::CUBLAS_OP_T
1695                } else {
1696                    cublasOperation_t::CUBLAS_OP_N
1697                },
1698                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1699                n_i,
1700                1,
1701                &alpha,
1702                l_ptr as *const f64,
1703                n_i,
1704                rhs_ptr as *mut f64,
1705                n_i,
1706            )
1707        };
1708        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1709            return Err(ArrowSchurGpuFailure::Unavailable);
1710        }
1711        Ok(())
1712    }
1713
1714    /// Accumulate `schur ← schur − Σ_i Y_i^T Y_i` and `rhs ← rhs + Σ_i Y_i^T u_i`
1715    /// using one GEMM and one GEMV per block. Each call uses beta=1 to chain
1716    /// the accumulation device-side.
1717    fn accumulate_schur(
1718        blas: &CudaBlas,
1719        d: usize,
1720        k: usize,
1721        n: usize,
1722        y_stack: &CudaSlice<f64>,
1723        u_stack: &CudaSlice<f64>,
1724        schur: &mut CudaSlice<f64>,
1725        rhs: &mut CudaSlice<f64>,
1726    ) -> Result<(), ArrowSchurGpuFailure> {
1727        let y_block_elems = d * k;
1728        let u_block_elems = d;
1729        for i in 0..n {
1730            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1731            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1732            // GEMM: schur += (-1) · Y_i^T · Y_i  (Y_i is d×k col-major; out is k×k)
1733            let gemm_cfg = GemmConfig::<f64> {
1734                transa: cublasOperation_t::CUBLAS_OP_T,
1735                transb: cublasOperation_t::CUBLAS_OP_N,
1736                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1737                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1738                k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1739                alpha: -1.0,
1740                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1741                ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1742                beta: 1.0,
1743                ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1744            };
1745            // SAFETY: y_slice is d×k col-major, schur is k×k col-major; alpha/beta scalars set above.
1746            unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1747                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1748            // GEMV: rhs += 1 · Y_i^T · u_i
1749            let gemv_cfg = GemvConfig::<f64> {
1750                trans: cublasOperation_t::CUBLAS_OP_T,
1751                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1752                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1753                alpha: 1.0,
1754                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1755                incx: 1,
1756                beta: 1.0,
1757                incy: 1,
1758            };
1759            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1760            // device buffers; `rhs` is the length-k accumulator.
1761            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1762                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1763        }
1764        Ok(())
1765    }
1766
1767    /// `#1017` resident gradient path: accumulate ONLY the Schur RHS term
1768    /// `rhs += Σ_i Y_iᵀ u_i`, skipping the `−Σ_i Y_iᵀ Y_i` matrix GEMM that the
1769    /// resident frame already folded into its persistent `L_S` factor. This is
1770    /// the per-iterate-cheap counterpart of [`accumulate_schur`]: the GEMV here
1771    /// is bit-identical to the GEMV inside `accumulate_schur` (same config, same
1772    /// `beta=1` accumulation order over rows), so the resident frame's `δβ`
1773    /// matches a full `solve()` at the same gradient.
1774    fn accumulate_schur_rhs_only(
1775        blas: &CudaBlas,
1776        d: usize,
1777        k: usize,
1778        n: usize,
1779        y_stack: &CudaSlice<f64>,
1780        u_stack: &CudaSlice<f64>,
1781        rhs: &mut CudaSlice<f64>,
1782    ) -> Result<(), ArrowSchurGpuFailure> {
1783        let y_block_elems = d * k;
1784        let u_block_elems = d;
1785        for i in 0..n {
1786            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1787            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1788            let gemv_cfg = GemvConfig::<f64> {
1789                trans: cublasOperation_t::CUBLAS_OP_T,
1790                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1791                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1792                alpha: 1.0,
1793                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1794                incx: 1,
1795                beta: 1.0,
1796                incy: 1,
1797            };
1798            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1799            // device buffers; `rhs` is the length-k accumulator.
1800            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1801                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1802        }
1803        Ok(())
1804    }
1805
1806    /// Accumulate `g_dev[i] ← u_i + Y_i · δβ` per block. This is the
1807    /// pre-trsm RHS for the back-substitution `L_i^T x_i = w_i`.
1808    fn accumulate_back_sub_rhs(
1809        blas: &CudaBlas,
1810        d: usize,
1811        k: usize,
1812        n: usize,
1813        y_stack: &CudaSlice<f64>,
1814        delta_beta: &CudaSlice<f64>,
1815        u_stack: &mut CudaSlice<f64>,
1816    ) -> Result<(), ArrowSchurGpuFailure> {
1817        let y_block_elems = d * k;
1818        let u_block_elems = d;
1819        for i in 0..n {
1820            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1821            let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1822            let gemv_cfg = GemvConfig::<f64> {
1823                trans: cublasOperation_t::CUBLAS_OP_N,
1824                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1825                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1826                alpha: 1.0,
1827                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1828                incx: 1,
1829                beta: 1.0,
1830                incy: 1,
1831            };
1832            // SAFETY: y_slice / delta_beta / u_slice are live device buffers
1833            // of the expected sizes (d×k, k, d).
1834            unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1835                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1836        }
1837        Ok(())
1838    }
1839
1840    // ────────────────────────────────────────────────────────────────────
1841    // Layer D + E — fused NVRTC dispatch.
1842    //
1843    // The forward kernel (`arrow_schur_forward_pgroup`) is a single launch
1844    // that, per row block, factors `D_i + ρI = L_i L_iᵀ` in shared memory,
1845    // forward-solves `u_i = L_i⁻¹ g_i` and `Y_i = L_i⁻¹ B_i`, and emits the
1846    // per-block Schur partials `partial_S[i] = Yᵀ Y` (R×R) and
1847    // `partial_r[i] = Yᵀ u` (R). The host reduces partials on the CPU after
1848    // dtoh (one fused sum across `n` blocks of R²+R doubles; cheap because
1849    // n·R² ≲ 5M doubles at large scale), assembles `S_β`, factors it via
1850    // cuSOLVER, and launches the back-substitution kernel
1851    // `arrow_schur_back_sub_pgroup` to recover `δt_i = -L_i⁻ᵀ(u_i + Y_i δβ)`
1852    // without re-uploading the local factors.
1853    // ────────────────────────────────────────────────────────────────────
1854
1855    use std::collections::HashMap;
1856    use std::sync::Mutex;
1857
1858    /// One compiled NVRTC module per `(cc_major, cc_minor, p_max, r_template)`.
1859    /// `cc_*` lets one process drive multiple device generations; the
1860    /// `(p_max, r_template)` pair selects the shared-memory layout baked into
1861    /// the kernel source.
1862    struct FusedModuleCache {
1863        modules: Mutex<
1864            HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1865        >,
1866    }
1867
1868    fn fused_module_cache() -> &'static FusedModuleCache {
1869        static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1870        CACHE.get_or_init(|| FusedModuleCache {
1871            modules: Mutex::new(HashMap::new()),
1872        })
1873    }
1874
1875    fn fused_module_for(
1876        ctx: &Arc<CudaContext>,
1877        key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
1878    ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
1879        let cache = fused_module_cache();
1880        if let Ok(guard) = cache.modules.lock() {
1881            if let Some(existing) = guard.get(&key) {
1882                return Ok(existing.clone());
1883            }
1884        }
1885        let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
1886            key.p_max as usize,
1887            key.r_template as usize,
1888        );
1889        let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
1890            ArrowSchurGpuFailure::SchurFactorFailed {
1891                reason: format!(
1892                    "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
1893                    key.p_max, key.r_template
1894                ),
1895            }
1896        })?;
1897        let module = ctx
1898            .load_module(ptx)
1899            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1900        if let Ok(mut guard) = cache.modules.lock() {
1901            guard.entry(key).or_insert_with(|| module.clone());
1902        }
1903        Ok(module)
1904    }
1905
1906    const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
1907extern "C" __global__ void arrow_pcg_jacobi_mul(
1908    const double* __restrict__ inv_diag,
1909    const double* __restrict__ r,
1910    double* __restrict__ z,
1911    int n
1912) {
1913    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1914    if (idx < n) {
1915        z[idx] = inv_diag[idx] * r[idx];
1916    }
1917}
1918
1919extern "C" __global__ void arrow_pcg_update_p(
1920    const double* __restrict__ z,
1921    double beta,
1922    double* __restrict__ p,
1923    int n
1924) {
1925    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1926    if (idx < n) {
1927        p[idx] = z[idx] + beta * p[idx];
1928    }
1929}
1930
1931extern "C" __global__ void arrow_sae_init(
1932    double* __restrict__ out,
1933    const double* __restrict__ x,
1934    double ridge,
1935    int n
1936) {
1937    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1938    if (idx < n) {
1939        out[idx] = ridge * x[idx];
1940    }
1941}
1942
1943extern "C" __global__ void arrow_sae_smooth_matvec(
1944    const double* __restrict__ x,
1945    double* __restrict__ out,
1946    const int* __restrict__ block_offsets,
1947    const int* __restrict__ block_m,
1948    const int* __restrict__ factor_ptr,
1949    const double* __restrict__ factors,
1950    int p,
1951    int n_blocks
1952) {
1953    int block_id = blockIdx.y;
1954    int linear = blockIdx.x * blockDim.x + threadIdx.x;
1955    if (block_id >= n_blocks) {
1956        return;
1957    }
1958    int m = block_m[block_id];
1959    int total = m * p;
1960    if (linear >= total) {
1961        return;
1962    }
1963    int li = linear / p;
1964    int oc = linear - li * p;
1965    int off = block_offsets[block_id];
1966    int fbase = factor_ptr[block_id];
1967    double acc = 0.0;
1968    for (int lj = 0; lj < m; ++lj) {
1969        double a = factors[fbase + li * m + lj];
1970        acc += a * x[off + lj * p + oc];
1971    }
1972    out[off + li * p + oc] += acc;
1973}
1974
1975extern "C" __global__ void arrow_sae_sparse_g_matvec(
1976    const double* __restrict__ x,
1977    double* __restrict__ out,
1978    const int* __restrict__ row_off,
1979    const int* __restrict__ col_off,
1980    const int* __restrict__ rows,
1981    const int* __restrict__ cols,
1982    const int* __restrict__ data_ptr,
1983    const double* __restrict__ data,
1984    int p,
1985    int n_blocks
1986) {
1987    int block_id = blockIdx.y;
1988    int linear = blockIdx.x * blockDim.x + threadIdx.x;
1989    if (block_id >= n_blocks) {
1990        return;
1991    }
1992    int m_i = rows[block_id];
1993    int m_j = cols[block_id];
1994    int total = m_i * p;
1995    if (linear >= total) {
1996        return;
1997    }
1998    int li = linear / p;
1999    int oc = linear - li * p;
2000    int rbase = row_off[block_id];
2001    int cbase = col_off[block_id];
2002    int dbase = data_ptr[block_id];
2003    double acc = 0.0;
2004    for (int lj = 0; lj < m_j; ++lj) {
2005        acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2006    }
2007    // #1017 — a row atom co-occurs with multiple column atoms, so several
2008    // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2009    // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2010    // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2011    // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2012    atomicAdd(&out[(rbase + li) * p + oc], acc);
2013}
2014
2015extern "C" __global__ void arrow_sae_gather_u(
2016    const double* __restrict__ x,
2017    const int* __restrict__ row_ptr,
2018    const int* __restrict__ beta_base,
2019    const double* __restrict__ phi,
2020    double* __restrict__ u,
2021    int p,
2022    int n_rows
2023) {
2024    int row = blockIdx.y;
2025    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2026    if (row >= n_rows || oc >= p) {
2027        return;
2028    }
2029    double acc = 0.0;
2030    int start = row_ptr[row];
2031    int end = row_ptr[row + 1];
2032    for (int e = start; e < end; ++e) {
2033        acc += phi[e] * x[beta_base[e] + oc];
2034    }
2035    u[row * p + oc] = acc;
2036}
2037
2038extern "C" __global__ void arrow_sae_apply_l(
2039    const double* __restrict__ u,
2040    const int* __restrict__ jac_ptr,
2041    const double* __restrict__ jac,
2042    double* __restrict__ w,
2043    int p,
2044    int max_q,
2045    int n_rows
2046) {
2047    int row = blockIdx.y;
2048    int c = blockIdx.x * blockDim.x + threadIdx.x;
2049    if (row >= n_rows) {
2050        return;
2051    }
2052    int jstart = jac_ptr[row];
2053    int q = (jac_ptr[row + 1] - jstart) / p;
2054    if (c >= q) {
2055        return;
2056    }
2057    double acc = 0.0;
2058    for (int oc = 0; oc < p; ++oc) {
2059        acc += jac[jstart + c * p + oc] * u[row * p + oc];
2060    }
2061    w[row * max_q + c] = acc;
2062}
2063
2064extern "C" __global__ void arrow_sae_apply_ainv(
2065    const double* __restrict__ ainv,
2066    const double* __restrict__ w,
2067    double* __restrict__ v,
2068    int max_q,
2069    int n_rows
2070) {
2071    int row = blockIdx.y;
2072    int c = blockIdx.x * blockDim.x + threadIdx.x;
2073    if (row >= n_rows || c >= max_q) {
2074        return;
2075    }
2076    double acc = 0.0;
2077    int base = row * max_q * max_q;
2078    for (int j = 0; j < max_q; ++j) {
2079        acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2080    }
2081    v[row * max_q + c] = acc;
2082}
2083
2084extern "C" __global__ void arrow_sae_scatter_sub(
2085    const double* __restrict__ v,
2086    const int* __restrict__ jac_ptr,
2087    const double* __restrict__ jac,
2088    const int* __restrict__ row_ptr,
2089    const int* __restrict__ beta_base,
2090    const double* __restrict__ phi,
2091    double* __restrict__ out,
2092    int p,
2093    int max_q,
2094    int n_rows
2095) {
2096    int row = blockIdx.y;
2097    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2098    if (row >= n_rows || oc >= p) {
2099        return;
2100    }
2101    int jstart = jac_ptr[row];
2102    int q = (jac_ptr[row + 1] - jstart) / p;
2103    double lt_v = 0.0;
2104    for (int c = 0; c < q; ++c) {
2105        lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2106    }
2107    int start = row_ptr[row];
2108    int end = row_ptr[row + 1];
2109    for (int e = start; e < end; ++e) {
2110        atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2111    }
2112}
2113
2114extern "C" __global__ void arrow_sae_diag_sub(
2115    double* __restrict__ diag,
2116    const double* __restrict__ ainv,
2117    const int* __restrict__ jac_ptr,
2118    const double* __restrict__ jac,
2119    const int* __restrict__ row_ptr,
2120    const int* __restrict__ beta_base,
2121    const double* __restrict__ phi,
2122    int p,
2123    int max_q,
2124    int n_rows
2125) {
2126    int row = blockIdx.y;
2127    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2128    if (row >= n_rows || oc >= p) {
2129        return;
2130    }
2131    int jstart = jac_ptr[row];
2132    int q = (jac_ptr[row + 1] - jstart) / p;
2133    int abase = row * max_q * max_q;
2134    double quad = 0.0;
2135    for (int c = 0; c < q; ++c) {
2136        double lc = jac[jstart + c * p + oc];
2137        for (int d = 0; d < q; ++d) {
2138            quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2139        }
2140    }
2141    int start = row_ptr[row];
2142    int end = row_ptr[row + 1];
2143    for (int e = start; e < end; ++e) {
2144        double pe = phi[e];
2145        atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2146    }
2147}
2148
2149/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2150 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2151 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2152 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2153 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2154
2155extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2156    const double* __restrict__ x,
2157    double* __restrict__ out,
2158    const int* __restrict__ block_offsets,
2159    const int* __restrict__ block_m,
2160    const int* __restrict__ block_r,
2161    const int* __restrict__ factor_ptr,
2162    const double* __restrict__ factors,
2163    int n_blocks
2164) {
2165    int block_id = blockIdx.y;
2166    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2167    if (block_id >= n_blocks) {
2168        return;
2169    }
2170    int m = block_m[block_id];
2171    int r = block_r[block_id];
2172    int total = m * r;
2173    if (linear >= total) {
2174        return;
2175    }
2176    int li = linear / r;
2177    int ib = linear - li * r;
2178    int off = block_offsets[block_id];
2179    int fbase = factor_ptr[block_id];
2180    double acc = 0.0;
2181    for (int lj = 0; lj < m; ++lj) {
2182        double a = factors[fbase + li * m + lj];
2183        acc += a * x[off + lj * r + ib];
2184    }
2185    out[off + li * r + ib] += acc;
2186}
2187
2188extern "C" __global__ void arrow_sae_frame_g_matvec(
2189    const double* __restrict__ x,
2190    double* __restrict__ out,
2191    const int* __restrict__ off_i,
2192    const int* __restrict__ off_j,
2193    const int* __restrict__ r_i,
2194    const int* __restrict__ r_j,
2195    const int* __restrict__ m_i,
2196    const int* __restrict__ m_j,
2197    const int* __restrict__ g_ptr,
2198    const double* __restrict__ g_data,
2199    const int* __restrict__ w_ptr,
2200    const double* __restrict__ w_data,
2201    int n_blocks
2202) {
2203    int block_id = blockIdx.y;
2204    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2205    if (block_id >= n_blocks) {
2206        return;
2207    }
2208    int ri = r_i[block_id];
2209    int rj = r_j[block_id];
2210    int mi = m_i[block_id];
2211    int mj = m_j[block_id];
2212    int total = mi * ri;
2213    if (linear >= total) {
2214        return;
2215    }
2216    int li = linear / ri;       // basis row in atom i
2217    int a = linear - li * ri;   // frame coord in atom i
2218    int oi = off_i[block_id];
2219    int oj = off_j[block_id];
2220    int gbase = g_ptr[block_id];
2221    int wbase = w_ptr[block_id];
2222    double acc = 0.0;
2223    for (int lj = 0; lj < mj; ++lj) {
2224        double g = g_data[gbase + li * mj + lj];
2225        if (g == 0.0) { continue; }
2226        int xj_base = oj + lj * rj;
2227        double inner = 0.0;
2228        for (int b = 0; b < rj; ++b) {
2229            inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2230        }
2231        acc += g * inner;
2232    }
2233    // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2234    // multiple co-occurring (i,j) frame blocks running concurrently on
2235    // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2236    // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2237    atomicAdd(&out[oi + li * ri + a], acc);
2238}
2239
2240/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2241 *   h_i   = H_tβ^(i) · x                (length q_i)
2242 *   s_i   = (H_tt^(i)+ρ_t I)⁻¹ h_i      (apply cached ainv, length q_i)
2243 *   out  -= (H_tβ^(i))ᵀ · s_i           (scatter into border_dim)
2244 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2245 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2246extern "C" __global__ void arrow_sae_frame_apply_h(
2247    const double* __restrict__ x,
2248    const int* __restrict__ htb_ptr,
2249    const double* __restrict__ htb,
2250    const int* __restrict__ q_of,
2251    double* __restrict__ hvec,
2252    int k,
2253    int max_q,
2254    int n_rows
2255) {
2256    int row = blockIdx.y;
2257    int c = blockIdx.x * blockDim.x + threadIdx.x;
2258    if (row >= n_rows) { return; }
2259    int q = q_of[row];
2260    if (c >= q) { return; }
2261    int base = htb_ptr[row] + c * k;
2262    double acc = 0.0;
2263    for (int a = 0; a < k; ++a) {
2264        acc += htb[base + a] * x[a];
2265    }
2266    hvec[row * max_q + c] = acc;
2267}
2268
2269extern "C" __global__ void arrow_sae_frame_apply_ainv(
2270    const double* __restrict__ ainv,
2271    const double* __restrict__ hvec,
2272    const int* __restrict__ q_of,
2273    double* __restrict__ svec,
2274    int max_q,
2275    int n_rows
2276) {
2277    int row = blockIdx.y;
2278    int c = blockIdx.x * blockDim.x + threadIdx.x;
2279    if (row >= n_rows || c >= max_q) { return; }
2280    int q = q_of[row];
2281    double acc = 0.0;
2282    int abase = row * max_q * max_q;
2283    for (int j = 0; j < q; ++j) {
2284        acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2285    }
2286    svec[row * max_q + c] = acc;
2287}
2288
2289extern "C" __global__ void arrow_sae_frame_scatter_h(
2290    const double* __restrict__ svec,
2291    const int* __restrict__ htb_ptr,
2292    const double* __restrict__ htb,
2293    const int* __restrict__ q_of,
2294    double* __restrict__ out,
2295    int k,
2296    int max_q,
2297    int n_rows
2298) {
2299    int row = blockIdx.y;
2300    int a = blockIdx.x * blockDim.x + threadIdx.x;
2301    if (row >= n_rows || a >= k) { return; }
2302    int q = q_of[row];
2303    int hbase = htb_ptr[row];
2304    double acc = 0.0;
2305    for (int c = 0; c < q; ++c) {
2306        acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2307    }
2308    atomicAdd(&out[a], -acc);
2309}
2310
2311/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2312extern "C" __global__ void arrow_sae_frame_diag_sub(
2313    double* __restrict__ diag,
2314    const double* __restrict__ ainv,
2315    const int* __restrict__ htb_ptr,
2316    const double* __restrict__ htb,
2317    const int* __restrict__ q_of,
2318    int k,
2319    int max_q,
2320    int n_rows
2321) {
2322    int row = blockIdx.y;
2323    int a = blockIdx.x * blockDim.x + threadIdx.x;
2324    if (row >= n_rows || a >= k) { return; }
2325    int q = q_of[row];
2326    int hbase = htb_ptr[row];
2327    int abase = row * max_q * max_q;
2328    double quad = 0.0;
2329    for (int c = 0; c < q; ++c) {
2330        double hc = htb[hbase + c * k + a];
2331        for (int d = 0; d < q; ++d) {
2332            quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2333        }
2334    }
2335    atomicAdd(&diag[a], -quad);
2336}
2337"#;
2338
2339    fn pcg_vector_module(
2340        ctx: &Arc<CudaContext>,
2341    ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2342        static CACHE: gam_gpu::device_cache::PtxModuleCache =
2343            gam_gpu::device_cache::PtxModuleCache::new();
2344        CACHE
2345            .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2346            .map_err(|err| {
2347                // #1551: an NVRTC compile / module-load failure of
2348                // PCG_VECTOR_KERNEL_SOURCE means the device SAE PCG cannot run;
2349                // log it (the historical silent collapse to `Unavailable` is what
2350                // masked the missing `--gpu-architecture` for so long) and fall
2351                // back to the CPU.
2352                log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2353                ArrowSchurGpuFailure::Unavailable
2354            })
2355    }
2356
2357    fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2358        let threads = 256u32;
2359        let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2360        Ok(LaunchConfig {
2361            grid_dim: (blocks, 1, 1),
2362            block_dim: (threads, 1, 1),
2363            shared_mem_bytes: 0,
2364        })
2365    }
2366
2367    fn launch_jacobi_mul(
2368        stream: &Arc<CudaStream>,
2369        module: &Arc<CudaModule>,
2370        inv_diag: &CudaSlice<f64>,
2371        r: &CudaSlice<f64>,
2372        z: &mut CudaSlice<f64>,
2373        n: usize,
2374    ) -> Result<(), ArrowSchurGpuFailure> {
2375        let kernel = module
2376            .load_function("arrow_pcg_jacobi_mul")
2377            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2378        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2379        let mut builder = stream.launch_builder(&kernel);
2380        builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2381        // SAFETY: all buffers have length n and belong to `stream`; the kernel only
2382        // reads/writes indices `< n`.
2383        unsafe { builder.launch(pcg_launch_config(n)?) }
2384            .map(drop)
2385            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2386    }
2387
2388    fn launch_update_p(
2389        stream: &Arc<CudaStream>,
2390        module: &Arc<CudaModule>,
2391        z: &CudaSlice<f64>,
2392        beta: f64,
2393        p: &mut CudaSlice<f64>,
2394        n: usize,
2395    ) -> Result<(), ArrowSchurGpuFailure> {
2396        let kernel = module
2397            .load_function("arrow_pcg_update_p")
2398            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2399        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2400        let mut builder = stream.launch_builder(&kernel);
2401        builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2402        // SAFETY: z/p both have length n and belong to `stream`; the kernel only
2403        // reads/writes indices `< n`.
2404        unsafe { builder.launch(pcg_launch_config(n)?) }
2405            .map(drop)
2406            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2407    }
2408
2409    struct DeviceSaePcgBuffers {
2410        row_ptr: CudaSlice<i32>,
2411        beta_base: CudaSlice<i32>,
2412        phi: CudaSlice<f64>,
2413        jac_ptr: CudaSlice<i32>,
2414        jac: CudaSlice<f64>,
2415        smooth_offsets: CudaSlice<i32>,
2416        smooth_m: CudaSlice<i32>,
2417        smooth_ptr: CudaSlice<i32>,
2418        smooth_data: CudaSlice<f64>,
2419        g_row_off: CudaSlice<i32>,
2420        g_col_off: CudaSlice<i32>,
2421        g_rows: CudaSlice<i32>,
2422        g_cols: CudaSlice<i32>,
2423        g_ptr: CudaSlice<i32>,
2424        g_data: CudaSlice<f64>,
2425        ainv: CudaSlice<f64>,
2426        u: CudaSlice<f64>,
2427        w: CudaSlice<f64>,
2428        v: CudaSlice<f64>,
2429        n_rows: usize,
2430        p: usize,
2431        k: usize,
2432        max_q: usize,
2433        smooth_blocks: usize,
2434        g_blocks: usize,
2435    }
2436
2437    fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2438        to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2439    }
2440
2441    fn sae_penalty_diag_host(
2442        data: &DeviceSaePcgData,
2443        ridge_beta: f64,
2444    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2445        let mut diag = vec![ridge_beta; data.beta_dim];
2446        for block in &data.smooth_blocks {
2447            let (rows, cols) = block.factor_a.dim();
2448            if rows != cols {
2449                return Err(ArrowSchurGpuFailure::Unavailable);
2450            }
2451            for row in 0..rows {
2452                let coeff = block.factor_a[[row, row]];
2453                let base = block
2454                    .global_offset
2455                    .checked_add(
2456                        row.checked_mul(data.p)
2457                            .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2458                    )
2459                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2460                let end = base
2461                    .checked_add(data.p)
2462                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2463                if end > diag.len() {
2464                    return Err(ArrowSchurGpuFailure::Unavailable);
2465                }
2466                for channel in 0..data.p {
2467                    diag[base + channel] += coeff;
2468                }
2469            }
2470        }
2471        for block in &data.sparse_g_blocks {
2472            if block.row_off != block.col_off {
2473                continue;
2474            }
2475            let (rows, cols) = block.data.dim();
2476            for row in 0..rows.min(cols) {
2477                let coeff = block.data[[row, row]];
2478                let beta_row = block
2479                    .row_off
2480                    .checked_add(row)
2481                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2482                let base = beta_row
2483                    .checked_mul(data.p)
2484                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2485                let end = base
2486                    .checked_add(data.p)
2487                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2488                if end > diag.len() {
2489                    return Err(ArrowSchurGpuFailure::Unavailable);
2490                }
2491                for channel in 0..data.p {
2492                    diag[base + channel] += coeff;
2493                }
2494            }
2495        }
2496        Ok(diag)
2497    }
2498
2499    fn flatten_device_sae_data(
2500        sys: &ArrowSchurSystem,
2501        data: &DeviceSaePcgData,
2502        ridge_t: f64,
2503        stream: &Arc<CudaStream>,
2504    ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2505        let n_rows = sys.rows.len();
2506        let p = data.p;
2507        let k = data.beta_dim;
2508        if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2509            return Err(ArrowSchurGpuFailure::Unavailable);
2510        }
2511
2512        let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2513        let mut beta_base_host = Vec::<i32>::new();
2514        let mut phi_host = Vec::<f64>::new();
2515        row_ptr_host.push(0_i32);
2516        for row in data.a_phi.iter() {
2517            for &(base, phi) in row {
2518                beta_base_host.push(checked_i32(base)?);
2519                phi_host.push(phi);
2520            }
2521            row_ptr_host.push(checked_i32(beta_base_host.len())?);
2522        }
2523
2524        let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2525        let mut jac_host = Vec::<f64>::new();
2526        let mut max_q = 0usize;
2527        jac_ptr_host.push(0_i32);
2528        for row_jac in data.local_jac.iter() {
2529            if row_jac.len() % p != 0 {
2530                return Err(ArrowSchurGpuFailure::Unavailable);
2531            }
2532            max_q = max_q.max(row_jac.len() / p);
2533            jac_host.extend_from_slice(row_jac);
2534            jac_ptr_host.push(checked_i32(jac_host.len())?);
2535        }
2536        if max_q == 0 {
2537            return Err(ArrowSchurGpuFailure::Unavailable);
2538        }
2539
2540        let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2541        let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2542        let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2543        let mut smooth_data_host = Vec::<f64>::new();
2544        smooth_ptr_host.push(0_i32);
2545        for block in &data.smooth_blocks {
2546            let (rows, cols) = block.factor_a.dim();
2547            if rows != cols {
2548                return Err(ArrowSchurGpuFailure::Unavailable);
2549            }
2550            smooth_offsets_host.push(checked_i32(block.global_offset)?);
2551            smooth_m_host.push(checked_i32(rows)?);
2552            for r in 0..rows {
2553                for c in 0..cols {
2554                    smooth_data_host.push(block.factor_a[[r, c]]);
2555                }
2556            }
2557            smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2558        }
2559
2560        let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2561        let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2562        let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2563        let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2564        let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2565        let mut g_data_host = Vec::<f64>::new();
2566        g_ptr_host.push(0_i32);
2567        for block in &data.sparse_g_blocks {
2568            let (rows, cols) = block.data.dim();
2569            g_row_off_host.push(checked_i32(block.row_off)?);
2570            g_col_off_host.push(checked_i32(block.col_off)?);
2571            g_rows_host.push(checked_i32(rows)?);
2572            g_cols_host.push(checked_i32(cols)?);
2573            for r in 0..rows {
2574                for c in 0..cols {
2575                    g_data_host.push(block.data[[r, c]]);
2576                }
2577            }
2578            g_ptr_host.push(checked_i32(g_data_host.len())?);
2579        }
2580
2581        let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2582        for (row_idx, row) in sys.rows.iter().enumerate() {
2583            let q = data.local_jac[row_idx].len() / p;
2584            if row.htt.dim() != (q, q) {
2585                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2586                    reason: format!(
2587                        "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2588                        row.htt.dim()
2589                    ),
2590                });
2591            }
2592            let mut block = row.htt.clone();
2593            for d in 0..q {
2594                block[[d, d]] += ridge_t;
2595            }
2596            let factor = gam_linalg::triangular::cholesky_factor_in_place(
2597                block.view(),
2598                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2599            )
2600            .ok_or_else(|| {
2601                let scale = row
2602                    .htt
2603                    .diag()
2604                    .iter()
2605                    .map(|v| v.abs())
2606                    .fold(0.0_f64, f64::max)
2607                    .max(1.0);
2608                ArrowSchurGpuFailure::RidgeBumpRequired {
2609                    row: row_idx,
2610                    bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
2611                }
2612            })?;
2613            for col in 0..q {
2614                let mut e = Array1::<f64>::zeros(q);
2615                e[col] = 1.0;
2616                let solved =
2617                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2618                for r in 0..q {
2619                    ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2620                }
2621            }
2622        }
2623
2624        Ok(DeviceSaePcgBuffers {
2625            row_ptr: stream
2626                .clone_htod(&row_ptr_host)
2627                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2628            beta_base: stream
2629                .clone_htod(&beta_base_host)
2630                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2631            phi: stream
2632                .clone_htod(&phi_host)
2633                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2634            jac_ptr: stream
2635                .clone_htod(&jac_ptr_host)
2636                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2637            jac: stream
2638                .clone_htod(&jac_host)
2639                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2640            smooth_offsets: stream
2641                .clone_htod(&smooth_offsets_host)
2642                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2643            smooth_m: stream
2644                .clone_htod(&smooth_m_host)
2645                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2646            smooth_ptr: stream
2647                .clone_htod(&smooth_ptr_host)
2648                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2649            smooth_data: stream
2650                .clone_htod(&smooth_data_host)
2651                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2652            g_row_off: stream
2653                .clone_htod(&g_row_off_host)
2654                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2655            g_col_off: stream
2656                .clone_htod(&g_col_off_host)
2657                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2658            g_rows: stream
2659                .clone_htod(&g_rows_host)
2660                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2661            g_cols: stream
2662                .clone_htod(&g_cols_host)
2663                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2664            g_ptr: stream
2665                .clone_htod(&g_ptr_host)
2666                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2667            g_data: stream
2668                .clone_htod(&g_data_host)
2669                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2670            ainv: stream
2671                .clone_htod(&ainv_host)
2672                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2673            u: stream
2674                .alloc_zeros::<f64>(n_rows * p)
2675                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2676            w: stream
2677                .alloc_zeros::<f64>(n_rows * max_q)
2678                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2679            v: stream
2680                .alloc_zeros::<f64>(n_rows * max_q)
2681                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2682            n_rows,
2683            p,
2684            k,
2685            max_q,
2686            smooth_blocks: data.smooth_blocks.len(),
2687            g_blocks: data.sparse_g_blocks.len(),
2688        })
2689    }
2690
2691    fn launch_sae_init(
2692        stream: &Arc<CudaStream>,
2693        module: &Arc<CudaModule>,
2694        out: &mut CudaSlice<f64>,
2695        x: &CudaSlice<f64>,
2696        ridge: f64,
2697        n: usize,
2698    ) -> Result<(), ArrowSchurGpuFailure> {
2699        let kernel = module
2700            .load_function("arrow_sae_init")
2701            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2702        let n_i32 = checked_i32(n)?;
2703        let mut builder = stream.launch_builder(&kernel);
2704        builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2705        // SAFETY: `out` and `x` are live device buffers with at least `n`
2706        // entries on `stream`; the kernel writes one in-bounds element per
2707        // launched index below `n`.
2708        unsafe { builder.launch(pcg_launch_config(n)?) }
2709            .map(drop)
2710            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2711    }
2712
2713    fn launch_sae_penalty_matvec(
2714        stream: &Arc<CudaStream>,
2715        module: &Arc<CudaModule>,
2716        buffers: &mut DeviceSaePcgBuffers,
2717        x: &CudaSlice<f64>,
2718        out: &mut CudaSlice<f64>,
2719        ridge_beta: f64,
2720    ) -> Result<(), ArrowSchurGpuFailure> {
2721        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2722        if buffers.smooth_blocks > 0 {
2723            let kernel = module
2724                .load_function("arrow_sae_smooth_matvec")
2725                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2726            let max_m = buffers.k;
2727            let p_i32 = checked_i32(buffers.p)?;
2728            let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2729            let cfg = LaunchConfig {
2730                grid_dim: (
2731                    ((max_m as u32).saturating_add(255) / 256).max(1),
2732                    checked_i32(buffers.smooth_blocks)? as u32,
2733                    1,
2734                ),
2735                block_dim: (256, 1, 1),
2736                shared_mem_bytes: 0,
2737            };
2738            let mut builder = stream.launch_builder(&kernel);
2739            builder
2740                .arg(x)
2741                .arg(&mut *out)
2742                .arg(&buffers.smooth_offsets)
2743                .arg(&buffers.smooth_m)
2744                .arg(&buffers.smooth_ptr)
2745                .arg(&buffers.smooth_data)
2746                .arg(&p_i32)
2747                .arg(&blocks_i32);
2748            // SAFETY: smooth block metadata and dense smooth data were flattened
2749            // into live device buffers; the 2D grid covers only declared block
2750            // and coefficient-channel work items, and the kernel bounds-checks
2751            // against each block's stored size.
2752            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2753        }
2754        if buffers.g_blocks > 0 {
2755            let kernel = module
2756                .load_function("arrow_sae_sparse_g_matvec")
2757                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2758            let max_work = buffers
2759                .k
2760                .checked_div(buffers.p)
2761                .unwrap_or(0)
2762                .saturating_mul(buffers.p);
2763            let p_i32 = checked_i32(buffers.p)?;
2764            let blocks_i32 = checked_i32(buffers.g_blocks)?;
2765            let cfg = LaunchConfig {
2766                grid_dim: (
2767                    ((max_work as u32).saturating_add(255) / 256).max(1),
2768                    checked_i32(buffers.g_blocks)? as u32,
2769                    1,
2770                ),
2771                block_dim: (256, 1, 1),
2772                shared_mem_bytes: 0,
2773            };
2774            let mut builder = stream.launch_builder(&kernel);
2775            builder
2776                .arg(x)
2777                .arg(&mut *out)
2778                .arg(&buffers.g_row_off)
2779                .arg(&buffers.g_col_off)
2780                .arg(&buffers.g_rows)
2781                .arg(&buffers.g_cols)
2782                .arg(&buffers.g_ptr)
2783                .arg(&buffers.g_data)
2784                .arg(&p_i32)
2785                .arg(&blocks_i32);
2786            // SAFETY: sparse G block metadata/data are live device buffers built
2787            // from host CSR-like block descriptors; the launch dimensions cover
2788            // declared block work only and the kernel checks row/column bounds
2789            // before reading or accumulating.
2790            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2791        }
2792        Ok(())
2793    }
2794
2795    fn launch_sae_row_schur_sub(
2796        stream: &Arc<CudaStream>,
2797        module: &Arc<CudaModule>,
2798        buffers: &mut DeviceSaePcgBuffers,
2799        x: &CudaSlice<f64>,
2800        out: &mut CudaSlice<f64>,
2801    ) -> Result<(), ArrowSchurGpuFailure> {
2802        let p_i32 = checked_i32(buffers.p)?;
2803        let max_q_i32 = checked_i32(buffers.max_q)?;
2804        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2805        let cfg_p_rows = LaunchConfig {
2806            grid_dim: (
2807                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2808                checked_i32(buffers.n_rows)? as u32,
2809                1,
2810            ),
2811            block_dim: (256, 1, 1),
2812            shared_mem_bytes: 0,
2813        };
2814        let gather = module
2815            .load_function("arrow_sae_gather_u")
2816            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2817        {
2818            let mut builder = stream.launch_builder(&gather);
2819            builder
2820                .arg(x)
2821                .arg(&buffers.row_ptr)
2822                .arg(&buffers.beta_base)
2823                .arg(&buffers.phi)
2824                .arg(&mut buffers.u)
2825                .arg(&p_i32)
2826                .arg(&n_rows_i32);
2827            // SAFETY: `x`, row pointers, beta offsets, basis rows, and `u` are
2828            // live device buffers sized for `n_rows` by `p`; the kernel guards
2829            // row/channel indices before gathering.
2830            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2831        }
2832
2833        let cfg_q_rows = LaunchConfig {
2834            grid_dim: (
2835                ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2836                checked_i32(buffers.n_rows)? as u32,
2837                1,
2838            ),
2839            block_dim: (256, 1, 1),
2840            shared_mem_bytes: 0,
2841        };
2842        let apply_l = module
2843            .load_function("arrow_sae_apply_l")
2844            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2845        {
2846            let mut builder = stream.launch_builder(&apply_l);
2847            builder
2848                .arg(&buffers.u)
2849                .arg(&buffers.jac_ptr)
2850                .arg(&buffers.jac)
2851                .arg(&mut buffers.w)
2852                .arg(&p_i32)
2853                .arg(&max_q_i32)
2854                .arg(&n_rows_i32);
2855            // SAFETY: `u`, Jacobian row pointers/data, and `w` are live buffers
2856            // sized for the `(n_rows, p)` to `(n_rows, max_q)` multiply; the
2857            // kernel checks row and local-coordinate bounds.
2858            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2859        }
2860
2861        let apply_ainv = module
2862            .load_function("arrow_sae_apply_ainv")
2863            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2864        {
2865            let mut builder = stream.launch_builder(&apply_ainv);
2866            builder
2867                .arg(&buffers.ainv)
2868                .arg(&buffers.w)
2869                .arg(&mut buffers.v)
2870                .arg(&max_q_i32)
2871                .arg(&n_rows_i32);
2872            // SAFETY: `ainv`, `w`, and `v` are live device buffers sized for
2873            // `n_rows * max_q`; the kernel guards all row/local-coordinate
2874            // indices before reading or writing.
2875            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2876        }
2877
2878        let scatter = module
2879            .load_function("arrow_sae_scatter_sub")
2880            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2881        {
2882            let mut builder = stream.launch_builder(&scatter);
2883            builder
2884                .arg(&buffers.v)
2885                .arg(&buffers.jac_ptr)
2886                .arg(&buffers.jac)
2887                .arg(&buffers.row_ptr)
2888                .arg(&buffers.beta_base)
2889                .arg(&buffers.phi)
2890                .arg(out)
2891                .arg(&p_i32)
2892                .arg(&max_q_i32)
2893                .arg(&n_rows_i32);
2894            // SAFETY: `v`, Jacobian metadata, row pointers, beta offsets, basis
2895            // rows, and `out` are live buffers for `n_rows` by `p`; scatter
2896            // indices are checked against row and channel bounds in the kernel.
2897            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2898        }
2899        Ok(())
2900    }
2901
2902    fn launch_sae_diag_sub(
2903        stream: &Arc<CudaStream>,
2904        module: &Arc<CudaModule>,
2905        buffers: &DeviceSaePcgBuffers,
2906        diag: &mut CudaSlice<f64>,
2907    ) -> Result<(), ArrowSchurGpuFailure> {
2908        let kernel = module
2909            .load_function("arrow_sae_diag_sub")
2910            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2911        let p_i32 = checked_i32(buffers.p)?;
2912        let max_q_i32 = checked_i32(buffers.max_q)?;
2913        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2914        let cfg = LaunchConfig {
2915            grid_dim: (
2916                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2917                checked_i32(buffers.n_rows)? as u32,
2918                1,
2919            ),
2920            block_dim: (256, 1, 1),
2921            shared_mem_bytes: 0,
2922        };
2923        let mut builder = stream.launch_builder(&kernel);
2924        builder
2925            .arg(diag)
2926            .arg(&buffers.ainv)
2927            .arg(&buffers.jac_ptr)
2928            .arg(&buffers.jac)
2929            .arg(&buffers.row_ptr)
2930            .arg(&buffers.beta_base)
2931            .arg(&buffers.phi)
2932            .arg(&p_i32)
2933            .arg(&max_q_i32)
2934            .arg(&n_rows_i32);
2935        // SAFETY: diagonal output and all read-only SAE row metadata buffers are
2936        // live on `stream` with sizes matching `n_rows`, `p`, and `max_q`; the
2937        // kernel bounds-checks its flattened work index.
2938        unsafe { builder.launch(cfg) }
2939            .map(drop)
2940            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2941    }
2942
2943    fn launch_sae_matvec(
2944        stream: &Arc<CudaStream>,
2945        module: &Arc<CudaModule>,
2946        buffers: &mut DeviceSaePcgBuffers,
2947        x: &CudaSlice<f64>,
2948        out: &mut CudaSlice<f64>,
2949        ridge_beta: f64,
2950    ) -> Result<(), ArrowSchurGpuFailure> {
2951        launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
2952        launch_sae_row_schur_sub(stream, module, buffers, x, out)
2953    }
2954
2955    /// Pack `D + ρ_t I`, `B`, and `g` into the strided `(n × P_MAX × P_MAX)`
2956    /// / `(n × P_MAX × R_TEMPLATE)` / `(n × P_MAX)` layout the fused kernel
2957    /// expects. Entries outside the runtime `(p, r)` window stay at zero so
2958    /// the kernel's per-element loops are safe to no-op there.
2959    fn pack_fused_host(
2960        sys: &ArrowSchurSystem,
2961        ridge_t: f64,
2962        p_max: usize,
2963        r_template: usize,
2964    ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
2965        let n = sys.rows.len();
2966        let d = sys.d;
2967        let k = sys.k;
2968        let mut d_buf = vec![0.0_f64; n * p_max * p_max];
2969        let mut b_buf = vec![0.0_f64; n * p_max * r_template];
2970        let mut g_buf = vec![0.0_f64; n * p_max];
2971        for (i, row) in sys.rows.iter().enumerate() {
2972            // D_i + ρI, column-major in P_MAX×P_MAX strided block.
2973            for col in 0..d {
2974                let base = (i * p_max + col) * p_max;
2975                for r in 0..d {
2976                    let mut value = row.htt[[r, col]];
2977                    if r == col {
2978                        value += ridge_t;
2979                    }
2980                    d_buf[base + r] = value;
2981                }
2982            }
2983            // B_i in P_MAX×R_TEMPLATE strided block. The per-row (per-i) block
2984            // stride is `p_max · r_template` (matching the `b_buf` allocation
2985            // above and the kernel's `b_stack`/`y_out` layout), NOT
2986            // `p_max · p_max`: using the D-block multiplier here overflows the
2987            // buffer whenever `p_max > r_template` (e.g. d=30→p_max=32,
2988            // k=5→r_template=5). The within-block element offset stays
2989            // column-major `col·p_max + r` (P_MAX rows per column).
2990            for col in 0..k {
2991                let base = (i * r_template + col) * p_max;
2992                for r in 0..d {
2993                    b_buf[base + r] = row.htbeta[[r, col]];
2994                }
2995            }
2996            // g_i in P_MAX strided vector.
2997            let g_base = i * p_max;
2998            for r in 0..d {
2999                g_buf[g_base + r] = row.gt[r];
3000            }
3001        }
3002        (d_buf, b_buf, g_buf)
3003    }
3004
3005    // -----------------------------------------------------------------------
3006    // #1017 Phase 3: across-iteration device residency.
3007    //
3008    // `solve()` re-packs and re-uploads `D` (`H_tt`), `B` (`H_tβ`) and `g`,
3009    // then re-runs the per-row POTRF and the border Schur factorization on
3010    // EVERY call. For the SAE joint inner Newton at a frozen gate/basis frame
3011    // the Hessian blocks `D`, `B`, `H_ββ` are CONSTANT across the inner loop —
3012    // only the gradient `g = r(z) = H z − g₀` changes per iterate. So the
3013    // factor work (`O(n·d³ + p³)`) and the dominant `O(n·d·p)` cross-block
3014    // upload are pure waste when repeated per iterate.
3015    //
3016    // `ResidentArrowFrame` performs that constant work ONCE at construction:
3017    // upload+ridge+POTRF of `D` (keeping `L_i` resident in `l_dev`), the
3018    // forward solve `Y_i = L_i^{-1} B_i` (kept resident in `y_dev`), and the
3019    // Schur assembly + border POTRF (keeping `L_S` resident in `schur_dev`).
3020    // Each subsequent `solve_gradient(g)` uploads only the `n·d` row gradient,
3021    // runs the cheap residual path — `u_i = L_i^{-1} g_i` (one batched TRSM),
3022    // Schur RHS `−g_β + Σ Y_iᵀ u_i`, `δβ = L_S^{-T} L_S^{-1} rhs` (two TRSM,
3023    // NO POTRF), back-sub `δt_i = −L_i^{-T}(u_i + Y_i δβ)` — and reads back only
3024    // `δ` and the cached log|H|. The heavy buffers never leave the device
3025    // across iterations; the per-iterate host transfer is `O(n·d + p)`, not
3026    // `O(n·d·p)`. Numerics are bit-identical to a `solve()` at the same
3027    // `(D, B, H_ββ, g, ridge_t, ridge_beta)` because the factor buffers and the
3028    // helper kernels are the same; the resident path merely SKIPS re-deriving
3029    // the parts that do not depend on `g`. The CPU dense reference
3030    // (`solve_arrow_newton_step_dense_reference`) is the parity oracle.
3031    pub(super) struct ResidentArrowFrame {
3032        n: usize,
3033        d: usize,
3034        k: usize,
3035        stream: Arc<CudaStream>,
3036        blas: CudaBlas,
3037        /// Per-row lower Cholesky factors `L_i` of `H_tt + ρ_t I`, stacked
3038        /// column-major (`n` tiles of `d×d`). Resident across iterations.
3039        l_dev: CudaSlice<f64>,
3040        /// Whitened cross blocks `Y_i = L_i^{-1} H_tβ^(i)`, stacked column-major
3041        /// (`n` tiles of `d×k`). Resident across iterations.
3042        y_dev: CudaSlice<f64>,
3043        /// Lower Cholesky factor `L_S` of the reduced Schur complement
3044        /// `S_β = H_ββ + ρ_β I − Σ_i Y_iᵀ Y_i`. Resident across iterations.
3045        schur_dev: CudaSlice<f64>,
3046        /// `log|H| = 2 Σ log L_{i,jj} + 2 Σ log L_{S,aa}`, constant for the
3047        /// frame (depends only on the factored Hessian, not on `g`).
3048        log_det_hessian: f64,
3049    }
3050
3051    impl ResidentArrowFrame {
3052        /// Upload the constant Hessian blocks and perform the one-time factor
3053        /// work (`POTRF(D)`, `Y_i = L_i^{-1} B_i`, Schur assembly + border
3054        /// `POTRF`). The frame then serves cheap per-gradient solves.
3055        pub(super) fn new(
3056            sys: &ArrowSchurSystem,
3057            ridge_t: f64,
3058            ridge_beta: f64,
3059        ) -> Result<Self, ArrowSchurGpuFailure> {
3060            if ridge_t.is_nan() || ridge_beta.is_nan() {
3061                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3062                    reason: "ridge is NaN".to_string(),
3063                });
3064            }
3065            let n = sys.rows.len();
3066            let d = sys.d;
3067            let k = sys.k;
3068            let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3069                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3070            let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3071                .and_then(|ctx| ctx.new_stream().ok())
3072                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3073            let solver =
3074                DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3075            let blas =
3076                CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3077
3078            // Upload the constant blocks. `g` is uploaded per-gradient, not here.
3079            let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3080            let mut l_dev = stream
3081                .clone_htod(&d_host)
3082                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3083            let mut y_dev = stream
3084                .clone_htod(&b_host)
3085                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3086
3087            // POTRF(D) → L_i, kept resident in l_dev.
3088            let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3089            if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3090                let pivot = info_host[idx];
3091                let scale = sys.rows[idx]
3092                    .htt
3093                    .diag()
3094                    .iter()
3095                    .map(|v| v.abs())
3096                    .fold(0.0_f64, f64::max)
3097                    .max(1.0);
3098                return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3099                    row: idx,
3100                    bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3101                });
3102            }
3103
3104            // Y_i = L_i^{-1} B_i, in place over y_dev. Kept resident.
3105            trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3106
3107            // Schur assembly S_β = (H_ββ + ρ_β I) − Σ Y_iᵀ Y_i, then POTRF → L_S.
3108            // The RHS accumulation is folded into the gradient path; here we
3109            // only need the (gradient-independent) Schur factor, so accumulate
3110            // into a throwaway rhs buffer.
3111            let schur_init: Vec<f64> = {
3112                let mut tmp = Vec::with_capacity(k * k);
3113                for col in 0..k {
3114                    for row in 0..k {
3115                        let mut v = sys.hbb[[row, col]];
3116                        if row == col {
3117                            v += ridge_beta;
3118                        }
3119                        tmp.push(v);
3120                    }
3121                }
3122                tmp
3123            };
3124            let mut schur_dev = stream
3125                .clone_htod(&schur_init)
3126                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3127            // A zero u-stack makes `Σ Y_iᵀ u_i = 0`, so only the `−Σ Y_iᵀ Y_i`
3128            // Schur term is accumulated (the rhs is rebuilt per gradient).
3129            let zero_u = stream
3130                .clone_htod(&vec![0.0_f64; n * d])
3131                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3132            let mut throwaway_rhs = stream
3133                .clone_htod(&vec![0.0_f64; k])
3134                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3135            accumulate_schur(
3136                &blas,
3137                d,
3138                k,
3139                n,
3140                &y_dev,
3141                &zero_u,
3142                &mut schur_dev,
3143                &mut throwaway_rhs,
3144            )?;
3145            let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3146            if info != 0 {
3147                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3148                    reason: format!("Schur Cholesky failed at pivot {info}"),
3149                });
3150            }
3151
3152            // log|H| from the resident factors (constant for the frame).
3153            let l_local_host = stream
3154                .clone_dtoh(&l_dev)
3155                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3156            let l_schur_host = stream
3157                .clone_dtoh(&schur_dev)
3158                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3159            let mut log_det = 0.0_f64;
3160            for i in 0..n {
3161                let base = i * d * d;
3162                for j in 0..d {
3163                    log_det += l_local_host[base + j * d + j].ln();
3164                }
3165            }
3166            for j in 0..k {
3167                log_det += l_schur_host[j * k + j].ln();
3168            }
3169            log_det *= 2.0;
3170
3171            Ok(Self {
3172                n,
3173                d,
3174                k,
3175                stream,
3176                blas,
3177                l_dev,
3178                y_dev,
3179                schur_dev,
3180                log_det_hessian: log_det,
3181            })
3182        }
3183
3184        #[inline]
3185        pub(super) fn log_det_hessian(&self) -> f64 {
3186            self.log_det_hessian
3187        }
3188
3189        /// Solve `H δ = −gradient` for a fresh gradient `(g_t, g_β)` reusing the
3190        /// resident factors. Uploads only `g_t` (`n·d` scalars); reads back only
3191        /// `δ`. No POTRF runs here — all factorization is amortized into `new`.
3192        pub(super) fn solve_gradient(
3193            &self,
3194            g_t: &[f64],
3195            g_beta: &[f64],
3196        ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3197            let n = self.n;
3198            let d = self.d;
3199            let k = self.k;
3200            if g_t.len() != n * d || g_beta.len() != k {
3201                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3202                    reason: format!(
3203                        "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3204                        g_t.len(),
3205                        n * d,
3206                        g_beta.len(),
3207                        k
3208                    ),
3209                });
3210            }
3211            // Upload the per-iterate row gradient → u_i = L_i^{-1} g_i in place.
3212            let mut u_dev = self
3213                .stream
3214                .clone_htod(&g_t.to_vec())
3215                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3216            trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3217
3218            // Schur RHS = −g_β + Σ_i Y_iᵀ u_i. Reuse the resident Schur factor
3219            // (no POTRF, and skip the −Σ Y_iᵀ Y_i GEMM already baked into L_S).
3220            let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3221            let mut rhs_dev = self
3222                .stream
3223                .clone_htod(&rhs_init)
3224                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3225            accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3226
3227            // δβ ← L_S^{-T} L_S^{-1} rhs using the resident border factor.
3228            trsm_single(
3229                &self.blas,
3230                &self.stream,
3231                k,
3232                &self.schur_dev,
3233                &mut rhs_dev,
3234                false,
3235                false,
3236            )?;
3237            trsm_single(
3238                &self.blas,
3239                &self.stream,
3240                k,
3241                &self.schur_dev,
3242                &mut rhs_dev,
3243                false,
3244                true,
3245            )?;
3246            let delta_beta_host = self
3247                .stream
3248                .clone_dtoh(&rhs_dev)
3249                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3250            let delta_beta = Array1::from_vec(delta_beta_host);
3251
3252            // Back-sub δt_i = −L_i^{-T}(u_i + Y_i δβ).
3253            accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3254            trsm_batched_lower_inplace_transposed(
3255                &self.blas,
3256                &self.stream,
3257                d,
3258                n,
3259                1,
3260                &self.l_dev,
3261                &mut u_dev,
3262            )?;
3263            let x_host = self
3264                .stream
3265                .clone_dtoh(&u_dev)
3266                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3267            let mut delta_t = Array1::<f64>::zeros(n * d);
3268            for (i, v) in x_host.iter().enumerate() {
3269                delta_t[i] = -*v;
3270            }
3271
3272            Ok(ArrowSchurGpuSolution {
3273                delta_t,
3274                delta_beta,
3275                log_det_hessian: self.log_det_hessian,
3276            })
3277        }
3278    }
3279
3280    pub(super) fn solve_fused(
3281        sys: &ArrowSchurSystem,
3282        ridge_t: f64,
3283        ridge_beta: f64,
3284    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3285        let n = sys.rows.len();
3286        let d = sys.d;
3287        let k = sys.k;
3288        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3289            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3290        let p_max = plan.p_max;
3291        let r_template = plan.r_template;
3292
3293        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3294            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3295        )
3296        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3297        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3298            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3299        let stream = ctx
3300            .new_stream()
3301            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3302        let cap = &runtime.device.capability;
3303        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3304            cc_major: cap.compute_major,
3305            cc_minor: cap.compute_minor,
3306            p_max: p_max as u32,
3307            r_template: r_template as u32,
3308        };
3309        let module = fused_module_for(&ctx, key)?;
3310        let forward = module
3311            .load_function("arrow_schur_forward_pgroup")
3312            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3313        let back_sub = module
3314            .load_function("arrow_schur_back_sub_pgroup")
3315            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3316
3317        // ----- Upload packed D, B, g -----
3318        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3319        let d_dev = stream
3320            .clone_htod(&d_host)
3321            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3322        let b_dev = stream
3323            .clone_htod(&b_host)
3324            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3325        let g_dev = stream
3326            .clone_htod(&g_host)
3327            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3328        let mut l_out = stream
3329            .alloc_zeros::<f64>(n * p_max * p_max)
3330            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3331        let mut u_out = stream
3332            .alloc_zeros::<f64>(n * p_max)
3333            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3334        let mut y_out = stream
3335            .alloc_zeros::<f64>(n * p_max * r_template)
3336            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3337        let mut partial_s = stream
3338            .alloc_zeros::<f64>(plan.partial_s_doubles)
3339            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3340        let mut partial_r = stream
3341            .alloc_zeros::<f64>(plan.partial_r_doubles)
3342            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3343        let mut status_dev = stream
3344            .alloc_zeros::<i32>(n)
3345            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3346
3347        // ----- Launch forward kernel: 1 block per row, P_MAX threads -----
3348        let cfg = LaunchConfig {
3349            grid_dim: (plan.blocks, 1, 1),
3350            block_dim: (plan.threads_per_block, 1, 1),
3351            shared_mem_bytes: 0,
3352        };
3353        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3354        let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3355        let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3356        let ridge_arg = ridge_t;
3357        {
3358            let mut builder = stream.launch_builder(&forward);
3359            builder
3360                .arg(&d_dev)
3361                .arg(&b_dev)
3362                .arg(&g_dev)
3363                .arg(&n_i32)
3364                .arg(&p_i32)
3365                .arg(&r_i32)
3366                .arg(&ridge_arg)
3367                .arg(&mut l_out)
3368                .arg(&mut u_out)
3369                .arg(&mut y_out)
3370                .arg(&mut partial_s)
3371                .arg(&mut partial_r)
3372                .arg(&mut status_dev);
3373            // SAFETY: all buffers were just allocated on `stream` with sizes
3374            // derived from `plan`; kernel parameter list matches the
3375            // FORWARD_KERNEL_SOURCE signature.
3376            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3377        }
3378        stream
3379            .synchronize()
3380            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3381
3382        // ----- Check per-block pivot status -----
3383        let status_host = stream
3384            .clone_dtoh(&status_dev)
3385            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3386        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3387            let pivot = status_host[row];
3388            let scale = sys.rows[row]
3389                .htt
3390                .diag()
3391                .iter()
3392                .map(|v| v.abs())
3393                .fold(0.0_f64, f64::max)
3394                .max(1.0);
3395            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3396                row,
3397                bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3398            });
3399        }
3400
3401        // ----- Reduce partials on host into S_β and r_β -----
3402        let partial_s_host = stream
3403            .clone_dtoh(&partial_s)
3404            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3405        let partial_r_host = stream
3406            .clone_dtoh(&partial_r)
3407            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3408        let mut schur_host = vec![0.0_f64; k * k];
3409        for col in 0..k {
3410            for row in 0..k {
3411                let mut v = sys.hbb[[row, col]];
3412                if row == col {
3413                    v += ridge_beta;
3414                }
3415                schur_host[col * k + row] = v;
3416            }
3417        }
3418        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3419        for i in 0..n {
3420            // partial_S[i] stride is R_TEMPLATE × R_TEMPLATE column-major; we
3421            // only read the leading (k × k) sub-block.
3422            let s_base = i * r_template * r_template;
3423            for col in 0..k {
3424                let col_base = s_base + col * r_template;
3425                let dst_col_base = col * k;
3426                for row in 0..k {
3427                    schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3428                }
3429            }
3430            let r_base = i * r_template;
3431            for a in 0..k {
3432                rhs_host[a] += partial_r_host[r_base + a];
3433            }
3434        }
3435
3436        // ----- Factor S_β on device (cuSOLVER), solve for δβ -----
3437        let mut schur_dev = stream
3438            .clone_htod(&schur_host)
3439            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3440        let mut rhs_dev = stream
3441            .clone_htod(&rhs_host)
3442            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3443        let solver =
3444            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3446        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3447        if info != 0 {
3448            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3449                reason: format!("fused Schur Cholesky failed at pivot {info}"),
3450            });
3451        }
3452        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3453        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3454        let delta_beta_host = stream
3455            .clone_dtoh(&rhs_dev)
3456            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457        let delta_beta = Array1::from_vec(delta_beta_host.clone());
3458
3459        // ----- Layer E: launch back-sub kernel using persisted L, u, Y -----
3460        let mut delta_t_dev = stream
3461            .alloc_zeros::<f64>(n * p_max)
3462            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463        let back_cfg = LaunchConfig {
3464            grid_dim: (plan.blocks, 1, 1),
3465            block_dim: (plan.threads_per_block, 1, 1),
3466            shared_mem_bytes: 0,
3467        };
3468        {
3469            let mut builder = stream.launch_builder(&back_sub);
3470            builder
3471                .arg(&l_out)
3472                .arg(&u_out)
3473                .arg(&y_out)
3474                .arg(&rhs_dev)
3475                .arg(&n_i32)
3476                .arg(&p_i32)
3477                .arg(&r_i32)
3478                .arg(&mut delta_t_dev);
3479            // SAFETY: kernel parameter list matches FORWARD_KERNEL_SOURCE
3480            // back-sub signature; `rhs_dev` holds δβ in the leading k entries
3481            // (R_TEMPLATE-strided indexing is column 0..k of the R-vector).
3482            unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3483        }
3484        stream
3485            .synchronize()
3486            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487
3488        let delta_t_host = stream
3489            .clone_dtoh(&delta_t_dev)
3490            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3491        let mut delta_t = Array1::<f64>::zeros(n * d);
3492        for i in 0..n {
3493            let src_base = i * p_max;
3494            let dst_base = i * d;
3495            for r in 0..d {
3496                delta_t[dst_base + r] = delta_t_host[src_base + r];
3497            }
3498        }
3499
3500        // ----- log|H| = 2·Σ log L_{i,jj} + 2·Σ log R_{β,aa} -----
3501        let l_local_host = stream
3502            .clone_dtoh(&l_out)
3503            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3504        let l_schur_host = stream
3505            .clone_dtoh(&schur_dev)
3506            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3507        let mut log_det = 0.0_f64;
3508        for i in 0..n {
3509            let base = i * p_max * p_max;
3510            for j in 0..d {
3511                log_det += l_local_host[base + j * p_max + j].ln();
3512            }
3513        }
3514        for j in 0..k {
3515            log_det += l_schur_host[j * k + j].ln();
3516        }
3517        log_det *= 2.0;
3518
3519        Ok(ArrowSchurGpuSolution {
3520            delta_t,
3521            delta_beta,
3522            log_det_hessian: log_det,
3523        })
3524    }
3525
3526    /// Pre-compute `Y_i = L_i^{-1} H_tβ^(i)` via the fused forward kernel and
3527    /// return a closure that evaluates the full Schur matvec
3528    /// `S·x = (H_ββ + ρ·I)·x − Σ_i Y_i^T (Y_i·x)` for each PCG iteration.
3529    ///
3530    /// The `Y_i` factors are kept in a host-side buffer after one GPU forward
3531    /// pass. Each matvec call runs O(N·d·K) host loops over the pre-computed
3532    /// buffer plus an optional `H_ββ·x` call (matrix-free or dense). This is
3533    /// the first landing of the GPU matvec; a future iteration can move the
3534    /// `Y_i·x` / `Y_i^T z_i` steps to cuBLAS batched GEMV.
3535    pub(super) fn build_schur_matvec_backend(
3536        sys: &ArrowSchurSystem,
3537        ridge_t: f64,
3538        ridge_beta: f64,
3539    ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3540        let n = sys.rows.len();
3541        let d = sys.d;
3542        let k = sys.k;
3543        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3544            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3545        let p_max = plan.p_max;
3546        let r_template = plan.r_template;
3547
3548        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3549            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3550        )
3551        .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3552        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3553            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3554        let stream = ctx
3555            .new_stream()
3556            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3557        let cap = &runtime.device.capability;
3558        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3559            cc_major: cap.compute_major,
3560            cc_minor: cap.compute_minor,
3561            p_max: p_max as u32,
3562            r_template: r_template as u32,
3563        };
3564        let module = fused_module_for(&ctx, key)?;
3565        let forward = module
3566            .load_function("arrow_schur_forward_pgroup")
3567            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3568
3569        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3570        let d_dev = stream
3571            .clone_htod(&d_host)
3572            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3573        let b_dev = stream
3574            .clone_htod(&b_host)
3575            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3576        let g_dev = stream
3577            .clone_htod(&g_host)
3578            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3579        let mut l_out = stream
3580            .alloc_zeros::<f64>(n * p_max * p_max)
3581            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3582        let mut u_out = stream
3583            .alloc_zeros::<f64>(n * p_max)
3584            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3585        let mut y_out = stream
3586            .alloc_zeros::<f64>(n * p_max * r_template)
3587            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3588        let mut partial_s = stream
3589            .alloc_zeros::<f64>(plan.partial_s_doubles)
3590            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3591        let mut partial_r = stream
3592            .alloc_zeros::<f64>(plan.partial_r_doubles)
3593            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3594        let mut status_dev = stream
3595            .alloc_zeros::<i32>(n)
3596            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3597
3598        let cfg = LaunchConfig {
3599            grid_dim: (plan.blocks, 1, 1),
3600            block_dim: (plan.threads_per_block, 1, 1),
3601            shared_mem_bytes: 0,
3602        };
3603        let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3604        let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3605        let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3606        let ridge_arg = ridge_t;
3607        {
3608            let mut builder = stream.launch_builder(&forward);
3609            builder
3610                .arg(&d_dev)
3611                .arg(&b_dev)
3612                .arg(&g_dev)
3613                .arg(&n_i32)
3614                .arg(&p_i32)
3615                .arg(&r_i32)
3616                .arg(&ridge_arg)
3617                .arg(&mut l_out)
3618                .arg(&mut u_out)
3619                .arg(&mut y_out)
3620                .arg(&mut partial_s)
3621                .arg(&mut partial_r)
3622                .arg(&mut status_dev);
3623            // SAFETY: all buffers were allocated on `stream` with sizes
3624            // derived from `plan`; parameter list matches FORWARD_KERNEL_SOURCE.
3625            unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3626        }
3627        stream
3628            .synchronize()
3629            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3630
3631        let status_host = stream
3632            .clone_dtoh(&status_dev)
3633            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3634        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3635            let pivot = status_host[row];
3636            let scale = sys.rows[row]
3637                .htt
3638                .diag()
3639                .iter()
3640                .map(|v| v.abs())
3641                .fold(0.0_f64, f64::max)
3642                .max(1.0);
3643            return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3644                row,
3645                bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3646            });
3647        }
3648
3649        // Download Y_i factors: n × p_max × r_template column-major per block.
3650        let y_host = stream
3651            .clone_dtoh(&y_out)
3652            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3653
3654        // Capture H_ββ data for the closure. Use the matrix-free hook if present
3655        // (SAE-manifold callers), otherwise fall back to the dense matrix rows.
3656        let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3657        let hbb_is_kk = sys.hbb.dim() == (k, k);
3658        let hbb_matvec_opt = sys.hbb_matvec.clone();
3659
3660        let closure: crate::arrow_schur::GpuSchurMatvec =
3661            Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3662                assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3663                assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3664
3665                // (H_ββ + ρ·I)·x into out.
3666                if let Some(ref mv) = hbb_matvec_opt {
3667                    mv(x.view(), out);
3668                    for a in 0..k {
3669                        out[a] += ridge_beta * x[a];
3670                    }
3671                } else if hbb_is_kk {
3672                    // hbb_host row-major: hbb[a, b] = hbb_host[a * k + b].
3673                    for a in 0..k {
3674                        let mut acc = ridge_beta * x[a];
3675                        for b in 0..k {
3676                            acc += hbb_host[a * k + b] * x[b];
3677                        }
3678                        out[a] = acc;
3679                    }
3680                } else {
3681                    for a in 0..k {
3682                        out[a] = ridge_beta * x[a];
3683                    }
3684                }
3685
3686                // out[c] -= Σ_i (Y_i^T (Y_i·x))[c].
3687                // Y_i column-major at y_host[i·p_max·r_template + col·p_max + row].
3688                let mut z = vec![0.0_f64; d];
3689                for i in 0..n {
3690                    let y_base = i * p_max * r_template;
3691                    for r in 0..d {
3692                        let mut acc = 0.0;
3693                        for c in 0..k {
3694                            acc += y_host[y_base + c * p_max + r] * x[c];
3695                        }
3696                        z[r] = acc;
3697                    }
3698                    for c in 0..k {
3699                        let mut acc = 0.0;
3700                        for r in 0..d {
3701                            acc += y_host[y_base + c * p_max + r] * z[r];
3702                        }
3703                        out[c] -= acc;
3704                    }
3705                }
3706            });
3707
3708        Ok(closure)
3709    }
3710
3711    // ── #1017/#1026 frames-engaged device PCG ──────────────────────────────
3712
3713    struct DeviceSaeFrameBuffers {
3714        // Smooth `λ S_k ⊗ I_{r_k}`.
3715        s_off: CudaSlice<i32>,
3716        s_m: CudaSlice<i32>,
3717        s_r: CudaSlice<i32>,
3718        s_ptr: CudaSlice<i32>,
3719        s_data: CudaSlice<f64>,
3720        s_blocks: usize,
3721        // Data `G_{ij} ⊗ W_{ij}`.
3722        g_off_i: CudaSlice<i32>,
3723        g_off_j: CudaSlice<i32>,
3724        g_ri: CudaSlice<i32>,
3725        g_rj: CudaSlice<i32>,
3726        g_mi: CudaSlice<i32>,
3727        g_mj: CudaSlice<i32>,
3728        g_ptr: CudaSlice<i32>,
3729        g_data: CudaSlice<f64>,
3730        w_ptr: CudaSlice<i32>,
3731        w_data: CudaSlice<f64>,
3732        g_blocks: usize,
3733        g_max_work: usize,
3734        // Per-row dense cross-block H_tβ^(i) + row q + factored ainv.
3735        htb_ptr: CudaSlice<i32>,
3736        htb: CudaSlice<f64>,
3737        q_of: CudaSlice<i32>,
3738        ainv: CudaSlice<f64>,
3739        hvec: CudaSlice<f64>,
3740        svec: CudaSlice<f64>,
3741        n_rows: usize,
3742        k: usize,
3743        max_q: usize,
3744    }
3745
3746    fn flatten_device_sae_frame_data(
3747        sys: &ArrowSchurSystem,
3748        data: &DeviceSaePcgData,
3749        frame: &DeviceSaeFrameData,
3750        ridge_t: f64,
3751        stream: &Arc<CudaStream>,
3752    ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3753        let n_rows = sys.rows.len();
3754        let k = data.beta_dim;
3755        if frame.row_htbeta.len() != n_rows
3756            || frame.ranks.len() != frame.basis_sizes.len()
3757            || frame.border_offsets.len() != frame.ranks.len()
3758            || data.smooth_blocks.len() != frame.smooth_ranks.len()
3759        {
3760            return Err(ArrowSchurGpuFailure::Unavailable);
3761        }
3762
3763        // Smooth blocks.
3764        let mut s_off = Vec::new();
3765        let mut s_m = Vec::new();
3766        let mut s_r = Vec::new();
3767        let mut s_ptr = vec![0_i32];
3768        let mut s_data = Vec::<f64>::new();
3769        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3770            let (m, mc) = blk.factor_a.dim();
3771            if m != mc {
3772                return Err(ArrowSchurGpuFailure::Unavailable);
3773            }
3774            s_off.push(checked_i32(blk.global_offset)?);
3775            s_m.push(checked_i32(m)?);
3776            s_r.push(checked_i32(r)?);
3777            for ri in 0..m {
3778                for ci in 0..m {
3779                    s_data.push(blk.factor_a[[ri, ci]]);
3780                }
3781            }
3782            s_ptr.push(checked_i32(s_data.len())?);
3783        }
3784
3785        // Data blocks (g + w).
3786        let mut g_off_i = Vec::new();
3787        let mut g_off_j = Vec::new();
3788        let mut g_ri = Vec::new();
3789        let mut g_rj = Vec::new();
3790        let mut g_mi = Vec::new();
3791        let mut g_mj = Vec::new();
3792        let mut g_ptr = vec![0_i32];
3793        let mut g_data = Vec::<f64>::new();
3794        let mut w_ptr = vec![0_i32];
3795        let mut w_data = Vec::<f64>::new();
3796        let mut g_max_work = 0usize;
3797        for blk in &frame.frame_blocks {
3798            let ri = frame.ranks[blk.atom_i];
3799            let rj = frame.ranks[blk.atom_j];
3800            let (mi, mj) = blk.g.dim();
3801            if blk.w.dim() != (ri, rj) {
3802                return Err(ArrowSchurGpuFailure::Unavailable);
3803            }
3804            g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3805            g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3806            g_ri.push(checked_i32(ri)?);
3807            g_rj.push(checked_i32(rj)?);
3808            g_mi.push(checked_i32(mi)?);
3809            g_mj.push(checked_i32(mj)?);
3810            for r in 0..mi {
3811                for c in 0..mj {
3812                    g_data.push(blk.g[[r, c]]);
3813                }
3814            }
3815            g_ptr.push(checked_i32(g_data.len())?);
3816            for a in 0..ri {
3817                for b in 0..rj {
3818                    w_data.push(blk.w[[a, b]]);
3819                }
3820            }
3821            w_ptr.push(checked_i32(w_data.len())?);
3822            g_max_work = g_max_work.max(mi * ri);
3823        }
3824
3825        // Per-row dense cross-block + q + ainv (factored H_tt⁻¹).
3826        let mut htb_ptr = vec![0_i32];
3827        let mut htb = Vec::<f64>::new();
3828        let mut q_of = Vec::<i32>::with_capacity(n_rows);
3829        let mut max_q = 0usize;
3830        for (i, slab) in frame.row_htbeta.iter().enumerate() {
3831            let qi = sys.row_dims[i];
3832            // A populated slab must be q_i × k row-major; an empty slab ⇒ q=0
3833            // (the row contributes no reduced-Schur term).
3834            let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3835                qi
3836            } else {
3837                0
3838            };
3839            q_of.push(checked_i32(q_eff)?);
3840            max_q = max_q.max(q_eff);
3841            if q_eff > 0 {
3842                htb.extend_from_slice(slab);
3843            }
3844            htb_ptr.push(checked_i32(htb.len())?);
3845        }
3846        if max_q == 0 {
3847            // No row contributes a reduced term — the system is pure-penalty.
3848            // Still valid; give max_q=1 so the ainv buffer is non-empty.
3849            max_q = 1;
3850        }
3851
3852        let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3853        for (i, row) in sys.rows.iter().enumerate() {
3854            let q = q_of[i] as usize;
3855            if q == 0 {
3856                continue;
3857            }
3858            if row.htt.dim() != (q, q) {
3859                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3860                    reason: format!(
3861                        "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3862                        row.htt.dim()
3863                    ),
3864                });
3865            }
3866            let mut block = row.htt.clone();
3867            for d in 0..q {
3868                block[[d, d]] += ridge_t;
3869            }
3870            let factor = gam_linalg::triangular::cholesky_factor_in_place(
3871                block.view(),
3872                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3873            )
3874            .ok_or_else(|| {
3875                let scale = row
3876                    .htt
3877                    .diag()
3878                    .iter()
3879                    .map(|v| v.abs())
3880                    .fold(0.0_f64, f64::max)
3881                    .max(1.0);
3882                ArrowSchurGpuFailure::RidgeBumpRequired {
3883                    row: i,
3884                    bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
3885                }
3886            })?;
3887            for col in 0..q {
3888                let mut e = Array1::<f64>::zeros(q);
3889                e[col] = 1.0;
3890                let solved =
3891                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3892                for r in 0..q {
3893                    ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3894                }
3895            }
3896        }
3897
3898        let htod_i = |v: &[i32]| {
3899            stream
3900                .clone_htod(v)
3901                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3902        };
3903        let htod_f = |v: &[f64]| {
3904            stream
3905                .clone_htod(v)
3906                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3907        };
3908        Ok(DeviceSaeFrameBuffers {
3909            s_off: htod_i(&s_off)?,
3910            s_m: htod_i(&s_m)?,
3911            s_r: htod_i(&s_r)?,
3912            s_ptr: htod_i(&s_ptr)?,
3913            s_data: htod_f(&s_data)?,
3914            s_blocks: data.smooth_blocks.len(),
3915            g_off_i: htod_i(&g_off_i)?,
3916            g_off_j: htod_i(&g_off_j)?,
3917            g_ri: htod_i(&g_ri)?,
3918            g_rj: htod_i(&g_rj)?,
3919            g_mi: htod_i(&g_mi)?,
3920            g_mj: htod_i(&g_mj)?,
3921            g_ptr: htod_i(&g_ptr)?,
3922            g_data: htod_f(&g_data)?,
3923            w_ptr: htod_i(&w_ptr)?,
3924            w_data: htod_f(&w_data)?,
3925            g_blocks: frame.frame_blocks.len(),
3926            g_max_work,
3927            htb_ptr: htod_i(&htb_ptr)?,
3928            htb: htod_f(&htb)?,
3929            q_of: htod_i(&q_of)?,
3930            ainv: htod_f(&ainv)?,
3931            hvec: stream
3932                .alloc_zeros::<f64>(n_rows * max_q)
3933                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
3934            svec: stream
3935                .alloc_zeros::<f64>(n_rows * max_q)
3936                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
3937            n_rows,
3938            k,
3939            max_q,
3940        })
3941    }
3942
3943    fn sae_frame_penalty_diag_host(
3944        data: &DeviceSaePcgData,
3945        frame: &DeviceSaeFrameData,
3946        ridge_beta: f64,
3947    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
3948        let mut diag = vec![ridge_beta; data.beta_dim];
3949        // Smooth: diag[off + ia·r + ib] += S[ia,ia].
3950        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3951            let m = blk.factor_a.nrows();
3952            for ia in 0..m {
3953                let coeff = blk.factor_a[[ia, ia]];
3954                let base = blk.global_offset + ia * r;
3955                for ib in 0..r {
3956                    if base + ib >= diag.len() {
3957                        return Err(ArrowSchurGpuFailure::Unavailable);
3958                    }
3959                    diag[base + ib] += coeff;
3960                }
3961            }
3962        }
3963        // Data: on-diagonal atom blocks contribute g[li,li]·w[a,a].
3964        for blk in &frame.frame_blocks {
3965            if blk.atom_i != blk.atom_j {
3966                continue;
3967            }
3968            let r = frame.ranks[blk.atom_i];
3969            let off = frame.border_offsets[blk.atom_i];
3970            let (mi, mj) = blk.g.dim();
3971            for li in 0..mi.min(mj) {
3972                let gii = blk.g[[li, li]];
3973                let base = off + li * r;
3974                for a in 0..r {
3975                    if base + a >= diag.len() {
3976                        return Err(ArrowSchurGpuFailure::Unavailable);
3977                    }
3978                    diag[base + a] += gii * blk.w[[a, a]];
3979                }
3980            }
3981        }
3982        Ok(diag)
3983    }
3984
3985    fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
3986        Ok(LaunchConfig {
3987            grid_dim: (
3988                ((work as u32).saturating_add(255) / 256).max(1),
3989                checked_i32(n_rows)? as u32,
3990                1,
3991            ),
3992            block_dim: (256, 1, 1),
3993            shared_mem_bytes: 0,
3994        })
3995    }
3996
3997    fn launch_sae_frame_matvec(
3998        stream: &Arc<CudaStream>,
3999        module: &Arc<CudaModule>,
4000        buffers: &mut DeviceSaeFrameBuffers,
4001        x: &CudaSlice<f64>,
4002        out: &mut CudaSlice<f64>,
4003        ridge_beta: f64,
4004    ) -> Result<(), ArrowSchurGpuFailure> {
4005        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4006        // Smooth penalty.
4007        if buffers.s_blocks > 0 {
4008            let kernel = module
4009                .load_function("arrow_sae_frame_smooth_matvec")
4010                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4011            let blocks_i32 = checked_i32(buffers.s_blocks)?;
4012            let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4013            let mut b = stream.launch_builder(&kernel);
4014            b.arg(x)
4015                .arg(&mut *out)
4016                .arg(&buffers.s_off)
4017                .arg(&buffers.s_m)
4018                .arg(&buffers.s_r)
4019                .arg(&buffers.s_ptr)
4020                .arg(&buffers.s_data)
4021                .arg(&blocks_i32);
4022            // SAFETY: smooth block metadata/data are live device buffers; the grid
4023            // covers (k channels × n_blocks) and the kernel bounds-checks m·r.
4024            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4025        }
4026        // Data penalty.
4027        if buffers.g_blocks > 0 {
4028            let kernel = module
4029                .load_function("arrow_sae_frame_g_matvec")
4030                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4031            let blocks_i32 = checked_i32(buffers.g_blocks)?;
4032            let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4033            let mut b = stream.launch_builder(&kernel);
4034            b.arg(x)
4035                .arg(&mut *out)
4036                .arg(&buffers.g_off_i)
4037                .arg(&buffers.g_off_j)
4038                .arg(&buffers.g_ri)
4039                .arg(&buffers.g_rj)
4040                .arg(&buffers.g_mi)
4041                .arg(&buffers.g_mj)
4042                .arg(&buffers.g_ptr)
4043                .arg(&buffers.g_data)
4044                .arg(&buffers.w_ptr)
4045                .arg(&buffers.w_data)
4046                .arg(&blocks_i32);
4047            // SAFETY: g/w block metadata/data are live device buffers; the grid
4048            // covers (max m_i·r_i × n_blocks) and the kernel bounds-checks.
4049            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4050        }
4051        // Reduced-Schur subtraction via dense per-row cross-blocks.
4052        let k_i32 = checked_i32(buffers.k)?;
4053        let max_q_i32 = checked_i32(buffers.max_q)?;
4054        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4055        {
4056            let kernel = module
4057                .load_function("arrow_sae_frame_apply_h")
4058                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4059            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4060            let mut b = stream.launch_builder(&kernel);
4061            b.arg(x)
4062                .arg(&buffers.htb_ptr)
4063                .arg(&buffers.htb)
4064                .arg(&buffers.q_of)
4065                .arg(&mut buffers.hvec)
4066                .arg(&k_i32)
4067                .arg(&max_q_i32)
4068                .arg(&n_rows_i32);
4069            // SAFETY: dense cross-block + pointers + hvec are live buffers sized
4070            // for (n_rows × max_q) / (n_rows × k); kernel guards q_i and k.
4071            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4072        }
4073        {
4074            let kernel = module
4075                .load_function("arrow_sae_frame_apply_ainv")
4076                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4077            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4078            let mut b = stream.launch_builder(&kernel);
4079            b.arg(&buffers.ainv)
4080                .arg(&buffers.hvec)
4081                .arg(&buffers.q_of)
4082                .arg(&mut buffers.svec)
4083                .arg(&max_q_i32)
4084                .arg(&n_rows_i32);
4085            // SAFETY: ainv/hvec/svec are live buffers sized for n_rows·max_q²
4086            // and n_rows·max_q; the kernel guards row/coord bounds.
4087            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4088        }
4089        {
4090            let kernel = module
4091                .load_function("arrow_sae_frame_scatter_h")
4092                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4093            let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4094            let mut b = stream.launch_builder(&kernel);
4095            b.arg(&buffers.svec)
4096                .arg(&buffers.htb_ptr)
4097                .arg(&buffers.htb)
4098                .arg(&buffers.q_of)
4099                .arg(out)
4100                .arg(&k_i32)
4101                .arg(&max_q_i32)
4102                .arg(&n_rows_i32);
4103            // SAFETY: svec/cross-block/out are live buffers; the kernel atomically
4104            // accumulates into out[a] for a<k and reads c<q_i.
4105            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4106        }
4107        Ok(())
4108    }
4109
4110    fn launch_sae_frame_diag_sub(
4111        stream: &Arc<CudaStream>,
4112        module: &Arc<CudaModule>,
4113        buffers: &DeviceSaeFrameBuffers,
4114        diag: &mut CudaSlice<f64>,
4115    ) -> Result<(), ArrowSchurGpuFailure> {
4116        let kernel = module
4117            .load_function("arrow_sae_frame_diag_sub")
4118            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4119        let k_i32 = checked_i32(buffers.k)?;
4120        let max_q_i32 = checked_i32(buffers.max_q)?;
4121        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4122        let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4123        let mut b = stream.launch_builder(&kernel);
4124        b.arg(diag)
4125            .arg(&buffers.ainv)
4126            .arg(&buffers.htb_ptr)
4127            .arg(&buffers.htb)
4128            .arg(&buffers.q_of)
4129            .arg(&k_i32)
4130            .arg(&max_q_i32)
4131            .arg(&n_rows_i32);
4132        // SAFETY: diag + cross-block + ainv live buffers; kernel guards a<k, c/d<q.
4133        unsafe { b.launch(cfg) }
4134            .map(drop)
4135            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4136    }
4137
4138    pub(super) fn solve_sae_matrix_free_pcg_framed(
4139        sys: &ArrowSchurSystem,
4140        data: &DeviceSaePcgData,
4141        ridge_t: f64,
4142        ridge_beta: f64,
4143        rhs_beta: &Array1<f64>,
4144        max_iterations: usize,
4145        relative_tolerance: f64,
4146    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4147        let k = rhs_beta.len();
4148        if k == 0 || data.beta_dim != k || sys.k != k {
4149            return Err(ArrowSchurGpuFailure::Unavailable);
4150        }
4151        let frame = data
4152            .frame
4153            .as_ref()
4154            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4155        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4156            .filter(|rt| {
4157                rt.policy().reduced_schur_matvec_should_offload(
4158                    sys.rows.len(),
4159                    sys.k,
4160                    sys.d,
4161                    max_iterations,
4162                )
4163            })
4164            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4165        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4166            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4167        let stream = ctx
4168            .new_stream()
4169            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4170        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4171        let vector_module = pcg_vector_module(&ctx)?;
4172        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4173
4174        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4175        if rhs_norm == 0.0 {
4176            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4177        }
4178        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4179        let rhs_dev = stream
4180            .clone_htod(
4181                rhs_beta
4182                    .as_slice()
4183                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4184            )
4185            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4186        let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4187        let mut diag_dev = stream
4188            .clone_htod(&diag_host)
4189            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4190        launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4191        let diag_host = stream
4192            .clone_dtoh(&diag_dev)
4193            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4194        let mut inv_diag = Vec::with_capacity(k);
4195        for (idx, &d) in diag_host.iter().enumerate() {
4196            if !d.is_finite() || d <= 1.0e-18 {
4197                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4198                    reason: format!(
4199                        "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4200                    ),
4201                });
4202            }
4203            inv_diag.push(1.0 / d);
4204        }
4205        let inv_diag_dev = stream
4206            .clone_htod(&inv_diag)
4207            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4208
4209        let mut x_dev = stream
4210            .alloc_zeros::<f64>(k)
4211            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4212        let mut r_dev = stream
4213            .alloc_zeros::<f64>(k)
4214            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4215        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4216        let mut z_dev = stream
4217            .alloc_zeros::<f64>(k)
4218            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4219        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4220        let mut p_dev = stream
4221            .alloc_zeros::<f64>(k)
4222            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4223        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4224        let mut ap_dev = stream
4225            .alloc_zeros::<f64>(k)
4226            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4227
4228        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4229        if rz <= 0.0 || !rz.is_finite() {
4230            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4231                reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4232            });
4233        }
4234        let mut diag = PcgDiagnostics {
4235            precond_apply_calls: 1,
4236            stopping_reason: PcgStopReason::MaxIter,
4237            ..PcgDiagnostics::default()
4238        };
4239        for _ in 0..max_iterations.max(1) {
4240            launch_sae_frame_matvec(
4241                &stream,
4242                vector_module,
4243                &mut buffers,
4244                &p_dev,
4245                &mut ap_dev,
4246                ridge_beta,
4247            )?;
4248            diag.matvec_calls += 1;
4249            diag.iterations += 1;
4250            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4251            if pap <= 0.0 || !pap.is_finite() {
4252                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4253                    reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4254                });
4255            }
4256            let alpha = rz / pap;
4257            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4258            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4259            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4260            if r_norm <= tol {
4261                diag.final_relative_residual = r_norm / rhs_norm;
4262                diag.stopping_reason = PcgStopReason::Converged;
4263                break;
4264            }
4265            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4266            diag.precond_apply_calls += 1;
4267            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4268            if rz_new <= 0.0 || !rz_new.is_finite() {
4269                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4270                    reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4271                });
4272            }
4273            let beta = rz_new / rz;
4274            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4275            rz = rz_new;
4276        }
4277        if diag.stopping_reason != PcgStopReason::Converged {
4278            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4279            diag.final_relative_residual = r_norm / rhs_norm;
4280            diag.stopping_reason = PcgStopReason::MaxIter;
4281        }
4282        let x = stream
4283            .clone_dtoh(&x_dev)
4284            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4285        Ok((Array1::from_vec(x), diag))
4286    }
4287
4288    /// #1551 stage-isolating triage seam: run the framed reduced-Schur matvec
4289    /// `out = S·x` ONCE on the device (no PCG, no offload-floor gate) and return
4290    /// `out`, so a tiny hand-verifiable fixture can diff it against the CPU oracle
4291    /// `sae_framed_schur_matvec_cpu` element-by-element to localize the structural
4292    /// divergence to a single kernel stage. Returns `Unavailable` only when CUDA
4293    /// is genuinely absent (so the test skips cleanly off-device).
4294    pub(super) fn solve_sae_matrix_free_pcg(
4295        sys: &ArrowSchurSystem,
4296        data: &DeviceSaePcgData,
4297        ridge_t: f64,
4298        ridge_beta: f64,
4299        rhs_beta: &Array1<f64>,
4300        max_iterations: usize,
4301        relative_tolerance: f64,
4302    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4303        let k = rhs_beta.len();
4304        if k == 0 || data.beta_dim != k || sys.k != k {
4305            return Err(ArrowSchurGpuFailure::Unavailable);
4306        }
4307        // #1017/#1026 GUARD: the legacy `⊗ I_p` kernel must NEVER receive framed
4308        // data (factored `G ⊗ W_{ij}` + dense per-row cross blocks); decline so a
4309        // mis-route falls back to the CPU rather than returning a wrong step.
4310        if data.frame.is_some() {
4311            return Err(ArrowSchurGpuFailure::Unavailable);
4312        }
4313        // #1017 Phase-1 dispatch re-key: this is the matrix-free SAE reduced-Schur
4314        // PCG — the production hot path, not a single dense factorization. The
4315        // dense-Direct floor `dense_hessian_work_target_is_gpu(n, k)` keys on
4316        // `2·n·k²` and is the WRONG gate here: it ignores the per-row frame depth
4317        // `d` (the M dimension that multiplies the per-apply work) and the
4318        // `1/cg_iters` staging amortisation, so it both undercounts the SAE batched
4319        // work `n·k·d` and applies a cold single-launch breakeven to an apply that
4320        // reuses device-resident frames `max_iterations` times. Key instead on the
4321        // CG-amortised total batched work — the same predicate the host injection
4322        // gate (`maybe_inject_gpu_schur_matvec`) consults — so few-row/wide-`k`/
4323        // modest-`d` LLM shapes register the real `n × k × d × cg_iters` arithmetic.
4324        // Kernels and numerics are untouched; only where the matvec runs changes,
4325        // and the host falls back to the bit-identical CPU matvec when this declines.
4326        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4327            .filter(|rt| {
4328                rt.policy().reduced_schur_matvec_should_offload(
4329                    sys.rows.len(),
4330                    sys.k,
4331                    sys.d,
4332                    max_iterations,
4333                )
4334            })
4335            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4336        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4337            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4338        let stream = ctx
4339            .new_stream()
4340            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4341        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4342        let vector_module = pcg_vector_module(&ctx)?;
4343        let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4344
4345        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4346        if rhs_norm == 0.0 {
4347            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4348        }
4349        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4350        let rhs_dev = stream
4351            .clone_htod(
4352                rhs_beta
4353                    .as_slice()
4354                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4355            )
4356            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4357        let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4358        let mut diag_dev = stream
4359            .clone_htod(&diag_host)
4360            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4361        launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4362        let diag_host = stream
4363            .clone_dtoh(&diag_dev)
4364            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4365        let mut inv_diag = Vec::with_capacity(k);
4366        for (idx, &d) in diag_host.iter().enumerate() {
4367            if !d.is_finite() || d <= 1.0e-18 {
4368                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4369                    reason: format!(
4370                        "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4371                    ),
4372                });
4373            }
4374            inv_diag.push(1.0 / d);
4375        }
4376        let inv_diag_dev = stream
4377            .clone_htod(&inv_diag)
4378            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4379
4380        let mut x_dev = stream
4381            .alloc_zeros::<f64>(k)
4382            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4383        let mut r_dev = stream
4384            .alloc_zeros::<f64>(k)
4385            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4386        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4387        let mut z_dev = stream
4388            .alloc_zeros::<f64>(k)
4389            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4390        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4391        let mut p_dev = stream
4392            .alloc_zeros::<f64>(k)
4393            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4394        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4395        let mut ap_dev = stream
4396            .alloc_zeros::<f64>(k)
4397            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4398
4399        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4400        if rz <= 0.0 || !rz.is_finite() {
4401            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4402                reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4403            });
4404        }
4405        let mut diag = PcgDiagnostics {
4406            precond_apply_calls: 1,
4407            stopping_reason: PcgStopReason::MaxIter,
4408            ..PcgDiagnostics::default()
4409        };
4410
4411        for _ in 0..max_iterations.max(1) {
4412            launch_sae_matvec(
4413                &stream,
4414                vector_module,
4415                &mut buffers,
4416                &p_dev,
4417                &mut ap_dev,
4418                ridge_beta,
4419            )?;
4420            diag.matvec_calls += 1;
4421            diag.iterations += 1;
4422            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4423            if pap <= 0.0 || !pap.is_finite() {
4424                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4425                    reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4426                });
4427            }
4428            let alpha = rz / pap;
4429            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4430            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4431            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4432            if r_norm <= tol {
4433                diag.final_relative_residual = r_norm / rhs_norm;
4434                diag.stopping_reason = PcgStopReason::Converged;
4435                break;
4436            }
4437            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4438            diag.precond_apply_calls += 1;
4439            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4440            if rz_new <= 0.0 || !rz_new.is_finite() {
4441                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4442                    reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4443                });
4444            }
4445            let beta = rz_new / rz;
4446            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4447            rz = rz_new;
4448        }
4449        if diag.stopping_reason != PcgStopReason::Converged {
4450            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4451            diag.final_relative_residual = r_norm / rhs_norm;
4452            diag.stopping_reason = PcgStopReason::MaxIter;
4453        }
4454        let x = stream
4455            .clone_dtoh(&x_dev)
4456            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4457        Ok((Array1::from_vec(x), diag))
4458    }
4459
4460    pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4461        s_acc: &ndarray::Array2<f64>,
4462        rhs_beta: &Array1<f64>,
4463        max_iterations: usize,
4464        relative_tolerance: f64,
4465    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4466        let k = rhs_beta.len();
4467        // #1017 dispatch re-key: this is an ITERATIVE device-resident PCG, not a
4468        // single GEMV. `S` (k×k) is uploaded once and reused for `max_iterations`
4469        // `S·p` GEMVs while only convergence scalars cross PCIe, so the staging
4470        // cost is amortised over the whole CG solve. Gating on the flops of ONE
4471        // `Gemv{k,k}` (`2·k²`) understates the work by the iteration count and
4472        // declines shapes (e.g. k≈512) whose total iterated arithmetic
4473        // `2·k²·iters` clears the device floor by orders of magnitude — the same
4474        // single-launch-breakeven miskey #1017 fixed for the framed reduced-Schur
4475        // matvec. Key on the CG-amortised total work via a `Gemm{k,k,iters}` whose
4476        // `flops()` is exactly `2·k²·iters`; numerics and kernels are untouched,
4477        // and the host falls back to the bit-identical CPU PCG when this declines.
4478        let cg_iters = max_iterations.max(1);
4479        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4480            gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4481                m: k,
4482                n: k,
4483                k: cg_iters,
4484            },
4485        )
4486        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4487        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4488            .and_then(|ctx| ctx.new_stream().ok())
4489            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4490        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4491        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4492            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4493        let vector_module = pcg_vector_module(&ctx)?;
4494
4495        // Jacobi diagonal from S; must be strictly positive for SPD.
4496        let mut inv_diag = vec![0.0_f64; k];
4497        for j in 0..k {
4498            let djj = s_acc[[j, j]];
4499            if !(djj > 0.0) {
4500                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4501                    reason: format!(
4502                        "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4503                    ),
4504                });
4505            }
4506            inv_diag[j] = 1.0 / djj;
4507        }
4508
4509        // Upload S column-major (S[row,col] at col*k + row).
4510        let mut s_host = vec![0.0_f64; k * k];
4511        for col in 0..k {
4512            for row in 0..k {
4513                s_host[col * k + row] = s_acc[[row, col]];
4514            }
4515        }
4516        let s_dev = stream
4517            .clone_htod(&s_host)
4518            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4519
4520        // Steihaug truncated-CG with Jacobi preconditioner, host scalar
4521        // recurrences and a device `S·p` matvec. The streaming reduced solve
4522        // uses an unbounded trust region (pure CG to tolerance).
4523        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4524        if rhs_norm == 0.0 {
4525            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4526        }
4527        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4528
4529        // Device-resident PCG state. Only convergence scalars cross back during
4530        // the loop; x/r/z/p/Sp stay on CUDA until the final solution download.
4531        let mut x_dev = stream
4532            .alloc_zeros::<f64>(k)
4533            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4534        let mut r_dev = stream
4535            .clone_htod(
4536                rhs_beta
4537                    .as_slice()
4538                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4539            )
4540            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541        let inv_diag_dev = stream
4542            .clone_htod(&inv_diag)
4543            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544        let mut z_dev = stream
4545            .alloc_zeros::<f64>(k)
4546            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4547        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4548        let mut p_dev = stream
4549            .alloc_zeros::<f64>(k)
4550            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4551        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4552        let mut sp_dev = stream
4553            .alloc_zeros::<f64>(k)
4554            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4555        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4556        let mut diag = PcgDiagnostics {
4557            precond_apply_calls: 1,
4558            stopping_reason: PcgStopReason::MaxIter,
4559            ..PcgDiagnostics::default()
4560        };
4561        if rz <= 0.0 || !rz.is_finite() {
4562            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4563                reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4564            });
4565        }
4566
4567        let max_iters = max_iterations.max(1);
4568        for _ in 0..max_iters {
4569            // sp = S · p (device GEMV, S column-major k×k, op = N).
4570            let gemv_cfg = GemvConfig::<f64> {
4571                trans: cublasOperation_t::CUBLAS_OP_N,
4572                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4573                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4574                alpha: 1.0,
4575                lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4576                incx: 1,
4577                beta: 0.0,
4578                incy: 1,
4579            };
4580            // SAFETY: s_dev is k×k column-major, p_dev / sp_dev length k.
4581            unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4582                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4583            diag.matvec_calls += 1;
4584            diag.iterations += 1;
4585
4586            let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4587            if !(p_sp > 0.0) {
4588                // Non-positive curvature on a (proximal-ridged) SPD system means
4589                // numerical breakdown; surface so the caller escalates.
4590                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4591                    reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4592                });
4593            }
4594            let alpha = rz / p_sp;
4595            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4596            device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4597            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4598            if r_norm <= tol {
4599                diag.final_relative_residual = r_norm / rhs_norm;
4600                diag.stopping_reason = PcgStopReason::Converged;
4601                break;
4602            }
4603            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4604            diag.precond_apply_calls += 1;
4605            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4606            if rz_new <= 0.0 || !rz_new.is_finite() {
4607                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4608                    reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4609                });
4610            }
4611            let beta = rz_new / rz;
4612            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4613            rz = rz_new;
4614        }
4615        if diag.stopping_reason != PcgStopReason::Converged {
4616            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4617            diag.final_relative_residual = r_norm / rhs_norm;
4618            diag.stopping_reason = PcgStopReason::MaxIter;
4619        }
4620
4621        let x = stream
4622            .clone_dtoh(&x_dev)
4623            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4624        Ok((Array1::from_vec(x), diag))
4625    }
4626
4627    fn device_copy(
4628        blas: &CudaBlas,
4629        stream: &Arc<CudaStream>,
4630        n: usize,
4631        src: &CudaSlice<f64>,
4632        dst: &mut CudaSlice<f64>,
4633    ) -> Result<(), ArrowSchurGpuFailure> {
4634        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4635        let (src_ptr, _src_rec) = src.device_ptr(stream);
4636        let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4637        // SAFETY: src and dst are live device allocations on this stream with at
4638        // least n contiguous f64 entries and unit stride.
4639        let status = unsafe {
4640            cudarc::cublas::sys::cublasDcopy_v2(
4641                *blas.handle(),
4642                n_i,
4643                src_ptr as *const f64,
4644                1,
4645                dst_ptr as *mut f64,
4646                1,
4647            )
4648        };
4649        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4650            Ok(())
4651        } else {
4652            Err(ArrowSchurGpuFailure::Unavailable)
4653        }
4654    }
4655
4656    fn device_axpy(
4657        blas: &CudaBlas,
4658        stream: &Arc<CudaStream>,
4659        n: usize,
4660        alpha: f64,
4661        x: &CudaSlice<f64>,
4662        y: &mut CudaSlice<f64>,
4663    ) -> Result<(), ArrowSchurGpuFailure> {
4664        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4665        let (x_ptr, _x_rec) = x.device_ptr(stream);
4666        let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4667        // SAFETY: x and y are live device allocations on this stream with at
4668        // least n contiguous f64 entries and unit stride; cuBLAS only reads alpha.
4669        let status = unsafe {
4670            cudarc::cublas::sys::cublasDaxpy_v2(
4671                *blas.handle(),
4672                n_i,
4673                &alpha,
4674                x_ptr as *const f64,
4675                1,
4676                y_ptr as *mut f64,
4677                1,
4678            )
4679        };
4680        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4681            Ok(())
4682        } else {
4683            Err(ArrowSchurGpuFailure::Unavailable)
4684        }
4685    }
4686
4687    fn device_dot(
4688        blas: &CudaBlas,
4689        stream: &Arc<CudaStream>,
4690        n: usize,
4691        x: &CudaSlice<f64>,
4692        y: &CudaSlice<f64>,
4693    ) -> Result<f64, ArrowSchurGpuFailure> {
4694        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4695        let (x_ptr, _x_rec) = x.device_ptr(stream);
4696        let (y_ptr, _y_rec) = y.device_ptr(stream);
4697        let mut result = 0.0_f64;
4698        // SAFETY: x and y are live device allocations on this stream with at
4699        // least n contiguous f64 entries and unit stride; result is a valid host
4700        // out-pointer for the cuBLAS scalar.
4701        let status = unsafe {
4702            cudarc::cublas::sys::cublasDdot_v2(
4703                *blas.handle(),
4704                n_i,
4705                x_ptr as *const f64,
4706                1,
4707                y_ptr as *const f64,
4708                1,
4709                &mut result,
4710            )
4711        };
4712        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4713            Ok(result)
4714        } else {
4715            Err(ArrowSchurGpuFailure::Unavailable)
4716        }
4717    }
4718
4719    fn device_nrm2(
4720        blas: &CudaBlas,
4721        stream: &Arc<CudaStream>,
4722        n: usize,
4723        x: &CudaSlice<f64>,
4724    ) -> Result<f64, ArrowSchurGpuFailure> {
4725        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4726        let (x_ptr, _x_rec) = x.device_ptr(stream);
4727        let mut result = 0.0_f64;
4728        // SAFETY: x is a live device allocation on this stream with at least n
4729        // contiguous f64 entries and unit stride; result is a valid host
4730        // out-pointer for the cuBLAS scalar.
4731        let status = unsafe {
4732            cudarc::cublas::sys::cublasDnrm2_v2(
4733                *blas.handle(),
4734                n_i,
4735                x_ptr as *const f64,
4736                1,
4737                &mut result,
4738            )
4739        };
4740        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4741            Ok(result)
4742        } else {
4743            Err(ArrowSchurGpuFailure::Unavailable)
4744        }
4745    }
4746
4747    #[cfg(test)]
4748    mod tests {
4749        //! #1551 device-side framed-matvec triage. Lives inside `mod cuda` so it
4750        //! can call the private kernel launchers directly (no test-only public
4751        //! seam, which the ban-scanner forbids). A bare `#[cfg(test)] mod tests`
4752        //! is the one form the scanner permits.
4753        use super::*;
4754        use crate::arrow_schur::{
4755            ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4756            FactoredFrameGBlock,
4757        };
4758        use ndarray::Array2;
4759
4760        /// Run the framed reduced-Schur matvec `out = S·x` ONCE on the device
4761        /// (no PCG, no offload gate) and return `out`.
4762        fn device_matvec_once(
4763            sys: &ArrowSchurSystem,
4764            data: &DeviceSaePcgData,
4765            ridge_t: f64,
4766            ridge_beta: f64,
4767            x_host: &[f64],
4768        ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4769            let k = x_host.len();
4770            let frame = data
4771                .frame
4772                .as_ref()
4773                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4774            let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4775                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4776            let ctx =
4777                gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4778                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4779            let stream = ctx
4780                .new_stream()
4781                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4782            let vector_module = pcg_vector_module(&ctx)?;
4783            let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4784            let x_dev = stream
4785                .clone_htod(x_host)
4786                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4787            let mut out_dev = stream
4788                .alloc_zeros::<f64>(k)
4789                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4790            launch_sae_frame_matvec(
4791                &stream,
4792                vector_module,
4793                &mut buffers,
4794                &x_dev,
4795                &mut out_dev,
4796                ridge_beta,
4797            )?;
4798            stream
4799                .clone_dtoh(&out_dev)
4800                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4801        }
4802
4803        /// #1551 stage-isolating matvec triage on a TINY hand-verifiable fixture:
4804        /// diff the device framed matvec `S·e_col` against the CPU oracle
4805        /// `sae_framed_schur_matvec_cpu` for every identity column, reporting the
4806        /// worst-divergent border index so the structural 91% localizes to one
4807        /// kernel stage. Skips cleanly off-device.
4808        #[test]
4809        fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4810            if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4811                return;
4812            }
4813            let p = 3usize;
4814            let ranks = vec![2usize, 3usize];
4815            let basis_sizes = vec![2usize, 2usize];
4816            let mut border_offsets = Vec::new();
4817            let mut acc = 0usize;
4818            for k in 0..2 {
4819                border_offsets.push(acc);
4820                acc += basis_sizes[k] * ranks[k];
4821            }
4822            let border_dim = acc; // 2·2 + 2·3 = 10
4823            let frame_of = |k: usize| -> Array2<f64> {
4824                Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4825                    0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4826                })
4827            };
4828            let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4829            let w_of = |i: usize, j: usize| -> Array2<f64> {
4830                let (ui, uj) = (&frames[i], &frames[j]);
4831                Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4832                    (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4833                })
4834            };
4835            let mut frame_blocks = Vec::new();
4836            for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4837                let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4838                let mut g =
4839                    Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4840                if i == j {
4841                    for r in 0..mi.min(mj) {
4842                        g[[r, r]] += mi as f64 + 2.0;
4843                    }
4844                }
4845                frame_blocks.push(FactoredFrameGBlock {
4846                    atom_i: i,
4847                    atom_j: j,
4848                    g,
4849                    w: w_of(i, j),
4850                });
4851            }
4852            let mut smooth_blocks = Vec::new();
4853            for k in 0..2 {
4854                let m = basis_sizes[k];
4855                let mut s =
4856                    Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
4857                for r in 0..m {
4858                    s[[r, r]] += 1.0;
4859                }
4860                smooth_blocks.push(DeviceSaeSmoothBlock {
4861                    global_offset: border_offsets[k],
4862                    factor_a: s,
4863                });
4864            }
4865            let smooth_ranks = ranks.clone();
4866            let n = 2usize;
4867            let q = 2usize;
4868            let mut sys = ArrowSchurSystem::new(n, q, border_dim);
4869            let mut row_htbeta = Vec::new();
4870            for i in 0..n {
4871                let mut htt =
4872                    Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
4873                for r in 0..q {
4874                    htt[[r, r]] += q as f64 + 2.0;
4875                }
4876                sys.rows[i].htt = htt;
4877                let mut slab = vec![0.0_f64; q * border_dim];
4878                for c in 0..q {
4879                    for col in 0..border_dim {
4880                        let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
4881                        slab[c * border_dim + col] = v;
4882                        sys.rows[i].htbeta[[c, col]] = v;
4883                    }
4884                }
4885                row_htbeta.push(slab);
4886            }
4887            let data = DeviceSaePcgData {
4888                p,
4889                beta_dim: border_dim,
4890                a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
4891                local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
4892                smooth_blocks,
4893                sparse_g_blocks: Vec::new(),
4894                frame: Some(DeviceSaeFrameData {
4895                    ranks,
4896                    basis_sizes,
4897                    border_offsets,
4898                    frame_blocks,
4899                    smooth_ranks,
4900                    row_htbeta,
4901                }),
4902            };
4903            let ridge_t = 1e-7;
4904            let ridge_beta = 1e-6;
4905            let mut first_bad: Option<usize> = None;
4906            let mut worst = 0.0_f64;
4907            let mut worst_at = 0usize;
4908            let mut worst_dev = 0.0_f64;
4909            let mut worst_cpu = 0.0_f64;
4910            for col in 0..border_dim {
4911                let mut x = vec![0.0_f64; border_dim];
4912                x[col] = 1.0;
4913                let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
4914                    Ok(v) => v,
4915                    Err(_) => return,
4916                };
4917                let mut cpu = vec![0.0_f64; border_dim];
4918                super::super::sae_framed_schur_matvec_cpu(
4919                    &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
4920                )
4921                .expect("cpu matvec");
4922                for r in 0..border_dim {
4923                    let d = (dev[r] - cpu[r]).abs();
4924                    if d > 1e-9 && first_bad.is_none() {
4925                        first_bad = Some(r * border_dim + col);
4926                    }
4927                    if d > worst {
4928                        worst = d;
4929                        worst_at = r * border_dim + col;
4930                        worst_dev = dev[r];
4931                        worst_cpu = cpu[r];
4932                    }
4933                }
4934            }
4935            assert!(
4936                worst <= 1e-9,
4937                "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
4938                 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
4939                 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
4940                 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
4941                 G⊗W=cross, reduced-Schur=dense per-row)",
4942            );
4943        }
4944    }
4945}
4946
4947#[cfg(test)]
4948mod tests {
4949    use super::*;
4950    use crate::arrow_schur::ArrowSchurSystem;
4951    use ndarray::{Array2, ArrayView1};
4952
4953    fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
4954        let mut sys = ArrowSchurSystem::new(n, d, k);
4955        let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
4956        let mut sample = || -> f64 {
4957            state = state
4958                .wrapping_mul(6364136223846793005)
4959                .wrapping_add(1442695040888963407);
4960            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
4961        };
4962        for row in &mut sys.rows {
4963            let mut a = Array2::<f64>::zeros((d, d));
4964            for r in 0..d {
4965                for c in 0..d {
4966                    a[[r, c]] = sample();
4967                }
4968            }
4969            let mut htt = a.t().dot(&a);
4970            for r in 0..d {
4971                htt[[r, r]] += d as f64 + 1.0;
4972            }
4973            row.htt = htt;
4974            for r in 0..d {
4975                for c in 0..k {
4976                    row.htbeta[[r, c]] = 0.1 * sample();
4977                }
4978                row.gt[r] = sample();
4979            }
4980        }
4981        let mut hbb_a = Array2::<f64>::zeros((k, k));
4982        for r in 0..k {
4983            for c in 0..k {
4984                hbb_a[[r, c]] = sample();
4985            }
4986        }
4987        let mut hbb = hbb_a.t().dot(&hbb_a);
4988        for r in 0..k {
4989            hbb[[r, r]] += k as f64 + 1.0;
4990        }
4991        sys.hbb = hbb;
4992        for r in 0..k {
4993            sys.gb[r] = sample();
4994        }
4995        sys
4996    }
4997
4998    fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
4999        let mut s = Array2::<f64>::zeros((k, k));
5000        for row in 0..k {
5001            s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5002            if row + 1 < k {
5003                s[[row, row + 1]] = -0.05;
5004                s[[row + 1, row]] = -0.05;
5005            }
5006            if row + 7 < k {
5007                s[[row, row + 7]] = 0.01;
5008                s[[row + 7, row]] = 0.01;
5009            }
5010        }
5011        let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5012        (s, rhs)
5013    }
5014
5015    fn dense_pcg_cpu_reference(
5016        s: &Array2<f64>,
5017        rhs: &Array1<f64>,
5018        max_iterations: usize,
5019        relative_tolerance: f64,
5020    ) -> Array1<f64> {
5021        let k = rhs.len();
5022        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5023        if rhs_norm == 0.0 {
5024            return Array1::<f64>::zeros(k);
5025        }
5026        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5027        let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5028        let mut x = Array1::<f64>::zeros(k);
5029        let mut r = rhs.clone();
5030        let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5031        let mut p = z.clone();
5032        let mut sp = Array1::<f64>::zeros(k);
5033        let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5034        for _ in 0..max_iterations.max(1) {
5035            for row in 0..k {
5036                let mut acc = 0.0;
5037                for col in 0..k {
5038                    acc += s[[row, col]] * p[col];
5039                }
5040                sp[row] = acc;
5041            }
5042            let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5043            let alpha = rz / p_sp;
5044            for idx in 0..k {
5045                x[idx] += alpha * p[idx];
5046                r[idx] -= alpha * sp[idx];
5047            }
5048            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5049            if r_norm <= tol {
5050                break;
5051            }
5052            for idx in 0..k {
5053                z[idx] = inv_diag[idx] * r[idx];
5054            }
5055            let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5056            let beta = rz_next / rz;
5057            for idx in 0..k {
5058                p[idx] = z[idx] + beta * p[idx];
5059            }
5060            rz = rz_next;
5061        }
5062        x
5063    }
5064
5065    #[test]
5066    fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5067        let (s, rhs) = device_pcg_fixture(512);
5068        let max_iterations = 200usize;
5069        let relative_tolerance = 1.0e-12;
5070        let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5071        let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5072            &s,
5073            &rhs,
5074            max_iterations,
5075            relative_tolerance,
5076        ) {
5077            Ok(result) => result,
5078            // #1017 — fail loud, never skip-pass: this fixture clears the device
5079            // offload floor, so a CUDA device that is PRESENT yet declines/returns
5080            // Err means the device PCG kernel does not run on GPU (a real fault that
5081            // must not masquerade as a pass via this skip). Legit skip ONLY when no
5082            // usable CUDA device exists (CPU CI). The exact `ArrowSchurGpuFailure`
5083            // variant is folded into the assert message as the diagnostic.
5084            Err(failure) => {
5085                assert!(
5086                    gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5087                    "#1017: CUDA device present but the device reduced-beta PCG \
5088                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5089                     the kernel does not run correctly on GPU"
5090                );
5091                return;
5092            }
5093        };
5094        let max_err = cpu
5095            .iter()
5096            .zip(device.iter())
5097            .map(|(a, b)| (a - b).abs())
5098            .fold(0.0_f64, f64::max);
5099        assert!(
5100            max_err <= 1.0e-10,
5101            "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5102        );
5103        assert!(diag.matvec_calls > 0);
5104        assert_eq!(diag.matvec_calls, diag.iterations);
5105    }
5106
5107    #[test]
5108    fn dense_reference_matches_independent_solve() {
5109        let sys = build_fixture(4, 5, 3, 7);
5110        let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5111        // Re-solve by an independent matrix build and a textbook
5112        // Gaussian-elimination Cholesky to guard against typos in the
5113        // reference implementation itself.
5114        let n = sys.rows.len();
5115        let d = sys.d;
5116        let k = sys.k;
5117        let total = n * d + k;
5118        let mut h = Array2::<f64>::zeros((total, total));
5119        let mut g = ndarray::Array1::<f64>::zeros(total);
5120        for (i, row) in sys.rows.iter().enumerate() {
5121            let base = i * d;
5122            for c in 0..d {
5123                for r in 0..d {
5124                    h[[base + r, base + c]] = row.htt[[r, c]];
5125                }
5126            }
5127            for c in 0..k {
5128                for r in 0..d {
5129                    h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5130                    h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5131                }
5132            }
5133            for r in 0..d {
5134                g[base + r] = row.gt[r];
5135            }
5136        }
5137        for c in 0..k {
5138            for r in 0..k {
5139                h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5140            }
5141            g[n * d + c] = sys.gb[c];
5142        }
5143        let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5144        let rhs = g.mapv(|v| -v);
5145        let expected = cholesky_solve_vector(l.view(), rhs.view());
5146        for i in 0..n * d {
5147            assert!(
5148                (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5149                "delta_t[{i}] mismatch: got {} expected {}",
5150                solution.delta_t[i],
5151                expected[i]
5152            );
5153        }
5154        for a in 0..k {
5155            assert!(
5156                (solution.delta_beta[a] - expected[n * d + a]).abs()
5157                    < 1e-10 * (1.0 + expected[n * d + a].abs()),
5158                "delta_beta[{a}] mismatch"
5159            );
5160        }
5161    }
5162
5163    /// #1017: the row-procedural reduced-Schur matvec (the matrix-free SAE
5164    /// host backend) auto-fans its per-row point-elimination sum across rayon
5165    /// over fixed row chunks when at the top level (`n ≥
5166    /// SCHUR_MATVEC_PARALLEL_ROW_MIN`), and stays serial when already inside a
5167    /// rayon worker. The chunk-ordered fold makes the parallel result
5168    /// **deterministic** (two parallel calls are bit-identical — scheduling
5169    /// cannot change the numbers) and it agrees with the serial accumulation up
5170    /// to ULP-scale chunk reassociation (the #1017 verification gate). That
5171    /// reassociation is a genuine f64 departure from serial, so the criterion
5172    /// ranking across topology candidates is stable only up to the reassociation
5173    /// margin: a near-tie winner inside that margin can flip. This is NOT an
5174    /// exact no-move guarantee (#1211); for that, the ranking path must use the
5175    /// fixed-order serial accumulation.
5176    #[test]
5177    fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5178        use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5179        let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; // trips the parallel path
5180        let d = 3usize;
5181        let k = 24usize;
5182        let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5183        // Install a matrix-free forward/transpose pair that reads the dense
5184        // `htbeta` slabs the fixture already populated, so the procedural
5185        // backend has a well-defined operator to apply (and exercises exactly
5186        // the sparse gather/scatter the SAE Kronecker path drives).
5187        let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5188        let forward_slabs = slabs.clone();
5189        let transpose_slabs = slabs;
5190        sys.set_row_htbeta_operator(
5191            move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5192                let h = &forward_slabs[row];
5193                for r in 0..h.nrows() {
5194                    let mut acc = 0.0_f64;
5195                    for c in 0..h.ncols() {
5196                        acc += h[[r, c]] * x[c];
5197                    }
5198                    out[r] = acc;
5199                }
5200            },
5201            move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5202                let h = &transpose_slabs[row];
5203                for r in 0..h.nrows() {
5204                    for c in 0..h.ncols() {
5205                        out[c] += h[[r, c]] * v[r];
5206                    }
5207                }
5208            },
5209        );
5210
5211        let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5212            .expect("row-procedural matvec backend builds for matrix-free system");
5213        let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5214
5215        // Top-level call: auto-selects the parallel chunk-fold. Run twice and
5216        // assert bit-identity — the chunk-ordered reduction must not depend on
5217        // thread scheduling.
5218        let mut out_parallel_a = Array1::<f64>::zeros(k);
5219        matvec(&x, &mut out_parallel_a);
5220        let mut out_parallel_b = Array1::<f64>::zeros(k);
5221        matvec(&x, &mut out_parallel_b);
5222        for a in 0..k {
5223            assert_eq!(
5224                out_parallel_a[a].to_bits(),
5225                out_parallel_b[a].to_bits(),
5226                "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5227            );
5228        }
5229
5230        // Inside a rayon worker: auto-selects the serial path (nested-rayon
5231        // guard). `install` runs the closure on a pool thread, so
5232        // `current_thread_index()` is `Some`. The serial running sum and the
5233        // chunk-ordered parallel fold differ only by f64 reassociation.
5234        let mut out_serial = Array1::<f64>::zeros(k);
5235        rayon::ThreadPoolBuilder::new()
5236            .num_threads(2)
5237            .build()
5238            .expect("build rayon pool")
5239            .install(|| matvec(&x, &mut out_serial));
5240
5241        let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5242        for a in 0..k {
5243            let diff = (out_parallel_a[a] - out_serial[a]).abs();
5244            assert!(
5245                diff <= 1e-12 * (1.0 + max_abs),
5246                "row-procedural matvec parallel vs serial diverged beyond reassociation \
5247                 at index {a}: {} vs {} (diff={diff:e})",
5248                out_parallel_a[a],
5249                out_serial[a]
5250            );
5251        }
5252    }
5253
5254    /// #1017/#1026 — the frames-engaged CPU reduced-Schur matvec
5255    /// [`sae_framed_schur_matvec_cpu`] (the bit-parity oracle the GPU kernel
5256    /// mirrors) must equal the dense reduced Schur `S = (P_ββ + ρ_β I) −
5257    /// Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)` formed by the canonical dense
5258    /// reference, on a small framed system with mixed per-atom ranks
5259    /// (`r_k < p` framed + `r_k = p` un-framed). Size-independent gate.
5260    #[test]
5261    fn framed_sae_schur_matvec_matches_dense_reference() {
5262        use crate::arrow_schur::{
5263            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5264            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5265        };
5266
5267        let p = 4usize;
5268        // Three atoms: ranks 2 (framed), 4 (un-framed), 3 (framed).
5269        let ranks = vec![2usize, 4usize, 3usize];
5270        let basis_sizes = vec![2usize, 1usize, 2usize];
5271        let n_atoms = ranks.len();
5272        let mut border_offsets = Vec::with_capacity(n_atoms);
5273        let mut acc = 0usize;
5274        for k in 0..n_atoms {
5275            border_offsets.push(acc);
5276            acc += basis_sizes[k] * ranks[k];
5277        }
5278        let border_dim = acc; // 2*2 + 1*4 + 2*3 = 14
5279
5280        let mut state = 0x1234_5678_9abc_def0u64;
5281        let mut sample = || -> f64 {
5282            state = state
5283                .wrapping_mul(6364136223846793005)
5284                .wrapping_add(1442695040888963407);
5285            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5286        };
5287
5288        // Per-atom orthonormal-ish frames U_k (p × r_k) for the W = U_iᵀU_j
5289        // factors; un-framed atom (r=p) uses U = I_p.
5290        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5291        for k in 0..n_atoms {
5292            let r = ranks[k];
5293            let mut u = Array2::<f64>::zeros((p, r));
5294            for i in 0..p {
5295                for j in 0..r {
5296                    u[[i, j]] = if r == p && i == j {
5297                        1.0
5298                    } else if r == p {
5299                        0.0
5300                    } else {
5301                        sample()
5302                    };
5303                }
5304            }
5305            frames.push(u);
5306        }
5307        let w_of = |i: usize, j: usize| -> Array2<f64> {
5308            let (ui, uj) = (&frames[i], &frames[j]);
5309            let (ri, rj) = (ranks[i], ranks[j]);
5310            let mut w = Array2::<f64>::zeros((ri, rj));
5311            for a in 0..ri {
5312                for b in 0..rj {
5313                    let mut s = 0.0;
5314                    for c in 0..p {
5315                        s += ui[[c, a]] * uj[[c, b]];
5316                    }
5317                    w[[a, b]] = s;
5318                }
5319            }
5320            w
5321        };
5322
5323        // Co-occurring data-fit blocks: all diagonal pairs + one cross (0,2).
5324        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5325        let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5326        pairs.sort();
5327        for &(i, j) in &pairs {
5328            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5329            let mut g = Array2::<f64>::zeros((mi, mj));
5330            for r in 0..mi {
5331                for c in 0..mj {
5332                    g[[r, c]] = 0.3 * sample();
5333                }
5334            }
5335            // Make diagonal blocks SPD-leaning so S stays PD.
5336            if i == j {
5337                for r in 0..mi.min(mj) {
5338                    g[[r, r]] += mi as f64 + 2.0;
5339                }
5340            }
5341            frame_blocks.push(FactoredFrameGBlock {
5342                atom_i: i,
5343                atom_j: j,
5344                g,
5345                w: w_of(i, j),
5346            });
5347        }
5348
5349        // Smooth blocks λ S_k (M_k × M_k), SPD.
5350        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5351        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5352        for k in 0..n_atoms {
5353            let m = basis_sizes[k];
5354            let mut a = Array2::<f64>::zeros((m, m));
5355            for r in 0..m {
5356                for c in 0..m {
5357                    a[[r, c]] = 0.2 * sample();
5358                }
5359            }
5360            let mut s = a.t().dot(&a);
5361            for r in 0..m {
5362                s[[r, r]] += 1.0;
5363            }
5364            smooth_blocks.push(DeviceSaeSmoothBlock {
5365                global_offset: border_offsets[k],
5366                factor_a: s,
5367            });
5368            smooth_ranks.push(ranks[k]);
5369        }
5370
5371        // Build the system: n rows, dense htbeta slabs (q_i × border_dim).
5372        let n = 6usize;
5373        let q = 3usize;
5374        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5375        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5376        for i in 0..n {
5377            // SPD htt.
5378            let mut a = Array2::<f64>::zeros((q, q));
5379            for r in 0..q {
5380                for c in 0..q {
5381                    a[[r, c]] = sample();
5382                }
5383            }
5384            let mut htt = a.t().dot(&a);
5385            for r in 0..q {
5386                htt[[r, r]] += q as f64 + 1.0;
5387            }
5388            sys.rows[i].htt = htt;
5389            let mut slab = vec![0.0_f64; q * border_dim];
5390            for c in 0..q {
5391                for col in 0..border_dim {
5392                    let v = 0.15 * sample();
5393                    slab[c * border_dim + col] = v;
5394                    sys.rows[i].htbeta[[c, col]] = v;
5395                }
5396            }
5397            row_htbeta.push(slab);
5398        }
5399
5400        // Dense H_ββ from the SAME penalty ops (so the dense reference's S
5401        // matches the device penalty side exactly).
5402        let data_op =
5403            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5404                .expect("frame op");
5405        let mut hbb = data_op.to_dense();
5406        for k in 0..n_atoms {
5407            let op = IdentityRightKroneckerPenaltyOp {
5408                factor_a: smooth_blocks[k].factor_a.clone(),
5409                p: ranks[k],
5410                global_offset: border_offsets[k],
5411                k: border_dim,
5412            };
5413            let d = op.to_dense();
5414            for r in 0..border_dim {
5415                for c in 0..border_dim {
5416                    hbb[[r, c]] += d[[r, c]];
5417                }
5418            }
5419        }
5420        sys.hbb = hbb;
5421
5422        let data = DeviceSaePcgData {
5423            p,
5424            beta_dim: border_dim,
5425            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5426            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5427            smooth_blocks,
5428            sparse_g_blocks: Vec::new(),
5429            frame: Some(DeviceSaeFrameData {
5430                ranks: ranks.clone(),
5431                basis_sizes: basis_sizes.clone(),
5432                border_offsets: border_offsets.clone(),
5433                frame_blocks,
5434                smooth_ranks,
5435                row_htbeta,
5436            }),
5437        };
5438
5439        let ridge_t = 1e-7;
5440        let ridge_beta = 1e-6;
5441
5442        // Dense reference reduced Schur S (border_dim × border_dim), formed
5443        // exactly as solve_arrow_newton_step_dense_reference assembles the
5444        // bordered Hessian and eliminates the t-block.
5445        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5446        for r in 0..border_dim {
5447            for c in 0..border_dim {
5448                s_dense[[r, c]] = sys.hbb[[r, c]];
5449            }
5450            s_dense[[r, r]] += ridge_beta;
5451        }
5452        for row in &sys.rows {
5453            let mut htt = row.htt.clone();
5454            for d in 0..q {
5455                htt[[d, d]] += ridge_t;
5456            }
5457            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5458                .expect("htt PD");
5459            // Y = (htt)⁻¹ htbeta  (q × border_dim); S -= htbetaᵀ Y.
5460            let mut y = Array2::<f64>::zeros((q, border_dim));
5461            for col in 0..border_dim {
5462                let mut e = Array1::<f64>::zeros(q);
5463                for r in 0..q {
5464                    e[r] = row.htbeta[[r, col]];
5465                }
5466                let solved = cholesky_solve_vector(factor.view(), e.view());
5467                for r in 0..q {
5468                    y[[r, col]] = solved[r];
5469                }
5470            }
5471            for r in 0..border_dim {
5472                for c in 0..border_dim {
5473                    let mut acc = 0.0;
5474                    for d in 0..q {
5475                        acc += row.htbeta[[d, r]] * y[[d, c]];
5476                    }
5477                    s_dense[[r, c]] -= acc;
5478                }
5479            }
5480        }
5481
5482        // Probe vectors: compare S·x from the device-data CPU oracle vs dense S·x.
5483        let mut max_rel = 0.0_f64;
5484        for trial in 0..4 {
5485            let x: Vec<f64> = (0..border_dim)
5486                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5487                .collect();
5488            let mut got = vec![0.0_f64; border_dim];
5489            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5490                .expect("framed matvec");
5491            let mut want = vec![0.0_f64; border_dim];
5492            for r in 0..border_dim {
5493                let mut acc = 0.0;
5494                for c in 0..border_dim {
5495                    acc += s_dense[[r, c]] * x[c];
5496                }
5497                want[r] = acc;
5498            }
5499            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5500            for a in 0..border_dim {
5501                let rel = (got[a] - want[a]).abs() / scale;
5502                max_rel = max_rel.max(rel);
5503            }
5504        }
5505        assert!(
5506            max_rel <= 1e-10,
5507            "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5508        );
5509    }
5510
5511    /// #1017/#1026 GPU arm: when a CUDA device admits the framed SAE PCG, its
5512    /// solved `δβ` must match the CPU dense reduced-system solve of the SAME
5513    /// framed system (size-independent — a small device validates the kernel).
5514    /// Skips cleanly (returns) when no device is available or the policy
5515    /// declines (`solve_sae_matrix_free_pcg` → `Unavailable`).
5516    #[test]
5517    fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
5518        use crate::arrow_schur::{
5519            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5520            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5521        };
5522
5523        // Large enough to clear the device-offload policy floor (k ≥ 32 and
5524        // n·k·d·iters ≥ MATVEC_OFFLOAD_FLOPS_MIN) so the GPU kernel actually
5525        // runs on a device rather than the policy declining.
5526        let p = 6usize;
5527        let n_atoms = 8usize;
5528        let ranks: Vec<usize> = (0..n_atoms)
5529            .map(|k| if k % 2 == 0 { 3usize } else { p })
5530            .collect();
5531        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
5532        let mut border_offsets = Vec::with_capacity(n_atoms);
5533        let mut acc = 0usize;
5534        for k in 0..n_atoms {
5535            border_offsets.push(acc);
5536            acc += basis_sizes[k] * ranks[k];
5537        }
5538        let border_dim = acc; // Σ M_k·r_k = 4·(3·3) + 4·(3·6) = 36 + 72 = 108
5539
5540        let mut state = 0xfeed_face_dead_beefu64;
5541        let mut sample = || -> f64 {
5542            state = state
5543                .wrapping_mul(6364136223846793005)
5544                .wrapping_add(1442695040888963407);
5545            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5546        };
5547        let mut frames: Vec<Array2<f64>> = Vec::new();
5548        for k in 0..n_atoms {
5549            let r = ranks[k];
5550            let mut u = Array2::<f64>::zeros((p, r));
5551            for i in 0..p {
5552                for j in 0..r {
5553                    u[[i, j]] = if r == p && i == j {
5554                        1.0
5555                    } else if r == p {
5556                        0.0
5557                    } else {
5558                        sample()
5559                    };
5560                }
5561            }
5562            frames.push(u);
5563        }
5564        let w_of = |i: usize, j: usize| {
5565            let (ui, uj) = (&frames[i], &frames[j]);
5566            let (ri, rj) = (ranks[i], ranks[j]);
5567            let mut w = Array2::<f64>::zeros((ri, rj));
5568            for a in 0..ri {
5569                for b in 0..rj {
5570                    let mut s = 0.0;
5571                    for c in 0..p {
5572                        s += ui[[c, a]] * uj[[c, b]];
5573                    }
5574                    w[[a, b]] = s;
5575                }
5576            }
5577            w
5578        };
5579        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
5580        // A few off-diagonal cross blocks (symmetric pairs).
5581        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
5582            pairs.push((i, j));
5583            pairs.push((j, i));
5584        }
5585        let mut frame_blocks = Vec::new();
5586        for &(i, j) in &pairs {
5587            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5588            let mut g = Array2::<f64>::zeros((mi, mj));
5589            for r in 0..mi {
5590                for c in 0..mj {
5591                    g[[r, c]] = 0.25 * sample();
5592                }
5593            }
5594            if i == j {
5595                for r in 0..mi.min(mj) {
5596                    g[[r, r]] += mi as f64 + 2.0;
5597                }
5598            }
5599            frame_blocks.push(FactoredFrameGBlock {
5600                atom_i: i,
5601                atom_j: j,
5602                g,
5603                w: w_of(i, j),
5604            });
5605        }
5606        let mut smooth_blocks = Vec::new();
5607        let mut smooth_ranks = Vec::new();
5608        for k in 0..n_atoms {
5609            let m = basis_sizes[k];
5610            let mut a = Array2::<f64>::zeros((m, m));
5611            for r in 0..m {
5612                for c in 0..m {
5613                    a[[r, c]] = 0.2 * sample();
5614                }
5615            }
5616            let mut s = a.t().dot(&a);
5617            for r in 0..m {
5618                s[[r, r]] += 1.0;
5619            }
5620            smooth_blocks.push(DeviceSaeSmoothBlock {
5621                global_offset: border_offsets[k],
5622                factor_a: s,
5623            });
5624            smooth_ranks.push(ranks[k]);
5625        }
5626        let n = 400usize;
5627        let q = 4usize;
5628        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5629        let mut row_htbeta = Vec::new();
5630        for i in 0..n {
5631            let mut a = Array2::<f64>::zeros((q, q));
5632            for r in 0..q {
5633                for c in 0..q {
5634                    a[[r, c]] = sample();
5635                }
5636            }
5637            let mut htt = a.t().dot(&a);
5638            for r in 0..q {
5639                htt[[r, r]] += q as f64 + 1.0;
5640            }
5641            sys.rows[i].htt = htt;
5642            let mut slab = vec![0.0_f64; q * border_dim];
5643            for c in 0..q {
5644                for col in 0..border_dim {
5645                    // Small entries: with 400 rows the reduced-Schur subtraction
5646                    // Σ_i H_βtᵀ H_tt⁻¹ H_tβ must not overwhelm the PD penalty.
5647                    let v = 0.02 * sample();
5648                    slab[c * border_dim + col] = v;
5649                    sys.rows[i].htbeta[[c, col]] = v;
5650                }
5651            }
5652            row_htbeta.push(slab);
5653        }
5654        let data_op =
5655            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5656                .expect("frame op");
5657        let mut hbb = data_op.to_dense();
5658        for k in 0..n_atoms {
5659            let op = IdentityRightKroneckerPenaltyOp {
5660                factor_a: smooth_blocks[k].factor_a.clone(),
5661                p: ranks[k],
5662                global_offset: border_offsets[k],
5663                k: border_dim,
5664            };
5665            let d = op.to_dense();
5666            for r in 0..border_dim {
5667                for c in 0..border_dim {
5668                    hbb[[r, c]] += d[[r, c]];
5669                }
5670            }
5671        }
5672        sys.hbb = hbb;
5673        let data = DeviceSaePcgData {
5674            p,
5675            beta_dim: border_dim,
5676            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5677            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5678            smooth_blocks,
5679            sparse_g_blocks: Vec::new(),
5680            frame: Some(DeviceSaeFrameData {
5681                ranks: ranks.clone(),
5682                basis_sizes: basis_sizes.clone(),
5683                border_offsets: border_offsets.clone(),
5684                frame_blocks,
5685                smooth_ranks,
5686                row_htbeta,
5687            }),
5688        };
5689        let ridge_t = 1e-7;
5690        let ridge_beta = 1e-6;
5691        let rhs: Array1<f64> =
5692            Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
5693
5694        let (device, diag) =
5695            match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
5696                Ok(result) => result,
5697                // #1017 — fail loud, never skip-pass: this fixture clears the device
5698                // offload floor, so a CUDA device that is PRESENT yet declines means the
5699                // framed device PCG kernel does not run on GPU (the fault must not pass
5700                // silently). Legit skip ONLY when no usable CUDA device exists (CPU CI).
5701                // The exact `ArrowSchurGpuFailure` variant is folded into the assert.
5702                Err(failure) => {
5703                    assert!(
5704                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5705                        "#1017: CUDA device present but the framed device SAE PCG \
5706                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5707                     the kernel does not run correctly on GPU"
5708                    );
5709                    return;
5710                }
5711            };
5712
5713        // CPU dense reduced-system solve of the SAME framed system: form S via
5714        // the CPU oracle matvec on the identity, then solve S·δβ = rhs.
5715        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5716        for col in 0..border_dim {
5717            let mut e = vec![0.0_f64; border_dim];
5718            e[col] = 1.0;
5719            let mut sc = vec![0.0_f64; border_dim];
5720            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &e, &mut sc)
5721                .expect("cpu matvec");
5722            for r in 0..border_dim {
5723                s_dense[[r, col]] = sc[r];
5724            }
5725        }
5726        let factor = cholesky_factor_in_place(s_dense.view(), CholeskyGuard::NonnegativePivot)
5727            .expect("S PD");
5728        let cpu = cholesky_solve_vector(factor.view(), rhs.view());
5729
5730        let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5731        let mut max_rel = 0.0_f64;
5732        for a in 0..border_dim {
5733            max_rel = max_rel.max((device[a] - cpu[a]).abs() / scale);
5734        }
5735        // #1551 divergence triage (max_rel=0.91 once the device actually engages):
5736        // disambiguate matvec-bug vs PCG-non-convergence vs operator-mismatch.
5737        // Residual of the DEVICE solution against the CPU operator S_cpu: if the
5738        // device solved the SAME operator and converged, ‖S_cpu·device − rhs‖ ≈ 0;
5739        // a large residual means the device matvec is a DIFFERENT operator (kernel
5740        // bug), whereas a small residual with large max_rel would indicate a
5741        // (near-)singular S where both solves are valid. Also surface what the
5742        // device PCG thought it did (stopping reason / iters / final residual).
5743        let mut s_dev_resid = 0.0_f64;
5744        {
5745            let sx = s_dense.dot(&device);
5746            for a in 0..border_dim {
5747                s_dev_resid = s_dev_resid.max((sx[a] - rhs[a]).abs());
5748            }
5749        }
5750        let s_cpu_resid = {
5751            let sc = s_dense.dot(&cpu);
5752            let mut m = 0.0_f64;
5753            for a in 0..border_dim {
5754                m = m.max((sc[a] - rhs[a]).abs());
5755            }
5756            m
5757        };
5758        assert!(
5759            max_rel <= 1e-7,
5760            "[#1551 framed-triage] max_rel={max_rel:e} | device-vs-CPU-operator residual \
5761             ‖S_cpu·device−rhs‖={s_dev_resid:e} (CPU's own ={s_cpu_resid:e}) | device PCG \
5762             stop={:?} iters={} final_rel_resid={:e} — large operator-residual ⇒ device matvec \
5763             is a different operator (kernel bug); small ⇒ PCG/precond or singular-S issue",
5764            diag.stopping_reason,
5765            diag.iterations,
5766            diag.final_relative_residual,
5767        );
5768    }
5769}