Skip to main content

gam_solve/gpu_kernels/
arrow_schur.rs

1//! Fully GPU-resident batched Arrow-Schur dense Cholesky solver.
2//!
3//! Implements the square-root Schur form: each local block `D_i = L_i L_i^T`
4//! is factored on device, `u_i = L_i^{-1} g_i` and `Y_i = L_i^{-1} B_i` are
5//! formed by triangular solves, the reduced shared system
6//!     `S_β = C + ρ_β I − Σ_i Y_i^T Y_i,  r_β = -g_β + Σ_i Y_i^T u_i`
7//! is assembled on device, factored once, and the back-substitution
8//!     `w_i = u_i + Y_i · δβ,  L_i^T x_i = w_i,  δt_i = -x_i`
9//! is run on device. Only the final `(δt, δβ, log|H|)` triple is downloaded.
10//!
11//! The current caller (Arrow-Schur Newton step inside PIRLS) feeds uniform
12//! local block size `d` and uniform shared width `k`, so the entire pipeline
13//! is dispatched as a single p-group; per-p grouping for heterogenous blocks
14//! is Layer D's NVRTC fused-kernel concern and lives in this module's
15//! follow-up implementation rather than in policy plumbing.
16//!
17//! On non-Linux builds the entire module degrades to a CPU-fallback shim.
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24/// Outcome of a single Arrow-Schur Newton solve.
25pub struct ArrowSchurGpuSolution {
26    pub delta_t: Array1<f64>,
27    pub delta_beta: Array1<f64>,
28    /// Natural log of the determinant of the full bordered Hessian, computed
29    /// from the local Cholesky factors and the Schur factor on device.
30    pub log_det_hessian: f64,
31}
32
33/// Reason a device path declined to run; lets the host caller decide between
34/// CPU fallback and per-row escalation. `RidgeBumpRequired` carries the
35/// estimated diagonal bump needed to clear the failed pivot.
36#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38    /// CUDA runtime unavailable, allocation failed, or workload below policy.
39    Unavailable,
40    /// A row block was not positive definite even after the requested ridge.
41    /// Caller may retry with `ridge_t + bump`.
42    RidgeBumpRequired { row: usize, bump: f64 },
43    /// Shared Schur factor failed; bordered system is rank-deficient at the
44    /// requested ridges and the CPU path should handle escalation.
45    SchurFactorFailed { reason: String },
46    /// The system carries matrix-free `H_ββ` or per-row `H_tβ` operators that
47    /// the dense GPU Schur path cannot consume. The caller should route to CPU
48    /// `InexactPCG` (or supply dense buffers) rather than treating this as a
49    /// numerical failure. See `gpu/arrow_schur.rs` Part B for the planned GPU
50    /// PCG path that will lift this restriction at K ≥ 5000.
51    GpuRequiresDenseSystem {
52        had_hbb_matvec: bool,
53        had_htbeta_matvec: bool,
54    },
55}
56
57/// Relative rounding margin (multiplier on `diag_scale · √ε`) added on top of
58/// the deficit-clearing shift in [`ridge_bump_to_make_pd`].
59///
60/// The exact shift `-(λ_min)` makes a block PD in exact arithmetic, but a
61/// single retry at precisely that magnitude is routinely re-rejected by the
62/// next POTRF because the rounding error of forming `D + ridge·I` and
63/// re-factoring is itself O(√ε). The 1024× headroom (≈ 2¹⁰, ten extra bits
64/// below the f64 mantissa's 52) clears the pivot on the first retry without
65/// materially perturbing the curvature the Newton step sees. Shared by every
66/// per-row / batched / fused producer so they suggest a consistent bump.
67const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69/// Diagonal ridge bump that is GUARANTEED to make `H_tt + (ridge_t + bump)·I`
70/// positive definite for a *symmetric* per-row block, sized from the block's
71/// own entries rather than from the factorization's pivot index.
72///
73/// # Why the old `scale · |pivot| · √ε · 1024` estimate is wrong
74///
75/// The batched/fused device paths derive the suggested bump from the
76/// factorization "pivot" — but cuSOLVER's `potrf` (and the NVRTC kernel's
77/// status code) report the failing pivot as a **1-based row index**, NOT the
78/// magnitude of the negative pivot. A block that is indefinite by `O(1)`
79/// (e.g. `H_tt = -I`, whose smallest eigenvalue is `-1`) then yields the same
80/// `bump ≈ √ε · 1024 ≈ 1.5e-5` as a block that is indefinite by `O(√ε)`. The
81/// outer LM escalation, which retries at `ridge_t + bump` and grows
82/// geometrically with a bounded step count, can never lift a strongly
83/// indefinite block out of the negative regime, so the solve fails to recover
84/// even though the block is trivially regularizable. (Surfaced by the V100
85/// `ridge_bump_required_on_non_pd_row_recovers_after_bump` validation test.)
86///
87/// # The bound
88///
89/// By the Gershgorin circle theorem every eigenvalue `λ` of the symmetric
90/// matrix `A = H_tt` satisfies, for some row `i`,
91///   `λ ≥ A[i,i] − Σ_{j≠i} |A[i,j]|`,
92/// so `λ_min(A) ≥ min_i ( A[i,i] − Σ_{j≠i} |A[i,j]| ) =: g` (the most negative
93/// Gershgorin left edge). Adding `t·I` shifts every eigenvalue up by `t`, so
94/// `A + t·I` is PD as soon as `t > -g`. We are already sitting at `ridge_t`, so
95/// the ADDITIONAL bump needed is `-(g + ridge_t)` when that is positive. We add
96/// a relative safety margin (`√ε · scale · 1024`, the same headroom the legacy
97/// estimate used) so the re-factored, rounding-perturbed block clears the pivot
98/// on the first retry, and a `max(1)`-scaled floor so a marginally-indefinite
99/// block still gets a strictly positive, non-vanishing bump.
100///
101/// The returned value is the bump to ADD to the current `ridge_t`. It is always
102/// strictly positive (the caller only constructs `RidgeBumpRequired` on an
103/// actual non-PD failure, but the bound is defensive regardless).
104#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106    let d = htt.nrows();
107    // Diagonal magnitude scale (also the legacy `scale`), and the most-negative
108    // Gershgorin left edge `g = min_i (A_ii − Σ_{j≠i} |A_ij|)`.
109    let mut scale = 1.0_f64;
110    let mut min_gershgorin_edge = f64::INFINITY;
111    for i in 0..d {
112        let diag = htt[[i, i]];
113        scale = scale.max(diag.abs());
114        let mut off_sum = 0.0_f64;
115        for j in 0..d {
116            if j != i {
117                off_sum += htt[[i, j]].abs();
118            }
119        }
120        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121    }
122    if !min_gershgorin_edge.is_finite() {
123        // d == 0 (no rows) or non-finite entries: fall back to the scale-only
124        // floor so the caller still gets a strictly positive bump.
125        return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126    }
127    // Additional shift needed so `λ_min(A) + ridge_t + bump > 0`, i.e.
128    // `bump > -(min_gershgorin_edge + ridge_t)`.
129    let deficit = -(min_gershgorin_edge + ridge_t);
130    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131    // Lift past the deficit (when positive) plus a rounding margin; never below
132    // the scale-relative floor so a marginal block still moves.
133    deficit.max(0.0) + margin
134}
135
136/// [`ridge_bump_to_make_pd`] for a `d × d` symmetric block stored column-major
137/// in a flat slice with the current ridge ALREADY baked into the diagonal
138/// (the device packers emit `D = H_tt + ridge_t·I` this way). Because the shift
139/// is already present, the Gershgorin bound is taken at `ridge_t = 0` and the
140/// returned value is still the ADDITIONAL bump to add on top of the current
141/// ridge. Returns the scale-only floor when `block` is mis-sized.
142// The column-major bump is a helper for the device tile packers, which only
143// exist on the linux CUDA path (`mod cuda`, `#[cfg(target_os = "linux")]`). It
144// therefore lives where it is used: gate it to linux. Its parity unit test is
145// gated to linux to match (a `test` token in the cfg would trip the build.rs
146// `#[cfg(test)]`-on-a-src-item ban; including `test` to dodge the non-linux
147// dead_code lint is exactly the escape hatch that ban forbids).
148#[cfg(target_os = "linux")]
149#[must_use]
150fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
151    if d == 0 || block.len() < d * d {
152        return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
153    }
154    // Column-major: element (row r, col c) at block[c*d + r]. The matrix is
155    // symmetric, so reading by column gives the same Gershgorin edges as by row.
156    let mut scale = 1.0_f64;
157    let mut min_gershgorin_edge = f64::INFINITY;
158    for i in 0..d {
159        let diag = block[i * d + i];
160        scale = scale.max(diag.abs());
161        let mut off_sum = 0.0_f64;
162        for j in 0..d {
163            if j != i {
164                off_sum += block[j * d + i].abs();
165            }
166        }
167        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
168    }
169    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
170    if !min_gershgorin_edge.is_finite() {
171        return margin;
172    }
173    (-min_gershgorin_edge).max(0.0) + margin
174}
175
176/// Entry point: attempt the fully device-resident Arrow-Schur Newton solve.
177/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` to indicate "device path
178/// declined, fall back to CPU" — never panics.
179pub fn solve_arrow_newton_step(
180    sys: &ArrowSchurSystem,
181    ridge_t: f64,
182    ridge_beta: f64,
183) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
184    let n = sys.rows.len();
185    let d = sys.d;
186    let k = sys.k;
187
188    // Detect matrix-free operators before any dim() checks so callers get a
189    // clear, actionable error instead of a generic SchurFactorFailed. The GPU
190    // dense-Schur path requires materialised H_ββ and per-row H_tβ slabs;
191    // CPU InexactPCG is the correct fallback when either operator is abstract.
192    let had_hbb_matvec = sys.hbb_matvec.is_some();
193    let had_htbeta_matvec = sys.htbeta_matvec.is_some();
194    if had_hbb_matvec || had_htbeta_matvec {
195        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
196            had_hbb_matvec,
197            had_htbeta_matvec,
198        });
199    }
200
201    if sys.hbb.dim() != (k, k) {
202        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
203            reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
204        });
205    }
206    if n == 0 || d == 0 {
207        return Err(ArrowSchurGpuFailure::Unavailable);
208    }
209    if sys
210        .rows
211        .iter()
212        .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
213    {
214        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
215            reason: "row block dimension mismatch".to_string(),
216        });
217    }
218
219    #[cfg(not(target_os = "linux"))]
220    {
221        if ridge_t.is_nan() || ridge_beta.is_nan() {
222            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
223                reason: "ridge is NaN".to_string(),
224            });
225        }
226        Err(ArrowSchurGpuFailure::Unavailable)
227    }
228
229    #[cfg(target_os = "linux")]
230    {
231        // Multi-GPU: the arrow-Schur solve is row-block separable in its forward
232        // (per-row factor / whiten / partial-Schur) and backward (per-row
233        // back-sub) phases — only the small shared K×K reduce+factor+δβ is
234        // central. When more than one device is usable, split the WHOLE solve at
235        // row-block granularity across all GPUs. The POTRF stays fused with its
236        // dependent TRSM+GEMM on each tile's own stream, so no on-stream solve is
237        // orphaned. On `Unavailable` (one device, shape below policy, transient)
238        // fall through to the single-device fused / Layer-A paths below.
239        if gam_gpu::device_runtime::GpuRuntime::global()
240            .map(gam_gpu::device_runtime::GpuRuntime::device_count)
241            .unwrap_or(0)
242            > 1
243        {
244            match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
245                Ok(sol) => return Ok(sol),
246                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
247                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
248                }
249                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
250                    return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
251                }
252                // Unavailable / GpuRequiresDenseSystem: fall through to the
253                // single-device paths (already shape-validated above).
254                Err(_) => {}
255            }
256        }
257        // Layer D admission: when the system shape passes the
258        // (Σ p³ ≥ 1e5 OR R ≥ 16) heuristic and `p ≤ MAX_FUSED_P`, the fused
259        // NVRTC kernel replaces the cuSOLVER/cuBLAS Layer A+B+C path with a
260        // single per-row block. Layer C↔D parity (math block 3 §16 test 6)
261        // requires both paths to agree to 1e-10 on identical inputs.
262        if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
263            match cuda::solve_fused(sys, ridge_t, ridge_beta) {
264                Ok(sol) => return Ok(sol),
265                // RidgeBumpRequired must surface to the outer escalation loop —
266                // the fused path's pivot diagnostic is identical in semantics
267                // to the cuSOLVER batched POTRF info code.
268                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
269                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
270                }
271                // Any other failure (Unavailable, SchurFactorFailed) falls
272                // through to the unfused path so a flaky NVRTC compile or
273                // shared-mem allocation does not abort the outer Newton step.
274                Err(_) => {}
275            }
276        }
277        cuda::solve(sys, ridge_t, ridge_beta)
278    }
279}
280
281/// Build the stacked column-major D buffer (n local d×d blocks), the stacked
282/// stacked B buffer (n local d×k blocks), and the stacked g buffer
283/// (n local d-vectors) consumed by the device pipeline. Each block is laid
284/// out column-major so a single allocation + `cuMemcpyHtoD` reaches the
285/// device without per-row dispatch overhead.
286#[cfg(target_os = "linux")]
287fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
288    let n = sys.rows.len();
289    let d = sys.d;
290    let k = sys.k;
291    let mut d_buf = Vec::with_capacity(n * d * d);
292    let mut b_buf = Vec::with_capacity(n * d * k);
293    let mut g_buf = Vec::with_capacity(n * d);
294    for row in &sys.rows {
295        pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
296    }
297    (d_buf, b_buf, g_buf)
298}
299
300#[cfg(target_os = "linux")]
301#[inline]
302fn pack_block(
303    row: &crate::arrow_schur::ArrowRowBlock,
304    ridge_t: f64,
305    d: usize,
306    k: usize,
307    d_buf: &mut Vec<f64>,
308    b_buf: &mut Vec<f64>,
309    g_buf: &mut Vec<f64>,
310) {
311    for col in 0..d {
312        for r in 0..d {
313            let mut value = row.htt[[r, col]];
314            if r == col {
315                value += ridge_t;
316            }
317            d_buf.push(value);
318        }
319    }
320    for col in 0..k {
321        for r in 0..d {
322            b_buf.push(row.htbeta[[r, col]]);
323        }
324    }
325    for r in 0..d {
326        g_buf.push(row.gt[r]);
327    }
328}
329
330/// Test-only entry that forces the Layer D + E fused NVRTC path regardless
331/// of the admission heuristic. Used by the V100 Layer C↔D parity test to
332/// drive the fused kernel at small shapes the heuristic would otherwise
333/// route through the cuSOLVER/cuBLAS Layer A+B+C path.
334#[doc(hidden)]
335#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] // `sys` is consumed only by the linux branch
336pub fn solve_arrow_newton_step_fused_force(
337    sys: &ArrowSchurSystem,
338    ridge_t: f64,
339    ridge_beta: f64,
340) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
341    if ridge_t.is_nan() || ridge_beta.is_nan() {
342        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
343            reason: "ridge is NaN".to_string(),
344        });
345    }
346    #[cfg(not(target_os = "linux"))]
347    {
348        // No NVRTC toolchain off linux: the fused path is unconditionally
349        // unavailable. `sys` is consumed only by the linux branch below; the
350        // fn-level cfg_attr allows it to read as unused here without a banned
351        // `let _` binding or a no-op `drop` of the reference.
352        Err(ArrowSchurGpuFailure::Unavailable)
353    }
354    #[cfg(target_os = "linux")]
355    {
356        if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
357            .is_none()
358        {
359            return Err(ArrowSchurGpuFailure::Unavailable);
360        }
361        cuda::solve_fused(sys, ridge_t, ridge_beta)
362    }
363}
364
365/// #1017 Phase 3: a device-resident Arrow-Schur frame whose constant Hessian
366/// blocks (`D = H_tt`, `B = H_tβ`, border `H_ββ`) and their factors stay on the
367/// device across the inner Newton loop. Construct once per frozen gate/basis
368/// frame, then call [`ResidentArrowFrameHandle::solve_gradient`] once per
369/// iterate with the fresh residual gradient — only the `O(n·d + p)` gradient
370/// crosses to the device and only `δ` crosses back, in contrast to
371/// [`solve_arrow_newton_step`] which re-uploads and re-factors the full system
372/// every call. On a non-CUDA host construction returns
373/// `ArrowSchurGpuFailure::Unavailable`.
374pub struct ResidentArrowFrameHandle {
375    #[cfg(target_os = "linux")]
376    inner: cuda::ResidentArrowFrame,
377    #[cfg(not(target_os = "linux"))]
378    _never: std::convert::Infallible,
379}
380
381impl ResidentArrowFrameHandle {
382    /// Upload the constant Hessian blocks and perform the one-time factor work.
383    pub fn new(
384        sys: &ArrowSchurSystem,
385        ridge_t: f64,
386        ridge_beta: f64,
387    ) -> Result<Self, ArrowSchurGpuFailure> {
388        // The dense device path requires materialised blocks, same admission as
389        // `solve_arrow_newton_step`.
390        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
391            return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
392                had_hbb_matvec: sys.hbb_matvec.is_some(),
393                had_htbeta_matvec: sys.htbeta_matvec.is_some(),
394            });
395        }
396        #[cfg(not(target_os = "linux"))]
397        {
398            if ridge_t.is_nan() || ridge_beta.is_nan() {
399                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
400                    reason: "ridge is NaN".to_string(),
401                });
402            }
403            Err(ArrowSchurGpuFailure::Unavailable)
404        }
405        #[cfg(target_os = "linux")]
406        {
407            Ok(Self {
408                inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
409            })
410        }
411    }
412
413    /// Solve `H δ = −gradient` for a fresh gradient reusing the resident factors.
414    pub fn solve_gradient(
415        &self,
416        g_t: &[f64],
417        g_beta: &[f64],
418    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
419        #[cfg(not(target_os = "linux"))]
420        {
421            if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
422                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
423                    reason: "non-finite gradient entry".to_string(),
424                });
425            }
426            Err(ArrowSchurGpuFailure::Unavailable)
427        }
428        #[cfg(target_os = "linux")]
429        {
430            self.inner.solve_gradient(g_t, g_beta)
431        }
432    }
433
434    /// `log|H|` for the frame (constant; depends only on the factored Hessian).
435    #[must_use]
436    pub fn log_det_hessian(&self) -> f64 {
437        #[cfg(not(target_os = "linux"))]
438        {
439            // SAFETY: off-CUDA, `ResidentArrowFrameHandle::new` always returns
440            // `Err(Unavailable)`, so no handle of this type is ever constructed and
441            // this method is statically unreachable on non-Linux targets. A NaN
442            // sentinel would silently corrupt any consumer of the log-determinant,
443            // so fail loudly on the impossible path instead.
444            panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
445        }
446        #[cfg(target_os = "linux")]
447        {
448            self.inner.log_det_hessian()
449        }
450    }
451}
452
453/// Build a GPU-backed Schur matvec closure for CPU-driven PCG at K ≥ 5000.
454///
455/// Runs the fused NVRTC forward kernel once on the dense per-row `H_tβ` slabs
456/// to compute `Y_i = L_i^{-1} H_tβ^(i)` for all rows, persists the `Y_i`
457/// factors in a host-side buffer, and returns an `Arc<dyn Fn(...)>` closure
458/// that computes the full Schur matvec
459///
460/// ```text
461/// S·x = (H_ββ + ridge_beta·I)·x  −  Σ_i Y_i^T (Y_i·x)
462/// ```
463///
464/// each time it is called. At K ≥ 5000 the `Σ_i Y_i^T (Y_i·x)` term
465/// dominates over the host↔device transfer of the K-vector `x`, so the GPU
466/// path is a clear win even with per-iteration transfer.
467///
468/// `H_ββ·x` is evaluated on the CPU using `sys.hbb_matvec` when present (the
469/// matrix-free hook for SAE-manifold scale callers) or the dense `sys.hbb`
470/// block otherwise. The `Y_i` term uses cuBLAS batched GEMV device-side; only
471/// `x` (K doubles) and `out` (K doubles) cross the host↔device boundary per
472/// PCG iteration.
473///
474/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` if CUDA is unavailable or
475/// the system shape is outside the fused kernel's admission range (e.g.
476/// `d > MAX_FUSED_P = 32` or no CUDA context). Callers should fall back to CPU
477/// `InexactPCG` on `Unavailable`.
478///
479/// Returns `Err(ArrowSchurGpuFailure::RidgeBumpRequired)` if a per-row Cholesky
480/// factor failed at the requested `ridge_t`; the outer LM escalation should
481/// bump `ridge_t` and retry.
482///
483/// # Composition with the matrix-free SAE Kronecker operator
484///
485/// When `sys.htbeta_matvec` is set (matrix-free `H_tβ` Kronecker operator),
486/// the dense `H_tβ` slabs are absent — the dense forward kernel above cannot
487/// run, and at `K = 100K` the dense `Y_i = L_i^{-1} H_tβ^(i)` (`d × K` per row)
488/// could not be materialised anyway. Instead, `build_row_procedural_matvec`
489/// returns a row-procedural Schur matvec: per row it gathers
490/// `v_i = H_tβ^(i)·x` through the forward operator (sparse `O(m_i · p)`),
491/// solves `(H_tt^(i) + ρ_t·I)^{-1} v_i` through a pre-computed per-row Cholesky
492/// factor, and scatters `H_βt^(i)·w_i` through the sparse transpose operator
493/// (`O(m_i · p)`, replacing the old `O(K)` column-probe). This is the
494/// row-procedural `a_ik · Φ_k[i,m]` Kronecker apply over the active atoms only.
495pub fn gpu_schur_matvec_backend(
496    sys: &ArrowSchurSystem,
497    ridge_t: f64,
498    ridge_beta: f64,
499) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
500    // Matrix-free H_tβ operator present: drive the row-procedural sparse
501    // Kronecker apply (active atoms only) instead of the dense forward kernel.
502    if sys.htbeta_matvec.is_some() {
503        return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
504    }
505
506    #[cfg(not(target_os = "linux"))]
507    {
508        // No CUDA runtime on non-Linux. NaN ridges are validated to ensure the
509        // same contract as the Linux path.
510        if ridge_t.is_nan() || ridge_beta.is_nan() {
511            return Err(ArrowSchurGpuFailure::Unavailable);
512        }
513        Err(ArrowSchurGpuFailure::Unavailable)
514    }
515
516    #[cfg(target_os = "linux")]
517    {
518        cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
519    }
520}
521
522/// Build a row-procedural reduced-Schur matvec for matrix-free SAE Kronecker
523/// systems, eliminating the per-row latent block via cached per-row Cholesky
524/// factors and applying the cross-block through the sparse forward/transpose
525/// Kronecker operators (active atoms only).
526///
527/// The returned closure evaluates
528/// `S·x = (H_ββ + ρ_β·I)·x − Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x`,
529/// the same reduced Schur complement the dense path forms, but never
530/// materialises the `d × K` cross-block `H_tβ^(i)`: the forward operator
531/// (`out = H_tβ^(i)·x`) and transpose operator (`out += H_βt^(i)·v`) are the
532/// sparse Kronecker gather/scatter from `SaeKroneckerRows`. The per-row factor
533/// of `H_tt^(i) + ρ_t·I` is computed once when the closure is built and reused
534/// across every CG iteration.
535///
536/// Returns `RidgeBumpRequired` if a per-row block is not positive definite at
537/// the requested `ridge_t`; the outer LM escalation bumps `ridge_t` and retries.
538fn build_row_procedural_matvec(
539    sys: &ArrowSchurSystem,
540    ridge_t: f64,
541    ridge_beta: f64,
542) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
543    use std::sync::Arc;
544    let n = sys.rows.len();
545    let k = sys.k;
546    let forward = sys
547        .htbeta_matvec
548        .clone()
549        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
550    let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
551        // A forward operator without its sparse adjoint cannot be applied
552        // row-procedurally; this is a wiring error, surfaced as a Schur failure
553        // so the caller routes to the dense CPU path rather than misreporting a
554        // numerical bump.
555        ArrowSchurGpuFailure::SchurFactorFailed {
556            reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
557                     forward operator installed without its sparse adjoint"
558                .to_string(),
559        }
560    })?;
561
562    // Pre-factor each per-row block H_tt^(i) + ρ_t·I = L_i L_iᵀ on the host.
563    // The blocks are tiny (d_i ≲ 32) and the dense cross-block slabs are
564    // absent, so there is no device forward-kernel work to amortise here; the
565    // GPU win is the reduced K-system solve in `solve_reduced_beta_pcg`.
566    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
567    for (i, row) in sys.rows.iter().enumerate() {
568        let di = row.htt.nrows();
569        if row.htt.ncols() != di || row.gt.len() != di {
570            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
571                reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
572            });
573        }
574        let mut block = row.htt.clone();
575        for r in 0..di {
576            block[[r, r]] += ridge_t;
577        }
578        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
579            .ok_or_else(|| {
580                // Deficit-aware bump from the block's own entries (Gershgorin),
581                // so the outer LM escalation lifts a strongly-indefinite block
582                // out of the negative regime in one retry.
583                ArrowSchurGpuFailure::RidgeBumpRequired {
584                    row: i,
585                    bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
586                }
587            })?;
588        factors.push(factor);
589    }
590
591    // The SAE-manifold β-Hessian lives in the structured penalty operator
592    // (data-fit Gauss-Newton `G ⊗ I_p` + smoothness Kronecker blocks + any
593    // dense analytic-β residual), NOT in the dense `hbb` accumulator — for
594    // matrix-free systems `hbb` is zero/absent. Capture the effective penalty
595    // operator so `H_ββ·x` matches the CPU `schur_matvec` path exactly. The
596    // operator's `matvec` adds (`y += P x`), so seed `out` from the ridge term.
597    let penalty_op = sys.effective_penalty_op();
598    let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
599
600    let closure: crate::arrow_schur::GpuSchurMatvec =
601        Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
602            assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
603            assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
604
605            // (H_ββ + ρ_β·I)·x into out. Seed with the ridge term, then add the
606            // structured penalty-side product (penalty_op.matvec is additive).
607            {
608                let x_slice = x.as_slice().expect("x must be contiguous");
609                let out_slice = out.as_slice_mut().expect("out must be contiguous");
610                for a in 0..k {
611                    out_slice[a] = ridge_beta * x_slice[a];
612                }
613                penalty_op.matvec(x_slice, out_slice);
614            }
615
616            // out -= Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x.
617            //
618            // #1017: this row-procedural reduced-Schur term is the matrix-free
619            // SAE path's matvec hot loop (`build_row_procedural_matvec` is the
620            // host backend `gpu_schur_matvec_backend` returns when the dense
621            // `H_tβ` slabs are absent — the production Qwen shape). At
622            // (n≈2000 rows) it ran SERIALLY on one core and allocated a fresh
623            // length-`K` `neg` plus per-row `v_i`/`w_i` on EVERY CG iteration —
624            // tens of thousands of tiny heap allocations across a solve. Each
625            // row contributes an independent length-`K` scatter, so the sum is
626            // embarrassingly parallel; fan it across rayon over fixed row chunks
627            // and fold the per-chunk length-`K` partials in chunk order so the
628            // f64 reduction is deterministic (bit-identical run-to-run)
629            // regardless of thread scheduling — it agrees with the serial sum up
630            // to ULP-scale chunk reassociation (the #1017 verification gate).
631            // Because that reassociation is a real (if tiny) departure from
632            // serial, the criterion ranking across topology candidates is stable
633            // except for candidates separated by less than the reassociation
634            // margin, where the near-tie winner can flip — not an exact no-move
635            // guarantee (#1211). Stay
636            // sequential below
637            // `SCHUR_MATVEC_PARALLEL_ROW_MIN` rows and when already inside a
638            // rayon worker (the topology race fans candidates with
639            // `run_topology_race_parallel`) — the same nested-rayon guard the
640            // CPU `schur_matvec` uses. Buffers (`v_i`, `neg`) are reused across
641            // rows within a chunk, so the per-row allocation churn is gone.
642            let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
643                && rayon::current_thread_index().is_none();
644            if parallel {
645                use rayon::prelude::*;
646                const CHUNK: usize = 64;
647                let partials: Vec<Array1<f64>> = (0..n)
648                    .into_par_iter()
649                    .chunks(CHUNK)
650                    .map(|idxs| {
651                        // One length-`K` scatter accumulator per chunk; the
652                        // per-row latent vector `v_i` (length `d_i ≲ 32`) is the
653                        // only per-row buffer, sized to the row's own `d_i`.
654                        let mut neg = Array1::<f64>::zeros(k);
655                        for i in idxs {
656                            let di = row_dims[i];
657                            // v_i = H_tβ^(i)·x (sparse Kronecker gather).
658                            let mut v_i = Array1::<f64>::zeros(di);
659                            forward(i, x.view(), &mut v_i);
660                            // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
661                            let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
662                            // neg += H_βt^(i)·w_i (sparse scatter).
663                            transpose(i, w_i.view(), &mut neg);
664                        }
665                        neg
666                    })
667                    .collect();
668                // #1017/#1175 floating-point parity contract: Rayon may
669                // schedule chunks on any worker, but `.chunks(CHUNK).collect()`
670                // returns partials in chunk-index order. Each chunk's row sum
671                // is formed locally in increasing row order, then chunk
672                // partials are folded left-to-right below. That makes the
673                // parallel row-procedural Schur term deterministic for a fixed
674                // input and chunking (no scheduling-dependent gather/scatter
675                // reordering), but it is not required to be bit-identical to
676                // the serial path because additions are reassociated at chunk
677                // boundaries. CPU/GPU validation should therefore allow
678                // ULP-scale drift while expecting stable run-to-run results.
679                let mut neg = Array1::<f64>::zeros(k);
680                for part in &partials {
681                    for a in 0..k {
682                        neg[a] += part[a];
683                    }
684                }
685                for a in 0..k {
686                    out[a] -= neg[a];
687                }
688            } else {
689                // Serial path: reuse one `neg` and one `v_i` across rows.
690                let mut neg = Array1::<f64>::zeros(k);
691                for i in 0..n {
692                    let di = row_dims[i];
693                    // v_i = H_tβ^(i)·x (sparse Kronecker gather, length d_i).
694                    let mut v_i = Array1::<f64>::zeros(di);
695                    forward(i, x.view(), &mut v_i);
696                    // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
697                    let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
698                    // neg += H_βt^(i)·w_i (sparse scatter); subtract once at end.
699                    transpose(i, w_i.view(), &mut neg);
700                }
701                for a in 0..k {
702                    out[a] -= neg[a];
703                }
704            }
705        });
706
707    Ok(closure)
708}
709
710/// Solve the reduced shared β-system `S·δβ = r` fully on device with a
711/// Jacobi-preconditioned conjugate-gradient (Steihaug truncated-CG) loop.
712///
713/// `S` is the already-reduced symmetric positive-definite `K × K` Schur
714/// complement the streaming SAE joint fit accumulates across minibatches
715/// (`StreamingArrowSchur::take_accumulators` summed over chunks, with the
716/// global β ridge folded in). The per-row latent blocks have already been
717/// eliminated into `S` on the host streaming path; the device's job is the
718/// dense `K`-dimensional solve, which is the dominant cost at `K = 100K`.
719///
720/// The dense `S·p` matvec runs on device via cuBLAS `Dgemv`, and the PCG state
721/// vectors (`x`, `r`, `z`, `p`, `S·p`) remain device-resident for the solve.
722/// Jacobi preconditioning is an elementwise CUDA kernel; only convergence
723/// scalars (`pᵀSp`, `rᵀz`, `‖r‖`) cross the host boundary per iteration, plus the
724/// final solution vector.
725///
726/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` when CUDA is unavailable
727/// or the workload is below the dispatch policy; the caller then runs the CPU
728/// reduced-β solve. Returns `Err(ArrowSchurGpuFailure::SchurFactorFailed)`
729/// when `S` carries a non-positive Jacobi diagonal (caller escalates the
730/// proximal ridge).
731pub fn solve_reduced_beta_pcg(
732    s_acc: &Array2<f64>,
733    rhs_beta: &Array1<f64>,
734    max_iterations: usize,
735    relative_tolerance: f64,
736) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
737    solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
738        .map(|(x, _)| x)
739}
740
741#[doc(hidden)]
742pub fn solve_reduced_beta_pcg_with_diagnostics(
743    s_acc: &Array2<f64>,
744    rhs_beta: &Array1<f64>,
745    max_iterations: usize,
746    relative_tolerance: f64,
747) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
748    let k = rhs_beta.len();
749    if s_acc.dim() != (k, k) {
750        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
751            reason: format!(
752                "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
753                s_acc.dim()
754            ),
755        });
756    }
757    if k == 0 {
758        return Err(ArrowSchurGpuFailure::Unavailable);
759    }
760
761    #[cfg(not(target_os = "linux"))]
762    {
763        if relative_tolerance.is_nan() || max_iterations == 0 {
764            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
765                reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
766            });
767        }
768        Err(ArrowSchurGpuFailure::Unavailable)
769    }
770
771    #[cfg(target_os = "linux")]
772    {
773        cuda::solve_reduced_beta_pcg_with_diagnostics(
774            s_acc,
775            rhs_beta,
776            max_iterations,
777            relative_tolerance,
778        )
779    }
780}
781
782pub fn solve_sae_matrix_free_pcg(
783    sys: &ArrowSchurSystem,
784    data: &DeviceSaePcgData,
785    ridge_t: f64,
786    ridge_beta: f64,
787    rhs_beta: &Array1<f64>,
788    max_iterations: usize,
789    relative_tolerance: f64,
790) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
791    if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
792        return Err(ArrowSchurGpuFailure::Unavailable);
793    }
794    #[cfg(not(target_os = "linux"))]
795    {
796        if ridge_t.is_nan()
797            || ridge_beta.is_nan()
798            || relative_tolerance.is_nan()
799            || max_iterations == 0
800        {
801            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
802                reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
803            });
804        }
805        Err(ArrowSchurGpuFailure::Unavailable)
806    }
807    #[cfg(target_os = "linux")]
808    {
809        // #1017/#1026 dispatch GUARD: framed data (frame metadata present) carries
810        // a factored β border `G ⊗ W_{ij}` data Hessian and dense per-row cross
811        // blocks the legacy `⊗ I_p` kernel CANNOT represent — feeding it framed
812        // data would silently return a WRONG Newton step (it returns Ok with no
813        // fallback). Route framed systems to the dedicated framed kernel and
814        // legacy full-`B` systems to the legacy kernel; the two never cross.
815        if data.frame.is_some() {
816            cuda::solve_sae_matrix_free_pcg_framed(
817                sys,
818                data,
819                ridge_t,
820                ridge_beta,
821                rhs_beta,
822                max_iterations,
823                relative_tolerance,
824            )
825        } else {
826            cuda::solve_sae_matrix_free_pcg(
827                sys,
828                data,
829                ridge_t,
830                ridge_beta,
831                rhs_beta,
832                max_iterations,
833                relative_tolerance,
834            )
835        }
836    }
837}
838
839/// #1551 kernel-isolating parity probe: run the framed reduced-Schur matvec
840/// `out = S·x` exactly once on the device and return it (no PCG, no offload-floor
841/// gate). The test suite diffs this element-wise against the CPU oracle
842/// [`sae_framed_schur_matvec_cpu`] to prove the GPU kernel computes the SAME
843/// operator — a check that is independent of solver conditioning (unlike a
844/// solved-`δβ` comparison, which can diverge purely because dense Cholesky and
845/// iterative PCG resolve an ill-conditioned `S` to different accuracies). On a
846/// non-CUDA host this returns `Unavailable` so the caller skips cleanly.
847#[doc(hidden)]
848pub fn framed_schur_matvec_once_on_device(
849    sys: &ArrowSchurSystem,
850    data: &DeviceSaePcgData,
851    ridge_t: f64,
852    ridge_beta: f64,
853    x: &Array1<f64>,
854) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
855    if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
856        return Err(ArrowSchurGpuFailure::Unavailable);
857    }
858    if data.frame.is_none() {
859        return Err(ArrowSchurGpuFailure::Unavailable);
860    }
861    #[cfg(not(target_os = "linux"))]
862    {
863        // CUDA is linux-only; the framed device solve is unavailable here. Read
864        // the ridge params (a real, cheap use) so the values the shared signature
865        // carries are not flagged unused on this target — without an `#[allow]`
866        // or `let _` dodge, both of which the build.rs scanner bans.
867        if ridge_t.is_finite() && ridge_beta.is_finite() {
868            return Err(ArrowSchurGpuFailure::Unavailable);
869        }
870        Err(ArrowSchurGpuFailure::Unavailable)
871    }
872    #[cfg(target_os = "linux")]
873    {
874        cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
875    }
876}
877
878/// Reference dense back-end used by tests and as the fallback when the
879/// GPU declines. Kept here (not in `arrow_schur_gpu.rs`) so the validation
880/// suite has one canonical baseline.
881#[doc(hidden)]
882pub fn solve_arrow_newton_step_dense_reference(
883    sys: &ArrowSchurSystem,
884    ridge_t: f64,
885    ridge_beta: f64,
886) -> Result<ArrowSchurGpuSolution, String> {
887    let n = sys.rows.len();
888    let d = sys.d;
889    let k = sys.k;
890    let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
891    let mut h = Array2::<f64>::zeros((total, total));
892    let mut rhs = Array1::<f64>::zeros(total);
893    for (i, row) in sys.rows.iter().enumerate() {
894        let base = i * d;
895        for c in 0..d {
896            for r in 0..d {
897                h[[base + r, base + c]] = row.htt[[r, c]];
898            }
899            h[[base + c, base + c]] += ridge_t;
900        }
901        for c in 0..k {
902            for r in 0..d {
903                let value = row.htbeta[[r, c]];
904                h[[base + r, n * d + c]] = value;
905                h[[n * d + c, base + r]] = value;
906            }
907        }
908        for r in 0..d {
909            rhs[base + r] = -row.gt[r];
910        }
911    }
912    for c in 0..k {
913        for r in 0..k {
914            h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
915        }
916        h[[n * d + c, n * d + c]] += ridge_beta;
917        rhs[n * d + c] = -sys.gb[c];
918    }
919    let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
920        .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
921    let mut log_det = 0.0_f64;
922    for i in 0..total {
923        log_det += factor[[i, i]].ln();
924    }
925    log_det *= 2.0;
926    let solved = cholesky_solve_vector(factor.view(), rhs.view());
927    let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
928    let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
929    Ok(ArrowSchurGpuSolution {
930        delta_t,
931        delta_beta,
932        log_det_hessian: log_det,
933    })
934}
935
936/// Frames-engaged reduced-Schur penalty-side matvec `out = (P_ββ + ρ_β I)·x`,
937/// computed purely from the factored device data (issue #1017/#1026). This is
938/// the CPU bit-parity ORACLE for the GPU `arrow_sae_*` penalty kernels on the
939/// frames path: smooth `λ S_k ⊗ I_{r_k}` (each `smooth_blocks[i]` at its
940/// `global_offset` with right-width `frame.smooth_ranks[i]`) plus data-fit
941/// `G_{ij} ⊗ W_{ij}` (each `frame.frame_blocks` entry, with the `μ`-major /
942/// frame-minor index `border_offset[atom] + basis·r + frame_coord`). The
943/// accumulation order matches the device kernels exactly.
944///
945/// `out` is OVERWRITTEN: first set to `ρ_β·x`, then the penalty blocks add in.
946#[doc(hidden)]
947pub fn sae_framed_penalty_matvec_cpu(
948    data: &DeviceSaePcgData,
949    ridge_beta: f64,
950    x: &[f64],
951    out: &mut [f64],
952) {
953    let frame = data
954        .frame
955        .as_ref()
956        .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
957    let k = data.beta_dim;
958    for a in 0..k {
959        out[a] = ridge_beta * x[a];
960    }
961    // Smooth penalty `λ S_k ⊗ I_{r_k}`: y[off + ia·r + ib] += Σ_ja S[ia,ja]·x[off + ja·r + ib].
962    for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
963        let off = blk.global_offset;
964        let m = blk.factor_a.nrows();
965        for i_a in 0..m {
966            for i_b in 0..r {
967                let mut acc = 0.0_f64;
968                for j_a in 0..m {
969                    let s = blk.factor_a[[i_a, j_a]];
970                    if s == 0.0 {
971                        continue;
972                    }
973                    acc += s * x[off + j_a * r + i_b];
974                }
975                out[off + i_a * r + i_b] += acc;
976            }
977        }
978    }
979    // Data-fit penalty `G_{ij} ⊗ W_{ij}`.
980    for blk in &frame.frame_blocks {
981        let r_i = frame.ranks[blk.atom_i];
982        let r_j = frame.ranks[blk.atom_j];
983        let off_i = frame.border_offsets[blk.atom_i];
984        let off_j = frame.border_offsets[blk.atom_j];
985        let (m_i, m_j) = blk.g.dim();
986        for li in 0..m_i {
987            let yi_base = off_i + li * r_i;
988            for lj in 0..m_j {
989                let g = blk.g[[li, lj]];
990                if g == 0.0 {
991                    continue;
992                }
993                let xj_base = off_j + lj * r_j;
994                for a in 0..r_i {
995                    let mut acc = 0.0_f64;
996                    for b in 0..r_j {
997                        acc += blk.w[[a, b]] * x[xj_base + b];
998                    }
999                    out[yi_base + a] += g * acc;
1000                }
1001            }
1002        }
1003    }
1004}
1005
1006/// Frames-engaged FULL reduced-Schur matvec `out = S·x` purely from the device
1007/// data, where `S = (P_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)`
1008/// (issue #1017/#1026). The penalty side is [`sae_framed_penalty_matvec_cpu`];
1009/// the per-row reduced term reads the dense `frame.row_htbeta[i]`
1010/// (`q_i × border_dim`, row-major), solves against the row's
1011/// `H_tt^(i)+ρ_t I` Cholesky factor, and scatters the transpose back. This is
1012/// the size-independent bit-parity oracle the device kernel mirrors; it is also
1013/// the matvec the GPU PCG iterates.
1014#[doc(hidden)]
1015pub fn sae_framed_schur_matvec_cpu(
1016    sys: &ArrowSchurSystem,
1017    data: &DeviceSaePcgData,
1018    ridge_t: f64,
1019    ridge_beta: f64,
1020    x: &[f64],
1021    out: &mut [f64],
1022) -> Result<(), String> {
1023    let frame = data
1024        .frame
1025        .as_ref()
1026        .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1027    let k = data.beta_dim;
1028    sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1029    if frame.row_htbeta.len() != sys.rows.len() {
1030        return Err(format!(
1031            "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1032            frame.row_htbeta.len(),
1033            sys.rows.len()
1034        ));
1035    }
1036    for (i, row) in sys.rows.iter().enumerate() {
1037        let slab = &frame.row_htbeta[i];
1038        if slab.is_empty() {
1039            continue;
1040        }
1041        let qi = sys.row_dims[i];
1042        if qi == 0 || slab.len() != qi * k {
1043            continue;
1044        }
1045        // h = H_tβ^(i) · x  (length q_i).
1046        let mut h = vec![0.0_f64; qi];
1047        for c in 0..qi {
1048            let base = c * k;
1049            let mut acc = 0.0_f64;
1050            for a in 0..k {
1051                acc += slab[base + a] * x[a];
1052            }
1053            h[c] = acc;
1054        }
1055        // solve (H_tt^(i)+ρ_t I) s = h.
1056        let mut block = row.htt.clone();
1057        for d in 0..qi {
1058            block[[d, d]] += ridge_t;
1059        }
1060        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1061            .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1062        let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1063        // out -= H_βt^(i) · s = (H_tβ^(i))ᵀ · s.
1064        for c in 0..qi {
1065            let sc = s[c];
1066            if sc == 0.0 {
1067                continue;
1068            }
1069            let base = c * k;
1070            for a in 0..k {
1071                out[a] -= slab[base + a] * sc;
1072            }
1073        }
1074    }
1075    Ok(())
1076}
1077
1078#[cfg(target_os = "linux")]
1079mod cuda {
1080    use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1081    use gam_gpu::driver::to_i32;
1082    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1083    use crate::arrow_schur::{
1084        ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1085    };
1086    use cudarc::cublas::sys::{
1087        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1088    };
1089    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1090    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1091    use cudarc::driver::{
1092        CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1093        PushKernelArg,
1094    };
1095    use ndarray::Array1;
1096    use std::sync::{Arc, OnceLock};
1097
1098    /// Per-row work slot for the row-block-granular multi-GPU solve. Inputs are
1099    /// the packed single-row buffers (`d×d` D block + ρ_t ridge, `d×k` B block,
1100    /// `d` g vector); the forward pass fills the whitened factors `l/u/y` and the
1101    /// per-tile reduction lands in the tile's leading slot.
1102    struct RowSlot {
1103        // Inputs (packed once on the host, column-major).
1104        d_block: Vec<f64>, // d*d
1105        b_block: Vec<f64>, // d*k
1106        g_vec: Vec<f64>,   // d
1107        // Forward outputs, kept on the host for the back-sub pass.
1108        l_block: Vec<f64>, // d*d lower factor, column-major
1109        u_vec: Vec<f64>,   // d   (= L^{-1} g)
1110        y_block: Vec<f64>, // d*k (= L^{-1} B), column-major
1111        log_det_local: f64,
1112        // Set on a non-PD pivot so the orchestrator can raise RidgeBumpRequired
1113        // for the offending global row instead of silently falling back.
1114        bump: Option<f64>,
1115        // Tile-level reduction, written into the tile's first slot only.
1116        tile_partial_schur: Option<Vec<f64>>, // k*k col-major, = Σ Y_iᵀY_i
1117        tile_partial_rhs: Option<Vec<f64>>,   // k, = Σ Y_iᵀu_i
1118        // Back-sub output for this row.
1119        delta_t_block: Vec<f64>, // d
1120    }
1121
1122    /// Row-block-granular multi-GPU Arrow-Schur Newton solve.
1123    ///
1124    /// The solve is separable across row blocks in both phases:
1125    ///   * forward — each row's local Cholesky `L_i`, whitening
1126    ///     `u_i = L_i⁻¹g_i`, `Y_i = L_i⁻¹B_i`, and partial Schur
1127    ///     `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` are independent;
1128    ///   * backward — `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)` is independent.
1129    /// Only the small shared `K×K` reduce + factor + `δβ` solve is central.
1130    ///
1131    /// `gam_gpu::pool::scatter_batched` hands each device a contiguous row
1132    /// tile on its own bound context/stream; the per-tile forward keeps the
1133    /// POTRF fused with its dependent TRSM + Schur GEMM on that one stream, so no
1134    /// on-stream solve is orphaned. Tile partials and per-tile `log|L|` are
1135    /// reduced on the host (in tile/row order), `S_β` is factored on the primary
1136    /// device, and the back-sub is scattered back across the same tiles.
1137    ///
1138    /// Returns `Unavailable` (caller uses a single-device path) when the system
1139    /// carries matrix-free operators, the shared block is not dense `K×K`, the
1140    /// pool is single-device, or any tile's device work declines. A non-PD tip
1141    /// block surfaces as `RidgeBumpRequired` for the precise global row.
1142    pub(super) fn solve_multi_gpu(
1143        sys: &ArrowSchurSystem,
1144        ridge_t: f64,
1145        ridge_beta: f64,
1146    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1147        let n = sys.rows.len();
1148        let d = sys.d;
1149        let k = sys.k;
1150        if n == 0 || d == 0 || k == 0 {
1151            return Err(ArrowSchurGpuFailure::Unavailable);
1152        }
1153        // Dense shared block + materialised per-row slabs are required; the
1154        // public entry already rejected matrix-free operators, but re-check so
1155        // this routine is safe in isolation.
1156        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1157            return Err(ArrowSchurGpuFailure::Unavailable);
1158        }
1159
1160        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1161            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1162        if runtime.device_count() < 2 {
1163            return Err(ArrowSchurGpuFailure::Unavailable);
1164        }
1165
1166        // Pack one slot per row (column-major), folding ρ_t into each D block.
1167        let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1168        for row in &sys.rows {
1169            if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1170                return Err(ArrowSchurGpuFailure::Unavailable);
1171            }
1172            let mut d_block = Vec::with_capacity(d * d);
1173            let mut b_block = Vec::with_capacity(d * k);
1174            let mut g_vec = Vec::with_capacity(d);
1175            pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1176            slots.push(RowSlot {
1177                d_block,
1178                b_block,
1179                g_vec,
1180                l_block: Vec::new(),
1181                u_vec: Vec::new(),
1182                y_block: Vec::new(),
1183                log_det_local: 0.0,
1184                bump: None,
1185                tile_partial_schur: None,
1186                tile_partial_rhs: None,
1187                delta_t_block: vec![0.0; d],
1188            });
1189        }
1190
1191        // ---- Forward pass: per-device row tile, fused on its own stream ----
1192        let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1193            forward_tile(ordinal, d, k, tile)
1194        });
1195        if forward_ok.is_none() {
1196            return Err(ArrowSchurGpuFailure::Unavailable);
1197        }
1198
1199        // Surface a non-PD tip block as a precise per-row ridge bump.
1200        let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1201        if let Some((row, bump)) = slots
1202            .iter()
1203            .enumerate()
1204            .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1205        {
1206            return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1207        }
1208
1209        // ---- Central: reduce tile partials → S_β, r_β; factor; solve δβ ----
1210        // Seed S_β with H_ββ + ρ_β I (column-major) and r_β with -g_β, then fold
1211        // in the per-tile partials in tile order so the reduction order tracks
1212        // the single-device accumulation (up to inter-tile reassociation).
1213        let mut schur_host = vec![0.0_f64; k * k];
1214        for col in 0..k {
1215            for row in 0..k {
1216                let mut v = sys.hbb[[row, col]];
1217                if row == col {
1218                    v += ridge_beta;
1219                }
1220                schur_host[col * k + row] = v;
1221            }
1222        }
1223        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1224        let mut log_det = 0.0_f64;
1225        for start in tile_starts(&row_base_of_tile) {
1226            let slot = &slots[start];
1227            let partial_schur = slot
1228                .tile_partial_schur
1229                .as_ref()
1230                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1231            let partial_rhs = slot
1232                .tile_partial_rhs
1233                .as_ref()
1234                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1235            // `accumulate_schur` writes `partial_schur = -Σ_tile Y_iᵀY_i` (GEMM
1236            // α=-1, β=1 into a zero seed) and `partial_rhs = +Σ_tile Y_iᵀu_i`.
1237            // The reduced Schur is `S = (H_ββ+ρI) − Σ_all Y_iᵀY_i`, so adding the
1238            // (already-negated) partials reproduces the single-device sign.
1239            for idx in 0..k * k {
1240                schur_host[idx] += partial_schur[idx];
1241            }
1242            for a in 0..k {
1243                rhs_host[a] += partial_rhs[a];
1244            }
1245        }
1246        for slot in &slots {
1247            log_det += slot.log_det_local;
1248        }
1249
1250        // Factor S_β and solve δβ on the primary device (small K×K leaf). The
1251        // stream carries the primary context (same pattern as `solve()`); no
1252        // thread bind is needed for the cuSOLVER/cuBLAS handles created from it.
1253        let primary = runtime.selected_device().ordinal;
1254        let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1255            .and_then(|ctx| ctx.new_stream().ok())
1256            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1257        let solver =
1258            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1259        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1260        let mut schur_dev = stream
1261            .clone_htod(&schur_host)
1262            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1263        let mut rhs_dev = stream
1264            .clone_htod(&rhs_host)
1265            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1266        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1267        if info != 0 {
1268            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1269                reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1270            });
1271        }
1272        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1273        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1274        let delta_beta_host = stream
1275            .clone_dtoh(&rhs_dev)
1276            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1277        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1278        let l_schur_host = stream
1279            .clone_dtoh(&schur_dev)
1280            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1281        for j in 0..k {
1282            log_det += l_schur_host[j * k + j].ln();
1283        }
1284        log_det *= 2.0;
1285
1286        // ---- Backward pass: δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ), per-device tile ----
1287        let delta_beta_ref = &delta_beta_host;
1288        let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1289            back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1290        });
1291        if back_ok.is_none() {
1292            return Err(ArrowSchurGpuFailure::Unavailable);
1293        }
1294
1295        // Stitch per-row δt into the stacked (n*d) result.
1296        let mut delta_t = Array1::<f64>::zeros(n * d);
1297        for (i, slot) in slots.iter().enumerate() {
1298            let base = i * d;
1299            for r in 0..d {
1300                delta_t[base + r] = slot.delta_t_block[r];
1301            }
1302        }
1303
1304        Ok(ArrowSchurGpuSolution {
1305            delta_t,
1306            delta_beta,
1307            log_det_hessian: log_det,
1308        })
1309    }
1310
1311    /// Tile starts: the leading global row index of each device tile (where the
1312    /// tile-level partial reduction was written by the forward pass).
1313    fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1314        tiles.iter().map(|(_, range)| range.start)
1315    }
1316
1317    /// Forward pass for one device row tile, running on `ordinal`'s bound stream.
1318    /// Factors each row block, whitens `u`/`Y`, accumulates the tile's partial
1319    /// Schur `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` into the tile's leading slot, keeps the
1320    /// per-row `L`/`u`/`Y` on the host for back-sub, and records the per-row
1321    /// `Σ_j log L_jj`. A non-PD pivot is recorded in `slot.bump` (the tile still
1322    /// returns `Some(())` so the orchestrator raises a precise `RidgeBumpRequired`
1323    /// rather than collapsing the whole batch to CPU).
1324    fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1325        if tile.is_empty() {
1326            return Some(());
1327        }
1328        // `scatter_batched` has already bound this ordinal's context on this
1329        // worker thread; the stream below targets that same device.
1330        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1331            .and_then(|ctx| ctx.new_stream().ok())?;
1332        let solver = DnHandle::new(stream.clone()).ok()?;
1333        let blas = CudaBlas::new(stream.clone()).ok()?;
1334        let m = tile.len();
1335
1336        // Stack the tile's D, B, g into contiguous device buffers (same layout
1337        // the single-device path packs for `m` rows).
1338        let mut d_host = Vec::with_capacity(m * d * d);
1339        let mut b_host = Vec::with_capacity(m * d * k);
1340        let mut g_host = Vec::with_capacity(m * d);
1341        for slot in tile.iter() {
1342            d_host.extend_from_slice(&slot.d_block);
1343            b_host.extend_from_slice(&slot.b_block);
1344            g_host.extend_from_slice(&slot.g_vec);
1345        }
1346        let mut d_dev = stream.clone_htod(&d_host).ok()?;
1347        let mut b_dev = stream.clone_htod(&b_host).ok()?;
1348        let mut g_dev = stream.clone_htod(&g_host).ok()?;
1349
1350        // Batched POTRF; a non-PD block records its bump and stops the tile.
1351        // The bump is deficit-aware (Gershgorin lower bound on λ_min of the
1352        // already-ridged `d_block`), NOT derived from the cuSOLVER `info` —
1353        // which is a 1-based pivot ROW INDEX, not a pivot magnitude — so a
1354        // strongly-indefinite block recovers in one outer-loop retry.
1355        let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1356        if let Some(local) = info_host.iter().position(|info| *info != 0) {
1357            tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1358                &tile[local].d_block,
1359                d,
1360            ));
1361            return Some(());
1362        }
1363
1364        // Whiten: u = L⁻¹ g, Y = L⁻¹ B.
1365        trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1366        trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1367
1368        // Tile partial Schur: zero-seeded so the host adds the H_ββ seed once.
1369        let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1370        let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1371        accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1372
1373        // Download L, u, Y, and the tile partials.
1374        let l_host = stream.clone_dtoh(&d_dev).ok()?;
1375        let u_host = stream.clone_dtoh(&g_dev).ok()?;
1376        let y_host = stream.clone_dtoh(&b_dev).ok()?;
1377        let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1378        let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1379
1380        for (local, slot) in tile.iter_mut().enumerate() {
1381            let l_base = local * d * d;
1382            let u_base = local * d;
1383            let y_base = local * d * k;
1384            slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1385            slot.u_vec = u_host[u_base..u_base + d].to_vec();
1386            slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1387            let mut log_det_local = 0.0_f64;
1388            for j in 0..d {
1389                log_det_local += l_host[l_base + j * d + j].ln();
1390            }
1391            slot.log_det_local = log_det_local;
1392        }
1393        tile[0].tile_partial_schur = Some(partial_schur);
1394        tile[0].tile_partial_rhs = Some(partial_rhs);
1395        Some(())
1396    }
1397
1398    /// Back-substitution for one device row tile: `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)`.
1399    /// Re-uploads the tile's kept `L`/`u`/`Y` to `ordinal`, applies the GEMV
1400    /// accumulate + transposed TRSM, and writes each row's `δt` into its slot.
1401    fn back_sub_tile(
1402        ordinal: usize,
1403        d: usize,
1404        k: usize,
1405        delta_beta: &[f64],
1406        tile: &mut [RowSlot],
1407    ) -> Option<()> {
1408        if tile.is_empty() {
1409            return Some(());
1410        }
1411        // `scatter_batched` has already bound this ordinal's context on this
1412        // worker thread; the stream below targets that same device.
1413        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1414            .and_then(|ctx| ctx.new_stream().ok())?;
1415        let blas = CudaBlas::new(stream.clone()).ok()?;
1416        let m = tile.len();
1417
1418        let mut l_host = Vec::with_capacity(m * d * d);
1419        let mut u_host = Vec::with_capacity(m * d);
1420        let mut y_host = Vec::with_capacity(m * d * k);
1421        for slot in tile.iter() {
1422            l_host.extend_from_slice(&slot.l_block);
1423            u_host.extend_from_slice(&slot.u_vec);
1424            y_host.extend_from_slice(&slot.y_block);
1425        }
1426        let d_dev = stream.clone_htod(&l_host).ok()?;
1427        let mut g_dev = stream.clone_htod(&u_host).ok()?;
1428        let b_dev = stream.clone_htod(&y_host).ok()?;
1429        let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1430
1431        // g ← u + Y·δβ, then x = L⁻ᵀ g; δt = -x.
1432        accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1433        trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1434        let x_host = stream.clone_dtoh(&g_dev).ok()?;
1435        for (local, slot) in tile.iter_mut().enumerate() {
1436            let base = local * d;
1437            for r in 0..d {
1438                slot.delta_t_block[r] = -x_host[base + r];
1439            }
1440        }
1441        Some(())
1442    }
1443
1444    pub(super) fn solve(
1445        sys: &ArrowSchurSystem,
1446        ridge_t: f64,
1447        ridge_beta: f64,
1448    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1449        let n = sys.rows.len();
1450        let d = sys.d;
1451        let k = sys.k;
1452        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1453            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1454
1455        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1456            .and_then(|ctx| ctx.new_stream().ok())
1457            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1458        let solver =
1459            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1460        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1461
1462        // ----- Pack + upload D, B, g -----
1463        let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1464        let mut d_dev = stream
1465            .clone_htod(&d_host)
1466            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1467        let mut b_dev = stream
1468            .clone_htod(&b_host)
1469            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1470        let mut g_dev = stream
1471            .clone_htod(&g_host)
1472            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1473
1474        // ----- Layer A: batched lower Cholesky of D in place -----
1475        // This POTRF is fused with the downstream TRSM + Schur GEMM + back-sub
1476        // on this one stream, so splitting only the POTRF across devices would
1477        // orphan the dependent on-stream solves. Multi-GPU here is the
1478        // whole-solve row-block split in `solve_arrow_newton_step` (see
1479        // `solve_multi_gpu`), not a per-layer split — this device-resident path
1480        // is the single-device leaf the split dispatches per tile.
1481        let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1482        if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1483            // `info` is cuSOLVER's 1-based pivot ROW INDEX, not a magnitude;
1484            // size the bump from the block's own entries (Gershgorin λ_min
1485            // bound) so a strongly-indefinite block recovers in one retry.
1486            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1487                row: idx,
1488                bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1489            });
1490        }
1491
1492        // ----- Layer B (1/2): in-place triangular solves -----
1493        // u_i = L_i^{-1} g_i, packed as a stacked (n*d) column-vector.
1494        trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1495        // Y_i = L_i^{-1} B_i, in place over the (n*d) × k buffer (laid out as
1496        // n stacked column-major d×k tiles).
1497        trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1498
1499        // ----- Layer B (2/2): Schur reduction via single big GEMM / GEMV -----
1500        // Y_all is (n*d) × k column-major: viewing all n stacked d×k tiles as
1501        // one big matrix is bit-exact because each tile is column-major with
1502        // leading dim d and the tiles are contiguous in memory, so the
1503        // combined leading dim is n*d only for the *outer* matrix view. To
1504        // make the single-GEMM equivalence hold we must treat the stacked
1505        // buffer as (n*d) × k column-major with leading dim = n*d, which
1506        // means columns of Y_all are interleaved by row across blocks.
1507        // That is NOT what we packed. So we use the cuBLAS stride pattern
1508        // instead: stride-by-block, transpose-A, and *accumulate* into one
1509        // S_β buffer via beta=1 across batches. Equivalent flop count, no
1510        // extra reduction kernel, and correct layout.
1511        //
1512        // Concretely: schur ← C + ρ_β I; rhs ← -g_β; then for each block
1513        //   schur -= Y_i^T Y_i      (k×k)
1514        //   rhs   += Y_i^T u_i      (k)
1515        // We launch this as `n` sequential GEMMs/GEMVs with beta=1 on the
1516        // accumulator. Layer D fuses these into one NVRTC launch.
1517        let schur_init: Vec<f64> = {
1518            let mut tmp = Vec::with_capacity(k * k);
1519            for col in 0..k {
1520                for row in 0..k {
1521                    let mut v = sys.hbb[[row, col]];
1522                    if row == col {
1523                        v += ridge_beta;
1524                    }
1525                    tmp.push(v);
1526                }
1527            }
1528            tmp
1529        };
1530        let mut schur_dev = stream
1531            .clone_htod(&schur_init)
1532            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1533        let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1534        let mut rhs_dev = stream
1535            .clone_htod(&rhs_init)
1536            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1537
1538        accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1539
1540        // ----- Layer C (1/2): factor S_β and solve for δβ -----
1541        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1542        if info != 0 {
1543            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1544                reason: format!("Schur Cholesky failed at pivot {info}"),
1545            });
1546        }
1547        // δβ ← L_S^{-T} L_S^{-1} rhs
1548        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1549        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1550        let delta_beta_host = stream
1551            .clone_dtoh(&rhs_dev)
1552            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1553        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1554
1555        // ----- Layer C (2/2): back-sub δt_i = -L_i^{-T} (u_i + Y_i δβ) -----
1556        // Already on device:
1557        //   g_dev holds u_i stacked (n*d).
1558        //   b_dev holds Y_i stacked column-major n×(d×k) tiles.
1559        // Compute g_dev ← g_dev + Y_block · δβ per block (cuBLAS gemv with beta=1),
1560        // then in-place trsm with L_i^T (CUBLAS_OP_T) to obtain x_i, and finally
1561        // δt_i = -x_i on host after download.
1562        accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1563        trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1564
1565        let x_host = stream
1566            .clone_dtoh(&g_dev)
1567            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1568        let mut delta_t = Array1::<f64>::zeros(n * d);
1569        for (i, v) in x_host.iter().enumerate() {
1570            delta_t[i] = -*v;
1571        }
1572
1573        // ----- log|H| = 2 Σ log L_{i,jj} + 2 Σ log R_{β,aa} -----
1574        let l_local_host = stream
1575            .clone_dtoh(&d_dev)
1576            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1577        let l_schur_host = stream
1578            .clone_dtoh(&schur_dev)
1579            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1580        let mut log_det = 0.0_f64;
1581        for i in 0..n {
1582            let base = i * d * d;
1583            for j in 0..d {
1584                log_det += l_local_host[base + j * d + j].ln();
1585            }
1586        }
1587        for j in 0..k {
1588            log_det += l_schur_host[j * k + j].ln();
1589        }
1590        log_det *= 2.0;
1591
1592        Ok(ArrowSchurGpuSolution {
1593            delta_t,
1594            delta_beta,
1595            log_det_hessian: log_det,
1596        })
1597    }
1598
1599    fn potrf_batched(
1600        solver: &DnHandle,
1601        stream: &Arc<CudaStream>,
1602        p: usize,
1603        batch: usize,
1604        matrices: &mut CudaSlice<f64>,
1605    ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1606        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1607        let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1608        let matrix_len = p * p;
1609        let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1610        let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1611        let mut ptrs = Vec::with_capacity(batch);
1612        for idx in 0..batch {
1613            ptrs.push(base_ptr + (idx as u64) * bytes_per);
1614        }
1615        let mut ptrs_dev = stream
1616            .clone_htod(&ptrs)
1617            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1618        let mut info_dev = stream
1619            .alloc_zeros::<i32>(batch)
1620            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1621        let status = {
1622            let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1623            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1624            // SAFETY: pointer array and info buffer live on the device,
1625            // matrices_dev holds `batch` contiguous p×p column-major blocks.
1626            unsafe {
1627                cusolver_sys::cusolverDnDpotrfBatched(
1628                    solver.cu(),
1629                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1630                    p_i,
1631                    ptrs_ptr as *mut *mut f64,
1632                    p_i,
1633                    info_ptr as *mut i32,
1634                    batch_i,
1635                )
1636            }
1637        };
1638        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1639            return Err(ArrowSchurGpuFailure::Unavailable);
1640        }
1641        stream
1642            .clone_dtoh(&info_dev)
1643            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1644    }
1645
1646    fn potrf_single(
1647        solver: &DnHandle,
1648        stream: &Arc<CudaStream>,
1649        p: usize,
1650        matrix: &mut CudaSlice<f64>,
1651    ) -> Result<i32, ArrowSchurGpuFailure> {
1652        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1653        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1654        let mut lwork = 0_i32;
1655        {
1656            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1657            // SAFETY: buffer query against a live p-by-p column-major device matrix.
1658            let status = unsafe {
1659                cusolver_sys::cusolverDnDpotrf_bufferSize(
1660                    solver.cu(),
1661                    uplo,
1662                    p_i,
1663                    mat_ptr as *mut f64,
1664                    p_i,
1665                    &mut lwork,
1666                )
1667            };
1668            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1669                return Err(ArrowSchurGpuFailure::Unavailable);
1670            }
1671        }
1672        let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1673        let mut workspace = stream
1674            .alloc_zeros::<f64>(lwork_usize.max(1))
1675            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1676        let mut info_dev = stream
1677            .alloc_zeros::<i32>(1)
1678            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1679        {
1680            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1681            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1682            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1683            // SAFETY: all three pointers refer to live, correctly sized device buffers.
1684            let status = unsafe {
1685                cusolver_sys::cusolverDnDpotrf(
1686                    solver.cu(),
1687                    uplo,
1688                    p_i,
1689                    mat_ptr as *mut f64,
1690                    p_i,
1691                    work_ptr as *mut f64,
1692                    lwork,
1693                    info_ptr as *mut i32,
1694                )
1695            };
1696            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1697                return Err(ArrowSchurGpuFailure::Unavailable);
1698            }
1699        }
1700        let info_host = stream
1701            .clone_dtoh(&info_dev)
1702            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1703        Ok(info_host[0])
1704    }
1705
1706    /// In-place lower-triangular solves `X_i ← L_i^{-1} X_i` over the n stacked
1707    /// d×nrhs RHS tiles in `rhs`. Uses `cublasDtrsmBatched` so all n solves
1708    /// hit the device in one launch.
1709    fn trsm_batched_lower_inplace(
1710        blas: &CudaBlas,
1711        stream: &Arc<CudaStream>,
1712        d: usize,
1713        n: usize,
1714        nrhs: usize,
1715        l_stack: &CudaSlice<f64>,
1716        rhs_stack: &mut CudaSlice<f64>,
1717    ) -> Result<(), ArrowSchurGpuFailure> {
1718        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1719    }
1720
1721    /// As above but with `L_i^T` instead of `L_i`.
1722    fn trsm_batched_lower_inplace_transposed(
1723        blas: &CudaBlas,
1724        stream: &Arc<CudaStream>,
1725        d: usize,
1726        n: usize,
1727        nrhs: usize,
1728        l_stack: &CudaSlice<f64>,
1729        rhs_stack: &mut CudaSlice<f64>,
1730    ) -> Result<(), ArrowSchurGpuFailure> {
1731        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1732    }
1733
1734    fn trsm_batched_inplace_inner(
1735        blas: &CudaBlas,
1736        stream: &Arc<CudaStream>,
1737        d: usize,
1738        n: usize,
1739        nrhs: usize,
1740        l_stack: &CudaSlice<f64>,
1741        rhs_stack: &mut CudaSlice<f64>,
1742        transposed: bool,
1743    ) -> Result<(), ArrowSchurGpuFailure> {
1744        let alpha = 1.0_f64;
1745        let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1746        let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1747        let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1748        let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1749        let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1750        let (l_base, _l_record) = l_stack.device_ptr(stream);
1751        let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1752        let mut l_ptrs = Vec::with_capacity(n);
1753        let mut rhs_ptrs = Vec::with_capacity(n);
1754        for i in 0..n {
1755            l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1756            rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1757        }
1758        let mut l_ptrs_dev = stream
1759            .clone_htod(&l_ptrs)
1760            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1761        let mut rhs_ptrs_dev = stream
1762            .clone_htod(&rhs_ptrs)
1763            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1764        let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1765        let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1766        let op = if transposed {
1767            cublasOperation_t::CUBLAS_OP_T
1768        } else {
1769            cublasOperation_t::CUBLAS_OP_N
1770        };
1771        let handle = *blas.handle();
1772        // SAFETY: pointer arrays and base buffers were just constructed from
1773        // live device allocations covering the entire batch.
1774        let status = unsafe {
1775            cudarc::cublas::sys::cublasDtrsmBatched(
1776                handle,
1777                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1778                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1779                op,
1780                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1781                d_i,
1782                nrhs_i,
1783                &alpha,
1784                l_ptrs_ptr as *const *const f64,
1785                d_i,
1786                rhs_ptrs_ptr as *const *mut f64,
1787                d_i,
1788                batch_i,
1789            )
1790        };
1791        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1792            return Err(ArrowSchurGpuFailure::Unavailable);
1793        }
1794        Ok(())
1795    }
1796
1797    /// Single-matrix lower-triangular solve: `rhs ← L^{-1} rhs` (or
1798    /// `L^{-T} rhs` if `transposed`). For the Schur Cholesky back-sub.
1799    fn trsm_single(
1800        blas: &CudaBlas,
1801        stream: &Arc<CudaStream>,
1802        n: usize,
1803        l: &CudaSlice<f64>,
1804        rhs: &mut CudaSlice<f64>,
1805        upper: bool,
1806        transposed: bool,
1807    ) -> Result<(), ArrowSchurGpuFailure> {
1808        let alpha = 1.0_f64;
1809        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1810        let handle = *blas.handle();
1811        let (l_ptr, _l_rec) = l.device_ptr(stream);
1812        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1813        // SAFETY: single n×n lower factor and n-vector RHS on device.
1814        let status = unsafe {
1815            cudarc::cublas::sys::cublasDtrsm_v2(
1816                handle,
1817                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1818                if upper {
1819                    cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1820                } else {
1821                    cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1822                },
1823                if transposed {
1824                    cublasOperation_t::CUBLAS_OP_T
1825                } else {
1826                    cublasOperation_t::CUBLAS_OP_N
1827                },
1828                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1829                n_i,
1830                1,
1831                &alpha,
1832                l_ptr as *const f64,
1833                n_i,
1834                rhs_ptr as *mut f64,
1835                n_i,
1836            )
1837        };
1838        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1839            return Err(ArrowSchurGpuFailure::Unavailable);
1840        }
1841        Ok(())
1842    }
1843
1844    /// Accumulate `schur ← schur − Σ_i Y_i^T Y_i` and `rhs ← rhs + Σ_i Y_i^T u_i`
1845    /// using one GEMM and one GEMV per block. Each call uses beta=1 to chain
1846    /// the accumulation device-side.
1847    fn accumulate_schur(
1848        blas: &CudaBlas,
1849        d: usize,
1850        k: usize,
1851        n: usize,
1852        y_stack: &CudaSlice<f64>,
1853        u_stack: &CudaSlice<f64>,
1854        schur: &mut CudaSlice<f64>,
1855        rhs: &mut CudaSlice<f64>,
1856    ) -> Result<(), ArrowSchurGpuFailure> {
1857        let y_block_elems = d * k;
1858        let u_block_elems = d;
1859        for i in 0..n {
1860            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1861            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1862            // GEMM: schur += (-1) · Y_i^T · Y_i  (Y_i is d×k col-major; out is k×k)
1863            let gemm_cfg = GemmConfig::<f64> {
1864                transa: cublasOperation_t::CUBLAS_OP_T,
1865                transb: cublasOperation_t::CUBLAS_OP_N,
1866                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1867                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1868                k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1869                alpha: -1.0,
1870                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1871                ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1872                beta: 1.0,
1873                ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1874            };
1875            // SAFETY: y_slice is d×k col-major, schur is k×k col-major; alpha/beta scalars set above.
1876            unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1877                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1878            // GEMV: rhs += 1 · Y_i^T · u_i
1879            let gemv_cfg = GemvConfig::<f64> {
1880                trans: cublasOperation_t::CUBLAS_OP_T,
1881                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1882                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1883                alpha: 1.0,
1884                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1885                incx: 1,
1886                beta: 1.0,
1887                incy: 1,
1888            };
1889            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1890            // device buffers; `rhs` is the length-k accumulator.
1891            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1892                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1893        }
1894        Ok(())
1895    }
1896
1897    /// `#1017` resident gradient path: accumulate ONLY the Schur RHS term
1898    /// `rhs += Σ_i Y_iᵀ u_i`, skipping the `−Σ_i Y_iᵀ Y_i` matrix GEMM that the
1899    /// resident frame already folded into its persistent `L_S` factor. This is
1900    /// the per-iterate-cheap counterpart of [`accumulate_schur`]: the GEMV here
1901    /// is bit-identical to the GEMV inside `accumulate_schur` (same config, same
1902    /// `beta=1` accumulation order over rows), so the resident frame's `δβ`
1903    /// matches a full `solve()` at the same gradient.
1904    fn accumulate_schur_rhs_only(
1905        blas: &CudaBlas,
1906        d: usize,
1907        k: usize,
1908        n: usize,
1909        y_stack: &CudaSlice<f64>,
1910        u_stack: &CudaSlice<f64>,
1911        rhs: &mut CudaSlice<f64>,
1912    ) -> Result<(), ArrowSchurGpuFailure> {
1913        let y_block_elems = d * k;
1914        let u_block_elems = d;
1915        for i in 0..n {
1916            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1917            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1918            let gemv_cfg = GemvConfig::<f64> {
1919                trans: cublasOperation_t::CUBLAS_OP_T,
1920                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1921                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1922                alpha: 1.0,
1923                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1924                incx: 1,
1925                beta: 1.0,
1926                incy: 1,
1927            };
1928            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1929            // device buffers; `rhs` is the length-k accumulator.
1930            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1931                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1932        }
1933        Ok(())
1934    }
1935
1936    /// Accumulate `g_dev[i] ← u_i + Y_i · δβ` per block. This is the
1937    /// pre-trsm RHS for the back-substitution `L_i^T x_i = w_i`.
1938    fn accumulate_back_sub_rhs(
1939        blas: &CudaBlas,
1940        d: usize,
1941        k: usize,
1942        n: usize,
1943        y_stack: &CudaSlice<f64>,
1944        delta_beta: &CudaSlice<f64>,
1945        u_stack: &mut CudaSlice<f64>,
1946    ) -> Result<(), ArrowSchurGpuFailure> {
1947        let y_block_elems = d * k;
1948        let u_block_elems = d;
1949        for i in 0..n {
1950            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1951            let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1952            let gemv_cfg = GemvConfig::<f64> {
1953                trans: cublasOperation_t::CUBLAS_OP_N,
1954                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1955                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1956                alpha: 1.0,
1957                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1958                incx: 1,
1959                beta: 1.0,
1960                incy: 1,
1961            };
1962            // SAFETY: y_slice / delta_beta / u_slice are live device buffers
1963            // of the expected sizes (d×k, k, d).
1964            unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1965                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1966        }
1967        Ok(())
1968    }
1969
1970    // ────────────────────────────────────────────────────────────────────
1971    // Layer D + E — fused NVRTC dispatch.
1972    //
1973    // The forward kernel (`arrow_schur_forward_pgroup`) is a single launch
1974    // that, per row block, factors `D_i + ρI = L_i L_iᵀ` in shared memory,
1975    // forward-solves `u_i = L_i⁻¹ g_i` and `Y_i = L_i⁻¹ B_i`, and emits the
1976    // per-block Schur partials `partial_S[i] = Yᵀ Y` (R×R) and
1977    // `partial_r[i] = Yᵀ u` (R). The host reduces partials on the CPU after
1978    // dtoh (one fused sum across `n` blocks of R²+R doubles; cheap because
1979    // n·R² ≲ 5M doubles at large scale), assembles `S_β`, factors it via
1980    // cuSOLVER, and launches the back-substitution kernel
1981    // `arrow_schur_back_sub_pgroup` to recover `δt_i = -L_i⁻ᵀ(u_i + Y_i δβ)`
1982    // without re-uploading the local factors.
1983    // ────────────────────────────────────────────────────────────────────
1984
1985    use std::collections::HashMap;
1986    use std::sync::Mutex;
1987
1988    /// One compiled NVRTC module per `(cc_major, cc_minor, p_max, r_template)`.
1989    /// `cc_*` lets one process drive multiple device generations; the
1990    /// `(p_max, r_template)` pair selects the shared-memory layout baked into
1991    /// the kernel source.
1992    struct FusedModuleCache {
1993        modules: Mutex<
1994            HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1995        >,
1996    }
1997
1998    fn fused_module_cache() -> &'static FusedModuleCache {
1999        static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
2000        CACHE.get_or_init(|| FusedModuleCache {
2001            modules: Mutex::new(HashMap::new()),
2002        })
2003    }
2004
2005    fn fused_module_for(
2006        ctx: &Arc<CudaContext>,
2007        key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2008    ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2009        let cache = fused_module_cache();
2010        if let Ok(guard) = cache.modules.lock() {
2011            if let Some(existing) = guard.get(&key) {
2012                return Ok(existing.clone());
2013            }
2014        }
2015        let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2016            key.p_max as usize,
2017            key.r_template as usize,
2018        );
2019        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).map_err(|err| {
2020            ArrowSchurGpuFailure::SchurFactorFailed {
2021                reason: format!(
2022                    "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2023                    key.p_max, key.r_template
2024                ),
2025            }
2026        })?;
2027        let module = ctx
2028            .load_module(ptx)
2029            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2030        if let Ok(mut guard) = cache.modules.lock() {
2031            guard.entry(key).or_insert_with(|| module.clone());
2032        }
2033        Ok(module)
2034    }
2035
2036    const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2037extern "C" __global__ void arrow_pcg_jacobi_mul(
2038    const double* __restrict__ inv_diag,
2039    const double* __restrict__ r,
2040    double* __restrict__ z,
2041    int n
2042) {
2043    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2044    if (idx < n) {
2045        z[idx] = inv_diag[idx] * r[idx];
2046    }
2047}
2048
2049extern "C" __global__ void arrow_pcg_update_p(
2050    const double* __restrict__ z,
2051    double beta,
2052    double* __restrict__ p,
2053    int n
2054) {
2055    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2056    if (idx < n) {
2057        p[idx] = z[idx] + beta * p[idx];
2058    }
2059}
2060
2061extern "C" __global__ void arrow_sae_init(
2062    double* __restrict__ out,
2063    const double* __restrict__ x,
2064    double ridge,
2065    int n
2066) {
2067    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2068    if (idx < n) {
2069        out[idx] = ridge * x[idx];
2070    }
2071}
2072
2073extern "C" __global__ void arrow_sae_smooth_matvec(
2074    const double* __restrict__ x,
2075    double* __restrict__ out,
2076    const int* __restrict__ block_offsets,
2077    const int* __restrict__ block_m,
2078    const int* __restrict__ factor_ptr,
2079    const double* __restrict__ factors,
2080    int p,
2081    int n_blocks
2082) {
2083    int block_id = blockIdx.y;
2084    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2085    if (block_id >= n_blocks) {
2086        return;
2087    }
2088    int m = block_m[block_id];
2089    int total = m * p;
2090    if (linear >= total) {
2091        return;
2092    }
2093    int li = linear / p;
2094    int oc = linear - li * p;
2095    int off = block_offsets[block_id];
2096    int fbase = factor_ptr[block_id];
2097    double acc = 0.0;
2098    for (int lj = 0; lj < m; ++lj) {
2099        double a = factors[fbase + li * m + lj];
2100        acc += a * x[off + lj * p + oc];
2101    }
2102    out[off + li * p + oc] += acc;
2103}
2104
2105extern "C" __global__ void arrow_sae_sparse_g_matvec(
2106    const double* __restrict__ x,
2107    double* __restrict__ out,
2108    const int* __restrict__ row_off,
2109    const int* __restrict__ col_off,
2110    const int* __restrict__ rows,
2111    const int* __restrict__ cols,
2112    const int* __restrict__ data_ptr,
2113    const double* __restrict__ data,
2114    int p,
2115    int n_blocks
2116) {
2117    int block_id = blockIdx.y;
2118    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2119    if (block_id >= n_blocks) {
2120        return;
2121    }
2122    int m_i = rows[block_id];
2123    int m_j = cols[block_id];
2124    int total = m_i * p;
2125    if (linear >= total) {
2126        return;
2127    }
2128    int li = linear / p;
2129    int oc = linear - li * p;
2130    int rbase = row_off[block_id];
2131    int cbase = col_off[block_id];
2132    int dbase = data_ptr[block_id];
2133    double acc = 0.0;
2134    for (int lj = 0; lj < m_j; ++lj) {
2135        acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2136    }
2137    // #1017 — a row atom co-occurs with multiple column atoms, so several
2138    // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2139    // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2140    // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2141    // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2142    atomicAdd(&out[(rbase + li) * p + oc], acc);
2143}
2144
2145extern "C" __global__ void arrow_sae_gather_u(
2146    const double* __restrict__ x,
2147    const int* __restrict__ row_ptr,
2148    const int* __restrict__ beta_base,
2149    const double* __restrict__ phi,
2150    double* __restrict__ u,
2151    int p,
2152    int n_rows
2153) {
2154    int row = blockIdx.y;
2155    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2156    if (row >= n_rows || oc >= p) {
2157        return;
2158    }
2159    double acc = 0.0;
2160    int start = row_ptr[row];
2161    int end = row_ptr[row + 1];
2162    for (int e = start; e < end; ++e) {
2163        acc += phi[e] * x[beta_base[e] + oc];
2164    }
2165    u[row * p + oc] = acc;
2166}
2167
2168extern "C" __global__ void arrow_sae_apply_l(
2169    const double* __restrict__ u,
2170    const int* __restrict__ jac_ptr,
2171    const double* __restrict__ jac,
2172    double* __restrict__ w,
2173    int p,
2174    int max_q,
2175    int n_rows
2176) {
2177    int row = blockIdx.y;
2178    int c = blockIdx.x * blockDim.x + threadIdx.x;
2179    if (row >= n_rows) {
2180        return;
2181    }
2182    int jstart = jac_ptr[row];
2183    int q = (jac_ptr[row + 1] - jstart) / p;
2184    if (c >= q) {
2185        return;
2186    }
2187    double acc = 0.0;
2188    for (int oc = 0; oc < p; ++oc) {
2189        acc += jac[jstart + c * p + oc] * u[row * p + oc];
2190    }
2191    w[row * max_q + c] = acc;
2192}
2193
2194extern "C" __global__ void arrow_sae_apply_ainv(
2195    const double* __restrict__ ainv,
2196    const double* __restrict__ w,
2197    double* __restrict__ v,
2198    int max_q,
2199    int n_rows
2200) {
2201    int row = blockIdx.y;
2202    int c = blockIdx.x * blockDim.x + threadIdx.x;
2203    if (row >= n_rows || c >= max_q) {
2204        return;
2205    }
2206    double acc = 0.0;
2207    int base = row * max_q * max_q;
2208    for (int j = 0; j < max_q; ++j) {
2209        acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2210    }
2211    v[row * max_q + c] = acc;
2212}
2213
2214extern "C" __global__ void arrow_sae_scatter_sub(
2215    const double* __restrict__ v,
2216    const int* __restrict__ jac_ptr,
2217    const double* __restrict__ jac,
2218    const int* __restrict__ row_ptr,
2219    const int* __restrict__ beta_base,
2220    const double* __restrict__ phi,
2221    double* __restrict__ out,
2222    int p,
2223    int max_q,
2224    int n_rows
2225) {
2226    int row = blockIdx.y;
2227    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2228    if (row >= n_rows || oc >= p) {
2229        return;
2230    }
2231    int jstart = jac_ptr[row];
2232    int q = (jac_ptr[row + 1] - jstart) / p;
2233    double lt_v = 0.0;
2234    for (int c = 0; c < q; ++c) {
2235        lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2236    }
2237    int start = row_ptr[row];
2238    int end = row_ptr[row + 1];
2239    for (int e = start; e < end; ++e) {
2240        atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2241    }
2242}
2243
2244extern "C" __global__ void arrow_sae_diag_sub(
2245    double* __restrict__ diag,
2246    const double* __restrict__ ainv,
2247    const int* __restrict__ jac_ptr,
2248    const double* __restrict__ jac,
2249    const int* __restrict__ row_ptr,
2250    const int* __restrict__ beta_base,
2251    const double* __restrict__ phi,
2252    int p,
2253    int max_q,
2254    int n_rows
2255) {
2256    int row = blockIdx.y;
2257    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2258    if (row >= n_rows || oc >= p) {
2259        return;
2260    }
2261    int jstart = jac_ptr[row];
2262    int q = (jac_ptr[row + 1] - jstart) / p;
2263    int abase = row * max_q * max_q;
2264    double quad = 0.0;
2265    for (int c = 0; c < q; ++c) {
2266        double lc = jac[jstart + c * p + oc];
2267        for (int d = 0; d < q; ++d) {
2268            quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2269        }
2270    }
2271    int start = row_ptr[row];
2272    int end = row_ptr[row + 1];
2273    for (int e = start; e < end; ++e) {
2274        double pe = phi[e];
2275        atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2276    }
2277}
2278
2279/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2280 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2281 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2282 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2283 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2284
2285extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2286    const double* __restrict__ x,
2287    double* __restrict__ out,
2288    const int* __restrict__ block_offsets,
2289    const int* __restrict__ block_m,
2290    const int* __restrict__ block_r,
2291    const int* __restrict__ factor_ptr,
2292    const double* __restrict__ factors,
2293    int n_blocks
2294) {
2295    int block_id = blockIdx.y;
2296    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2297    if (block_id >= n_blocks) {
2298        return;
2299    }
2300    int m = block_m[block_id];
2301    int r = block_r[block_id];
2302    int total = m * r;
2303    if (linear >= total) {
2304        return;
2305    }
2306    int li = linear / r;
2307    int ib = linear - li * r;
2308    int off = block_offsets[block_id];
2309    int fbase = factor_ptr[block_id];
2310    double acc = 0.0;
2311    for (int lj = 0; lj < m; ++lj) {
2312        double a = factors[fbase + li * m + lj];
2313        acc += a * x[off + lj * r + ib];
2314    }
2315    out[off + li * r + ib] += acc;
2316}
2317
2318extern "C" __global__ void arrow_sae_frame_g_matvec(
2319    const double* __restrict__ x,
2320    double* __restrict__ out,
2321    const int* __restrict__ off_i,
2322    const int* __restrict__ off_j,
2323    const int* __restrict__ r_i,
2324    const int* __restrict__ r_j,
2325    const int* __restrict__ m_i,
2326    const int* __restrict__ m_j,
2327    const int* __restrict__ g_ptr,
2328    const double* __restrict__ g_data,
2329    const int* __restrict__ w_ptr,
2330    const double* __restrict__ w_data,
2331    int n_blocks
2332) {
2333    int block_id = blockIdx.y;
2334    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2335    if (block_id >= n_blocks) {
2336        return;
2337    }
2338    int ri = r_i[block_id];
2339    int rj = r_j[block_id];
2340    int mi = m_i[block_id];
2341    int mj = m_j[block_id];
2342    int total = mi * ri;
2343    if (linear >= total) {
2344        return;
2345    }
2346    int li = linear / ri;       // basis row in atom i
2347    int a = linear - li * ri;   // frame coord in atom i
2348    int oi = off_i[block_id];
2349    int oj = off_j[block_id];
2350    int gbase = g_ptr[block_id];
2351    int wbase = w_ptr[block_id];
2352    double acc = 0.0;
2353    for (int lj = 0; lj < mj; ++lj) {
2354        double g = g_data[gbase + li * mj + lj];
2355        if (g == 0.0) { continue; }
2356        int xj_base = oj + lj * rj;
2357        double inner = 0.0;
2358        for (int b = 0; b < rj; ++b) {
2359            inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2360        }
2361        acc += g * inner;
2362    }
2363    // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2364    // multiple co-occurring (i,j) frame blocks running concurrently on
2365    // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2366    // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2367    atomicAdd(&out[oi + li * ri + a], acc);
2368}
2369
2370/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2371 *   h_i   = H_tβ^(i) · x                (length q_i)
2372 *   s_i   = (H_tt^(i)+ρ_t I)⁻¹ h_i      (apply cached ainv, length q_i)
2373 *   out  -= (H_tβ^(i))ᵀ · s_i           (scatter into border_dim)
2374 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2375 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2376extern "C" __global__ void arrow_sae_frame_apply_h(
2377    const double* __restrict__ x,
2378    const int* __restrict__ htb_ptr,
2379    const double* __restrict__ htb,
2380    const int* __restrict__ q_of,
2381    double* __restrict__ hvec,
2382    int k,
2383    int max_q,
2384    int n_rows
2385) {
2386    int row = blockIdx.y;
2387    int c = blockIdx.x * blockDim.x + threadIdx.x;
2388    if (row >= n_rows) { return; }
2389    int q = q_of[row];
2390    if (c >= q) { return; }
2391    int base = htb_ptr[row] + c * k;
2392    double acc = 0.0;
2393    for (int a = 0; a < k; ++a) {
2394        acc += htb[base + a] * x[a];
2395    }
2396    hvec[row * max_q + c] = acc;
2397}
2398
2399extern "C" __global__ void arrow_sae_frame_apply_ainv(
2400    const double* __restrict__ ainv,
2401    const double* __restrict__ hvec,
2402    const int* __restrict__ q_of,
2403    double* __restrict__ svec,
2404    int max_q,
2405    int n_rows
2406) {
2407    int row = blockIdx.y;
2408    int c = blockIdx.x * blockDim.x + threadIdx.x;
2409    if (row >= n_rows || c >= max_q) { return; }
2410    int q = q_of[row];
2411    double acc = 0.0;
2412    int abase = row * max_q * max_q;
2413    for (int j = 0; j < q; ++j) {
2414        acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2415    }
2416    svec[row * max_q + c] = acc;
2417}
2418
2419extern "C" __global__ void arrow_sae_frame_scatter_h(
2420    const double* __restrict__ svec,
2421    const int* __restrict__ htb_ptr,
2422    const double* __restrict__ htb,
2423    const int* __restrict__ q_of,
2424    double* __restrict__ out,
2425    int k,
2426    int max_q,
2427    int n_rows
2428) {
2429    int row = blockIdx.y;
2430    int a = blockIdx.x * blockDim.x + threadIdx.x;
2431    if (row >= n_rows || a >= k) { return; }
2432    int q = q_of[row];
2433    int hbase = htb_ptr[row];
2434    double acc = 0.0;
2435    for (int c = 0; c < q; ++c) {
2436        acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2437    }
2438    atomicAdd(&out[a], -acc);
2439}
2440
2441/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2442extern "C" __global__ void arrow_sae_frame_diag_sub(
2443    double* __restrict__ diag,
2444    const double* __restrict__ ainv,
2445    const int* __restrict__ htb_ptr,
2446    const double* __restrict__ htb,
2447    const int* __restrict__ q_of,
2448    int k,
2449    int max_q,
2450    int n_rows
2451) {
2452    int row = blockIdx.y;
2453    int a = blockIdx.x * blockDim.x + threadIdx.x;
2454    if (row >= n_rows || a >= k) { return; }
2455    int q = q_of[row];
2456    int hbase = htb_ptr[row];
2457    int abase = row * max_q * max_q;
2458    double quad = 0.0;
2459    for (int c = 0; c < q; ++c) {
2460        double hc = htb[hbase + c * k + a];
2461        for (int d = 0; d < q; ++d) {
2462            quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2463        }
2464    }
2465    atomicAdd(&diag[a], -quad);
2466}
2467"#;
2468
2469    fn pcg_vector_module(
2470        ctx: &Arc<CudaContext>,
2471    ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2472        static CACHE: gam_gpu::device_cache::PtxModuleCache =
2473            gam_gpu::device_cache::PtxModuleCache::new();
2474        CACHE
2475            .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2476            .map_err(|err| {
2477                // #1551: an NVRTC compile / module-load failure of
2478                // PCG_VECTOR_KERNEL_SOURCE means the device SAE PCG cannot run;
2479                // log it (the historical silent collapse to `Unavailable` is what
2480                // masked the missing `--gpu-architecture` for so long) and fall
2481                // back to the CPU.
2482                log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2483                ArrowSchurGpuFailure::Unavailable
2484            })
2485    }
2486
2487    fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2488        let threads = 256u32;
2489        let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2490        Ok(LaunchConfig {
2491            grid_dim: (blocks, 1, 1),
2492            block_dim: (threads, 1, 1),
2493            shared_mem_bytes: 0,
2494        })
2495    }
2496
2497    fn launch_jacobi_mul(
2498        stream: &Arc<CudaStream>,
2499        module: &Arc<CudaModule>,
2500        inv_diag: &CudaSlice<f64>,
2501        r: &CudaSlice<f64>,
2502        z: &mut CudaSlice<f64>,
2503        n: usize,
2504    ) -> Result<(), ArrowSchurGpuFailure> {
2505        let kernel = module
2506            .load_function("arrow_pcg_jacobi_mul")
2507            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2508        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2509        let mut builder = stream.launch_builder(&kernel);
2510        builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2511        // SAFETY: all buffers have length n and belong to `stream`; the kernel only
2512        // reads/writes indices `< n`.
2513        unsafe { builder.launch(pcg_launch_config(n)?) }
2514            .map(drop)
2515            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2516    }
2517
2518    fn launch_update_p(
2519        stream: &Arc<CudaStream>,
2520        module: &Arc<CudaModule>,
2521        z: &CudaSlice<f64>,
2522        beta: f64,
2523        p: &mut CudaSlice<f64>,
2524        n: usize,
2525    ) -> Result<(), ArrowSchurGpuFailure> {
2526        let kernel = module
2527            .load_function("arrow_pcg_update_p")
2528            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2529        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2530        let mut builder = stream.launch_builder(&kernel);
2531        builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2532        // SAFETY: z/p both have length n and belong to `stream`; the kernel only
2533        // reads/writes indices `< n`.
2534        unsafe { builder.launch(pcg_launch_config(n)?) }
2535            .map(drop)
2536            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2537    }
2538
2539    struct DeviceSaePcgBuffers {
2540        row_ptr: CudaSlice<i32>,
2541        beta_base: CudaSlice<i32>,
2542        phi: CudaSlice<f64>,
2543        jac_ptr: CudaSlice<i32>,
2544        jac: CudaSlice<f64>,
2545        smooth_offsets: CudaSlice<i32>,
2546        smooth_m: CudaSlice<i32>,
2547        smooth_ptr: CudaSlice<i32>,
2548        smooth_data: CudaSlice<f64>,
2549        g_row_off: CudaSlice<i32>,
2550        g_col_off: CudaSlice<i32>,
2551        g_rows: CudaSlice<i32>,
2552        g_cols: CudaSlice<i32>,
2553        g_ptr: CudaSlice<i32>,
2554        g_data: CudaSlice<f64>,
2555        ainv: CudaSlice<f64>,
2556        u: CudaSlice<f64>,
2557        w: CudaSlice<f64>,
2558        v: CudaSlice<f64>,
2559        n_rows: usize,
2560        p: usize,
2561        k: usize,
2562        max_q: usize,
2563        smooth_blocks: usize,
2564        g_blocks: usize,
2565    }
2566
2567    fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2568        to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2569    }
2570
2571    fn sae_penalty_diag_host(
2572        data: &DeviceSaePcgData,
2573        ridge_beta: f64,
2574    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2575        let mut diag = vec![ridge_beta; data.beta_dim];
2576        for block in &data.smooth_blocks {
2577            let (rows, cols) = block.factor_a.dim();
2578            if rows != cols {
2579                return Err(ArrowSchurGpuFailure::Unavailable);
2580            }
2581            for row in 0..rows {
2582                let coeff = block.factor_a[[row, row]];
2583                let base = block
2584                    .global_offset
2585                    .checked_add(
2586                        row.checked_mul(data.p)
2587                            .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2588                    )
2589                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2590                let end = base
2591                    .checked_add(data.p)
2592                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2593                if end > diag.len() {
2594                    return Err(ArrowSchurGpuFailure::Unavailable);
2595                }
2596                for channel in 0..data.p {
2597                    diag[base + channel] += coeff;
2598                }
2599            }
2600        }
2601        for block in &data.sparse_g_blocks {
2602            if block.row_off != block.col_off {
2603                continue;
2604            }
2605            let (rows, cols) = block.data.dim();
2606            for row in 0..rows.min(cols) {
2607                let coeff = block.data[[row, row]];
2608                let beta_row = block
2609                    .row_off
2610                    .checked_add(row)
2611                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2612                let base = beta_row
2613                    .checked_mul(data.p)
2614                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2615                let end = base
2616                    .checked_add(data.p)
2617                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2618                if end > diag.len() {
2619                    return Err(ArrowSchurGpuFailure::Unavailable);
2620                }
2621                for channel in 0..data.p {
2622                    diag[base + channel] += coeff;
2623                }
2624            }
2625        }
2626        Ok(diag)
2627    }
2628
2629    fn flatten_device_sae_data(
2630        sys: &ArrowSchurSystem,
2631        data: &DeviceSaePcgData,
2632        ridge_t: f64,
2633        stream: &Arc<CudaStream>,
2634    ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2635        let n_rows = sys.rows.len();
2636        let p = data.p;
2637        let k = data.beta_dim;
2638        if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2639            return Err(ArrowSchurGpuFailure::Unavailable);
2640        }
2641
2642        let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2643        let mut beta_base_host = Vec::<i32>::new();
2644        let mut phi_host = Vec::<f64>::new();
2645        row_ptr_host.push(0_i32);
2646        for row in data.a_phi.iter() {
2647            for &(base, phi) in row {
2648                beta_base_host.push(checked_i32(base)?);
2649                phi_host.push(phi);
2650            }
2651            row_ptr_host.push(checked_i32(beta_base_host.len())?);
2652        }
2653
2654        let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2655        let mut jac_host = Vec::<f64>::new();
2656        let mut max_q = 0usize;
2657        jac_ptr_host.push(0_i32);
2658        for row_jac in data.local_jac.iter() {
2659            if row_jac.len() % p != 0 {
2660                return Err(ArrowSchurGpuFailure::Unavailable);
2661            }
2662            max_q = max_q.max(row_jac.len() / p);
2663            jac_host.extend_from_slice(row_jac);
2664            jac_ptr_host.push(checked_i32(jac_host.len())?);
2665        }
2666        if max_q == 0 {
2667            return Err(ArrowSchurGpuFailure::Unavailable);
2668        }
2669
2670        let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2671        let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2672        let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2673        let mut smooth_data_host = Vec::<f64>::new();
2674        smooth_ptr_host.push(0_i32);
2675        for block in &data.smooth_blocks {
2676            let (rows, cols) = block.factor_a.dim();
2677            if rows != cols {
2678                return Err(ArrowSchurGpuFailure::Unavailable);
2679            }
2680            smooth_offsets_host.push(checked_i32(block.global_offset)?);
2681            smooth_m_host.push(checked_i32(rows)?);
2682            for r in 0..rows {
2683                for c in 0..cols {
2684                    smooth_data_host.push(block.factor_a[[r, c]]);
2685                }
2686            }
2687            smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2688        }
2689
2690        let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2691        let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2692        let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2693        let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2694        let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2695        let mut g_data_host = Vec::<f64>::new();
2696        g_ptr_host.push(0_i32);
2697        for block in &data.sparse_g_blocks {
2698            let (rows, cols) = block.data.dim();
2699            g_row_off_host.push(checked_i32(block.row_off)?);
2700            g_col_off_host.push(checked_i32(block.col_off)?);
2701            g_rows_host.push(checked_i32(rows)?);
2702            g_cols_host.push(checked_i32(cols)?);
2703            for r in 0..rows {
2704                for c in 0..cols {
2705                    g_data_host.push(block.data[[r, c]]);
2706                }
2707            }
2708            g_ptr_host.push(checked_i32(g_data_host.len())?);
2709        }
2710
2711        let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2712        for (row_idx, row) in sys.rows.iter().enumerate() {
2713            let q = data.local_jac[row_idx].len() / p;
2714            if row.htt.dim() != (q, q) {
2715                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2716                    reason: format!(
2717                        "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2718                        row.htt.dim()
2719                    ),
2720                });
2721            }
2722            let mut block = row.htt.clone();
2723            for d in 0..q {
2724                block[[d, d]] += ridge_t;
2725            }
2726            let factor = gam_linalg::triangular::cholesky_factor_in_place(
2727                block.view(),
2728                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2729            )
2730            .ok_or_else(|| {
2731                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
2732                // indefinite per-row block recovers in one outer-loop retry.
2733                ArrowSchurGpuFailure::RidgeBumpRequired {
2734                    row: row_idx,
2735                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2736                }
2737            })?;
2738            for col in 0..q {
2739                let mut e = Array1::<f64>::zeros(q);
2740                e[col] = 1.0;
2741                let solved =
2742                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2743                for r in 0..q {
2744                    ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2745                }
2746            }
2747        }
2748
2749        Ok(DeviceSaePcgBuffers {
2750            row_ptr: stream
2751                .clone_htod(&row_ptr_host)
2752                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2753            beta_base: stream
2754                .clone_htod(&beta_base_host)
2755                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2756            phi: stream
2757                .clone_htod(&phi_host)
2758                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2759            jac_ptr: stream
2760                .clone_htod(&jac_ptr_host)
2761                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2762            jac: stream
2763                .clone_htod(&jac_host)
2764                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2765            smooth_offsets: stream
2766                .clone_htod(&smooth_offsets_host)
2767                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2768            smooth_m: stream
2769                .clone_htod(&smooth_m_host)
2770                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2771            smooth_ptr: stream
2772                .clone_htod(&smooth_ptr_host)
2773                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2774            smooth_data: stream
2775                .clone_htod(&smooth_data_host)
2776                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2777            g_row_off: stream
2778                .clone_htod(&g_row_off_host)
2779                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2780            g_col_off: stream
2781                .clone_htod(&g_col_off_host)
2782                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2783            g_rows: stream
2784                .clone_htod(&g_rows_host)
2785                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2786            g_cols: stream
2787                .clone_htod(&g_cols_host)
2788                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2789            g_ptr: stream
2790                .clone_htod(&g_ptr_host)
2791                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2792            g_data: stream
2793                .clone_htod(&g_data_host)
2794                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795            ainv: stream
2796                .clone_htod(&ainv_host)
2797                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798            u: stream
2799                .alloc_zeros::<f64>(n_rows * p)
2800                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801            w: stream
2802                .alloc_zeros::<f64>(n_rows * max_q)
2803                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804            v: stream
2805                .alloc_zeros::<f64>(n_rows * max_q)
2806                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2807            n_rows,
2808            p,
2809            k,
2810            max_q,
2811            smooth_blocks: data.smooth_blocks.len(),
2812            g_blocks: data.sparse_g_blocks.len(),
2813        })
2814    }
2815
2816    fn launch_sae_init(
2817        stream: &Arc<CudaStream>,
2818        module: &Arc<CudaModule>,
2819        out: &mut CudaSlice<f64>,
2820        x: &CudaSlice<f64>,
2821        ridge: f64,
2822        n: usize,
2823    ) -> Result<(), ArrowSchurGpuFailure> {
2824        let kernel = module
2825            .load_function("arrow_sae_init")
2826            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2827        let n_i32 = checked_i32(n)?;
2828        let mut builder = stream.launch_builder(&kernel);
2829        builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2830        // SAFETY: `out` and `x` are live device buffers with at least `n`
2831        // entries on `stream`; the kernel writes one in-bounds element per
2832        // launched index below `n`.
2833        unsafe { builder.launch(pcg_launch_config(n)?) }
2834            .map(drop)
2835            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2836    }
2837
2838    fn launch_sae_penalty_matvec(
2839        stream: &Arc<CudaStream>,
2840        module: &Arc<CudaModule>,
2841        buffers: &mut DeviceSaePcgBuffers,
2842        x: &CudaSlice<f64>,
2843        out: &mut CudaSlice<f64>,
2844        ridge_beta: f64,
2845    ) -> Result<(), ArrowSchurGpuFailure> {
2846        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2847        if buffers.smooth_blocks > 0 {
2848            let kernel = module
2849                .load_function("arrow_sae_smooth_matvec")
2850                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2851            let max_m = buffers.k;
2852            let p_i32 = checked_i32(buffers.p)?;
2853            let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2854            let cfg = LaunchConfig {
2855                grid_dim: (
2856                    ((max_m as u32).saturating_add(255) / 256).max(1),
2857                    checked_i32(buffers.smooth_blocks)? as u32,
2858                    1,
2859                ),
2860                block_dim: (256, 1, 1),
2861                shared_mem_bytes: 0,
2862            };
2863            let mut builder = stream.launch_builder(&kernel);
2864            builder
2865                .arg(x)
2866                .arg(&mut *out)
2867                .arg(&buffers.smooth_offsets)
2868                .arg(&buffers.smooth_m)
2869                .arg(&buffers.smooth_ptr)
2870                .arg(&buffers.smooth_data)
2871                .arg(&p_i32)
2872                .arg(&blocks_i32);
2873            // SAFETY: smooth block metadata and dense smooth data were flattened
2874            // into live device buffers; the 2D grid covers only declared block
2875            // and coefficient-channel work items, and the kernel bounds-checks
2876            // against each block's stored size.
2877            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2878        }
2879        if buffers.g_blocks > 0 {
2880            let kernel = module
2881                .load_function("arrow_sae_sparse_g_matvec")
2882                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2883            let max_work = buffers
2884                .k
2885                .checked_div(buffers.p)
2886                .unwrap_or(0)
2887                .saturating_mul(buffers.p);
2888            let p_i32 = checked_i32(buffers.p)?;
2889            let blocks_i32 = checked_i32(buffers.g_blocks)?;
2890            let cfg = LaunchConfig {
2891                grid_dim: (
2892                    ((max_work as u32).saturating_add(255) / 256).max(1),
2893                    checked_i32(buffers.g_blocks)? as u32,
2894                    1,
2895                ),
2896                block_dim: (256, 1, 1),
2897                shared_mem_bytes: 0,
2898            };
2899            let mut builder = stream.launch_builder(&kernel);
2900            builder
2901                .arg(x)
2902                .arg(&mut *out)
2903                .arg(&buffers.g_row_off)
2904                .arg(&buffers.g_col_off)
2905                .arg(&buffers.g_rows)
2906                .arg(&buffers.g_cols)
2907                .arg(&buffers.g_ptr)
2908                .arg(&buffers.g_data)
2909                .arg(&p_i32)
2910                .arg(&blocks_i32);
2911            // SAFETY: sparse G block metadata/data are live device buffers built
2912            // from host CSR-like block descriptors; the launch dimensions cover
2913            // declared block work only and the kernel checks row/column bounds
2914            // before reading or accumulating.
2915            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2916        }
2917        Ok(())
2918    }
2919
2920    fn launch_sae_row_schur_sub(
2921        stream: &Arc<CudaStream>,
2922        module: &Arc<CudaModule>,
2923        buffers: &mut DeviceSaePcgBuffers,
2924        x: &CudaSlice<f64>,
2925        out: &mut CudaSlice<f64>,
2926    ) -> Result<(), ArrowSchurGpuFailure> {
2927        let p_i32 = checked_i32(buffers.p)?;
2928        let max_q_i32 = checked_i32(buffers.max_q)?;
2929        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2930        let cfg_p_rows = LaunchConfig {
2931            grid_dim: (
2932                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2933                checked_i32(buffers.n_rows)? as u32,
2934                1,
2935            ),
2936            block_dim: (256, 1, 1),
2937            shared_mem_bytes: 0,
2938        };
2939        let gather = module
2940            .load_function("arrow_sae_gather_u")
2941            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2942        {
2943            let mut builder = stream.launch_builder(&gather);
2944            builder
2945                .arg(x)
2946                .arg(&buffers.row_ptr)
2947                .arg(&buffers.beta_base)
2948                .arg(&buffers.phi)
2949                .arg(&mut buffers.u)
2950                .arg(&p_i32)
2951                .arg(&n_rows_i32);
2952            // SAFETY: `x`, row pointers, beta offsets, basis rows, and `u` are
2953            // live device buffers sized for `n_rows` by `p`; the kernel guards
2954            // row/channel indices before gathering.
2955            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2956        }
2957
2958        let cfg_q_rows = LaunchConfig {
2959            grid_dim: (
2960                ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2961                checked_i32(buffers.n_rows)? as u32,
2962                1,
2963            ),
2964            block_dim: (256, 1, 1),
2965            shared_mem_bytes: 0,
2966        };
2967        let apply_l = module
2968            .load_function("arrow_sae_apply_l")
2969            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2970        {
2971            let mut builder = stream.launch_builder(&apply_l);
2972            builder
2973                .arg(&buffers.u)
2974                .arg(&buffers.jac_ptr)
2975                .arg(&buffers.jac)
2976                .arg(&mut buffers.w)
2977                .arg(&p_i32)
2978                .arg(&max_q_i32)
2979                .arg(&n_rows_i32);
2980            // SAFETY: `u`, Jacobian row pointers/data, and `w` are live buffers
2981            // sized for the `(n_rows, p)` to `(n_rows, max_q)` multiply; the
2982            // kernel checks row and local-coordinate bounds.
2983            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2984        }
2985
2986        let apply_ainv = module
2987            .load_function("arrow_sae_apply_ainv")
2988            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2989        {
2990            let mut builder = stream.launch_builder(&apply_ainv);
2991            builder
2992                .arg(&buffers.ainv)
2993                .arg(&buffers.w)
2994                .arg(&mut buffers.v)
2995                .arg(&max_q_i32)
2996                .arg(&n_rows_i32);
2997            // SAFETY: `ainv`, `w`, and `v` are live device buffers sized for
2998            // `n_rows * max_q`; the kernel guards all row/local-coordinate
2999            // indices before reading or writing.
3000            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3001        }
3002
3003        let scatter = module
3004            .load_function("arrow_sae_scatter_sub")
3005            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3006        {
3007            let mut builder = stream.launch_builder(&scatter);
3008            builder
3009                .arg(&buffers.v)
3010                .arg(&buffers.jac_ptr)
3011                .arg(&buffers.jac)
3012                .arg(&buffers.row_ptr)
3013                .arg(&buffers.beta_base)
3014                .arg(&buffers.phi)
3015                .arg(out)
3016                .arg(&p_i32)
3017                .arg(&max_q_i32)
3018                .arg(&n_rows_i32);
3019            // SAFETY: `v`, Jacobian metadata, row pointers, beta offsets, basis
3020            // rows, and `out` are live buffers for `n_rows` by `p`; scatter
3021            // indices are checked against row and channel bounds in the kernel.
3022            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3023        }
3024        Ok(())
3025    }
3026
3027    fn launch_sae_diag_sub(
3028        stream: &Arc<CudaStream>,
3029        module: &Arc<CudaModule>,
3030        buffers: &DeviceSaePcgBuffers,
3031        diag: &mut CudaSlice<f64>,
3032    ) -> Result<(), ArrowSchurGpuFailure> {
3033        let kernel = module
3034            .load_function("arrow_sae_diag_sub")
3035            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3036        let p_i32 = checked_i32(buffers.p)?;
3037        let max_q_i32 = checked_i32(buffers.max_q)?;
3038        let n_rows_i32 = checked_i32(buffers.n_rows)?;
3039        let cfg = LaunchConfig {
3040            grid_dim: (
3041                ((buffers.p as u32).saturating_add(255) / 256).max(1),
3042                checked_i32(buffers.n_rows)? as u32,
3043                1,
3044            ),
3045            block_dim: (256, 1, 1),
3046            shared_mem_bytes: 0,
3047        };
3048        let mut builder = stream.launch_builder(&kernel);
3049        builder
3050            .arg(diag)
3051            .arg(&buffers.ainv)
3052            .arg(&buffers.jac_ptr)
3053            .arg(&buffers.jac)
3054            .arg(&buffers.row_ptr)
3055            .arg(&buffers.beta_base)
3056            .arg(&buffers.phi)
3057            .arg(&p_i32)
3058            .arg(&max_q_i32)
3059            .arg(&n_rows_i32);
3060        // SAFETY: diagonal output and all read-only SAE row metadata buffers are
3061        // live on `stream` with sizes matching `n_rows`, `p`, and `max_q`; the
3062        // kernel bounds-checks its flattened work index.
3063        unsafe { builder.launch(cfg) }
3064            .map(drop)
3065            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3066    }
3067
3068    fn launch_sae_matvec(
3069        stream: &Arc<CudaStream>,
3070        module: &Arc<CudaModule>,
3071        buffers: &mut DeviceSaePcgBuffers,
3072        x: &CudaSlice<f64>,
3073        out: &mut CudaSlice<f64>,
3074        ridge_beta: f64,
3075    ) -> Result<(), ArrowSchurGpuFailure> {
3076        launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3077        launch_sae_row_schur_sub(stream, module, buffers, x, out)
3078    }
3079
3080    /// Pack `D + ρ_t I`, `B`, and `g` into the strided `(n × P_MAX × P_MAX)`
3081    /// / `(n × P_MAX × R_TEMPLATE)` / `(n × P_MAX)` layout the fused kernel
3082    /// expects. Entries outside the runtime `(p, r)` window stay at zero so
3083    /// the kernel's per-element loops are safe to no-op there.
3084    fn pack_fused_host(
3085        sys: &ArrowSchurSystem,
3086        ridge_t: f64,
3087        p_max: usize,
3088        r_template: usize,
3089    ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3090        let n = sys.rows.len();
3091        let d = sys.d;
3092        let k = sys.k;
3093        let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3094        let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3095        let mut g_buf = vec![0.0_f64; n * p_max];
3096        for (i, row) in sys.rows.iter().enumerate() {
3097            // D_i + ρI, column-major in P_MAX×P_MAX strided block.
3098            for col in 0..d {
3099                let base = (i * p_max + col) * p_max;
3100                for r in 0..d {
3101                    let mut value = row.htt[[r, col]];
3102                    if r == col {
3103                        value += ridge_t;
3104                    }
3105                    d_buf[base + r] = value;
3106                }
3107            }
3108            // B_i in P_MAX×R_TEMPLATE strided block. The per-row (per-i) block
3109            // stride is `p_max · r_template` (matching the `b_buf` allocation
3110            // above and the kernel's `b_stack`/`y_out` layout), NOT
3111            // `p_max · p_max`: using the D-block multiplier here overflows the
3112            // buffer whenever `p_max > r_template` (e.g. d=30→p_max=32,
3113            // k=5→r_template=5). The within-block element offset stays
3114            // column-major `col·p_max + r` (P_MAX rows per column).
3115            for col in 0..k {
3116                let base = (i * r_template + col) * p_max;
3117                for r in 0..d {
3118                    b_buf[base + r] = row.htbeta[[r, col]];
3119                }
3120            }
3121            // g_i in P_MAX strided vector.
3122            let g_base = i * p_max;
3123            for r in 0..d {
3124                g_buf[g_base + r] = row.gt[r];
3125            }
3126        }
3127        (d_buf, b_buf, g_buf)
3128    }
3129
3130    // -----------------------------------------------------------------------
3131    // #1017 Phase 3: across-iteration device residency.
3132    //
3133    // `solve()` re-packs and re-uploads `D` (`H_tt`), `B` (`H_tβ`) and `g`,
3134    // then re-runs the per-row POTRF and the border Schur factorization on
3135    // EVERY call. For the SAE joint inner Newton at a frozen gate/basis frame
3136    // the Hessian blocks `D`, `B`, `H_ββ` are CONSTANT across the inner loop —
3137    // only the gradient `g = r(z) = H z − g₀` changes per iterate. So the
3138    // factor work (`O(n·d³ + p³)`) and the dominant `O(n·d·p)` cross-block
3139    // upload are pure waste when repeated per iterate.
3140    //
3141    // `ResidentArrowFrame` performs that constant work ONCE at construction:
3142    // upload+ridge+POTRF of `D` (keeping `L_i` resident in `l_dev`), the
3143    // forward solve `Y_i = L_i^{-1} B_i` (kept resident in `y_dev`), and the
3144    // Schur assembly + border POTRF (keeping `L_S` resident in `schur_dev`).
3145    // Each subsequent `solve_gradient(g)` uploads only the `n·d` row gradient,
3146    // runs the cheap residual path — `u_i = L_i^{-1} g_i` (one batched TRSM),
3147    // Schur RHS `−g_β + Σ Y_iᵀ u_i`, `δβ = L_S^{-T} L_S^{-1} rhs` (two TRSM,
3148    // NO POTRF), back-sub `δt_i = −L_i^{-T}(u_i + Y_i δβ)` — and reads back only
3149    // `δ` and the cached log|H|. The heavy buffers never leave the device
3150    // across iterations; the per-iterate host transfer is `O(n·d + p)`, not
3151    // `O(n·d·p)`. Numerics are bit-identical to a `solve()` at the same
3152    // `(D, B, H_ββ, g, ridge_t, ridge_beta)` because the factor buffers and the
3153    // helper kernels are the same; the resident path merely SKIPS re-deriving
3154    // the parts that do not depend on `g`. The CPU dense reference
3155    // (`solve_arrow_newton_step_dense_reference`) is the parity oracle.
3156    pub(super) struct ResidentArrowFrame {
3157        n: usize,
3158        d: usize,
3159        k: usize,
3160        stream: Arc<CudaStream>,
3161        blas: CudaBlas,
3162        /// Per-row lower Cholesky factors `L_i` of `H_tt + ρ_t I`, stacked
3163        /// column-major (`n` tiles of `d×d`). Resident across iterations.
3164        l_dev: CudaSlice<f64>,
3165        /// Whitened cross blocks `Y_i = L_i^{-1} H_tβ^(i)`, stacked column-major
3166        /// (`n` tiles of `d×k`). Resident across iterations.
3167        y_dev: CudaSlice<f64>,
3168        /// Lower Cholesky factor `L_S` of the reduced Schur complement
3169        /// `S_β = H_ββ + ρ_β I − Σ_i Y_iᵀ Y_i`. Resident across iterations.
3170        schur_dev: CudaSlice<f64>,
3171        /// `log|H| = 2 Σ log L_{i,jj} + 2 Σ log L_{S,aa}`, constant for the
3172        /// frame (depends only on the factored Hessian, not on `g`).
3173        log_det_hessian: f64,
3174    }
3175
3176    impl ResidentArrowFrame {
3177        /// Upload the constant Hessian blocks and perform the one-time factor
3178        /// work (`POTRF(D)`, `Y_i = L_i^{-1} B_i`, Schur assembly + border
3179        /// `POTRF`). The frame then serves cheap per-gradient solves.
3180        pub(super) fn new(
3181            sys: &ArrowSchurSystem,
3182            ridge_t: f64,
3183            ridge_beta: f64,
3184        ) -> Result<Self, ArrowSchurGpuFailure> {
3185            if ridge_t.is_nan() || ridge_beta.is_nan() {
3186                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3187                    reason: "ridge is NaN".to_string(),
3188                });
3189            }
3190            let n = sys.rows.len();
3191            let d = sys.d;
3192            let k = sys.k;
3193            let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3194                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3195            let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3196                .and_then(|ctx| ctx.new_stream().ok())
3197                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3198            let solver =
3199                DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3200            let blas =
3201                CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3202
3203            // Upload the constant blocks. `g` is uploaded per-gradient, not here.
3204            let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3205            let mut l_dev = stream
3206                .clone_htod(&d_host)
3207                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3208            let mut y_dev = stream
3209                .clone_htod(&b_host)
3210                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3211
3212            // POTRF(D) → L_i, kept resident in l_dev.
3213            let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3214            if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3215                // cuSOLVER `info` is a 1-based pivot row index; size the bump
3216                // from the block (Gershgorin λ_min bound) so a strongly
3217                // indefinite block recovers in one retry.
3218                return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3219                    row: idx,
3220                    bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3221                });
3222            }
3223
3224            // Y_i = L_i^{-1} B_i, in place over y_dev. Kept resident.
3225            trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3226
3227            // Schur assembly S_β = (H_ββ + ρ_β I) − Σ Y_iᵀ Y_i, then POTRF → L_S.
3228            // The RHS accumulation is folded into the gradient path; here we
3229            // only need the (gradient-independent) Schur factor, so accumulate
3230            // into a throwaway rhs buffer.
3231            let schur_init: Vec<f64> = {
3232                let mut tmp = Vec::with_capacity(k * k);
3233                for col in 0..k {
3234                    for row in 0..k {
3235                        let mut v = sys.hbb[[row, col]];
3236                        if row == col {
3237                            v += ridge_beta;
3238                        }
3239                        tmp.push(v);
3240                    }
3241                }
3242                tmp
3243            };
3244            let mut schur_dev = stream
3245                .clone_htod(&schur_init)
3246                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3247            // A zero u-stack makes `Σ Y_iᵀ u_i = 0`, so only the `−Σ Y_iᵀ Y_i`
3248            // Schur term is accumulated (the rhs is rebuilt per gradient).
3249            let zero_u = stream
3250                .clone_htod(&vec![0.0_f64; n * d])
3251                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3252            let mut throwaway_rhs = stream
3253                .clone_htod(&vec![0.0_f64; k])
3254                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3255            accumulate_schur(
3256                &blas,
3257                d,
3258                k,
3259                n,
3260                &y_dev,
3261                &zero_u,
3262                &mut schur_dev,
3263                &mut throwaway_rhs,
3264            )?;
3265            let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3266            if info != 0 {
3267                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3268                    reason: format!("Schur Cholesky failed at pivot {info}"),
3269                });
3270            }
3271
3272            // log|H| from the resident factors (constant for the frame).
3273            let l_local_host = stream
3274                .clone_dtoh(&l_dev)
3275                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3276            let l_schur_host = stream
3277                .clone_dtoh(&schur_dev)
3278                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3279            let mut log_det = 0.0_f64;
3280            for i in 0..n {
3281                let base = i * d * d;
3282                for j in 0..d {
3283                    log_det += l_local_host[base + j * d + j].ln();
3284                }
3285            }
3286            for j in 0..k {
3287                log_det += l_schur_host[j * k + j].ln();
3288            }
3289            log_det *= 2.0;
3290
3291            Ok(Self {
3292                n,
3293                d,
3294                k,
3295                stream,
3296                blas,
3297                l_dev,
3298                y_dev,
3299                schur_dev,
3300                log_det_hessian: log_det,
3301            })
3302        }
3303
3304        #[inline]
3305        pub(super) fn log_det_hessian(&self) -> f64 {
3306            self.log_det_hessian
3307        }
3308
3309        /// Solve `H δ = −gradient` for a fresh gradient `(g_t, g_β)` reusing the
3310        /// resident factors. Uploads only `g_t` (`n·d` scalars); reads back only
3311        /// `δ`. No POTRF runs here — all factorization is amortized into `new`.
3312        pub(super) fn solve_gradient(
3313            &self,
3314            g_t: &[f64],
3315            g_beta: &[f64],
3316        ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3317            let n = self.n;
3318            let d = self.d;
3319            let k = self.k;
3320            if g_t.len() != n * d || g_beta.len() != k {
3321                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3322                    reason: format!(
3323                        "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3324                        g_t.len(),
3325                        n * d,
3326                        g_beta.len(),
3327                        k
3328                    ),
3329                });
3330            }
3331            // Upload the per-iterate row gradient → u_i = L_i^{-1} g_i in place.
3332            let mut u_dev = self
3333                .stream
3334                .clone_htod(&g_t.to_vec())
3335                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3336            trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3337
3338            // Schur RHS = −g_β + Σ_i Y_iᵀ u_i. Reuse the resident Schur factor
3339            // (no POTRF, and skip the −Σ Y_iᵀ Y_i GEMM already baked into L_S).
3340            let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3341            let mut rhs_dev = self
3342                .stream
3343                .clone_htod(&rhs_init)
3344                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3345            accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3346
3347            // δβ ← L_S^{-T} L_S^{-1} rhs using the resident border factor.
3348            trsm_single(
3349                &self.blas,
3350                &self.stream,
3351                k,
3352                &self.schur_dev,
3353                &mut rhs_dev,
3354                false,
3355                false,
3356            )?;
3357            trsm_single(
3358                &self.blas,
3359                &self.stream,
3360                k,
3361                &self.schur_dev,
3362                &mut rhs_dev,
3363                false,
3364                true,
3365            )?;
3366            let delta_beta_host = self
3367                .stream
3368                .clone_dtoh(&rhs_dev)
3369                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3370            let delta_beta = Array1::from_vec(delta_beta_host);
3371
3372            // Back-sub δt_i = −L_i^{-T}(u_i + Y_i δβ).
3373            accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3374            trsm_batched_lower_inplace_transposed(
3375                &self.blas,
3376                &self.stream,
3377                d,
3378                n,
3379                1,
3380                &self.l_dev,
3381                &mut u_dev,
3382            )?;
3383            let x_host = self
3384                .stream
3385                .clone_dtoh(&u_dev)
3386                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3387            let mut delta_t = Array1::<f64>::zeros(n * d);
3388            for (i, v) in x_host.iter().enumerate() {
3389                delta_t[i] = -*v;
3390            }
3391
3392            Ok(ArrowSchurGpuSolution {
3393                delta_t,
3394                delta_beta,
3395                log_det_hessian: self.log_det_hessian,
3396            })
3397        }
3398    }
3399
3400    pub(super) fn solve_fused(
3401        sys: &ArrowSchurSystem,
3402        ridge_t: f64,
3403        ridge_beta: f64,
3404    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3405        let n = sys.rows.len();
3406        let d = sys.d;
3407        let k = sys.k;
3408        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3409            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3410        let p_max = plan.p_max;
3411        let r_template = plan.r_template;
3412
3413        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3414            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3415        )
3416        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3417        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3418            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3419        let stream = ctx
3420            .new_stream()
3421            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3422        let cap = &runtime.device.capability;
3423        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3424            cc_major: cap.compute_major,
3425            cc_minor: cap.compute_minor,
3426            p_max: p_max as u32,
3427            r_template: r_template as u32,
3428        };
3429        let module = fused_module_for(&ctx, key)?;
3430        let forward = module
3431            .load_function("arrow_schur_forward_pgroup")
3432            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3433        let back_sub = module
3434            .load_function("arrow_schur_back_sub_pgroup")
3435            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3436
3437        // ----- Upload packed D, B, g -----
3438        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3439        let d_dev = stream
3440            .clone_htod(&d_host)
3441            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3442        let b_dev = stream
3443            .clone_htod(&b_host)
3444            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445        let g_dev = stream
3446            .clone_htod(&g_host)
3447            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3448        let mut l_out = stream
3449            .alloc_zeros::<f64>(n * p_max * p_max)
3450            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3451        let mut u_out = stream
3452            .alloc_zeros::<f64>(n * p_max)
3453            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3454        let mut y_out = stream
3455            .alloc_zeros::<f64>(n * p_max * r_template)
3456            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457        let mut partial_s = stream
3458            .alloc_zeros::<f64>(plan.partial_s_doubles)
3459            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3460        let mut partial_r = stream
3461            .alloc_zeros::<f64>(plan.partial_r_doubles)
3462            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463        let mut status_dev = stream
3464            .alloc_zeros::<i32>(n)
3465            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3466
3467        // ----- Launch forward kernel: 1 block per row, P_MAX threads -----
3468        let cfg = LaunchConfig {
3469            grid_dim: (plan.blocks, 1, 1),
3470            block_dim: (plan.threads_per_block, 1, 1),
3471            shared_mem_bytes: 0,
3472        };
3473        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3474        let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3475        let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3476        let ridge_arg = ridge_t;
3477        {
3478            let mut builder = stream.launch_builder(&forward);
3479            builder
3480                .arg(&d_dev)
3481                .arg(&b_dev)
3482                .arg(&g_dev)
3483                .arg(&n_i32)
3484                .arg(&p_i32)
3485                .arg(&r_i32)
3486                .arg(&ridge_arg)
3487                .arg(&mut l_out)
3488                .arg(&mut u_out)
3489                .arg(&mut y_out)
3490                .arg(&mut partial_s)
3491                .arg(&mut partial_r)
3492                .arg(&mut status_dev);
3493            // SAFETY: all buffers were just allocated on `stream` with sizes
3494            // derived from `plan`; kernel parameter list matches the
3495            // FORWARD_KERNEL_SOURCE signature.
3496            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3497        }
3498        stream
3499            .synchronize()
3500            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3501
3502        // ----- Check per-block pivot status -----
3503        let status_host = stream
3504            .clone_dtoh(&status_dev)
3505            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3506        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3507            // The NVRTC kernel's status code is a 1-based pivot row index, not
3508            // a magnitude; size the bump from the block (Gershgorin λ_min
3509            // bound) so a strongly indefinite block recovers in one retry.
3510            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3511                row,
3512                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3513            });
3514        }
3515
3516        // ----- Reduce partials on host into S_β and r_β -----
3517        let partial_s_host = stream
3518            .clone_dtoh(&partial_s)
3519            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3520        let partial_r_host = stream
3521            .clone_dtoh(&partial_r)
3522            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3523        let mut schur_host = vec![0.0_f64; k * k];
3524        for col in 0..k {
3525            for row in 0..k {
3526                let mut v = sys.hbb[[row, col]];
3527                if row == col {
3528                    v += ridge_beta;
3529                }
3530                schur_host[col * k + row] = v;
3531            }
3532        }
3533        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3534        for i in 0..n {
3535            // partial_S[i] stride is R_TEMPLATE × R_TEMPLATE column-major; we
3536            // only read the leading (k × k) sub-block.
3537            let s_base = i * r_template * r_template;
3538            for col in 0..k {
3539                let col_base = s_base + col * r_template;
3540                let dst_col_base = col * k;
3541                for row in 0..k {
3542                    schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3543                }
3544            }
3545            let r_base = i * r_template;
3546            for a in 0..k {
3547                rhs_host[a] += partial_r_host[r_base + a];
3548            }
3549        }
3550
3551        // ----- Factor S_β on device (cuSOLVER), solve for δβ -----
3552        let mut schur_dev = stream
3553            .clone_htod(&schur_host)
3554            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3555        let mut rhs_dev = stream
3556            .clone_htod(&rhs_host)
3557            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3558        let solver =
3559            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3560        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3561        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3562        if info != 0 {
3563            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3564                reason: format!("fused Schur Cholesky failed at pivot {info}"),
3565            });
3566        }
3567        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3568        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3569        let delta_beta_host = stream
3570            .clone_dtoh(&rhs_dev)
3571            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3572        let delta_beta = Array1::from_vec(delta_beta_host.clone());
3573
3574        // ----- Layer E: launch back-sub kernel using persisted L, u, Y -----
3575        let mut delta_t_dev = stream
3576            .alloc_zeros::<f64>(n * p_max)
3577            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3578        let back_cfg = LaunchConfig {
3579            grid_dim: (plan.blocks, 1, 1),
3580            block_dim: (plan.threads_per_block, 1, 1),
3581            shared_mem_bytes: 0,
3582        };
3583        {
3584            let mut builder = stream.launch_builder(&back_sub);
3585            builder
3586                .arg(&l_out)
3587                .arg(&u_out)
3588                .arg(&y_out)
3589                .arg(&rhs_dev)
3590                .arg(&n_i32)
3591                .arg(&p_i32)
3592                .arg(&r_i32)
3593                .arg(&mut delta_t_dev);
3594            // SAFETY: kernel parameter list matches FORWARD_KERNEL_SOURCE
3595            // back-sub signature; `rhs_dev` holds δβ in the leading k entries
3596            // (R_TEMPLATE-strided indexing is column 0..k of the R-vector).
3597            unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3598        }
3599        stream
3600            .synchronize()
3601            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3602
3603        let delta_t_host = stream
3604            .clone_dtoh(&delta_t_dev)
3605            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3606        let mut delta_t = Array1::<f64>::zeros(n * d);
3607        for i in 0..n {
3608            let src_base = i * p_max;
3609            let dst_base = i * d;
3610            for r in 0..d {
3611                delta_t[dst_base + r] = delta_t_host[src_base + r];
3612            }
3613        }
3614
3615        // ----- log|H| = 2·Σ log L_{i,jj} + 2·Σ log R_{β,aa} -----
3616        let l_local_host = stream
3617            .clone_dtoh(&l_out)
3618            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3619        let l_schur_host = stream
3620            .clone_dtoh(&schur_dev)
3621            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3622        let mut log_det = 0.0_f64;
3623        for i in 0..n {
3624            let base = i * p_max * p_max;
3625            for j in 0..d {
3626                log_det += l_local_host[base + j * p_max + j].ln();
3627            }
3628        }
3629        for j in 0..k {
3630            log_det += l_schur_host[j * k + j].ln();
3631        }
3632        log_det *= 2.0;
3633
3634        Ok(ArrowSchurGpuSolution {
3635            delta_t,
3636            delta_beta,
3637            log_det_hessian: log_det,
3638        })
3639    }
3640
3641    /// Pre-compute `Y_i = L_i^{-1} H_tβ^(i)` via the fused forward kernel and
3642    /// return a closure that evaluates the full Schur matvec
3643    /// `S·x = (H_ββ + ρ·I)·x − Σ_i Y_i^T (Y_i·x)` for each PCG iteration.
3644    ///
3645    /// The `Y_i` factors are kept in a host-side buffer after one GPU forward
3646    /// pass. Each matvec call runs O(N·d·K) host loops over the pre-computed
3647    /// buffer plus an optional `H_ββ·x` call (matrix-free or dense). This is
3648    /// the first landing of the GPU matvec; a future iteration can move the
3649    /// `Y_i·x` / `Y_i^T z_i` steps to cuBLAS batched GEMV.
3650    pub(super) fn build_schur_matvec_backend(
3651        sys: &ArrowSchurSystem,
3652        ridge_t: f64,
3653        ridge_beta: f64,
3654    ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3655        let n = sys.rows.len();
3656        let d = sys.d;
3657        let k = sys.k;
3658        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3659            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3660        let p_max = plan.p_max;
3661        let r_template = plan.r_template;
3662
3663        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3664            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3665        )
3666        .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3667        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3668            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3669        let stream = ctx
3670            .new_stream()
3671            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3672        let cap = &runtime.device.capability;
3673        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3674            cc_major: cap.compute_major,
3675            cc_minor: cap.compute_minor,
3676            p_max: p_max as u32,
3677            r_template: r_template as u32,
3678        };
3679        let module = fused_module_for(&ctx, key)?;
3680        let forward = module
3681            .load_function("arrow_schur_forward_pgroup")
3682            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3683
3684        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3685        let d_dev = stream
3686            .clone_htod(&d_host)
3687            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3688        let b_dev = stream
3689            .clone_htod(&b_host)
3690            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3691        let g_dev = stream
3692            .clone_htod(&g_host)
3693            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3694        let mut l_out = stream
3695            .alloc_zeros::<f64>(n * p_max * p_max)
3696            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3697        let mut u_out = stream
3698            .alloc_zeros::<f64>(n * p_max)
3699            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3700        let mut y_out = stream
3701            .alloc_zeros::<f64>(n * p_max * r_template)
3702            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3703        let mut partial_s = stream
3704            .alloc_zeros::<f64>(plan.partial_s_doubles)
3705            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3706        let mut partial_r = stream
3707            .alloc_zeros::<f64>(plan.partial_r_doubles)
3708            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3709        let mut status_dev = stream
3710            .alloc_zeros::<i32>(n)
3711            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3712
3713        let cfg = LaunchConfig {
3714            grid_dim: (plan.blocks, 1, 1),
3715            block_dim: (plan.threads_per_block, 1, 1),
3716            shared_mem_bytes: 0,
3717        };
3718        let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3719        let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3720        let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3721        let ridge_arg = ridge_t;
3722        {
3723            let mut builder = stream.launch_builder(&forward);
3724            builder
3725                .arg(&d_dev)
3726                .arg(&b_dev)
3727                .arg(&g_dev)
3728                .arg(&n_i32)
3729                .arg(&p_i32)
3730                .arg(&r_i32)
3731                .arg(&ridge_arg)
3732                .arg(&mut l_out)
3733                .arg(&mut u_out)
3734                .arg(&mut y_out)
3735                .arg(&mut partial_s)
3736                .arg(&mut partial_r)
3737                .arg(&mut status_dev);
3738            // SAFETY: all buffers were allocated on `stream` with sizes
3739            // derived from `plan`; parameter list matches FORWARD_KERNEL_SOURCE.
3740            unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3741        }
3742        stream
3743            .synchronize()
3744            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3745
3746        let status_host = stream
3747            .clone_dtoh(&status_dev)
3748            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3749        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3750            // Status code is a 1-based pivot row index, not a magnitude; size
3751            // the bump from the block (Gershgorin λ_min bound) so a strongly
3752            // indefinite block recovers in one retry.
3753            return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3754                row,
3755                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3756            });
3757        }
3758
3759        // Download Y_i factors: n × p_max × r_template column-major per block.
3760        let y_host = stream
3761            .clone_dtoh(&y_out)
3762            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3763
3764        // Capture H_ββ data for the closure. Use the matrix-free hook if present
3765        // (SAE-manifold callers), otherwise fall back to the dense matrix rows.
3766        let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3767        let hbb_is_kk = sys.hbb.dim() == (k, k);
3768        let hbb_matvec_opt = sys.hbb_matvec.clone();
3769
3770        let closure: crate::arrow_schur::GpuSchurMatvec =
3771            Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3772                assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3773                assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3774
3775                // (H_ββ + ρ·I)·x into out.
3776                if let Some(ref mv) = hbb_matvec_opt {
3777                    mv(x.view(), out);
3778                    for a in 0..k {
3779                        out[a] += ridge_beta * x[a];
3780                    }
3781                } else if hbb_is_kk {
3782                    // hbb_host row-major: hbb[a, b] = hbb_host[a * k + b].
3783                    for a in 0..k {
3784                        let mut acc = ridge_beta * x[a];
3785                        for b in 0..k {
3786                            acc += hbb_host[a * k + b] * x[b];
3787                        }
3788                        out[a] = acc;
3789                    }
3790                } else {
3791                    for a in 0..k {
3792                        out[a] = ridge_beta * x[a];
3793                    }
3794                }
3795
3796                // out[c] -= Σ_i (Y_i^T (Y_i·x))[c].
3797                // Y_i column-major at y_host[i·p_max·r_template + col·p_max + row].
3798                let mut z = vec![0.0_f64; d];
3799                for i in 0..n {
3800                    let y_base = i * p_max * r_template;
3801                    for r in 0..d {
3802                        let mut acc = 0.0;
3803                        for c in 0..k {
3804                            acc += y_host[y_base + c * p_max + r] * x[c];
3805                        }
3806                        z[r] = acc;
3807                    }
3808                    for c in 0..k {
3809                        let mut acc = 0.0;
3810                        for r in 0..d {
3811                            acc += y_host[y_base + c * p_max + r] * z[r];
3812                        }
3813                        out[c] -= acc;
3814                    }
3815                }
3816            });
3817
3818        Ok(closure)
3819    }
3820
3821    // ── #1017/#1026 frames-engaged device PCG ──────────────────────────────
3822
3823    struct DeviceSaeFrameBuffers {
3824        // Smooth `λ S_k ⊗ I_{r_k}`.
3825        s_off: CudaSlice<i32>,
3826        s_m: CudaSlice<i32>,
3827        s_r: CudaSlice<i32>,
3828        s_ptr: CudaSlice<i32>,
3829        s_data: CudaSlice<f64>,
3830        s_blocks: usize,
3831        // Data `G_{ij} ⊗ W_{ij}`.
3832        g_off_i: CudaSlice<i32>,
3833        g_off_j: CudaSlice<i32>,
3834        g_ri: CudaSlice<i32>,
3835        g_rj: CudaSlice<i32>,
3836        g_mi: CudaSlice<i32>,
3837        g_mj: CudaSlice<i32>,
3838        g_ptr: CudaSlice<i32>,
3839        g_data: CudaSlice<f64>,
3840        w_ptr: CudaSlice<i32>,
3841        w_data: CudaSlice<f64>,
3842        g_blocks: usize,
3843        g_max_work: usize,
3844        // Per-row dense cross-block H_tβ^(i) + row q + factored ainv.
3845        htb_ptr: CudaSlice<i32>,
3846        htb: CudaSlice<f64>,
3847        q_of: CudaSlice<i32>,
3848        ainv: CudaSlice<f64>,
3849        hvec: CudaSlice<f64>,
3850        svec: CudaSlice<f64>,
3851        n_rows: usize,
3852        k: usize,
3853        max_q: usize,
3854    }
3855
3856    fn flatten_device_sae_frame_data(
3857        sys: &ArrowSchurSystem,
3858        data: &DeviceSaePcgData,
3859        frame: &DeviceSaeFrameData,
3860        ridge_t: f64,
3861        stream: &Arc<CudaStream>,
3862    ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3863        let n_rows = sys.rows.len();
3864        let k = data.beta_dim;
3865        if frame.row_htbeta.len() != n_rows
3866            || frame.ranks.len() != frame.basis_sizes.len()
3867            || frame.border_offsets.len() != frame.ranks.len()
3868            || data.smooth_blocks.len() != frame.smooth_ranks.len()
3869        {
3870            return Err(ArrowSchurGpuFailure::Unavailable);
3871        }
3872
3873        // Smooth blocks.
3874        let mut s_off = Vec::new();
3875        let mut s_m = Vec::new();
3876        let mut s_r = Vec::new();
3877        let mut s_ptr = vec![0_i32];
3878        let mut s_data = Vec::<f64>::new();
3879        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3880            let (m, mc) = blk.factor_a.dim();
3881            if m != mc {
3882                return Err(ArrowSchurGpuFailure::Unavailable);
3883            }
3884            s_off.push(checked_i32(blk.global_offset)?);
3885            s_m.push(checked_i32(m)?);
3886            s_r.push(checked_i32(r)?);
3887            for ri in 0..m {
3888                for ci in 0..m {
3889                    s_data.push(blk.factor_a[[ri, ci]]);
3890                }
3891            }
3892            s_ptr.push(checked_i32(s_data.len())?);
3893        }
3894
3895        // Data blocks (g + w).
3896        let mut g_off_i = Vec::new();
3897        let mut g_off_j = Vec::new();
3898        let mut g_ri = Vec::new();
3899        let mut g_rj = Vec::new();
3900        let mut g_mi = Vec::new();
3901        let mut g_mj = Vec::new();
3902        let mut g_ptr = vec![0_i32];
3903        let mut g_data = Vec::<f64>::new();
3904        let mut w_ptr = vec![0_i32];
3905        let mut w_data = Vec::<f64>::new();
3906        let mut g_max_work = 0usize;
3907        for blk in &frame.frame_blocks {
3908            let ri = frame.ranks[blk.atom_i];
3909            let rj = frame.ranks[blk.atom_j];
3910            let (mi, mj) = blk.g.dim();
3911            if blk.w.dim() != (ri, rj) {
3912                return Err(ArrowSchurGpuFailure::Unavailable);
3913            }
3914            g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3915            g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3916            g_ri.push(checked_i32(ri)?);
3917            g_rj.push(checked_i32(rj)?);
3918            g_mi.push(checked_i32(mi)?);
3919            g_mj.push(checked_i32(mj)?);
3920            for r in 0..mi {
3921                for c in 0..mj {
3922                    g_data.push(blk.g[[r, c]]);
3923                }
3924            }
3925            g_ptr.push(checked_i32(g_data.len())?);
3926            for a in 0..ri {
3927                for b in 0..rj {
3928                    w_data.push(blk.w[[a, b]]);
3929                }
3930            }
3931            w_ptr.push(checked_i32(w_data.len())?);
3932            g_max_work = g_max_work.max(mi * ri);
3933        }
3934
3935        // Per-row dense cross-block + q + ainv (factored H_tt⁻¹).
3936        let mut htb_ptr = vec![0_i32];
3937        let mut htb = Vec::<f64>::new();
3938        let mut q_of = Vec::<i32>::with_capacity(n_rows);
3939        let mut max_q = 0usize;
3940        for (i, slab) in frame.row_htbeta.iter().enumerate() {
3941            let qi = sys.row_dims[i];
3942            // A populated slab must be q_i × k row-major; an empty slab ⇒ q=0
3943            // (the row contributes no reduced-Schur term).
3944            let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3945                qi
3946            } else {
3947                0
3948            };
3949            q_of.push(checked_i32(q_eff)?);
3950            max_q = max_q.max(q_eff);
3951            if q_eff > 0 {
3952                htb.extend_from_slice(slab);
3953            }
3954            htb_ptr.push(checked_i32(htb.len())?);
3955        }
3956        if max_q == 0 {
3957            // No row contributes a reduced term — the system is pure-penalty.
3958            // Still valid; give max_q=1 so the ainv buffer is non-empty.
3959            max_q = 1;
3960        }
3961
3962        let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3963        for (i, row) in sys.rows.iter().enumerate() {
3964            let q = q_of[i] as usize;
3965            if q == 0 {
3966                continue;
3967            }
3968            if row.htt.dim() != (q, q) {
3969                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3970                    reason: format!(
3971                        "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3972                        row.htt.dim()
3973                    ),
3974                });
3975            }
3976            let mut block = row.htt.clone();
3977            for d in 0..q {
3978                block[[d, d]] += ridge_t;
3979            }
3980            let factor = gam_linalg::triangular::cholesky_factor_in_place(
3981                block.view(),
3982                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3983            )
3984            .ok_or_else(|| {
3985                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
3986                // indefinite per-row block recovers in one outer-loop retry.
3987                ArrowSchurGpuFailure::RidgeBumpRequired {
3988                    row: i,
3989                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3990                }
3991            })?;
3992            for col in 0..q {
3993                let mut e = Array1::<f64>::zeros(q);
3994                e[col] = 1.0;
3995                let solved =
3996                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3997                for r in 0..q {
3998                    ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3999                }
4000            }
4001        }
4002
4003        let htod_i = |v: &[i32]| {
4004            stream
4005                .clone_htod(v)
4006                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4007        };
4008        let htod_f = |v: &[f64]| {
4009            stream
4010                .clone_htod(v)
4011                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4012        };
4013        Ok(DeviceSaeFrameBuffers {
4014            s_off: htod_i(&s_off)?,
4015            s_m: htod_i(&s_m)?,
4016            s_r: htod_i(&s_r)?,
4017            s_ptr: htod_i(&s_ptr)?,
4018            s_data: htod_f(&s_data)?,
4019            s_blocks: data.smooth_blocks.len(),
4020            g_off_i: htod_i(&g_off_i)?,
4021            g_off_j: htod_i(&g_off_j)?,
4022            g_ri: htod_i(&g_ri)?,
4023            g_rj: htod_i(&g_rj)?,
4024            g_mi: htod_i(&g_mi)?,
4025            g_mj: htod_i(&g_mj)?,
4026            g_ptr: htod_i(&g_ptr)?,
4027            g_data: htod_f(&g_data)?,
4028            w_ptr: htod_i(&w_ptr)?,
4029            w_data: htod_f(&w_data)?,
4030            g_blocks: frame.frame_blocks.len(),
4031            g_max_work,
4032            htb_ptr: htod_i(&htb_ptr)?,
4033            htb: htod_f(&htb)?,
4034            q_of: htod_i(&q_of)?,
4035            ainv: htod_f(&ainv)?,
4036            hvec: stream
4037                .alloc_zeros::<f64>(n_rows * max_q)
4038                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4039            svec: stream
4040                .alloc_zeros::<f64>(n_rows * max_q)
4041                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4042            n_rows,
4043            k,
4044            max_q,
4045        })
4046    }
4047
4048    fn sae_frame_penalty_diag_host(
4049        data: &DeviceSaePcgData,
4050        frame: &DeviceSaeFrameData,
4051        ridge_beta: f64,
4052    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4053        let mut diag = vec![ridge_beta; data.beta_dim];
4054        // Smooth: diag[off + ia·r + ib] += S[ia,ia].
4055        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4056            let m = blk.factor_a.nrows();
4057            for ia in 0..m {
4058                let coeff = blk.factor_a[[ia, ia]];
4059                let base = blk.global_offset + ia * r;
4060                for ib in 0..r {
4061                    if base + ib >= diag.len() {
4062                        return Err(ArrowSchurGpuFailure::Unavailable);
4063                    }
4064                    diag[base + ib] += coeff;
4065                }
4066            }
4067        }
4068        // Data: on-diagonal atom blocks contribute g[li,li]·w[a,a].
4069        for blk in &frame.frame_blocks {
4070            if blk.atom_i != blk.atom_j {
4071                continue;
4072            }
4073            let r = frame.ranks[blk.atom_i];
4074            let off = frame.border_offsets[blk.atom_i];
4075            let (mi, mj) = blk.g.dim();
4076            for li in 0..mi.min(mj) {
4077                let gii = blk.g[[li, li]];
4078                let base = off + li * r;
4079                for a in 0..r {
4080                    if base + a >= diag.len() {
4081                        return Err(ArrowSchurGpuFailure::Unavailable);
4082                    }
4083                    diag[base + a] += gii * blk.w[[a, a]];
4084                }
4085            }
4086        }
4087        Ok(diag)
4088    }
4089
4090    fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4091        Ok(LaunchConfig {
4092            grid_dim: (
4093                ((work as u32).saturating_add(255) / 256).max(1),
4094                checked_i32(n_rows)? as u32,
4095                1,
4096            ),
4097            block_dim: (256, 1, 1),
4098            shared_mem_bytes: 0,
4099        })
4100    }
4101
4102    fn launch_sae_frame_matvec(
4103        stream: &Arc<CudaStream>,
4104        module: &Arc<CudaModule>,
4105        buffers: &mut DeviceSaeFrameBuffers,
4106        x: &CudaSlice<f64>,
4107        out: &mut CudaSlice<f64>,
4108        ridge_beta: f64,
4109    ) -> Result<(), ArrowSchurGpuFailure> {
4110        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4111        // Smooth penalty.
4112        if buffers.s_blocks > 0 {
4113            let kernel = module
4114                .load_function("arrow_sae_frame_smooth_matvec")
4115                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4116            let blocks_i32 = checked_i32(buffers.s_blocks)?;
4117            let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4118            let mut b = stream.launch_builder(&kernel);
4119            b.arg(x)
4120                .arg(&mut *out)
4121                .arg(&buffers.s_off)
4122                .arg(&buffers.s_m)
4123                .arg(&buffers.s_r)
4124                .arg(&buffers.s_ptr)
4125                .arg(&buffers.s_data)
4126                .arg(&blocks_i32);
4127            // SAFETY: smooth block metadata/data are live device buffers; the grid
4128            // covers (k channels × n_blocks) and the kernel bounds-checks m·r.
4129            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4130        }
4131        // Data penalty.
4132        if buffers.g_blocks > 0 {
4133            let kernel = module
4134                .load_function("arrow_sae_frame_g_matvec")
4135                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4136            let blocks_i32 = checked_i32(buffers.g_blocks)?;
4137            let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4138            let mut b = stream.launch_builder(&kernel);
4139            b.arg(x)
4140                .arg(&mut *out)
4141                .arg(&buffers.g_off_i)
4142                .arg(&buffers.g_off_j)
4143                .arg(&buffers.g_ri)
4144                .arg(&buffers.g_rj)
4145                .arg(&buffers.g_mi)
4146                .arg(&buffers.g_mj)
4147                .arg(&buffers.g_ptr)
4148                .arg(&buffers.g_data)
4149                .arg(&buffers.w_ptr)
4150                .arg(&buffers.w_data)
4151                .arg(&blocks_i32);
4152            // SAFETY: g/w block metadata/data are live device buffers; the grid
4153            // covers (max m_i·r_i × n_blocks) and the kernel bounds-checks.
4154            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4155        }
4156        // Reduced-Schur subtraction via dense per-row cross-blocks.
4157        let k_i32 = checked_i32(buffers.k)?;
4158        let max_q_i32 = checked_i32(buffers.max_q)?;
4159        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4160        {
4161            let kernel = module
4162                .load_function("arrow_sae_frame_apply_h")
4163                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4164            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4165            let mut b = stream.launch_builder(&kernel);
4166            b.arg(x)
4167                .arg(&buffers.htb_ptr)
4168                .arg(&buffers.htb)
4169                .arg(&buffers.q_of)
4170                .arg(&mut buffers.hvec)
4171                .arg(&k_i32)
4172                .arg(&max_q_i32)
4173                .arg(&n_rows_i32);
4174            // SAFETY: dense cross-block + pointers + hvec are live buffers sized
4175            // for (n_rows × max_q) / (n_rows × k); kernel guards q_i and k.
4176            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4177        }
4178        {
4179            let kernel = module
4180                .load_function("arrow_sae_frame_apply_ainv")
4181                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4182            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4183            let mut b = stream.launch_builder(&kernel);
4184            b.arg(&buffers.ainv)
4185                .arg(&buffers.hvec)
4186                .arg(&buffers.q_of)
4187                .arg(&mut buffers.svec)
4188                .arg(&max_q_i32)
4189                .arg(&n_rows_i32);
4190            // SAFETY: ainv/hvec/svec are live buffers sized for n_rows·max_q²
4191            // and n_rows·max_q; the kernel guards row/coord bounds.
4192            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4193        }
4194        {
4195            let kernel = module
4196                .load_function("arrow_sae_frame_scatter_h")
4197                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4198            let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4199            let mut b = stream.launch_builder(&kernel);
4200            b.arg(&buffers.svec)
4201                .arg(&buffers.htb_ptr)
4202                .arg(&buffers.htb)
4203                .arg(&buffers.q_of)
4204                .arg(out)
4205                .arg(&k_i32)
4206                .arg(&max_q_i32)
4207                .arg(&n_rows_i32);
4208            // SAFETY: svec/cross-block/out are live buffers; the kernel atomically
4209            // accumulates into out[a] for a<k and reads c<q_i.
4210            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4211        }
4212        Ok(())
4213    }
4214
4215    fn launch_sae_frame_diag_sub(
4216        stream: &Arc<CudaStream>,
4217        module: &Arc<CudaModule>,
4218        buffers: &DeviceSaeFrameBuffers,
4219        diag: &mut CudaSlice<f64>,
4220    ) -> Result<(), ArrowSchurGpuFailure> {
4221        let kernel = module
4222            .load_function("arrow_sae_frame_diag_sub")
4223            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4224        let k_i32 = checked_i32(buffers.k)?;
4225        let max_q_i32 = checked_i32(buffers.max_q)?;
4226        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4227        let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4228        let mut b = stream.launch_builder(&kernel);
4229        b.arg(diag)
4230            .arg(&buffers.ainv)
4231            .arg(&buffers.htb_ptr)
4232            .arg(&buffers.htb)
4233            .arg(&buffers.q_of)
4234            .arg(&k_i32)
4235            .arg(&max_q_i32)
4236            .arg(&n_rows_i32);
4237        // SAFETY: diag + cross-block + ainv live buffers; kernel guards a<k, c/d<q.
4238        unsafe { b.launch(cfg) }
4239            .map(drop)
4240            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4241    }
4242
4243    /// #1551 kernel-isolating seam: evaluate the framed reduced-Schur matvec
4244    /// `out = S·x` EXACTLY ONCE on the device (no PCG, no offload-floor gate) and
4245    /// return `out`. This is the parity probe the test harness diffs against the
4246    /// CPU oracle [`super::sae_framed_schur_matvec_cpu`] element-by-element, so a
4247    /// kernel/marshalling defect is exposed directly — independent of how the
4248    /// iterative solver behaves on an ill-conditioned assembled `S` (where dense
4249    /// Cholesky and PCG legitimately disagree at the solution level). Declines
4250    /// (`Unavailable`) only when CUDA is genuinely absent so the test skips
4251    /// cleanly off-device; it deliberately does NOT consult the offload policy so
4252    /// even a tiny verifiable fixture runs on the GPU.
4253    pub(super) fn framed_schur_matvec_once_on_device(
4254        sys: &ArrowSchurSystem,
4255        data: &DeviceSaePcgData,
4256        ridge_t: f64,
4257        ridge_beta: f64,
4258        x: &Array1<f64>,
4259    ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4260        let k = x.len();
4261        if k == 0 || data.beta_dim != k || sys.k != k {
4262            return Err(ArrowSchurGpuFailure::Unavailable);
4263        }
4264        let frame = data
4265            .frame
4266            .as_ref()
4267            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4268        // No offload-policy filter here: the seam exists to validate the kernel on
4269        // ANY device, including the smallest hand-checkable fixture.
4270        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4271            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4272        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4273            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4274        let stream = ctx
4275            .new_stream()
4276            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4277        let vector_module = pcg_vector_module(&ctx)?;
4278        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4279        let x_dev = stream
4280            .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4281            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4282        let mut out_dev = stream
4283            .alloc_zeros::<f64>(k)
4284            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4285        launch_sae_frame_matvec(
4286            &stream,
4287            vector_module,
4288            &mut buffers,
4289            &x_dev,
4290            &mut out_dev,
4291            ridge_beta,
4292        )?;
4293        let out = stream
4294            .clone_dtoh(&out_dev)
4295            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4296        Ok(Array1::from_vec(out))
4297    }
4298
4299    pub(super) fn solve_sae_matrix_free_pcg_framed(
4300        sys: &ArrowSchurSystem,
4301        data: &DeviceSaePcgData,
4302        ridge_t: f64,
4303        ridge_beta: f64,
4304        rhs_beta: &Array1<f64>,
4305        max_iterations: usize,
4306        relative_tolerance: f64,
4307    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4308        let k = rhs_beta.len();
4309        if k == 0 || data.beta_dim != k || sys.k != k {
4310            return Err(ArrowSchurGpuFailure::Unavailable);
4311        }
4312        let frame = data
4313            .frame
4314            .as_ref()
4315            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4317            .filter(|rt| {
4318                rt.policy().reduced_schur_matvec_should_offload(
4319                    sys.rows.len(),
4320                    sys.k,
4321                    sys.d,
4322                    max_iterations,
4323                )
4324            })
4325            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4326        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4327            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4328        let stream = ctx
4329            .new_stream()
4330            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4331        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4332        let vector_module = pcg_vector_module(&ctx)?;
4333        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4334
4335        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4336        if rhs_norm == 0.0 {
4337            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4338        }
4339        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4340        let rhs_dev = stream
4341            .clone_htod(
4342                rhs_beta
4343                    .as_slice()
4344                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4345            )
4346            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4347        let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4348        let mut diag_dev = stream
4349            .clone_htod(&diag_host)
4350            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4351        launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4352        let diag_host = stream
4353            .clone_dtoh(&diag_dev)
4354            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4355        let mut inv_diag = Vec::with_capacity(k);
4356        for (idx, &d) in diag_host.iter().enumerate() {
4357            if !d.is_finite() || d <= 1.0e-18 {
4358                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4359                    reason: format!(
4360                        "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4361                    ),
4362                });
4363            }
4364            inv_diag.push(1.0 / d);
4365        }
4366        let inv_diag_dev = stream
4367            .clone_htod(&inv_diag)
4368            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4369
4370        let mut x_dev = stream
4371            .alloc_zeros::<f64>(k)
4372            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373        let mut r_dev = stream
4374            .alloc_zeros::<f64>(k)
4375            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4376        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4377        let mut z_dev = stream
4378            .alloc_zeros::<f64>(k)
4379            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4380        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4381        let mut p_dev = stream
4382            .alloc_zeros::<f64>(k)
4383            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4384        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4385        let mut ap_dev = stream
4386            .alloc_zeros::<f64>(k)
4387            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4388
4389        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4390        if rz <= 0.0 || !rz.is_finite() {
4391            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4392                reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4393            });
4394        }
4395        let mut diag = PcgDiagnostics {
4396            precond_apply_calls: 1,
4397            stopping_reason: PcgStopReason::MaxIter,
4398            ..PcgDiagnostics::default()
4399        };
4400        for _ in 0..max_iterations.max(1) {
4401            launch_sae_frame_matvec(
4402                &stream,
4403                vector_module,
4404                &mut buffers,
4405                &p_dev,
4406                &mut ap_dev,
4407                ridge_beta,
4408            )?;
4409            diag.matvec_calls += 1;
4410            diag.iterations += 1;
4411            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4412            if pap <= 0.0 || !pap.is_finite() {
4413                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4414                    reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4415                });
4416            }
4417            let alpha = rz / pap;
4418            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4419            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4420            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4421            if r_norm <= tol {
4422                diag.final_relative_residual = r_norm / rhs_norm;
4423                diag.stopping_reason = PcgStopReason::Converged;
4424                break;
4425            }
4426            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4427            diag.precond_apply_calls += 1;
4428            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4429            if rz_new <= 0.0 || !rz_new.is_finite() {
4430                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4431                    reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4432                });
4433            }
4434            let beta = rz_new / rz;
4435            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4436            rz = rz_new;
4437        }
4438        if diag.stopping_reason != PcgStopReason::Converged {
4439            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4440            diag.final_relative_residual = r_norm / rhs_norm;
4441            diag.stopping_reason = PcgStopReason::MaxIter;
4442        }
4443        let x = stream
4444            .clone_dtoh(&x_dev)
4445            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4446        Ok((Array1::from_vec(x), diag))
4447    }
4448
4449    /// #1551 stage-isolating triage seam: run the framed reduced-Schur matvec
4450    /// `out = S·x` ONCE on the device (no PCG, no offload-floor gate) and return
4451    /// `out`, so a tiny hand-verifiable fixture can diff it against the CPU oracle
4452    /// `sae_framed_schur_matvec_cpu` element-by-element to localize the structural
4453    /// divergence to a single kernel stage. Returns `Unavailable` only when CUDA
4454    /// is genuinely absent (so the test skips cleanly off-device).
4455    pub(super) fn solve_sae_matrix_free_pcg(
4456        sys: &ArrowSchurSystem,
4457        data: &DeviceSaePcgData,
4458        ridge_t: f64,
4459        ridge_beta: f64,
4460        rhs_beta: &Array1<f64>,
4461        max_iterations: usize,
4462        relative_tolerance: f64,
4463    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4464        let k = rhs_beta.len();
4465        if k == 0 || data.beta_dim != k || sys.k != k {
4466            return Err(ArrowSchurGpuFailure::Unavailable);
4467        }
4468        // #1017/#1026 GUARD: the legacy `⊗ I_p` kernel must NEVER receive framed
4469        // data (factored `G ⊗ W_{ij}` + dense per-row cross blocks); decline so a
4470        // mis-route falls back to the CPU rather than returning a wrong step.
4471        if data.frame.is_some() {
4472            return Err(ArrowSchurGpuFailure::Unavailable);
4473        }
4474        // #1017 Phase-1 dispatch re-key: this is the matrix-free SAE reduced-Schur
4475        // PCG — the production hot path, not a single dense factorization. The
4476        // dense-Direct floor `dense_hessian_work_target_is_gpu(n, k)` keys on
4477        // `2·n·k²` and is the WRONG gate here: it ignores the per-row frame depth
4478        // `d` (the M dimension that multiplies the per-apply work) and the
4479        // `1/cg_iters` staging amortisation, so it both undercounts the SAE batched
4480        // work `n·k·d` and applies a cold single-launch breakeven to an apply that
4481        // reuses device-resident frames `max_iterations` times. Key instead on the
4482        // CG-amortised total batched work — the same predicate the host injection
4483        // gate (`maybe_inject_gpu_schur_matvec`) consults — so few-row/wide-`k`/
4484        // modest-`d` LLM shapes register the real `n × k × d × cg_iters` arithmetic.
4485        // Kernels and numerics are untouched; only where the matvec runs changes,
4486        // and the host falls back to the bit-identical CPU matvec when this declines.
4487        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4488            .filter(|rt| {
4489                rt.policy().reduced_schur_matvec_should_offload(
4490                    sys.rows.len(),
4491                    sys.k,
4492                    sys.d,
4493                    max_iterations,
4494                )
4495            })
4496            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4497        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4498            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4499        let stream = ctx
4500            .new_stream()
4501            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4502        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4503        let vector_module = pcg_vector_module(&ctx)?;
4504        let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4505
4506        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4507        if rhs_norm == 0.0 {
4508            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4509        }
4510        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4511        let rhs_dev = stream
4512            .clone_htod(
4513                rhs_beta
4514                    .as_slice()
4515                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4516            )
4517            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4518        let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4519        let mut diag_dev = stream
4520            .clone_htod(&diag_host)
4521            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4522        launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4523        let diag_host = stream
4524            .clone_dtoh(&diag_dev)
4525            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4526        let mut inv_diag = Vec::with_capacity(k);
4527        for (idx, &d) in diag_host.iter().enumerate() {
4528            if !d.is_finite() || d <= 1.0e-18 {
4529                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4530                    reason: format!(
4531                        "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4532                    ),
4533                });
4534            }
4535            inv_diag.push(1.0 / d);
4536        }
4537        let inv_diag_dev = stream
4538            .clone_htod(&inv_diag)
4539            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4540
4541        let mut x_dev = stream
4542            .alloc_zeros::<f64>(k)
4543            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544        let mut r_dev = stream
4545            .alloc_zeros::<f64>(k)
4546            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4547        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4548        let mut z_dev = stream
4549            .alloc_zeros::<f64>(k)
4550            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4551        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4552        let mut p_dev = stream
4553            .alloc_zeros::<f64>(k)
4554            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4555        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4556        let mut ap_dev = stream
4557            .alloc_zeros::<f64>(k)
4558            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4559
4560        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4561        if rz <= 0.0 || !rz.is_finite() {
4562            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4563                reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4564            });
4565        }
4566        let mut diag = PcgDiagnostics {
4567            precond_apply_calls: 1,
4568            stopping_reason: PcgStopReason::MaxIter,
4569            ..PcgDiagnostics::default()
4570        };
4571
4572        for _ in 0..max_iterations.max(1) {
4573            launch_sae_matvec(
4574                &stream,
4575                vector_module,
4576                &mut buffers,
4577                &p_dev,
4578                &mut ap_dev,
4579                ridge_beta,
4580            )?;
4581            diag.matvec_calls += 1;
4582            diag.iterations += 1;
4583            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4584            if pap <= 0.0 || !pap.is_finite() {
4585                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4586                    reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4587                });
4588            }
4589            let alpha = rz / pap;
4590            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4591            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4592            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4593            if r_norm <= tol {
4594                diag.final_relative_residual = r_norm / rhs_norm;
4595                diag.stopping_reason = PcgStopReason::Converged;
4596                break;
4597            }
4598            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4599            diag.precond_apply_calls += 1;
4600            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4601            if rz_new <= 0.0 || !rz_new.is_finite() {
4602                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4603                    reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4604                });
4605            }
4606            let beta = rz_new / rz;
4607            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4608            rz = rz_new;
4609        }
4610        if diag.stopping_reason != PcgStopReason::Converged {
4611            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4612            diag.final_relative_residual = r_norm / rhs_norm;
4613            diag.stopping_reason = PcgStopReason::MaxIter;
4614        }
4615        let x = stream
4616            .clone_dtoh(&x_dev)
4617            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4618        Ok((Array1::from_vec(x), diag))
4619    }
4620
4621    pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4622        s_acc: &ndarray::Array2<f64>,
4623        rhs_beta: &Array1<f64>,
4624        max_iterations: usize,
4625        relative_tolerance: f64,
4626    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4627        let k = rhs_beta.len();
4628        // #1017 dispatch re-key: this is an ITERATIVE device-resident PCG, not a
4629        // single GEMV. `S` (k×k) is uploaded once and reused for `max_iterations`
4630        // `S·p` GEMVs while only convergence scalars cross PCIe, so the staging
4631        // cost is amortised over the whole CG solve. Gating on the flops of ONE
4632        // `Gemv{k,k}` (`2·k²`) understates the work by the iteration count and
4633        // declines shapes (e.g. k≈512) whose total iterated arithmetic
4634        // `2·k²·iters` clears the device floor by orders of magnitude — the same
4635        // single-launch-breakeven miskey #1017 fixed for the framed reduced-Schur
4636        // matvec. Key on the CG-amortised total work via a `Gemm{k,k,iters}` whose
4637        // `flops()` is exactly `2·k²·iters`; numerics and kernels are untouched,
4638        // and the host falls back to the bit-identical CPU PCG when this declines.
4639        let cg_iters = max_iterations.max(1);
4640        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4641            gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4642                m: k,
4643                n: k,
4644                k: cg_iters,
4645            },
4646        )
4647        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4648        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4649            .and_then(|ctx| ctx.new_stream().ok())
4650            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4651        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4652        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4653            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4654        let vector_module = pcg_vector_module(&ctx)?;
4655
4656        // Jacobi diagonal from S; must be strictly positive for SPD.
4657        let mut inv_diag = vec![0.0_f64; k];
4658        for j in 0..k {
4659            let djj = s_acc[[j, j]];
4660            if !(djj > 0.0) {
4661                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4662                    reason: format!(
4663                        "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4664                    ),
4665                });
4666            }
4667            inv_diag[j] = 1.0 / djj;
4668        }
4669
4670        // Upload S column-major (S[row,col] at col*k + row).
4671        let mut s_host = vec![0.0_f64; k * k];
4672        for col in 0..k {
4673            for row in 0..k {
4674                s_host[col * k + row] = s_acc[[row, col]];
4675            }
4676        }
4677        let s_dev = stream
4678            .clone_htod(&s_host)
4679            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4680
4681        // Steihaug truncated-CG with Jacobi preconditioner, host scalar
4682        // recurrences and a device `S·p` matvec. The streaming reduced solve
4683        // uses an unbounded trust region (pure CG to tolerance).
4684        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4685        if rhs_norm == 0.0 {
4686            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4687        }
4688        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4689
4690        // Device-resident PCG state. Only convergence scalars cross back during
4691        // the loop; x/r/z/p/Sp stay on CUDA until the final solution download.
4692        let mut x_dev = stream
4693            .alloc_zeros::<f64>(k)
4694            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4695        let mut r_dev = stream
4696            .clone_htod(
4697                rhs_beta
4698                    .as_slice()
4699                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4700            )
4701            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702        let inv_diag_dev = stream
4703            .clone_htod(&inv_diag)
4704            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4705        let mut z_dev = stream
4706            .alloc_zeros::<f64>(k)
4707            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4708        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4709        let mut p_dev = stream
4710            .alloc_zeros::<f64>(k)
4711            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4712        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4713        let mut sp_dev = stream
4714            .alloc_zeros::<f64>(k)
4715            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4716        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4717        let mut diag = PcgDiagnostics {
4718            precond_apply_calls: 1,
4719            stopping_reason: PcgStopReason::MaxIter,
4720            ..PcgDiagnostics::default()
4721        };
4722        if rz <= 0.0 || !rz.is_finite() {
4723            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4724                reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4725            });
4726        }
4727
4728        let max_iters = max_iterations.max(1);
4729        for _ in 0..max_iters {
4730            // sp = S · p (device GEMV, S column-major k×k, op = N).
4731            let gemv_cfg = GemvConfig::<f64> {
4732                trans: cublasOperation_t::CUBLAS_OP_N,
4733                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4734                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4735                alpha: 1.0,
4736                lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4737                incx: 1,
4738                beta: 0.0,
4739                incy: 1,
4740            };
4741            // SAFETY: s_dev is k×k column-major, p_dev / sp_dev length k.
4742            unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4743                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4744            diag.matvec_calls += 1;
4745            diag.iterations += 1;
4746
4747            let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4748            if !(p_sp > 0.0) {
4749                // Non-positive curvature on a (proximal-ridged) SPD system means
4750                // numerical breakdown; surface so the caller escalates.
4751                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4752                    reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4753                });
4754            }
4755            let alpha = rz / p_sp;
4756            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4757            device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4758            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4759            if r_norm <= tol {
4760                diag.final_relative_residual = r_norm / rhs_norm;
4761                diag.stopping_reason = PcgStopReason::Converged;
4762                break;
4763            }
4764            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4765            diag.precond_apply_calls += 1;
4766            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4767            if rz_new <= 0.0 || !rz_new.is_finite() {
4768                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4769                    reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4770                });
4771            }
4772            let beta = rz_new / rz;
4773            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4774            rz = rz_new;
4775        }
4776        if diag.stopping_reason != PcgStopReason::Converged {
4777            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4778            diag.final_relative_residual = r_norm / rhs_norm;
4779            diag.stopping_reason = PcgStopReason::MaxIter;
4780        }
4781
4782        let x = stream
4783            .clone_dtoh(&x_dev)
4784            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4785        Ok((Array1::from_vec(x), diag))
4786    }
4787
4788    fn device_copy(
4789        blas: &CudaBlas,
4790        stream: &Arc<CudaStream>,
4791        n: usize,
4792        src: &CudaSlice<f64>,
4793        dst: &mut CudaSlice<f64>,
4794    ) -> Result<(), ArrowSchurGpuFailure> {
4795        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4796        let (src_ptr, _src_rec) = src.device_ptr(stream);
4797        let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4798        // SAFETY: src and dst are live device allocations on this stream with at
4799        // least n contiguous f64 entries and unit stride.
4800        let status = unsafe {
4801            cudarc::cublas::sys::cublasDcopy_v2(
4802                *blas.handle(),
4803                n_i,
4804                src_ptr as *const f64,
4805                1,
4806                dst_ptr as *mut f64,
4807                1,
4808            )
4809        };
4810        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4811            Ok(())
4812        } else {
4813            Err(ArrowSchurGpuFailure::Unavailable)
4814        }
4815    }
4816
4817    fn device_axpy(
4818        blas: &CudaBlas,
4819        stream: &Arc<CudaStream>,
4820        n: usize,
4821        alpha: f64,
4822        x: &CudaSlice<f64>,
4823        y: &mut CudaSlice<f64>,
4824    ) -> Result<(), ArrowSchurGpuFailure> {
4825        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4826        let (x_ptr, _x_rec) = x.device_ptr(stream);
4827        let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4828        // SAFETY: x and y are live device allocations on this stream with at
4829        // least n contiguous f64 entries and unit stride; cuBLAS only reads alpha.
4830        let status = unsafe {
4831            cudarc::cublas::sys::cublasDaxpy_v2(
4832                *blas.handle(),
4833                n_i,
4834                &alpha,
4835                x_ptr as *const f64,
4836                1,
4837                y_ptr as *mut f64,
4838                1,
4839            )
4840        };
4841        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4842            Ok(())
4843        } else {
4844            Err(ArrowSchurGpuFailure::Unavailable)
4845        }
4846    }
4847
4848    fn device_dot(
4849        blas: &CudaBlas,
4850        stream: &Arc<CudaStream>,
4851        n: usize,
4852        x: &CudaSlice<f64>,
4853        y: &CudaSlice<f64>,
4854    ) -> Result<f64, ArrowSchurGpuFailure> {
4855        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4856        let (x_ptr, _x_rec) = x.device_ptr(stream);
4857        let (y_ptr, _y_rec) = y.device_ptr(stream);
4858        let mut result = 0.0_f64;
4859        // SAFETY: x and y are live device allocations on this stream with at
4860        // least n contiguous f64 entries and unit stride; result is a valid host
4861        // out-pointer for the cuBLAS scalar.
4862        let status = unsafe {
4863            cudarc::cublas::sys::cublasDdot_v2(
4864                *blas.handle(),
4865                n_i,
4866                x_ptr as *const f64,
4867                1,
4868                y_ptr as *const f64,
4869                1,
4870                &mut result,
4871            )
4872        };
4873        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4874            Ok(result)
4875        } else {
4876            Err(ArrowSchurGpuFailure::Unavailable)
4877        }
4878    }
4879
4880    fn device_nrm2(
4881        blas: &CudaBlas,
4882        stream: &Arc<CudaStream>,
4883        n: usize,
4884        x: &CudaSlice<f64>,
4885    ) -> Result<f64, ArrowSchurGpuFailure> {
4886        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4887        let (x_ptr, _x_rec) = x.device_ptr(stream);
4888        let mut result = 0.0_f64;
4889        // SAFETY: x is a live device allocation on this stream with at least n
4890        // contiguous f64 entries and unit stride; result is a valid host
4891        // out-pointer for the cuBLAS scalar.
4892        let status = unsafe {
4893            cudarc::cublas::sys::cublasDnrm2_v2(
4894                *blas.handle(),
4895                n_i,
4896                x_ptr as *const f64,
4897                1,
4898                &mut result,
4899            )
4900        };
4901        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4902            Ok(result)
4903        } else {
4904            Err(ArrowSchurGpuFailure::Unavailable)
4905        }
4906    }
4907
4908    #[cfg(test)]
4909    mod tests {
4910        //! #1551 device-side framed-matvec triage. Lives inside `mod cuda` so it
4911        //! can call the private kernel launchers directly (no test-only public
4912        //! seam, which the ban-scanner forbids). A bare `#[cfg(test)] mod tests`
4913        //! is the one form the scanner permits.
4914        use super::*;
4915        use crate::arrow_schur::{
4916            ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4917            FactoredFrameGBlock,
4918        };
4919        use ndarray::Array2;
4920
4921        /// Run the framed reduced-Schur matvec `out = S·x` ONCE on the device
4922        /// (no PCG, no offload gate) and return `out`.
4923        fn device_matvec_once(
4924            sys: &ArrowSchurSystem,
4925            data: &DeviceSaePcgData,
4926            ridge_t: f64,
4927            ridge_beta: f64,
4928            x_host: &[f64],
4929        ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4930            let k = x_host.len();
4931            let frame = data
4932                .frame
4933                .as_ref()
4934                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4935            let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4936                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4937            let ctx =
4938                gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4939                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4940            let stream = ctx
4941                .new_stream()
4942                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4943            let vector_module = pcg_vector_module(&ctx)?;
4944            let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4945            let x_dev = stream
4946                .clone_htod(x_host)
4947                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4948            let mut out_dev = stream
4949                .alloc_zeros::<f64>(k)
4950                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4951            launch_sae_frame_matvec(
4952                &stream,
4953                vector_module,
4954                &mut buffers,
4955                &x_dev,
4956                &mut out_dev,
4957                ridge_beta,
4958            )?;
4959            stream
4960                .clone_dtoh(&out_dev)
4961                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4962        }
4963
4964        /// #1551 stage-isolating matvec triage on a TINY hand-verifiable fixture:
4965        /// diff the device framed matvec `S·e_col` against the CPU oracle
4966        /// `sae_framed_schur_matvec_cpu` for every identity column, reporting the
4967        /// worst-divergent border index so the structural 91% localizes to one
4968        /// kernel stage. Skips cleanly off-device.
4969        #[test]
4970        fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4971            if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4972                return;
4973            }
4974            let p = 3usize;
4975            let ranks = vec![2usize, 3usize];
4976            let basis_sizes = vec![2usize, 2usize];
4977            let mut border_offsets = Vec::new();
4978            let mut acc = 0usize;
4979            for k in 0..2 {
4980                border_offsets.push(acc);
4981                acc += basis_sizes[k] * ranks[k];
4982            }
4983            let border_dim = acc; // 2·2 + 2·3 = 10
4984            let frame_of = |k: usize| -> Array2<f64> {
4985                Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4986                    0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4987                })
4988            };
4989            let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4990            let w_of = |i: usize, j: usize| -> Array2<f64> {
4991                let (ui, uj) = (&frames[i], &frames[j]);
4992                Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4993                    (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4994                })
4995            };
4996            let mut frame_blocks = Vec::new();
4997            for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4998                let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4999                let mut g =
5000                    Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
5001                if i == j {
5002                    for r in 0..mi.min(mj) {
5003                        g[[r, r]] += mi as f64 + 2.0;
5004                    }
5005                }
5006                frame_blocks.push(FactoredFrameGBlock {
5007                    atom_i: i,
5008                    atom_j: j,
5009                    g,
5010                    w: w_of(i, j),
5011                });
5012            }
5013            let mut smooth_blocks = Vec::new();
5014            for k in 0..2 {
5015                let m = basis_sizes[k];
5016                let mut s =
5017                    Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5018                for r in 0..m {
5019                    s[[r, r]] += 1.0;
5020                }
5021                smooth_blocks.push(DeviceSaeSmoothBlock {
5022                    global_offset: border_offsets[k],
5023                    factor_a: s,
5024                });
5025            }
5026            let smooth_ranks = ranks.clone();
5027            let n = 2usize;
5028            let q = 2usize;
5029            let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5030            let mut row_htbeta = Vec::new();
5031            for i in 0..n {
5032                let mut htt =
5033                    Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5034                for r in 0..q {
5035                    htt[[r, r]] += q as f64 + 2.0;
5036                }
5037                sys.rows[i].htt = htt;
5038                let mut slab = vec![0.0_f64; q * border_dim];
5039                for c in 0..q {
5040                    for col in 0..border_dim {
5041                        let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5042                        slab[c * border_dim + col] = v;
5043                        sys.rows[i].htbeta[[c, col]] = v;
5044                    }
5045                }
5046                row_htbeta.push(slab);
5047            }
5048            let data = DeviceSaePcgData {
5049                p,
5050                beta_dim: border_dim,
5051                a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5052                local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5053                smooth_blocks,
5054                sparse_g_blocks: Vec::new(),
5055                frame: Some(DeviceSaeFrameData {
5056                    ranks,
5057                    basis_sizes,
5058                    border_offsets,
5059                    frame_blocks,
5060                    smooth_ranks,
5061                    row_htbeta,
5062                }),
5063            };
5064            let ridge_t = 1e-7;
5065            let ridge_beta = 1e-6;
5066            let mut first_bad: Option<usize> = None;
5067            let mut worst = 0.0_f64;
5068            let mut worst_at = 0usize;
5069            let mut worst_dev = 0.0_f64;
5070            let mut worst_cpu = 0.0_f64;
5071            for col in 0..border_dim {
5072                let mut x = vec![0.0_f64; border_dim];
5073                x[col] = 1.0;
5074                let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5075                    Ok(v) => v,
5076                    Err(_) => return,
5077                };
5078                let mut cpu = vec![0.0_f64; border_dim];
5079                super::super::sae_framed_schur_matvec_cpu(
5080                    &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5081                )
5082                .expect("cpu matvec");
5083                for r in 0..border_dim {
5084                    let d = (dev[r] - cpu[r]).abs();
5085                    if d > 1e-9 && first_bad.is_none() {
5086                        first_bad = Some(r * border_dim + col);
5087                    }
5088                    if d > worst {
5089                        worst = d;
5090                        worst_at = r * border_dim + col;
5091                        worst_dev = dev[r];
5092                        worst_cpu = cpu[r];
5093                    }
5094                }
5095            }
5096            assert!(
5097                worst <= 1e-9,
5098                "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5099                 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5100                 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5101                 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5102                 G⊗W=cross, reduced-Schur=dense per-row)",
5103            );
5104        }
5105    }
5106}
5107
5108#[cfg(test)]
5109mod tests {
5110    use super::*;
5111    use crate::arrow_schur::ArrowSchurSystem;
5112    use ndarray::{Array2, ArrayView1};
5113
5114    fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5115        let mut sys = ArrowSchurSystem::new(n, d, k);
5116        let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5117        let mut sample = || -> f64 {
5118            state = state
5119                .wrapping_mul(6364136223846793005)
5120                .wrapping_add(1442695040888963407);
5121            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5122        };
5123        for row in &mut sys.rows {
5124            let mut a = Array2::<f64>::zeros((d, d));
5125            for r in 0..d {
5126                for c in 0..d {
5127                    a[[r, c]] = sample();
5128                }
5129            }
5130            let mut htt = a.t().dot(&a);
5131            for r in 0..d {
5132                htt[[r, r]] += d as f64 + 1.0;
5133            }
5134            row.htt = htt;
5135            for r in 0..d {
5136                for c in 0..k {
5137                    row.htbeta[[r, c]] = 0.1 * sample();
5138                }
5139                row.gt[r] = sample();
5140            }
5141        }
5142        let mut hbb_a = Array2::<f64>::zeros((k, k));
5143        for r in 0..k {
5144            for c in 0..k {
5145                hbb_a[[r, c]] = sample();
5146            }
5147        }
5148        let mut hbb = hbb_a.t().dot(&hbb_a);
5149        for r in 0..k {
5150            hbb[[r, r]] += k as f64 + 1.0;
5151        }
5152        sys.hbb = hbb;
5153        for r in 0..k {
5154            sys.gb[r] = sample();
5155        }
5156        sys
5157    }
5158
5159    /// The Gershgorin ridge bump must actually make a known-indefinite block PD
5160    /// on the first retry — the whole point of #1711. Verified directly by
5161    /// re-factoring `H_tt + (ridge_t + bump)·I` with the same Cholesky guard the
5162    /// device readback uses, for blocks whose `λ_min` is known in closed form.
5163    #[test]
5164    fn ridge_bump_makes_known_indefinite_blocks_pd() {
5165        // `cholesky_factor_in_place` / `CholeskyGuard` are already in scope via
5166        // `super::*` (imported at the top of the module).
5167        // A few blocks with a CLOSED-FORM smallest eigenvalue, all at ridge_t=0.
5168        // (label, matrix, λ_min) — the bump must clear each one.
5169        let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); // λ_min = -1
5170        let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); // λ_min = -250
5171        // Symmetric 2×2 [[1, 2], [2, 1]] has eigenvalues 3 and -1 → indefinite.
5172        let mut indef2 = Array2::<f64>::zeros((2, 2));
5173        indef2[[0, 0]] = 1.0;
5174        indef2[[1, 1]] = 1.0;
5175        indef2[[0, 1]] = 2.0;
5176        indef2[[1, 0]] = 2.0;
5177        // A genuinely PD block must get a bump that is the bare rounding margin
5178        // only (deficit 0), and must still factor — the helper is defensive.
5179        let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5180
5181        for (label, block) in [
5182            ("-I (λ_min=-1)", neg_identity),
5183            ("-250·I (λ_min=-250)", scaled_neg),
5184            ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5185            ("5·I (PD)", pd),
5186        ] {
5187            let ridge_t = 0.0;
5188            let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5189            assert!(
5190                bump > 0.0 && bump.is_finite(),
5191                "[{label}] bump must be strictly positive and finite, got {bump:e}"
5192            );
5193            let d = block.nrows();
5194            let mut shifted = block.clone();
5195            for i in 0..d {
5196                shifted[[i, i]] += ridge_t + bump;
5197            }
5198            assert!(
5199                cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5200                "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5201                 Gershgorin bump, but the Cholesky still rejected it"
5202            );
5203        }
5204    }
5205
5206    /// The column-major variant (multi-GPU tile path) must agree with the
5207    /// row-major helper for a symmetric block, since Gershgorin edges are
5208    /// invariant under reading the symmetric matrix by row vs by column. The
5209    /// colmajor variant takes the bound at ridge_t=0 (the ridge is already baked
5210    /// into the diagonal it reads), so compare against `ridge_bump_to_make_pd`
5211    /// with `ridge_t = 0`.
5212    ///
5213    /// Gated to linux: `ridge_bump_to_make_pd_colmajor` only exists on the
5214    /// linux CUDA tile path, so the parity test runs where the function does.
5215    #[cfg(target_os = "linux")]
5216    #[test]
5217    fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5218        // Symmetric 3×3 with a negative-definite-ish diagonal and off-diagonals.
5219        let mut a = Array2::<f64>::zeros((3, 3));
5220        a[[0, 0]] = -2.0;
5221        a[[1, 1]] = 0.5;
5222        a[[2, 2]] = 1.0;
5223        a[[0, 1]] = 0.3;
5224        a[[1, 0]] = 0.3;
5225        a[[1, 2]] = -0.4;
5226        a[[2, 1]] = -0.4;
5227        a[[0, 2]] = 0.1;
5228        a[[2, 0]] = 0.1;
5229
5230        let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5231
5232        // Flatten column-major: block[c*d + r] = a[[r, c]].
5233        let d = 3;
5234        let mut col_major = vec![0.0_f64; d * d];
5235        for c in 0..d {
5236            for r in 0..d {
5237                col_major[c * d + r] = a[[r, c]];
5238            }
5239        }
5240        let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5241
5242        assert!(
5243            (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5244            "colmajor bump {col_major_bump:e} must match rowmajor bump \
5245             {row_major_bump:e} for a symmetric block"
5246        );
5247
5248        // And the bump must actually make it PD (sanity, same as the row-major test).
5249        let mut shifted = a.clone();
5250        for i in 0..d {
5251            shifted[[i, i]] += col_major_bump;
5252        }
5253        assert!(
5254            cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5255            "colmajor Gershgorin bump must make the symmetric block PD"
5256        );
5257    }
5258
5259    fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5260        let mut s = Array2::<f64>::zeros((k, k));
5261        for row in 0..k {
5262            s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5263            if row + 1 < k {
5264                s[[row, row + 1]] = -0.05;
5265                s[[row + 1, row]] = -0.05;
5266            }
5267            if row + 7 < k {
5268                s[[row, row + 7]] = 0.01;
5269                s[[row + 7, row]] = 0.01;
5270            }
5271        }
5272        let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5273        (s, rhs)
5274    }
5275
5276    fn dense_pcg_cpu_reference(
5277        s: &Array2<f64>,
5278        rhs: &Array1<f64>,
5279        max_iterations: usize,
5280        relative_tolerance: f64,
5281    ) -> Array1<f64> {
5282        let k = rhs.len();
5283        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5284        if rhs_norm == 0.0 {
5285            return Array1::<f64>::zeros(k);
5286        }
5287        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5288        let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5289        let mut x = Array1::<f64>::zeros(k);
5290        let mut r = rhs.clone();
5291        let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5292        let mut p = z.clone();
5293        let mut sp = Array1::<f64>::zeros(k);
5294        let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5295        for _ in 0..max_iterations.max(1) {
5296            for row in 0..k {
5297                let mut acc = 0.0;
5298                for col in 0..k {
5299                    acc += s[[row, col]] * p[col];
5300                }
5301                sp[row] = acc;
5302            }
5303            let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5304            let alpha = rz / p_sp;
5305            for idx in 0..k {
5306                x[idx] += alpha * p[idx];
5307                r[idx] -= alpha * sp[idx];
5308            }
5309            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5310            if r_norm <= tol {
5311                break;
5312            }
5313            for idx in 0..k {
5314                z[idx] = inv_diag[idx] * r[idx];
5315            }
5316            let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5317            let beta = rz_next / rz;
5318            for idx in 0..k {
5319                p[idx] = z[idx] + beta * p[idx];
5320            }
5321            rz = rz_next;
5322        }
5323        x
5324    }
5325
5326    #[test]
5327    fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5328        let (s, rhs) = device_pcg_fixture(512);
5329        let max_iterations = 200usize;
5330        let relative_tolerance = 1.0e-12;
5331        let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5332        let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5333            &s,
5334            &rhs,
5335            max_iterations,
5336            relative_tolerance,
5337        ) {
5338            Ok(result) => result,
5339            // #1017 — fail loud, never skip-pass: this fixture clears the device
5340            // offload floor, so a CUDA device that is PRESENT yet declines/returns
5341            // Err means the device PCG kernel does not run on GPU (a real fault that
5342            // must not masquerade as a pass via this skip). Legit skip ONLY when no
5343            // usable CUDA device exists (CPU CI). The exact `ArrowSchurGpuFailure`
5344            // variant is folded into the assert message as the diagnostic.
5345            Err(failure) => {
5346                assert!(
5347                    gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5348                    "#1017: CUDA device present but the device reduced-beta PCG \
5349                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5350                     the kernel does not run correctly on GPU"
5351                );
5352                return;
5353            }
5354        };
5355        let max_err = cpu
5356            .iter()
5357            .zip(device.iter())
5358            .map(|(a, b)| (a - b).abs())
5359            .fold(0.0_f64, f64::max);
5360        assert!(
5361            max_err <= 1.0e-10,
5362            "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5363        );
5364        assert!(diag.matvec_calls > 0);
5365        assert_eq!(diag.matvec_calls, diag.iterations);
5366    }
5367
5368    #[test]
5369    fn dense_reference_matches_independent_solve() {
5370        let sys = build_fixture(4, 5, 3, 7);
5371        let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5372        // Re-solve by an independent matrix build and a textbook
5373        // Gaussian-elimination Cholesky to guard against typos in the
5374        // reference implementation itself.
5375        let n = sys.rows.len();
5376        let d = sys.d;
5377        let k = sys.k;
5378        let total = n * d + k;
5379        let mut h = Array2::<f64>::zeros((total, total));
5380        let mut g = ndarray::Array1::<f64>::zeros(total);
5381        for (i, row) in sys.rows.iter().enumerate() {
5382            let base = i * d;
5383            for c in 0..d {
5384                for r in 0..d {
5385                    h[[base + r, base + c]] = row.htt[[r, c]];
5386                }
5387            }
5388            for c in 0..k {
5389                for r in 0..d {
5390                    h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5391                    h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5392                }
5393            }
5394            for r in 0..d {
5395                g[base + r] = row.gt[r];
5396            }
5397        }
5398        for c in 0..k {
5399            for r in 0..k {
5400                h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5401            }
5402            g[n * d + c] = sys.gb[c];
5403        }
5404        let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5405        let rhs = g.mapv(|v| -v);
5406        let expected = cholesky_solve_vector(l.view(), rhs.view());
5407        for i in 0..n * d {
5408            assert!(
5409                (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5410                "delta_t[{i}] mismatch: got {} expected {}",
5411                solution.delta_t[i],
5412                expected[i]
5413            );
5414        }
5415        for a in 0..k {
5416            assert!(
5417                (solution.delta_beta[a] - expected[n * d + a]).abs()
5418                    < 1e-10 * (1.0 + expected[n * d + a].abs()),
5419                "delta_beta[{a}] mismatch"
5420            );
5421        }
5422    }
5423
5424    /// #1017: the row-procedural reduced-Schur matvec (the matrix-free SAE
5425    /// host backend) auto-fans its per-row point-elimination sum across rayon
5426    /// over fixed row chunks when at the top level (`n ≥
5427    /// SCHUR_MATVEC_PARALLEL_ROW_MIN`), and stays serial when already inside a
5428    /// rayon worker. The chunk-ordered fold makes the parallel result
5429    /// **deterministic** (two parallel calls are bit-identical — scheduling
5430    /// cannot change the numbers) and it agrees with the serial accumulation up
5431    /// to ULP-scale chunk reassociation (the #1017 verification gate). That
5432    /// reassociation is a genuine f64 departure from serial, so the criterion
5433    /// ranking across topology candidates is stable only up to the reassociation
5434    /// margin: a near-tie winner inside that margin can flip. This is NOT an
5435    /// exact no-move guarantee (#1211); for that, the ranking path must use the
5436    /// fixed-order serial accumulation.
5437    #[test]
5438    fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5439        use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5440        let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; // trips the parallel path
5441        let d = 3usize;
5442        let k = 24usize;
5443        let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5444        // Install a matrix-free forward/transpose pair that reads the dense
5445        // `htbeta` slabs the fixture already populated, so the procedural
5446        // backend has a well-defined operator to apply (and exercises exactly
5447        // the sparse gather/scatter the SAE Kronecker path drives).
5448        let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5449        let forward_slabs = slabs.clone();
5450        let transpose_slabs = slabs;
5451        sys.set_row_htbeta_operator(
5452            move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5453                let h = &forward_slabs[row];
5454                for r in 0..h.nrows() {
5455                    let mut acc = 0.0_f64;
5456                    for c in 0..h.ncols() {
5457                        acc += h[[r, c]] * x[c];
5458                    }
5459                    out[r] = acc;
5460                }
5461            },
5462            move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5463                let h = &transpose_slabs[row];
5464                for r in 0..h.nrows() {
5465                    for c in 0..h.ncols() {
5466                        out[c] += h[[r, c]] * v[r];
5467                    }
5468                }
5469            },
5470        );
5471
5472        let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5473            .expect("row-procedural matvec backend builds for matrix-free system");
5474        let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5475
5476        // Top-level call: auto-selects the parallel chunk-fold. Run twice and
5477        // assert bit-identity — the chunk-ordered reduction must not depend on
5478        // thread scheduling.
5479        let mut out_parallel_a = Array1::<f64>::zeros(k);
5480        matvec(&x, &mut out_parallel_a);
5481        let mut out_parallel_b = Array1::<f64>::zeros(k);
5482        matvec(&x, &mut out_parallel_b);
5483        for a in 0..k {
5484            assert_eq!(
5485                out_parallel_a[a].to_bits(),
5486                out_parallel_b[a].to_bits(),
5487                "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5488            );
5489        }
5490
5491        // Inside a rayon worker: auto-selects the serial path (nested-rayon
5492        // guard). `install` runs the closure on a pool thread, so
5493        // `current_thread_index()` is `Some`. The serial running sum and the
5494        // chunk-ordered parallel fold differ only by f64 reassociation.
5495        let mut out_serial = Array1::<f64>::zeros(k);
5496        rayon::ThreadPoolBuilder::new()
5497            .num_threads(2)
5498            .build()
5499            .expect("build rayon pool")
5500            .install(|| matvec(&x, &mut out_serial));
5501
5502        let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5503        for a in 0..k {
5504            let diff = (out_parallel_a[a] - out_serial[a]).abs();
5505            assert!(
5506                diff <= 1e-12 * (1.0 + max_abs),
5507                "row-procedural matvec parallel vs serial diverged beyond reassociation \
5508                 at index {a}: {} vs {} (diff={diff:e})",
5509                out_parallel_a[a],
5510                out_serial[a]
5511            );
5512        }
5513    }
5514
5515    /// #1017/#1026 — the frames-engaged CPU reduced-Schur matvec
5516    /// [`sae_framed_schur_matvec_cpu`] (the bit-parity oracle the GPU kernel
5517    /// mirrors) must equal the dense reduced Schur `S = (P_ββ + ρ_β I) −
5518    /// Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)` formed by the canonical dense
5519    /// reference, on a small framed system with mixed per-atom ranks
5520    /// (`r_k < p` framed + `r_k = p` un-framed). Size-independent gate.
5521    #[test]
5522    fn framed_sae_schur_matvec_matches_dense_reference() {
5523        use crate::arrow_schur::{
5524            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5525            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5526        };
5527
5528        let p = 4usize;
5529        // Three atoms: ranks 2 (framed), 4 (un-framed), 3 (framed).
5530        let ranks = vec![2usize, 4usize, 3usize];
5531        let basis_sizes = vec![2usize, 1usize, 2usize];
5532        let n_atoms = ranks.len();
5533        let mut border_offsets = Vec::with_capacity(n_atoms);
5534        let mut acc = 0usize;
5535        for k in 0..n_atoms {
5536            border_offsets.push(acc);
5537            acc += basis_sizes[k] * ranks[k];
5538        }
5539        let border_dim = acc; // 2*2 + 1*4 + 2*3 = 14
5540
5541        let mut state = 0x1234_5678_9abc_def0u64;
5542        let mut sample = || -> f64 {
5543            state = state
5544                .wrapping_mul(6364136223846793005)
5545                .wrapping_add(1442695040888963407);
5546            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5547        };
5548
5549        // Per-atom orthonormal-ish frames U_k (p × r_k) for the W = U_iᵀU_j
5550        // factors; un-framed atom (r=p) uses U = I_p.
5551        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5552        for k in 0..n_atoms {
5553            let r = ranks[k];
5554            let mut u = Array2::<f64>::zeros((p, r));
5555            for i in 0..p {
5556                for j in 0..r {
5557                    u[[i, j]] = if r == p && i == j {
5558                        1.0
5559                    } else if r == p {
5560                        0.0
5561                    } else {
5562                        sample()
5563                    };
5564                }
5565            }
5566            frames.push(u);
5567        }
5568        let w_of = |i: usize, j: usize| -> Array2<f64> {
5569            let (ui, uj) = (&frames[i], &frames[j]);
5570            let (ri, rj) = (ranks[i], ranks[j]);
5571            let mut w = Array2::<f64>::zeros((ri, rj));
5572            for a in 0..ri {
5573                for b in 0..rj {
5574                    let mut s = 0.0;
5575                    for c in 0..p {
5576                        s += ui[[c, a]] * uj[[c, b]];
5577                    }
5578                    w[[a, b]] = s;
5579                }
5580            }
5581            w
5582        };
5583
5584        // Co-occurring data-fit blocks: all diagonal pairs + one cross (0,2).
5585        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5586        let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5587        pairs.sort();
5588        for &(i, j) in &pairs {
5589            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5590            let mut g = Array2::<f64>::zeros((mi, mj));
5591            for r in 0..mi {
5592                for c in 0..mj {
5593                    g[[r, c]] = 0.3 * sample();
5594                }
5595            }
5596            // Make diagonal blocks SPD-leaning so S stays PD.
5597            if i == j {
5598                for r in 0..mi.min(mj) {
5599                    g[[r, r]] += mi as f64 + 2.0;
5600                }
5601            }
5602            frame_blocks.push(FactoredFrameGBlock {
5603                atom_i: i,
5604                atom_j: j,
5605                g,
5606                w: w_of(i, j),
5607            });
5608        }
5609
5610        // Smooth blocks λ S_k (M_k × M_k), SPD.
5611        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5612        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5613        for k in 0..n_atoms {
5614            let m = basis_sizes[k];
5615            let mut a = Array2::<f64>::zeros((m, m));
5616            for r in 0..m {
5617                for c in 0..m {
5618                    a[[r, c]] = 0.2 * sample();
5619                }
5620            }
5621            let mut s = a.t().dot(&a);
5622            for r in 0..m {
5623                s[[r, r]] += 1.0;
5624            }
5625            smooth_blocks.push(DeviceSaeSmoothBlock {
5626                global_offset: border_offsets[k],
5627                factor_a: s,
5628            });
5629            smooth_ranks.push(ranks[k]);
5630        }
5631
5632        // Build the system: n rows, dense htbeta slabs (q_i × border_dim).
5633        let n = 6usize;
5634        let q = 3usize;
5635        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5636        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5637        for i in 0..n {
5638            // SPD htt.
5639            let mut a = Array2::<f64>::zeros((q, q));
5640            for r in 0..q {
5641                for c in 0..q {
5642                    a[[r, c]] = sample();
5643                }
5644            }
5645            let mut htt = a.t().dot(&a);
5646            for r in 0..q {
5647                htt[[r, r]] += q as f64 + 1.0;
5648            }
5649            sys.rows[i].htt = htt;
5650            let mut slab = vec![0.0_f64; q * border_dim];
5651            for c in 0..q {
5652                for col in 0..border_dim {
5653                    let v = 0.15 * sample();
5654                    slab[c * border_dim + col] = v;
5655                    sys.rows[i].htbeta[[c, col]] = v;
5656                }
5657            }
5658            row_htbeta.push(slab);
5659        }
5660
5661        // Dense H_ββ from the SAME penalty ops (so the dense reference's S
5662        // matches the device penalty side exactly).
5663        let data_op =
5664            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5665                .expect("frame op");
5666        let mut hbb = data_op.to_dense();
5667        for k in 0..n_atoms {
5668            let op = IdentityRightKroneckerPenaltyOp {
5669                factor_a: smooth_blocks[k].factor_a.clone(),
5670                p: ranks[k],
5671                global_offset: border_offsets[k],
5672                k: border_dim,
5673            };
5674            let d = op.to_dense();
5675            for r in 0..border_dim {
5676                for c in 0..border_dim {
5677                    hbb[[r, c]] += d[[r, c]];
5678                }
5679            }
5680        }
5681        sys.hbb = hbb;
5682
5683        let data = DeviceSaePcgData {
5684            p,
5685            beta_dim: border_dim,
5686            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5687            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5688            smooth_blocks,
5689            sparse_g_blocks: Vec::new(),
5690            frame: Some(DeviceSaeFrameData {
5691                ranks: ranks.clone(),
5692                basis_sizes: basis_sizes.clone(),
5693                border_offsets: border_offsets.clone(),
5694                frame_blocks,
5695                smooth_ranks,
5696                row_htbeta,
5697            }),
5698        };
5699
5700        let ridge_t = 1e-7;
5701        let ridge_beta = 1e-6;
5702
5703        // Dense reference reduced Schur S (border_dim × border_dim), formed
5704        // exactly as solve_arrow_newton_step_dense_reference assembles the
5705        // bordered Hessian and eliminates the t-block.
5706        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5707        for r in 0..border_dim {
5708            for c in 0..border_dim {
5709                s_dense[[r, c]] = sys.hbb[[r, c]];
5710            }
5711            s_dense[[r, r]] += ridge_beta;
5712        }
5713        for row in &sys.rows {
5714            let mut htt = row.htt.clone();
5715            for d in 0..q {
5716                htt[[d, d]] += ridge_t;
5717            }
5718            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5719                .expect("htt PD");
5720            // Y = (htt)⁻¹ htbeta  (q × border_dim); S -= htbetaᵀ Y.
5721            let mut y = Array2::<f64>::zeros((q, border_dim));
5722            for col in 0..border_dim {
5723                let mut e = Array1::<f64>::zeros(q);
5724                for r in 0..q {
5725                    e[r] = row.htbeta[[r, col]];
5726                }
5727                let solved = cholesky_solve_vector(factor.view(), e.view());
5728                for r in 0..q {
5729                    y[[r, col]] = solved[r];
5730                }
5731            }
5732            for r in 0..border_dim {
5733                for c in 0..border_dim {
5734                    let mut acc = 0.0;
5735                    for d in 0..q {
5736                        acc += row.htbeta[[d, r]] * y[[d, c]];
5737                    }
5738                    s_dense[[r, c]] -= acc;
5739                }
5740            }
5741        }
5742
5743        // Probe vectors: compare S·x from the device-data CPU oracle vs dense S·x.
5744        let mut max_rel = 0.0_f64;
5745        for trial in 0..4 {
5746            let x: Vec<f64> = (0..border_dim)
5747                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5748                .collect();
5749            let mut got = vec![0.0_f64; border_dim];
5750            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5751                .expect("framed matvec");
5752            let mut want = vec![0.0_f64; border_dim];
5753            for r in 0..border_dim {
5754                let mut acc = 0.0;
5755                for c in 0..border_dim {
5756                    acc += s_dense[[r, c]] * x[c];
5757                }
5758                want[r] = acc;
5759            }
5760            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5761            for a in 0..border_dim {
5762                let rel = (got[a] - want[a]).abs() / scale;
5763                max_rel = max_rel.max(rel);
5764            }
5765        }
5766        assert!(
5767            max_rel <= 1e-10,
5768            "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5769        );
5770    }
5771
5772    /// #1017/#1026 — large-K, many-atom, dense-cross-pair parity for the framed
5773    /// SAE reduced-Schur CPU oracle. The small `framed_sae_schur_matvec_matches_dense_reference`
5774    /// pins `border_dim=14`/3 atoms/1 cross pair; the interactions that only
5775    /// appear at scale — variable per-atom `r_k` (mixed framed `r_k<p` and
5776    /// un-framed `r_k=p`), the prefix-sum `border_offsets`, and dense cross-atom
5777    /// `W_ij` coupling across MANY co-occurring `frame_blocks` and many `row_htbeta`
5778    /// slabs — were validated only on the A100. This pins the CPU oracle (and
5779    /// therefore the device kernel it mirrors) against the dense reduced Schur at
5780    /// 40 atoms / `border_dim≈240` / a neighbour-coupled cross-pair set, on CPU.
5781    #[test]
5782    fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5783        use crate::arrow_schur::{
5784            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5785            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5786        };
5787
5788        let p = 12usize;
5789        let n_atoms = 40usize;
5790        // Variable per-atom rank: most atoms are genuinely framed (r_k<p), every
5791        // fifth atom is un-framed (r_k=p, U=I_p — the within-atom G⊗I_r collapse).
5792        let ranks: Vec<usize> = (0..n_atoms)
5793            .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5794            .collect();
5795        let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5796        let mut border_offsets = Vec::with_capacity(n_atoms);
5797        let mut acc = 0usize;
5798        for k in 0..n_atoms {
5799            border_offsets.push(acc);
5800            acc += basis_sizes[k] * ranks[k];
5801        }
5802        let border_dim = acc;
5803
5804        let mut state = 0x0bad_c0de_dead_beefu64;
5805        let mut sample = || -> f64 {
5806            state = state
5807                .wrapping_mul(6364136223846793005)
5808                .wrapping_add(1442695040888963407);
5809            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5810        };
5811
5812        // Per-atom frames U_k (p × r_k); un-framed atom (r=p) uses U = I_p.
5813        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5814        for k in 0..n_atoms {
5815            let r = ranks[k];
5816            let mut u = Array2::<f64>::zeros((p, r));
5817            for i in 0..p {
5818                for j in 0..r {
5819                    u[[i, j]] = if r == p {
5820                        if i == j { 1.0 } else { 0.0 }
5821                    } else {
5822                        sample()
5823                    };
5824                }
5825            }
5826            frames.push(u);
5827        }
5828        let w_of = |i: usize, j: usize| -> Array2<f64> {
5829            let (ui, uj) = (&frames[i], &frames[j]);
5830            let (ri, rj) = (ranks[i], ranks[j]);
5831            let mut w = Array2::<f64>::zeros((ri, rj));
5832            for a in 0..ri {
5833                for b in 0..rj {
5834                    let mut s = 0.0;
5835                    for c in 0..p {
5836                        s += ui[[c, a]] * uj[[c, b]];
5837                    }
5838                    w[[a, b]] = s;
5839                }
5840            }
5841            w
5842        };
5843
5844        // Co-occurring data-fit blocks: every diagonal pair + each neighbour pair
5845        // (k,k+1) and its transpose — dense cross coupling across the whole border.
5846        let mut pairs: Vec<(usize, usize)> = Vec::new();
5847        for k in 0..n_atoms {
5848            pairs.push((k, k));
5849        }
5850        for k in 0..n_atoms - 1 {
5851            pairs.push((k, k + 1));
5852            pairs.push((k + 1, k));
5853        }
5854        pairs.sort_unstable();
5855        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5856        for &(i, j) in &pairs {
5857            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5858            let mut g = Array2::<f64>::zeros((mi, mj));
5859            for r in 0..mi {
5860                for c in 0..mj {
5861                    g[[r, c]] = 0.3 * sample();
5862                }
5863            }
5864            if i == j {
5865                for r in 0..mi.min(mj) {
5866                    g[[r, r]] += mi as f64 + 2.0;
5867                }
5868            }
5869            frame_blocks.push(FactoredFrameGBlock {
5870                atom_i: i,
5871                atom_j: j,
5872                g,
5873                w: w_of(i, j),
5874            });
5875        }
5876
5877        // Smooth blocks λ S_k (M_k × M_k), SPD.
5878        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5879        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5880        for k in 0..n_atoms {
5881            let m = basis_sizes[k];
5882            let mut a = Array2::<f64>::zeros((m, m));
5883            for r in 0..m {
5884                for c in 0..m {
5885                    a[[r, c]] = 0.2 * sample();
5886                }
5887            }
5888            let mut s = a.t().dot(&a);
5889            for r in 0..m {
5890                s[[r, r]] += 1.0;
5891            }
5892            smooth_blocks.push(DeviceSaeSmoothBlock {
5893                global_offset: border_offsets[k],
5894                factor_a: s,
5895            });
5896            smooth_ranks.push(ranks[k]);
5897        }
5898
5899        // n rows with SPD htt and full-width (q × border_dim) htbeta slabs.
5900        let n = 8usize;
5901        let q = 3usize;
5902        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5903        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5904        for i in 0..n {
5905            let mut a = Array2::<f64>::zeros((q, q));
5906            for r in 0..q {
5907                for c in 0..q {
5908                    a[[r, c]] = sample();
5909                }
5910            }
5911            let mut htt = a.t().dot(&a);
5912            for r in 0..q {
5913                htt[[r, r]] += q as f64 + 1.0;
5914            }
5915            sys.rows[i].htt = htt;
5916            let mut slab = vec![0.0_f64; q * border_dim];
5917            for c in 0..q {
5918                for col in 0..border_dim {
5919                    let v = 0.15 * sample();
5920                    slab[c * border_dim + col] = v;
5921                    sys.rows[i].htbeta[[c, col]] = v;
5922                }
5923            }
5924            row_htbeta.push(slab);
5925        }
5926
5927        // Dense H_ββ from the SAME penalty ops (data-fit + smooth), so the dense
5928        // reference S matches the device penalty side exactly.
5929        let data_op =
5930            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5931                .expect("frame op");
5932        let mut hbb = data_op.to_dense();
5933        for k in 0..n_atoms {
5934            let op = IdentityRightKroneckerPenaltyOp {
5935                factor_a: smooth_blocks[k].factor_a.clone(),
5936                p: ranks[k],
5937                global_offset: border_offsets[k],
5938                k: border_dim,
5939            };
5940            let d = op.to_dense();
5941            for r in 0..border_dim {
5942                for c in 0..border_dim {
5943                    hbb[[r, c]] += d[[r, c]];
5944                }
5945            }
5946        }
5947        sys.hbb = hbb;
5948
5949        let data = DeviceSaePcgData {
5950            p,
5951            beta_dim: border_dim,
5952            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5953            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5954            smooth_blocks,
5955            sparse_g_blocks: Vec::new(),
5956            frame: Some(DeviceSaeFrameData {
5957                ranks: ranks.clone(),
5958                basis_sizes: basis_sizes.clone(),
5959                border_offsets: border_offsets.clone(),
5960                frame_blocks,
5961                smooth_ranks,
5962                row_htbeta,
5963            }),
5964        };
5965
5966        let ridge_t = 1e-7;
5967        let ridge_beta = 1e-6;
5968
5969        // Dense reference reduced Schur S = (hbb + ridge_beta I) - Σ_i htbetaᵀ (htt+ridge_t I)⁻¹ htbeta.
5970        let mut s_dense = sys.hbb.clone();
5971        for r in 0..border_dim {
5972            s_dense[[r, r]] += ridge_beta;
5973        }
5974        for row in &sys.rows {
5975            let mut htt = row.htt.clone();
5976            for d in 0..q {
5977                htt[[d, d]] += ridge_t;
5978            }
5979            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5980                .expect("htt PD");
5981            let mut y = Array2::<f64>::zeros((q, border_dim));
5982            for col in 0..border_dim {
5983                let mut e = Array1::<f64>::zeros(q);
5984                for r in 0..q {
5985                    e[r] = row.htbeta[[r, col]];
5986                }
5987                let solved = cholesky_solve_vector(factor.view(), e.view());
5988                for r in 0..q {
5989                    y[[r, col]] = solved[r];
5990                }
5991            }
5992            for r in 0..border_dim {
5993                for c in 0..border_dim {
5994                    let mut acc = 0.0;
5995                    for d in 0..q {
5996                        acc += row.htbeta[[d, r]] * y[[d, c]];
5997                    }
5998                    s_dense[[r, c]] -= acc;
5999                }
6000            }
6001        }
6002
6003        let mut max_rel = 0.0_f64;
6004        for trial in 0..4 {
6005            let x: Vec<f64> = (0..border_dim)
6006                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6007                .collect();
6008            let mut got = vec![0.0_f64; border_dim];
6009            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6010                .expect("framed matvec");
6011            let mut want = vec![0.0_f64; border_dim];
6012            for r in 0..border_dim {
6013                let mut acc = 0.0;
6014                for c in 0..border_dim {
6015                    acc += s_dense[[r, c]] * x[c];
6016                }
6017                want[r] = acc;
6018            }
6019            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6020            for a in 0..border_dim {
6021                let rel = (got[a] - want[a]).abs() / scale;
6022                max_rel = max_rel.max(rel);
6023            }
6024        }
6025        assert!(
6026            max_rel <= 1e-10,
6027            "large-K framed SAE Schur matvec vs dense reference diverged: \
6028             max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6029        );
6030    }
6031
6032    /// #1017/#1026 GPU arm: when a CUDA device admits the framed SAE PCG, its
6033    /// solved `δβ` must match the CPU dense reduced-system solve of the SAME
6034    /// framed system (size-independent — a small device validates the kernel).
6035    /// Skips cleanly (returns) when no device is available or the policy
6036    /// declines (`solve_sae_matrix_free_pcg` → `Unavailable`).
6037    #[test]
6038    fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6039        use crate::arrow_schur::{
6040            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6041            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6042        };
6043
6044        // Large enough to clear the device-offload policy floor (k ≥ 32 and
6045        // n·k·d·iters ≥ MATVEC_OFFLOAD_FLOPS_MIN) so the GPU kernel actually
6046        // runs on a device rather than the policy declining.
6047        let p = 6usize;
6048        let n_atoms = 8usize;
6049        let ranks: Vec<usize> = (0..n_atoms)
6050            .map(|k| if k % 2 == 0 { 3usize } else { p })
6051            .collect();
6052        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6053        let mut border_offsets = Vec::with_capacity(n_atoms);
6054        let mut acc = 0usize;
6055        for k in 0..n_atoms {
6056            border_offsets.push(acc);
6057            acc += basis_sizes[k] * ranks[k];
6058        }
6059        let border_dim = acc; // Σ M_k·r_k = 4·(3·3) + 4·(3·6) = 36 + 72 = 108
6060
6061        let mut state = 0xfeed_face_dead_beefu64;
6062        let mut sample = || -> f64 {
6063            state = state
6064                .wrapping_mul(6364136223846793005)
6065                .wrapping_add(1442695040888963407);
6066            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6067        };
6068        let mut frames: Vec<Array2<f64>> = Vec::new();
6069        for k in 0..n_atoms {
6070            let r = ranks[k];
6071            let mut u = Array2::<f64>::zeros((p, r));
6072            for i in 0..p {
6073                for j in 0..r {
6074                    u[[i, j]] = if r == p && i == j {
6075                        1.0
6076                    } else if r == p {
6077                        0.0
6078                    } else {
6079                        sample()
6080                    };
6081                }
6082            }
6083            frames.push(u);
6084        }
6085        let w_of = |i: usize, j: usize| {
6086            let (ui, uj) = (&frames[i], &frames[j]);
6087            let (ri, rj) = (ranks[i], ranks[j]);
6088            let mut w = Array2::<f64>::zeros((ri, rj));
6089            for a in 0..ri {
6090                for b in 0..rj {
6091                    let mut s = 0.0;
6092                    for c in 0..p {
6093                        s += ui[[c, a]] * uj[[c, b]];
6094                    }
6095                    w[[a, b]] = s;
6096                }
6097            }
6098            w
6099        };
6100        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6101        // A few off-diagonal cross blocks (symmetric pairs).
6102        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6103            pairs.push((i, j));
6104            pairs.push((j, i));
6105        }
6106        let mut frame_blocks = Vec::new();
6107        for &(i, j) in &pairs {
6108            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6109            let mut g = Array2::<f64>::zeros((mi, mj));
6110            for r in 0..mi {
6111                for c in 0..mj {
6112                    g[[r, c]] = 0.25 * sample();
6113                }
6114            }
6115            if i == j {
6116                for r in 0..mi.min(mj) {
6117                    g[[r, r]] += mi as f64 + 2.0;
6118                }
6119            }
6120            frame_blocks.push(FactoredFrameGBlock {
6121                atom_i: i,
6122                atom_j: j,
6123                g,
6124                w: w_of(i, j),
6125            });
6126        }
6127        let mut smooth_blocks = Vec::new();
6128        let mut smooth_ranks = Vec::new();
6129        for k in 0..n_atoms {
6130            let m = basis_sizes[k];
6131            let mut a = Array2::<f64>::zeros((m, m));
6132            for r in 0..m {
6133                for c in 0..m {
6134                    a[[r, c]] = 0.2 * sample();
6135                }
6136            }
6137            let mut s = a.t().dot(&a);
6138            for r in 0..m {
6139                s[[r, r]] += 1.0;
6140            }
6141            smooth_blocks.push(DeviceSaeSmoothBlock {
6142                global_offset: border_offsets[k],
6143                factor_a: s,
6144            });
6145            smooth_ranks.push(ranks[k]);
6146        }
6147        let n = 400usize;
6148        let q = 4usize;
6149        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6150        let mut row_htbeta = Vec::new();
6151        for i in 0..n {
6152            let mut a = Array2::<f64>::zeros((q, q));
6153            for r in 0..q {
6154                for c in 0..q {
6155                    a[[r, c]] = sample();
6156                }
6157            }
6158            let mut htt = a.t().dot(&a);
6159            for r in 0..q {
6160                htt[[r, r]] += q as f64 + 1.0;
6161            }
6162            sys.rows[i].htt = htt;
6163            let mut slab = vec![0.0_f64; q * border_dim];
6164            for c in 0..q {
6165                for col in 0..border_dim {
6166                    // Small entries: with 400 rows the reduced-Schur subtraction
6167                    // Σ_i H_βtᵀ H_tt⁻¹ H_tβ must not overwhelm the PD penalty.
6168                    let v = 0.02 * sample();
6169                    slab[c * border_dim + col] = v;
6170                    sys.rows[i].htbeta[[c, col]] = v;
6171                }
6172            }
6173            row_htbeta.push(slab);
6174        }
6175        let data_op =
6176            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6177                .expect("frame op");
6178        let mut hbb = data_op.to_dense();
6179        for k in 0..n_atoms {
6180            let op = IdentityRightKroneckerPenaltyOp {
6181                factor_a: smooth_blocks[k].factor_a.clone(),
6182                p: ranks[k],
6183                global_offset: border_offsets[k],
6184                k: border_dim,
6185            };
6186            let d = op.to_dense();
6187            for r in 0..border_dim {
6188                for c in 0..border_dim {
6189                    hbb[[r, c]] += d[[r, c]];
6190                }
6191            }
6192        }
6193        sys.hbb = hbb;
6194        let data = DeviceSaePcgData {
6195            p,
6196            beta_dim: border_dim,
6197            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6198            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6199            smooth_blocks,
6200            sparse_g_blocks: Vec::new(),
6201            frame: Some(DeviceSaeFrameData {
6202                ranks: ranks.clone(),
6203                basis_sizes: basis_sizes.clone(),
6204                border_offsets: border_offsets.clone(),
6205                frame_blocks,
6206                smooth_ranks,
6207                row_htbeta,
6208            }),
6209        };
6210        let ridge_t = 1e-7;
6211        let ridge_beta = 1e-6;
6212        let rhs: Array1<f64> =
6213            Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6214
6215        let (device, diag) =
6216            match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6217                Ok(result) => result,
6218                // #1017 — fail loud, never skip-pass: this fixture clears the device
6219                // offload floor, so a CUDA device that is PRESENT yet declines means the
6220                // framed device PCG kernel does not run on GPU (the fault must not pass
6221                // silently). Legit skip ONLY when no usable CUDA device exists (CPU CI).
6222                // The exact `ArrowSchurGpuFailure` variant is folded into the assert.
6223                Err(failure) => {
6224                    assert!(
6225                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6226                        "#1017: CUDA device present but the framed device SAE PCG \
6227                     declined/faulted instead of returning a result (tag: {failure:?}) — \
6228                     the kernel does not run correctly on GPU"
6229                    );
6230                    return;
6231                }
6232            };
6233
6234        // #1551 PARITY GATE — operator-residual, NOT solution-vector equality.
6235        //
6236        // The honest GPU↔CPU contract for an iterative solve of `S·δβ = rhs` is
6237        // that the device solution SOLVES the system DEFINED BY THE CPU ORACLE to
6238        // PCG tolerance — i.e. `‖S_cpu·δβ_device − rhs‖ / ‖rhs‖ ≤ tol`, where
6239        // `S_cpu` is applied with the bit-for-bit CPU oracle matvec the device
6240        // kernel mirrors (`sae_framed_schur_matvec_cpu`). This is the correct,
6241        // conditioning-robust gate: a near-singular assembled `S` has a large
6242        // condition number `κ(S)`, which amplifies an O(ε) residual difference
6243        // into an O(κ·ε) *solution-vector* difference, so comparing δβ vectors
6244        // would spuriously fail even when both operators are bit-identical and
6245        // both solves converged. (Historically this test compared δβ against a
6246        // dense-Cholesky reference and "failed" with max_rel≈0.9 because the
6247        // dense solve itself only reached ‖S·x−rhs‖≈0.1 on this fixture's
6248        // ill-conditioned S while the device PCG reached ~1e-12 — the device was
6249        // MORE accurate than the reference, not wrong. The kernel correctness is
6250        // pinned conditioning-free by `framed_sae_device_matvec_matches_cpu_oracle_*`.)
6251        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6252        let oracle_resid = |x: &Array1<f64>| -> f64 {
6253            let mut sx = vec![0.0_f64; border_dim];
6254            sae_framed_schur_matvec_cpu(
6255                &sys,
6256                &data,
6257                ridge_t,
6258                ridge_beta,
6259                x.as_slice().unwrap(),
6260                &mut sx,
6261            )
6262            .expect("cpu oracle matvec");
6263            let mut acc = 0.0_f64;
6264            for a in 0..border_dim {
6265                let e = sx[a] - rhs[a];
6266                acc += e * e;
6267            }
6268            acc.sqrt()
6269        };
6270        let s_dev_resid = oracle_resid(&device);
6271        let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6272
6273        // Independent CPU iterative solve of the SAME operator with the SAME
6274        // Jacobi preconditioner the device builds, via the shared `pcg_core`. If
6275        // the device kernel computed a different operator, the two converged
6276        // residuals could not BOTH be tiny.
6277        let precond = {
6278            let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6279            // The reduced-Schur diagonal subtraction (device `arrow_sae_frame_diag_sub`)
6280            // mirrored on the host for the Jacobi preconditioner.
6281            let mut diag = d;
6282            for (i, row) in sys.rows.iter().enumerate() {
6283                let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6284                let qi = sys.row_dims[i];
6285                if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6286                    continue;
6287                }
6288                let mut block = row.htt.clone();
6289                for dd in 0..qi {
6290                    block[[dd, dd]] += ridge_t;
6291                }
6292                let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6293                    .expect("row htt PD");
6294                // ainv = (H_tt+ρI)⁻¹ column by column.
6295                let mut ainv = Array2::<f64>::zeros((qi, qi));
6296                for col in 0..qi {
6297                    let mut e = Array1::<f64>::zeros(qi);
6298                    e[col] = 1.0;
6299                    let s = cholesky_solve_vector(factor.view(), e.view());
6300                    for r in 0..qi {
6301                        ainv[[r, col]] = s[r];
6302                    }
6303                }
6304                for a in 0..border_dim {
6305                    let mut quad = 0.0_f64;
6306                    for c in 0..qi {
6307                        let hc = slab[c * border_dim + a];
6308                        for dd in 0..qi {
6309                            quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6310                        }
6311                    }
6312                    diag[a] -= quad;
6313                }
6314            }
6315            Array1::from_vec(diag)
6316        };
6317        let mut cpu = Array1::<f64>::zeros(border_dim);
6318        let cpu_result = {
6319            let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6320                let mut tmp = vec![0.0_f64; border_dim];
6321                sae_framed_schur_matvec_cpu(
6322                    &sys,
6323                    &data,
6324                    ridge_t,
6325                    ridge_beta,
6326                    v.as_slice().unwrap(),
6327                    &mut tmp,
6328                )
6329                .expect("cpu oracle matvec");
6330                out.assign(&Array1::from_vec(tmp));
6331            };
6332            gam_linalg::pcg::pcg_core(
6333                &mut apply,
6334                &rhs.view(),
6335                &precond.view(),
6336                1e-12,
6337                800,
6338                32,
6339                false,
6340                gam_linalg::pcg::DotReduction::Serial,
6341                &mut cpu.view_mut(),
6342            )
6343        };
6344        let s_cpu_resid = oracle_resid(&cpu);
6345        let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6346
6347        // GATE 1: the device solution solves the CPU-oracle system to PCG-grade
6348        // accuracy (proves device kernel == CPU operator AND device PCG converged).
6349        assert!(
6350            dev_rel_resid <= 1e-7,
6351            "[#1551] device δβ does not solve the CPU-oracle system: \
6352             ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6353             device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6354             means the device matvec is a DIFFERENT operator (kernel bug)",
6355            diag.stopping_reason,
6356            diag.iterations,
6357            diag.final_relative_residual,
6358        );
6359        // GATE 2: the independent CPU iterative solve of the same operator with the
6360        // same preconditioner also converges — both paths agree on the operator.
6361        assert!(
6362            cpu_rel_resid <= 1e-6,
6363            "[#1551] CPU pcg_core failed to solve the oracle system: \
6364             ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6365            cpu_result.stop,
6366            cpu_result.iterations,
6367        );
6368    }
6369
6370    /// Host mirror of the device `sae_frame_penalty_diag_host` for the framed
6371    /// Jacobi preconditioner (penalty diagonal only; the reduced-Schur diagonal
6372    /// subtraction is applied by the caller). Test-only.
6373    fn sae_frame_penalty_diag_host_for_test(
6374        data: &DeviceSaePcgData,
6375        ridge_beta: f64,
6376    ) -> Vec<f64> {
6377        let frame = data.frame.as_ref().expect("frame");
6378        let mut diag = vec![ridge_beta; data.beta_dim];
6379        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6380            let m = blk.factor_a.nrows();
6381            for ia in 0..m {
6382                let coeff = blk.factor_a[[ia, ia]];
6383                let base = blk.global_offset + ia * r;
6384                for ib in 0..r {
6385                    diag[base + ib] += coeff;
6386                }
6387            }
6388        }
6389        for blk in &frame.frame_blocks {
6390            if blk.atom_i != blk.atom_j {
6391                continue;
6392            }
6393            let r = frame.ranks[blk.atom_i];
6394            let off = frame.border_offsets[blk.atom_i];
6395            let (mi, mj) = blk.g.dim();
6396            for li in 0..mi.min(mj) {
6397                let gii = blk.g[[li, li]];
6398                let base = off + li * r;
6399                for a in 0..r {
6400                    diag[base + a] += gii * blk.w[[a, a]];
6401                }
6402            }
6403        }
6404        diag
6405    }
6406
6407    /// #1551 DEFINITIVE kernel-correctness proof: the framed reduced-Schur matvec
6408    /// `out = S·x` must agree with the CPU oracle [`sae_framed_schur_matvec_cpu`]
6409    /// element-wise, for several independent `x`, to ≤ 1e-9. This is the parity
6410    /// gate that actually localizes a kernel/marshalling defect — unlike a
6411    /// solved-`δβ` comparison, it does NOT route through a linear solve, so it is
6412    /// independent of the conditioning of the assembled `S` (a near-singular `S`
6413    /// can make a dense-Cholesky vector and an iterative-PCG vector disagree at
6414    /// the *solution* level even when both operators are bit-correct; the
6415    /// operator itself must still match here). Fails loud if CUDA is present but
6416    /// the device matvec declines; skips cleanly only when no device exists.
6417    #[test]
6418    fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6419        use crate::arrow_schur::{
6420            DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6421        };
6422
6423        // Hand-checkable frame fixture: a mix of framed (r<p) and identity-ride
6424        // (r==p) atoms, a few off-diagonal cross blocks, dense per-row H_tβ.
6425        let p = 6usize;
6426        let n_atoms = 8usize;
6427        let ranks: Vec<usize> = (0..n_atoms)
6428            .map(|k| if k % 2 == 0 { 3usize } else { p })
6429            .collect();
6430        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6431        let mut border_offsets = Vec::with_capacity(n_atoms);
6432        let mut acc = 0usize;
6433        for k in 0..n_atoms {
6434            border_offsets.push(acc);
6435            acc += basis_sizes[k] * ranks[k];
6436        }
6437        let border_dim = acc;
6438
6439        let mut state = 0x1551_0017_1026_0922u64;
6440        let mut sample = || -> f64 {
6441            state = state
6442                .wrapping_mul(6364136223846793005)
6443                .wrapping_add(1442695040888963407);
6444            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6445        };
6446        let mut frames: Vec<Array2<f64>> = Vec::new();
6447        for k in 0..n_atoms {
6448            let r = ranks[k];
6449            let mut u = Array2::<f64>::zeros((p, r));
6450            for i in 0..p {
6451                for j in 0..r {
6452                    u[[i, j]] = if r == p && i == j {
6453                        1.0
6454                    } else if r == p {
6455                        0.0
6456                    } else {
6457                        sample()
6458                    };
6459                }
6460            }
6461            frames.push(u);
6462        }
6463        let w_of = |i: usize, j: usize| {
6464            let (ui, uj) = (&frames[i], &frames[j]);
6465            let (ri, rj) = (ranks[i], ranks[j]);
6466            let mut w = Array2::<f64>::zeros((ri, rj));
6467            for a in 0..ri {
6468                for b in 0..rj {
6469                    let mut s = 0.0;
6470                    for c in 0..p {
6471                        s += ui[[c, a]] * uj[[c, b]];
6472                    }
6473                    w[[a, b]] = s;
6474                }
6475            }
6476            w
6477        };
6478        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6479        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6480            pairs.push((i, j));
6481            pairs.push((j, i));
6482        }
6483        let mut frame_blocks = Vec::new();
6484        for &(i, j) in &pairs {
6485            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6486            let mut g = Array2::<f64>::zeros((mi, mj));
6487            for r in 0..mi {
6488                for c in 0..mj {
6489                    g[[r, c]] = 0.25 * sample();
6490                }
6491            }
6492            if i == j {
6493                for r in 0..mi.min(mj) {
6494                    g[[r, r]] += mi as f64 + 2.0;
6495                }
6496            }
6497            frame_blocks.push(FactoredFrameGBlock {
6498                atom_i: i,
6499                atom_j: j,
6500                g,
6501                w: w_of(i, j),
6502            });
6503        }
6504        let mut smooth_blocks = Vec::new();
6505        let mut smooth_ranks = Vec::new();
6506        for k in 0..n_atoms {
6507            let m = basis_sizes[k];
6508            let mut a = Array2::<f64>::zeros((m, m));
6509            for r in 0..m {
6510                for c in 0..m {
6511                    a[[r, c]] = 0.2 * sample();
6512                }
6513            }
6514            let mut s = a.t().dot(&a);
6515            for r in 0..m {
6516                s[[r, r]] += 1.0;
6517            }
6518            smooth_blocks.push(DeviceSaeSmoothBlock {
6519                global_offset: border_offsets[k],
6520                factor_a: s,
6521            });
6522            smooth_ranks.push(ranks[k]);
6523        }
6524        // Modest row count: this seam bypasses the offload floor, so we keep the
6525        // fixture small and the per-row reduced-Schur term well-scaled — the
6526        // matvec parity does not depend on the assembled-S conditioning at all.
6527        let n = 32usize;
6528        let q = 4usize;
6529        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6530        let mut row_htbeta = Vec::new();
6531        for i in 0..n {
6532            let mut a = Array2::<f64>::zeros((q, q));
6533            for r in 0..q {
6534                for c in 0..q {
6535                    a[[r, c]] = sample();
6536                }
6537            }
6538            let mut htt = a.t().dot(&a);
6539            for r in 0..q {
6540                htt[[r, r]] += q as f64 + 1.0;
6541            }
6542            sys.rows[i].htt = htt;
6543            let mut slab = vec![0.0_f64; q * border_dim];
6544            for c in 0..q {
6545                for col in 0..border_dim {
6546                    let v = 0.3 * sample();
6547                    slab[c * border_dim + col] = v;
6548                    sys.rows[i].htbeta[[c, col]] = v;
6549                }
6550            }
6551            row_htbeta.push(slab);
6552        }
6553        let ridge_t = 1e-7;
6554        let ridge_beta = 1e-6;
6555        let data = DeviceSaePcgData {
6556            p,
6557            beta_dim: border_dim,
6558            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6559            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6560            smooth_blocks,
6561            sparse_g_blocks: Vec::new(),
6562            frame: Some(DeviceSaeFrameData {
6563                ranks: ranks.clone(),
6564                basis_sizes: basis_sizes.clone(),
6565                border_offsets: border_offsets.clone(),
6566                frame_blocks,
6567                smooth_ranks,
6568                row_htbeta,
6569            }),
6570        };
6571
6572        // Several independent probe vectors x, including unit axes and dense
6573        // random — a marshalling stride/offset bug shows up as a per-component
6574        // mismatch on at least one.
6575        let mut probes: Vec<Array1<f64>> = Vec::new();
6576        probes.push(Array1::from_shape_fn(border_dim, |a| {
6577            ((a as f64 + 1.0) * 0.37).sin()
6578        }));
6579        probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6580        for axis in [0usize, border_dim / 3, border_dim - 1] {
6581            let mut e = Array1::<f64>::zeros(border_dim);
6582            e[axis] = 1.0;
6583            probes.push(e);
6584        }
6585
6586        let mut any_ran = false;
6587        let mut worst = 0.0_f64;
6588        for (pi, x) in probes.iter().enumerate() {
6589            let device = match super::framed_schur_matvec_once_on_device(
6590                &sys, &data, ridge_t, ridge_beta, x,
6591            ) {
6592                Ok(out) => out,
6593                Err(failure) => {
6594                    // Fail loud: a present CUDA device that declines this seam
6595                    // (which deliberately ignores the offload floor) means the
6596                    // framed matvec kernel does not run on GPU.
6597                    assert!(
6598                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6599                        "#1551: CUDA device present but the framed device matvec \
6600                         declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6601                         does not run on GPU"
6602                    );
6603                    return;
6604                }
6605            };
6606            any_ran = true;
6607            let mut cpu = vec![0.0_f64; border_dim];
6608            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6609                .expect("cpu oracle matvec");
6610            let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6611            for a in 0..border_dim {
6612                let rel = (device[a] - cpu[a]).abs() / scale;
6613                worst = worst.max(rel);
6614                assert!(
6615                    rel <= 1e-9,
6616                    "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6617                     rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6618                    device[a],
6619                    cpu[a],
6620                );
6621            }
6622        }
6623        if any_ran {
6624            // Positive on-device confirmation: the framed matvec ran on the GPU
6625            // and matched the CPU oracle across every probe. (1e-9 is far above
6626            // the ~1e-13 fp64 GEMV round-off; a structural marshalling bug would
6627            // be O(1).)
6628            assert!(
6629                gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6630                "#1551: matvec ran but no GPU runtime — unexpected"
6631            );
6632            assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6633        }
6634    }
6635}