Skip to main content

gam_solve/gpu_kernels/
arrow_schur.rs

1//! Fully GPU-resident batched Arrow-Schur dense Cholesky solver.
2//!
3//! Implements the square-root Schur form: each local block `D_i = L_i L_i^T`
4//! is factored on device, `u_i = L_i^{-1} g_i` and `Y_i = L_i^{-1} B_i` are
5//! formed by triangular solves, the reduced shared system
6//!     `S_β = C + ρ_β I − Σ_i Y_i^T Y_i,  r_β = -g_β + Σ_i Y_i^T u_i`
7//! is assembled on device, factored once, and the back-substitution
8//!     `w_i = u_i + Y_i · δβ,  L_i^T x_i = w_i,  δt_i = -x_i`
9//! is run on device. Only the final `(δt, δβ, log|H|)` triple is downloaded.
10//!
11//! The current caller (Arrow-Schur Newton step inside PIRLS) feeds uniform
12//! local block size `d` and uniform shared width `k`, so the entire pipeline
13//! is dispatched as a single p-group; per-p grouping for heterogenous blocks
14//! is Layer D's NVRTC fused-kernel concern and lives in this module's
15//! follow-up implementation rather than in policy plumbing.
16//!
17//! On non-Linux builds the entire module degrades to a CPU-fallback shim.
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24/// Outcome of a single Arrow-Schur Newton solve.
25pub struct ArrowSchurGpuSolution {
26    pub delta_t: Array1<f64>,
27    pub delta_beta: Array1<f64>,
28    /// Natural log of the determinant of the full bordered Hessian, computed
29    /// from the local Cholesky factors and the Schur factor on device.
30    pub log_det_hessian: f64,
31}
32
33/// Reason a device path declined to run; lets the host caller decide between
34/// CPU fallback and per-row escalation. `RidgeBumpRequired` carries the
35/// estimated diagonal bump needed to clear the failed pivot.
36#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38    /// CUDA runtime unavailable, allocation failed, or workload below policy.
39    Unavailable,
40    /// A row block was not positive definite even after the requested ridge.
41    /// Caller may retry with `ridge_t + bump`.
42    RidgeBumpRequired { row: usize, bump: f64 },
43    /// Shared Schur factor failed; bordered system is rank-deficient at the
44    /// requested ridges and the CPU path should handle escalation.
45    SchurFactorFailed { reason: String },
46    /// The system carries matrix-free `H_ββ` or per-row `H_tβ` operators that
47    /// the dense GPU Schur path cannot consume. The caller should route to CPU
48    /// `InexactPCG` (or supply dense buffers) rather than treating this as a
49    /// numerical failure. See `gpu/arrow_schur.rs` Part B for the planned GPU
50    /// PCG path that will lift this restriction at K ≥ 5000.
51    GpuRequiresDenseSystem {
52        had_hbb_matvec: bool,
53        had_htbeta_matvec: bool,
54    },
55}
56
57/// Relative rounding margin (multiplier on `diag_scale · √ε`) added on top of
58/// the deficit-clearing shift in [`ridge_bump_to_make_pd`].
59///
60/// The exact shift `-(λ_min)` makes a block PD in exact arithmetic, but a
61/// single retry at precisely that magnitude is routinely re-rejected by the
62/// next POTRF because the rounding error of forming `D + ridge·I` and
63/// re-factoring is itself O(√ε). The 1024× headroom (≈ 2¹⁰, ten extra bits
64/// below the f64 mantissa's 52) clears the pivot on the first retry without
65/// materially perturbing the curvature the Newton step sees. Shared by every
66/// per-row / batched / fused producer so they suggest a consistent bump.
67const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69/// Diagonal ridge bump that is GUARANTEED to make `H_tt + (ridge_t + bump)·I`
70/// positive definite for a *symmetric* per-row block, sized from the block's
71/// own entries rather than from the factorization's pivot index.
72///
73/// # Why the old `scale · |pivot| · √ε · 1024` estimate is wrong
74///
75/// The batched/fused device paths derive the suggested bump from the
76/// factorization "pivot" — but cuSOLVER's `potrf` (and the NVRTC kernel's
77/// status code) report the failing pivot as a **1-based row index**, NOT the
78/// magnitude of the negative pivot. A block that is indefinite by `O(1)`
79/// (e.g. `H_tt = -I`, whose smallest eigenvalue is `-1`) then yields the same
80/// `bump ≈ √ε · 1024 ≈ 1.5e-5` as a block that is indefinite by `O(√ε)`. The
81/// outer LM escalation, which retries at `ridge_t + bump` and grows
82/// geometrically with a bounded step count, can never lift a strongly
83/// indefinite block out of the negative regime, so the solve fails to recover
84/// even though the block is trivially regularizable. (Surfaced by the V100
85/// `ridge_bump_required_on_non_pd_row_recovers_after_bump` validation test.)
86///
87/// # The bound
88///
89/// By the Gershgorin circle theorem every eigenvalue `λ` of the symmetric
90/// matrix `A = H_tt` satisfies, for some row `i`,
91///   `λ ≥ A[i,i] − Σ_{j≠i} |A[i,j]|`,
92/// so `λ_min(A) ≥ min_i ( A[i,i] − Σ_{j≠i} |A[i,j]| ) =: g` (the most negative
93/// Gershgorin left edge). Adding `t·I` shifts every eigenvalue up by `t`, so
94/// `A + t·I` is PD as soon as `t > -g`. We are already sitting at `ridge_t`, so
95/// the ADDITIONAL bump needed is `-(g + ridge_t)` when that is positive. We add
96/// a relative safety margin (`√ε · scale · 1024`, the same headroom the legacy
97/// estimate used) so the re-factored, rounding-perturbed block clears the pivot
98/// on the first retry, and a `max(1)`-scaled floor so a marginally-indefinite
99/// block still gets a strictly positive, non-vanishing bump.
100///
101/// The returned value is the bump to ADD to the current `ridge_t`. It is always
102/// strictly positive (the caller only constructs `RidgeBumpRequired` on an
103/// actual non-PD failure, but the bound is defensive regardless).
104#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106    let d = htt.nrows();
107    // Diagonal magnitude scale (also the legacy `scale`), and the most-negative
108    // Gershgorin left edge `g = min_i (A_ii − Σ_{j≠i} |A_ij|)`.
109    let mut scale = 1.0_f64;
110    let mut min_gershgorin_edge = f64::INFINITY;
111    for i in 0..d {
112        let diag = htt[[i, i]];
113        scale = scale.max(diag.abs());
114        let mut off_sum = 0.0_f64;
115        for j in 0..d {
116            if j != i {
117                off_sum += htt[[i, j]].abs();
118            }
119        }
120        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121    }
122    if !min_gershgorin_edge.is_finite() {
123        // d == 0 (no rows) or non-finite entries: fall back to the scale-only
124        // floor so the caller still gets a strictly positive bump.
125        return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126    }
127    // Additional shift needed so `λ_min(A) + ridge_t + bump > 0`, i.e.
128    // `bump > -(min_gershgorin_edge + ridge_t)`.
129    let deficit = -(min_gershgorin_edge + ridge_t);
130    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131    // Lift past the deficit (when positive) plus a rounding margin; never below
132    // the scale-relative floor so a marginal block still moves.
133    deficit.max(0.0) + margin
134}
135
136/// [`ridge_bump_to_make_pd`] for a `d × d` symmetric block stored column-major
137/// in a flat slice with the current ridge ALREADY baked into the diagonal
138/// (the device packers emit `D = H_tt + ridge_t·I` this way). Because the shift
139/// is already present, the Gershgorin bound is taken at `ridge_t = 0` and the
140/// returned value is still the ADDITIONAL bump to add on top of the current
141/// ridge. Returns the scale-only floor when `block` is mis-sized.
142#[must_use]
143fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
144    if d == 0 || block.len() < d * d {
145        return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
146    }
147    // Column-major: element (row r, col c) at block[c*d + r]. The matrix is
148    // symmetric, so reading by column gives the same Gershgorin edges as by row.
149    let mut scale = 1.0_f64;
150    let mut min_gershgorin_edge = f64::INFINITY;
151    for i in 0..d {
152        let diag = block[i * d + i];
153        scale = scale.max(diag.abs());
154        let mut off_sum = 0.0_f64;
155        for j in 0..d {
156            if j != i {
157                off_sum += block[j * d + i].abs();
158            }
159        }
160        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
161    }
162    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
163    if !min_gershgorin_edge.is_finite() {
164        return margin;
165    }
166    (-min_gershgorin_edge).max(0.0) + margin
167}
168
169/// Entry point: attempt the fully device-resident Arrow-Schur Newton solve.
170/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` to indicate "device path
171/// declined, fall back to CPU" — never panics.
172pub fn solve_arrow_newton_step(
173    sys: &ArrowSchurSystem,
174    ridge_t: f64,
175    ridge_beta: f64,
176) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
177    let n = sys.rows.len();
178    let d = sys.d;
179    let k = sys.k;
180
181    // Detect matrix-free operators before any dim() checks so callers get a
182    // clear, actionable error instead of a generic SchurFactorFailed. The GPU
183    // dense-Schur path requires materialised H_ββ and per-row H_tβ slabs;
184    // CPU InexactPCG is the correct fallback when either operator is abstract.
185    let had_hbb_matvec = sys.hbb_matvec.is_some();
186    let had_htbeta_matvec = sys.htbeta_matvec.is_some();
187    if had_hbb_matvec || had_htbeta_matvec {
188        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
189            had_hbb_matvec,
190            had_htbeta_matvec,
191        });
192    }
193
194    if sys.hbb.dim() != (k, k) {
195        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
196            reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
197        });
198    }
199    if n == 0 || d == 0 {
200        return Err(ArrowSchurGpuFailure::Unavailable);
201    }
202    if sys
203        .rows
204        .iter()
205        .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
206    {
207        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
208            reason: "row block dimension mismatch".to_string(),
209        });
210    }
211
212    #[cfg(not(target_os = "linux"))]
213    {
214        if ridge_t.is_nan() || ridge_beta.is_nan() {
215            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
216                reason: "ridge is NaN".to_string(),
217            });
218        }
219        Err(ArrowSchurGpuFailure::Unavailable)
220    }
221
222    #[cfg(target_os = "linux")]
223    {
224        // Multi-GPU: the arrow-Schur solve is row-block separable in its forward
225        // (per-row factor / whiten / partial-Schur) and backward (per-row
226        // back-sub) phases — only the small shared K×K reduce+factor+δβ is
227        // central. When more than one device is usable, split the WHOLE solve at
228        // row-block granularity across all GPUs. The POTRF stays fused with its
229        // dependent TRSM+GEMM on each tile's own stream, so no on-stream solve is
230        // orphaned. On `Unavailable` (one device, shape below policy, transient)
231        // fall through to the single-device fused / Layer-A paths below.
232        if gam_gpu::device_runtime::GpuRuntime::global()
233            .map(gam_gpu::device_runtime::GpuRuntime::device_count)
234            .unwrap_or(0)
235            > 1
236        {
237            match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
238                Ok(sol) => return Ok(sol),
239                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
240                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
241                }
242                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
243                    return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
244                }
245                // Unavailable / GpuRequiresDenseSystem: fall through to the
246                // single-device paths (already shape-validated above).
247                Err(_) => {}
248            }
249        }
250        // Layer D admission: when the system shape passes the
251        // (Σ p³ ≥ 1e5 OR R ≥ 16) heuristic and `p ≤ MAX_FUSED_P`, the fused
252        // NVRTC kernel replaces the cuSOLVER/cuBLAS Layer A+B+C path with a
253        // single per-row block. Layer C↔D parity (math block 3 §16 test 6)
254        // requires both paths to agree to 1e-10 on identical inputs.
255        if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
256            match cuda::solve_fused(sys, ridge_t, ridge_beta) {
257                Ok(sol) => return Ok(sol),
258                // RidgeBumpRequired must surface to the outer escalation loop —
259                // the fused path's pivot diagnostic is identical in semantics
260                // to the cuSOLVER batched POTRF info code.
261                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
262                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
263                }
264                // Any other failure (Unavailable, SchurFactorFailed) falls
265                // through to the unfused path so a flaky NVRTC compile or
266                // shared-mem allocation does not abort the outer Newton step.
267                Err(_) => {}
268            }
269        }
270        cuda::solve(sys, ridge_t, ridge_beta)
271    }
272}
273
274/// Build the stacked column-major D buffer (n local d×d blocks), the stacked
275/// stacked B buffer (n local d×k blocks), and the stacked g buffer
276/// (n local d-vectors) consumed by the device pipeline. Each block is laid
277/// out column-major so a single allocation + `cuMemcpyHtoD` reaches the
278/// device without per-row dispatch overhead.
279#[cfg(target_os = "linux")]
280fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
281    let n = sys.rows.len();
282    let d = sys.d;
283    let k = sys.k;
284    let mut d_buf = Vec::with_capacity(n * d * d);
285    let mut b_buf = Vec::with_capacity(n * d * k);
286    let mut g_buf = Vec::with_capacity(n * d);
287    for row in &sys.rows {
288        pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
289    }
290    (d_buf, b_buf, g_buf)
291}
292
293#[cfg(target_os = "linux")]
294#[inline]
295fn pack_block(
296    row: &crate::arrow_schur::ArrowRowBlock,
297    ridge_t: f64,
298    d: usize,
299    k: usize,
300    d_buf: &mut Vec<f64>,
301    b_buf: &mut Vec<f64>,
302    g_buf: &mut Vec<f64>,
303) {
304    for col in 0..d {
305        for r in 0..d {
306            let mut value = row.htt[[r, col]];
307            if r == col {
308                value += ridge_t;
309            }
310            d_buf.push(value);
311        }
312    }
313    for col in 0..k {
314        for r in 0..d {
315            b_buf.push(row.htbeta[[r, col]]);
316        }
317    }
318    for r in 0..d {
319        g_buf.push(row.gt[r]);
320    }
321}
322
323/// Test-only entry that forces the Layer D + E fused NVRTC path regardless
324/// of the admission heuristic. Used by the V100 Layer C↔D parity test to
325/// drive the fused kernel at small shapes the heuristic would otherwise
326/// route through the cuSOLVER/cuBLAS Layer A+B+C path.
327#[doc(hidden)]
328#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] // `sys` is consumed only by the linux branch
329pub fn solve_arrow_newton_step_fused_force(
330    sys: &ArrowSchurSystem,
331    ridge_t: f64,
332    ridge_beta: f64,
333) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
334    if ridge_t.is_nan() || ridge_beta.is_nan() {
335        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
336            reason: "ridge is NaN".to_string(),
337        });
338    }
339    #[cfg(not(target_os = "linux"))]
340    {
341        // No NVRTC toolchain off linux: the fused path is unconditionally
342        // unavailable. `sys` is consumed only by the linux branch below; the
343        // fn-level cfg_attr allows it to read as unused here without a banned
344        // `let _` binding or a no-op `drop` of the reference.
345        Err(ArrowSchurGpuFailure::Unavailable)
346    }
347    #[cfg(target_os = "linux")]
348    {
349        if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
350            .is_none()
351        {
352            return Err(ArrowSchurGpuFailure::Unavailable);
353        }
354        cuda::solve_fused(sys, ridge_t, ridge_beta)
355    }
356}
357
358/// #1017 Phase 3: a device-resident Arrow-Schur frame whose constant Hessian
359/// blocks (`D = H_tt`, `B = H_tβ`, border `H_ββ`) and their factors stay on the
360/// device across the inner Newton loop. Construct once per frozen gate/basis
361/// frame, then call [`ResidentArrowFrameHandle::solve_gradient`] once per
362/// iterate with the fresh residual gradient — only the `O(n·d + p)` gradient
363/// crosses to the device and only `δ` crosses back, in contrast to
364/// [`solve_arrow_newton_step`] which re-uploads and re-factors the full system
365/// every call. On a non-CUDA host construction returns
366/// `ArrowSchurGpuFailure::Unavailable`.
367pub struct ResidentArrowFrameHandle {
368    #[cfg(target_os = "linux")]
369    inner: cuda::ResidentArrowFrame,
370    #[cfg(not(target_os = "linux"))]
371    _never: std::convert::Infallible,
372}
373
374impl ResidentArrowFrameHandle {
375    /// Upload the constant Hessian blocks and perform the one-time factor work.
376    pub fn new(
377        sys: &ArrowSchurSystem,
378        ridge_t: f64,
379        ridge_beta: f64,
380    ) -> Result<Self, ArrowSchurGpuFailure> {
381        // The dense device path requires materialised blocks, same admission as
382        // `solve_arrow_newton_step`.
383        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
384            return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
385                had_hbb_matvec: sys.hbb_matvec.is_some(),
386                had_htbeta_matvec: sys.htbeta_matvec.is_some(),
387            });
388        }
389        #[cfg(not(target_os = "linux"))]
390        {
391            if ridge_t.is_nan() || ridge_beta.is_nan() {
392                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
393                    reason: "ridge is NaN".to_string(),
394                });
395            }
396            Err(ArrowSchurGpuFailure::Unavailable)
397        }
398        #[cfg(target_os = "linux")]
399        {
400            Ok(Self {
401                inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
402            })
403        }
404    }
405
406    /// Solve `H δ = −gradient` for a fresh gradient reusing the resident factors.
407    pub fn solve_gradient(
408        &self,
409        g_t: &[f64],
410        g_beta: &[f64],
411    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
412        #[cfg(not(target_os = "linux"))]
413        {
414            if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
415                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
416                    reason: "non-finite gradient entry".to_string(),
417                });
418            }
419            Err(ArrowSchurGpuFailure::Unavailable)
420        }
421        #[cfg(target_os = "linux")]
422        {
423            self.inner.solve_gradient(g_t, g_beta)
424        }
425    }
426
427    /// `log|H|` for the frame (constant; depends only on the factored Hessian).
428    #[must_use]
429    pub fn log_det_hessian(&self) -> f64 {
430        #[cfg(not(target_os = "linux"))]
431        {
432            // SAFETY: off-CUDA, `ResidentArrowFrameHandle::new` always returns
433            // `Err(Unavailable)`, so no handle of this type is ever constructed and
434            // this method is statically unreachable on non-Linux targets. A NaN
435            // sentinel would silently corrupt any consumer of the log-determinant,
436            // so fail loudly on the impossible path instead.
437            panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
438        }
439        #[cfg(target_os = "linux")]
440        {
441            self.inner.log_det_hessian()
442        }
443    }
444}
445
446/// Build a GPU-backed Schur matvec closure for CPU-driven PCG at K ≥ 5000.
447///
448/// Runs the fused NVRTC forward kernel once on the dense per-row `H_tβ` slabs
449/// to compute `Y_i = L_i^{-1} H_tβ^(i)` for all rows, persists the `Y_i`
450/// factors in a host-side buffer, and returns an `Arc<dyn Fn(...)>` closure
451/// that computes the full Schur matvec
452///
453/// ```text
454/// S·x = (H_ββ + ridge_beta·I)·x  −  Σ_i Y_i^T (Y_i·x)
455/// ```
456///
457/// each time it is called. At K ≥ 5000 the `Σ_i Y_i^T (Y_i·x)` term
458/// dominates over the host↔device transfer of the K-vector `x`, so the GPU
459/// path is a clear win even with per-iteration transfer.
460///
461/// `H_ββ·x` is evaluated on the CPU using `sys.hbb_matvec` when present (the
462/// matrix-free hook for SAE-manifold scale callers) or the dense `sys.hbb`
463/// block otherwise. The `Y_i` term uses cuBLAS batched GEMV device-side; only
464/// `x` (K doubles) and `out` (K doubles) cross the host↔device boundary per
465/// PCG iteration.
466///
467/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` if CUDA is unavailable or
468/// the system shape is outside the fused kernel's admission range (e.g.
469/// `d > MAX_FUSED_P = 32` or no CUDA context). Callers should fall back to CPU
470/// `InexactPCG` on `Unavailable`.
471///
472/// Returns `Err(ArrowSchurGpuFailure::RidgeBumpRequired)` if a per-row Cholesky
473/// factor failed at the requested `ridge_t`; the outer LM escalation should
474/// bump `ridge_t` and retry.
475///
476/// # Composition with the matrix-free SAE Kronecker operator
477///
478/// When `sys.htbeta_matvec` is set (matrix-free `H_tβ` Kronecker operator),
479/// the dense `H_tβ` slabs are absent — the dense forward kernel above cannot
480/// run, and at `K = 100K` the dense `Y_i = L_i^{-1} H_tβ^(i)` (`d × K` per row)
481/// could not be materialised anyway. Instead, `build_row_procedural_matvec`
482/// returns a row-procedural Schur matvec: per row it gathers
483/// `v_i = H_tβ^(i)·x` through the forward operator (sparse `O(m_i · p)`),
484/// solves `(H_tt^(i) + ρ_t·I)^{-1} v_i` through a pre-computed per-row Cholesky
485/// factor, and scatters `H_βt^(i)·w_i` through the sparse transpose operator
486/// (`O(m_i · p)`, replacing the old `O(K)` column-probe). This is the
487/// row-procedural `a_ik · Φ_k[i,m]` Kronecker apply over the active atoms only.
488pub fn gpu_schur_matvec_backend(
489    sys: &ArrowSchurSystem,
490    ridge_t: f64,
491    ridge_beta: f64,
492) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
493    // Matrix-free H_tβ operator present: drive the row-procedural sparse
494    // Kronecker apply (active atoms only) instead of the dense forward kernel.
495    if sys.htbeta_matvec.is_some() {
496        return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
497    }
498
499    #[cfg(not(target_os = "linux"))]
500    {
501        // No CUDA runtime on non-Linux. NaN ridges are validated to ensure the
502        // same contract as the Linux path.
503        if ridge_t.is_nan() || ridge_beta.is_nan() {
504            return Err(ArrowSchurGpuFailure::Unavailable);
505        }
506        Err(ArrowSchurGpuFailure::Unavailable)
507    }
508
509    #[cfg(target_os = "linux")]
510    {
511        cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
512    }
513}
514
515/// Build a row-procedural reduced-Schur matvec for matrix-free SAE Kronecker
516/// systems, eliminating the per-row latent block via cached per-row Cholesky
517/// factors and applying the cross-block through the sparse forward/transpose
518/// Kronecker operators (active atoms only).
519///
520/// The returned closure evaluates
521/// `S·x = (H_ββ + ρ_β·I)·x − Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x`,
522/// the same reduced Schur complement the dense path forms, but never
523/// materialises the `d × K` cross-block `H_tβ^(i)`: the forward operator
524/// (`out = H_tβ^(i)·x`) and transpose operator (`out += H_βt^(i)·v`) are the
525/// sparse Kronecker gather/scatter from `SaeKroneckerRows`. The per-row factor
526/// of `H_tt^(i) + ρ_t·I` is computed once when the closure is built and reused
527/// across every CG iteration.
528///
529/// Returns `RidgeBumpRequired` if a per-row block is not positive definite at
530/// the requested `ridge_t`; the outer LM escalation bumps `ridge_t` and retries.
531fn build_row_procedural_matvec(
532    sys: &ArrowSchurSystem,
533    ridge_t: f64,
534    ridge_beta: f64,
535) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
536    use std::sync::Arc;
537    let n = sys.rows.len();
538    let k = sys.k;
539    let forward = sys
540        .htbeta_matvec
541        .clone()
542        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
543    let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
544        // A forward operator without its sparse adjoint cannot be applied
545        // row-procedurally; this is a wiring error, surfaced as a Schur failure
546        // so the caller routes to the dense CPU path rather than misreporting a
547        // numerical bump.
548        ArrowSchurGpuFailure::SchurFactorFailed {
549            reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
550                     forward operator installed without its sparse adjoint"
551                .to_string(),
552        }
553    })?;
554
555    // Pre-factor each per-row block H_tt^(i) + ρ_t·I = L_i L_iᵀ on the host.
556    // The blocks are tiny (d_i ≲ 32) and the dense cross-block slabs are
557    // absent, so there is no device forward-kernel work to amortise here; the
558    // GPU win is the reduced K-system solve in `solve_reduced_beta_pcg`.
559    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
560    for (i, row) in sys.rows.iter().enumerate() {
561        let di = row.htt.nrows();
562        if row.htt.ncols() != di || row.gt.len() != di {
563            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
564                reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
565            });
566        }
567        let mut block = row.htt.clone();
568        for r in 0..di {
569            block[[r, r]] += ridge_t;
570        }
571        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
572            .ok_or_else(|| {
573                // Deficit-aware bump from the block's own entries (Gershgorin),
574                // so the outer LM escalation lifts a strongly-indefinite block
575                // out of the negative regime in one retry.
576                ArrowSchurGpuFailure::RidgeBumpRequired {
577                    row: i,
578                    bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
579                }
580            })?;
581        factors.push(factor);
582    }
583
584    // The SAE-manifold β-Hessian lives in the structured penalty operator
585    // (data-fit Gauss-Newton `G ⊗ I_p` + smoothness Kronecker blocks + any
586    // dense analytic-β residual), NOT in the dense `hbb` accumulator — for
587    // matrix-free systems `hbb` is zero/absent. Capture the effective penalty
588    // operator so `H_ββ·x` matches the CPU `schur_matvec` path exactly. The
589    // operator's `matvec` adds (`y += P x`), so seed `out` from the ridge term.
590    let penalty_op = sys.effective_penalty_op();
591    let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
592
593    let closure: crate::arrow_schur::GpuSchurMatvec =
594        Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
595            assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
596            assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
597
598            // (H_ββ + ρ_β·I)·x into out. Seed with the ridge term, then add the
599            // structured penalty-side product (penalty_op.matvec is additive).
600            {
601                let x_slice = x.as_slice().expect("x must be contiguous");
602                let out_slice = out.as_slice_mut().expect("out must be contiguous");
603                for a in 0..k {
604                    out_slice[a] = ridge_beta * x_slice[a];
605                }
606                penalty_op.matvec(x_slice, out_slice);
607            }
608
609            // out -= Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x.
610            //
611            // #1017: this row-procedural reduced-Schur term is the matrix-free
612            // SAE path's matvec hot loop (`build_row_procedural_matvec` is the
613            // host backend `gpu_schur_matvec_backend` returns when the dense
614            // `H_tβ` slabs are absent — the production Qwen shape). At
615            // (n≈2000 rows) it ran SERIALLY on one core and allocated a fresh
616            // length-`K` `neg` plus per-row `v_i`/`w_i` on EVERY CG iteration —
617            // tens of thousands of tiny heap allocations across a solve. Each
618            // row contributes an independent length-`K` scatter, so the sum is
619            // embarrassingly parallel; fan it across rayon over fixed row chunks
620            // and fold the per-chunk length-`K` partials in chunk order so the
621            // f64 reduction is deterministic (bit-identical run-to-run)
622            // regardless of thread scheduling — it agrees with the serial sum up
623            // to ULP-scale chunk reassociation (the #1017 verification gate).
624            // Because that reassociation is a real (if tiny) departure from
625            // serial, the criterion ranking across topology candidates is stable
626            // except for candidates separated by less than the reassociation
627            // margin, where the near-tie winner can flip — not an exact no-move
628            // guarantee (#1211). Stay
629            // sequential below
630            // `SCHUR_MATVEC_PARALLEL_ROW_MIN` rows and when already inside a
631            // rayon worker (the topology race fans candidates with
632            // `run_topology_race_parallel`) — the same nested-rayon guard the
633            // CPU `schur_matvec` uses. Buffers (`v_i`, `neg`) are reused across
634            // rows within a chunk, so the per-row allocation churn is gone.
635            let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
636                && rayon::current_thread_index().is_none();
637            if parallel {
638                use rayon::prelude::*;
639                const CHUNK: usize = 64;
640                let partials: Vec<Array1<f64>> = (0..n)
641                    .into_par_iter()
642                    .chunks(CHUNK)
643                    .map(|idxs| {
644                        // One length-`K` scatter accumulator per chunk; the
645                        // per-row latent vector `v_i` (length `d_i ≲ 32`) is the
646                        // only per-row buffer, sized to the row's own `d_i`.
647                        let mut neg = Array1::<f64>::zeros(k);
648                        for i in idxs {
649                            let di = row_dims[i];
650                            // v_i = H_tβ^(i)·x (sparse Kronecker gather).
651                            let mut v_i = Array1::<f64>::zeros(di);
652                            forward(i, x.view(), &mut v_i);
653                            // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
654                            let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
655                            // neg += H_βt^(i)·w_i (sparse scatter).
656                            transpose(i, w_i.view(), &mut neg);
657                        }
658                        neg
659                    })
660                    .collect();
661                // #1017/#1175 floating-point parity contract: each chunk's row
662                // sum is formed locally, then chunk partials are folded
663                // left-to-right. That makes the parallel row-procedural Schur
664                // term deterministic for a fixed input and chunking, but it is
665                // not required to be bit-identical to the serial path because
666                // the additions are reassociated at chunk boundaries. CPU/GPU
667                // validation should therefore allow ULP-scale drift while
668                // expecting stable run-to-run results.
669                let mut neg = Array1::<f64>::zeros(k);
670                for part in &partials {
671                    for a in 0..k {
672                        neg[a] += part[a];
673                    }
674                }
675                for a in 0..k {
676                    out[a] -= neg[a];
677                }
678            } else {
679                // Serial path: reuse one `neg` and one `v_i` across rows.
680                let mut neg = Array1::<f64>::zeros(k);
681                for i in 0..n {
682                    let di = row_dims[i];
683                    // v_i = H_tβ^(i)·x (sparse Kronecker gather, length d_i).
684                    let mut v_i = Array1::<f64>::zeros(di);
685                    forward(i, x.view(), &mut v_i);
686                    // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
687                    let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
688                    // neg += H_βt^(i)·w_i (sparse scatter); subtract once at end.
689                    transpose(i, w_i.view(), &mut neg);
690                }
691                for a in 0..k {
692                    out[a] -= neg[a];
693                }
694            }
695        });
696
697    Ok(closure)
698}
699
700/// Solve the reduced shared β-system `S·δβ = r` fully on device with a
701/// Jacobi-preconditioned conjugate-gradient (Steihaug truncated-CG) loop.
702///
703/// `S` is the already-reduced symmetric positive-definite `K × K` Schur
704/// complement the streaming SAE joint fit accumulates across minibatches
705/// (`StreamingArrowSchur::take_accumulators` summed over chunks, with the
706/// global β ridge folded in). The per-row latent blocks have already been
707/// eliminated into `S` on the host streaming path; the device's job is the
708/// dense `K`-dimensional solve, which is the dominant cost at `K = 100K`.
709///
710/// The dense `S·p` matvec runs on device via cuBLAS `Dgemv`, and the PCG state
711/// vectors (`x`, `r`, `z`, `p`, `S·p`) remain device-resident for the solve.
712/// Jacobi preconditioning is an elementwise CUDA kernel; only convergence
713/// scalars (`pᵀSp`, `rᵀz`, `‖r‖`) cross the host boundary per iteration, plus the
714/// final solution vector.
715///
716/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` when CUDA is unavailable
717/// or the workload is below the dispatch policy; the caller then runs the CPU
718/// reduced-β solve. Returns `Err(ArrowSchurGpuFailure::SchurFactorFailed)`
719/// when `S` carries a non-positive Jacobi diagonal (caller escalates the
720/// proximal ridge).
721pub fn solve_reduced_beta_pcg(
722    s_acc: &Array2<f64>,
723    rhs_beta: &Array1<f64>,
724    max_iterations: usize,
725    relative_tolerance: f64,
726) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
727    solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
728        .map(|(x, _)| x)
729}
730
731#[doc(hidden)]
732pub fn solve_reduced_beta_pcg_with_diagnostics(
733    s_acc: &Array2<f64>,
734    rhs_beta: &Array1<f64>,
735    max_iterations: usize,
736    relative_tolerance: f64,
737) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
738    let k = rhs_beta.len();
739    if s_acc.dim() != (k, k) {
740        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
741            reason: format!(
742                "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
743                s_acc.dim()
744            ),
745        });
746    }
747    if k == 0 {
748        return Err(ArrowSchurGpuFailure::Unavailable);
749    }
750
751    #[cfg(not(target_os = "linux"))]
752    {
753        if relative_tolerance.is_nan() || max_iterations == 0 {
754            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
755                reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
756            });
757        }
758        Err(ArrowSchurGpuFailure::Unavailable)
759    }
760
761    #[cfg(target_os = "linux")]
762    {
763        cuda::solve_reduced_beta_pcg_with_diagnostics(
764            s_acc,
765            rhs_beta,
766            max_iterations,
767            relative_tolerance,
768        )
769    }
770}
771
772pub fn solve_sae_matrix_free_pcg(
773    sys: &ArrowSchurSystem,
774    data: &DeviceSaePcgData,
775    ridge_t: f64,
776    ridge_beta: f64,
777    rhs_beta: &Array1<f64>,
778    max_iterations: usize,
779    relative_tolerance: f64,
780) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
781    if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
782        return Err(ArrowSchurGpuFailure::Unavailable);
783    }
784    #[cfg(not(target_os = "linux"))]
785    {
786        if ridge_t.is_nan()
787            || ridge_beta.is_nan()
788            || relative_tolerance.is_nan()
789            || max_iterations == 0
790        {
791            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
792                reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
793            });
794        }
795        Err(ArrowSchurGpuFailure::Unavailable)
796    }
797    #[cfg(target_os = "linux")]
798    {
799        // #1017/#1026 dispatch GUARD: framed data (frame metadata present) carries
800        // a factored β border `G ⊗ W_{ij}` data Hessian and dense per-row cross
801        // blocks the legacy `⊗ I_p` kernel CANNOT represent — feeding it framed
802        // data would silently return a WRONG Newton step (it returns Ok with no
803        // fallback). Route framed systems to the dedicated framed kernel and
804        // legacy full-`B` systems to the legacy kernel; the two never cross.
805        if data.frame.is_some() {
806            cuda::solve_sae_matrix_free_pcg_framed(
807                sys,
808                data,
809                ridge_t,
810                ridge_beta,
811                rhs_beta,
812                max_iterations,
813                relative_tolerance,
814            )
815        } else {
816            cuda::solve_sae_matrix_free_pcg(
817                sys,
818                data,
819                ridge_t,
820                ridge_beta,
821                rhs_beta,
822                max_iterations,
823                relative_tolerance,
824            )
825        }
826    }
827}
828
829/// #1551 kernel-isolating parity probe: run the framed reduced-Schur matvec
830/// `out = S·x` exactly once on the device and return it (no PCG, no offload-floor
831/// gate). The test suite diffs this element-wise against the CPU oracle
832/// [`sae_framed_schur_matvec_cpu`] to prove the GPU kernel computes the SAME
833/// operator — a check that is independent of solver conditioning (unlike a
834/// solved-`δβ` comparison, which can diverge purely because dense Cholesky and
835/// iterative PCG resolve an ill-conditioned `S` to different accuracies). On a
836/// non-CUDA host this returns `Unavailable` so the caller skips cleanly.
837#[doc(hidden)]
838pub fn framed_schur_matvec_once_on_device(
839    sys: &ArrowSchurSystem,
840    data: &DeviceSaePcgData,
841    ridge_t: f64,
842    ridge_beta: f64,
843    x: &Array1<f64>,
844) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
845    if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
846        return Err(ArrowSchurGpuFailure::Unavailable);
847    }
848    if data.frame.is_none() {
849        return Err(ArrowSchurGpuFailure::Unavailable);
850    }
851    #[cfg(not(target_os = "linux"))]
852    {
853        // CUDA is linux-only; the framed device solve is unavailable here. Read
854        // the ridge params (a real, cheap use) so the values the shared signature
855        // carries are not flagged unused on this target — without an `#[allow]`
856        // or `let _` dodge, both of which the build.rs scanner bans.
857        if ridge_t.is_finite() && ridge_beta.is_finite() {
858            return Err(ArrowSchurGpuFailure::Unavailable);
859        }
860        Err(ArrowSchurGpuFailure::Unavailable)
861    }
862    #[cfg(target_os = "linux")]
863    {
864        cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
865    }
866}
867
868/// Reference dense back-end used by tests and as the fallback when the
869/// GPU declines. Kept here (not in `arrow_schur_gpu.rs`) so the validation
870/// suite has one canonical baseline.
871#[doc(hidden)]
872pub fn solve_arrow_newton_step_dense_reference(
873    sys: &ArrowSchurSystem,
874    ridge_t: f64,
875    ridge_beta: f64,
876) -> Result<ArrowSchurGpuSolution, String> {
877    let n = sys.rows.len();
878    let d = sys.d;
879    let k = sys.k;
880    let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
881    let mut h = Array2::<f64>::zeros((total, total));
882    let mut rhs = Array1::<f64>::zeros(total);
883    for (i, row) in sys.rows.iter().enumerate() {
884        let base = i * d;
885        for c in 0..d {
886            for r in 0..d {
887                h[[base + r, base + c]] = row.htt[[r, c]];
888            }
889            h[[base + c, base + c]] += ridge_t;
890        }
891        for c in 0..k {
892            for r in 0..d {
893                let value = row.htbeta[[r, c]];
894                h[[base + r, n * d + c]] = value;
895                h[[n * d + c, base + r]] = value;
896            }
897        }
898        for r in 0..d {
899            rhs[base + r] = -row.gt[r];
900        }
901    }
902    for c in 0..k {
903        for r in 0..k {
904            h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
905        }
906        h[[n * d + c, n * d + c]] += ridge_beta;
907        rhs[n * d + c] = -sys.gb[c];
908    }
909    let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
910        .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
911    let mut log_det = 0.0_f64;
912    for i in 0..total {
913        log_det += factor[[i, i]].ln();
914    }
915    log_det *= 2.0;
916    let solved = cholesky_solve_vector(factor.view(), rhs.view());
917    let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
918    let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
919    Ok(ArrowSchurGpuSolution {
920        delta_t,
921        delta_beta,
922        log_det_hessian: log_det,
923    })
924}
925
926/// Frames-engaged reduced-Schur penalty-side matvec `out = (P_ββ + ρ_β I)·x`,
927/// computed purely from the factored device data (issue #1017/#1026). This is
928/// the CPU bit-parity ORACLE for the GPU `arrow_sae_*` penalty kernels on the
929/// frames path: smooth `λ S_k ⊗ I_{r_k}` (each `smooth_blocks[i]` at its
930/// `global_offset` with right-width `frame.smooth_ranks[i]`) plus data-fit
931/// `G_{ij} ⊗ W_{ij}` (each `frame.frame_blocks` entry, with the `μ`-major /
932/// frame-minor index `border_offset[atom] + basis·r + frame_coord`). The
933/// accumulation order matches the device kernels exactly.
934///
935/// `out` is OVERWRITTEN: first set to `ρ_β·x`, then the penalty blocks add in.
936#[doc(hidden)]
937pub fn sae_framed_penalty_matvec_cpu(
938    data: &DeviceSaePcgData,
939    ridge_beta: f64,
940    x: &[f64],
941    out: &mut [f64],
942) {
943    let frame = data
944        .frame
945        .as_ref()
946        .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
947    let k = data.beta_dim;
948    for a in 0..k {
949        out[a] = ridge_beta * x[a];
950    }
951    // Smooth penalty `λ S_k ⊗ I_{r_k}`: y[off + ia·r + ib] += Σ_ja S[ia,ja]·x[off + ja·r + ib].
952    for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
953        let off = blk.global_offset;
954        let m = blk.factor_a.nrows();
955        for i_a in 0..m {
956            for i_b in 0..r {
957                let mut acc = 0.0_f64;
958                for j_a in 0..m {
959                    let s = blk.factor_a[[i_a, j_a]];
960                    if s == 0.0 {
961                        continue;
962                    }
963                    acc += s * x[off + j_a * r + i_b];
964                }
965                out[off + i_a * r + i_b] += acc;
966            }
967        }
968    }
969    // Data-fit penalty `G_{ij} ⊗ W_{ij}`.
970    for blk in &frame.frame_blocks {
971        let r_i = frame.ranks[blk.atom_i];
972        let r_j = frame.ranks[blk.atom_j];
973        let off_i = frame.border_offsets[blk.atom_i];
974        let off_j = frame.border_offsets[blk.atom_j];
975        let (m_i, m_j) = blk.g.dim();
976        for li in 0..m_i {
977            let yi_base = off_i + li * r_i;
978            for lj in 0..m_j {
979                let g = blk.g[[li, lj]];
980                if g == 0.0 {
981                    continue;
982                }
983                let xj_base = off_j + lj * r_j;
984                for a in 0..r_i {
985                    let mut acc = 0.0_f64;
986                    for b in 0..r_j {
987                        acc += blk.w[[a, b]] * x[xj_base + b];
988                    }
989                    out[yi_base + a] += g * acc;
990                }
991            }
992        }
993    }
994}
995
996/// Frames-engaged FULL reduced-Schur matvec `out = S·x` purely from the device
997/// data, where `S = (P_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)`
998/// (issue #1017/#1026). The penalty side is [`sae_framed_penalty_matvec_cpu`];
999/// the per-row reduced term reads the dense `frame.row_htbeta[i]`
1000/// (`q_i × border_dim`, row-major), solves against the row's
1001/// `H_tt^(i)+ρ_t I` Cholesky factor, and scatters the transpose back. This is
1002/// the size-independent bit-parity oracle the device kernel mirrors; it is also
1003/// the matvec the GPU PCG iterates.
1004#[doc(hidden)]
1005pub fn sae_framed_schur_matvec_cpu(
1006    sys: &ArrowSchurSystem,
1007    data: &DeviceSaePcgData,
1008    ridge_t: f64,
1009    ridge_beta: f64,
1010    x: &[f64],
1011    out: &mut [f64],
1012) -> Result<(), String> {
1013    let frame = data
1014        .frame
1015        .as_ref()
1016        .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1017    let k = data.beta_dim;
1018    sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1019    if frame.row_htbeta.len() != sys.rows.len() {
1020        return Err(format!(
1021            "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1022            frame.row_htbeta.len(),
1023            sys.rows.len()
1024        ));
1025    }
1026    for (i, row) in sys.rows.iter().enumerate() {
1027        let slab = &frame.row_htbeta[i];
1028        if slab.is_empty() {
1029            continue;
1030        }
1031        let qi = sys.row_dims[i];
1032        if qi == 0 || slab.len() != qi * k {
1033            continue;
1034        }
1035        // h = H_tβ^(i) · x  (length q_i).
1036        let mut h = vec![0.0_f64; qi];
1037        for c in 0..qi {
1038            let base = c * k;
1039            let mut acc = 0.0_f64;
1040            for a in 0..k {
1041                acc += slab[base + a] * x[a];
1042            }
1043            h[c] = acc;
1044        }
1045        // solve (H_tt^(i)+ρ_t I) s = h.
1046        let mut block = row.htt.clone();
1047        for d in 0..qi {
1048            block[[d, d]] += ridge_t;
1049        }
1050        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1051            .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1052        let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1053        // out -= H_βt^(i) · s = (H_tβ^(i))ᵀ · s.
1054        for c in 0..qi {
1055            let sc = s[c];
1056            if sc == 0.0 {
1057                continue;
1058            }
1059            let base = c * k;
1060            for a in 0..k {
1061                out[a] -= slab[base + a] * sc;
1062            }
1063        }
1064    }
1065    Ok(())
1066}
1067
1068#[cfg(target_os = "linux")]
1069mod cuda {
1070    use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1071    use gam_gpu::driver::to_i32;
1072    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1073    use crate::arrow_schur::{
1074        ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1075    };
1076    use cudarc::cublas::sys::{
1077        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1078    };
1079    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1080    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1081    use cudarc::driver::{
1082        CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1083        PushKernelArg,
1084    };
1085    use ndarray::Array1;
1086    use std::sync::{Arc, OnceLock};
1087
1088    /// Per-row work slot for the row-block-granular multi-GPU solve. Inputs are
1089    /// the packed single-row buffers (`d×d` D block + ρ_t ridge, `d×k` B block,
1090    /// `d` g vector); the forward pass fills the whitened factors `l/u/y` and the
1091    /// per-tile reduction lands in the tile's leading slot.
1092    struct RowSlot {
1093        // Inputs (packed once on the host, column-major).
1094        d_block: Vec<f64>, // d*d
1095        b_block: Vec<f64>, // d*k
1096        g_vec: Vec<f64>,   // d
1097        // Forward outputs, kept on the host for the back-sub pass.
1098        l_block: Vec<f64>, // d*d lower factor, column-major
1099        u_vec: Vec<f64>,   // d   (= L^{-1} g)
1100        y_block: Vec<f64>, // d*k (= L^{-1} B), column-major
1101        log_det_local: f64,
1102        // Set on a non-PD pivot so the orchestrator can raise RidgeBumpRequired
1103        // for the offending global row instead of silently falling back.
1104        bump: Option<f64>,
1105        // Tile-level reduction, written into the tile's first slot only.
1106        tile_partial_schur: Option<Vec<f64>>, // k*k col-major, = Σ Y_iᵀY_i
1107        tile_partial_rhs: Option<Vec<f64>>,   // k, = Σ Y_iᵀu_i
1108        // Back-sub output for this row.
1109        delta_t_block: Vec<f64>, // d
1110    }
1111
1112    /// Row-block-granular multi-GPU Arrow-Schur Newton solve.
1113    ///
1114    /// The solve is separable across row blocks in both phases:
1115    ///   * forward — each row's local Cholesky `L_i`, whitening
1116    ///     `u_i = L_i⁻¹g_i`, `Y_i = L_i⁻¹B_i`, and partial Schur
1117    ///     `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` are independent;
1118    ///   * backward — `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)` is independent.
1119    /// Only the small shared `K×K` reduce + factor + `δβ` solve is central.
1120    ///
1121    /// `gam_gpu::pool::scatter_batched` hands each device a contiguous row
1122    /// tile on its own bound context/stream; the per-tile forward keeps the
1123    /// POTRF fused with its dependent TRSM + Schur GEMM on that one stream, so no
1124    /// on-stream solve is orphaned. Tile partials and per-tile `log|L|` are
1125    /// reduced on the host (in tile/row order), `S_β` is factored on the primary
1126    /// device, and the back-sub is scattered back across the same tiles.
1127    ///
1128    /// Returns `Unavailable` (caller uses a single-device path) when the system
1129    /// carries matrix-free operators, the shared block is not dense `K×K`, the
1130    /// pool is single-device, or any tile's device work declines. A non-PD tip
1131    /// block surfaces as `RidgeBumpRequired` for the precise global row.
1132    pub(super) fn solve_multi_gpu(
1133        sys: &ArrowSchurSystem,
1134        ridge_t: f64,
1135        ridge_beta: f64,
1136    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1137        let n = sys.rows.len();
1138        let d = sys.d;
1139        let k = sys.k;
1140        if n == 0 || d == 0 || k == 0 {
1141            return Err(ArrowSchurGpuFailure::Unavailable);
1142        }
1143        // Dense shared block + materialised per-row slabs are required; the
1144        // public entry already rejected matrix-free operators, but re-check so
1145        // this routine is safe in isolation.
1146        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1147            return Err(ArrowSchurGpuFailure::Unavailable);
1148        }
1149
1150        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1151            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1152        if runtime.device_count() < 2 {
1153            return Err(ArrowSchurGpuFailure::Unavailable);
1154        }
1155
1156        // Pack one slot per row (column-major), folding ρ_t into each D block.
1157        let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1158        for row in &sys.rows {
1159            if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1160                return Err(ArrowSchurGpuFailure::Unavailable);
1161            }
1162            let mut d_block = Vec::with_capacity(d * d);
1163            let mut b_block = Vec::with_capacity(d * k);
1164            let mut g_vec = Vec::with_capacity(d);
1165            pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1166            slots.push(RowSlot {
1167                d_block,
1168                b_block,
1169                g_vec,
1170                l_block: Vec::new(),
1171                u_vec: Vec::new(),
1172                y_block: Vec::new(),
1173                log_det_local: 0.0,
1174                bump: None,
1175                tile_partial_schur: None,
1176                tile_partial_rhs: None,
1177                delta_t_block: vec![0.0; d],
1178            });
1179        }
1180
1181        // ---- Forward pass: per-device row tile, fused on its own stream ----
1182        let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1183            forward_tile(ordinal, d, k, tile)
1184        });
1185        if forward_ok.is_none() {
1186            return Err(ArrowSchurGpuFailure::Unavailable);
1187        }
1188
1189        // Surface a non-PD tip block as a precise per-row ridge bump.
1190        let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1191        if let Some((row, bump)) = slots
1192            .iter()
1193            .enumerate()
1194            .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1195        {
1196            return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1197        }
1198
1199        // ---- Central: reduce tile partials → S_β, r_β; factor; solve δβ ----
1200        // Seed S_β with H_ββ + ρ_β I (column-major) and r_β with -g_β, then fold
1201        // in the per-tile partials in tile order so the reduction order tracks
1202        // the single-device accumulation (up to inter-tile reassociation).
1203        let mut schur_host = vec![0.0_f64; k * k];
1204        for col in 0..k {
1205            for row in 0..k {
1206                let mut v = sys.hbb[[row, col]];
1207                if row == col {
1208                    v += ridge_beta;
1209                }
1210                schur_host[col * k + row] = v;
1211            }
1212        }
1213        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1214        let mut log_det = 0.0_f64;
1215        for start in tile_starts(&row_base_of_tile) {
1216            let slot = &slots[start];
1217            let partial_schur = slot
1218                .tile_partial_schur
1219                .as_ref()
1220                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1221            let partial_rhs = slot
1222                .tile_partial_rhs
1223                .as_ref()
1224                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1225            // `accumulate_schur` writes `partial_schur = -Σ_tile Y_iᵀY_i` (GEMM
1226            // α=-1, β=1 into a zero seed) and `partial_rhs = +Σ_tile Y_iᵀu_i`.
1227            // The reduced Schur is `S = (H_ββ+ρI) − Σ_all Y_iᵀY_i`, so adding the
1228            // (already-negated) partials reproduces the single-device sign.
1229            for idx in 0..k * k {
1230                schur_host[idx] += partial_schur[idx];
1231            }
1232            for a in 0..k {
1233                rhs_host[a] += partial_rhs[a];
1234            }
1235        }
1236        for slot in &slots {
1237            log_det += slot.log_det_local;
1238        }
1239
1240        // Factor S_β and solve δβ on the primary device (small K×K leaf). The
1241        // stream carries the primary context (same pattern as `solve()`); no
1242        // thread bind is needed for the cuSOLVER/cuBLAS handles created from it.
1243        let primary = runtime.selected_device().ordinal;
1244        let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1245            .and_then(|ctx| ctx.new_stream().ok())
1246            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1247        let solver =
1248            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1249        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1250        let mut schur_dev = stream
1251            .clone_htod(&schur_host)
1252            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1253        let mut rhs_dev = stream
1254            .clone_htod(&rhs_host)
1255            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1256        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1257        if info != 0 {
1258            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1259                reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1260            });
1261        }
1262        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1263        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1264        let delta_beta_host = stream
1265            .clone_dtoh(&rhs_dev)
1266            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1267        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1268        let l_schur_host = stream
1269            .clone_dtoh(&schur_dev)
1270            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1271        for j in 0..k {
1272            log_det += l_schur_host[j * k + j].ln();
1273        }
1274        log_det *= 2.0;
1275
1276        // ---- Backward pass: δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ), per-device tile ----
1277        let delta_beta_ref = &delta_beta_host;
1278        let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1279            back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1280        });
1281        if back_ok.is_none() {
1282            return Err(ArrowSchurGpuFailure::Unavailable);
1283        }
1284
1285        // Stitch per-row δt into the stacked (n*d) result.
1286        let mut delta_t = Array1::<f64>::zeros(n * d);
1287        for (i, slot) in slots.iter().enumerate() {
1288            let base = i * d;
1289            for r in 0..d {
1290                delta_t[base + r] = slot.delta_t_block[r];
1291            }
1292        }
1293
1294        Ok(ArrowSchurGpuSolution {
1295            delta_t,
1296            delta_beta,
1297            log_det_hessian: log_det,
1298        })
1299    }
1300
1301    /// Tile starts: the leading global row index of each device tile (where the
1302    /// tile-level partial reduction was written by the forward pass).
1303    fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1304        tiles.iter().map(|(_, range)| range.start)
1305    }
1306
1307    /// Forward pass for one device row tile, running on `ordinal`'s bound stream.
1308    /// Factors each row block, whitens `u`/`Y`, accumulates the tile's partial
1309    /// Schur `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` into the tile's leading slot, keeps the
1310    /// per-row `L`/`u`/`Y` on the host for back-sub, and records the per-row
1311    /// `Σ_j log L_jj`. A non-PD pivot is recorded in `slot.bump` (the tile still
1312    /// returns `Some(())` so the orchestrator raises a precise `RidgeBumpRequired`
1313    /// rather than collapsing the whole batch to CPU).
1314    fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1315        if tile.is_empty() {
1316            return Some(());
1317        }
1318        // `scatter_batched` has already bound this ordinal's context on this
1319        // worker thread; the stream below targets that same device.
1320        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1321            .and_then(|ctx| ctx.new_stream().ok())?;
1322        let solver = DnHandle::new(stream.clone()).ok()?;
1323        let blas = CudaBlas::new(stream.clone()).ok()?;
1324        let m = tile.len();
1325
1326        // Stack the tile's D, B, g into contiguous device buffers (same layout
1327        // the single-device path packs for `m` rows).
1328        let mut d_host = Vec::with_capacity(m * d * d);
1329        let mut b_host = Vec::with_capacity(m * d * k);
1330        let mut g_host = Vec::with_capacity(m * d);
1331        for slot in tile.iter() {
1332            d_host.extend_from_slice(&slot.d_block);
1333            b_host.extend_from_slice(&slot.b_block);
1334            g_host.extend_from_slice(&slot.g_vec);
1335        }
1336        let mut d_dev = stream.clone_htod(&d_host).ok()?;
1337        let mut b_dev = stream.clone_htod(&b_host).ok()?;
1338        let mut g_dev = stream.clone_htod(&g_host).ok()?;
1339
1340        // Batched POTRF; a non-PD block records its bump and stops the tile.
1341        // The bump is deficit-aware (Gershgorin lower bound on λ_min of the
1342        // already-ridged `d_block`), NOT derived from the cuSOLVER `info` —
1343        // which is a 1-based pivot ROW INDEX, not a pivot magnitude — so a
1344        // strongly-indefinite block recovers in one outer-loop retry.
1345        let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1346        if let Some(local) = info_host.iter().position(|info| *info != 0) {
1347            tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1348                &tile[local].d_block,
1349                d,
1350            ));
1351            return Some(());
1352        }
1353
1354        // Whiten: u = L⁻¹ g, Y = L⁻¹ B.
1355        trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1356        trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1357
1358        // Tile partial Schur: zero-seeded so the host adds the H_ββ seed once.
1359        let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1360        let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1361        accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1362
1363        // Download L, u, Y, and the tile partials.
1364        let l_host = stream.clone_dtoh(&d_dev).ok()?;
1365        let u_host = stream.clone_dtoh(&g_dev).ok()?;
1366        let y_host = stream.clone_dtoh(&b_dev).ok()?;
1367        let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1368        let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1369
1370        for (local, slot) in tile.iter_mut().enumerate() {
1371            let l_base = local * d * d;
1372            let u_base = local * d;
1373            let y_base = local * d * k;
1374            slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1375            slot.u_vec = u_host[u_base..u_base + d].to_vec();
1376            slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1377            let mut log_det_local = 0.0_f64;
1378            for j in 0..d {
1379                log_det_local += l_host[l_base + j * d + j].ln();
1380            }
1381            slot.log_det_local = log_det_local;
1382        }
1383        tile[0].tile_partial_schur = Some(partial_schur);
1384        tile[0].tile_partial_rhs = Some(partial_rhs);
1385        Some(())
1386    }
1387
1388    /// Back-substitution for one device row tile: `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)`.
1389    /// Re-uploads the tile's kept `L`/`u`/`Y` to `ordinal`, applies the GEMV
1390    /// accumulate + transposed TRSM, and writes each row's `δt` into its slot.
1391    fn back_sub_tile(
1392        ordinal: usize,
1393        d: usize,
1394        k: usize,
1395        delta_beta: &[f64],
1396        tile: &mut [RowSlot],
1397    ) -> Option<()> {
1398        if tile.is_empty() {
1399            return Some(());
1400        }
1401        // `scatter_batched` has already bound this ordinal's context on this
1402        // worker thread; the stream below targets that same device.
1403        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1404            .and_then(|ctx| ctx.new_stream().ok())?;
1405        let blas = CudaBlas::new(stream.clone()).ok()?;
1406        let m = tile.len();
1407
1408        let mut l_host = Vec::with_capacity(m * d * d);
1409        let mut u_host = Vec::with_capacity(m * d);
1410        let mut y_host = Vec::with_capacity(m * d * k);
1411        for slot in tile.iter() {
1412            l_host.extend_from_slice(&slot.l_block);
1413            u_host.extend_from_slice(&slot.u_vec);
1414            y_host.extend_from_slice(&slot.y_block);
1415        }
1416        let d_dev = stream.clone_htod(&l_host).ok()?;
1417        let mut g_dev = stream.clone_htod(&u_host).ok()?;
1418        let b_dev = stream.clone_htod(&y_host).ok()?;
1419        let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1420
1421        // g ← u + Y·δβ, then x = L⁻ᵀ g; δt = -x.
1422        accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1423        trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1424        let x_host = stream.clone_dtoh(&g_dev).ok()?;
1425        for (local, slot) in tile.iter_mut().enumerate() {
1426            let base = local * d;
1427            for r in 0..d {
1428                slot.delta_t_block[r] = -x_host[base + r];
1429            }
1430        }
1431        Some(())
1432    }
1433
1434    pub(super) fn solve(
1435        sys: &ArrowSchurSystem,
1436        ridge_t: f64,
1437        ridge_beta: f64,
1438    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1439        let n = sys.rows.len();
1440        let d = sys.d;
1441        let k = sys.k;
1442        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1443            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1444
1445        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1446            .and_then(|ctx| ctx.new_stream().ok())
1447            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1448        let solver =
1449            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1450        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1451
1452        // ----- Pack + upload D, B, g -----
1453        let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1454        let mut d_dev = stream
1455            .clone_htod(&d_host)
1456            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1457        let mut b_dev = stream
1458            .clone_htod(&b_host)
1459            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1460        let mut g_dev = stream
1461            .clone_htod(&g_host)
1462            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1463
1464        // ----- Layer A: batched lower Cholesky of D in place -----
1465        // This POTRF is fused with the downstream TRSM + Schur GEMM + back-sub
1466        // on this one stream, so splitting only the POTRF across devices would
1467        // orphan the dependent on-stream solves. Multi-GPU here is the
1468        // whole-solve row-block split in `solve_arrow_newton_step` (see
1469        // `solve_multi_gpu`), not a per-layer split — this device-resident path
1470        // is the single-device leaf the split dispatches per tile.
1471        let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1472        if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1473            // `info` is cuSOLVER's 1-based pivot ROW INDEX, not a magnitude;
1474            // size the bump from the block's own entries (Gershgorin λ_min
1475            // bound) so a strongly-indefinite block recovers in one retry.
1476            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1477                row: idx,
1478                bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1479            });
1480        }
1481
1482        // ----- Layer B (1/2): in-place triangular solves -----
1483        // u_i = L_i^{-1} g_i, packed as a stacked (n*d) column-vector.
1484        trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1485        // Y_i = L_i^{-1} B_i, in place over the (n*d) × k buffer (laid out as
1486        // n stacked column-major d×k tiles).
1487        trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1488
1489        // ----- Layer B (2/2): Schur reduction via single big GEMM / GEMV -----
1490        // Y_all is (n*d) × k column-major: viewing all n stacked d×k tiles as
1491        // one big matrix is bit-exact because each tile is column-major with
1492        // leading dim d and the tiles are contiguous in memory, so the
1493        // combined leading dim is n*d only for the *outer* matrix view. To
1494        // make the single-GEMM equivalence hold we must treat the stacked
1495        // buffer as (n*d) × k column-major with leading dim = n*d, which
1496        // means columns of Y_all are interleaved by row across blocks.
1497        // That is NOT what we packed. So we use the cuBLAS stride pattern
1498        // instead: stride-by-block, transpose-A, and *accumulate* into one
1499        // S_β buffer via beta=1 across batches. Equivalent flop count, no
1500        // extra reduction kernel, and correct layout.
1501        //
1502        // Concretely: schur ← C + ρ_β I; rhs ← -g_β; then for each block
1503        //   schur -= Y_i^T Y_i      (k×k)
1504        //   rhs   += Y_i^T u_i      (k)
1505        // We launch this as `n` sequential GEMMs/GEMVs with beta=1 on the
1506        // accumulator. Layer D fuses these into one NVRTC launch.
1507        let schur_init: Vec<f64> = {
1508            let mut tmp = Vec::with_capacity(k * k);
1509            for col in 0..k {
1510                for row in 0..k {
1511                    let mut v = sys.hbb[[row, col]];
1512                    if row == col {
1513                        v += ridge_beta;
1514                    }
1515                    tmp.push(v);
1516                }
1517            }
1518            tmp
1519        };
1520        let mut schur_dev = stream
1521            .clone_htod(&schur_init)
1522            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1523        let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1524        let mut rhs_dev = stream
1525            .clone_htod(&rhs_init)
1526            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1527
1528        accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1529
1530        // ----- Layer C (1/2): factor S_β and solve for δβ -----
1531        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1532        if info != 0 {
1533            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1534                reason: format!("Schur Cholesky failed at pivot {info}"),
1535            });
1536        }
1537        // δβ ← L_S^{-T} L_S^{-1} rhs
1538        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1539        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1540        let delta_beta_host = stream
1541            .clone_dtoh(&rhs_dev)
1542            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1543        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1544
1545        // ----- Layer C (2/2): back-sub δt_i = -L_i^{-T} (u_i + Y_i δβ) -----
1546        // Already on device:
1547        //   g_dev holds u_i stacked (n*d).
1548        //   b_dev holds Y_i stacked column-major n×(d×k) tiles.
1549        // Compute g_dev ← g_dev + Y_block · δβ per block (cuBLAS gemv with beta=1),
1550        // then in-place trsm with L_i^T (CUBLAS_OP_T) to obtain x_i, and finally
1551        // δt_i = -x_i on host after download.
1552        accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1553        trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1554
1555        let x_host = stream
1556            .clone_dtoh(&g_dev)
1557            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1558        let mut delta_t = Array1::<f64>::zeros(n * d);
1559        for (i, v) in x_host.iter().enumerate() {
1560            delta_t[i] = -*v;
1561        }
1562
1563        // ----- log|H| = 2 Σ log L_{i,jj} + 2 Σ log R_{β,aa} -----
1564        let l_local_host = stream
1565            .clone_dtoh(&d_dev)
1566            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1567        let l_schur_host = stream
1568            .clone_dtoh(&schur_dev)
1569            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1570        let mut log_det = 0.0_f64;
1571        for i in 0..n {
1572            let base = i * d * d;
1573            for j in 0..d {
1574                log_det += l_local_host[base + j * d + j].ln();
1575            }
1576        }
1577        for j in 0..k {
1578            log_det += l_schur_host[j * k + j].ln();
1579        }
1580        log_det *= 2.0;
1581
1582        Ok(ArrowSchurGpuSolution {
1583            delta_t,
1584            delta_beta,
1585            log_det_hessian: log_det,
1586        })
1587    }
1588
1589    fn potrf_batched(
1590        solver: &DnHandle,
1591        stream: &Arc<CudaStream>,
1592        p: usize,
1593        batch: usize,
1594        matrices: &mut CudaSlice<f64>,
1595    ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1596        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1597        let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1598        let matrix_len = p * p;
1599        let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1600        let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1601        let mut ptrs = Vec::with_capacity(batch);
1602        for idx in 0..batch {
1603            ptrs.push(base_ptr + (idx as u64) * bytes_per);
1604        }
1605        let mut ptrs_dev = stream
1606            .clone_htod(&ptrs)
1607            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1608        let mut info_dev = stream
1609            .alloc_zeros::<i32>(batch)
1610            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1611        let status = {
1612            let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1613            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1614            // SAFETY: pointer array and info buffer live on the device,
1615            // matrices_dev holds `batch` contiguous p×p column-major blocks.
1616            unsafe {
1617                cusolver_sys::cusolverDnDpotrfBatched(
1618                    solver.cu(),
1619                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1620                    p_i,
1621                    ptrs_ptr as *mut *mut f64,
1622                    p_i,
1623                    info_ptr as *mut i32,
1624                    batch_i,
1625                )
1626            }
1627        };
1628        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1629            return Err(ArrowSchurGpuFailure::Unavailable);
1630        }
1631        stream
1632            .clone_dtoh(&info_dev)
1633            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1634    }
1635
1636    fn potrf_single(
1637        solver: &DnHandle,
1638        stream: &Arc<CudaStream>,
1639        p: usize,
1640        matrix: &mut CudaSlice<f64>,
1641    ) -> Result<i32, ArrowSchurGpuFailure> {
1642        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1643        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1644        let mut lwork = 0_i32;
1645        {
1646            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1647            // SAFETY: buffer query against a live p-by-p column-major device matrix.
1648            let status = unsafe {
1649                cusolver_sys::cusolverDnDpotrf_bufferSize(
1650                    solver.cu(),
1651                    uplo,
1652                    p_i,
1653                    mat_ptr as *mut f64,
1654                    p_i,
1655                    &mut lwork,
1656                )
1657            };
1658            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1659                return Err(ArrowSchurGpuFailure::Unavailable);
1660            }
1661        }
1662        let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1663        let mut workspace = stream
1664            .alloc_zeros::<f64>(lwork_usize.max(1))
1665            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1666        let mut info_dev = stream
1667            .alloc_zeros::<i32>(1)
1668            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1669        {
1670            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1671            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1672            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1673            // SAFETY: all three pointers refer to live, correctly sized device buffers.
1674            let status = unsafe {
1675                cusolver_sys::cusolverDnDpotrf(
1676                    solver.cu(),
1677                    uplo,
1678                    p_i,
1679                    mat_ptr as *mut f64,
1680                    p_i,
1681                    work_ptr as *mut f64,
1682                    lwork,
1683                    info_ptr as *mut i32,
1684                )
1685            };
1686            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1687                return Err(ArrowSchurGpuFailure::Unavailable);
1688            }
1689        }
1690        let info_host = stream
1691            .clone_dtoh(&info_dev)
1692            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1693        Ok(info_host[0])
1694    }
1695
1696    /// In-place lower-triangular solves `X_i ← L_i^{-1} X_i` over the n stacked
1697    /// d×nrhs RHS tiles in `rhs`. Uses `cublasDtrsmBatched` so all n solves
1698    /// hit the device in one launch.
1699    fn trsm_batched_lower_inplace(
1700        blas: &CudaBlas,
1701        stream: &Arc<CudaStream>,
1702        d: usize,
1703        n: usize,
1704        nrhs: usize,
1705        l_stack: &CudaSlice<f64>,
1706        rhs_stack: &mut CudaSlice<f64>,
1707    ) -> Result<(), ArrowSchurGpuFailure> {
1708        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1709    }
1710
1711    /// As above but with `L_i^T` instead of `L_i`.
1712    fn trsm_batched_lower_inplace_transposed(
1713        blas: &CudaBlas,
1714        stream: &Arc<CudaStream>,
1715        d: usize,
1716        n: usize,
1717        nrhs: usize,
1718        l_stack: &CudaSlice<f64>,
1719        rhs_stack: &mut CudaSlice<f64>,
1720    ) -> Result<(), ArrowSchurGpuFailure> {
1721        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1722    }
1723
1724    fn trsm_batched_inplace_inner(
1725        blas: &CudaBlas,
1726        stream: &Arc<CudaStream>,
1727        d: usize,
1728        n: usize,
1729        nrhs: usize,
1730        l_stack: &CudaSlice<f64>,
1731        rhs_stack: &mut CudaSlice<f64>,
1732        transposed: bool,
1733    ) -> Result<(), ArrowSchurGpuFailure> {
1734        let alpha = 1.0_f64;
1735        let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1736        let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1737        let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1738        let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1739        let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1740        let (l_base, _l_record) = l_stack.device_ptr(stream);
1741        let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1742        let mut l_ptrs = Vec::with_capacity(n);
1743        let mut rhs_ptrs = Vec::with_capacity(n);
1744        for i in 0..n {
1745            l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1746            rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1747        }
1748        let mut l_ptrs_dev = stream
1749            .clone_htod(&l_ptrs)
1750            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1751        let mut rhs_ptrs_dev = stream
1752            .clone_htod(&rhs_ptrs)
1753            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1754        let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1755        let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1756        let op = if transposed {
1757            cublasOperation_t::CUBLAS_OP_T
1758        } else {
1759            cublasOperation_t::CUBLAS_OP_N
1760        };
1761        let handle = *blas.handle();
1762        // SAFETY: pointer arrays and base buffers were just constructed from
1763        // live device allocations covering the entire batch.
1764        let status = unsafe {
1765            cudarc::cublas::sys::cublasDtrsmBatched(
1766                handle,
1767                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1768                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1769                op,
1770                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1771                d_i,
1772                nrhs_i,
1773                &alpha,
1774                l_ptrs_ptr as *const *const f64,
1775                d_i,
1776                rhs_ptrs_ptr as *const *mut f64,
1777                d_i,
1778                batch_i,
1779            )
1780        };
1781        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1782            return Err(ArrowSchurGpuFailure::Unavailable);
1783        }
1784        Ok(())
1785    }
1786
1787    /// Single-matrix lower-triangular solve: `rhs ← L^{-1} rhs` (or
1788    /// `L^{-T} rhs` if `transposed`). For the Schur Cholesky back-sub.
1789    fn trsm_single(
1790        blas: &CudaBlas,
1791        stream: &Arc<CudaStream>,
1792        n: usize,
1793        l: &CudaSlice<f64>,
1794        rhs: &mut CudaSlice<f64>,
1795        upper: bool,
1796        transposed: bool,
1797    ) -> Result<(), ArrowSchurGpuFailure> {
1798        let alpha = 1.0_f64;
1799        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1800        let handle = *blas.handle();
1801        let (l_ptr, _l_rec) = l.device_ptr(stream);
1802        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1803        // SAFETY: single n×n lower factor and n-vector RHS on device.
1804        let status = unsafe {
1805            cudarc::cublas::sys::cublasDtrsm_v2(
1806                handle,
1807                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1808                if upper {
1809                    cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1810                } else {
1811                    cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1812                },
1813                if transposed {
1814                    cublasOperation_t::CUBLAS_OP_T
1815                } else {
1816                    cublasOperation_t::CUBLAS_OP_N
1817                },
1818                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1819                n_i,
1820                1,
1821                &alpha,
1822                l_ptr as *const f64,
1823                n_i,
1824                rhs_ptr as *mut f64,
1825                n_i,
1826            )
1827        };
1828        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1829            return Err(ArrowSchurGpuFailure::Unavailable);
1830        }
1831        Ok(())
1832    }
1833
1834    /// Accumulate `schur ← schur − Σ_i Y_i^T Y_i` and `rhs ← rhs + Σ_i Y_i^T u_i`
1835    /// using one GEMM and one GEMV per block. Each call uses beta=1 to chain
1836    /// the accumulation device-side.
1837    fn accumulate_schur(
1838        blas: &CudaBlas,
1839        d: usize,
1840        k: usize,
1841        n: usize,
1842        y_stack: &CudaSlice<f64>,
1843        u_stack: &CudaSlice<f64>,
1844        schur: &mut CudaSlice<f64>,
1845        rhs: &mut CudaSlice<f64>,
1846    ) -> Result<(), ArrowSchurGpuFailure> {
1847        let y_block_elems = d * k;
1848        let u_block_elems = d;
1849        for i in 0..n {
1850            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1851            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1852            // GEMM: schur += (-1) · Y_i^T · Y_i  (Y_i is d×k col-major; out is k×k)
1853            let gemm_cfg = GemmConfig::<f64> {
1854                transa: cublasOperation_t::CUBLAS_OP_T,
1855                transb: cublasOperation_t::CUBLAS_OP_N,
1856                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1857                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1858                k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1859                alpha: -1.0,
1860                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1861                ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1862                beta: 1.0,
1863                ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1864            };
1865            // SAFETY: y_slice is d×k col-major, schur is k×k col-major; alpha/beta scalars set above.
1866            unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1867                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1868            // GEMV: rhs += 1 · Y_i^T · u_i
1869            let gemv_cfg = GemvConfig::<f64> {
1870                trans: cublasOperation_t::CUBLAS_OP_T,
1871                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1872                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1873                alpha: 1.0,
1874                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1875                incx: 1,
1876                beta: 1.0,
1877                incy: 1,
1878            };
1879            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1880            // device buffers; `rhs` is the length-k accumulator.
1881            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1882                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1883        }
1884        Ok(())
1885    }
1886
1887    /// `#1017` resident gradient path: accumulate ONLY the Schur RHS term
1888    /// `rhs += Σ_i Y_iᵀ u_i`, skipping the `−Σ_i Y_iᵀ Y_i` matrix GEMM that the
1889    /// resident frame already folded into its persistent `L_S` factor. This is
1890    /// the per-iterate-cheap counterpart of [`accumulate_schur`]: the GEMV here
1891    /// is bit-identical to the GEMV inside `accumulate_schur` (same config, same
1892    /// `beta=1` accumulation order over rows), so the resident frame's `δβ`
1893    /// matches a full `solve()` at the same gradient.
1894    fn accumulate_schur_rhs_only(
1895        blas: &CudaBlas,
1896        d: usize,
1897        k: usize,
1898        n: usize,
1899        y_stack: &CudaSlice<f64>,
1900        u_stack: &CudaSlice<f64>,
1901        rhs: &mut CudaSlice<f64>,
1902    ) -> Result<(), ArrowSchurGpuFailure> {
1903        let y_block_elems = d * k;
1904        let u_block_elems = d;
1905        for i in 0..n {
1906            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1907            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1908            let gemv_cfg = GemvConfig::<f64> {
1909                trans: cublasOperation_t::CUBLAS_OP_T,
1910                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1911                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1912                alpha: 1.0,
1913                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1914                incx: 1,
1915                beta: 1.0,
1916                incy: 1,
1917            };
1918            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1919            // device buffers; `rhs` is the length-k accumulator.
1920            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1921                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1922        }
1923        Ok(())
1924    }
1925
1926    /// Accumulate `g_dev[i] ← u_i + Y_i · δβ` per block. This is the
1927    /// pre-trsm RHS for the back-substitution `L_i^T x_i = w_i`.
1928    fn accumulate_back_sub_rhs(
1929        blas: &CudaBlas,
1930        d: usize,
1931        k: usize,
1932        n: usize,
1933        y_stack: &CudaSlice<f64>,
1934        delta_beta: &CudaSlice<f64>,
1935        u_stack: &mut CudaSlice<f64>,
1936    ) -> Result<(), ArrowSchurGpuFailure> {
1937        let y_block_elems = d * k;
1938        let u_block_elems = d;
1939        for i in 0..n {
1940            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1941            let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1942            let gemv_cfg = GemvConfig::<f64> {
1943                trans: cublasOperation_t::CUBLAS_OP_N,
1944                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1945                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1946                alpha: 1.0,
1947                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1948                incx: 1,
1949                beta: 1.0,
1950                incy: 1,
1951            };
1952            // SAFETY: y_slice / delta_beta / u_slice are live device buffers
1953            // of the expected sizes (d×k, k, d).
1954            unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1955                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1956        }
1957        Ok(())
1958    }
1959
1960    // ────────────────────────────────────────────────────────────────────
1961    // Layer D + E — fused NVRTC dispatch.
1962    //
1963    // The forward kernel (`arrow_schur_forward_pgroup`) is a single launch
1964    // that, per row block, factors `D_i + ρI = L_i L_iᵀ` in shared memory,
1965    // forward-solves `u_i = L_i⁻¹ g_i` and `Y_i = L_i⁻¹ B_i`, and emits the
1966    // per-block Schur partials `partial_S[i] = Yᵀ Y` (R×R) and
1967    // `partial_r[i] = Yᵀ u` (R). The host reduces partials on the CPU after
1968    // dtoh (one fused sum across `n` blocks of R²+R doubles; cheap because
1969    // n·R² ≲ 5M doubles at large scale), assembles `S_β`, factors it via
1970    // cuSOLVER, and launches the back-substitution kernel
1971    // `arrow_schur_back_sub_pgroup` to recover `δt_i = -L_i⁻ᵀ(u_i + Y_i δβ)`
1972    // without re-uploading the local factors.
1973    // ────────────────────────────────────────────────────────────────────
1974
1975    use std::collections::HashMap;
1976    use std::sync::Mutex;
1977
1978    /// One compiled NVRTC module per `(cc_major, cc_minor, p_max, r_template)`.
1979    /// `cc_*` lets one process drive multiple device generations; the
1980    /// `(p_max, r_template)` pair selects the shared-memory layout baked into
1981    /// the kernel source.
1982    struct FusedModuleCache {
1983        modules: Mutex<
1984            HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1985        >,
1986    }
1987
1988    fn fused_module_cache() -> &'static FusedModuleCache {
1989        static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1990        CACHE.get_or_init(|| FusedModuleCache {
1991            modules: Mutex::new(HashMap::new()),
1992        })
1993    }
1994
1995    fn fused_module_for(
1996        ctx: &Arc<CudaContext>,
1997        key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
1998    ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
1999        let cache = fused_module_cache();
2000        if let Ok(guard) = cache.modules.lock() {
2001            if let Some(existing) = guard.get(&key) {
2002                return Ok(existing.clone());
2003            }
2004        }
2005        let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2006            key.p_max as usize,
2007            key.r_template as usize,
2008        );
2009        let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
2010            ArrowSchurGpuFailure::SchurFactorFailed {
2011                reason: format!(
2012                    "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2013                    key.p_max, key.r_template
2014                ),
2015            }
2016        })?;
2017        let module = ctx
2018            .load_module(ptx)
2019            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2020        if let Ok(mut guard) = cache.modules.lock() {
2021            guard.entry(key).or_insert_with(|| module.clone());
2022        }
2023        Ok(module)
2024    }
2025
2026    const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2027extern "C" __global__ void arrow_pcg_jacobi_mul(
2028    const double* __restrict__ inv_diag,
2029    const double* __restrict__ r,
2030    double* __restrict__ z,
2031    int n
2032) {
2033    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2034    if (idx < n) {
2035        z[idx] = inv_diag[idx] * r[idx];
2036    }
2037}
2038
2039extern "C" __global__ void arrow_pcg_update_p(
2040    const double* __restrict__ z,
2041    double beta,
2042    double* __restrict__ p,
2043    int n
2044) {
2045    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2046    if (idx < n) {
2047        p[idx] = z[idx] + beta * p[idx];
2048    }
2049}
2050
2051extern "C" __global__ void arrow_sae_init(
2052    double* __restrict__ out,
2053    const double* __restrict__ x,
2054    double ridge,
2055    int n
2056) {
2057    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2058    if (idx < n) {
2059        out[idx] = ridge * x[idx];
2060    }
2061}
2062
2063extern "C" __global__ void arrow_sae_smooth_matvec(
2064    const double* __restrict__ x,
2065    double* __restrict__ out,
2066    const int* __restrict__ block_offsets,
2067    const int* __restrict__ block_m,
2068    const int* __restrict__ factor_ptr,
2069    const double* __restrict__ factors,
2070    int p,
2071    int n_blocks
2072) {
2073    int block_id = blockIdx.y;
2074    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2075    if (block_id >= n_blocks) {
2076        return;
2077    }
2078    int m = block_m[block_id];
2079    int total = m * p;
2080    if (linear >= total) {
2081        return;
2082    }
2083    int li = linear / p;
2084    int oc = linear - li * p;
2085    int off = block_offsets[block_id];
2086    int fbase = factor_ptr[block_id];
2087    double acc = 0.0;
2088    for (int lj = 0; lj < m; ++lj) {
2089        double a = factors[fbase + li * m + lj];
2090        acc += a * x[off + lj * p + oc];
2091    }
2092    out[off + li * p + oc] += acc;
2093}
2094
2095extern "C" __global__ void arrow_sae_sparse_g_matvec(
2096    const double* __restrict__ x,
2097    double* __restrict__ out,
2098    const int* __restrict__ row_off,
2099    const int* __restrict__ col_off,
2100    const int* __restrict__ rows,
2101    const int* __restrict__ cols,
2102    const int* __restrict__ data_ptr,
2103    const double* __restrict__ data,
2104    int p,
2105    int n_blocks
2106) {
2107    int block_id = blockIdx.y;
2108    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2109    if (block_id >= n_blocks) {
2110        return;
2111    }
2112    int m_i = rows[block_id];
2113    int m_j = cols[block_id];
2114    int total = m_i * p;
2115    if (linear >= total) {
2116        return;
2117    }
2118    int li = linear / p;
2119    int oc = linear - li * p;
2120    int rbase = row_off[block_id];
2121    int cbase = col_off[block_id];
2122    int dbase = data_ptr[block_id];
2123    double acc = 0.0;
2124    for (int lj = 0; lj < m_j; ++lj) {
2125        acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2126    }
2127    // #1017 — a row atom co-occurs with multiple column atoms, so several
2128    // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2129    // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2130    // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2131    // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2132    atomicAdd(&out[(rbase + li) * p + oc], acc);
2133}
2134
2135extern "C" __global__ void arrow_sae_gather_u(
2136    const double* __restrict__ x,
2137    const int* __restrict__ row_ptr,
2138    const int* __restrict__ beta_base,
2139    const double* __restrict__ phi,
2140    double* __restrict__ u,
2141    int p,
2142    int n_rows
2143) {
2144    int row = blockIdx.y;
2145    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2146    if (row >= n_rows || oc >= p) {
2147        return;
2148    }
2149    double acc = 0.0;
2150    int start = row_ptr[row];
2151    int end = row_ptr[row + 1];
2152    for (int e = start; e < end; ++e) {
2153        acc += phi[e] * x[beta_base[e] + oc];
2154    }
2155    u[row * p + oc] = acc;
2156}
2157
2158extern "C" __global__ void arrow_sae_apply_l(
2159    const double* __restrict__ u,
2160    const int* __restrict__ jac_ptr,
2161    const double* __restrict__ jac,
2162    double* __restrict__ w,
2163    int p,
2164    int max_q,
2165    int n_rows
2166) {
2167    int row = blockIdx.y;
2168    int c = blockIdx.x * blockDim.x + threadIdx.x;
2169    if (row >= n_rows) {
2170        return;
2171    }
2172    int jstart = jac_ptr[row];
2173    int q = (jac_ptr[row + 1] - jstart) / p;
2174    if (c >= q) {
2175        return;
2176    }
2177    double acc = 0.0;
2178    for (int oc = 0; oc < p; ++oc) {
2179        acc += jac[jstart + c * p + oc] * u[row * p + oc];
2180    }
2181    w[row * max_q + c] = acc;
2182}
2183
2184extern "C" __global__ void arrow_sae_apply_ainv(
2185    const double* __restrict__ ainv,
2186    const double* __restrict__ w,
2187    double* __restrict__ v,
2188    int max_q,
2189    int n_rows
2190) {
2191    int row = blockIdx.y;
2192    int c = blockIdx.x * blockDim.x + threadIdx.x;
2193    if (row >= n_rows || c >= max_q) {
2194        return;
2195    }
2196    double acc = 0.0;
2197    int base = row * max_q * max_q;
2198    for (int j = 0; j < max_q; ++j) {
2199        acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2200    }
2201    v[row * max_q + c] = acc;
2202}
2203
2204extern "C" __global__ void arrow_sae_scatter_sub(
2205    const double* __restrict__ v,
2206    const int* __restrict__ jac_ptr,
2207    const double* __restrict__ jac,
2208    const int* __restrict__ row_ptr,
2209    const int* __restrict__ beta_base,
2210    const double* __restrict__ phi,
2211    double* __restrict__ out,
2212    int p,
2213    int max_q,
2214    int n_rows
2215) {
2216    int row = blockIdx.y;
2217    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2218    if (row >= n_rows || oc >= p) {
2219        return;
2220    }
2221    int jstart = jac_ptr[row];
2222    int q = (jac_ptr[row + 1] - jstart) / p;
2223    double lt_v = 0.0;
2224    for (int c = 0; c < q; ++c) {
2225        lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2226    }
2227    int start = row_ptr[row];
2228    int end = row_ptr[row + 1];
2229    for (int e = start; e < end; ++e) {
2230        atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2231    }
2232}
2233
2234extern "C" __global__ void arrow_sae_diag_sub(
2235    double* __restrict__ diag,
2236    const double* __restrict__ ainv,
2237    const int* __restrict__ jac_ptr,
2238    const double* __restrict__ jac,
2239    const int* __restrict__ row_ptr,
2240    const int* __restrict__ beta_base,
2241    const double* __restrict__ phi,
2242    int p,
2243    int max_q,
2244    int n_rows
2245) {
2246    int row = blockIdx.y;
2247    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2248    if (row >= n_rows || oc >= p) {
2249        return;
2250    }
2251    int jstart = jac_ptr[row];
2252    int q = (jac_ptr[row + 1] - jstart) / p;
2253    int abase = row * max_q * max_q;
2254    double quad = 0.0;
2255    for (int c = 0; c < q; ++c) {
2256        double lc = jac[jstart + c * p + oc];
2257        for (int d = 0; d < q; ++d) {
2258            quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2259        }
2260    }
2261    int start = row_ptr[row];
2262    int end = row_ptr[row + 1];
2263    for (int e = start; e < end; ++e) {
2264        double pe = phi[e];
2265        atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2266    }
2267}
2268
2269/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2270 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2271 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2272 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2273 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2274
2275extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2276    const double* __restrict__ x,
2277    double* __restrict__ out,
2278    const int* __restrict__ block_offsets,
2279    const int* __restrict__ block_m,
2280    const int* __restrict__ block_r,
2281    const int* __restrict__ factor_ptr,
2282    const double* __restrict__ factors,
2283    int n_blocks
2284) {
2285    int block_id = blockIdx.y;
2286    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2287    if (block_id >= n_blocks) {
2288        return;
2289    }
2290    int m = block_m[block_id];
2291    int r = block_r[block_id];
2292    int total = m * r;
2293    if (linear >= total) {
2294        return;
2295    }
2296    int li = linear / r;
2297    int ib = linear - li * r;
2298    int off = block_offsets[block_id];
2299    int fbase = factor_ptr[block_id];
2300    double acc = 0.0;
2301    for (int lj = 0; lj < m; ++lj) {
2302        double a = factors[fbase + li * m + lj];
2303        acc += a * x[off + lj * r + ib];
2304    }
2305    out[off + li * r + ib] += acc;
2306}
2307
2308extern "C" __global__ void arrow_sae_frame_g_matvec(
2309    const double* __restrict__ x,
2310    double* __restrict__ out,
2311    const int* __restrict__ off_i,
2312    const int* __restrict__ off_j,
2313    const int* __restrict__ r_i,
2314    const int* __restrict__ r_j,
2315    const int* __restrict__ m_i,
2316    const int* __restrict__ m_j,
2317    const int* __restrict__ g_ptr,
2318    const double* __restrict__ g_data,
2319    const int* __restrict__ w_ptr,
2320    const double* __restrict__ w_data,
2321    int n_blocks
2322) {
2323    int block_id = blockIdx.y;
2324    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2325    if (block_id >= n_blocks) {
2326        return;
2327    }
2328    int ri = r_i[block_id];
2329    int rj = r_j[block_id];
2330    int mi = m_i[block_id];
2331    int mj = m_j[block_id];
2332    int total = mi * ri;
2333    if (linear >= total) {
2334        return;
2335    }
2336    int li = linear / ri;       // basis row in atom i
2337    int a = linear - li * ri;   // frame coord in atom i
2338    int oi = off_i[block_id];
2339    int oj = off_j[block_id];
2340    int gbase = g_ptr[block_id];
2341    int wbase = w_ptr[block_id];
2342    double acc = 0.0;
2343    for (int lj = 0; lj < mj; ++lj) {
2344        double g = g_data[gbase + li * mj + lj];
2345        if (g == 0.0) { continue; }
2346        int xj_base = oj + lj * rj;
2347        double inner = 0.0;
2348        for (int b = 0; b < rj; ++b) {
2349            inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2350        }
2351        acc += g * inner;
2352    }
2353    // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2354    // multiple co-occurring (i,j) frame blocks running concurrently on
2355    // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2356    // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2357    atomicAdd(&out[oi + li * ri + a], acc);
2358}
2359
2360/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2361 *   h_i   = H_tβ^(i) · x                (length q_i)
2362 *   s_i   = (H_tt^(i)+ρ_t I)⁻¹ h_i      (apply cached ainv, length q_i)
2363 *   out  -= (H_tβ^(i))ᵀ · s_i           (scatter into border_dim)
2364 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2365 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2366extern "C" __global__ void arrow_sae_frame_apply_h(
2367    const double* __restrict__ x,
2368    const int* __restrict__ htb_ptr,
2369    const double* __restrict__ htb,
2370    const int* __restrict__ q_of,
2371    double* __restrict__ hvec,
2372    int k,
2373    int max_q,
2374    int n_rows
2375) {
2376    int row = blockIdx.y;
2377    int c = blockIdx.x * blockDim.x + threadIdx.x;
2378    if (row >= n_rows) { return; }
2379    int q = q_of[row];
2380    if (c >= q) { return; }
2381    int base = htb_ptr[row] + c * k;
2382    double acc = 0.0;
2383    for (int a = 0; a < k; ++a) {
2384        acc += htb[base + a] * x[a];
2385    }
2386    hvec[row * max_q + c] = acc;
2387}
2388
2389extern "C" __global__ void arrow_sae_frame_apply_ainv(
2390    const double* __restrict__ ainv,
2391    const double* __restrict__ hvec,
2392    const int* __restrict__ q_of,
2393    double* __restrict__ svec,
2394    int max_q,
2395    int n_rows
2396) {
2397    int row = blockIdx.y;
2398    int c = blockIdx.x * blockDim.x + threadIdx.x;
2399    if (row >= n_rows || c >= max_q) { return; }
2400    int q = q_of[row];
2401    double acc = 0.0;
2402    int abase = row * max_q * max_q;
2403    for (int j = 0; j < q; ++j) {
2404        acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2405    }
2406    svec[row * max_q + c] = acc;
2407}
2408
2409extern "C" __global__ void arrow_sae_frame_scatter_h(
2410    const double* __restrict__ svec,
2411    const int* __restrict__ htb_ptr,
2412    const double* __restrict__ htb,
2413    const int* __restrict__ q_of,
2414    double* __restrict__ out,
2415    int k,
2416    int max_q,
2417    int n_rows
2418) {
2419    int row = blockIdx.y;
2420    int a = blockIdx.x * blockDim.x + threadIdx.x;
2421    if (row >= n_rows || a >= k) { return; }
2422    int q = q_of[row];
2423    int hbase = htb_ptr[row];
2424    double acc = 0.0;
2425    for (int c = 0; c < q; ++c) {
2426        acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2427    }
2428    atomicAdd(&out[a], -acc);
2429}
2430
2431/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2432extern "C" __global__ void arrow_sae_frame_diag_sub(
2433    double* __restrict__ diag,
2434    const double* __restrict__ ainv,
2435    const int* __restrict__ htb_ptr,
2436    const double* __restrict__ htb,
2437    const int* __restrict__ q_of,
2438    int k,
2439    int max_q,
2440    int n_rows
2441) {
2442    int row = blockIdx.y;
2443    int a = blockIdx.x * blockDim.x + threadIdx.x;
2444    if (row >= n_rows || a >= k) { return; }
2445    int q = q_of[row];
2446    int hbase = htb_ptr[row];
2447    int abase = row * max_q * max_q;
2448    double quad = 0.0;
2449    for (int c = 0; c < q; ++c) {
2450        double hc = htb[hbase + c * k + a];
2451        for (int d = 0; d < q; ++d) {
2452            quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2453        }
2454    }
2455    atomicAdd(&diag[a], -quad);
2456}
2457"#;
2458
2459    fn pcg_vector_module(
2460        ctx: &Arc<CudaContext>,
2461    ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2462        static CACHE: gam_gpu::device_cache::PtxModuleCache =
2463            gam_gpu::device_cache::PtxModuleCache::new();
2464        CACHE
2465            .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2466            .map_err(|err| {
2467                // #1551: an NVRTC compile / module-load failure of
2468                // PCG_VECTOR_KERNEL_SOURCE means the device SAE PCG cannot run;
2469                // log it (the historical silent collapse to `Unavailable` is what
2470                // masked the missing `--gpu-architecture` for so long) and fall
2471                // back to the CPU.
2472                log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2473                ArrowSchurGpuFailure::Unavailable
2474            })
2475    }
2476
2477    fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2478        let threads = 256u32;
2479        let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2480        Ok(LaunchConfig {
2481            grid_dim: (blocks, 1, 1),
2482            block_dim: (threads, 1, 1),
2483            shared_mem_bytes: 0,
2484        })
2485    }
2486
2487    fn launch_jacobi_mul(
2488        stream: &Arc<CudaStream>,
2489        module: &Arc<CudaModule>,
2490        inv_diag: &CudaSlice<f64>,
2491        r: &CudaSlice<f64>,
2492        z: &mut CudaSlice<f64>,
2493        n: usize,
2494    ) -> Result<(), ArrowSchurGpuFailure> {
2495        let kernel = module
2496            .load_function("arrow_pcg_jacobi_mul")
2497            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2498        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2499        let mut builder = stream.launch_builder(&kernel);
2500        builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2501        // SAFETY: all buffers have length n and belong to `stream`; the kernel only
2502        // reads/writes indices `< n`.
2503        unsafe { builder.launch(pcg_launch_config(n)?) }
2504            .map(drop)
2505            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2506    }
2507
2508    fn launch_update_p(
2509        stream: &Arc<CudaStream>,
2510        module: &Arc<CudaModule>,
2511        z: &CudaSlice<f64>,
2512        beta: f64,
2513        p: &mut CudaSlice<f64>,
2514        n: usize,
2515    ) -> Result<(), ArrowSchurGpuFailure> {
2516        let kernel = module
2517            .load_function("arrow_pcg_update_p")
2518            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2519        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2520        let mut builder = stream.launch_builder(&kernel);
2521        builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2522        // SAFETY: z/p both have length n and belong to `stream`; the kernel only
2523        // reads/writes indices `< n`.
2524        unsafe { builder.launch(pcg_launch_config(n)?) }
2525            .map(drop)
2526            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2527    }
2528
2529    struct DeviceSaePcgBuffers {
2530        row_ptr: CudaSlice<i32>,
2531        beta_base: CudaSlice<i32>,
2532        phi: CudaSlice<f64>,
2533        jac_ptr: CudaSlice<i32>,
2534        jac: CudaSlice<f64>,
2535        smooth_offsets: CudaSlice<i32>,
2536        smooth_m: CudaSlice<i32>,
2537        smooth_ptr: CudaSlice<i32>,
2538        smooth_data: CudaSlice<f64>,
2539        g_row_off: CudaSlice<i32>,
2540        g_col_off: CudaSlice<i32>,
2541        g_rows: CudaSlice<i32>,
2542        g_cols: CudaSlice<i32>,
2543        g_ptr: CudaSlice<i32>,
2544        g_data: CudaSlice<f64>,
2545        ainv: CudaSlice<f64>,
2546        u: CudaSlice<f64>,
2547        w: CudaSlice<f64>,
2548        v: CudaSlice<f64>,
2549        n_rows: usize,
2550        p: usize,
2551        k: usize,
2552        max_q: usize,
2553        smooth_blocks: usize,
2554        g_blocks: usize,
2555    }
2556
2557    fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2558        to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2559    }
2560
2561    fn sae_penalty_diag_host(
2562        data: &DeviceSaePcgData,
2563        ridge_beta: f64,
2564    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2565        let mut diag = vec![ridge_beta; data.beta_dim];
2566        for block in &data.smooth_blocks {
2567            let (rows, cols) = block.factor_a.dim();
2568            if rows != cols {
2569                return Err(ArrowSchurGpuFailure::Unavailable);
2570            }
2571            for row in 0..rows {
2572                let coeff = block.factor_a[[row, row]];
2573                let base = block
2574                    .global_offset
2575                    .checked_add(
2576                        row.checked_mul(data.p)
2577                            .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2578                    )
2579                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2580                let end = base
2581                    .checked_add(data.p)
2582                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2583                if end > diag.len() {
2584                    return Err(ArrowSchurGpuFailure::Unavailable);
2585                }
2586                for channel in 0..data.p {
2587                    diag[base + channel] += coeff;
2588                }
2589            }
2590        }
2591        for block in &data.sparse_g_blocks {
2592            if block.row_off != block.col_off {
2593                continue;
2594            }
2595            let (rows, cols) = block.data.dim();
2596            for row in 0..rows.min(cols) {
2597                let coeff = block.data[[row, row]];
2598                let beta_row = block
2599                    .row_off
2600                    .checked_add(row)
2601                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2602                let base = beta_row
2603                    .checked_mul(data.p)
2604                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2605                let end = base
2606                    .checked_add(data.p)
2607                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2608                if end > diag.len() {
2609                    return Err(ArrowSchurGpuFailure::Unavailable);
2610                }
2611                for channel in 0..data.p {
2612                    diag[base + channel] += coeff;
2613                }
2614            }
2615        }
2616        Ok(diag)
2617    }
2618
2619    fn flatten_device_sae_data(
2620        sys: &ArrowSchurSystem,
2621        data: &DeviceSaePcgData,
2622        ridge_t: f64,
2623        stream: &Arc<CudaStream>,
2624    ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2625        let n_rows = sys.rows.len();
2626        let p = data.p;
2627        let k = data.beta_dim;
2628        if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2629            return Err(ArrowSchurGpuFailure::Unavailable);
2630        }
2631
2632        let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2633        let mut beta_base_host = Vec::<i32>::new();
2634        let mut phi_host = Vec::<f64>::new();
2635        row_ptr_host.push(0_i32);
2636        for row in data.a_phi.iter() {
2637            for &(base, phi) in row {
2638                beta_base_host.push(checked_i32(base)?);
2639                phi_host.push(phi);
2640            }
2641            row_ptr_host.push(checked_i32(beta_base_host.len())?);
2642        }
2643
2644        let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2645        let mut jac_host = Vec::<f64>::new();
2646        let mut max_q = 0usize;
2647        jac_ptr_host.push(0_i32);
2648        for row_jac in data.local_jac.iter() {
2649            if row_jac.len() % p != 0 {
2650                return Err(ArrowSchurGpuFailure::Unavailable);
2651            }
2652            max_q = max_q.max(row_jac.len() / p);
2653            jac_host.extend_from_slice(row_jac);
2654            jac_ptr_host.push(checked_i32(jac_host.len())?);
2655        }
2656        if max_q == 0 {
2657            return Err(ArrowSchurGpuFailure::Unavailable);
2658        }
2659
2660        let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2661        let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2662        let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2663        let mut smooth_data_host = Vec::<f64>::new();
2664        smooth_ptr_host.push(0_i32);
2665        for block in &data.smooth_blocks {
2666            let (rows, cols) = block.factor_a.dim();
2667            if rows != cols {
2668                return Err(ArrowSchurGpuFailure::Unavailable);
2669            }
2670            smooth_offsets_host.push(checked_i32(block.global_offset)?);
2671            smooth_m_host.push(checked_i32(rows)?);
2672            for r in 0..rows {
2673                for c in 0..cols {
2674                    smooth_data_host.push(block.factor_a[[r, c]]);
2675                }
2676            }
2677            smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2678        }
2679
2680        let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2681        let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2682        let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2683        let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2684        let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2685        let mut g_data_host = Vec::<f64>::new();
2686        g_ptr_host.push(0_i32);
2687        for block in &data.sparse_g_blocks {
2688            let (rows, cols) = block.data.dim();
2689            g_row_off_host.push(checked_i32(block.row_off)?);
2690            g_col_off_host.push(checked_i32(block.col_off)?);
2691            g_rows_host.push(checked_i32(rows)?);
2692            g_cols_host.push(checked_i32(cols)?);
2693            for r in 0..rows {
2694                for c in 0..cols {
2695                    g_data_host.push(block.data[[r, c]]);
2696                }
2697            }
2698            g_ptr_host.push(checked_i32(g_data_host.len())?);
2699        }
2700
2701        let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2702        for (row_idx, row) in sys.rows.iter().enumerate() {
2703            let q = data.local_jac[row_idx].len() / p;
2704            if row.htt.dim() != (q, q) {
2705                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2706                    reason: format!(
2707                        "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2708                        row.htt.dim()
2709                    ),
2710                });
2711            }
2712            let mut block = row.htt.clone();
2713            for d in 0..q {
2714                block[[d, d]] += ridge_t;
2715            }
2716            let factor = gam_linalg::triangular::cholesky_factor_in_place(
2717                block.view(),
2718                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2719            )
2720            .ok_or_else(|| {
2721                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
2722                // indefinite per-row block recovers in one outer-loop retry.
2723                ArrowSchurGpuFailure::RidgeBumpRequired {
2724                    row: row_idx,
2725                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2726                }
2727            })?;
2728            for col in 0..q {
2729                let mut e = Array1::<f64>::zeros(q);
2730                e[col] = 1.0;
2731                let solved =
2732                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2733                for r in 0..q {
2734                    ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2735                }
2736            }
2737        }
2738
2739        Ok(DeviceSaePcgBuffers {
2740            row_ptr: stream
2741                .clone_htod(&row_ptr_host)
2742                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2743            beta_base: stream
2744                .clone_htod(&beta_base_host)
2745                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2746            phi: stream
2747                .clone_htod(&phi_host)
2748                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2749            jac_ptr: stream
2750                .clone_htod(&jac_ptr_host)
2751                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2752            jac: stream
2753                .clone_htod(&jac_host)
2754                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2755            smooth_offsets: stream
2756                .clone_htod(&smooth_offsets_host)
2757                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2758            smooth_m: stream
2759                .clone_htod(&smooth_m_host)
2760                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2761            smooth_ptr: stream
2762                .clone_htod(&smooth_ptr_host)
2763                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2764            smooth_data: stream
2765                .clone_htod(&smooth_data_host)
2766                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2767            g_row_off: stream
2768                .clone_htod(&g_row_off_host)
2769                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2770            g_col_off: stream
2771                .clone_htod(&g_col_off_host)
2772                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2773            g_rows: stream
2774                .clone_htod(&g_rows_host)
2775                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2776            g_cols: stream
2777                .clone_htod(&g_cols_host)
2778                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2779            g_ptr: stream
2780                .clone_htod(&g_ptr_host)
2781                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2782            g_data: stream
2783                .clone_htod(&g_data_host)
2784                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2785            ainv: stream
2786                .clone_htod(&ainv_host)
2787                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2788            u: stream
2789                .alloc_zeros::<f64>(n_rows * p)
2790                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2791            w: stream
2792                .alloc_zeros::<f64>(n_rows * max_q)
2793                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2794            v: stream
2795                .alloc_zeros::<f64>(n_rows * max_q)
2796                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2797            n_rows,
2798            p,
2799            k,
2800            max_q,
2801            smooth_blocks: data.smooth_blocks.len(),
2802            g_blocks: data.sparse_g_blocks.len(),
2803        })
2804    }
2805
2806    fn launch_sae_init(
2807        stream: &Arc<CudaStream>,
2808        module: &Arc<CudaModule>,
2809        out: &mut CudaSlice<f64>,
2810        x: &CudaSlice<f64>,
2811        ridge: f64,
2812        n: usize,
2813    ) -> Result<(), ArrowSchurGpuFailure> {
2814        let kernel = module
2815            .load_function("arrow_sae_init")
2816            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2817        let n_i32 = checked_i32(n)?;
2818        let mut builder = stream.launch_builder(&kernel);
2819        builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2820        // SAFETY: `out` and `x` are live device buffers with at least `n`
2821        // entries on `stream`; the kernel writes one in-bounds element per
2822        // launched index below `n`.
2823        unsafe { builder.launch(pcg_launch_config(n)?) }
2824            .map(drop)
2825            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2826    }
2827
2828    fn launch_sae_penalty_matvec(
2829        stream: &Arc<CudaStream>,
2830        module: &Arc<CudaModule>,
2831        buffers: &mut DeviceSaePcgBuffers,
2832        x: &CudaSlice<f64>,
2833        out: &mut CudaSlice<f64>,
2834        ridge_beta: f64,
2835    ) -> Result<(), ArrowSchurGpuFailure> {
2836        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2837        if buffers.smooth_blocks > 0 {
2838            let kernel = module
2839                .load_function("arrow_sae_smooth_matvec")
2840                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2841            let max_m = buffers.k;
2842            let p_i32 = checked_i32(buffers.p)?;
2843            let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2844            let cfg = LaunchConfig {
2845                grid_dim: (
2846                    ((max_m as u32).saturating_add(255) / 256).max(1),
2847                    checked_i32(buffers.smooth_blocks)? as u32,
2848                    1,
2849                ),
2850                block_dim: (256, 1, 1),
2851                shared_mem_bytes: 0,
2852            };
2853            let mut builder = stream.launch_builder(&kernel);
2854            builder
2855                .arg(x)
2856                .arg(&mut *out)
2857                .arg(&buffers.smooth_offsets)
2858                .arg(&buffers.smooth_m)
2859                .arg(&buffers.smooth_ptr)
2860                .arg(&buffers.smooth_data)
2861                .arg(&p_i32)
2862                .arg(&blocks_i32);
2863            // SAFETY: smooth block metadata and dense smooth data were flattened
2864            // into live device buffers; the 2D grid covers only declared block
2865            // and coefficient-channel work items, and the kernel bounds-checks
2866            // against each block's stored size.
2867            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2868        }
2869        if buffers.g_blocks > 0 {
2870            let kernel = module
2871                .load_function("arrow_sae_sparse_g_matvec")
2872                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2873            let max_work = buffers
2874                .k
2875                .checked_div(buffers.p)
2876                .unwrap_or(0)
2877                .saturating_mul(buffers.p);
2878            let p_i32 = checked_i32(buffers.p)?;
2879            let blocks_i32 = checked_i32(buffers.g_blocks)?;
2880            let cfg = LaunchConfig {
2881                grid_dim: (
2882                    ((max_work as u32).saturating_add(255) / 256).max(1),
2883                    checked_i32(buffers.g_blocks)? as u32,
2884                    1,
2885                ),
2886                block_dim: (256, 1, 1),
2887                shared_mem_bytes: 0,
2888            };
2889            let mut builder = stream.launch_builder(&kernel);
2890            builder
2891                .arg(x)
2892                .arg(&mut *out)
2893                .arg(&buffers.g_row_off)
2894                .arg(&buffers.g_col_off)
2895                .arg(&buffers.g_rows)
2896                .arg(&buffers.g_cols)
2897                .arg(&buffers.g_ptr)
2898                .arg(&buffers.g_data)
2899                .arg(&p_i32)
2900                .arg(&blocks_i32);
2901            // SAFETY: sparse G block metadata/data are live device buffers built
2902            // from host CSR-like block descriptors; the launch dimensions cover
2903            // declared block work only and the kernel checks row/column bounds
2904            // before reading or accumulating.
2905            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2906        }
2907        Ok(())
2908    }
2909
2910    fn launch_sae_row_schur_sub(
2911        stream: &Arc<CudaStream>,
2912        module: &Arc<CudaModule>,
2913        buffers: &mut DeviceSaePcgBuffers,
2914        x: &CudaSlice<f64>,
2915        out: &mut CudaSlice<f64>,
2916    ) -> Result<(), ArrowSchurGpuFailure> {
2917        let p_i32 = checked_i32(buffers.p)?;
2918        let max_q_i32 = checked_i32(buffers.max_q)?;
2919        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2920        let cfg_p_rows = LaunchConfig {
2921            grid_dim: (
2922                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2923                checked_i32(buffers.n_rows)? as u32,
2924                1,
2925            ),
2926            block_dim: (256, 1, 1),
2927            shared_mem_bytes: 0,
2928        };
2929        let gather = module
2930            .load_function("arrow_sae_gather_u")
2931            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2932        {
2933            let mut builder = stream.launch_builder(&gather);
2934            builder
2935                .arg(x)
2936                .arg(&buffers.row_ptr)
2937                .arg(&buffers.beta_base)
2938                .arg(&buffers.phi)
2939                .arg(&mut buffers.u)
2940                .arg(&p_i32)
2941                .arg(&n_rows_i32);
2942            // SAFETY: `x`, row pointers, beta offsets, basis rows, and `u` are
2943            // live device buffers sized for `n_rows` by `p`; the kernel guards
2944            // row/channel indices before gathering.
2945            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2946        }
2947
2948        let cfg_q_rows = LaunchConfig {
2949            grid_dim: (
2950                ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2951                checked_i32(buffers.n_rows)? as u32,
2952                1,
2953            ),
2954            block_dim: (256, 1, 1),
2955            shared_mem_bytes: 0,
2956        };
2957        let apply_l = module
2958            .load_function("arrow_sae_apply_l")
2959            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2960        {
2961            let mut builder = stream.launch_builder(&apply_l);
2962            builder
2963                .arg(&buffers.u)
2964                .arg(&buffers.jac_ptr)
2965                .arg(&buffers.jac)
2966                .arg(&mut buffers.w)
2967                .arg(&p_i32)
2968                .arg(&max_q_i32)
2969                .arg(&n_rows_i32);
2970            // SAFETY: `u`, Jacobian row pointers/data, and `w` are live buffers
2971            // sized for the `(n_rows, p)` to `(n_rows, max_q)` multiply; the
2972            // kernel checks row and local-coordinate bounds.
2973            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2974        }
2975
2976        let apply_ainv = module
2977            .load_function("arrow_sae_apply_ainv")
2978            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2979        {
2980            let mut builder = stream.launch_builder(&apply_ainv);
2981            builder
2982                .arg(&buffers.ainv)
2983                .arg(&buffers.w)
2984                .arg(&mut buffers.v)
2985                .arg(&max_q_i32)
2986                .arg(&n_rows_i32);
2987            // SAFETY: `ainv`, `w`, and `v` are live device buffers sized for
2988            // `n_rows * max_q`; the kernel guards all row/local-coordinate
2989            // indices before reading or writing.
2990            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2991        }
2992
2993        let scatter = module
2994            .load_function("arrow_sae_scatter_sub")
2995            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2996        {
2997            let mut builder = stream.launch_builder(&scatter);
2998            builder
2999                .arg(&buffers.v)
3000                .arg(&buffers.jac_ptr)
3001                .arg(&buffers.jac)
3002                .arg(&buffers.row_ptr)
3003                .arg(&buffers.beta_base)
3004                .arg(&buffers.phi)
3005                .arg(out)
3006                .arg(&p_i32)
3007                .arg(&max_q_i32)
3008                .arg(&n_rows_i32);
3009            // SAFETY: `v`, Jacobian metadata, row pointers, beta offsets, basis
3010            // rows, and `out` are live buffers for `n_rows` by `p`; scatter
3011            // indices are checked against row and channel bounds in the kernel.
3012            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3013        }
3014        Ok(())
3015    }
3016
3017    fn launch_sae_diag_sub(
3018        stream: &Arc<CudaStream>,
3019        module: &Arc<CudaModule>,
3020        buffers: &DeviceSaePcgBuffers,
3021        diag: &mut CudaSlice<f64>,
3022    ) -> Result<(), ArrowSchurGpuFailure> {
3023        let kernel = module
3024            .load_function("arrow_sae_diag_sub")
3025            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3026        let p_i32 = checked_i32(buffers.p)?;
3027        let max_q_i32 = checked_i32(buffers.max_q)?;
3028        let n_rows_i32 = checked_i32(buffers.n_rows)?;
3029        let cfg = LaunchConfig {
3030            grid_dim: (
3031                ((buffers.p as u32).saturating_add(255) / 256).max(1),
3032                checked_i32(buffers.n_rows)? as u32,
3033                1,
3034            ),
3035            block_dim: (256, 1, 1),
3036            shared_mem_bytes: 0,
3037        };
3038        let mut builder = stream.launch_builder(&kernel);
3039        builder
3040            .arg(diag)
3041            .arg(&buffers.ainv)
3042            .arg(&buffers.jac_ptr)
3043            .arg(&buffers.jac)
3044            .arg(&buffers.row_ptr)
3045            .arg(&buffers.beta_base)
3046            .arg(&buffers.phi)
3047            .arg(&p_i32)
3048            .arg(&max_q_i32)
3049            .arg(&n_rows_i32);
3050        // SAFETY: diagonal output and all read-only SAE row metadata buffers are
3051        // live on `stream` with sizes matching `n_rows`, `p`, and `max_q`; the
3052        // kernel bounds-checks its flattened work index.
3053        unsafe { builder.launch(cfg) }
3054            .map(drop)
3055            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3056    }
3057
3058    fn launch_sae_matvec(
3059        stream: &Arc<CudaStream>,
3060        module: &Arc<CudaModule>,
3061        buffers: &mut DeviceSaePcgBuffers,
3062        x: &CudaSlice<f64>,
3063        out: &mut CudaSlice<f64>,
3064        ridge_beta: f64,
3065    ) -> Result<(), ArrowSchurGpuFailure> {
3066        launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3067        launch_sae_row_schur_sub(stream, module, buffers, x, out)
3068    }
3069
3070    /// Pack `D + ρ_t I`, `B`, and `g` into the strided `(n × P_MAX × P_MAX)`
3071    /// / `(n × P_MAX × R_TEMPLATE)` / `(n × P_MAX)` layout the fused kernel
3072    /// expects. Entries outside the runtime `(p, r)` window stay at zero so
3073    /// the kernel's per-element loops are safe to no-op there.
3074    fn pack_fused_host(
3075        sys: &ArrowSchurSystem,
3076        ridge_t: f64,
3077        p_max: usize,
3078        r_template: usize,
3079    ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3080        let n = sys.rows.len();
3081        let d = sys.d;
3082        let k = sys.k;
3083        let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3084        let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3085        let mut g_buf = vec![0.0_f64; n * p_max];
3086        for (i, row) in sys.rows.iter().enumerate() {
3087            // D_i + ρI, column-major in P_MAX×P_MAX strided block.
3088            for col in 0..d {
3089                let base = (i * p_max + col) * p_max;
3090                for r in 0..d {
3091                    let mut value = row.htt[[r, col]];
3092                    if r == col {
3093                        value += ridge_t;
3094                    }
3095                    d_buf[base + r] = value;
3096                }
3097            }
3098            // B_i in P_MAX×R_TEMPLATE strided block. The per-row (per-i) block
3099            // stride is `p_max · r_template` (matching the `b_buf` allocation
3100            // above and the kernel's `b_stack`/`y_out` layout), NOT
3101            // `p_max · p_max`: using the D-block multiplier here overflows the
3102            // buffer whenever `p_max > r_template` (e.g. d=30→p_max=32,
3103            // k=5→r_template=5). The within-block element offset stays
3104            // column-major `col·p_max + r` (P_MAX rows per column).
3105            for col in 0..k {
3106                let base = (i * r_template + col) * p_max;
3107                for r in 0..d {
3108                    b_buf[base + r] = row.htbeta[[r, col]];
3109                }
3110            }
3111            // g_i in P_MAX strided vector.
3112            let g_base = i * p_max;
3113            for r in 0..d {
3114                g_buf[g_base + r] = row.gt[r];
3115            }
3116        }
3117        (d_buf, b_buf, g_buf)
3118    }
3119
3120    // -----------------------------------------------------------------------
3121    // #1017 Phase 3: across-iteration device residency.
3122    //
3123    // `solve()` re-packs and re-uploads `D` (`H_tt`), `B` (`H_tβ`) and `g`,
3124    // then re-runs the per-row POTRF and the border Schur factorization on
3125    // EVERY call. For the SAE joint inner Newton at a frozen gate/basis frame
3126    // the Hessian blocks `D`, `B`, `H_ββ` are CONSTANT across the inner loop —
3127    // only the gradient `g = r(z) = H z − g₀` changes per iterate. So the
3128    // factor work (`O(n·d³ + p³)`) and the dominant `O(n·d·p)` cross-block
3129    // upload are pure waste when repeated per iterate.
3130    //
3131    // `ResidentArrowFrame` performs that constant work ONCE at construction:
3132    // upload+ridge+POTRF of `D` (keeping `L_i` resident in `l_dev`), the
3133    // forward solve `Y_i = L_i^{-1} B_i` (kept resident in `y_dev`), and the
3134    // Schur assembly + border POTRF (keeping `L_S` resident in `schur_dev`).
3135    // Each subsequent `solve_gradient(g)` uploads only the `n·d` row gradient,
3136    // runs the cheap residual path — `u_i = L_i^{-1} g_i` (one batched TRSM),
3137    // Schur RHS `−g_β + Σ Y_iᵀ u_i`, `δβ = L_S^{-T} L_S^{-1} rhs` (two TRSM,
3138    // NO POTRF), back-sub `δt_i = −L_i^{-T}(u_i + Y_i δβ)` — and reads back only
3139    // `δ` and the cached log|H|. The heavy buffers never leave the device
3140    // across iterations; the per-iterate host transfer is `O(n·d + p)`, not
3141    // `O(n·d·p)`. Numerics are bit-identical to a `solve()` at the same
3142    // `(D, B, H_ββ, g, ridge_t, ridge_beta)` because the factor buffers and the
3143    // helper kernels are the same; the resident path merely SKIPS re-deriving
3144    // the parts that do not depend on `g`. The CPU dense reference
3145    // (`solve_arrow_newton_step_dense_reference`) is the parity oracle.
3146    pub(super) struct ResidentArrowFrame {
3147        n: usize,
3148        d: usize,
3149        k: usize,
3150        stream: Arc<CudaStream>,
3151        blas: CudaBlas,
3152        /// Per-row lower Cholesky factors `L_i` of `H_tt + ρ_t I`, stacked
3153        /// column-major (`n` tiles of `d×d`). Resident across iterations.
3154        l_dev: CudaSlice<f64>,
3155        /// Whitened cross blocks `Y_i = L_i^{-1} H_tβ^(i)`, stacked column-major
3156        /// (`n` tiles of `d×k`). Resident across iterations.
3157        y_dev: CudaSlice<f64>,
3158        /// Lower Cholesky factor `L_S` of the reduced Schur complement
3159        /// `S_β = H_ββ + ρ_β I − Σ_i Y_iᵀ Y_i`. Resident across iterations.
3160        schur_dev: CudaSlice<f64>,
3161        /// `log|H| = 2 Σ log L_{i,jj} + 2 Σ log L_{S,aa}`, constant for the
3162        /// frame (depends only on the factored Hessian, not on `g`).
3163        log_det_hessian: f64,
3164    }
3165
3166    impl ResidentArrowFrame {
3167        /// Upload the constant Hessian blocks and perform the one-time factor
3168        /// work (`POTRF(D)`, `Y_i = L_i^{-1} B_i`, Schur assembly + border
3169        /// `POTRF`). The frame then serves cheap per-gradient solves.
3170        pub(super) fn new(
3171            sys: &ArrowSchurSystem,
3172            ridge_t: f64,
3173            ridge_beta: f64,
3174        ) -> Result<Self, ArrowSchurGpuFailure> {
3175            if ridge_t.is_nan() || ridge_beta.is_nan() {
3176                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3177                    reason: "ridge is NaN".to_string(),
3178                });
3179            }
3180            let n = sys.rows.len();
3181            let d = sys.d;
3182            let k = sys.k;
3183            let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3184                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3185            let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3186                .and_then(|ctx| ctx.new_stream().ok())
3187                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3188            let solver =
3189                DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3190            let blas =
3191                CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3192
3193            // Upload the constant blocks. `g` is uploaded per-gradient, not here.
3194            let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3195            let mut l_dev = stream
3196                .clone_htod(&d_host)
3197                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3198            let mut y_dev = stream
3199                .clone_htod(&b_host)
3200                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3201
3202            // POTRF(D) → L_i, kept resident in l_dev.
3203            let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3204            if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3205                // cuSOLVER `info` is a 1-based pivot row index; size the bump
3206                // from the block (Gershgorin λ_min bound) so a strongly
3207                // indefinite block recovers in one retry.
3208                return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3209                    row: idx,
3210                    bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3211                });
3212            }
3213
3214            // Y_i = L_i^{-1} B_i, in place over y_dev. Kept resident.
3215            trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3216
3217            // Schur assembly S_β = (H_ββ + ρ_β I) − Σ Y_iᵀ Y_i, then POTRF → L_S.
3218            // The RHS accumulation is folded into the gradient path; here we
3219            // only need the (gradient-independent) Schur factor, so accumulate
3220            // into a throwaway rhs buffer.
3221            let schur_init: Vec<f64> = {
3222                let mut tmp = Vec::with_capacity(k * k);
3223                for col in 0..k {
3224                    for row in 0..k {
3225                        let mut v = sys.hbb[[row, col]];
3226                        if row == col {
3227                            v += ridge_beta;
3228                        }
3229                        tmp.push(v);
3230                    }
3231                }
3232                tmp
3233            };
3234            let mut schur_dev = stream
3235                .clone_htod(&schur_init)
3236                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3237            // A zero u-stack makes `Σ Y_iᵀ u_i = 0`, so only the `−Σ Y_iᵀ Y_i`
3238            // Schur term is accumulated (the rhs is rebuilt per gradient).
3239            let zero_u = stream
3240                .clone_htod(&vec![0.0_f64; n * d])
3241                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3242            let mut throwaway_rhs = stream
3243                .clone_htod(&vec![0.0_f64; k])
3244                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3245            accumulate_schur(
3246                &blas,
3247                d,
3248                k,
3249                n,
3250                &y_dev,
3251                &zero_u,
3252                &mut schur_dev,
3253                &mut throwaway_rhs,
3254            )?;
3255            let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3256            if info != 0 {
3257                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3258                    reason: format!("Schur Cholesky failed at pivot {info}"),
3259                });
3260            }
3261
3262            // log|H| from the resident factors (constant for the frame).
3263            let l_local_host = stream
3264                .clone_dtoh(&l_dev)
3265                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3266            let l_schur_host = stream
3267                .clone_dtoh(&schur_dev)
3268                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3269            let mut log_det = 0.0_f64;
3270            for i in 0..n {
3271                let base = i * d * d;
3272                for j in 0..d {
3273                    log_det += l_local_host[base + j * d + j].ln();
3274                }
3275            }
3276            for j in 0..k {
3277                log_det += l_schur_host[j * k + j].ln();
3278            }
3279            log_det *= 2.0;
3280
3281            Ok(Self {
3282                n,
3283                d,
3284                k,
3285                stream,
3286                blas,
3287                l_dev,
3288                y_dev,
3289                schur_dev,
3290                log_det_hessian: log_det,
3291            })
3292        }
3293
3294        #[inline]
3295        pub(super) fn log_det_hessian(&self) -> f64 {
3296            self.log_det_hessian
3297        }
3298
3299        /// Solve `H δ = −gradient` for a fresh gradient `(g_t, g_β)` reusing the
3300        /// resident factors. Uploads only `g_t` (`n·d` scalars); reads back only
3301        /// `δ`. No POTRF runs here — all factorization is amortized into `new`.
3302        pub(super) fn solve_gradient(
3303            &self,
3304            g_t: &[f64],
3305            g_beta: &[f64],
3306        ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3307            let n = self.n;
3308            let d = self.d;
3309            let k = self.k;
3310            if g_t.len() != n * d || g_beta.len() != k {
3311                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3312                    reason: format!(
3313                        "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3314                        g_t.len(),
3315                        n * d,
3316                        g_beta.len(),
3317                        k
3318                    ),
3319                });
3320            }
3321            // Upload the per-iterate row gradient → u_i = L_i^{-1} g_i in place.
3322            let mut u_dev = self
3323                .stream
3324                .clone_htod(&g_t.to_vec())
3325                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3326            trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3327
3328            // Schur RHS = −g_β + Σ_i Y_iᵀ u_i. Reuse the resident Schur factor
3329            // (no POTRF, and skip the −Σ Y_iᵀ Y_i GEMM already baked into L_S).
3330            let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3331            let mut rhs_dev = self
3332                .stream
3333                .clone_htod(&rhs_init)
3334                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3335            accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3336
3337            // δβ ← L_S^{-T} L_S^{-1} rhs using the resident border factor.
3338            trsm_single(
3339                &self.blas,
3340                &self.stream,
3341                k,
3342                &self.schur_dev,
3343                &mut rhs_dev,
3344                false,
3345                false,
3346            )?;
3347            trsm_single(
3348                &self.blas,
3349                &self.stream,
3350                k,
3351                &self.schur_dev,
3352                &mut rhs_dev,
3353                false,
3354                true,
3355            )?;
3356            let delta_beta_host = self
3357                .stream
3358                .clone_dtoh(&rhs_dev)
3359                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3360            let delta_beta = Array1::from_vec(delta_beta_host);
3361
3362            // Back-sub δt_i = −L_i^{-T}(u_i + Y_i δβ).
3363            accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3364            trsm_batched_lower_inplace_transposed(
3365                &self.blas,
3366                &self.stream,
3367                d,
3368                n,
3369                1,
3370                &self.l_dev,
3371                &mut u_dev,
3372            )?;
3373            let x_host = self
3374                .stream
3375                .clone_dtoh(&u_dev)
3376                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3377            let mut delta_t = Array1::<f64>::zeros(n * d);
3378            for (i, v) in x_host.iter().enumerate() {
3379                delta_t[i] = -*v;
3380            }
3381
3382            Ok(ArrowSchurGpuSolution {
3383                delta_t,
3384                delta_beta,
3385                log_det_hessian: self.log_det_hessian,
3386            })
3387        }
3388    }
3389
3390    pub(super) fn solve_fused(
3391        sys: &ArrowSchurSystem,
3392        ridge_t: f64,
3393        ridge_beta: f64,
3394    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3395        let n = sys.rows.len();
3396        let d = sys.d;
3397        let k = sys.k;
3398        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3399            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3400        let p_max = plan.p_max;
3401        let r_template = plan.r_template;
3402
3403        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3404            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3405        )
3406        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3407        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3408            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3409        let stream = ctx
3410            .new_stream()
3411            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3412        let cap = &runtime.device.capability;
3413        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3414            cc_major: cap.compute_major,
3415            cc_minor: cap.compute_minor,
3416            p_max: p_max as u32,
3417            r_template: r_template as u32,
3418        };
3419        let module = fused_module_for(&ctx, key)?;
3420        let forward = module
3421            .load_function("arrow_schur_forward_pgroup")
3422            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3423        let back_sub = module
3424            .load_function("arrow_schur_back_sub_pgroup")
3425            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3426
3427        // ----- Upload packed D, B, g -----
3428        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3429        let d_dev = stream
3430            .clone_htod(&d_host)
3431            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3432        let b_dev = stream
3433            .clone_htod(&b_host)
3434            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3435        let g_dev = stream
3436            .clone_htod(&g_host)
3437            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3438        let mut l_out = stream
3439            .alloc_zeros::<f64>(n * p_max * p_max)
3440            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3441        let mut u_out = stream
3442            .alloc_zeros::<f64>(n * p_max)
3443            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3444        let mut y_out = stream
3445            .alloc_zeros::<f64>(n * p_max * r_template)
3446            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3447        let mut partial_s = stream
3448            .alloc_zeros::<f64>(plan.partial_s_doubles)
3449            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3450        let mut partial_r = stream
3451            .alloc_zeros::<f64>(plan.partial_r_doubles)
3452            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3453        let mut status_dev = stream
3454            .alloc_zeros::<i32>(n)
3455            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3456
3457        // ----- Launch forward kernel: 1 block per row, P_MAX threads -----
3458        let cfg = LaunchConfig {
3459            grid_dim: (plan.blocks, 1, 1),
3460            block_dim: (plan.threads_per_block, 1, 1),
3461            shared_mem_bytes: 0,
3462        };
3463        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3464        let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3465        let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3466        let ridge_arg = ridge_t;
3467        {
3468            let mut builder = stream.launch_builder(&forward);
3469            builder
3470                .arg(&d_dev)
3471                .arg(&b_dev)
3472                .arg(&g_dev)
3473                .arg(&n_i32)
3474                .arg(&p_i32)
3475                .arg(&r_i32)
3476                .arg(&ridge_arg)
3477                .arg(&mut l_out)
3478                .arg(&mut u_out)
3479                .arg(&mut y_out)
3480                .arg(&mut partial_s)
3481                .arg(&mut partial_r)
3482                .arg(&mut status_dev);
3483            // SAFETY: all buffers were just allocated on `stream` with sizes
3484            // derived from `plan`; kernel parameter list matches the
3485            // FORWARD_KERNEL_SOURCE signature.
3486            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487        }
3488        stream
3489            .synchronize()
3490            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3491
3492        // ----- Check per-block pivot status -----
3493        let status_host = stream
3494            .clone_dtoh(&status_dev)
3495            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3496        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3497            // The NVRTC kernel's status code is a 1-based pivot row index, not
3498            // a magnitude; size the bump from the block (Gershgorin λ_min
3499            // bound) so a strongly indefinite block recovers in one retry.
3500            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3501                row,
3502                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3503            });
3504        }
3505
3506        // ----- Reduce partials on host into S_β and r_β -----
3507        let partial_s_host = stream
3508            .clone_dtoh(&partial_s)
3509            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3510        let partial_r_host = stream
3511            .clone_dtoh(&partial_r)
3512            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3513        let mut schur_host = vec![0.0_f64; k * k];
3514        for col in 0..k {
3515            for row in 0..k {
3516                let mut v = sys.hbb[[row, col]];
3517                if row == col {
3518                    v += ridge_beta;
3519                }
3520                schur_host[col * k + row] = v;
3521            }
3522        }
3523        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3524        for i in 0..n {
3525            // partial_S[i] stride is R_TEMPLATE × R_TEMPLATE column-major; we
3526            // only read the leading (k × k) sub-block.
3527            let s_base = i * r_template * r_template;
3528            for col in 0..k {
3529                let col_base = s_base + col * r_template;
3530                let dst_col_base = col * k;
3531                for row in 0..k {
3532                    schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3533                }
3534            }
3535            let r_base = i * r_template;
3536            for a in 0..k {
3537                rhs_host[a] += partial_r_host[r_base + a];
3538            }
3539        }
3540
3541        // ----- Factor S_β on device (cuSOLVER), solve for δβ -----
3542        let mut schur_dev = stream
3543            .clone_htod(&schur_host)
3544            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3545        let mut rhs_dev = stream
3546            .clone_htod(&rhs_host)
3547            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3548        let solver =
3549            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3550        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3551        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3552        if info != 0 {
3553            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3554                reason: format!("fused Schur Cholesky failed at pivot {info}"),
3555            });
3556        }
3557        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3558        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3559        let delta_beta_host = stream
3560            .clone_dtoh(&rhs_dev)
3561            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3562        let delta_beta = Array1::from_vec(delta_beta_host.clone());
3563
3564        // ----- Layer E: launch back-sub kernel using persisted L, u, Y -----
3565        let mut delta_t_dev = stream
3566            .alloc_zeros::<f64>(n * p_max)
3567            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3568        let back_cfg = LaunchConfig {
3569            grid_dim: (plan.blocks, 1, 1),
3570            block_dim: (plan.threads_per_block, 1, 1),
3571            shared_mem_bytes: 0,
3572        };
3573        {
3574            let mut builder = stream.launch_builder(&back_sub);
3575            builder
3576                .arg(&l_out)
3577                .arg(&u_out)
3578                .arg(&y_out)
3579                .arg(&rhs_dev)
3580                .arg(&n_i32)
3581                .arg(&p_i32)
3582                .arg(&r_i32)
3583                .arg(&mut delta_t_dev);
3584            // SAFETY: kernel parameter list matches FORWARD_KERNEL_SOURCE
3585            // back-sub signature; `rhs_dev` holds δβ in the leading k entries
3586            // (R_TEMPLATE-strided indexing is column 0..k of the R-vector).
3587            unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3588        }
3589        stream
3590            .synchronize()
3591            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3592
3593        let delta_t_host = stream
3594            .clone_dtoh(&delta_t_dev)
3595            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3596        let mut delta_t = Array1::<f64>::zeros(n * d);
3597        for i in 0..n {
3598            let src_base = i * p_max;
3599            let dst_base = i * d;
3600            for r in 0..d {
3601                delta_t[dst_base + r] = delta_t_host[src_base + r];
3602            }
3603        }
3604
3605        // ----- log|H| = 2·Σ log L_{i,jj} + 2·Σ log R_{β,aa} -----
3606        let l_local_host = stream
3607            .clone_dtoh(&l_out)
3608            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3609        let l_schur_host = stream
3610            .clone_dtoh(&schur_dev)
3611            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3612        let mut log_det = 0.0_f64;
3613        for i in 0..n {
3614            let base = i * p_max * p_max;
3615            for j in 0..d {
3616                log_det += l_local_host[base + j * p_max + j].ln();
3617            }
3618        }
3619        for j in 0..k {
3620            log_det += l_schur_host[j * k + j].ln();
3621        }
3622        log_det *= 2.0;
3623
3624        Ok(ArrowSchurGpuSolution {
3625            delta_t,
3626            delta_beta,
3627            log_det_hessian: log_det,
3628        })
3629    }
3630
3631    /// Pre-compute `Y_i = L_i^{-1} H_tβ^(i)` via the fused forward kernel and
3632    /// return a closure that evaluates the full Schur matvec
3633    /// `S·x = (H_ββ + ρ·I)·x − Σ_i Y_i^T (Y_i·x)` for each PCG iteration.
3634    ///
3635    /// The `Y_i` factors are kept in a host-side buffer after one GPU forward
3636    /// pass. Each matvec call runs O(N·d·K) host loops over the pre-computed
3637    /// buffer plus an optional `H_ββ·x` call (matrix-free or dense). This is
3638    /// the first landing of the GPU matvec; a future iteration can move the
3639    /// `Y_i·x` / `Y_i^T z_i` steps to cuBLAS batched GEMV.
3640    pub(super) fn build_schur_matvec_backend(
3641        sys: &ArrowSchurSystem,
3642        ridge_t: f64,
3643        ridge_beta: f64,
3644    ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3645        let n = sys.rows.len();
3646        let d = sys.d;
3647        let k = sys.k;
3648        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3649            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3650        let p_max = plan.p_max;
3651        let r_template = plan.r_template;
3652
3653        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3654            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3655        )
3656        .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3657        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3658            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3659        let stream = ctx
3660            .new_stream()
3661            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3662        let cap = &runtime.device.capability;
3663        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3664            cc_major: cap.compute_major,
3665            cc_minor: cap.compute_minor,
3666            p_max: p_max as u32,
3667            r_template: r_template as u32,
3668        };
3669        let module = fused_module_for(&ctx, key)?;
3670        let forward = module
3671            .load_function("arrow_schur_forward_pgroup")
3672            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3673
3674        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3675        let d_dev = stream
3676            .clone_htod(&d_host)
3677            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3678        let b_dev = stream
3679            .clone_htod(&b_host)
3680            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3681        let g_dev = stream
3682            .clone_htod(&g_host)
3683            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3684        let mut l_out = stream
3685            .alloc_zeros::<f64>(n * p_max * p_max)
3686            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3687        let mut u_out = stream
3688            .alloc_zeros::<f64>(n * p_max)
3689            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3690        let mut y_out = stream
3691            .alloc_zeros::<f64>(n * p_max * r_template)
3692            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3693        let mut partial_s = stream
3694            .alloc_zeros::<f64>(plan.partial_s_doubles)
3695            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3696        let mut partial_r = stream
3697            .alloc_zeros::<f64>(plan.partial_r_doubles)
3698            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3699        let mut status_dev = stream
3700            .alloc_zeros::<i32>(n)
3701            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3702
3703        let cfg = LaunchConfig {
3704            grid_dim: (plan.blocks, 1, 1),
3705            block_dim: (plan.threads_per_block, 1, 1),
3706            shared_mem_bytes: 0,
3707        };
3708        let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3709        let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3710        let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3711        let ridge_arg = ridge_t;
3712        {
3713            let mut builder = stream.launch_builder(&forward);
3714            builder
3715                .arg(&d_dev)
3716                .arg(&b_dev)
3717                .arg(&g_dev)
3718                .arg(&n_i32)
3719                .arg(&p_i32)
3720                .arg(&r_i32)
3721                .arg(&ridge_arg)
3722                .arg(&mut l_out)
3723                .arg(&mut u_out)
3724                .arg(&mut y_out)
3725                .arg(&mut partial_s)
3726                .arg(&mut partial_r)
3727                .arg(&mut status_dev);
3728            // SAFETY: all buffers were allocated on `stream` with sizes
3729            // derived from `plan`; parameter list matches FORWARD_KERNEL_SOURCE.
3730            unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3731        }
3732        stream
3733            .synchronize()
3734            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3735
3736        let status_host = stream
3737            .clone_dtoh(&status_dev)
3738            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3739        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3740            // Status code is a 1-based pivot row index, not a magnitude; size
3741            // the bump from the block (Gershgorin λ_min bound) so a strongly
3742            // indefinite block recovers in one retry.
3743            return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3744                row,
3745                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3746            });
3747        }
3748
3749        // Download Y_i factors: n × p_max × r_template column-major per block.
3750        let y_host = stream
3751            .clone_dtoh(&y_out)
3752            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3753
3754        // Capture H_ββ data for the closure. Use the matrix-free hook if present
3755        // (SAE-manifold callers), otherwise fall back to the dense matrix rows.
3756        let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3757        let hbb_is_kk = sys.hbb.dim() == (k, k);
3758        let hbb_matvec_opt = sys.hbb_matvec.clone();
3759
3760        let closure: crate::arrow_schur::GpuSchurMatvec =
3761            Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3762                assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3763                assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3764
3765                // (H_ββ + ρ·I)·x into out.
3766                if let Some(ref mv) = hbb_matvec_opt {
3767                    mv(x.view(), out);
3768                    for a in 0..k {
3769                        out[a] += ridge_beta * x[a];
3770                    }
3771                } else if hbb_is_kk {
3772                    // hbb_host row-major: hbb[a, b] = hbb_host[a * k + b].
3773                    for a in 0..k {
3774                        let mut acc = ridge_beta * x[a];
3775                        for b in 0..k {
3776                            acc += hbb_host[a * k + b] * x[b];
3777                        }
3778                        out[a] = acc;
3779                    }
3780                } else {
3781                    for a in 0..k {
3782                        out[a] = ridge_beta * x[a];
3783                    }
3784                }
3785
3786                // out[c] -= Σ_i (Y_i^T (Y_i·x))[c].
3787                // Y_i column-major at y_host[i·p_max·r_template + col·p_max + row].
3788                let mut z = vec![0.0_f64; d];
3789                for i in 0..n {
3790                    let y_base = i * p_max * r_template;
3791                    for r in 0..d {
3792                        let mut acc = 0.0;
3793                        for c in 0..k {
3794                            acc += y_host[y_base + c * p_max + r] * x[c];
3795                        }
3796                        z[r] = acc;
3797                    }
3798                    for c in 0..k {
3799                        let mut acc = 0.0;
3800                        for r in 0..d {
3801                            acc += y_host[y_base + c * p_max + r] * z[r];
3802                        }
3803                        out[c] -= acc;
3804                    }
3805                }
3806            });
3807
3808        Ok(closure)
3809    }
3810
3811    // ── #1017/#1026 frames-engaged device PCG ──────────────────────────────
3812
3813    struct DeviceSaeFrameBuffers {
3814        // Smooth `λ S_k ⊗ I_{r_k}`.
3815        s_off: CudaSlice<i32>,
3816        s_m: CudaSlice<i32>,
3817        s_r: CudaSlice<i32>,
3818        s_ptr: CudaSlice<i32>,
3819        s_data: CudaSlice<f64>,
3820        s_blocks: usize,
3821        // Data `G_{ij} ⊗ W_{ij}`.
3822        g_off_i: CudaSlice<i32>,
3823        g_off_j: CudaSlice<i32>,
3824        g_ri: CudaSlice<i32>,
3825        g_rj: CudaSlice<i32>,
3826        g_mi: CudaSlice<i32>,
3827        g_mj: CudaSlice<i32>,
3828        g_ptr: CudaSlice<i32>,
3829        g_data: CudaSlice<f64>,
3830        w_ptr: CudaSlice<i32>,
3831        w_data: CudaSlice<f64>,
3832        g_blocks: usize,
3833        g_max_work: usize,
3834        // Per-row dense cross-block H_tβ^(i) + row q + factored ainv.
3835        htb_ptr: CudaSlice<i32>,
3836        htb: CudaSlice<f64>,
3837        q_of: CudaSlice<i32>,
3838        ainv: CudaSlice<f64>,
3839        hvec: CudaSlice<f64>,
3840        svec: CudaSlice<f64>,
3841        n_rows: usize,
3842        k: usize,
3843        max_q: usize,
3844    }
3845
3846    fn flatten_device_sae_frame_data(
3847        sys: &ArrowSchurSystem,
3848        data: &DeviceSaePcgData,
3849        frame: &DeviceSaeFrameData,
3850        ridge_t: f64,
3851        stream: &Arc<CudaStream>,
3852    ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3853        let n_rows = sys.rows.len();
3854        let k = data.beta_dim;
3855        if frame.row_htbeta.len() != n_rows
3856            || frame.ranks.len() != frame.basis_sizes.len()
3857            || frame.border_offsets.len() != frame.ranks.len()
3858            || data.smooth_blocks.len() != frame.smooth_ranks.len()
3859        {
3860            return Err(ArrowSchurGpuFailure::Unavailable);
3861        }
3862
3863        // Smooth blocks.
3864        let mut s_off = Vec::new();
3865        let mut s_m = Vec::new();
3866        let mut s_r = Vec::new();
3867        let mut s_ptr = vec![0_i32];
3868        let mut s_data = Vec::<f64>::new();
3869        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3870            let (m, mc) = blk.factor_a.dim();
3871            if m != mc {
3872                return Err(ArrowSchurGpuFailure::Unavailable);
3873            }
3874            s_off.push(checked_i32(blk.global_offset)?);
3875            s_m.push(checked_i32(m)?);
3876            s_r.push(checked_i32(r)?);
3877            for ri in 0..m {
3878                for ci in 0..m {
3879                    s_data.push(blk.factor_a[[ri, ci]]);
3880                }
3881            }
3882            s_ptr.push(checked_i32(s_data.len())?);
3883        }
3884
3885        // Data blocks (g + w).
3886        let mut g_off_i = Vec::new();
3887        let mut g_off_j = Vec::new();
3888        let mut g_ri = Vec::new();
3889        let mut g_rj = Vec::new();
3890        let mut g_mi = Vec::new();
3891        let mut g_mj = Vec::new();
3892        let mut g_ptr = vec![0_i32];
3893        let mut g_data = Vec::<f64>::new();
3894        let mut w_ptr = vec![0_i32];
3895        let mut w_data = Vec::<f64>::new();
3896        let mut g_max_work = 0usize;
3897        for blk in &frame.frame_blocks {
3898            let ri = frame.ranks[blk.atom_i];
3899            let rj = frame.ranks[blk.atom_j];
3900            let (mi, mj) = blk.g.dim();
3901            if blk.w.dim() != (ri, rj) {
3902                return Err(ArrowSchurGpuFailure::Unavailable);
3903            }
3904            g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3905            g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3906            g_ri.push(checked_i32(ri)?);
3907            g_rj.push(checked_i32(rj)?);
3908            g_mi.push(checked_i32(mi)?);
3909            g_mj.push(checked_i32(mj)?);
3910            for r in 0..mi {
3911                for c in 0..mj {
3912                    g_data.push(blk.g[[r, c]]);
3913                }
3914            }
3915            g_ptr.push(checked_i32(g_data.len())?);
3916            for a in 0..ri {
3917                for b in 0..rj {
3918                    w_data.push(blk.w[[a, b]]);
3919                }
3920            }
3921            w_ptr.push(checked_i32(w_data.len())?);
3922            g_max_work = g_max_work.max(mi * ri);
3923        }
3924
3925        // Per-row dense cross-block + q + ainv (factored H_tt⁻¹).
3926        let mut htb_ptr = vec![0_i32];
3927        let mut htb = Vec::<f64>::new();
3928        let mut q_of = Vec::<i32>::with_capacity(n_rows);
3929        let mut max_q = 0usize;
3930        for (i, slab) in frame.row_htbeta.iter().enumerate() {
3931            let qi = sys.row_dims[i];
3932            // A populated slab must be q_i × k row-major; an empty slab ⇒ q=0
3933            // (the row contributes no reduced-Schur term).
3934            let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3935                qi
3936            } else {
3937                0
3938            };
3939            q_of.push(checked_i32(q_eff)?);
3940            max_q = max_q.max(q_eff);
3941            if q_eff > 0 {
3942                htb.extend_from_slice(slab);
3943            }
3944            htb_ptr.push(checked_i32(htb.len())?);
3945        }
3946        if max_q == 0 {
3947            // No row contributes a reduced term — the system is pure-penalty.
3948            // Still valid; give max_q=1 so the ainv buffer is non-empty.
3949            max_q = 1;
3950        }
3951
3952        let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3953        for (i, row) in sys.rows.iter().enumerate() {
3954            let q = q_of[i] as usize;
3955            if q == 0 {
3956                continue;
3957            }
3958            if row.htt.dim() != (q, q) {
3959                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3960                    reason: format!(
3961                        "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3962                        row.htt.dim()
3963                    ),
3964                });
3965            }
3966            let mut block = row.htt.clone();
3967            for d in 0..q {
3968                block[[d, d]] += ridge_t;
3969            }
3970            let factor = gam_linalg::triangular::cholesky_factor_in_place(
3971                block.view(),
3972                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3973            )
3974            .ok_or_else(|| {
3975                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
3976                // indefinite per-row block recovers in one outer-loop retry.
3977                ArrowSchurGpuFailure::RidgeBumpRequired {
3978                    row: i,
3979                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3980                }
3981            })?;
3982            for col in 0..q {
3983                let mut e = Array1::<f64>::zeros(q);
3984                e[col] = 1.0;
3985                let solved =
3986                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3987                for r in 0..q {
3988                    ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3989                }
3990            }
3991        }
3992
3993        let htod_i = |v: &[i32]| {
3994            stream
3995                .clone_htod(v)
3996                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3997        };
3998        let htod_f = |v: &[f64]| {
3999            stream
4000                .clone_htod(v)
4001                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4002        };
4003        Ok(DeviceSaeFrameBuffers {
4004            s_off: htod_i(&s_off)?,
4005            s_m: htod_i(&s_m)?,
4006            s_r: htod_i(&s_r)?,
4007            s_ptr: htod_i(&s_ptr)?,
4008            s_data: htod_f(&s_data)?,
4009            s_blocks: data.smooth_blocks.len(),
4010            g_off_i: htod_i(&g_off_i)?,
4011            g_off_j: htod_i(&g_off_j)?,
4012            g_ri: htod_i(&g_ri)?,
4013            g_rj: htod_i(&g_rj)?,
4014            g_mi: htod_i(&g_mi)?,
4015            g_mj: htod_i(&g_mj)?,
4016            g_ptr: htod_i(&g_ptr)?,
4017            g_data: htod_f(&g_data)?,
4018            w_ptr: htod_i(&w_ptr)?,
4019            w_data: htod_f(&w_data)?,
4020            g_blocks: frame.frame_blocks.len(),
4021            g_max_work,
4022            htb_ptr: htod_i(&htb_ptr)?,
4023            htb: htod_f(&htb)?,
4024            q_of: htod_i(&q_of)?,
4025            ainv: htod_f(&ainv)?,
4026            hvec: stream
4027                .alloc_zeros::<f64>(n_rows * max_q)
4028                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4029            svec: stream
4030                .alloc_zeros::<f64>(n_rows * max_q)
4031                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4032            n_rows,
4033            k,
4034            max_q,
4035        })
4036    }
4037
4038    fn sae_frame_penalty_diag_host(
4039        data: &DeviceSaePcgData,
4040        frame: &DeviceSaeFrameData,
4041        ridge_beta: f64,
4042    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4043        let mut diag = vec![ridge_beta; data.beta_dim];
4044        // Smooth: diag[off + ia·r + ib] += S[ia,ia].
4045        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4046            let m = blk.factor_a.nrows();
4047            for ia in 0..m {
4048                let coeff = blk.factor_a[[ia, ia]];
4049                let base = blk.global_offset + ia * r;
4050                for ib in 0..r {
4051                    if base + ib >= diag.len() {
4052                        return Err(ArrowSchurGpuFailure::Unavailable);
4053                    }
4054                    diag[base + ib] += coeff;
4055                }
4056            }
4057        }
4058        // Data: on-diagonal atom blocks contribute g[li,li]·w[a,a].
4059        for blk in &frame.frame_blocks {
4060            if blk.atom_i != blk.atom_j {
4061                continue;
4062            }
4063            let r = frame.ranks[blk.atom_i];
4064            let off = frame.border_offsets[blk.atom_i];
4065            let (mi, mj) = blk.g.dim();
4066            for li in 0..mi.min(mj) {
4067                let gii = blk.g[[li, li]];
4068                let base = off + li * r;
4069                for a in 0..r {
4070                    if base + a >= diag.len() {
4071                        return Err(ArrowSchurGpuFailure::Unavailable);
4072                    }
4073                    diag[base + a] += gii * blk.w[[a, a]];
4074                }
4075            }
4076        }
4077        Ok(diag)
4078    }
4079
4080    fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4081        Ok(LaunchConfig {
4082            grid_dim: (
4083                ((work as u32).saturating_add(255) / 256).max(1),
4084                checked_i32(n_rows)? as u32,
4085                1,
4086            ),
4087            block_dim: (256, 1, 1),
4088            shared_mem_bytes: 0,
4089        })
4090    }
4091
4092    fn launch_sae_frame_matvec(
4093        stream: &Arc<CudaStream>,
4094        module: &Arc<CudaModule>,
4095        buffers: &mut DeviceSaeFrameBuffers,
4096        x: &CudaSlice<f64>,
4097        out: &mut CudaSlice<f64>,
4098        ridge_beta: f64,
4099    ) -> Result<(), ArrowSchurGpuFailure> {
4100        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4101        // Smooth penalty.
4102        if buffers.s_blocks > 0 {
4103            let kernel = module
4104                .load_function("arrow_sae_frame_smooth_matvec")
4105                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4106            let blocks_i32 = checked_i32(buffers.s_blocks)?;
4107            let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4108            let mut b = stream.launch_builder(&kernel);
4109            b.arg(x)
4110                .arg(&mut *out)
4111                .arg(&buffers.s_off)
4112                .arg(&buffers.s_m)
4113                .arg(&buffers.s_r)
4114                .arg(&buffers.s_ptr)
4115                .arg(&buffers.s_data)
4116                .arg(&blocks_i32);
4117            // SAFETY: smooth block metadata/data are live device buffers; the grid
4118            // covers (k channels × n_blocks) and the kernel bounds-checks m·r.
4119            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4120        }
4121        // Data penalty.
4122        if buffers.g_blocks > 0 {
4123            let kernel = module
4124                .load_function("arrow_sae_frame_g_matvec")
4125                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4126            let blocks_i32 = checked_i32(buffers.g_blocks)?;
4127            let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4128            let mut b = stream.launch_builder(&kernel);
4129            b.arg(x)
4130                .arg(&mut *out)
4131                .arg(&buffers.g_off_i)
4132                .arg(&buffers.g_off_j)
4133                .arg(&buffers.g_ri)
4134                .arg(&buffers.g_rj)
4135                .arg(&buffers.g_mi)
4136                .arg(&buffers.g_mj)
4137                .arg(&buffers.g_ptr)
4138                .arg(&buffers.g_data)
4139                .arg(&buffers.w_ptr)
4140                .arg(&buffers.w_data)
4141                .arg(&blocks_i32);
4142            // SAFETY: g/w block metadata/data are live device buffers; the grid
4143            // covers (max m_i·r_i × n_blocks) and the kernel bounds-checks.
4144            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4145        }
4146        // Reduced-Schur subtraction via dense per-row cross-blocks.
4147        let k_i32 = checked_i32(buffers.k)?;
4148        let max_q_i32 = checked_i32(buffers.max_q)?;
4149        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4150        {
4151            let kernel = module
4152                .load_function("arrow_sae_frame_apply_h")
4153                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4154            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4155            let mut b = stream.launch_builder(&kernel);
4156            b.arg(x)
4157                .arg(&buffers.htb_ptr)
4158                .arg(&buffers.htb)
4159                .arg(&buffers.q_of)
4160                .arg(&mut buffers.hvec)
4161                .arg(&k_i32)
4162                .arg(&max_q_i32)
4163                .arg(&n_rows_i32);
4164            // SAFETY: dense cross-block + pointers + hvec are live buffers sized
4165            // for (n_rows × max_q) / (n_rows × k); kernel guards q_i and k.
4166            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4167        }
4168        {
4169            let kernel = module
4170                .load_function("arrow_sae_frame_apply_ainv")
4171                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4172            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4173            let mut b = stream.launch_builder(&kernel);
4174            b.arg(&buffers.ainv)
4175                .arg(&buffers.hvec)
4176                .arg(&buffers.q_of)
4177                .arg(&mut buffers.svec)
4178                .arg(&max_q_i32)
4179                .arg(&n_rows_i32);
4180            // SAFETY: ainv/hvec/svec are live buffers sized for n_rows·max_q²
4181            // and n_rows·max_q; the kernel guards row/coord bounds.
4182            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4183        }
4184        {
4185            let kernel = module
4186                .load_function("arrow_sae_frame_scatter_h")
4187                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4188            let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4189            let mut b = stream.launch_builder(&kernel);
4190            b.arg(&buffers.svec)
4191                .arg(&buffers.htb_ptr)
4192                .arg(&buffers.htb)
4193                .arg(&buffers.q_of)
4194                .arg(out)
4195                .arg(&k_i32)
4196                .arg(&max_q_i32)
4197                .arg(&n_rows_i32);
4198            // SAFETY: svec/cross-block/out are live buffers; the kernel atomically
4199            // accumulates into out[a] for a<k and reads c<q_i.
4200            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4201        }
4202        Ok(())
4203    }
4204
4205    fn launch_sae_frame_diag_sub(
4206        stream: &Arc<CudaStream>,
4207        module: &Arc<CudaModule>,
4208        buffers: &DeviceSaeFrameBuffers,
4209        diag: &mut CudaSlice<f64>,
4210    ) -> Result<(), ArrowSchurGpuFailure> {
4211        let kernel = module
4212            .load_function("arrow_sae_frame_diag_sub")
4213            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4214        let k_i32 = checked_i32(buffers.k)?;
4215        let max_q_i32 = checked_i32(buffers.max_q)?;
4216        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4217        let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4218        let mut b = stream.launch_builder(&kernel);
4219        b.arg(diag)
4220            .arg(&buffers.ainv)
4221            .arg(&buffers.htb_ptr)
4222            .arg(&buffers.htb)
4223            .arg(&buffers.q_of)
4224            .arg(&k_i32)
4225            .arg(&max_q_i32)
4226            .arg(&n_rows_i32);
4227        // SAFETY: diag + cross-block + ainv live buffers; kernel guards a<k, c/d<q.
4228        unsafe { b.launch(cfg) }
4229            .map(drop)
4230            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4231    }
4232
4233    /// #1551 kernel-isolating seam: evaluate the framed reduced-Schur matvec
4234    /// `out = S·x` EXACTLY ONCE on the device (no PCG, no offload-floor gate) and
4235    /// return `out`. This is the parity probe the test harness diffs against the
4236    /// CPU oracle [`super::sae_framed_schur_matvec_cpu`] element-by-element, so a
4237    /// kernel/marshalling defect is exposed directly — independent of how the
4238    /// iterative solver behaves on an ill-conditioned assembled `S` (where dense
4239    /// Cholesky and PCG legitimately disagree at the solution level). Declines
4240    /// (`Unavailable`) only when CUDA is genuinely absent so the test skips
4241    /// cleanly off-device; it deliberately does NOT consult the offload policy so
4242    /// even a tiny verifiable fixture runs on the GPU.
4243    pub(super) fn framed_schur_matvec_once_on_device(
4244        sys: &ArrowSchurSystem,
4245        data: &DeviceSaePcgData,
4246        ridge_t: f64,
4247        ridge_beta: f64,
4248        x: &Array1<f64>,
4249    ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4250        let k = x.len();
4251        if k == 0 || data.beta_dim != k || sys.k != k {
4252            return Err(ArrowSchurGpuFailure::Unavailable);
4253        }
4254        let frame = data
4255            .frame
4256            .as_ref()
4257            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4258        // No offload-policy filter here: the seam exists to validate the kernel on
4259        // ANY device, including the smallest hand-checkable fixture.
4260        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4261            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4262        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4263            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4264        let stream = ctx
4265            .new_stream()
4266            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4267        let vector_module = pcg_vector_module(&ctx)?;
4268        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4269        let x_dev = stream
4270            .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4271            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4272        let mut out_dev = stream
4273            .alloc_zeros::<f64>(k)
4274            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4275        launch_sae_frame_matvec(
4276            &stream,
4277            vector_module,
4278            &mut buffers,
4279            &x_dev,
4280            &mut out_dev,
4281            ridge_beta,
4282        )?;
4283        let out = stream
4284            .clone_dtoh(&out_dev)
4285            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4286        Ok(Array1::from_vec(out))
4287    }
4288
4289    pub(super) fn solve_sae_matrix_free_pcg_framed(
4290        sys: &ArrowSchurSystem,
4291        data: &DeviceSaePcgData,
4292        ridge_t: f64,
4293        ridge_beta: f64,
4294        rhs_beta: &Array1<f64>,
4295        max_iterations: usize,
4296        relative_tolerance: f64,
4297    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4298        let k = rhs_beta.len();
4299        if k == 0 || data.beta_dim != k || sys.k != k {
4300            return Err(ArrowSchurGpuFailure::Unavailable);
4301        }
4302        let frame = data
4303            .frame
4304            .as_ref()
4305            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4306        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4307            .filter(|rt| {
4308                rt.policy().reduced_schur_matvec_should_offload(
4309                    sys.rows.len(),
4310                    sys.k,
4311                    sys.d,
4312                    max_iterations,
4313                )
4314            })
4315            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4317            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4318        let stream = ctx
4319            .new_stream()
4320            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4321        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4322        let vector_module = pcg_vector_module(&ctx)?;
4323        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4324
4325        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4326        if rhs_norm == 0.0 {
4327            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4328        }
4329        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4330        let rhs_dev = stream
4331            .clone_htod(
4332                rhs_beta
4333                    .as_slice()
4334                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4335            )
4336            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4337        let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4338        let mut diag_dev = stream
4339            .clone_htod(&diag_host)
4340            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4341        launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4342        let diag_host = stream
4343            .clone_dtoh(&diag_dev)
4344            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4345        let mut inv_diag = Vec::with_capacity(k);
4346        for (idx, &d) in diag_host.iter().enumerate() {
4347            if !d.is_finite() || d <= 1.0e-18 {
4348                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4349                    reason: format!(
4350                        "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4351                    ),
4352                });
4353            }
4354            inv_diag.push(1.0 / d);
4355        }
4356        let inv_diag_dev = stream
4357            .clone_htod(&inv_diag)
4358            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4359
4360        let mut x_dev = stream
4361            .alloc_zeros::<f64>(k)
4362            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4363        let mut r_dev = stream
4364            .alloc_zeros::<f64>(k)
4365            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4366        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4367        let mut z_dev = stream
4368            .alloc_zeros::<f64>(k)
4369            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4370        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4371        let mut p_dev = stream
4372            .alloc_zeros::<f64>(k)
4373            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4374        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4375        let mut ap_dev = stream
4376            .alloc_zeros::<f64>(k)
4377            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4378
4379        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4380        if rz <= 0.0 || !rz.is_finite() {
4381            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4382                reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4383            });
4384        }
4385        let mut diag = PcgDiagnostics {
4386            precond_apply_calls: 1,
4387            stopping_reason: PcgStopReason::MaxIter,
4388            ..PcgDiagnostics::default()
4389        };
4390        for _ in 0..max_iterations.max(1) {
4391            launch_sae_frame_matvec(
4392                &stream,
4393                vector_module,
4394                &mut buffers,
4395                &p_dev,
4396                &mut ap_dev,
4397                ridge_beta,
4398            )?;
4399            diag.matvec_calls += 1;
4400            diag.iterations += 1;
4401            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4402            if pap <= 0.0 || !pap.is_finite() {
4403                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4404                    reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4405                });
4406            }
4407            let alpha = rz / pap;
4408            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4409            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4410            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4411            if r_norm <= tol {
4412                diag.final_relative_residual = r_norm / rhs_norm;
4413                diag.stopping_reason = PcgStopReason::Converged;
4414                break;
4415            }
4416            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4417            diag.precond_apply_calls += 1;
4418            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4419            if rz_new <= 0.0 || !rz_new.is_finite() {
4420                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4421                    reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4422                });
4423            }
4424            let beta = rz_new / rz;
4425            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4426            rz = rz_new;
4427        }
4428        if diag.stopping_reason != PcgStopReason::Converged {
4429            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4430            diag.final_relative_residual = r_norm / rhs_norm;
4431            diag.stopping_reason = PcgStopReason::MaxIter;
4432        }
4433        let x = stream
4434            .clone_dtoh(&x_dev)
4435            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4436        Ok((Array1::from_vec(x), diag))
4437    }
4438
4439    /// #1551 stage-isolating triage seam: run the framed reduced-Schur matvec
4440    /// `out = S·x` ONCE on the device (no PCG, no offload-floor gate) and return
4441    /// `out`, so a tiny hand-verifiable fixture can diff it against the CPU oracle
4442    /// `sae_framed_schur_matvec_cpu` element-by-element to localize the structural
4443    /// divergence to a single kernel stage. Returns `Unavailable` only when CUDA
4444    /// is genuinely absent (so the test skips cleanly off-device).
4445    pub(super) fn solve_sae_matrix_free_pcg(
4446        sys: &ArrowSchurSystem,
4447        data: &DeviceSaePcgData,
4448        ridge_t: f64,
4449        ridge_beta: f64,
4450        rhs_beta: &Array1<f64>,
4451        max_iterations: usize,
4452        relative_tolerance: f64,
4453    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4454        let k = rhs_beta.len();
4455        if k == 0 || data.beta_dim != k || sys.k != k {
4456            return Err(ArrowSchurGpuFailure::Unavailable);
4457        }
4458        // #1017/#1026 GUARD: the legacy `⊗ I_p` kernel must NEVER receive framed
4459        // data (factored `G ⊗ W_{ij}` + dense per-row cross blocks); decline so a
4460        // mis-route falls back to the CPU rather than returning a wrong step.
4461        if data.frame.is_some() {
4462            return Err(ArrowSchurGpuFailure::Unavailable);
4463        }
4464        // #1017 Phase-1 dispatch re-key: this is the matrix-free SAE reduced-Schur
4465        // PCG — the production hot path, not a single dense factorization. The
4466        // dense-Direct floor `dense_hessian_work_target_is_gpu(n, k)` keys on
4467        // `2·n·k²` and is the WRONG gate here: it ignores the per-row frame depth
4468        // `d` (the M dimension that multiplies the per-apply work) and the
4469        // `1/cg_iters` staging amortisation, so it both undercounts the SAE batched
4470        // work `n·k·d` and applies a cold single-launch breakeven to an apply that
4471        // reuses device-resident frames `max_iterations` times. Key instead on the
4472        // CG-amortised total batched work — the same predicate the host injection
4473        // gate (`maybe_inject_gpu_schur_matvec`) consults — so few-row/wide-`k`/
4474        // modest-`d` LLM shapes register the real `n × k × d × cg_iters` arithmetic.
4475        // Kernels and numerics are untouched; only where the matvec runs changes,
4476        // and the host falls back to the bit-identical CPU matvec when this declines.
4477        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4478            .filter(|rt| {
4479                rt.policy().reduced_schur_matvec_should_offload(
4480                    sys.rows.len(),
4481                    sys.k,
4482                    sys.d,
4483                    max_iterations,
4484                )
4485            })
4486            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4487        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4488            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4489        let stream = ctx
4490            .new_stream()
4491            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4492        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4493        let vector_module = pcg_vector_module(&ctx)?;
4494        let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4495
4496        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4497        if rhs_norm == 0.0 {
4498            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4499        }
4500        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4501        let rhs_dev = stream
4502            .clone_htod(
4503                rhs_beta
4504                    .as_slice()
4505                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4506            )
4507            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4508        let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4509        let mut diag_dev = stream
4510            .clone_htod(&diag_host)
4511            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4512        launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4513        let diag_host = stream
4514            .clone_dtoh(&diag_dev)
4515            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4516        let mut inv_diag = Vec::with_capacity(k);
4517        for (idx, &d) in diag_host.iter().enumerate() {
4518            if !d.is_finite() || d <= 1.0e-18 {
4519                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4520                    reason: format!(
4521                        "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4522                    ),
4523                });
4524            }
4525            inv_diag.push(1.0 / d);
4526        }
4527        let inv_diag_dev = stream
4528            .clone_htod(&inv_diag)
4529            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4530
4531        let mut x_dev = stream
4532            .alloc_zeros::<f64>(k)
4533            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4534        let mut r_dev = stream
4535            .alloc_zeros::<f64>(k)
4536            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4537        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4538        let mut z_dev = stream
4539            .alloc_zeros::<f64>(k)
4540            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4542        let mut p_dev = stream
4543            .alloc_zeros::<f64>(k)
4544            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4545        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4546        let mut ap_dev = stream
4547            .alloc_zeros::<f64>(k)
4548            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4549
4550        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4551        if rz <= 0.0 || !rz.is_finite() {
4552            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4553                reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4554            });
4555        }
4556        let mut diag = PcgDiagnostics {
4557            precond_apply_calls: 1,
4558            stopping_reason: PcgStopReason::MaxIter,
4559            ..PcgDiagnostics::default()
4560        };
4561
4562        for _ in 0..max_iterations.max(1) {
4563            launch_sae_matvec(
4564                &stream,
4565                vector_module,
4566                &mut buffers,
4567                &p_dev,
4568                &mut ap_dev,
4569                ridge_beta,
4570            )?;
4571            diag.matvec_calls += 1;
4572            diag.iterations += 1;
4573            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4574            if pap <= 0.0 || !pap.is_finite() {
4575                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4576                    reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4577                });
4578            }
4579            let alpha = rz / pap;
4580            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4581            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4582            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4583            if r_norm <= tol {
4584                diag.final_relative_residual = r_norm / rhs_norm;
4585                diag.stopping_reason = PcgStopReason::Converged;
4586                break;
4587            }
4588            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4589            diag.precond_apply_calls += 1;
4590            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4591            if rz_new <= 0.0 || !rz_new.is_finite() {
4592                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4593                    reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4594                });
4595            }
4596            let beta = rz_new / rz;
4597            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4598            rz = rz_new;
4599        }
4600        if diag.stopping_reason != PcgStopReason::Converged {
4601            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4602            diag.final_relative_residual = r_norm / rhs_norm;
4603            diag.stopping_reason = PcgStopReason::MaxIter;
4604        }
4605        let x = stream
4606            .clone_dtoh(&x_dev)
4607            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4608        Ok((Array1::from_vec(x), diag))
4609    }
4610
4611    pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4612        s_acc: &ndarray::Array2<f64>,
4613        rhs_beta: &Array1<f64>,
4614        max_iterations: usize,
4615        relative_tolerance: f64,
4616    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4617        let k = rhs_beta.len();
4618        // #1017 dispatch re-key: this is an ITERATIVE device-resident PCG, not a
4619        // single GEMV. `S` (k×k) is uploaded once and reused for `max_iterations`
4620        // `S·p` GEMVs while only convergence scalars cross PCIe, so the staging
4621        // cost is amortised over the whole CG solve. Gating on the flops of ONE
4622        // `Gemv{k,k}` (`2·k²`) understates the work by the iteration count and
4623        // declines shapes (e.g. k≈512) whose total iterated arithmetic
4624        // `2·k²·iters` clears the device floor by orders of magnitude — the same
4625        // single-launch-breakeven miskey #1017 fixed for the framed reduced-Schur
4626        // matvec. Key on the CG-amortised total work via a `Gemm{k,k,iters}` whose
4627        // `flops()` is exactly `2·k²·iters`; numerics and kernels are untouched,
4628        // and the host falls back to the bit-identical CPU PCG when this declines.
4629        let cg_iters = max_iterations.max(1);
4630        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4631            gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4632                m: k,
4633                n: k,
4634                k: cg_iters,
4635            },
4636        )
4637        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4638        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4639            .and_then(|ctx| ctx.new_stream().ok())
4640            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4641        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4642        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4643            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4644        let vector_module = pcg_vector_module(&ctx)?;
4645
4646        // Jacobi diagonal from S; must be strictly positive for SPD.
4647        let mut inv_diag = vec![0.0_f64; k];
4648        for j in 0..k {
4649            let djj = s_acc[[j, j]];
4650            if !(djj > 0.0) {
4651                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4652                    reason: format!(
4653                        "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4654                    ),
4655                });
4656            }
4657            inv_diag[j] = 1.0 / djj;
4658        }
4659
4660        // Upload S column-major (S[row,col] at col*k + row).
4661        let mut s_host = vec![0.0_f64; k * k];
4662        for col in 0..k {
4663            for row in 0..k {
4664                s_host[col * k + row] = s_acc[[row, col]];
4665            }
4666        }
4667        let s_dev = stream
4668            .clone_htod(&s_host)
4669            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4670
4671        // Steihaug truncated-CG with Jacobi preconditioner, host scalar
4672        // recurrences and a device `S·p` matvec. The streaming reduced solve
4673        // uses an unbounded trust region (pure CG to tolerance).
4674        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4675        if rhs_norm == 0.0 {
4676            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4677        }
4678        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4679
4680        // Device-resident PCG state. Only convergence scalars cross back during
4681        // the loop; x/r/z/p/Sp stay on CUDA until the final solution download.
4682        let mut x_dev = stream
4683            .alloc_zeros::<f64>(k)
4684            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4685        let mut r_dev = stream
4686            .clone_htod(
4687                rhs_beta
4688                    .as_slice()
4689                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4690            )
4691            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4692        let inv_diag_dev = stream
4693            .clone_htod(&inv_diag)
4694            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4695        let mut z_dev = stream
4696            .alloc_zeros::<f64>(k)
4697            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4698        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4699        let mut p_dev = stream
4700            .alloc_zeros::<f64>(k)
4701            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4703        let mut sp_dev = stream
4704            .alloc_zeros::<f64>(k)
4705            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4706        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4707        let mut diag = PcgDiagnostics {
4708            precond_apply_calls: 1,
4709            stopping_reason: PcgStopReason::MaxIter,
4710            ..PcgDiagnostics::default()
4711        };
4712        if rz <= 0.0 || !rz.is_finite() {
4713            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4714                reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4715            });
4716        }
4717
4718        let max_iters = max_iterations.max(1);
4719        for _ in 0..max_iters {
4720            // sp = S · p (device GEMV, S column-major k×k, op = N).
4721            let gemv_cfg = GemvConfig::<f64> {
4722                trans: cublasOperation_t::CUBLAS_OP_N,
4723                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4724                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4725                alpha: 1.0,
4726                lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4727                incx: 1,
4728                beta: 0.0,
4729                incy: 1,
4730            };
4731            // SAFETY: s_dev is k×k column-major, p_dev / sp_dev length k.
4732            unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4733                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4734            diag.matvec_calls += 1;
4735            diag.iterations += 1;
4736
4737            let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4738            if !(p_sp > 0.0) {
4739                // Non-positive curvature on a (proximal-ridged) SPD system means
4740                // numerical breakdown; surface so the caller escalates.
4741                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4742                    reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4743                });
4744            }
4745            let alpha = rz / p_sp;
4746            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4747            device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4748            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4749            if r_norm <= tol {
4750                diag.final_relative_residual = r_norm / rhs_norm;
4751                diag.stopping_reason = PcgStopReason::Converged;
4752                break;
4753            }
4754            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4755            diag.precond_apply_calls += 1;
4756            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4757            if rz_new <= 0.0 || !rz_new.is_finite() {
4758                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4759                    reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4760                });
4761            }
4762            let beta = rz_new / rz;
4763            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4764            rz = rz_new;
4765        }
4766        if diag.stopping_reason != PcgStopReason::Converged {
4767            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4768            diag.final_relative_residual = r_norm / rhs_norm;
4769            diag.stopping_reason = PcgStopReason::MaxIter;
4770        }
4771
4772        let x = stream
4773            .clone_dtoh(&x_dev)
4774            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4775        Ok((Array1::from_vec(x), diag))
4776    }
4777
4778    fn device_copy(
4779        blas: &CudaBlas,
4780        stream: &Arc<CudaStream>,
4781        n: usize,
4782        src: &CudaSlice<f64>,
4783        dst: &mut CudaSlice<f64>,
4784    ) -> Result<(), ArrowSchurGpuFailure> {
4785        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4786        let (src_ptr, _src_rec) = src.device_ptr(stream);
4787        let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4788        // SAFETY: src and dst are live device allocations on this stream with at
4789        // least n contiguous f64 entries and unit stride.
4790        let status = unsafe {
4791            cudarc::cublas::sys::cublasDcopy_v2(
4792                *blas.handle(),
4793                n_i,
4794                src_ptr as *const f64,
4795                1,
4796                dst_ptr as *mut f64,
4797                1,
4798            )
4799        };
4800        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4801            Ok(())
4802        } else {
4803            Err(ArrowSchurGpuFailure::Unavailable)
4804        }
4805    }
4806
4807    fn device_axpy(
4808        blas: &CudaBlas,
4809        stream: &Arc<CudaStream>,
4810        n: usize,
4811        alpha: f64,
4812        x: &CudaSlice<f64>,
4813        y: &mut CudaSlice<f64>,
4814    ) -> Result<(), ArrowSchurGpuFailure> {
4815        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4816        let (x_ptr, _x_rec) = x.device_ptr(stream);
4817        let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4818        // SAFETY: x and y are live device allocations on this stream with at
4819        // least n contiguous f64 entries and unit stride; cuBLAS only reads alpha.
4820        let status = unsafe {
4821            cudarc::cublas::sys::cublasDaxpy_v2(
4822                *blas.handle(),
4823                n_i,
4824                &alpha,
4825                x_ptr as *const f64,
4826                1,
4827                y_ptr as *mut f64,
4828                1,
4829            )
4830        };
4831        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4832            Ok(())
4833        } else {
4834            Err(ArrowSchurGpuFailure::Unavailable)
4835        }
4836    }
4837
4838    fn device_dot(
4839        blas: &CudaBlas,
4840        stream: &Arc<CudaStream>,
4841        n: usize,
4842        x: &CudaSlice<f64>,
4843        y: &CudaSlice<f64>,
4844    ) -> Result<f64, ArrowSchurGpuFailure> {
4845        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4846        let (x_ptr, _x_rec) = x.device_ptr(stream);
4847        let (y_ptr, _y_rec) = y.device_ptr(stream);
4848        let mut result = 0.0_f64;
4849        // SAFETY: x and y are live device allocations on this stream with at
4850        // least n contiguous f64 entries and unit stride; result is a valid host
4851        // out-pointer for the cuBLAS scalar.
4852        let status = unsafe {
4853            cudarc::cublas::sys::cublasDdot_v2(
4854                *blas.handle(),
4855                n_i,
4856                x_ptr as *const f64,
4857                1,
4858                y_ptr as *const f64,
4859                1,
4860                &mut result,
4861            )
4862        };
4863        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4864            Ok(result)
4865        } else {
4866            Err(ArrowSchurGpuFailure::Unavailable)
4867        }
4868    }
4869
4870    fn device_nrm2(
4871        blas: &CudaBlas,
4872        stream: &Arc<CudaStream>,
4873        n: usize,
4874        x: &CudaSlice<f64>,
4875    ) -> Result<f64, ArrowSchurGpuFailure> {
4876        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4877        let (x_ptr, _x_rec) = x.device_ptr(stream);
4878        let mut result = 0.0_f64;
4879        // SAFETY: x is a live device allocation on this stream with at least n
4880        // contiguous f64 entries and unit stride; result is a valid host
4881        // out-pointer for the cuBLAS scalar.
4882        let status = unsafe {
4883            cudarc::cublas::sys::cublasDnrm2_v2(
4884                *blas.handle(),
4885                n_i,
4886                x_ptr as *const f64,
4887                1,
4888                &mut result,
4889            )
4890        };
4891        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4892            Ok(result)
4893        } else {
4894            Err(ArrowSchurGpuFailure::Unavailable)
4895        }
4896    }
4897
4898    #[cfg(test)]
4899    mod tests {
4900        //! #1551 device-side framed-matvec triage. Lives inside `mod cuda` so it
4901        //! can call the private kernel launchers directly (no test-only public
4902        //! seam, which the ban-scanner forbids). A bare `#[cfg(test)] mod tests`
4903        //! is the one form the scanner permits.
4904        use super::*;
4905        use crate::arrow_schur::{
4906            ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4907            FactoredFrameGBlock,
4908        };
4909        use ndarray::Array2;
4910
4911        /// Run the framed reduced-Schur matvec `out = S·x` ONCE on the device
4912        /// (no PCG, no offload gate) and return `out`.
4913        fn device_matvec_once(
4914            sys: &ArrowSchurSystem,
4915            data: &DeviceSaePcgData,
4916            ridge_t: f64,
4917            ridge_beta: f64,
4918            x_host: &[f64],
4919        ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4920            let k = x_host.len();
4921            let frame = data
4922                .frame
4923                .as_ref()
4924                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4925            let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4926                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4927            let ctx =
4928                gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4929                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4930            let stream = ctx
4931                .new_stream()
4932                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4933            let vector_module = pcg_vector_module(&ctx)?;
4934            let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4935            let x_dev = stream
4936                .clone_htod(x_host)
4937                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4938            let mut out_dev = stream
4939                .alloc_zeros::<f64>(k)
4940                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4941            launch_sae_frame_matvec(
4942                &stream,
4943                vector_module,
4944                &mut buffers,
4945                &x_dev,
4946                &mut out_dev,
4947                ridge_beta,
4948            )?;
4949            stream
4950                .clone_dtoh(&out_dev)
4951                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4952        }
4953
4954        /// #1551 stage-isolating matvec triage on a TINY hand-verifiable fixture:
4955        /// diff the device framed matvec `S·e_col` against the CPU oracle
4956        /// `sae_framed_schur_matvec_cpu` for every identity column, reporting the
4957        /// worst-divergent border index so the structural 91% localizes to one
4958        /// kernel stage. Skips cleanly off-device.
4959        #[test]
4960        fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4961            if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4962                return;
4963            }
4964            let p = 3usize;
4965            let ranks = vec![2usize, 3usize];
4966            let basis_sizes = vec![2usize, 2usize];
4967            let mut border_offsets = Vec::new();
4968            let mut acc = 0usize;
4969            for k in 0..2 {
4970                border_offsets.push(acc);
4971                acc += basis_sizes[k] * ranks[k];
4972            }
4973            let border_dim = acc; // 2·2 + 2·3 = 10
4974            let frame_of = |k: usize| -> Array2<f64> {
4975                Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4976                    0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4977                })
4978            };
4979            let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4980            let w_of = |i: usize, j: usize| -> Array2<f64> {
4981                let (ui, uj) = (&frames[i], &frames[j]);
4982                Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4983                    (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4984                })
4985            };
4986            let mut frame_blocks = Vec::new();
4987            for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4988                let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4989                let mut g =
4990                    Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4991                if i == j {
4992                    for r in 0..mi.min(mj) {
4993                        g[[r, r]] += mi as f64 + 2.0;
4994                    }
4995                }
4996                frame_blocks.push(FactoredFrameGBlock {
4997                    atom_i: i,
4998                    atom_j: j,
4999                    g,
5000                    w: w_of(i, j),
5001                });
5002            }
5003            let mut smooth_blocks = Vec::new();
5004            for k in 0..2 {
5005                let m = basis_sizes[k];
5006                let mut s =
5007                    Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5008                for r in 0..m {
5009                    s[[r, r]] += 1.0;
5010                }
5011                smooth_blocks.push(DeviceSaeSmoothBlock {
5012                    global_offset: border_offsets[k],
5013                    factor_a: s,
5014                });
5015            }
5016            let smooth_ranks = ranks.clone();
5017            let n = 2usize;
5018            let q = 2usize;
5019            let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5020            let mut row_htbeta = Vec::new();
5021            for i in 0..n {
5022                let mut htt =
5023                    Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5024                for r in 0..q {
5025                    htt[[r, r]] += q as f64 + 2.0;
5026                }
5027                sys.rows[i].htt = htt;
5028                let mut slab = vec![0.0_f64; q * border_dim];
5029                for c in 0..q {
5030                    for col in 0..border_dim {
5031                        let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5032                        slab[c * border_dim + col] = v;
5033                        sys.rows[i].htbeta[[c, col]] = v;
5034                    }
5035                }
5036                row_htbeta.push(slab);
5037            }
5038            let data = DeviceSaePcgData {
5039                p,
5040                beta_dim: border_dim,
5041                a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5042                local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5043                smooth_blocks,
5044                sparse_g_blocks: Vec::new(),
5045                frame: Some(DeviceSaeFrameData {
5046                    ranks,
5047                    basis_sizes,
5048                    border_offsets,
5049                    frame_blocks,
5050                    smooth_ranks,
5051                    row_htbeta,
5052                }),
5053            };
5054            let ridge_t = 1e-7;
5055            let ridge_beta = 1e-6;
5056            let mut first_bad: Option<usize> = None;
5057            let mut worst = 0.0_f64;
5058            let mut worst_at = 0usize;
5059            let mut worst_dev = 0.0_f64;
5060            let mut worst_cpu = 0.0_f64;
5061            for col in 0..border_dim {
5062                let mut x = vec![0.0_f64; border_dim];
5063                x[col] = 1.0;
5064                let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5065                    Ok(v) => v,
5066                    Err(_) => return,
5067                };
5068                let mut cpu = vec![0.0_f64; border_dim];
5069                super::super::sae_framed_schur_matvec_cpu(
5070                    &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5071                )
5072                .expect("cpu matvec");
5073                for r in 0..border_dim {
5074                    let d = (dev[r] - cpu[r]).abs();
5075                    if d > 1e-9 && first_bad.is_none() {
5076                        first_bad = Some(r * border_dim + col);
5077                    }
5078                    if d > worst {
5079                        worst = d;
5080                        worst_at = r * border_dim + col;
5081                        worst_dev = dev[r];
5082                        worst_cpu = cpu[r];
5083                    }
5084                }
5085            }
5086            assert!(
5087                worst <= 1e-9,
5088                "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5089                 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5090                 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5091                 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5092                 G⊗W=cross, reduced-Schur=dense per-row)",
5093            );
5094        }
5095    }
5096}
5097
5098#[cfg(test)]
5099mod tests {
5100    use super::*;
5101    use crate::arrow_schur::ArrowSchurSystem;
5102    use ndarray::{Array2, ArrayView1};
5103
5104    fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5105        let mut sys = ArrowSchurSystem::new(n, d, k);
5106        let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5107        let mut sample = || -> f64 {
5108            state = state
5109                .wrapping_mul(6364136223846793005)
5110                .wrapping_add(1442695040888963407);
5111            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5112        };
5113        for row in &mut sys.rows {
5114            let mut a = Array2::<f64>::zeros((d, d));
5115            for r in 0..d {
5116                for c in 0..d {
5117                    a[[r, c]] = sample();
5118                }
5119            }
5120            let mut htt = a.t().dot(&a);
5121            for r in 0..d {
5122                htt[[r, r]] += d as f64 + 1.0;
5123            }
5124            row.htt = htt;
5125            for r in 0..d {
5126                for c in 0..k {
5127                    row.htbeta[[r, c]] = 0.1 * sample();
5128                }
5129                row.gt[r] = sample();
5130            }
5131        }
5132        let mut hbb_a = Array2::<f64>::zeros((k, k));
5133        for r in 0..k {
5134            for c in 0..k {
5135                hbb_a[[r, c]] = sample();
5136            }
5137        }
5138        let mut hbb = hbb_a.t().dot(&hbb_a);
5139        for r in 0..k {
5140            hbb[[r, r]] += k as f64 + 1.0;
5141        }
5142        sys.hbb = hbb;
5143        for r in 0..k {
5144            sys.gb[r] = sample();
5145        }
5146        sys
5147    }
5148
5149    /// The Gershgorin ridge bump must actually make a known-indefinite block PD
5150    /// on the first retry — the whole point of #1711. Verified directly by
5151    /// re-factoring `H_tt + (ridge_t + bump)·I` with the same Cholesky guard the
5152    /// device readback uses, for blocks whose `λ_min` is known in closed form.
5153    #[test]
5154    fn ridge_bump_makes_known_indefinite_blocks_pd() {
5155        // `cholesky_factor_in_place` / `CholeskyGuard` are already in scope via
5156        // `super::*` (imported at the top of the module).
5157        // A few blocks with a CLOSED-FORM smallest eigenvalue, all at ridge_t=0.
5158        // (label, matrix, λ_min) — the bump must clear each one.
5159        let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); // λ_min = -1
5160        let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); // λ_min = -250
5161        // Symmetric 2×2 [[1, 2], [2, 1]] has eigenvalues 3 and -1 → indefinite.
5162        let mut indef2 = Array2::<f64>::zeros((2, 2));
5163        indef2[[0, 0]] = 1.0;
5164        indef2[[1, 1]] = 1.0;
5165        indef2[[0, 1]] = 2.0;
5166        indef2[[1, 0]] = 2.0;
5167        // A genuinely PD block must get a bump that is the bare rounding margin
5168        // only (deficit 0), and must still factor — the helper is defensive.
5169        let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5170
5171        for (label, block) in [
5172            ("-I (λ_min=-1)", neg_identity),
5173            ("-250·I (λ_min=-250)", scaled_neg),
5174            ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5175            ("5·I (PD)", pd),
5176        ] {
5177            let ridge_t = 0.0;
5178            let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5179            assert!(
5180                bump > 0.0 && bump.is_finite(),
5181                "[{label}] bump must be strictly positive and finite, got {bump:e}"
5182            );
5183            let d = block.nrows();
5184            let mut shifted = block.clone();
5185            for i in 0..d {
5186                shifted[[i, i]] += ridge_t + bump;
5187            }
5188            assert!(
5189                cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5190                "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5191                 Gershgorin bump, but the Cholesky still rejected it"
5192            );
5193        }
5194    }
5195
5196    /// The column-major variant (multi-GPU tile path) must agree with the
5197    /// row-major helper for a symmetric block, since Gershgorin edges are
5198    /// invariant under reading the symmetric matrix by row vs by column. The
5199    /// colmajor variant takes the bound at ridge_t=0 (the ridge is already baked
5200    /// into the diagonal it reads), so compare against `ridge_bump_to_make_pd`
5201    /// with `ridge_t = 0`.
5202    #[test]
5203    fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5204        // Symmetric 3×3 with a negative-definite-ish diagonal and off-diagonals.
5205        let mut a = Array2::<f64>::zeros((3, 3));
5206        a[[0, 0]] = -2.0;
5207        a[[1, 1]] = 0.5;
5208        a[[2, 2]] = 1.0;
5209        a[[0, 1]] = 0.3;
5210        a[[1, 0]] = 0.3;
5211        a[[1, 2]] = -0.4;
5212        a[[2, 1]] = -0.4;
5213        a[[0, 2]] = 0.1;
5214        a[[2, 0]] = 0.1;
5215
5216        let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5217
5218        // Flatten column-major: block[c*d + r] = a[[r, c]].
5219        let d = 3;
5220        let mut col_major = vec![0.0_f64; d * d];
5221        for c in 0..d {
5222            for r in 0..d {
5223                col_major[c * d + r] = a[[r, c]];
5224            }
5225        }
5226        let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5227
5228        assert!(
5229            (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5230            "colmajor bump {col_major_bump:e} must match rowmajor bump \
5231             {row_major_bump:e} for a symmetric block"
5232        );
5233
5234        // And the bump must actually make it PD (sanity, same as the row-major test).
5235        let mut shifted = a.clone();
5236        for i in 0..d {
5237            shifted[[i, i]] += col_major_bump;
5238        }
5239        assert!(
5240            cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5241            "colmajor Gershgorin bump must make the symmetric block PD"
5242        );
5243    }
5244
5245    fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5246        let mut s = Array2::<f64>::zeros((k, k));
5247        for row in 0..k {
5248            s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5249            if row + 1 < k {
5250                s[[row, row + 1]] = -0.05;
5251                s[[row + 1, row]] = -0.05;
5252            }
5253            if row + 7 < k {
5254                s[[row, row + 7]] = 0.01;
5255                s[[row + 7, row]] = 0.01;
5256            }
5257        }
5258        let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5259        (s, rhs)
5260    }
5261
5262    fn dense_pcg_cpu_reference(
5263        s: &Array2<f64>,
5264        rhs: &Array1<f64>,
5265        max_iterations: usize,
5266        relative_tolerance: f64,
5267    ) -> Array1<f64> {
5268        let k = rhs.len();
5269        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5270        if rhs_norm == 0.0 {
5271            return Array1::<f64>::zeros(k);
5272        }
5273        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5274        let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5275        let mut x = Array1::<f64>::zeros(k);
5276        let mut r = rhs.clone();
5277        let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5278        let mut p = z.clone();
5279        let mut sp = Array1::<f64>::zeros(k);
5280        let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5281        for _ in 0..max_iterations.max(1) {
5282            for row in 0..k {
5283                let mut acc = 0.0;
5284                for col in 0..k {
5285                    acc += s[[row, col]] * p[col];
5286                }
5287                sp[row] = acc;
5288            }
5289            let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5290            let alpha = rz / p_sp;
5291            for idx in 0..k {
5292                x[idx] += alpha * p[idx];
5293                r[idx] -= alpha * sp[idx];
5294            }
5295            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5296            if r_norm <= tol {
5297                break;
5298            }
5299            for idx in 0..k {
5300                z[idx] = inv_diag[idx] * r[idx];
5301            }
5302            let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5303            let beta = rz_next / rz;
5304            for idx in 0..k {
5305                p[idx] = z[idx] + beta * p[idx];
5306            }
5307            rz = rz_next;
5308        }
5309        x
5310    }
5311
5312    #[test]
5313    fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5314        let (s, rhs) = device_pcg_fixture(512);
5315        let max_iterations = 200usize;
5316        let relative_tolerance = 1.0e-12;
5317        let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5318        let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5319            &s,
5320            &rhs,
5321            max_iterations,
5322            relative_tolerance,
5323        ) {
5324            Ok(result) => result,
5325            // #1017 — fail loud, never skip-pass: this fixture clears the device
5326            // offload floor, so a CUDA device that is PRESENT yet declines/returns
5327            // Err means the device PCG kernel does not run on GPU (a real fault that
5328            // must not masquerade as a pass via this skip). Legit skip ONLY when no
5329            // usable CUDA device exists (CPU CI). The exact `ArrowSchurGpuFailure`
5330            // variant is folded into the assert message as the diagnostic.
5331            Err(failure) => {
5332                assert!(
5333                    gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5334                    "#1017: CUDA device present but the device reduced-beta PCG \
5335                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5336                     the kernel does not run correctly on GPU"
5337                );
5338                return;
5339            }
5340        };
5341        let max_err = cpu
5342            .iter()
5343            .zip(device.iter())
5344            .map(|(a, b)| (a - b).abs())
5345            .fold(0.0_f64, f64::max);
5346        assert!(
5347            max_err <= 1.0e-10,
5348            "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5349        );
5350        assert!(diag.matvec_calls > 0);
5351        assert_eq!(diag.matvec_calls, diag.iterations);
5352    }
5353
5354    #[test]
5355    fn dense_reference_matches_independent_solve() {
5356        let sys = build_fixture(4, 5, 3, 7);
5357        let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5358        // Re-solve by an independent matrix build and a textbook
5359        // Gaussian-elimination Cholesky to guard against typos in the
5360        // reference implementation itself.
5361        let n = sys.rows.len();
5362        let d = sys.d;
5363        let k = sys.k;
5364        let total = n * d + k;
5365        let mut h = Array2::<f64>::zeros((total, total));
5366        let mut g = ndarray::Array1::<f64>::zeros(total);
5367        for (i, row) in sys.rows.iter().enumerate() {
5368            let base = i * d;
5369            for c in 0..d {
5370                for r in 0..d {
5371                    h[[base + r, base + c]] = row.htt[[r, c]];
5372                }
5373            }
5374            for c in 0..k {
5375                for r in 0..d {
5376                    h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5377                    h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5378                }
5379            }
5380            for r in 0..d {
5381                g[base + r] = row.gt[r];
5382            }
5383        }
5384        for c in 0..k {
5385            for r in 0..k {
5386                h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5387            }
5388            g[n * d + c] = sys.gb[c];
5389        }
5390        let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5391        let rhs = g.mapv(|v| -v);
5392        let expected = cholesky_solve_vector(l.view(), rhs.view());
5393        for i in 0..n * d {
5394            assert!(
5395                (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5396                "delta_t[{i}] mismatch: got {} expected {}",
5397                solution.delta_t[i],
5398                expected[i]
5399            );
5400        }
5401        for a in 0..k {
5402            assert!(
5403                (solution.delta_beta[a] - expected[n * d + a]).abs()
5404                    < 1e-10 * (1.0 + expected[n * d + a].abs()),
5405                "delta_beta[{a}] mismatch"
5406            );
5407        }
5408    }
5409
5410    /// #1017: the row-procedural reduced-Schur matvec (the matrix-free SAE
5411    /// host backend) auto-fans its per-row point-elimination sum across rayon
5412    /// over fixed row chunks when at the top level (`n ≥
5413    /// SCHUR_MATVEC_PARALLEL_ROW_MIN`), and stays serial when already inside a
5414    /// rayon worker. The chunk-ordered fold makes the parallel result
5415    /// **deterministic** (two parallel calls are bit-identical — scheduling
5416    /// cannot change the numbers) and it agrees with the serial accumulation up
5417    /// to ULP-scale chunk reassociation (the #1017 verification gate). That
5418    /// reassociation is a genuine f64 departure from serial, so the criterion
5419    /// ranking across topology candidates is stable only up to the reassociation
5420    /// margin: a near-tie winner inside that margin can flip. This is NOT an
5421    /// exact no-move guarantee (#1211); for that, the ranking path must use the
5422    /// fixed-order serial accumulation.
5423    #[test]
5424    fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5425        use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5426        let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; // trips the parallel path
5427        let d = 3usize;
5428        let k = 24usize;
5429        let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5430        // Install a matrix-free forward/transpose pair that reads the dense
5431        // `htbeta` slabs the fixture already populated, so the procedural
5432        // backend has a well-defined operator to apply (and exercises exactly
5433        // the sparse gather/scatter the SAE Kronecker path drives).
5434        let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5435        let forward_slabs = slabs.clone();
5436        let transpose_slabs = slabs;
5437        sys.set_row_htbeta_operator(
5438            move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5439                let h = &forward_slabs[row];
5440                for r in 0..h.nrows() {
5441                    let mut acc = 0.0_f64;
5442                    for c in 0..h.ncols() {
5443                        acc += h[[r, c]] * x[c];
5444                    }
5445                    out[r] = acc;
5446                }
5447            },
5448            move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5449                let h = &transpose_slabs[row];
5450                for r in 0..h.nrows() {
5451                    for c in 0..h.ncols() {
5452                        out[c] += h[[r, c]] * v[r];
5453                    }
5454                }
5455            },
5456        );
5457
5458        let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5459            .expect("row-procedural matvec backend builds for matrix-free system");
5460        let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5461
5462        // Top-level call: auto-selects the parallel chunk-fold. Run twice and
5463        // assert bit-identity — the chunk-ordered reduction must not depend on
5464        // thread scheduling.
5465        let mut out_parallel_a = Array1::<f64>::zeros(k);
5466        matvec(&x, &mut out_parallel_a);
5467        let mut out_parallel_b = Array1::<f64>::zeros(k);
5468        matvec(&x, &mut out_parallel_b);
5469        for a in 0..k {
5470            assert_eq!(
5471                out_parallel_a[a].to_bits(),
5472                out_parallel_b[a].to_bits(),
5473                "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5474            );
5475        }
5476
5477        // Inside a rayon worker: auto-selects the serial path (nested-rayon
5478        // guard). `install` runs the closure on a pool thread, so
5479        // `current_thread_index()` is `Some`. The serial running sum and the
5480        // chunk-ordered parallel fold differ only by f64 reassociation.
5481        let mut out_serial = Array1::<f64>::zeros(k);
5482        rayon::ThreadPoolBuilder::new()
5483            .num_threads(2)
5484            .build()
5485            .expect("build rayon pool")
5486            .install(|| matvec(&x, &mut out_serial));
5487
5488        let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5489        for a in 0..k {
5490            let diff = (out_parallel_a[a] - out_serial[a]).abs();
5491            assert!(
5492                diff <= 1e-12 * (1.0 + max_abs),
5493                "row-procedural matvec parallel vs serial diverged beyond reassociation \
5494                 at index {a}: {} vs {} (diff={diff:e})",
5495                out_parallel_a[a],
5496                out_serial[a]
5497            );
5498        }
5499    }
5500
5501    /// #1017/#1026 — the frames-engaged CPU reduced-Schur matvec
5502    /// [`sae_framed_schur_matvec_cpu`] (the bit-parity oracle the GPU kernel
5503    /// mirrors) must equal the dense reduced Schur `S = (P_ββ + ρ_β I) −
5504    /// Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)` formed by the canonical dense
5505    /// reference, on a small framed system with mixed per-atom ranks
5506    /// (`r_k < p` framed + `r_k = p` un-framed). Size-independent gate.
5507    #[test]
5508    fn framed_sae_schur_matvec_matches_dense_reference() {
5509        use crate::arrow_schur::{
5510            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5511            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5512        };
5513
5514        let p = 4usize;
5515        // Three atoms: ranks 2 (framed), 4 (un-framed), 3 (framed).
5516        let ranks = vec![2usize, 4usize, 3usize];
5517        let basis_sizes = vec![2usize, 1usize, 2usize];
5518        let n_atoms = ranks.len();
5519        let mut border_offsets = Vec::with_capacity(n_atoms);
5520        let mut acc = 0usize;
5521        for k in 0..n_atoms {
5522            border_offsets.push(acc);
5523            acc += basis_sizes[k] * ranks[k];
5524        }
5525        let border_dim = acc; // 2*2 + 1*4 + 2*3 = 14
5526
5527        let mut state = 0x1234_5678_9abc_def0u64;
5528        let mut sample = || -> f64 {
5529            state = state
5530                .wrapping_mul(6364136223846793005)
5531                .wrapping_add(1442695040888963407);
5532            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5533        };
5534
5535        // Per-atom orthonormal-ish frames U_k (p × r_k) for the W = U_iᵀU_j
5536        // factors; un-framed atom (r=p) uses U = I_p.
5537        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5538        for k in 0..n_atoms {
5539            let r = ranks[k];
5540            let mut u = Array2::<f64>::zeros((p, r));
5541            for i in 0..p {
5542                for j in 0..r {
5543                    u[[i, j]] = if r == p && i == j {
5544                        1.0
5545                    } else if r == p {
5546                        0.0
5547                    } else {
5548                        sample()
5549                    };
5550                }
5551            }
5552            frames.push(u);
5553        }
5554        let w_of = |i: usize, j: usize| -> Array2<f64> {
5555            let (ui, uj) = (&frames[i], &frames[j]);
5556            let (ri, rj) = (ranks[i], ranks[j]);
5557            let mut w = Array2::<f64>::zeros((ri, rj));
5558            for a in 0..ri {
5559                for b in 0..rj {
5560                    let mut s = 0.0;
5561                    for c in 0..p {
5562                        s += ui[[c, a]] * uj[[c, b]];
5563                    }
5564                    w[[a, b]] = s;
5565                }
5566            }
5567            w
5568        };
5569
5570        // Co-occurring data-fit blocks: all diagonal pairs + one cross (0,2).
5571        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5572        let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5573        pairs.sort();
5574        for &(i, j) in &pairs {
5575            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5576            let mut g = Array2::<f64>::zeros((mi, mj));
5577            for r in 0..mi {
5578                for c in 0..mj {
5579                    g[[r, c]] = 0.3 * sample();
5580                }
5581            }
5582            // Make diagonal blocks SPD-leaning so S stays PD.
5583            if i == j {
5584                for r in 0..mi.min(mj) {
5585                    g[[r, r]] += mi as f64 + 2.0;
5586                }
5587            }
5588            frame_blocks.push(FactoredFrameGBlock {
5589                atom_i: i,
5590                atom_j: j,
5591                g,
5592                w: w_of(i, j),
5593            });
5594        }
5595
5596        // Smooth blocks λ S_k (M_k × M_k), SPD.
5597        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5598        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5599        for k in 0..n_atoms {
5600            let m = basis_sizes[k];
5601            let mut a = Array2::<f64>::zeros((m, m));
5602            for r in 0..m {
5603                for c in 0..m {
5604                    a[[r, c]] = 0.2 * sample();
5605                }
5606            }
5607            let mut s = a.t().dot(&a);
5608            for r in 0..m {
5609                s[[r, r]] += 1.0;
5610            }
5611            smooth_blocks.push(DeviceSaeSmoothBlock {
5612                global_offset: border_offsets[k],
5613                factor_a: s,
5614            });
5615            smooth_ranks.push(ranks[k]);
5616        }
5617
5618        // Build the system: n rows, dense htbeta slabs (q_i × border_dim).
5619        let n = 6usize;
5620        let q = 3usize;
5621        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5622        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5623        for i in 0..n {
5624            // SPD htt.
5625            let mut a = Array2::<f64>::zeros((q, q));
5626            for r in 0..q {
5627                for c in 0..q {
5628                    a[[r, c]] = sample();
5629                }
5630            }
5631            let mut htt = a.t().dot(&a);
5632            for r in 0..q {
5633                htt[[r, r]] += q as f64 + 1.0;
5634            }
5635            sys.rows[i].htt = htt;
5636            let mut slab = vec![0.0_f64; q * border_dim];
5637            for c in 0..q {
5638                for col in 0..border_dim {
5639                    let v = 0.15 * sample();
5640                    slab[c * border_dim + col] = v;
5641                    sys.rows[i].htbeta[[c, col]] = v;
5642                }
5643            }
5644            row_htbeta.push(slab);
5645        }
5646
5647        // Dense H_ββ from the SAME penalty ops (so the dense reference's S
5648        // matches the device penalty side exactly).
5649        let data_op =
5650            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5651                .expect("frame op");
5652        let mut hbb = data_op.to_dense();
5653        for k in 0..n_atoms {
5654            let op = IdentityRightKroneckerPenaltyOp {
5655                factor_a: smooth_blocks[k].factor_a.clone(),
5656                p: ranks[k],
5657                global_offset: border_offsets[k],
5658                k: border_dim,
5659            };
5660            let d = op.to_dense();
5661            for r in 0..border_dim {
5662                for c in 0..border_dim {
5663                    hbb[[r, c]] += d[[r, c]];
5664                }
5665            }
5666        }
5667        sys.hbb = hbb;
5668
5669        let data = DeviceSaePcgData {
5670            p,
5671            beta_dim: border_dim,
5672            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5673            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5674            smooth_blocks,
5675            sparse_g_blocks: Vec::new(),
5676            frame: Some(DeviceSaeFrameData {
5677                ranks: ranks.clone(),
5678                basis_sizes: basis_sizes.clone(),
5679                border_offsets: border_offsets.clone(),
5680                frame_blocks,
5681                smooth_ranks,
5682                row_htbeta,
5683            }),
5684        };
5685
5686        let ridge_t = 1e-7;
5687        let ridge_beta = 1e-6;
5688
5689        // Dense reference reduced Schur S (border_dim × border_dim), formed
5690        // exactly as solve_arrow_newton_step_dense_reference assembles the
5691        // bordered Hessian and eliminates the t-block.
5692        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5693        for r in 0..border_dim {
5694            for c in 0..border_dim {
5695                s_dense[[r, c]] = sys.hbb[[r, c]];
5696            }
5697            s_dense[[r, r]] += ridge_beta;
5698        }
5699        for row in &sys.rows {
5700            let mut htt = row.htt.clone();
5701            for d in 0..q {
5702                htt[[d, d]] += ridge_t;
5703            }
5704            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5705                .expect("htt PD");
5706            // Y = (htt)⁻¹ htbeta  (q × border_dim); S -= htbetaᵀ Y.
5707            let mut y = Array2::<f64>::zeros((q, border_dim));
5708            for col in 0..border_dim {
5709                let mut e = Array1::<f64>::zeros(q);
5710                for r in 0..q {
5711                    e[r] = row.htbeta[[r, col]];
5712                }
5713                let solved = cholesky_solve_vector(factor.view(), e.view());
5714                for r in 0..q {
5715                    y[[r, col]] = solved[r];
5716                }
5717            }
5718            for r in 0..border_dim {
5719                for c in 0..border_dim {
5720                    let mut acc = 0.0;
5721                    for d in 0..q {
5722                        acc += row.htbeta[[d, r]] * y[[d, c]];
5723                    }
5724                    s_dense[[r, c]] -= acc;
5725                }
5726            }
5727        }
5728
5729        // Probe vectors: compare S·x from the device-data CPU oracle vs dense S·x.
5730        let mut max_rel = 0.0_f64;
5731        for trial in 0..4 {
5732            let x: Vec<f64> = (0..border_dim)
5733                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5734                .collect();
5735            let mut got = vec![0.0_f64; border_dim];
5736            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5737                .expect("framed matvec");
5738            let mut want = vec![0.0_f64; border_dim];
5739            for r in 0..border_dim {
5740                let mut acc = 0.0;
5741                for c in 0..border_dim {
5742                    acc += s_dense[[r, c]] * x[c];
5743                }
5744                want[r] = acc;
5745            }
5746            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5747            for a in 0..border_dim {
5748                let rel = (got[a] - want[a]).abs() / scale;
5749                max_rel = max_rel.max(rel);
5750            }
5751        }
5752        assert!(
5753            max_rel <= 1e-10,
5754            "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5755        );
5756    }
5757
5758    /// #1017/#1026 GPU arm: when a CUDA device admits the framed SAE PCG, its
5759    /// solved `δβ` must match the CPU dense reduced-system solve of the SAME
5760    /// framed system (size-independent — a small device validates the kernel).
5761    /// Skips cleanly (returns) when no device is available or the policy
5762    /// declines (`solve_sae_matrix_free_pcg` → `Unavailable`).
5763    #[test]
5764    fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
5765        use crate::arrow_schur::{
5766            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5767            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5768        };
5769
5770        // Large enough to clear the device-offload policy floor (k ≥ 32 and
5771        // n·k·d·iters ≥ MATVEC_OFFLOAD_FLOPS_MIN) so the GPU kernel actually
5772        // runs on a device rather than the policy declining.
5773        let p = 6usize;
5774        let n_atoms = 8usize;
5775        let ranks: Vec<usize> = (0..n_atoms)
5776            .map(|k| if k % 2 == 0 { 3usize } else { p })
5777            .collect();
5778        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
5779        let mut border_offsets = Vec::with_capacity(n_atoms);
5780        let mut acc = 0usize;
5781        for k in 0..n_atoms {
5782            border_offsets.push(acc);
5783            acc += basis_sizes[k] * ranks[k];
5784        }
5785        let border_dim = acc; // Σ M_k·r_k = 4·(3·3) + 4·(3·6) = 36 + 72 = 108
5786
5787        let mut state = 0xfeed_face_dead_beefu64;
5788        let mut sample = || -> f64 {
5789            state = state
5790                .wrapping_mul(6364136223846793005)
5791                .wrapping_add(1442695040888963407);
5792            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5793        };
5794        let mut frames: Vec<Array2<f64>> = Vec::new();
5795        for k in 0..n_atoms {
5796            let r = ranks[k];
5797            let mut u = Array2::<f64>::zeros((p, r));
5798            for i in 0..p {
5799                for j in 0..r {
5800                    u[[i, j]] = if r == p && i == j {
5801                        1.0
5802                    } else if r == p {
5803                        0.0
5804                    } else {
5805                        sample()
5806                    };
5807                }
5808            }
5809            frames.push(u);
5810        }
5811        let w_of = |i: usize, j: usize| {
5812            let (ui, uj) = (&frames[i], &frames[j]);
5813            let (ri, rj) = (ranks[i], ranks[j]);
5814            let mut w = Array2::<f64>::zeros((ri, rj));
5815            for a in 0..ri {
5816                for b in 0..rj {
5817                    let mut s = 0.0;
5818                    for c in 0..p {
5819                        s += ui[[c, a]] * uj[[c, b]];
5820                    }
5821                    w[[a, b]] = s;
5822                }
5823            }
5824            w
5825        };
5826        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
5827        // A few off-diagonal cross blocks (symmetric pairs).
5828        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
5829            pairs.push((i, j));
5830            pairs.push((j, i));
5831        }
5832        let mut frame_blocks = Vec::new();
5833        for &(i, j) in &pairs {
5834            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5835            let mut g = Array2::<f64>::zeros((mi, mj));
5836            for r in 0..mi {
5837                for c in 0..mj {
5838                    g[[r, c]] = 0.25 * sample();
5839                }
5840            }
5841            if i == j {
5842                for r in 0..mi.min(mj) {
5843                    g[[r, r]] += mi as f64 + 2.0;
5844                }
5845            }
5846            frame_blocks.push(FactoredFrameGBlock {
5847                atom_i: i,
5848                atom_j: j,
5849                g,
5850                w: w_of(i, j),
5851            });
5852        }
5853        let mut smooth_blocks = Vec::new();
5854        let mut smooth_ranks = Vec::new();
5855        for k in 0..n_atoms {
5856            let m = basis_sizes[k];
5857            let mut a = Array2::<f64>::zeros((m, m));
5858            for r in 0..m {
5859                for c in 0..m {
5860                    a[[r, c]] = 0.2 * sample();
5861                }
5862            }
5863            let mut s = a.t().dot(&a);
5864            for r in 0..m {
5865                s[[r, r]] += 1.0;
5866            }
5867            smooth_blocks.push(DeviceSaeSmoothBlock {
5868                global_offset: border_offsets[k],
5869                factor_a: s,
5870            });
5871            smooth_ranks.push(ranks[k]);
5872        }
5873        let n = 400usize;
5874        let q = 4usize;
5875        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5876        let mut row_htbeta = Vec::new();
5877        for i in 0..n {
5878            let mut a = Array2::<f64>::zeros((q, q));
5879            for r in 0..q {
5880                for c in 0..q {
5881                    a[[r, c]] = sample();
5882                }
5883            }
5884            let mut htt = a.t().dot(&a);
5885            for r in 0..q {
5886                htt[[r, r]] += q as f64 + 1.0;
5887            }
5888            sys.rows[i].htt = htt;
5889            let mut slab = vec![0.0_f64; q * border_dim];
5890            for c in 0..q {
5891                for col in 0..border_dim {
5892                    // Small entries: with 400 rows the reduced-Schur subtraction
5893                    // Σ_i H_βtᵀ H_tt⁻¹ H_tβ must not overwhelm the PD penalty.
5894                    let v = 0.02 * sample();
5895                    slab[c * border_dim + col] = v;
5896                    sys.rows[i].htbeta[[c, col]] = v;
5897                }
5898            }
5899            row_htbeta.push(slab);
5900        }
5901        let data_op =
5902            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5903                .expect("frame op");
5904        let mut hbb = data_op.to_dense();
5905        for k in 0..n_atoms {
5906            let op = IdentityRightKroneckerPenaltyOp {
5907                factor_a: smooth_blocks[k].factor_a.clone(),
5908                p: ranks[k],
5909                global_offset: border_offsets[k],
5910                k: border_dim,
5911            };
5912            let d = op.to_dense();
5913            for r in 0..border_dim {
5914                for c in 0..border_dim {
5915                    hbb[[r, c]] += d[[r, c]];
5916                }
5917            }
5918        }
5919        sys.hbb = hbb;
5920        let data = DeviceSaePcgData {
5921            p,
5922            beta_dim: border_dim,
5923            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5924            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5925            smooth_blocks,
5926            sparse_g_blocks: Vec::new(),
5927            frame: Some(DeviceSaeFrameData {
5928                ranks: ranks.clone(),
5929                basis_sizes: basis_sizes.clone(),
5930                border_offsets: border_offsets.clone(),
5931                frame_blocks,
5932                smooth_ranks,
5933                row_htbeta,
5934            }),
5935        };
5936        let ridge_t = 1e-7;
5937        let ridge_beta = 1e-6;
5938        let rhs: Array1<f64> =
5939            Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
5940
5941        let (device, diag) =
5942            match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
5943                Ok(result) => result,
5944                // #1017 — fail loud, never skip-pass: this fixture clears the device
5945                // offload floor, so a CUDA device that is PRESENT yet declines means the
5946                // framed device PCG kernel does not run on GPU (the fault must not pass
5947                // silently). Legit skip ONLY when no usable CUDA device exists (CPU CI).
5948                // The exact `ArrowSchurGpuFailure` variant is folded into the assert.
5949                Err(failure) => {
5950                    assert!(
5951                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5952                        "#1017: CUDA device present but the framed device SAE PCG \
5953                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5954                     the kernel does not run correctly on GPU"
5955                    );
5956                    return;
5957                }
5958            };
5959
5960        // #1551 PARITY GATE — operator-residual, NOT solution-vector equality.
5961        //
5962        // The honest GPU↔CPU contract for an iterative solve of `S·δβ = rhs` is
5963        // that the device solution SOLVES the system DEFINED BY THE CPU ORACLE to
5964        // PCG tolerance — i.e. `‖S_cpu·δβ_device − rhs‖ / ‖rhs‖ ≤ tol`, where
5965        // `S_cpu` is applied with the bit-for-bit CPU oracle matvec the device
5966        // kernel mirrors (`sae_framed_schur_matvec_cpu`). This is the correct,
5967        // conditioning-robust gate: a near-singular assembled `S` has a large
5968        // condition number `κ(S)`, which amplifies an O(ε) residual difference
5969        // into an O(κ·ε) *solution-vector* difference, so comparing δβ vectors
5970        // would spuriously fail even when both operators are bit-identical and
5971        // both solves converged. (Historically this test compared δβ against a
5972        // dense-Cholesky reference and "failed" with max_rel≈0.9 because the
5973        // dense solve itself only reached ‖S·x−rhs‖≈0.1 on this fixture's
5974        // ill-conditioned S while the device PCG reached ~1e-12 — the device was
5975        // MORE accurate than the reference, not wrong. The kernel correctness is
5976        // pinned conditioning-free by `framed_sae_device_matvec_matches_cpu_oracle_*`.)
5977        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5978        let oracle_resid = |x: &Array1<f64>| -> f64 {
5979            let mut sx = vec![0.0_f64; border_dim];
5980            sae_framed_schur_matvec_cpu(
5981                &sys,
5982                &data,
5983                ridge_t,
5984                ridge_beta,
5985                x.as_slice().unwrap(),
5986                &mut sx,
5987            )
5988            .expect("cpu oracle matvec");
5989            let mut acc = 0.0_f64;
5990            for a in 0..border_dim {
5991                let e = sx[a] - rhs[a];
5992                acc += e * e;
5993            }
5994            acc.sqrt()
5995        };
5996        let s_dev_resid = oracle_resid(&device);
5997        let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
5998
5999        // Independent CPU iterative solve of the SAME operator with the SAME
6000        // Jacobi preconditioner the device builds, via the shared `pcg_core`. If
6001        // the device kernel computed a different operator, the two converged
6002        // residuals could not BOTH be tiny.
6003        let precond = {
6004            let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6005            // The reduced-Schur diagonal subtraction (device `arrow_sae_frame_diag_sub`)
6006            // mirrored on the host for the Jacobi preconditioner.
6007            let mut diag = d;
6008            for (i, row) in sys.rows.iter().enumerate() {
6009                let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6010                let qi = sys.row_dims[i];
6011                if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6012                    continue;
6013                }
6014                let mut block = row.htt.clone();
6015                for dd in 0..qi {
6016                    block[[dd, dd]] += ridge_t;
6017                }
6018                let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6019                    .expect("row htt PD");
6020                // ainv = (H_tt+ρI)⁻¹ column by column.
6021                let mut ainv = Array2::<f64>::zeros((qi, qi));
6022                for col in 0..qi {
6023                    let mut e = Array1::<f64>::zeros(qi);
6024                    e[col] = 1.0;
6025                    let s = cholesky_solve_vector(factor.view(), e.view());
6026                    for r in 0..qi {
6027                        ainv[[r, col]] = s[r];
6028                    }
6029                }
6030                for a in 0..border_dim {
6031                    let mut quad = 0.0_f64;
6032                    for c in 0..qi {
6033                        let hc = slab[c * border_dim + a];
6034                        for dd in 0..qi {
6035                            quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6036                        }
6037                    }
6038                    diag[a] -= quad;
6039                }
6040            }
6041            Array1::from_vec(diag)
6042        };
6043        let mut cpu = Array1::<f64>::zeros(border_dim);
6044        let cpu_result = {
6045            let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6046                let mut tmp = vec![0.0_f64; border_dim];
6047                sae_framed_schur_matvec_cpu(
6048                    &sys,
6049                    &data,
6050                    ridge_t,
6051                    ridge_beta,
6052                    v.as_slice().unwrap(),
6053                    &mut tmp,
6054                )
6055                .expect("cpu oracle matvec");
6056                out.assign(&Array1::from_vec(tmp));
6057            };
6058            gam_linalg::pcg::pcg_core(
6059                &mut apply,
6060                &rhs.view(),
6061                &precond.view(),
6062                1e-12,
6063                800,
6064                32,
6065                false,
6066                gam_linalg::pcg::DotReduction::Serial,
6067                &mut cpu.view_mut(),
6068            )
6069        };
6070        let s_cpu_resid = oracle_resid(&cpu);
6071        let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6072
6073        // GATE 1: the device solution solves the CPU-oracle system to PCG-grade
6074        // accuracy (proves device kernel == CPU operator AND device PCG converged).
6075        assert!(
6076            dev_rel_resid <= 1e-7,
6077            "[#1551] device δβ does not solve the CPU-oracle system: \
6078             ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6079             device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6080             means the device matvec is a DIFFERENT operator (kernel bug)",
6081            diag.stopping_reason,
6082            diag.iterations,
6083            diag.final_relative_residual,
6084        );
6085        // GATE 2: the independent CPU iterative solve of the same operator with the
6086        // same preconditioner also converges — both paths agree on the operator.
6087        assert!(
6088            cpu_rel_resid <= 1e-6,
6089            "[#1551] CPU pcg_core failed to solve the oracle system: \
6090             ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6091            cpu_result.stop,
6092            cpu_result.iterations,
6093        );
6094    }
6095
6096    /// Host mirror of the device `sae_frame_penalty_diag_host` for the framed
6097    /// Jacobi preconditioner (penalty diagonal only; the reduced-Schur diagonal
6098    /// subtraction is applied by the caller). Test-only.
6099    fn sae_frame_penalty_diag_host_for_test(
6100        data: &DeviceSaePcgData,
6101        ridge_beta: f64,
6102    ) -> Vec<f64> {
6103        let frame = data.frame.as_ref().expect("frame");
6104        let mut diag = vec![ridge_beta; data.beta_dim];
6105        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6106            let m = blk.factor_a.nrows();
6107            for ia in 0..m {
6108                let coeff = blk.factor_a[[ia, ia]];
6109                let base = blk.global_offset + ia * r;
6110                for ib in 0..r {
6111                    diag[base + ib] += coeff;
6112                }
6113            }
6114        }
6115        for blk in &frame.frame_blocks {
6116            if blk.atom_i != blk.atom_j {
6117                continue;
6118            }
6119            let r = frame.ranks[blk.atom_i];
6120            let off = frame.border_offsets[blk.atom_i];
6121            let (mi, mj) = blk.g.dim();
6122            for li in 0..mi.min(mj) {
6123                let gii = blk.g[[li, li]];
6124                let base = off + li * r;
6125                for a in 0..r {
6126                    diag[base + a] += gii * blk.w[[a, a]];
6127                }
6128            }
6129        }
6130        diag
6131    }
6132
6133    /// #1551 DEFINITIVE kernel-correctness proof: the framed reduced-Schur matvec
6134    /// `out = S·x` must agree with the CPU oracle [`sae_framed_schur_matvec_cpu`]
6135    /// element-wise, for several independent `x`, to ≤ 1e-9. This is the parity
6136    /// gate that actually localizes a kernel/marshalling defect — unlike a
6137    /// solved-`δβ` comparison, it does NOT route through a linear solve, so it is
6138    /// independent of the conditioning of the assembled `S` (a near-singular `S`
6139    /// can make a dense-Cholesky vector and an iterative-PCG vector disagree at
6140    /// the *solution* level even when both operators are bit-correct; the
6141    /// operator itself must still match here). Fails loud if CUDA is present but
6142    /// the device matvec declines; skips cleanly only when no device exists.
6143    #[test]
6144    fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6145        use crate::arrow_schur::{
6146            DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6147        };
6148
6149        // Hand-checkable frame fixture: a mix of framed (r<p) and identity-ride
6150        // (r==p) atoms, a few off-diagonal cross blocks, dense per-row H_tβ.
6151        let p = 6usize;
6152        let n_atoms = 8usize;
6153        let ranks: Vec<usize> = (0..n_atoms)
6154            .map(|k| if k % 2 == 0 { 3usize } else { p })
6155            .collect();
6156        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6157        let mut border_offsets = Vec::with_capacity(n_atoms);
6158        let mut acc = 0usize;
6159        for k in 0..n_atoms {
6160            border_offsets.push(acc);
6161            acc += basis_sizes[k] * ranks[k];
6162        }
6163        let border_dim = acc;
6164
6165        let mut state = 0x1551_0017_1026_0922u64;
6166        let mut sample = || -> f64 {
6167            state = state
6168                .wrapping_mul(6364136223846793005)
6169                .wrapping_add(1442695040888963407);
6170            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6171        };
6172        let mut frames: Vec<Array2<f64>> = Vec::new();
6173        for k in 0..n_atoms {
6174            let r = ranks[k];
6175            let mut u = Array2::<f64>::zeros((p, r));
6176            for i in 0..p {
6177                for j in 0..r {
6178                    u[[i, j]] = if r == p && i == j {
6179                        1.0
6180                    } else if r == p {
6181                        0.0
6182                    } else {
6183                        sample()
6184                    };
6185                }
6186            }
6187            frames.push(u);
6188        }
6189        let w_of = |i: usize, j: usize| {
6190            let (ui, uj) = (&frames[i], &frames[j]);
6191            let (ri, rj) = (ranks[i], ranks[j]);
6192            let mut w = Array2::<f64>::zeros((ri, rj));
6193            for a in 0..ri {
6194                for b in 0..rj {
6195                    let mut s = 0.0;
6196                    for c in 0..p {
6197                        s += ui[[c, a]] * uj[[c, b]];
6198                    }
6199                    w[[a, b]] = s;
6200                }
6201            }
6202            w
6203        };
6204        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6205        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6206            pairs.push((i, j));
6207            pairs.push((j, i));
6208        }
6209        let mut frame_blocks = Vec::new();
6210        for &(i, j) in &pairs {
6211            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6212            let mut g = Array2::<f64>::zeros((mi, mj));
6213            for r in 0..mi {
6214                for c in 0..mj {
6215                    g[[r, c]] = 0.25 * sample();
6216                }
6217            }
6218            if i == j {
6219                for r in 0..mi.min(mj) {
6220                    g[[r, r]] += mi as f64 + 2.0;
6221                }
6222            }
6223            frame_blocks.push(FactoredFrameGBlock {
6224                atom_i: i,
6225                atom_j: j,
6226                g,
6227                w: w_of(i, j),
6228            });
6229        }
6230        let mut smooth_blocks = Vec::new();
6231        let mut smooth_ranks = Vec::new();
6232        for k in 0..n_atoms {
6233            let m = basis_sizes[k];
6234            let mut a = Array2::<f64>::zeros((m, m));
6235            for r in 0..m {
6236                for c in 0..m {
6237                    a[[r, c]] = 0.2 * sample();
6238                }
6239            }
6240            let mut s = a.t().dot(&a);
6241            for r in 0..m {
6242                s[[r, r]] += 1.0;
6243            }
6244            smooth_blocks.push(DeviceSaeSmoothBlock {
6245                global_offset: border_offsets[k],
6246                factor_a: s,
6247            });
6248            smooth_ranks.push(ranks[k]);
6249        }
6250        // Modest row count: this seam bypasses the offload floor, so we keep the
6251        // fixture small and the per-row reduced-Schur term well-scaled — the
6252        // matvec parity does not depend on the assembled-S conditioning at all.
6253        let n = 32usize;
6254        let q = 4usize;
6255        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6256        let mut row_htbeta = Vec::new();
6257        for i in 0..n {
6258            let mut a = Array2::<f64>::zeros((q, q));
6259            for r in 0..q {
6260                for c in 0..q {
6261                    a[[r, c]] = sample();
6262                }
6263            }
6264            let mut htt = a.t().dot(&a);
6265            for r in 0..q {
6266                htt[[r, r]] += q as f64 + 1.0;
6267            }
6268            sys.rows[i].htt = htt;
6269            let mut slab = vec![0.0_f64; q * border_dim];
6270            for c in 0..q {
6271                for col in 0..border_dim {
6272                    let v = 0.3 * sample();
6273                    slab[c * border_dim + col] = v;
6274                    sys.rows[i].htbeta[[c, col]] = v;
6275                }
6276            }
6277            row_htbeta.push(slab);
6278        }
6279        let ridge_t = 1e-7;
6280        let ridge_beta = 1e-6;
6281        let data = DeviceSaePcgData {
6282            p,
6283            beta_dim: border_dim,
6284            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6285            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6286            smooth_blocks,
6287            sparse_g_blocks: Vec::new(),
6288            frame: Some(DeviceSaeFrameData {
6289                ranks: ranks.clone(),
6290                basis_sizes: basis_sizes.clone(),
6291                border_offsets: border_offsets.clone(),
6292                frame_blocks,
6293                smooth_ranks,
6294                row_htbeta,
6295            }),
6296        };
6297
6298        // Several independent probe vectors x, including unit axes and dense
6299        // random — a marshalling stride/offset bug shows up as a per-component
6300        // mismatch on at least one.
6301        let mut probes: Vec<Array1<f64>> = Vec::new();
6302        probes.push(Array1::from_shape_fn(border_dim, |a| {
6303            ((a as f64 + 1.0) * 0.37).sin()
6304        }));
6305        probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6306        for axis in [0usize, border_dim / 3, border_dim - 1] {
6307            let mut e = Array1::<f64>::zeros(border_dim);
6308            e[axis] = 1.0;
6309            probes.push(e);
6310        }
6311
6312        let mut any_ran = false;
6313        let mut worst = 0.0_f64;
6314        for (pi, x) in probes.iter().enumerate() {
6315            let device = match super::framed_schur_matvec_once_on_device(
6316                &sys, &data, ridge_t, ridge_beta, x,
6317            ) {
6318                Ok(out) => out,
6319                Err(failure) => {
6320                    // Fail loud: a present CUDA device that declines this seam
6321                    // (which deliberately ignores the offload floor) means the
6322                    // framed matvec kernel does not run on GPU.
6323                    assert!(
6324                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6325                        "#1551: CUDA device present but the framed device matvec \
6326                         declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6327                         does not run on GPU"
6328                    );
6329                    return;
6330                }
6331            };
6332            any_ran = true;
6333            let mut cpu = vec![0.0_f64; border_dim];
6334            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6335                .expect("cpu oracle matvec");
6336            let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6337            for a in 0..border_dim {
6338                let rel = (device[a] - cpu[a]).abs() / scale;
6339                worst = worst.max(rel);
6340                assert!(
6341                    rel <= 1e-9,
6342                    "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6343                     rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6344                    device[a],
6345                    cpu[a],
6346                );
6347            }
6348        }
6349        if any_ran {
6350            // Positive on-device confirmation: the framed matvec ran on the GPU
6351            // and matched the CPU oracle across every probe. (1e-9 is far above
6352            // the ~1e-13 fp64 GEMV round-off; a structural marshalling bug would
6353            // be O(1).)
6354            assert!(
6355                gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6356                "#1551: matvec ran but no GPU runtime — unexpected"
6357            );
6358            assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6359        }
6360    }
6361}