Skip to main content

gam_solve/gpu_kernels/
arrow_schur.rs

1//! Fully GPU-resident batched Arrow-Schur dense Cholesky solver.
2//!
3//! Implements the square-root Schur form: each local block `D_i = L_i L_i^T`
4//! is factored on device, `u_i = L_i^{-1} g_i` and `Y_i = L_i^{-1} B_i` are
5//! formed by triangular solves, the reduced shared system
6//!     `S_β = C + ρ_β I − Σ_i Y_i^T Y_i,  r_β = -g_β + Σ_i Y_i^T u_i`
7//! is assembled on device, factored once, and the back-substitution
8//!     `w_i = u_i + Y_i · δβ,  L_i^T x_i = w_i,  δt_i = -x_i`
9//! is run on device. Only the final `(δt, δβ, log|H|)` triple is downloaded.
10//!
11//! The current caller (Arrow-Schur Newton step inside PIRLS) feeds uniform
12//! local block size `d` and uniform shared width `k`, so the entire pipeline
13//! is dispatched as a single p-group; per-p grouping for heterogenous blocks
14//! is Layer D's NVRTC fused-kernel concern and lives in this module's
15//! follow-up implementation rather than in policy plumbing.
16//!
17//! On non-Linux builds the entire module degrades to a CPU-fallback shim.
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24/// Outcome of a single Arrow-Schur Newton solve.
25pub struct ArrowSchurGpuSolution {
26    pub delta_t: Array1<f64>,
27    pub delta_beta: Array1<f64>,
28    /// Natural log of the determinant of the full bordered Hessian, computed
29    /// from the local Cholesky factors and the Schur factor on device.
30    pub log_det_hessian: f64,
31}
32
33/// Reason a device path declined to run; lets the host caller decide between
34/// CPU fallback and per-row escalation. `RidgeBumpRequired` carries the
35/// estimated diagonal bump needed to clear the failed pivot.
36#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38    /// CUDA runtime unavailable, allocation failed, or workload below policy.
39    Unavailable,
40    /// A row block was not positive definite even after the requested ridge.
41    /// Caller may retry with `ridge_t + bump`.
42    RidgeBumpRequired { row: usize, bump: f64 },
43    /// Shared Schur factor failed; bordered system is rank-deficient at the
44    /// requested ridges and the CPU path should handle escalation.
45    SchurFactorFailed { reason: String },
46    /// The system carries matrix-free `H_ββ` or per-row `H_tβ` operators that
47    /// the dense GPU Schur path cannot consume. The caller should route to CPU
48    /// `InexactPCG` (or supply dense buffers) rather than treating this as a
49    /// numerical failure. See `gpu/arrow_schur.rs` Part B for the planned GPU
50    /// PCG path that will lift this restriction at K ≥ 5000.
51    GpuRequiresDenseSystem {
52        had_hbb_matvec: bool,
53        had_htbeta_matvec: bool,
54    },
55}
56
57/// Relative rounding margin (multiplier on `diag_scale · √ε`) added on top of
58/// the deficit-clearing shift in [`ridge_bump_to_make_pd`].
59///
60/// The exact shift `-(λ_min)` makes a block PD in exact arithmetic, but a
61/// single retry at precisely that magnitude is routinely re-rejected by the
62/// next POTRF because the rounding error of forming `D + ridge·I` and
63/// re-factoring is itself O(√ε). The 1024× headroom (≈ 2¹⁰, ten extra bits
64/// below the f64 mantissa's 52) clears the pivot on the first retry without
65/// materially perturbing the curvature the Newton step sees. Shared by every
66/// per-row / batched / fused producer so they suggest a consistent bump.
67const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69/// Diagonal ridge bump that is GUARANTEED to make `H_tt + (ridge_t + bump)·I`
70/// positive definite for a *symmetric* per-row block, sized from the block's
71/// own entries rather than from the factorization's pivot index.
72///
73/// # Why the old `scale · |pivot| · √ε · 1024` estimate is wrong
74///
75/// The batched/fused device paths derive the suggested bump from the
76/// factorization "pivot" — but cuSOLVER's `potrf` (and the NVRTC kernel's
77/// status code) report the failing pivot as a **1-based row index**, NOT the
78/// magnitude of the negative pivot. A block that is indefinite by `O(1)`
79/// (e.g. `H_tt = -I`, whose smallest eigenvalue is `-1`) then yields the same
80/// `bump ≈ √ε · 1024 ≈ 1.5e-5` as a block that is indefinite by `O(√ε)`. The
81/// outer LM escalation, which retries at `ridge_t + bump` and grows
82/// geometrically with a bounded step count, can never lift a strongly
83/// indefinite block out of the negative regime, so the solve fails to recover
84/// even though the block is trivially regularizable. (Surfaced by the V100
85/// `ridge_bump_required_on_non_pd_row_recovers_after_bump` validation test.)
86///
87/// # The bound
88///
89/// By the Gershgorin circle theorem every eigenvalue `λ` of the symmetric
90/// matrix `A = H_tt` satisfies, for some row `i`,
91///   `λ ≥ A[i,i] − Σ_{j≠i} |A[i,j]|`,
92/// so `λ_min(A) ≥ min_i ( A[i,i] − Σ_{j≠i} |A[i,j]| ) =: g` (the most negative
93/// Gershgorin left edge). Adding `t·I` shifts every eigenvalue up by `t`, so
94/// `A + t·I` is PD as soon as `t > -g`. We are already sitting at `ridge_t`, so
95/// the ADDITIONAL bump needed is `-(g + ridge_t)` when that is positive. We add
96/// a relative safety margin (`√ε · scale · 1024`, the same headroom the legacy
97/// estimate used) so the re-factored, rounding-perturbed block clears the pivot
98/// on the first retry, and a `max(1)`-scaled floor so a marginally-indefinite
99/// block still gets a strictly positive, non-vanishing bump.
100///
101/// The returned value is the bump to ADD to the current `ridge_t`. It is always
102/// strictly positive (the caller only constructs `RidgeBumpRequired` on an
103/// actual non-PD failure, but the bound is defensive regardless).
104#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106    let d = htt.nrows();
107    // Diagonal magnitude scale (also the legacy `scale`), and the most-negative
108    // Gershgorin left edge `g = min_i (A_ii − Σ_{j≠i} |A_ij|)`.
109    let mut scale = 1.0_f64;
110    let mut min_gershgorin_edge = f64::INFINITY;
111    for i in 0..d {
112        let diag = htt[[i, i]];
113        scale = scale.max(diag.abs());
114        let mut off_sum = 0.0_f64;
115        for j in 0..d {
116            if j != i {
117                off_sum += htt[[i, j]].abs();
118            }
119        }
120        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121    }
122    if !min_gershgorin_edge.is_finite() {
123        // d == 0 (no rows) or non-finite entries: fall back to the scale-only
124        // floor so the caller still gets a strictly positive bump.
125        return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126    }
127    // Additional shift needed so `λ_min(A) + ridge_t + bump > 0`, i.e.
128    // `bump > -(min_gershgorin_edge + ridge_t)`.
129    let deficit = -(min_gershgorin_edge + ridge_t);
130    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131    // Lift past the deficit (when positive) plus a rounding margin; never below
132    // the scale-relative floor so a marginal block still moves.
133    deficit.max(0.0) + margin
134}
135
136/// [`ridge_bump_to_make_pd`] for a `d × d` symmetric block stored column-major
137/// in a flat slice with the current ridge ALREADY baked into the diagonal
138/// (the device packers emit `D = H_tt + ridge_t·I` this way). Because the shift
139/// is already present, the Gershgorin bound is taken at `ridge_t = 0` and the
140/// returned value is still the ADDITIONAL bump to add on top of the current
141/// ridge. Returns the scale-only floor when `block` is mis-sized.
142// The column-major bump is a helper for the device tile packers, which only
143// exist on the linux CUDA path (`mod cuda`, `#[cfg(target_os = "linux")]`). It
144// therefore lives where it is used: gate it to linux. Its parity unit test is
145// gated to linux to match (a `test` token in the cfg would trip the build.rs
146// `#[cfg(test)]`-on-a-src-item ban; including `test` to dodge the non-linux
147// dead_code lint is exactly the escape hatch that ban forbids).
148#[cfg(target_os = "linux")]
149#[must_use]
150fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
151    if d == 0 || block.len() < d * d {
152        return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
153    }
154    // Column-major: element (row r, col c) at block[c*d + r]. The matrix is
155    // symmetric, so reading by column gives the same Gershgorin edges as by row.
156    let mut scale = 1.0_f64;
157    let mut min_gershgorin_edge = f64::INFINITY;
158    for i in 0..d {
159        let diag = block[i * d + i];
160        scale = scale.max(diag.abs());
161        let mut off_sum = 0.0_f64;
162        for j in 0..d {
163            if j != i {
164                off_sum += block[j * d + i].abs();
165            }
166        }
167        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
168    }
169    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
170    if !min_gershgorin_edge.is_finite() {
171        return margin;
172    }
173    (-min_gershgorin_edge).max(0.0) + margin
174}
175
176/// Entry point: attempt the fully device-resident Arrow-Schur Newton solve.
177/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` to indicate "device path
178/// declined, fall back to CPU" — never panics.
179pub fn solve_arrow_newton_step(
180    sys: &ArrowSchurSystem,
181    ridge_t: f64,
182    ridge_beta: f64,
183) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
184    let n = sys.rows.len();
185    let d = sys.d;
186    let k = sys.k;
187
188    // Detect matrix-free operators before any dim() checks so callers get a
189    // clear, actionable error instead of a generic SchurFactorFailed. The GPU
190    // dense-Schur path requires materialised H_ββ and per-row H_tβ slabs;
191    // CPU InexactPCG is the correct fallback when either operator is abstract.
192    let had_hbb_matvec = sys.hbb_matvec.is_some();
193    let had_htbeta_matvec = sys.htbeta_matvec.is_some();
194    if had_hbb_matvec || had_htbeta_matvec {
195        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
196            had_hbb_matvec,
197            had_htbeta_matvec,
198        });
199    }
200
201    if sys.hbb.dim() != (k, k) {
202        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
203            reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
204        });
205    }
206    if n == 0 || d == 0 {
207        return Err(ArrowSchurGpuFailure::Unavailable);
208    }
209    if sys
210        .rows
211        .iter()
212        .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
213    {
214        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
215            reason: "row block dimension mismatch".to_string(),
216        });
217    }
218
219    #[cfg(not(target_os = "linux"))]
220    {
221        if ridge_t.is_nan() || ridge_beta.is_nan() {
222            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
223                reason: "ridge is NaN".to_string(),
224            });
225        }
226        Err(ArrowSchurGpuFailure::Unavailable)
227    }
228
229    #[cfg(target_os = "linux")]
230    {
231        // Multi-GPU: the arrow-Schur solve is row-block separable in its forward
232        // (per-row factor / whiten / partial-Schur) and backward (per-row
233        // back-sub) phases — only the small shared K×K reduce+factor+δβ is
234        // central. When more than one device is usable, split the WHOLE solve at
235        // row-block granularity across all GPUs. The POTRF stays fused with its
236        // dependent TRSM+GEMM on each tile's own stream, so no on-stream solve is
237        // orphaned. On `Unavailable` (one device, shape below policy, transient)
238        // fall through to the single-device fused / Layer-A paths below.
239        if gam_gpu::device_runtime::GpuRuntime::global()
240            .map(gam_gpu::device_runtime::GpuRuntime::device_count)
241            .unwrap_or(0)
242            > 1
243        {
244            match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
245                Ok(sol) => return Ok(sol),
246                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
247                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
248                }
249                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
250                    return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
251                }
252                // Unavailable / GpuRequiresDenseSystem: fall through to the
253                // single-device paths (already shape-validated above).
254                Err(_) => {}
255            }
256        }
257        // Layer D admission: when the system shape passes the
258        // (Σ p³ ≥ 1e5 OR R ≥ 16) heuristic and `p ≤ MAX_FUSED_P`, the fused
259        // NVRTC kernel replaces the cuSOLVER/cuBLAS Layer A+B+C path with a
260        // single per-row block. Layer C↔D parity (math block 3 §16 test 6)
261        // requires both paths to agree to 1e-10 on identical inputs.
262        if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
263            match cuda::solve_fused(sys, ridge_t, ridge_beta) {
264                Ok(sol) => return Ok(sol),
265                // RidgeBumpRequired must surface to the outer escalation loop —
266                // the fused path's pivot diagnostic is identical in semantics
267                // to the cuSOLVER batched POTRF info code.
268                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
269                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
270                }
271                // Any other failure (Unavailable, SchurFactorFailed) falls
272                // through to the unfused path so a flaky NVRTC compile or
273                // shared-mem allocation does not abort the outer Newton step.
274                Err(_) => {}
275            }
276        }
277        cuda::solve(sys, ridge_t, ridge_beta)
278    }
279}
280
281/// Build the stacked column-major D buffer (n local d×d blocks), the stacked
282/// stacked B buffer (n local d×k blocks), and the stacked g buffer
283/// (n local d-vectors) consumed by the device pipeline. Each block is laid
284/// out column-major so a single allocation + `cuMemcpyHtoD` reaches the
285/// device without per-row dispatch overhead.
286#[cfg(target_os = "linux")]
287fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
288    let n = sys.rows.len();
289    let d = sys.d;
290    let k = sys.k;
291    let mut d_buf = Vec::with_capacity(n * d * d);
292    let mut b_buf = Vec::with_capacity(n * d * k);
293    let mut g_buf = Vec::with_capacity(n * d);
294    for row in &sys.rows {
295        pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
296    }
297    (d_buf, b_buf, g_buf)
298}
299
300#[cfg(target_os = "linux")]
301#[inline]
302fn pack_block(
303    row: &crate::arrow_schur::ArrowRowBlock,
304    ridge_t: f64,
305    d: usize,
306    k: usize,
307    d_buf: &mut Vec<f64>,
308    b_buf: &mut Vec<f64>,
309    g_buf: &mut Vec<f64>,
310) {
311    for col in 0..d {
312        for r in 0..d {
313            let mut value = row.htt[[r, col]];
314            if r == col {
315                value += ridge_t;
316            }
317            d_buf.push(value);
318        }
319    }
320    for col in 0..k {
321        for r in 0..d {
322            b_buf.push(row.htbeta[[r, col]]);
323        }
324    }
325    for r in 0..d {
326        g_buf.push(row.gt[r]);
327    }
328}
329
330/// Test-only entry that forces the Layer D + E fused NVRTC path regardless
331/// of the admission heuristic. Used by the V100 Layer C↔D parity test to
332/// drive the fused kernel at small shapes the heuristic would otherwise
333/// route through the cuSOLVER/cuBLAS Layer A+B+C path.
334#[doc(hidden)]
335#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] // `sys` is consumed only by the linux branch
336pub fn solve_arrow_newton_step_fused_force(
337    sys: &ArrowSchurSystem,
338    ridge_t: f64,
339    ridge_beta: f64,
340) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
341    if ridge_t.is_nan() || ridge_beta.is_nan() {
342        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
343            reason: "ridge is NaN".to_string(),
344        });
345    }
346    #[cfg(not(target_os = "linux"))]
347    {
348        // No NVRTC toolchain off linux: the fused path is unconditionally
349        // unavailable. `sys` is consumed only by the linux branch below; the
350        // fn-level cfg_attr allows it to read as unused here without a banned
351        // `let _` binding or a no-op `drop` of the reference.
352        Err(ArrowSchurGpuFailure::Unavailable)
353    }
354    #[cfg(target_os = "linux")]
355    {
356        if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
357            .is_none()
358        {
359            return Err(ArrowSchurGpuFailure::Unavailable);
360        }
361        cuda::solve_fused(sys, ridge_t, ridge_beta)
362    }
363}
364
365/// #1017 Phase 3: a device-resident Arrow-Schur frame whose constant Hessian
366/// blocks (`D = H_tt`, `B = H_tβ`, border `H_ββ`) and their factors stay on the
367/// device across the inner Newton loop. Construct once per frozen gate/basis
368/// frame, then call [`ResidentArrowFrameHandle::solve_gradient`] once per
369/// iterate with the fresh residual gradient — only the `O(n·d + p)` gradient
370/// crosses to the device and only `δ` crosses back, in contrast to
371/// [`solve_arrow_newton_step`] which re-uploads and re-factors the full system
372/// every call. On a non-CUDA host construction returns
373/// `ArrowSchurGpuFailure::Unavailable`.
374pub struct ResidentArrowFrameHandle {
375    #[cfg(target_os = "linux")]
376    inner: cuda::ResidentArrowFrame,
377    #[cfg(not(target_os = "linux"))]
378    _never: std::convert::Infallible,
379}
380
381impl ResidentArrowFrameHandle {
382    /// Upload the constant Hessian blocks and perform the one-time factor work.
383    pub fn new(
384        sys: &ArrowSchurSystem,
385        ridge_t: f64,
386        ridge_beta: f64,
387    ) -> Result<Self, ArrowSchurGpuFailure> {
388        // The dense device path requires materialised blocks, same admission as
389        // `solve_arrow_newton_step`.
390        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
391            return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
392                had_hbb_matvec: sys.hbb_matvec.is_some(),
393                had_htbeta_matvec: sys.htbeta_matvec.is_some(),
394            });
395        }
396        #[cfg(not(target_os = "linux"))]
397        {
398            if ridge_t.is_nan() || ridge_beta.is_nan() {
399                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
400                    reason: "ridge is NaN".to_string(),
401                });
402            }
403            Err(ArrowSchurGpuFailure::Unavailable)
404        }
405        #[cfg(target_os = "linux")]
406        {
407            Ok(Self {
408                inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
409            })
410        }
411    }
412
413    /// Solve `H δ = −gradient` for a fresh gradient reusing the resident factors.
414    pub fn solve_gradient(
415        &self,
416        g_t: &[f64],
417        g_beta: &[f64],
418    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
419        #[cfg(not(target_os = "linux"))]
420        {
421            if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
422                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
423                    reason: "non-finite gradient entry".to_string(),
424                });
425            }
426            Err(ArrowSchurGpuFailure::Unavailable)
427        }
428        #[cfg(target_os = "linux")]
429        {
430            self.inner.solve_gradient(g_t, g_beta)
431        }
432    }
433
434    /// `log|H|` for the frame (constant; depends only on the factored Hessian).
435    #[must_use]
436    pub fn log_det_hessian(&self) -> f64 {
437        #[cfg(not(target_os = "linux"))]
438        {
439            // SAFETY: off-CUDA, `ResidentArrowFrameHandle::new` always returns
440            // `Err(Unavailable)`, so no handle of this type is ever constructed and
441            // this method is statically unreachable on non-Linux targets. A NaN
442            // sentinel would silently corrupt any consumer of the log-determinant,
443            // so fail loudly on the impossible path instead.
444            panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
445        }
446        #[cfg(target_os = "linux")]
447        {
448            self.inner.log_det_hessian()
449        }
450    }
451}
452
453/// Build a GPU-backed Schur matvec closure for CPU-driven PCG at K ≥ 5000.
454///
455/// Runs the fused NVRTC forward kernel once on the dense per-row `H_tβ` slabs
456/// to compute `Y_i = L_i^{-1} H_tβ^(i)` for all rows, persists the `Y_i`
457/// factors in a host-side buffer, and returns an `Arc<dyn Fn(...)>` closure
458/// that computes the full Schur matvec
459///
460/// ```text
461/// S·x = (H_ββ + ridge_beta·I)·x  −  Σ_i Y_i^T (Y_i·x)
462/// ```
463///
464/// each time it is called. At K ≥ 5000 the `Σ_i Y_i^T (Y_i·x)` term
465/// dominates over the host↔device transfer of the K-vector `x`, so the GPU
466/// path is a clear win even with per-iteration transfer.
467///
468/// `H_ββ·x` is evaluated on the CPU using `sys.hbb_matvec` when present (the
469/// matrix-free hook for SAE-manifold scale callers) or the dense `sys.hbb`
470/// block otherwise. The `Y_i` term uses cuBLAS batched GEMV device-side; only
471/// `x` (K doubles) and `out` (K doubles) cross the host↔device boundary per
472/// PCG iteration.
473///
474/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` if CUDA is unavailable or
475/// the system shape is outside the fused kernel's admission range (e.g.
476/// `d > MAX_FUSED_P = 32` or no CUDA context). Callers should fall back to CPU
477/// `InexactPCG` on `Unavailable`.
478///
479/// Returns `Err(ArrowSchurGpuFailure::RidgeBumpRequired)` if a per-row Cholesky
480/// factor failed at the requested `ridge_t`; the outer LM escalation should
481/// bump `ridge_t` and retry.
482///
483/// # Composition with the matrix-free SAE Kronecker operator
484///
485/// When `sys.htbeta_matvec` is set (matrix-free `H_tβ` Kronecker operator),
486/// the dense `H_tβ` slabs are absent — the dense forward kernel above cannot
487/// run, and at `K = 100K` the dense `Y_i = L_i^{-1} H_tβ^(i)` (`d × K` per row)
488/// could not be materialised anyway. Instead, `build_row_procedural_matvec`
489/// returns a row-procedural Schur matvec: per row it gathers
490/// `v_i = H_tβ^(i)·x` through the forward operator (sparse `O(m_i · p)`),
491/// solves `(H_tt^(i) + ρ_t·I)^{-1} v_i` through a pre-computed per-row Cholesky
492/// factor, and scatters `H_βt^(i)·w_i` through the sparse transpose operator
493/// (`O(m_i · p)`, replacing the old `O(K)` column-probe). This is the
494/// row-procedural `a_ik · Φ_k[i,m]` Kronecker apply over the active atoms only.
495pub fn gpu_schur_matvec_backend(
496    sys: &ArrowSchurSystem,
497    ridge_t: f64,
498    ridge_beta: f64,
499) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
500    // Matrix-free H_tβ operator present: drive the row-procedural sparse
501    // Kronecker apply (active atoms only) instead of the dense forward kernel.
502    if sys.htbeta_matvec.is_some() {
503        return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
504    }
505
506    #[cfg(not(target_os = "linux"))]
507    {
508        // No CUDA runtime on non-Linux. NaN ridges are validated to ensure the
509        // same contract as the Linux path.
510        if ridge_t.is_nan() || ridge_beta.is_nan() {
511            return Err(ArrowSchurGpuFailure::Unavailable);
512        }
513        Err(ArrowSchurGpuFailure::Unavailable)
514    }
515
516    #[cfg(target_os = "linux")]
517    {
518        cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
519    }
520}
521
522/// Build a row-procedural reduced-Schur matvec for matrix-free SAE Kronecker
523/// systems, eliminating the per-row latent block via cached per-row Cholesky
524/// factors and applying the cross-block through the sparse forward/transpose
525/// Kronecker operators (active atoms only).
526///
527/// The returned closure evaluates
528/// `S·x = (H_ββ + ρ_β·I)·x − Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x`,
529/// the same reduced Schur complement the dense path forms, but never
530/// materialises the `d × K` cross-block `H_tβ^(i)`: the forward operator
531/// (`out = H_tβ^(i)·x`) and transpose operator (`out += H_βt^(i)·v`) are the
532/// sparse Kronecker gather/scatter from `SaeKroneckerRows`. The per-row factor
533/// of `H_tt^(i) + ρ_t·I` is computed once when the closure is built and reused
534/// across every CG iteration.
535///
536/// Returns `RidgeBumpRequired` if a per-row block is not positive definite at
537/// the requested `ridge_t`; the outer LM escalation bumps `ridge_t` and retries.
538fn build_row_procedural_matvec(
539    sys: &ArrowSchurSystem,
540    ridge_t: f64,
541    ridge_beta: f64,
542) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
543    use std::sync::Arc;
544    let n = sys.rows.len();
545    let k = sys.k;
546    let forward = sys
547        .htbeta_matvec
548        .clone()
549        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
550    let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
551        // A forward operator without its sparse adjoint cannot be applied
552        // row-procedurally; this is a wiring error, surfaced as a Schur failure
553        // so the caller routes to the dense CPU path rather than misreporting a
554        // numerical bump.
555        ArrowSchurGpuFailure::SchurFactorFailed {
556            reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
557                     forward operator installed without its sparse adjoint"
558                .to_string(),
559        }
560    })?;
561
562    // Pre-factor each per-row block H_tt^(i) + ρ_t·I = L_i L_iᵀ on the host.
563    // The blocks are tiny (d_i ≲ 32) and the dense cross-block slabs are
564    // absent, so there is no device forward-kernel work to amortise here; the
565    // GPU win is the reduced K-system solve in `solve_reduced_beta_pcg`.
566    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
567    for (i, row) in sys.rows.iter().enumerate() {
568        let di = row.htt.nrows();
569        if row.htt.ncols() != di || row.gt.len() != di {
570            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
571                reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
572            });
573        }
574        let mut block = row.htt.clone();
575        for r in 0..di {
576            block[[r, r]] += ridge_t;
577        }
578        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
579            .ok_or_else(|| {
580                // Deficit-aware bump from the block's own entries (Gershgorin),
581                // so the outer LM escalation lifts a strongly-indefinite block
582                // out of the negative regime in one retry.
583                ArrowSchurGpuFailure::RidgeBumpRequired {
584                    row: i,
585                    bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
586                }
587            })?;
588        factors.push(factor);
589    }
590
591    // The SAE-manifold β-Hessian lives in the structured penalty operator
592    // (data-fit Gauss-Newton `G ⊗ I_p` + smoothness Kronecker blocks + any
593    // dense analytic-β residual), NOT in the dense `hbb` accumulator — for
594    // matrix-free systems `hbb` is zero/absent. Capture the effective penalty
595    // operator so `H_ββ·x` matches the CPU `schur_matvec` path exactly. The
596    // operator's `matvec` adds (`y += P x`), so seed `out` from the ridge term.
597    let penalty_op = sys.effective_penalty_op();
598    let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
599
600    let closure: crate::arrow_schur::GpuSchurMatvec =
601        Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
602            assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
603            assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
604
605            // (H_ββ + ρ_β·I)·x into out. Seed with the ridge term, then add the
606            // structured penalty-side product (penalty_op.matvec is additive).
607            {
608                let x_slice = x.as_slice().expect("x must be contiguous");
609                let out_slice = out.as_slice_mut().expect("out must be contiguous");
610                for a in 0..k {
611                    out_slice[a] = ridge_beta * x_slice[a];
612                }
613                penalty_op.matvec(x_slice, out_slice);
614            }
615
616            // out -= Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x.
617            //
618            // #1017: this row-procedural reduced-Schur term is the matrix-free
619            // SAE path's matvec hot loop (`build_row_procedural_matvec` is the
620            // host backend `gpu_schur_matvec_backend` returns when the dense
621            // `H_tβ` slabs are absent — the production Qwen shape). At
622            // (n≈2000 rows) it ran SERIALLY on one core and allocated a fresh
623            // length-`K` `neg` plus per-row `v_i`/`w_i` on EVERY CG iteration —
624            // tens of thousands of tiny heap allocations across a solve. Each
625            // row contributes an independent length-`K` scatter, so the sum is
626            // embarrassingly parallel; fan it across rayon over fixed row chunks
627            // and fold the per-chunk length-`K` partials in chunk order so the
628            // f64 reduction is deterministic (bit-identical run-to-run)
629            // regardless of thread scheduling — it agrees with the serial sum up
630            // to ULP-scale chunk reassociation (the #1017 verification gate).
631            // Because that reassociation is a real (if tiny) departure from
632            // serial, the criterion ranking across topology candidates is stable
633            // except for candidates separated by less than the reassociation
634            // margin, where the near-tie winner can flip — not an exact no-move
635            // guarantee (#1211). Stay
636            // sequential below
637            // `SCHUR_MATVEC_PARALLEL_ROW_MIN` rows and when already inside a
638            // rayon worker (the topology race fans candidates with
639            // `run_topology_race_parallel`) — the same nested-rayon guard the
640            // CPU `schur_matvec` uses. Buffers (`v_i`, `neg`) are reused across
641            // rows within a chunk, so the per-row allocation churn is gone.
642            let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
643                && rayon::current_thread_index().is_none();
644            if parallel {
645                use rayon::prelude::*;
646                const CHUNK: usize = 64;
647                let partials: Vec<Array1<f64>> = (0..n)
648                    .into_par_iter()
649                    .chunks(CHUNK)
650                    .map(|idxs| {
651                        // One length-`K` scatter accumulator per chunk; the
652                        // per-row latent vector `v_i` (length `d_i ≲ 32`) is the
653                        // only per-row buffer, sized to the row's own `d_i`.
654                        let mut neg = Array1::<f64>::zeros(k);
655                        for i in idxs {
656                            let di = row_dims[i];
657                            // v_i = H_tβ^(i)·x (sparse Kronecker gather).
658                            let mut v_i = Array1::<f64>::zeros(di);
659                            forward(i, x.view(), &mut v_i);
660                            // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
661                            let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
662                            // neg += H_βt^(i)·w_i (sparse scatter).
663                            transpose(i, w_i.view(), &mut neg);
664                        }
665                        neg
666                    })
667                    .collect();
668                // #1017/#1175 floating-point parity contract: each chunk's row
669                // sum is formed locally, then chunk partials are folded
670                // left-to-right. That makes the parallel row-procedural Schur
671                // term deterministic for a fixed input and chunking, but it is
672                // not required to be bit-identical to the serial path because
673                // the additions are reassociated at chunk boundaries. CPU/GPU
674                // validation should therefore allow ULP-scale drift while
675                // expecting stable run-to-run results.
676                let mut neg = Array1::<f64>::zeros(k);
677                for part in &partials {
678                    for a in 0..k {
679                        neg[a] += part[a];
680                    }
681                }
682                for a in 0..k {
683                    out[a] -= neg[a];
684                }
685            } else {
686                // Serial path: reuse one `neg` and one `v_i` across rows.
687                let mut neg = Array1::<f64>::zeros(k);
688                for i in 0..n {
689                    let di = row_dims[i];
690                    // v_i = H_tβ^(i)·x (sparse Kronecker gather, length d_i).
691                    let mut v_i = Array1::<f64>::zeros(di);
692                    forward(i, x.view(), &mut v_i);
693                    // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
694                    let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
695                    // neg += H_βt^(i)·w_i (sparse scatter); subtract once at end.
696                    transpose(i, w_i.view(), &mut neg);
697                }
698                for a in 0..k {
699                    out[a] -= neg[a];
700                }
701            }
702        });
703
704    Ok(closure)
705}
706
707/// Solve the reduced shared β-system `S·δβ = r` fully on device with a
708/// Jacobi-preconditioned conjugate-gradient (Steihaug truncated-CG) loop.
709///
710/// `S` is the already-reduced symmetric positive-definite `K × K` Schur
711/// complement the streaming SAE joint fit accumulates across minibatches
712/// (`StreamingArrowSchur::take_accumulators` summed over chunks, with the
713/// global β ridge folded in). The per-row latent blocks have already been
714/// eliminated into `S` on the host streaming path; the device's job is the
715/// dense `K`-dimensional solve, which is the dominant cost at `K = 100K`.
716///
717/// The dense `S·p` matvec runs on device via cuBLAS `Dgemv`, and the PCG state
718/// vectors (`x`, `r`, `z`, `p`, `S·p`) remain device-resident for the solve.
719/// Jacobi preconditioning is an elementwise CUDA kernel; only convergence
720/// scalars (`pᵀSp`, `rᵀz`, `‖r‖`) cross the host boundary per iteration, plus the
721/// final solution vector.
722///
723/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` when CUDA is unavailable
724/// or the workload is below the dispatch policy; the caller then runs the CPU
725/// reduced-β solve. Returns `Err(ArrowSchurGpuFailure::SchurFactorFailed)`
726/// when `S` carries a non-positive Jacobi diagonal (caller escalates the
727/// proximal ridge).
728pub fn solve_reduced_beta_pcg(
729    s_acc: &Array2<f64>,
730    rhs_beta: &Array1<f64>,
731    max_iterations: usize,
732    relative_tolerance: f64,
733) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
734    solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
735        .map(|(x, _)| x)
736}
737
738#[doc(hidden)]
739pub fn solve_reduced_beta_pcg_with_diagnostics(
740    s_acc: &Array2<f64>,
741    rhs_beta: &Array1<f64>,
742    max_iterations: usize,
743    relative_tolerance: f64,
744) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
745    let k = rhs_beta.len();
746    if s_acc.dim() != (k, k) {
747        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
748            reason: format!(
749                "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
750                s_acc.dim()
751            ),
752        });
753    }
754    if k == 0 {
755        return Err(ArrowSchurGpuFailure::Unavailable);
756    }
757
758    #[cfg(not(target_os = "linux"))]
759    {
760        if relative_tolerance.is_nan() || max_iterations == 0 {
761            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
762                reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
763            });
764        }
765        Err(ArrowSchurGpuFailure::Unavailable)
766    }
767
768    #[cfg(target_os = "linux")]
769    {
770        cuda::solve_reduced_beta_pcg_with_diagnostics(
771            s_acc,
772            rhs_beta,
773            max_iterations,
774            relative_tolerance,
775        )
776    }
777}
778
779pub fn solve_sae_matrix_free_pcg(
780    sys: &ArrowSchurSystem,
781    data: &DeviceSaePcgData,
782    ridge_t: f64,
783    ridge_beta: f64,
784    rhs_beta: &Array1<f64>,
785    max_iterations: usize,
786    relative_tolerance: f64,
787) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
788    if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
789        return Err(ArrowSchurGpuFailure::Unavailable);
790    }
791    #[cfg(not(target_os = "linux"))]
792    {
793        if ridge_t.is_nan()
794            || ridge_beta.is_nan()
795            || relative_tolerance.is_nan()
796            || max_iterations == 0
797        {
798            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
799                reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
800            });
801        }
802        Err(ArrowSchurGpuFailure::Unavailable)
803    }
804    #[cfg(target_os = "linux")]
805    {
806        // #1017/#1026 dispatch GUARD: framed data (frame metadata present) carries
807        // a factored β border `G ⊗ W_{ij}` data Hessian and dense per-row cross
808        // blocks the legacy `⊗ I_p` kernel CANNOT represent — feeding it framed
809        // data would silently return a WRONG Newton step (it returns Ok with no
810        // fallback). Route framed systems to the dedicated framed kernel and
811        // legacy full-`B` systems to the legacy kernel; the two never cross.
812        if data.frame.is_some() {
813            cuda::solve_sae_matrix_free_pcg_framed(
814                sys,
815                data,
816                ridge_t,
817                ridge_beta,
818                rhs_beta,
819                max_iterations,
820                relative_tolerance,
821            )
822        } else {
823            cuda::solve_sae_matrix_free_pcg(
824                sys,
825                data,
826                ridge_t,
827                ridge_beta,
828                rhs_beta,
829                max_iterations,
830                relative_tolerance,
831            )
832        }
833    }
834}
835
836/// #1551 kernel-isolating parity probe: run the framed reduced-Schur matvec
837/// `out = S·x` exactly once on the device and return it (no PCG, no offload-floor
838/// gate). The test suite diffs this element-wise against the CPU oracle
839/// [`sae_framed_schur_matvec_cpu`] to prove the GPU kernel computes the SAME
840/// operator — a check that is independent of solver conditioning (unlike a
841/// solved-`δβ` comparison, which can diverge purely because dense Cholesky and
842/// iterative PCG resolve an ill-conditioned `S` to different accuracies). On a
843/// non-CUDA host this returns `Unavailable` so the caller skips cleanly.
844#[doc(hidden)]
845pub fn framed_schur_matvec_once_on_device(
846    sys: &ArrowSchurSystem,
847    data: &DeviceSaePcgData,
848    ridge_t: f64,
849    ridge_beta: f64,
850    x: &Array1<f64>,
851) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
852    if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
853        return Err(ArrowSchurGpuFailure::Unavailable);
854    }
855    if data.frame.is_none() {
856        return Err(ArrowSchurGpuFailure::Unavailable);
857    }
858    #[cfg(not(target_os = "linux"))]
859    {
860        // CUDA is linux-only; the framed device solve is unavailable here. Read
861        // the ridge params (a real, cheap use) so the values the shared signature
862        // carries are not flagged unused on this target — without an `#[allow]`
863        // or `let _` dodge, both of which the build.rs scanner bans.
864        if ridge_t.is_finite() && ridge_beta.is_finite() {
865            return Err(ArrowSchurGpuFailure::Unavailable);
866        }
867        Err(ArrowSchurGpuFailure::Unavailable)
868    }
869    #[cfg(target_os = "linux")]
870    {
871        cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
872    }
873}
874
875/// Reference dense back-end used by tests and as the fallback when the
876/// GPU declines. Kept here (not in `arrow_schur_gpu.rs`) so the validation
877/// suite has one canonical baseline.
878#[doc(hidden)]
879pub fn solve_arrow_newton_step_dense_reference(
880    sys: &ArrowSchurSystem,
881    ridge_t: f64,
882    ridge_beta: f64,
883) -> Result<ArrowSchurGpuSolution, String> {
884    let n = sys.rows.len();
885    let d = sys.d;
886    let k = sys.k;
887    let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
888    let mut h = Array2::<f64>::zeros((total, total));
889    let mut rhs = Array1::<f64>::zeros(total);
890    for (i, row) in sys.rows.iter().enumerate() {
891        let base = i * d;
892        for c in 0..d {
893            for r in 0..d {
894                h[[base + r, base + c]] = row.htt[[r, c]];
895            }
896            h[[base + c, base + c]] += ridge_t;
897        }
898        for c in 0..k {
899            for r in 0..d {
900                let value = row.htbeta[[r, c]];
901                h[[base + r, n * d + c]] = value;
902                h[[n * d + c, base + r]] = value;
903            }
904        }
905        for r in 0..d {
906            rhs[base + r] = -row.gt[r];
907        }
908    }
909    for c in 0..k {
910        for r in 0..k {
911            h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
912        }
913        h[[n * d + c, n * d + c]] += ridge_beta;
914        rhs[n * d + c] = -sys.gb[c];
915    }
916    let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
917        .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
918    let mut log_det = 0.0_f64;
919    for i in 0..total {
920        log_det += factor[[i, i]].ln();
921    }
922    log_det *= 2.0;
923    let solved = cholesky_solve_vector(factor.view(), rhs.view());
924    let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
925    let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
926    Ok(ArrowSchurGpuSolution {
927        delta_t,
928        delta_beta,
929        log_det_hessian: log_det,
930    })
931}
932
933/// Frames-engaged reduced-Schur penalty-side matvec `out = (P_ββ + ρ_β I)·x`,
934/// computed purely from the factored device data (issue #1017/#1026). This is
935/// the CPU bit-parity ORACLE for the GPU `arrow_sae_*` penalty kernels on the
936/// frames path: smooth `λ S_k ⊗ I_{r_k}` (each `smooth_blocks[i]` at its
937/// `global_offset` with right-width `frame.smooth_ranks[i]`) plus data-fit
938/// `G_{ij} ⊗ W_{ij}` (each `frame.frame_blocks` entry, with the `μ`-major /
939/// frame-minor index `border_offset[atom] + basis·r + frame_coord`). The
940/// accumulation order matches the device kernels exactly.
941///
942/// `out` is OVERWRITTEN: first set to `ρ_β·x`, then the penalty blocks add in.
943#[doc(hidden)]
944pub fn sae_framed_penalty_matvec_cpu(
945    data: &DeviceSaePcgData,
946    ridge_beta: f64,
947    x: &[f64],
948    out: &mut [f64],
949) {
950    let frame = data
951        .frame
952        .as_ref()
953        .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
954    let k = data.beta_dim;
955    for a in 0..k {
956        out[a] = ridge_beta * x[a];
957    }
958    // Smooth penalty `λ S_k ⊗ I_{r_k}`: y[off + ia·r + ib] += Σ_ja S[ia,ja]·x[off + ja·r + ib].
959    for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
960        let off = blk.global_offset;
961        let m = blk.factor_a.nrows();
962        for i_a in 0..m {
963            for i_b in 0..r {
964                let mut acc = 0.0_f64;
965                for j_a in 0..m {
966                    let s = blk.factor_a[[i_a, j_a]];
967                    if s == 0.0 {
968                        continue;
969                    }
970                    acc += s * x[off + j_a * r + i_b];
971                }
972                out[off + i_a * r + i_b] += acc;
973            }
974        }
975    }
976    // Data-fit penalty `G_{ij} ⊗ W_{ij}`.
977    for blk in &frame.frame_blocks {
978        let r_i = frame.ranks[blk.atom_i];
979        let r_j = frame.ranks[blk.atom_j];
980        let off_i = frame.border_offsets[blk.atom_i];
981        let off_j = frame.border_offsets[blk.atom_j];
982        let (m_i, m_j) = blk.g.dim();
983        for li in 0..m_i {
984            let yi_base = off_i + li * r_i;
985            for lj in 0..m_j {
986                let g = blk.g[[li, lj]];
987                if g == 0.0 {
988                    continue;
989                }
990                let xj_base = off_j + lj * r_j;
991                for a in 0..r_i {
992                    let mut acc = 0.0_f64;
993                    for b in 0..r_j {
994                        acc += blk.w[[a, b]] * x[xj_base + b];
995                    }
996                    out[yi_base + a] += g * acc;
997                }
998            }
999        }
1000    }
1001}
1002
1003/// Frames-engaged FULL reduced-Schur matvec `out = S·x` purely from the device
1004/// data, where `S = (P_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)`
1005/// (issue #1017/#1026). The penalty side is [`sae_framed_penalty_matvec_cpu`];
1006/// the per-row reduced term reads the dense `frame.row_htbeta[i]`
1007/// (`q_i × border_dim`, row-major), solves against the row's
1008/// `H_tt^(i)+ρ_t I` Cholesky factor, and scatters the transpose back. This is
1009/// the size-independent bit-parity oracle the device kernel mirrors; it is also
1010/// the matvec the GPU PCG iterates.
1011#[doc(hidden)]
1012pub fn sae_framed_schur_matvec_cpu(
1013    sys: &ArrowSchurSystem,
1014    data: &DeviceSaePcgData,
1015    ridge_t: f64,
1016    ridge_beta: f64,
1017    x: &[f64],
1018    out: &mut [f64],
1019) -> Result<(), String> {
1020    let frame = data
1021        .frame
1022        .as_ref()
1023        .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1024    let k = data.beta_dim;
1025    sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1026    if frame.row_htbeta.len() != sys.rows.len() {
1027        return Err(format!(
1028            "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1029            frame.row_htbeta.len(),
1030            sys.rows.len()
1031        ));
1032    }
1033    for (i, row) in sys.rows.iter().enumerate() {
1034        let slab = &frame.row_htbeta[i];
1035        if slab.is_empty() {
1036            continue;
1037        }
1038        let qi = sys.row_dims[i];
1039        if qi == 0 || slab.len() != qi * k {
1040            continue;
1041        }
1042        // h = H_tβ^(i) · x  (length q_i).
1043        let mut h = vec![0.0_f64; qi];
1044        for c in 0..qi {
1045            let base = c * k;
1046            let mut acc = 0.0_f64;
1047            for a in 0..k {
1048                acc += slab[base + a] * x[a];
1049            }
1050            h[c] = acc;
1051        }
1052        // solve (H_tt^(i)+ρ_t I) s = h.
1053        let mut block = row.htt.clone();
1054        for d in 0..qi {
1055            block[[d, d]] += ridge_t;
1056        }
1057        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1058            .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1059        let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1060        // out -= H_βt^(i) · s = (H_tβ^(i))ᵀ · s.
1061        for c in 0..qi {
1062            let sc = s[c];
1063            if sc == 0.0 {
1064                continue;
1065            }
1066            let base = c * k;
1067            for a in 0..k {
1068                out[a] -= slab[base + a] * sc;
1069            }
1070        }
1071    }
1072    Ok(())
1073}
1074
1075#[cfg(target_os = "linux")]
1076mod cuda {
1077    use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1078    use gam_gpu::driver::to_i32;
1079    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1080    use crate::arrow_schur::{
1081        ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1082    };
1083    use cudarc::cublas::sys::{
1084        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1085    };
1086    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1087    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1088    use cudarc::driver::{
1089        CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1090        PushKernelArg,
1091    };
1092    use ndarray::Array1;
1093    use std::sync::{Arc, OnceLock};
1094
1095    /// Per-row work slot for the row-block-granular multi-GPU solve. Inputs are
1096    /// the packed single-row buffers (`d×d` D block + ρ_t ridge, `d×k` B block,
1097    /// `d` g vector); the forward pass fills the whitened factors `l/u/y` and the
1098    /// per-tile reduction lands in the tile's leading slot.
1099    struct RowSlot {
1100        // Inputs (packed once on the host, column-major).
1101        d_block: Vec<f64>, // d*d
1102        b_block: Vec<f64>, // d*k
1103        g_vec: Vec<f64>,   // d
1104        // Forward outputs, kept on the host for the back-sub pass.
1105        l_block: Vec<f64>, // d*d lower factor, column-major
1106        u_vec: Vec<f64>,   // d   (= L^{-1} g)
1107        y_block: Vec<f64>, // d*k (= L^{-1} B), column-major
1108        log_det_local: f64,
1109        // Set on a non-PD pivot so the orchestrator can raise RidgeBumpRequired
1110        // for the offending global row instead of silently falling back.
1111        bump: Option<f64>,
1112        // Tile-level reduction, written into the tile's first slot only.
1113        tile_partial_schur: Option<Vec<f64>>, // k*k col-major, = Σ Y_iᵀY_i
1114        tile_partial_rhs: Option<Vec<f64>>,   // k, = Σ Y_iᵀu_i
1115        // Back-sub output for this row.
1116        delta_t_block: Vec<f64>, // d
1117    }
1118
1119    /// Row-block-granular multi-GPU Arrow-Schur Newton solve.
1120    ///
1121    /// The solve is separable across row blocks in both phases:
1122    ///   * forward — each row's local Cholesky `L_i`, whitening
1123    ///     `u_i = L_i⁻¹g_i`, `Y_i = L_i⁻¹B_i`, and partial Schur
1124    ///     `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` are independent;
1125    ///   * backward — `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)` is independent.
1126    /// Only the small shared `K×K` reduce + factor + `δβ` solve is central.
1127    ///
1128    /// `gam_gpu::pool::scatter_batched` hands each device a contiguous row
1129    /// tile on its own bound context/stream; the per-tile forward keeps the
1130    /// POTRF fused with its dependent TRSM + Schur GEMM on that one stream, so no
1131    /// on-stream solve is orphaned. Tile partials and per-tile `log|L|` are
1132    /// reduced on the host (in tile/row order), `S_β` is factored on the primary
1133    /// device, and the back-sub is scattered back across the same tiles.
1134    ///
1135    /// Returns `Unavailable` (caller uses a single-device path) when the system
1136    /// carries matrix-free operators, the shared block is not dense `K×K`, the
1137    /// pool is single-device, or any tile's device work declines. A non-PD tip
1138    /// block surfaces as `RidgeBumpRequired` for the precise global row.
1139    pub(super) fn solve_multi_gpu(
1140        sys: &ArrowSchurSystem,
1141        ridge_t: f64,
1142        ridge_beta: f64,
1143    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1144        let n = sys.rows.len();
1145        let d = sys.d;
1146        let k = sys.k;
1147        if n == 0 || d == 0 || k == 0 {
1148            return Err(ArrowSchurGpuFailure::Unavailable);
1149        }
1150        // Dense shared block + materialised per-row slabs are required; the
1151        // public entry already rejected matrix-free operators, but re-check so
1152        // this routine is safe in isolation.
1153        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1154            return Err(ArrowSchurGpuFailure::Unavailable);
1155        }
1156
1157        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1158            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1159        if runtime.device_count() < 2 {
1160            return Err(ArrowSchurGpuFailure::Unavailable);
1161        }
1162
1163        // Pack one slot per row (column-major), folding ρ_t into each D block.
1164        let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1165        for row in &sys.rows {
1166            if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1167                return Err(ArrowSchurGpuFailure::Unavailable);
1168            }
1169            let mut d_block = Vec::with_capacity(d * d);
1170            let mut b_block = Vec::with_capacity(d * k);
1171            let mut g_vec = Vec::with_capacity(d);
1172            pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1173            slots.push(RowSlot {
1174                d_block,
1175                b_block,
1176                g_vec,
1177                l_block: Vec::new(),
1178                u_vec: Vec::new(),
1179                y_block: Vec::new(),
1180                log_det_local: 0.0,
1181                bump: None,
1182                tile_partial_schur: None,
1183                tile_partial_rhs: None,
1184                delta_t_block: vec![0.0; d],
1185            });
1186        }
1187
1188        // ---- Forward pass: per-device row tile, fused on its own stream ----
1189        let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1190            forward_tile(ordinal, d, k, tile)
1191        });
1192        if forward_ok.is_none() {
1193            return Err(ArrowSchurGpuFailure::Unavailable);
1194        }
1195
1196        // Surface a non-PD tip block as a precise per-row ridge bump.
1197        let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1198        if let Some((row, bump)) = slots
1199            .iter()
1200            .enumerate()
1201            .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1202        {
1203            return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1204        }
1205
1206        // ---- Central: reduce tile partials → S_β, r_β; factor; solve δβ ----
1207        // Seed S_β with H_ββ + ρ_β I (column-major) and r_β with -g_β, then fold
1208        // in the per-tile partials in tile order so the reduction order tracks
1209        // the single-device accumulation (up to inter-tile reassociation).
1210        let mut schur_host = vec![0.0_f64; k * k];
1211        for col in 0..k {
1212            for row in 0..k {
1213                let mut v = sys.hbb[[row, col]];
1214                if row == col {
1215                    v += ridge_beta;
1216                }
1217                schur_host[col * k + row] = v;
1218            }
1219        }
1220        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1221        let mut log_det = 0.0_f64;
1222        for start in tile_starts(&row_base_of_tile) {
1223            let slot = &slots[start];
1224            let partial_schur = slot
1225                .tile_partial_schur
1226                .as_ref()
1227                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1228            let partial_rhs = slot
1229                .tile_partial_rhs
1230                .as_ref()
1231                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1232            // `accumulate_schur` writes `partial_schur = -Σ_tile Y_iᵀY_i` (GEMM
1233            // α=-1, β=1 into a zero seed) and `partial_rhs = +Σ_tile Y_iᵀu_i`.
1234            // The reduced Schur is `S = (H_ββ+ρI) − Σ_all Y_iᵀY_i`, so adding the
1235            // (already-negated) partials reproduces the single-device sign.
1236            for idx in 0..k * k {
1237                schur_host[idx] += partial_schur[idx];
1238            }
1239            for a in 0..k {
1240                rhs_host[a] += partial_rhs[a];
1241            }
1242        }
1243        for slot in &slots {
1244            log_det += slot.log_det_local;
1245        }
1246
1247        // Factor S_β and solve δβ on the primary device (small K×K leaf). The
1248        // stream carries the primary context (same pattern as `solve()`); no
1249        // thread bind is needed for the cuSOLVER/cuBLAS handles created from it.
1250        let primary = runtime.selected_device().ordinal;
1251        let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1252            .and_then(|ctx| ctx.new_stream().ok())
1253            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1254        let solver =
1255            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1256        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1257        let mut schur_dev = stream
1258            .clone_htod(&schur_host)
1259            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1260        let mut rhs_dev = stream
1261            .clone_htod(&rhs_host)
1262            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1263        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1264        if info != 0 {
1265            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1266                reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1267            });
1268        }
1269        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1270        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1271        let delta_beta_host = stream
1272            .clone_dtoh(&rhs_dev)
1273            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1274        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1275        let l_schur_host = stream
1276            .clone_dtoh(&schur_dev)
1277            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1278        for j in 0..k {
1279            log_det += l_schur_host[j * k + j].ln();
1280        }
1281        log_det *= 2.0;
1282
1283        // ---- Backward pass: δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ), per-device tile ----
1284        let delta_beta_ref = &delta_beta_host;
1285        let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1286            back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1287        });
1288        if back_ok.is_none() {
1289            return Err(ArrowSchurGpuFailure::Unavailable);
1290        }
1291
1292        // Stitch per-row δt into the stacked (n*d) result.
1293        let mut delta_t = Array1::<f64>::zeros(n * d);
1294        for (i, slot) in slots.iter().enumerate() {
1295            let base = i * d;
1296            for r in 0..d {
1297                delta_t[base + r] = slot.delta_t_block[r];
1298            }
1299        }
1300
1301        Ok(ArrowSchurGpuSolution {
1302            delta_t,
1303            delta_beta,
1304            log_det_hessian: log_det,
1305        })
1306    }
1307
1308    /// Tile starts: the leading global row index of each device tile (where the
1309    /// tile-level partial reduction was written by the forward pass).
1310    fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1311        tiles.iter().map(|(_, range)| range.start)
1312    }
1313
1314    /// Forward pass for one device row tile, running on `ordinal`'s bound stream.
1315    /// Factors each row block, whitens `u`/`Y`, accumulates the tile's partial
1316    /// Schur `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` into the tile's leading slot, keeps the
1317    /// per-row `L`/`u`/`Y` on the host for back-sub, and records the per-row
1318    /// `Σ_j log L_jj`. A non-PD pivot is recorded in `slot.bump` (the tile still
1319    /// returns `Some(())` so the orchestrator raises a precise `RidgeBumpRequired`
1320    /// rather than collapsing the whole batch to CPU).
1321    fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1322        if tile.is_empty() {
1323            return Some(());
1324        }
1325        // `scatter_batched` has already bound this ordinal's context on this
1326        // worker thread; the stream below targets that same device.
1327        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1328            .and_then(|ctx| ctx.new_stream().ok())?;
1329        let solver = DnHandle::new(stream.clone()).ok()?;
1330        let blas = CudaBlas::new(stream.clone()).ok()?;
1331        let m = tile.len();
1332
1333        // Stack the tile's D, B, g into contiguous device buffers (same layout
1334        // the single-device path packs for `m` rows).
1335        let mut d_host = Vec::with_capacity(m * d * d);
1336        let mut b_host = Vec::with_capacity(m * d * k);
1337        let mut g_host = Vec::with_capacity(m * d);
1338        for slot in tile.iter() {
1339            d_host.extend_from_slice(&slot.d_block);
1340            b_host.extend_from_slice(&slot.b_block);
1341            g_host.extend_from_slice(&slot.g_vec);
1342        }
1343        let mut d_dev = stream.clone_htod(&d_host).ok()?;
1344        let mut b_dev = stream.clone_htod(&b_host).ok()?;
1345        let mut g_dev = stream.clone_htod(&g_host).ok()?;
1346
1347        // Batched POTRF; a non-PD block records its bump and stops the tile.
1348        // The bump is deficit-aware (Gershgorin lower bound on λ_min of the
1349        // already-ridged `d_block`), NOT derived from the cuSOLVER `info` —
1350        // which is a 1-based pivot ROW INDEX, not a pivot magnitude — so a
1351        // strongly-indefinite block recovers in one outer-loop retry.
1352        let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1353        if let Some(local) = info_host.iter().position(|info| *info != 0) {
1354            tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1355                &tile[local].d_block,
1356                d,
1357            ));
1358            return Some(());
1359        }
1360
1361        // Whiten: u = L⁻¹ g, Y = L⁻¹ B.
1362        trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1363        trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1364
1365        // Tile partial Schur: zero-seeded so the host adds the H_ββ seed once.
1366        let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1367        let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1368        accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1369
1370        // Download L, u, Y, and the tile partials.
1371        let l_host = stream.clone_dtoh(&d_dev).ok()?;
1372        let u_host = stream.clone_dtoh(&g_dev).ok()?;
1373        let y_host = stream.clone_dtoh(&b_dev).ok()?;
1374        let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1375        let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1376
1377        for (local, slot) in tile.iter_mut().enumerate() {
1378            let l_base = local * d * d;
1379            let u_base = local * d;
1380            let y_base = local * d * k;
1381            slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1382            slot.u_vec = u_host[u_base..u_base + d].to_vec();
1383            slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1384            let mut log_det_local = 0.0_f64;
1385            for j in 0..d {
1386                log_det_local += l_host[l_base + j * d + j].ln();
1387            }
1388            slot.log_det_local = log_det_local;
1389        }
1390        tile[0].tile_partial_schur = Some(partial_schur);
1391        tile[0].tile_partial_rhs = Some(partial_rhs);
1392        Some(())
1393    }
1394
1395    /// Back-substitution for one device row tile: `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)`.
1396    /// Re-uploads the tile's kept `L`/`u`/`Y` to `ordinal`, applies the GEMV
1397    /// accumulate + transposed TRSM, and writes each row's `δt` into its slot.
1398    fn back_sub_tile(
1399        ordinal: usize,
1400        d: usize,
1401        k: usize,
1402        delta_beta: &[f64],
1403        tile: &mut [RowSlot],
1404    ) -> Option<()> {
1405        if tile.is_empty() {
1406            return Some(());
1407        }
1408        // `scatter_batched` has already bound this ordinal's context on this
1409        // worker thread; the stream below targets that same device.
1410        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1411            .and_then(|ctx| ctx.new_stream().ok())?;
1412        let blas = CudaBlas::new(stream.clone()).ok()?;
1413        let m = tile.len();
1414
1415        let mut l_host = Vec::with_capacity(m * d * d);
1416        let mut u_host = Vec::with_capacity(m * d);
1417        let mut y_host = Vec::with_capacity(m * d * k);
1418        for slot in tile.iter() {
1419            l_host.extend_from_slice(&slot.l_block);
1420            u_host.extend_from_slice(&slot.u_vec);
1421            y_host.extend_from_slice(&slot.y_block);
1422        }
1423        let d_dev = stream.clone_htod(&l_host).ok()?;
1424        let mut g_dev = stream.clone_htod(&u_host).ok()?;
1425        let b_dev = stream.clone_htod(&y_host).ok()?;
1426        let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1427
1428        // g ← u + Y·δβ, then x = L⁻ᵀ g; δt = -x.
1429        accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1430        trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1431        let x_host = stream.clone_dtoh(&g_dev).ok()?;
1432        for (local, slot) in tile.iter_mut().enumerate() {
1433            let base = local * d;
1434            for r in 0..d {
1435                slot.delta_t_block[r] = -x_host[base + r];
1436            }
1437        }
1438        Some(())
1439    }
1440
1441    pub(super) fn solve(
1442        sys: &ArrowSchurSystem,
1443        ridge_t: f64,
1444        ridge_beta: f64,
1445    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1446        let n = sys.rows.len();
1447        let d = sys.d;
1448        let k = sys.k;
1449        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1450            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1451
1452        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1453            .and_then(|ctx| ctx.new_stream().ok())
1454            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1455        let solver =
1456            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1457        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1458
1459        // ----- Pack + upload D, B, g -----
1460        let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1461        let mut d_dev = stream
1462            .clone_htod(&d_host)
1463            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1464        let mut b_dev = stream
1465            .clone_htod(&b_host)
1466            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1467        let mut g_dev = stream
1468            .clone_htod(&g_host)
1469            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1470
1471        // ----- Layer A: batched lower Cholesky of D in place -----
1472        // This POTRF is fused with the downstream TRSM + Schur GEMM + back-sub
1473        // on this one stream, so splitting only the POTRF across devices would
1474        // orphan the dependent on-stream solves. Multi-GPU here is the
1475        // whole-solve row-block split in `solve_arrow_newton_step` (see
1476        // `solve_multi_gpu`), not a per-layer split — this device-resident path
1477        // is the single-device leaf the split dispatches per tile.
1478        let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1479        if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1480            // `info` is cuSOLVER's 1-based pivot ROW INDEX, not a magnitude;
1481            // size the bump from the block's own entries (Gershgorin λ_min
1482            // bound) so a strongly-indefinite block recovers in one retry.
1483            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1484                row: idx,
1485                bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1486            });
1487        }
1488
1489        // ----- Layer B (1/2): in-place triangular solves -----
1490        // u_i = L_i^{-1} g_i, packed as a stacked (n*d) column-vector.
1491        trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1492        // Y_i = L_i^{-1} B_i, in place over the (n*d) × k buffer (laid out as
1493        // n stacked column-major d×k tiles).
1494        trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1495
1496        // ----- Layer B (2/2): Schur reduction via single big GEMM / GEMV -----
1497        // Y_all is (n*d) × k column-major: viewing all n stacked d×k tiles as
1498        // one big matrix is bit-exact because each tile is column-major with
1499        // leading dim d and the tiles are contiguous in memory, so the
1500        // combined leading dim is n*d only for the *outer* matrix view. To
1501        // make the single-GEMM equivalence hold we must treat the stacked
1502        // buffer as (n*d) × k column-major with leading dim = n*d, which
1503        // means columns of Y_all are interleaved by row across blocks.
1504        // That is NOT what we packed. So we use the cuBLAS stride pattern
1505        // instead: stride-by-block, transpose-A, and *accumulate* into one
1506        // S_β buffer via beta=1 across batches. Equivalent flop count, no
1507        // extra reduction kernel, and correct layout.
1508        //
1509        // Concretely: schur ← C + ρ_β I; rhs ← -g_β; then for each block
1510        //   schur -= Y_i^T Y_i      (k×k)
1511        //   rhs   += Y_i^T u_i      (k)
1512        // We launch this as `n` sequential GEMMs/GEMVs with beta=1 on the
1513        // accumulator. Layer D fuses these into one NVRTC launch.
1514        let schur_init: Vec<f64> = {
1515            let mut tmp = Vec::with_capacity(k * k);
1516            for col in 0..k {
1517                for row in 0..k {
1518                    let mut v = sys.hbb[[row, col]];
1519                    if row == col {
1520                        v += ridge_beta;
1521                    }
1522                    tmp.push(v);
1523                }
1524            }
1525            tmp
1526        };
1527        let mut schur_dev = stream
1528            .clone_htod(&schur_init)
1529            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1530        let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1531        let mut rhs_dev = stream
1532            .clone_htod(&rhs_init)
1533            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1534
1535        accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1536
1537        // ----- Layer C (1/2): factor S_β and solve for δβ -----
1538        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1539        if info != 0 {
1540            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1541                reason: format!("Schur Cholesky failed at pivot {info}"),
1542            });
1543        }
1544        // δβ ← L_S^{-T} L_S^{-1} rhs
1545        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1546        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1547        let delta_beta_host = stream
1548            .clone_dtoh(&rhs_dev)
1549            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1550        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1551
1552        // ----- Layer C (2/2): back-sub δt_i = -L_i^{-T} (u_i + Y_i δβ) -----
1553        // Already on device:
1554        //   g_dev holds u_i stacked (n*d).
1555        //   b_dev holds Y_i stacked column-major n×(d×k) tiles.
1556        // Compute g_dev ← g_dev + Y_block · δβ per block (cuBLAS gemv with beta=1),
1557        // then in-place trsm with L_i^T (CUBLAS_OP_T) to obtain x_i, and finally
1558        // δt_i = -x_i on host after download.
1559        accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1560        trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1561
1562        let x_host = stream
1563            .clone_dtoh(&g_dev)
1564            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1565        let mut delta_t = Array1::<f64>::zeros(n * d);
1566        for (i, v) in x_host.iter().enumerate() {
1567            delta_t[i] = -*v;
1568        }
1569
1570        // ----- log|H| = 2 Σ log L_{i,jj} + 2 Σ log R_{β,aa} -----
1571        let l_local_host = stream
1572            .clone_dtoh(&d_dev)
1573            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1574        let l_schur_host = stream
1575            .clone_dtoh(&schur_dev)
1576            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1577        let mut log_det = 0.0_f64;
1578        for i in 0..n {
1579            let base = i * d * d;
1580            for j in 0..d {
1581                log_det += l_local_host[base + j * d + j].ln();
1582            }
1583        }
1584        for j in 0..k {
1585            log_det += l_schur_host[j * k + j].ln();
1586        }
1587        log_det *= 2.0;
1588
1589        Ok(ArrowSchurGpuSolution {
1590            delta_t,
1591            delta_beta,
1592            log_det_hessian: log_det,
1593        })
1594    }
1595
1596    fn potrf_batched(
1597        solver: &DnHandle,
1598        stream: &Arc<CudaStream>,
1599        p: usize,
1600        batch: usize,
1601        matrices: &mut CudaSlice<f64>,
1602    ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1603        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1604        let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1605        let matrix_len = p * p;
1606        let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1607        let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1608        let mut ptrs = Vec::with_capacity(batch);
1609        for idx in 0..batch {
1610            ptrs.push(base_ptr + (idx as u64) * bytes_per);
1611        }
1612        let mut ptrs_dev = stream
1613            .clone_htod(&ptrs)
1614            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1615        let mut info_dev = stream
1616            .alloc_zeros::<i32>(batch)
1617            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1618        let status = {
1619            let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1620            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1621            // SAFETY: pointer array and info buffer live on the device,
1622            // matrices_dev holds `batch` contiguous p×p column-major blocks.
1623            unsafe {
1624                cusolver_sys::cusolverDnDpotrfBatched(
1625                    solver.cu(),
1626                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1627                    p_i,
1628                    ptrs_ptr as *mut *mut f64,
1629                    p_i,
1630                    info_ptr as *mut i32,
1631                    batch_i,
1632                )
1633            }
1634        };
1635        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1636            return Err(ArrowSchurGpuFailure::Unavailable);
1637        }
1638        stream
1639            .clone_dtoh(&info_dev)
1640            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1641    }
1642
1643    fn potrf_single(
1644        solver: &DnHandle,
1645        stream: &Arc<CudaStream>,
1646        p: usize,
1647        matrix: &mut CudaSlice<f64>,
1648    ) -> Result<i32, ArrowSchurGpuFailure> {
1649        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1650        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1651        let mut lwork = 0_i32;
1652        {
1653            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1654            // SAFETY: buffer query against a live p-by-p column-major device matrix.
1655            let status = unsafe {
1656                cusolver_sys::cusolverDnDpotrf_bufferSize(
1657                    solver.cu(),
1658                    uplo,
1659                    p_i,
1660                    mat_ptr as *mut f64,
1661                    p_i,
1662                    &mut lwork,
1663                )
1664            };
1665            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1666                return Err(ArrowSchurGpuFailure::Unavailable);
1667            }
1668        }
1669        let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1670        let mut workspace = stream
1671            .alloc_zeros::<f64>(lwork_usize.max(1))
1672            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1673        let mut info_dev = stream
1674            .alloc_zeros::<i32>(1)
1675            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1676        {
1677            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1678            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1679            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1680            // SAFETY: all three pointers refer to live, correctly sized device buffers.
1681            let status = unsafe {
1682                cusolver_sys::cusolverDnDpotrf(
1683                    solver.cu(),
1684                    uplo,
1685                    p_i,
1686                    mat_ptr as *mut f64,
1687                    p_i,
1688                    work_ptr as *mut f64,
1689                    lwork,
1690                    info_ptr as *mut i32,
1691                )
1692            };
1693            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1694                return Err(ArrowSchurGpuFailure::Unavailable);
1695            }
1696        }
1697        let info_host = stream
1698            .clone_dtoh(&info_dev)
1699            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1700        Ok(info_host[0])
1701    }
1702
1703    /// In-place lower-triangular solves `X_i ← L_i^{-1} X_i` over the n stacked
1704    /// d×nrhs RHS tiles in `rhs`. Uses `cublasDtrsmBatched` so all n solves
1705    /// hit the device in one launch.
1706    fn trsm_batched_lower_inplace(
1707        blas: &CudaBlas,
1708        stream: &Arc<CudaStream>,
1709        d: usize,
1710        n: usize,
1711        nrhs: usize,
1712        l_stack: &CudaSlice<f64>,
1713        rhs_stack: &mut CudaSlice<f64>,
1714    ) -> Result<(), ArrowSchurGpuFailure> {
1715        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1716    }
1717
1718    /// As above but with `L_i^T` instead of `L_i`.
1719    fn trsm_batched_lower_inplace_transposed(
1720        blas: &CudaBlas,
1721        stream: &Arc<CudaStream>,
1722        d: usize,
1723        n: usize,
1724        nrhs: usize,
1725        l_stack: &CudaSlice<f64>,
1726        rhs_stack: &mut CudaSlice<f64>,
1727    ) -> Result<(), ArrowSchurGpuFailure> {
1728        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1729    }
1730
1731    fn trsm_batched_inplace_inner(
1732        blas: &CudaBlas,
1733        stream: &Arc<CudaStream>,
1734        d: usize,
1735        n: usize,
1736        nrhs: usize,
1737        l_stack: &CudaSlice<f64>,
1738        rhs_stack: &mut CudaSlice<f64>,
1739        transposed: bool,
1740    ) -> Result<(), ArrowSchurGpuFailure> {
1741        let alpha = 1.0_f64;
1742        let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1743        let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1744        let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1745        let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1746        let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1747        let (l_base, _l_record) = l_stack.device_ptr(stream);
1748        let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1749        let mut l_ptrs = Vec::with_capacity(n);
1750        let mut rhs_ptrs = Vec::with_capacity(n);
1751        for i in 0..n {
1752            l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1753            rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1754        }
1755        let mut l_ptrs_dev = stream
1756            .clone_htod(&l_ptrs)
1757            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1758        let mut rhs_ptrs_dev = stream
1759            .clone_htod(&rhs_ptrs)
1760            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1761        let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1762        let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1763        let op = if transposed {
1764            cublasOperation_t::CUBLAS_OP_T
1765        } else {
1766            cublasOperation_t::CUBLAS_OP_N
1767        };
1768        let handle = *blas.handle();
1769        // SAFETY: pointer arrays and base buffers were just constructed from
1770        // live device allocations covering the entire batch.
1771        let status = unsafe {
1772            cudarc::cublas::sys::cublasDtrsmBatched(
1773                handle,
1774                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1775                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1776                op,
1777                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1778                d_i,
1779                nrhs_i,
1780                &alpha,
1781                l_ptrs_ptr as *const *const f64,
1782                d_i,
1783                rhs_ptrs_ptr as *const *mut f64,
1784                d_i,
1785                batch_i,
1786            )
1787        };
1788        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1789            return Err(ArrowSchurGpuFailure::Unavailable);
1790        }
1791        Ok(())
1792    }
1793
1794    /// Single-matrix lower-triangular solve: `rhs ← L^{-1} rhs` (or
1795    /// `L^{-T} rhs` if `transposed`). For the Schur Cholesky back-sub.
1796    fn trsm_single(
1797        blas: &CudaBlas,
1798        stream: &Arc<CudaStream>,
1799        n: usize,
1800        l: &CudaSlice<f64>,
1801        rhs: &mut CudaSlice<f64>,
1802        upper: bool,
1803        transposed: bool,
1804    ) -> Result<(), ArrowSchurGpuFailure> {
1805        let alpha = 1.0_f64;
1806        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1807        let handle = *blas.handle();
1808        let (l_ptr, _l_rec) = l.device_ptr(stream);
1809        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1810        // SAFETY: single n×n lower factor and n-vector RHS on device.
1811        let status = unsafe {
1812            cudarc::cublas::sys::cublasDtrsm_v2(
1813                handle,
1814                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1815                if upper {
1816                    cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1817                } else {
1818                    cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1819                },
1820                if transposed {
1821                    cublasOperation_t::CUBLAS_OP_T
1822                } else {
1823                    cublasOperation_t::CUBLAS_OP_N
1824                },
1825                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1826                n_i,
1827                1,
1828                &alpha,
1829                l_ptr as *const f64,
1830                n_i,
1831                rhs_ptr as *mut f64,
1832                n_i,
1833            )
1834        };
1835        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1836            return Err(ArrowSchurGpuFailure::Unavailable);
1837        }
1838        Ok(())
1839    }
1840
1841    /// Accumulate `schur ← schur − Σ_i Y_i^T Y_i` and `rhs ← rhs + Σ_i Y_i^T u_i`
1842    /// using one GEMM and one GEMV per block. Each call uses beta=1 to chain
1843    /// the accumulation device-side.
1844    fn accumulate_schur(
1845        blas: &CudaBlas,
1846        d: usize,
1847        k: usize,
1848        n: usize,
1849        y_stack: &CudaSlice<f64>,
1850        u_stack: &CudaSlice<f64>,
1851        schur: &mut CudaSlice<f64>,
1852        rhs: &mut CudaSlice<f64>,
1853    ) -> Result<(), ArrowSchurGpuFailure> {
1854        let y_block_elems = d * k;
1855        let u_block_elems = d;
1856        for i in 0..n {
1857            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1858            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1859            // GEMM: schur += (-1) · Y_i^T · Y_i  (Y_i is d×k col-major; out is k×k)
1860            let gemm_cfg = GemmConfig::<f64> {
1861                transa: cublasOperation_t::CUBLAS_OP_T,
1862                transb: cublasOperation_t::CUBLAS_OP_N,
1863                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1864                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1865                k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1866                alpha: -1.0,
1867                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1868                ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1869                beta: 1.0,
1870                ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1871            };
1872            // SAFETY: y_slice is d×k col-major, schur is k×k col-major; alpha/beta scalars set above.
1873            unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1874                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1875            // GEMV: rhs += 1 · Y_i^T · u_i
1876            let gemv_cfg = GemvConfig::<f64> {
1877                trans: cublasOperation_t::CUBLAS_OP_T,
1878                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1879                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1880                alpha: 1.0,
1881                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1882                incx: 1,
1883                beta: 1.0,
1884                incy: 1,
1885            };
1886            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1887            // device buffers; `rhs` is the length-k accumulator.
1888            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1889                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1890        }
1891        Ok(())
1892    }
1893
1894    /// `#1017` resident gradient path: accumulate ONLY the Schur RHS term
1895    /// `rhs += Σ_i Y_iᵀ u_i`, skipping the `−Σ_i Y_iᵀ Y_i` matrix GEMM that the
1896    /// resident frame already folded into its persistent `L_S` factor. This is
1897    /// the per-iterate-cheap counterpart of [`accumulate_schur`]: the GEMV here
1898    /// is bit-identical to the GEMV inside `accumulate_schur` (same config, same
1899    /// `beta=1` accumulation order over rows), so the resident frame's `δβ`
1900    /// matches a full `solve()` at the same gradient.
1901    fn accumulate_schur_rhs_only(
1902        blas: &CudaBlas,
1903        d: usize,
1904        k: usize,
1905        n: usize,
1906        y_stack: &CudaSlice<f64>,
1907        u_stack: &CudaSlice<f64>,
1908        rhs: &mut CudaSlice<f64>,
1909    ) -> Result<(), ArrowSchurGpuFailure> {
1910        let y_block_elems = d * k;
1911        let u_block_elems = d;
1912        for i in 0..n {
1913            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1914            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1915            let gemv_cfg = GemvConfig::<f64> {
1916                trans: cublasOperation_t::CUBLAS_OP_T,
1917                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1918                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1919                alpha: 1.0,
1920                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1921                incx: 1,
1922                beta: 1.0,
1923                incy: 1,
1924            };
1925            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1926            // device buffers; `rhs` is the length-k accumulator.
1927            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1928                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1929        }
1930        Ok(())
1931    }
1932
1933    /// Accumulate `g_dev[i] ← u_i + Y_i · δβ` per block. This is the
1934    /// pre-trsm RHS for the back-substitution `L_i^T x_i = w_i`.
1935    fn accumulate_back_sub_rhs(
1936        blas: &CudaBlas,
1937        d: usize,
1938        k: usize,
1939        n: usize,
1940        y_stack: &CudaSlice<f64>,
1941        delta_beta: &CudaSlice<f64>,
1942        u_stack: &mut CudaSlice<f64>,
1943    ) -> Result<(), ArrowSchurGpuFailure> {
1944        let y_block_elems = d * k;
1945        let u_block_elems = d;
1946        for i in 0..n {
1947            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1948            let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1949            let gemv_cfg = GemvConfig::<f64> {
1950                trans: cublasOperation_t::CUBLAS_OP_N,
1951                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1952                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1953                alpha: 1.0,
1954                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1955                incx: 1,
1956                beta: 1.0,
1957                incy: 1,
1958            };
1959            // SAFETY: y_slice / delta_beta / u_slice are live device buffers
1960            // of the expected sizes (d×k, k, d).
1961            unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1962                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1963        }
1964        Ok(())
1965    }
1966
1967    // ────────────────────────────────────────────────────────────────────
1968    // Layer D + E — fused NVRTC dispatch.
1969    //
1970    // The forward kernel (`arrow_schur_forward_pgroup`) is a single launch
1971    // that, per row block, factors `D_i + ρI = L_i L_iᵀ` in shared memory,
1972    // forward-solves `u_i = L_i⁻¹ g_i` and `Y_i = L_i⁻¹ B_i`, and emits the
1973    // per-block Schur partials `partial_S[i] = Yᵀ Y` (R×R) and
1974    // `partial_r[i] = Yᵀ u` (R). The host reduces partials on the CPU after
1975    // dtoh (one fused sum across `n` blocks of R²+R doubles; cheap because
1976    // n·R² ≲ 5M doubles at large scale), assembles `S_β`, factors it via
1977    // cuSOLVER, and launches the back-substitution kernel
1978    // `arrow_schur_back_sub_pgroup` to recover `δt_i = -L_i⁻ᵀ(u_i + Y_i δβ)`
1979    // without re-uploading the local factors.
1980    // ────────────────────────────────────────────────────────────────────
1981
1982    use std::collections::HashMap;
1983    use std::sync::Mutex;
1984
1985    /// One compiled NVRTC module per `(cc_major, cc_minor, p_max, r_template)`.
1986    /// `cc_*` lets one process drive multiple device generations; the
1987    /// `(p_max, r_template)` pair selects the shared-memory layout baked into
1988    /// the kernel source.
1989    struct FusedModuleCache {
1990        modules: Mutex<
1991            HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1992        >,
1993    }
1994
1995    fn fused_module_cache() -> &'static FusedModuleCache {
1996        static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1997        CACHE.get_or_init(|| FusedModuleCache {
1998            modules: Mutex::new(HashMap::new()),
1999        })
2000    }
2001
2002    fn fused_module_for(
2003        ctx: &Arc<CudaContext>,
2004        key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2005    ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2006        let cache = fused_module_cache();
2007        if let Ok(guard) = cache.modules.lock() {
2008            if let Some(existing) = guard.get(&key) {
2009                return Ok(existing.clone());
2010            }
2011        }
2012        let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2013            key.p_max as usize,
2014            key.r_template as usize,
2015        );
2016        let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
2017            ArrowSchurGpuFailure::SchurFactorFailed {
2018                reason: format!(
2019                    "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2020                    key.p_max, key.r_template
2021                ),
2022            }
2023        })?;
2024        let module = ctx
2025            .load_module(ptx)
2026            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2027        if let Ok(mut guard) = cache.modules.lock() {
2028            guard.entry(key).or_insert_with(|| module.clone());
2029        }
2030        Ok(module)
2031    }
2032
2033    const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2034extern "C" __global__ void arrow_pcg_jacobi_mul(
2035    const double* __restrict__ inv_diag,
2036    const double* __restrict__ r,
2037    double* __restrict__ z,
2038    int n
2039) {
2040    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2041    if (idx < n) {
2042        z[idx] = inv_diag[idx] * r[idx];
2043    }
2044}
2045
2046extern "C" __global__ void arrow_pcg_update_p(
2047    const double* __restrict__ z,
2048    double beta,
2049    double* __restrict__ p,
2050    int n
2051) {
2052    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2053    if (idx < n) {
2054        p[idx] = z[idx] + beta * p[idx];
2055    }
2056}
2057
2058extern "C" __global__ void arrow_sae_init(
2059    double* __restrict__ out,
2060    const double* __restrict__ x,
2061    double ridge,
2062    int n
2063) {
2064    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2065    if (idx < n) {
2066        out[idx] = ridge * x[idx];
2067    }
2068}
2069
2070extern "C" __global__ void arrow_sae_smooth_matvec(
2071    const double* __restrict__ x,
2072    double* __restrict__ out,
2073    const int* __restrict__ block_offsets,
2074    const int* __restrict__ block_m,
2075    const int* __restrict__ factor_ptr,
2076    const double* __restrict__ factors,
2077    int p,
2078    int n_blocks
2079) {
2080    int block_id = blockIdx.y;
2081    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2082    if (block_id >= n_blocks) {
2083        return;
2084    }
2085    int m = block_m[block_id];
2086    int total = m * p;
2087    if (linear >= total) {
2088        return;
2089    }
2090    int li = linear / p;
2091    int oc = linear - li * p;
2092    int off = block_offsets[block_id];
2093    int fbase = factor_ptr[block_id];
2094    double acc = 0.0;
2095    for (int lj = 0; lj < m; ++lj) {
2096        double a = factors[fbase + li * m + lj];
2097        acc += a * x[off + lj * p + oc];
2098    }
2099    out[off + li * p + oc] += acc;
2100}
2101
2102extern "C" __global__ void arrow_sae_sparse_g_matvec(
2103    const double* __restrict__ x,
2104    double* __restrict__ out,
2105    const int* __restrict__ row_off,
2106    const int* __restrict__ col_off,
2107    const int* __restrict__ rows,
2108    const int* __restrict__ cols,
2109    const int* __restrict__ data_ptr,
2110    const double* __restrict__ data,
2111    int p,
2112    int n_blocks
2113) {
2114    int block_id = blockIdx.y;
2115    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2116    if (block_id >= n_blocks) {
2117        return;
2118    }
2119    int m_i = rows[block_id];
2120    int m_j = cols[block_id];
2121    int total = m_i * p;
2122    if (linear >= total) {
2123        return;
2124    }
2125    int li = linear / p;
2126    int oc = linear - li * p;
2127    int rbase = row_off[block_id];
2128    int cbase = col_off[block_id];
2129    int dbase = data_ptr[block_id];
2130    double acc = 0.0;
2131    for (int lj = 0; lj < m_j; ++lj) {
2132        acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2133    }
2134    // #1017 — a row atom co-occurs with multiple column atoms, so several
2135    // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2136    // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2137    // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2138    // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2139    atomicAdd(&out[(rbase + li) * p + oc], acc);
2140}
2141
2142extern "C" __global__ void arrow_sae_gather_u(
2143    const double* __restrict__ x,
2144    const int* __restrict__ row_ptr,
2145    const int* __restrict__ beta_base,
2146    const double* __restrict__ phi,
2147    double* __restrict__ u,
2148    int p,
2149    int n_rows
2150) {
2151    int row = blockIdx.y;
2152    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2153    if (row >= n_rows || oc >= p) {
2154        return;
2155    }
2156    double acc = 0.0;
2157    int start = row_ptr[row];
2158    int end = row_ptr[row + 1];
2159    for (int e = start; e < end; ++e) {
2160        acc += phi[e] * x[beta_base[e] + oc];
2161    }
2162    u[row * p + oc] = acc;
2163}
2164
2165extern "C" __global__ void arrow_sae_apply_l(
2166    const double* __restrict__ u,
2167    const int* __restrict__ jac_ptr,
2168    const double* __restrict__ jac,
2169    double* __restrict__ w,
2170    int p,
2171    int max_q,
2172    int n_rows
2173) {
2174    int row = blockIdx.y;
2175    int c = blockIdx.x * blockDim.x + threadIdx.x;
2176    if (row >= n_rows) {
2177        return;
2178    }
2179    int jstart = jac_ptr[row];
2180    int q = (jac_ptr[row + 1] - jstart) / p;
2181    if (c >= q) {
2182        return;
2183    }
2184    double acc = 0.0;
2185    for (int oc = 0; oc < p; ++oc) {
2186        acc += jac[jstart + c * p + oc] * u[row * p + oc];
2187    }
2188    w[row * max_q + c] = acc;
2189}
2190
2191extern "C" __global__ void arrow_sae_apply_ainv(
2192    const double* __restrict__ ainv,
2193    const double* __restrict__ w,
2194    double* __restrict__ v,
2195    int max_q,
2196    int n_rows
2197) {
2198    int row = blockIdx.y;
2199    int c = blockIdx.x * blockDim.x + threadIdx.x;
2200    if (row >= n_rows || c >= max_q) {
2201        return;
2202    }
2203    double acc = 0.0;
2204    int base = row * max_q * max_q;
2205    for (int j = 0; j < max_q; ++j) {
2206        acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2207    }
2208    v[row * max_q + c] = acc;
2209}
2210
2211extern "C" __global__ void arrow_sae_scatter_sub(
2212    const double* __restrict__ v,
2213    const int* __restrict__ jac_ptr,
2214    const double* __restrict__ jac,
2215    const int* __restrict__ row_ptr,
2216    const int* __restrict__ beta_base,
2217    const double* __restrict__ phi,
2218    double* __restrict__ out,
2219    int p,
2220    int max_q,
2221    int n_rows
2222) {
2223    int row = blockIdx.y;
2224    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2225    if (row >= n_rows || oc >= p) {
2226        return;
2227    }
2228    int jstart = jac_ptr[row];
2229    int q = (jac_ptr[row + 1] - jstart) / p;
2230    double lt_v = 0.0;
2231    for (int c = 0; c < q; ++c) {
2232        lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2233    }
2234    int start = row_ptr[row];
2235    int end = row_ptr[row + 1];
2236    for (int e = start; e < end; ++e) {
2237        atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2238    }
2239}
2240
2241extern "C" __global__ void arrow_sae_diag_sub(
2242    double* __restrict__ diag,
2243    const double* __restrict__ ainv,
2244    const int* __restrict__ jac_ptr,
2245    const double* __restrict__ jac,
2246    const int* __restrict__ row_ptr,
2247    const int* __restrict__ beta_base,
2248    const double* __restrict__ phi,
2249    int p,
2250    int max_q,
2251    int n_rows
2252) {
2253    int row = blockIdx.y;
2254    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2255    if (row >= n_rows || oc >= p) {
2256        return;
2257    }
2258    int jstart = jac_ptr[row];
2259    int q = (jac_ptr[row + 1] - jstart) / p;
2260    int abase = row * max_q * max_q;
2261    double quad = 0.0;
2262    for (int c = 0; c < q; ++c) {
2263        double lc = jac[jstart + c * p + oc];
2264        for (int d = 0; d < q; ++d) {
2265            quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2266        }
2267    }
2268    int start = row_ptr[row];
2269    int end = row_ptr[row + 1];
2270    for (int e = start; e < end; ++e) {
2271        double pe = phi[e];
2272        atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2273    }
2274}
2275
2276/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2277 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2278 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2279 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2280 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2281
2282extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2283    const double* __restrict__ x,
2284    double* __restrict__ out,
2285    const int* __restrict__ block_offsets,
2286    const int* __restrict__ block_m,
2287    const int* __restrict__ block_r,
2288    const int* __restrict__ factor_ptr,
2289    const double* __restrict__ factors,
2290    int n_blocks
2291) {
2292    int block_id = blockIdx.y;
2293    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2294    if (block_id >= n_blocks) {
2295        return;
2296    }
2297    int m = block_m[block_id];
2298    int r = block_r[block_id];
2299    int total = m * r;
2300    if (linear >= total) {
2301        return;
2302    }
2303    int li = linear / r;
2304    int ib = linear - li * r;
2305    int off = block_offsets[block_id];
2306    int fbase = factor_ptr[block_id];
2307    double acc = 0.0;
2308    for (int lj = 0; lj < m; ++lj) {
2309        double a = factors[fbase + li * m + lj];
2310        acc += a * x[off + lj * r + ib];
2311    }
2312    out[off + li * r + ib] += acc;
2313}
2314
2315extern "C" __global__ void arrow_sae_frame_g_matvec(
2316    const double* __restrict__ x,
2317    double* __restrict__ out,
2318    const int* __restrict__ off_i,
2319    const int* __restrict__ off_j,
2320    const int* __restrict__ r_i,
2321    const int* __restrict__ r_j,
2322    const int* __restrict__ m_i,
2323    const int* __restrict__ m_j,
2324    const int* __restrict__ g_ptr,
2325    const double* __restrict__ g_data,
2326    const int* __restrict__ w_ptr,
2327    const double* __restrict__ w_data,
2328    int n_blocks
2329) {
2330    int block_id = blockIdx.y;
2331    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2332    if (block_id >= n_blocks) {
2333        return;
2334    }
2335    int ri = r_i[block_id];
2336    int rj = r_j[block_id];
2337    int mi = m_i[block_id];
2338    int mj = m_j[block_id];
2339    int total = mi * ri;
2340    if (linear >= total) {
2341        return;
2342    }
2343    int li = linear / ri;       // basis row in atom i
2344    int a = linear - li * ri;   // frame coord in atom i
2345    int oi = off_i[block_id];
2346    int oj = off_j[block_id];
2347    int gbase = g_ptr[block_id];
2348    int wbase = w_ptr[block_id];
2349    double acc = 0.0;
2350    for (int lj = 0; lj < mj; ++lj) {
2351        double g = g_data[gbase + li * mj + lj];
2352        if (g == 0.0) { continue; }
2353        int xj_base = oj + lj * rj;
2354        double inner = 0.0;
2355        for (int b = 0; b < rj; ++b) {
2356            inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2357        }
2358        acc += g * inner;
2359    }
2360    // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2361    // multiple co-occurring (i,j) frame blocks running concurrently on
2362    // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2363    // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2364    atomicAdd(&out[oi + li * ri + a], acc);
2365}
2366
2367/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2368 *   h_i   = H_tβ^(i) · x                (length q_i)
2369 *   s_i   = (H_tt^(i)+ρ_t I)⁻¹ h_i      (apply cached ainv, length q_i)
2370 *   out  -= (H_tβ^(i))ᵀ · s_i           (scatter into border_dim)
2371 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2372 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2373extern "C" __global__ void arrow_sae_frame_apply_h(
2374    const double* __restrict__ x,
2375    const int* __restrict__ htb_ptr,
2376    const double* __restrict__ htb,
2377    const int* __restrict__ q_of,
2378    double* __restrict__ hvec,
2379    int k,
2380    int max_q,
2381    int n_rows
2382) {
2383    int row = blockIdx.y;
2384    int c = blockIdx.x * blockDim.x + threadIdx.x;
2385    if (row >= n_rows) { return; }
2386    int q = q_of[row];
2387    if (c >= q) { return; }
2388    int base = htb_ptr[row] + c * k;
2389    double acc = 0.0;
2390    for (int a = 0; a < k; ++a) {
2391        acc += htb[base + a] * x[a];
2392    }
2393    hvec[row * max_q + c] = acc;
2394}
2395
2396extern "C" __global__ void arrow_sae_frame_apply_ainv(
2397    const double* __restrict__ ainv,
2398    const double* __restrict__ hvec,
2399    const int* __restrict__ q_of,
2400    double* __restrict__ svec,
2401    int max_q,
2402    int n_rows
2403) {
2404    int row = blockIdx.y;
2405    int c = blockIdx.x * blockDim.x + threadIdx.x;
2406    if (row >= n_rows || c >= max_q) { return; }
2407    int q = q_of[row];
2408    double acc = 0.0;
2409    int abase = row * max_q * max_q;
2410    for (int j = 0; j < q; ++j) {
2411        acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2412    }
2413    svec[row * max_q + c] = acc;
2414}
2415
2416extern "C" __global__ void arrow_sae_frame_scatter_h(
2417    const double* __restrict__ svec,
2418    const int* __restrict__ htb_ptr,
2419    const double* __restrict__ htb,
2420    const int* __restrict__ q_of,
2421    double* __restrict__ out,
2422    int k,
2423    int max_q,
2424    int n_rows
2425) {
2426    int row = blockIdx.y;
2427    int a = blockIdx.x * blockDim.x + threadIdx.x;
2428    if (row >= n_rows || a >= k) { return; }
2429    int q = q_of[row];
2430    int hbase = htb_ptr[row];
2431    double acc = 0.0;
2432    for (int c = 0; c < q; ++c) {
2433        acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2434    }
2435    atomicAdd(&out[a], -acc);
2436}
2437
2438/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2439extern "C" __global__ void arrow_sae_frame_diag_sub(
2440    double* __restrict__ diag,
2441    const double* __restrict__ ainv,
2442    const int* __restrict__ htb_ptr,
2443    const double* __restrict__ htb,
2444    const int* __restrict__ q_of,
2445    int k,
2446    int max_q,
2447    int n_rows
2448) {
2449    int row = blockIdx.y;
2450    int a = blockIdx.x * blockDim.x + threadIdx.x;
2451    if (row >= n_rows || a >= k) { return; }
2452    int q = q_of[row];
2453    int hbase = htb_ptr[row];
2454    int abase = row * max_q * max_q;
2455    double quad = 0.0;
2456    for (int c = 0; c < q; ++c) {
2457        double hc = htb[hbase + c * k + a];
2458        for (int d = 0; d < q; ++d) {
2459            quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2460        }
2461    }
2462    atomicAdd(&diag[a], -quad);
2463}
2464"#;
2465
2466    fn pcg_vector_module(
2467        ctx: &Arc<CudaContext>,
2468    ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2469        static CACHE: gam_gpu::device_cache::PtxModuleCache =
2470            gam_gpu::device_cache::PtxModuleCache::new();
2471        CACHE
2472            .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2473            .map_err(|err| {
2474                // #1551: an NVRTC compile / module-load failure of
2475                // PCG_VECTOR_KERNEL_SOURCE means the device SAE PCG cannot run;
2476                // log it (the historical silent collapse to `Unavailable` is what
2477                // masked the missing `--gpu-architecture` for so long) and fall
2478                // back to the CPU.
2479                log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2480                ArrowSchurGpuFailure::Unavailable
2481            })
2482    }
2483
2484    fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2485        let threads = 256u32;
2486        let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2487        Ok(LaunchConfig {
2488            grid_dim: (blocks, 1, 1),
2489            block_dim: (threads, 1, 1),
2490            shared_mem_bytes: 0,
2491        })
2492    }
2493
2494    fn launch_jacobi_mul(
2495        stream: &Arc<CudaStream>,
2496        module: &Arc<CudaModule>,
2497        inv_diag: &CudaSlice<f64>,
2498        r: &CudaSlice<f64>,
2499        z: &mut CudaSlice<f64>,
2500        n: usize,
2501    ) -> Result<(), ArrowSchurGpuFailure> {
2502        let kernel = module
2503            .load_function("arrow_pcg_jacobi_mul")
2504            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2505        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2506        let mut builder = stream.launch_builder(&kernel);
2507        builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2508        // SAFETY: all buffers have length n and belong to `stream`; the kernel only
2509        // reads/writes indices `< n`.
2510        unsafe { builder.launch(pcg_launch_config(n)?) }
2511            .map(drop)
2512            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2513    }
2514
2515    fn launch_update_p(
2516        stream: &Arc<CudaStream>,
2517        module: &Arc<CudaModule>,
2518        z: &CudaSlice<f64>,
2519        beta: f64,
2520        p: &mut CudaSlice<f64>,
2521        n: usize,
2522    ) -> Result<(), ArrowSchurGpuFailure> {
2523        let kernel = module
2524            .load_function("arrow_pcg_update_p")
2525            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2526        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2527        let mut builder = stream.launch_builder(&kernel);
2528        builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2529        // SAFETY: z/p both have length n and belong to `stream`; the kernel only
2530        // reads/writes indices `< n`.
2531        unsafe { builder.launch(pcg_launch_config(n)?) }
2532            .map(drop)
2533            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2534    }
2535
2536    struct DeviceSaePcgBuffers {
2537        row_ptr: CudaSlice<i32>,
2538        beta_base: CudaSlice<i32>,
2539        phi: CudaSlice<f64>,
2540        jac_ptr: CudaSlice<i32>,
2541        jac: CudaSlice<f64>,
2542        smooth_offsets: CudaSlice<i32>,
2543        smooth_m: CudaSlice<i32>,
2544        smooth_ptr: CudaSlice<i32>,
2545        smooth_data: CudaSlice<f64>,
2546        g_row_off: CudaSlice<i32>,
2547        g_col_off: CudaSlice<i32>,
2548        g_rows: CudaSlice<i32>,
2549        g_cols: CudaSlice<i32>,
2550        g_ptr: CudaSlice<i32>,
2551        g_data: CudaSlice<f64>,
2552        ainv: CudaSlice<f64>,
2553        u: CudaSlice<f64>,
2554        w: CudaSlice<f64>,
2555        v: CudaSlice<f64>,
2556        n_rows: usize,
2557        p: usize,
2558        k: usize,
2559        max_q: usize,
2560        smooth_blocks: usize,
2561        g_blocks: usize,
2562    }
2563
2564    fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2565        to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2566    }
2567
2568    fn sae_penalty_diag_host(
2569        data: &DeviceSaePcgData,
2570        ridge_beta: f64,
2571    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2572        let mut diag = vec![ridge_beta; data.beta_dim];
2573        for block in &data.smooth_blocks {
2574            let (rows, cols) = block.factor_a.dim();
2575            if rows != cols {
2576                return Err(ArrowSchurGpuFailure::Unavailable);
2577            }
2578            for row in 0..rows {
2579                let coeff = block.factor_a[[row, row]];
2580                let base = block
2581                    .global_offset
2582                    .checked_add(
2583                        row.checked_mul(data.p)
2584                            .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2585                    )
2586                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2587                let end = base
2588                    .checked_add(data.p)
2589                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2590                if end > diag.len() {
2591                    return Err(ArrowSchurGpuFailure::Unavailable);
2592                }
2593                for channel in 0..data.p {
2594                    diag[base + channel] += coeff;
2595                }
2596            }
2597        }
2598        for block in &data.sparse_g_blocks {
2599            if block.row_off != block.col_off {
2600                continue;
2601            }
2602            let (rows, cols) = block.data.dim();
2603            for row in 0..rows.min(cols) {
2604                let coeff = block.data[[row, row]];
2605                let beta_row = block
2606                    .row_off
2607                    .checked_add(row)
2608                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2609                let base = beta_row
2610                    .checked_mul(data.p)
2611                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2612                let end = base
2613                    .checked_add(data.p)
2614                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2615                if end > diag.len() {
2616                    return Err(ArrowSchurGpuFailure::Unavailable);
2617                }
2618                for channel in 0..data.p {
2619                    diag[base + channel] += coeff;
2620                }
2621            }
2622        }
2623        Ok(diag)
2624    }
2625
2626    fn flatten_device_sae_data(
2627        sys: &ArrowSchurSystem,
2628        data: &DeviceSaePcgData,
2629        ridge_t: f64,
2630        stream: &Arc<CudaStream>,
2631    ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2632        let n_rows = sys.rows.len();
2633        let p = data.p;
2634        let k = data.beta_dim;
2635        if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2636            return Err(ArrowSchurGpuFailure::Unavailable);
2637        }
2638
2639        let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2640        let mut beta_base_host = Vec::<i32>::new();
2641        let mut phi_host = Vec::<f64>::new();
2642        row_ptr_host.push(0_i32);
2643        for row in data.a_phi.iter() {
2644            for &(base, phi) in row {
2645                beta_base_host.push(checked_i32(base)?);
2646                phi_host.push(phi);
2647            }
2648            row_ptr_host.push(checked_i32(beta_base_host.len())?);
2649        }
2650
2651        let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2652        let mut jac_host = Vec::<f64>::new();
2653        let mut max_q = 0usize;
2654        jac_ptr_host.push(0_i32);
2655        for row_jac in data.local_jac.iter() {
2656            if row_jac.len() % p != 0 {
2657                return Err(ArrowSchurGpuFailure::Unavailable);
2658            }
2659            max_q = max_q.max(row_jac.len() / p);
2660            jac_host.extend_from_slice(row_jac);
2661            jac_ptr_host.push(checked_i32(jac_host.len())?);
2662        }
2663        if max_q == 0 {
2664            return Err(ArrowSchurGpuFailure::Unavailable);
2665        }
2666
2667        let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2668        let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2669        let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2670        let mut smooth_data_host = Vec::<f64>::new();
2671        smooth_ptr_host.push(0_i32);
2672        for block in &data.smooth_blocks {
2673            let (rows, cols) = block.factor_a.dim();
2674            if rows != cols {
2675                return Err(ArrowSchurGpuFailure::Unavailable);
2676            }
2677            smooth_offsets_host.push(checked_i32(block.global_offset)?);
2678            smooth_m_host.push(checked_i32(rows)?);
2679            for r in 0..rows {
2680                for c in 0..cols {
2681                    smooth_data_host.push(block.factor_a[[r, c]]);
2682                }
2683            }
2684            smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2685        }
2686
2687        let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2688        let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2689        let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2690        let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2691        let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2692        let mut g_data_host = Vec::<f64>::new();
2693        g_ptr_host.push(0_i32);
2694        for block in &data.sparse_g_blocks {
2695            let (rows, cols) = block.data.dim();
2696            g_row_off_host.push(checked_i32(block.row_off)?);
2697            g_col_off_host.push(checked_i32(block.col_off)?);
2698            g_rows_host.push(checked_i32(rows)?);
2699            g_cols_host.push(checked_i32(cols)?);
2700            for r in 0..rows {
2701                for c in 0..cols {
2702                    g_data_host.push(block.data[[r, c]]);
2703                }
2704            }
2705            g_ptr_host.push(checked_i32(g_data_host.len())?);
2706        }
2707
2708        let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2709        for (row_idx, row) in sys.rows.iter().enumerate() {
2710            let q = data.local_jac[row_idx].len() / p;
2711            if row.htt.dim() != (q, q) {
2712                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2713                    reason: format!(
2714                        "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2715                        row.htt.dim()
2716                    ),
2717                });
2718            }
2719            let mut block = row.htt.clone();
2720            for d in 0..q {
2721                block[[d, d]] += ridge_t;
2722            }
2723            let factor = gam_linalg::triangular::cholesky_factor_in_place(
2724                block.view(),
2725                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2726            )
2727            .ok_or_else(|| {
2728                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
2729                // indefinite per-row block recovers in one outer-loop retry.
2730                ArrowSchurGpuFailure::RidgeBumpRequired {
2731                    row: row_idx,
2732                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2733                }
2734            })?;
2735            for col in 0..q {
2736                let mut e = Array1::<f64>::zeros(q);
2737                e[col] = 1.0;
2738                let solved =
2739                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2740                for r in 0..q {
2741                    ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2742                }
2743            }
2744        }
2745
2746        Ok(DeviceSaePcgBuffers {
2747            row_ptr: stream
2748                .clone_htod(&row_ptr_host)
2749                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2750            beta_base: stream
2751                .clone_htod(&beta_base_host)
2752                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2753            phi: stream
2754                .clone_htod(&phi_host)
2755                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2756            jac_ptr: stream
2757                .clone_htod(&jac_ptr_host)
2758                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2759            jac: stream
2760                .clone_htod(&jac_host)
2761                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2762            smooth_offsets: stream
2763                .clone_htod(&smooth_offsets_host)
2764                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2765            smooth_m: stream
2766                .clone_htod(&smooth_m_host)
2767                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2768            smooth_ptr: stream
2769                .clone_htod(&smooth_ptr_host)
2770                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2771            smooth_data: stream
2772                .clone_htod(&smooth_data_host)
2773                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2774            g_row_off: stream
2775                .clone_htod(&g_row_off_host)
2776                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2777            g_col_off: stream
2778                .clone_htod(&g_col_off_host)
2779                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2780            g_rows: stream
2781                .clone_htod(&g_rows_host)
2782                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2783            g_cols: stream
2784                .clone_htod(&g_cols_host)
2785                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2786            g_ptr: stream
2787                .clone_htod(&g_ptr_host)
2788                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2789            g_data: stream
2790                .clone_htod(&g_data_host)
2791                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2792            ainv: stream
2793                .clone_htod(&ainv_host)
2794                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795            u: stream
2796                .alloc_zeros::<f64>(n_rows * p)
2797                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798            w: stream
2799                .alloc_zeros::<f64>(n_rows * max_q)
2800                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801            v: stream
2802                .alloc_zeros::<f64>(n_rows * max_q)
2803                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804            n_rows,
2805            p,
2806            k,
2807            max_q,
2808            smooth_blocks: data.smooth_blocks.len(),
2809            g_blocks: data.sparse_g_blocks.len(),
2810        })
2811    }
2812
2813    fn launch_sae_init(
2814        stream: &Arc<CudaStream>,
2815        module: &Arc<CudaModule>,
2816        out: &mut CudaSlice<f64>,
2817        x: &CudaSlice<f64>,
2818        ridge: f64,
2819        n: usize,
2820    ) -> Result<(), ArrowSchurGpuFailure> {
2821        let kernel = module
2822            .load_function("arrow_sae_init")
2823            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2824        let n_i32 = checked_i32(n)?;
2825        let mut builder = stream.launch_builder(&kernel);
2826        builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2827        // SAFETY: `out` and `x` are live device buffers with at least `n`
2828        // entries on `stream`; the kernel writes one in-bounds element per
2829        // launched index below `n`.
2830        unsafe { builder.launch(pcg_launch_config(n)?) }
2831            .map(drop)
2832            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2833    }
2834
2835    fn launch_sae_penalty_matvec(
2836        stream: &Arc<CudaStream>,
2837        module: &Arc<CudaModule>,
2838        buffers: &mut DeviceSaePcgBuffers,
2839        x: &CudaSlice<f64>,
2840        out: &mut CudaSlice<f64>,
2841        ridge_beta: f64,
2842    ) -> Result<(), ArrowSchurGpuFailure> {
2843        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2844        if buffers.smooth_blocks > 0 {
2845            let kernel = module
2846                .load_function("arrow_sae_smooth_matvec")
2847                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2848            let max_m = buffers.k;
2849            let p_i32 = checked_i32(buffers.p)?;
2850            let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2851            let cfg = LaunchConfig {
2852                grid_dim: (
2853                    ((max_m as u32).saturating_add(255) / 256).max(1),
2854                    checked_i32(buffers.smooth_blocks)? as u32,
2855                    1,
2856                ),
2857                block_dim: (256, 1, 1),
2858                shared_mem_bytes: 0,
2859            };
2860            let mut builder = stream.launch_builder(&kernel);
2861            builder
2862                .arg(x)
2863                .arg(&mut *out)
2864                .arg(&buffers.smooth_offsets)
2865                .arg(&buffers.smooth_m)
2866                .arg(&buffers.smooth_ptr)
2867                .arg(&buffers.smooth_data)
2868                .arg(&p_i32)
2869                .arg(&blocks_i32);
2870            // SAFETY: smooth block metadata and dense smooth data were flattened
2871            // into live device buffers; the 2D grid covers only declared block
2872            // and coefficient-channel work items, and the kernel bounds-checks
2873            // against each block's stored size.
2874            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2875        }
2876        if buffers.g_blocks > 0 {
2877            let kernel = module
2878                .load_function("arrow_sae_sparse_g_matvec")
2879                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2880            let max_work = buffers
2881                .k
2882                .checked_div(buffers.p)
2883                .unwrap_or(0)
2884                .saturating_mul(buffers.p);
2885            let p_i32 = checked_i32(buffers.p)?;
2886            let blocks_i32 = checked_i32(buffers.g_blocks)?;
2887            let cfg = LaunchConfig {
2888                grid_dim: (
2889                    ((max_work as u32).saturating_add(255) / 256).max(1),
2890                    checked_i32(buffers.g_blocks)? as u32,
2891                    1,
2892                ),
2893                block_dim: (256, 1, 1),
2894                shared_mem_bytes: 0,
2895            };
2896            let mut builder = stream.launch_builder(&kernel);
2897            builder
2898                .arg(x)
2899                .arg(&mut *out)
2900                .arg(&buffers.g_row_off)
2901                .arg(&buffers.g_col_off)
2902                .arg(&buffers.g_rows)
2903                .arg(&buffers.g_cols)
2904                .arg(&buffers.g_ptr)
2905                .arg(&buffers.g_data)
2906                .arg(&p_i32)
2907                .arg(&blocks_i32);
2908            // SAFETY: sparse G block metadata/data are live device buffers built
2909            // from host CSR-like block descriptors; the launch dimensions cover
2910            // declared block work only and the kernel checks row/column bounds
2911            // before reading or accumulating.
2912            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2913        }
2914        Ok(())
2915    }
2916
2917    fn launch_sae_row_schur_sub(
2918        stream: &Arc<CudaStream>,
2919        module: &Arc<CudaModule>,
2920        buffers: &mut DeviceSaePcgBuffers,
2921        x: &CudaSlice<f64>,
2922        out: &mut CudaSlice<f64>,
2923    ) -> Result<(), ArrowSchurGpuFailure> {
2924        let p_i32 = checked_i32(buffers.p)?;
2925        let max_q_i32 = checked_i32(buffers.max_q)?;
2926        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2927        let cfg_p_rows = LaunchConfig {
2928            grid_dim: (
2929                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2930                checked_i32(buffers.n_rows)? as u32,
2931                1,
2932            ),
2933            block_dim: (256, 1, 1),
2934            shared_mem_bytes: 0,
2935        };
2936        let gather = module
2937            .load_function("arrow_sae_gather_u")
2938            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2939        {
2940            let mut builder = stream.launch_builder(&gather);
2941            builder
2942                .arg(x)
2943                .arg(&buffers.row_ptr)
2944                .arg(&buffers.beta_base)
2945                .arg(&buffers.phi)
2946                .arg(&mut buffers.u)
2947                .arg(&p_i32)
2948                .arg(&n_rows_i32);
2949            // SAFETY: `x`, row pointers, beta offsets, basis rows, and `u` are
2950            // live device buffers sized for `n_rows` by `p`; the kernel guards
2951            // row/channel indices before gathering.
2952            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2953        }
2954
2955        let cfg_q_rows = LaunchConfig {
2956            grid_dim: (
2957                ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2958                checked_i32(buffers.n_rows)? as u32,
2959                1,
2960            ),
2961            block_dim: (256, 1, 1),
2962            shared_mem_bytes: 0,
2963        };
2964        let apply_l = module
2965            .load_function("arrow_sae_apply_l")
2966            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2967        {
2968            let mut builder = stream.launch_builder(&apply_l);
2969            builder
2970                .arg(&buffers.u)
2971                .arg(&buffers.jac_ptr)
2972                .arg(&buffers.jac)
2973                .arg(&mut buffers.w)
2974                .arg(&p_i32)
2975                .arg(&max_q_i32)
2976                .arg(&n_rows_i32);
2977            // SAFETY: `u`, Jacobian row pointers/data, and `w` are live buffers
2978            // sized for the `(n_rows, p)` to `(n_rows, max_q)` multiply; the
2979            // kernel checks row and local-coordinate bounds.
2980            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2981        }
2982
2983        let apply_ainv = module
2984            .load_function("arrow_sae_apply_ainv")
2985            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2986        {
2987            let mut builder = stream.launch_builder(&apply_ainv);
2988            builder
2989                .arg(&buffers.ainv)
2990                .arg(&buffers.w)
2991                .arg(&mut buffers.v)
2992                .arg(&max_q_i32)
2993                .arg(&n_rows_i32);
2994            // SAFETY: `ainv`, `w`, and `v` are live device buffers sized for
2995            // `n_rows * max_q`; the kernel guards all row/local-coordinate
2996            // indices before reading or writing.
2997            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2998        }
2999
3000        let scatter = module
3001            .load_function("arrow_sae_scatter_sub")
3002            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3003        {
3004            let mut builder = stream.launch_builder(&scatter);
3005            builder
3006                .arg(&buffers.v)
3007                .arg(&buffers.jac_ptr)
3008                .arg(&buffers.jac)
3009                .arg(&buffers.row_ptr)
3010                .arg(&buffers.beta_base)
3011                .arg(&buffers.phi)
3012                .arg(out)
3013                .arg(&p_i32)
3014                .arg(&max_q_i32)
3015                .arg(&n_rows_i32);
3016            // SAFETY: `v`, Jacobian metadata, row pointers, beta offsets, basis
3017            // rows, and `out` are live buffers for `n_rows` by `p`; scatter
3018            // indices are checked against row and channel bounds in the kernel.
3019            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3020        }
3021        Ok(())
3022    }
3023
3024    fn launch_sae_diag_sub(
3025        stream: &Arc<CudaStream>,
3026        module: &Arc<CudaModule>,
3027        buffers: &DeviceSaePcgBuffers,
3028        diag: &mut CudaSlice<f64>,
3029    ) -> Result<(), ArrowSchurGpuFailure> {
3030        let kernel = module
3031            .load_function("arrow_sae_diag_sub")
3032            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3033        let p_i32 = checked_i32(buffers.p)?;
3034        let max_q_i32 = checked_i32(buffers.max_q)?;
3035        let n_rows_i32 = checked_i32(buffers.n_rows)?;
3036        let cfg = LaunchConfig {
3037            grid_dim: (
3038                ((buffers.p as u32).saturating_add(255) / 256).max(1),
3039                checked_i32(buffers.n_rows)? as u32,
3040                1,
3041            ),
3042            block_dim: (256, 1, 1),
3043            shared_mem_bytes: 0,
3044        };
3045        let mut builder = stream.launch_builder(&kernel);
3046        builder
3047            .arg(diag)
3048            .arg(&buffers.ainv)
3049            .arg(&buffers.jac_ptr)
3050            .arg(&buffers.jac)
3051            .arg(&buffers.row_ptr)
3052            .arg(&buffers.beta_base)
3053            .arg(&buffers.phi)
3054            .arg(&p_i32)
3055            .arg(&max_q_i32)
3056            .arg(&n_rows_i32);
3057        // SAFETY: diagonal output and all read-only SAE row metadata buffers are
3058        // live on `stream` with sizes matching `n_rows`, `p`, and `max_q`; the
3059        // kernel bounds-checks its flattened work index.
3060        unsafe { builder.launch(cfg) }
3061            .map(drop)
3062            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3063    }
3064
3065    fn launch_sae_matvec(
3066        stream: &Arc<CudaStream>,
3067        module: &Arc<CudaModule>,
3068        buffers: &mut DeviceSaePcgBuffers,
3069        x: &CudaSlice<f64>,
3070        out: &mut CudaSlice<f64>,
3071        ridge_beta: f64,
3072    ) -> Result<(), ArrowSchurGpuFailure> {
3073        launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3074        launch_sae_row_schur_sub(stream, module, buffers, x, out)
3075    }
3076
3077    /// Pack `D + ρ_t I`, `B`, and `g` into the strided `(n × P_MAX × P_MAX)`
3078    /// / `(n × P_MAX × R_TEMPLATE)` / `(n × P_MAX)` layout the fused kernel
3079    /// expects. Entries outside the runtime `(p, r)` window stay at zero so
3080    /// the kernel's per-element loops are safe to no-op there.
3081    fn pack_fused_host(
3082        sys: &ArrowSchurSystem,
3083        ridge_t: f64,
3084        p_max: usize,
3085        r_template: usize,
3086    ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3087        let n = sys.rows.len();
3088        let d = sys.d;
3089        let k = sys.k;
3090        let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3091        let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3092        let mut g_buf = vec![0.0_f64; n * p_max];
3093        for (i, row) in sys.rows.iter().enumerate() {
3094            // D_i + ρI, column-major in P_MAX×P_MAX strided block.
3095            for col in 0..d {
3096                let base = (i * p_max + col) * p_max;
3097                for r in 0..d {
3098                    let mut value = row.htt[[r, col]];
3099                    if r == col {
3100                        value += ridge_t;
3101                    }
3102                    d_buf[base + r] = value;
3103                }
3104            }
3105            // B_i in P_MAX×R_TEMPLATE strided block. The per-row (per-i) block
3106            // stride is `p_max · r_template` (matching the `b_buf` allocation
3107            // above and the kernel's `b_stack`/`y_out` layout), NOT
3108            // `p_max · p_max`: using the D-block multiplier here overflows the
3109            // buffer whenever `p_max > r_template` (e.g. d=30→p_max=32,
3110            // k=5→r_template=5). The within-block element offset stays
3111            // column-major `col·p_max + r` (P_MAX rows per column).
3112            for col in 0..k {
3113                let base = (i * r_template + col) * p_max;
3114                for r in 0..d {
3115                    b_buf[base + r] = row.htbeta[[r, col]];
3116                }
3117            }
3118            // g_i in P_MAX strided vector.
3119            let g_base = i * p_max;
3120            for r in 0..d {
3121                g_buf[g_base + r] = row.gt[r];
3122            }
3123        }
3124        (d_buf, b_buf, g_buf)
3125    }
3126
3127    // -----------------------------------------------------------------------
3128    // #1017 Phase 3: across-iteration device residency.
3129    //
3130    // `solve()` re-packs and re-uploads `D` (`H_tt`), `B` (`H_tβ`) and `g`,
3131    // then re-runs the per-row POTRF and the border Schur factorization on
3132    // EVERY call. For the SAE joint inner Newton at a frozen gate/basis frame
3133    // the Hessian blocks `D`, `B`, `H_ββ` are CONSTANT across the inner loop —
3134    // only the gradient `g = r(z) = H z − g₀` changes per iterate. So the
3135    // factor work (`O(n·d³ + p³)`) and the dominant `O(n·d·p)` cross-block
3136    // upload are pure waste when repeated per iterate.
3137    //
3138    // `ResidentArrowFrame` performs that constant work ONCE at construction:
3139    // upload+ridge+POTRF of `D` (keeping `L_i` resident in `l_dev`), the
3140    // forward solve `Y_i = L_i^{-1} B_i` (kept resident in `y_dev`), and the
3141    // Schur assembly + border POTRF (keeping `L_S` resident in `schur_dev`).
3142    // Each subsequent `solve_gradient(g)` uploads only the `n·d` row gradient,
3143    // runs the cheap residual path — `u_i = L_i^{-1} g_i` (one batched TRSM),
3144    // Schur RHS `−g_β + Σ Y_iᵀ u_i`, `δβ = L_S^{-T} L_S^{-1} rhs` (two TRSM,
3145    // NO POTRF), back-sub `δt_i = −L_i^{-T}(u_i + Y_i δβ)` — and reads back only
3146    // `δ` and the cached log|H|. The heavy buffers never leave the device
3147    // across iterations; the per-iterate host transfer is `O(n·d + p)`, not
3148    // `O(n·d·p)`. Numerics are bit-identical to a `solve()` at the same
3149    // `(D, B, H_ββ, g, ridge_t, ridge_beta)` because the factor buffers and the
3150    // helper kernels are the same; the resident path merely SKIPS re-deriving
3151    // the parts that do not depend on `g`. The CPU dense reference
3152    // (`solve_arrow_newton_step_dense_reference`) is the parity oracle.
3153    pub(super) struct ResidentArrowFrame {
3154        n: usize,
3155        d: usize,
3156        k: usize,
3157        stream: Arc<CudaStream>,
3158        blas: CudaBlas,
3159        /// Per-row lower Cholesky factors `L_i` of `H_tt + ρ_t I`, stacked
3160        /// column-major (`n` tiles of `d×d`). Resident across iterations.
3161        l_dev: CudaSlice<f64>,
3162        /// Whitened cross blocks `Y_i = L_i^{-1} H_tβ^(i)`, stacked column-major
3163        /// (`n` tiles of `d×k`). Resident across iterations.
3164        y_dev: CudaSlice<f64>,
3165        /// Lower Cholesky factor `L_S` of the reduced Schur complement
3166        /// `S_β = H_ββ + ρ_β I − Σ_i Y_iᵀ Y_i`. Resident across iterations.
3167        schur_dev: CudaSlice<f64>,
3168        /// `log|H| = 2 Σ log L_{i,jj} + 2 Σ log L_{S,aa}`, constant for the
3169        /// frame (depends only on the factored Hessian, not on `g`).
3170        log_det_hessian: f64,
3171    }
3172
3173    impl ResidentArrowFrame {
3174        /// Upload the constant Hessian blocks and perform the one-time factor
3175        /// work (`POTRF(D)`, `Y_i = L_i^{-1} B_i`, Schur assembly + border
3176        /// `POTRF`). The frame then serves cheap per-gradient solves.
3177        pub(super) fn new(
3178            sys: &ArrowSchurSystem,
3179            ridge_t: f64,
3180            ridge_beta: f64,
3181        ) -> Result<Self, ArrowSchurGpuFailure> {
3182            if ridge_t.is_nan() || ridge_beta.is_nan() {
3183                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3184                    reason: "ridge is NaN".to_string(),
3185                });
3186            }
3187            let n = sys.rows.len();
3188            let d = sys.d;
3189            let k = sys.k;
3190            let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3191                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3192            let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3193                .and_then(|ctx| ctx.new_stream().ok())
3194                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3195            let solver =
3196                DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3197            let blas =
3198                CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3199
3200            // Upload the constant blocks. `g` is uploaded per-gradient, not here.
3201            let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3202            let mut l_dev = stream
3203                .clone_htod(&d_host)
3204                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3205            let mut y_dev = stream
3206                .clone_htod(&b_host)
3207                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3208
3209            // POTRF(D) → L_i, kept resident in l_dev.
3210            let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3211            if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3212                // cuSOLVER `info` is a 1-based pivot row index; size the bump
3213                // from the block (Gershgorin λ_min bound) so a strongly
3214                // indefinite block recovers in one retry.
3215                return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3216                    row: idx,
3217                    bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3218                });
3219            }
3220
3221            // Y_i = L_i^{-1} B_i, in place over y_dev. Kept resident.
3222            trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3223
3224            // Schur assembly S_β = (H_ββ + ρ_β I) − Σ Y_iᵀ Y_i, then POTRF → L_S.
3225            // The RHS accumulation is folded into the gradient path; here we
3226            // only need the (gradient-independent) Schur factor, so accumulate
3227            // into a throwaway rhs buffer.
3228            let schur_init: Vec<f64> = {
3229                let mut tmp = Vec::with_capacity(k * k);
3230                for col in 0..k {
3231                    for row in 0..k {
3232                        let mut v = sys.hbb[[row, col]];
3233                        if row == col {
3234                            v += ridge_beta;
3235                        }
3236                        tmp.push(v);
3237                    }
3238                }
3239                tmp
3240            };
3241            let mut schur_dev = stream
3242                .clone_htod(&schur_init)
3243                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3244            // A zero u-stack makes `Σ Y_iᵀ u_i = 0`, so only the `−Σ Y_iᵀ Y_i`
3245            // Schur term is accumulated (the rhs is rebuilt per gradient).
3246            let zero_u = stream
3247                .clone_htod(&vec![0.0_f64; n * d])
3248                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3249            let mut throwaway_rhs = stream
3250                .clone_htod(&vec![0.0_f64; k])
3251                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3252            accumulate_schur(
3253                &blas,
3254                d,
3255                k,
3256                n,
3257                &y_dev,
3258                &zero_u,
3259                &mut schur_dev,
3260                &mut throwaway_rhs,
3261            )?;
3262            let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3263            if info != 0 {
3264                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3265                    reason: format!("Schur Cholesky failed at pivot {info}"),
3266                });
3267            }
3268
3269            // log|H| from the resident factors (constant for the frame).
3270            let l_local_host = stream
3271                .clone_dtoh(&l_dev)
3272                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3273            let l_schur_host = stream
3274                .clone_dtoh(&schur_dev)
3275                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3276            let mut log_det = 0.0_f64;
3277            for i in 0..n {
3278                let base = i * d * d;
3279                for j in 0..d {
3280                    log_det += l_local_host[base + j * d + j].ln();
3281                }
3282            }
3283            for j in 0..k {
3284                log_det += l_schur_host[j * k + j].ln();
3285            }
3286            log_det *= 2.0;
3287
3288            Ok(Self {
3289                n,
3290                d,
3291                k,
3292                stream,
3293                blas,
3294                l_dev,
3295                y_dev,
3296                schur_dev,
3297                log_det_hessian: log_det,
3298            })
3299        }
3300
3301        #[inline]
3302        pub(super) fn log_det_hessian(&self) -> f64 {
3303            self.log_det_hessian
3304        }
3305
3306        /// Solve `H δ = −gradient` for a fresh gradient `(g_t, g_β)` reusing the
3307        /// resident factors. Uploads only `g_t` (`n·d` scalars); reads back only
3308        /// `δ`. No POTRF runs here — all factorization is amortized into `new`.
3309        pub(super) fn solve_gradient(
3310            &self,
3311            g_t: &[f64],
3312            g_beta: &[f64],
3313        ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3314            let n = self.n;
3315            let d = self.d;
3316            let k = self.k;
3317            if g_t.len() != n * d || g_beta.len() != k {
3318                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3319                    reason: format!(
3320                        "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3321                        g_t.len(),
3322                        n * d,
3323                        g_beta.len(),
3324                        k
3325                    ),
3326                });
3327            }
3328            // Upload the per-iterate row gradient → u_i = L_i^{-1} g_i in place.
3329            let mut u_dev = self
3330                .stream
3331                .clone_htod(&g_t.to_vec())
3332                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3333            trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3334
3335            // Schur RHS = −g_β + Σ_i Y_iᵀ u_i. Reuse the resident Schur factor
3336            // (no POTRF, and skip the −Σ Y_iᵀ Y_i GEMM already baked into L_S).
3337            let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3338            let mut rhs_dev = self
3339                .stream
3340                .clone_htod(&rhs_init)
3341                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3342            accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3343
3344            // δβ ← L_S^{-T} L_S^{-1} rhs using the resident border factor.
3345            trsm_single(
3346                &self.blas,
3347                &self.stream,
3348                k,
3349                &self.schur_dev,
3350                &mut rhs_dev,
3351                false,
3352                false,
3353            )?;
3354            trsm_single(
3355                &self.blas,
3356                &self.stream,
3357                k,
3358                &self.schur_dev,
3359                &mut rhs_dev,
3360                false,
3361                true,
3362            )?;
3363            let delta_beta_host = self
3364                .stream
3365                .clone_dtoh(&rhs_dev)
3366                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3367            let delta_beta = Array1::from_vec(delta_beta_host);
3368
3369            // Back-sub δt_i = −L_i^{-T}(u_i + Y_i δβ).
3370            accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3371            trsm_batched_lower_inplace_transposed(
3372                &self.blas,
3373                &self.stream,
3374                d,
3375                n,
3376                1,
3377                &self.l_dev,
3378                &mut u_dev,
3379            )?;
3380            let x_host = self
3381                .stream
3382                .clone_dtoh(&u_dev)
3383                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3384            let mut delta_t = Array1::<f64>::zeros(n * d);
3385            for (i, v) in x_host.iter().enumerate() {
3386                delta_t[i] = -*v;
3387            }
3388
3389            Ok(ArrowSchurGpuSolution {
3390                delta_t,
3391                delta_beta,
3392                log_det_hessian: self.log_det_hessian,
3393            })
3394        }
3395    }
3396
3397    pub(super) fn solve_fused(
3398        sys: &ArrowSchurSystem,
3399        ridge_t: f64,
3400        ridge_beta: f64,
3401    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3402        let n = sys.rows.len();
3403        let d = sys.d;
3404        let k = sys.k;
3405        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3406            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3407        let p_max = plan.p_max;
3408        let r_template = plan.r_template;
3409
3410        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3411            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3412        )
3413        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3414        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3415            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3416        let stream = ctx
3417            .new_stream()
3418            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3419        let cap = &runtime.device.capability;
3420        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3421            cc_major: cap.compute_major,
3422            cc_minor: cap.compute_minor,
3423            p_max: p_max as u32,
3424            r_template: r_template as u32,
3425        };
3426        let module = fused_module_for(&ctx, key)?;
3427        let forward = module
3428            .load_function("arrow_schur_forward_pgroup")
3429            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3430        let back_sub = module
3431            .load_function("arrow_schur_back_sub_pgroup")
3432            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3433
3434        // ----- Upload packed D, B, g -----
3435        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3436        let d_dev = stream
3437            .clone_htod(&d_host)
3438            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3439        let b_dev = stream
3440            .clone_htod(&b_host)
3441            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3442        let g_dev = stream
3443            .clone_htod(&g_host)
3444            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445        let mut l_out = stream
3446            .alloc_zeros::<f64>(n * p_max * p_max)
3447            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3448        let mut u_out = stream
3449            .alloc_zeros::<f64>(n * p_max)
3450            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3451        let mut y_out = stream
3452            .alloc_zeros::<f64>(n * p_max * r_template)
3453            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3454        let mut partial_s = stream
3455            .alloc_zeros::<f64>(plan.partial_s_doubles)
3456            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457        let mut partial_r = stream
3458            .alloc_zeros::<f64>(plan.partial_r_doubles)
3459            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3460        let mut status_dev = stream
3461            .alloc_zeros::<i32>(n)
3462            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463
3464        // ----- Launch forward kernel: 1 block per row, P_MAX threads -----
3465        let cfg = LaunchConfig {
3466            grid_dim: (plan.blocks, 1, 1),
3467            block_dim: (plan.threads_per_block, 1, 1),
3468            shared_mem_bytes: 0,
3469        };
3470        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3471        let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3472        let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3473        let ridge_arg = ridge_t;
3474        {
3475            let mut builder = stream.launch_builder(&forward);
3476            builder
3477                .arg(&d_dev)
3478                .arg(&b_dev)
3479                .arg(&g_dev)
3480                .arg(&n_i32)
3481                .arg(&p_i32)
3482                .arg(&r_i32)
3483                .arg(&ridge_arg)
3484                .arg(&mut l_out)
3485                .arg(&mut u_out)
3486                .arg(&mut y_out)
3487                .arg(&mut partial_s)
3488                .arg(&mut partial_r)
3489                .arg(&mut status_dev);
3490            // SAFETY: all buffers were just allocated on `stream` with sizes
3491            // derived from `plan`; kernel parameter list matches the
3492            // FORWARD_KERNEL_SOURCE signature.
3493            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3494        }
3495        stream
3496            .synchronize()
3497            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3498
3499        // ----- Check per-block pivot status -----
3500        let status_host = stream
3501            .clone_dtoh(&status_dev)
3502            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3503        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3504            // The NVRTC kernel's status code is a 1-based pivot row index, not
3505            // a magnitude; size the bump from the block (Gershgorin λ_min
3506            // bound) so a strongly indefinite block recovers in one retry.
3507            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3508                row,
3509                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3510            });
3511        }
3512
3513        // ----- Reduce partials on host into S_β and r_β -----
3514        let partial_s_host = stream
3515            .clone_dtoh(&partial_s)
3516            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3517        let partial_r_host = stream
3518            .clone_dtoh(&partial_r)
3519            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3520        let mut schur_host = vec![0.0_f64; k * k];
3521        for col in 0..k {
3522            for row in 0..k {
3523                let mut v = sys.hbb[[row, col]];
3524                if row == col {
3525                    v += ridge_beta;
3526                }
3527                schur_host[col * k + row] = v;
3528            }
3529        }
3530        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3531        for i in 0..n {
3532            // partial_S[i] stride is R_TEMPLATE × R_TEMPLATE column-major; we
3533            // only read the leading (k × k) sub-block.
3534            let s_base = i * r_template * r_template;
3535            for col in 0..k {
3536                let col_base = s_base + col * r_template;
3537                let dst_col_base = col * k;
3538                for row in 0..k {
3539                    schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3540                }
3541            }
3542            let r_base = i * r_template;
3543            for a in 0..k {
3544                rhs_host[a] += partial_r_host[r_base + a];
3545            }
3546        }
3547
3548        // ----- Factor S_β on device (cuSOLVER), solve for δβ -----
3549        let mut schur_dev = stream
3550            .clone_htod(&schur_host)
3551            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3552        let mut rhs_dev = stream
3553            .clone_htod(&rhs_host)
3554            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3555        let solver =
3556            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3557        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3558        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3559        if info != 0 {
3560            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3561                reason: format!("fused Schur Cholesky failed at pivot {info}"),
3562            });
3563        }
3564        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3565        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3566        let delta_beta_host = stream
3567            .clone_dtoh(&rhs_dev)
3568            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3569        let delta_beta = Array1::from_vec(delta_beta_host.clone());
3570
3571        // ----- Layer E: launch back-sub kernel using persisted L, u, Y -----
3572        let mut delta_t_dev = stream
3573            .alloc_zeros::<f64>(n * p_max)
3574            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3575        let back_cfg = LaunchConfig {
3576            grid_dim: (plan.blocks, 1, 1),
3577            block_dim: (plan.threads_per_block, 1, 1),
3578            shared_mem_bytes: 0,
3579        };
3580        {
3581            let mut builder = stream.launch_builder(&back_sub);
3582            builder
3583                .arg(&l_out)
3584                .arg(&u_out)
3585                .arg(&y_out)
3586                .arg(&rhs_dev)
3587                .arg(&n_i32)
3588                .arg(&p_i32)
3589                .arg(&r_i32)
3590                .arg(&mut delta_t_dev);
3591            // SAFETY: kernel parameter list matches FORWARD_KERNEL_SOURCE
3592            // back-sub signature; `rhs_dev` holds δβ in the leading k entries
3593            // (R_TEMPLATE-strided indexing is column 0..k of the R-vector).
3594            unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3595        }
3596        stream
3597            .synchronize()
3598            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3599
3600        let delta_t_host = stream
3601            .clone_dtoh(&delta_t_dev)
3602            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3603        let mut delta_t = Array1::<f64>::zeros(n * d);
3604        for i in 0..n {
3605            let src_base = i * p_max;
3606            let dst_base = i * d;
3607            for r in 0..d {
3608                delta_t[dst_base + r] = delta_t_host[src_base + r];
3609            }
3610        }
3611
3612        // ----- log|H| = 2·Σ log L_{i,jj} + 2·Σ log R_{β,aa} -----
3613        let l_local_host = stream
3614            .clone_dtoh(&l_out)
3615            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3616        let l_schur_host = stream
3617            .clone_dtoh(&schur_dev)
3618            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3619        let mut log_det = 0.0_f64;
3620        for i in 0..n {
3621            let base = i * p_max * p_max;
3622            for j in 0..d {
3623                log_det += l_local_host[base + j * p_max + j].ln();
3624            }
3625        }
3626        for j in 0..k {
3627            log_det += l_schur_host[j * k + j].ln();
3628        }
3629        log_det *= 2.0;
3630
3631        Ok(ArrowSchurGpuSolution {
3632            delta_t,
3633            delta_beta,
3634            log_det_hessian: log_det,
3635        })
3636    }
3637
3638    /// Pre-compute `Y_i = L_i^{-1} H_tβ^(i)` via the fused forward kernel and
3639    /// return a closure that evaluates the full Schur matvec
3640    /// `S·x = (H_ββ + ρ·I)·x − Σ_i Y_i^T (Y_i·x)` for each PCG iteration.
3641    ///
3642    /// The `Y_i` factors are kept in a host-side buffer after one GPU forward
3643    /// pass. Each matvec call runs O(N·d·K) host loops over the pre-computed
3644    /// buffer plus an optional `H_ββ·x` call (matrix-free or dense). This is
3645    /// the first landing of the GPU matvec; a future iteration can move the
3646    /// `Y_i·x` / `Y_i^T z_i` steps to cuBLAS batched GEMV.
3647    pub(super) fn build_schur_matvec_backend(
3648        sys: &ArrowSchurSystem,
3649        ridge_t: f64,
3650        ridge_beta: f64,
3651    ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3652        let n = sys.rows.len();
3653        let d = sys.d;
3654        let k = sys.k;
3655        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3656            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3657        let p_max = plan.p_max;
3658        let r_template = plan.r_template;
3659
3660        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3661            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3662        )
3663        .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3664        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3665            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3666        let stream = ctx
3667            .new_stream()
3668            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3669        let cap = &runtime.device.capability;
3670        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3671            cc_major: cap.compute_major,
3672            cc_minor: cap.compute_minor,
3673            p_max: p_max as u32,
3674            r_template: r_template as u32,
3675        };
3676        let module = fused_module_for(&ctx, key)?;
3677        let forward = module
3678            .load_function("arrow_schur_forward_pgroup")
3679            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3680
3681        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3682        let d_dev = stream
3683            .clone_htod(&d_host)
3684            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3685        let b_dev = stream
3686            .clone_htod(&b_host)
3687            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3688        let g_dev = stream
3689            .clone_htod(&g_host)
3690            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3691        let mut l_out = stream
3692            .alloc_zeros::<f64>(n * p_max * p_max)
3693            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3694        let mut u_out = stream
3695            .alloc_zeros::<f64>(n * p_max)
3696            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3697        let mut y_out = stream
3698            .alloc_zeros::<f64>(n * p_max * r_template)
3699            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3700        let mut partial_s = stream
3701            .alloc_zeros::<f64>(plan.partial_s_doubles)
3702            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3703        let mut partial_r = stream
3704            .alloc_zeros::<f64>(plan.partial_r_doubles)
3705            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3706        let mut status_dev = stream
3707            .alloc_zeros::<i32>(n)
3708            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3709
3710        let cfg = LaunchConfig {
3711            grid_dim: (plan.blocks, 1, 1),
3712            block_dim: (plan.threads_per_block, 1, 1),
3713            shared_mem_bytes: 0,
3714        };
3715        let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3716        let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3717        let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3718        let ridge_arg = ridge_t;
3719        {
3720            let mut builder = stream.launch_builder(&forward);
3721            builder
3722                .arg(&d_dev)
3723                .arg(&b_dev)
3724                .arg(&g_dev)
3725                .arg(&n_i32)
3726                .arg(&p_i32)
3727                .arg(&r_i32)
3728                .arg(&ridge_arg)
3729                .arg(&mut l_out)
3730                .arg(&mut u_out)
3731                .arg(&mut y_out)
3732                .arg(&mut partial_s)
3733                .arg(&mut partial_r)
3734                .arg(&mut status_dev);
3735            // SAFETY: all buffers were allocated on `stream` with sizes
3736            // derived from `plan`; parameter list matches FORWARD_KERNEL_SOURCE.
3737            unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3738        }
3739        stream
3740            .synchronize()
3741            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3742
3743        let status_host = stream
3744            .clone_dtoh(&status_dev)
3745            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3746        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3747            // Status code is a 1-based pivot row index, not a magnitude; size
3748            // the bump from the block (Gershgorin λ_min bound) so a strongly
3749            // indefinite block recovers in one retry.
3750            return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3751                row,
3752                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3753            });
3754        }
3755
3756        // Download Y_i factors: n × p_max × r_template column-major per block.
3757        let y_host = stream
3758            .clone_dtoh(&y_out)
3759            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3760
3761        // Capture H_ββ data for the closure. Use the matrix-free hook if present
3762        // (SAE-manifold callers), otherwise fall back to the dense matrix rows.
3763        let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3764        let hbb_is_kk = sys.hbb.dim() == (k, k);
3765        let hbb_matvec_opt = sys.hbb_matvec.clone();
3766
3767        let closure: crate::arrow_schur::GpuSchurMatvec =
3768            Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3769                assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3770                assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3771
3772                // (H_ββ + ρ·I)·x into out.
3773                if let Some(ref mv) = hbb_matvec_opt {
3774                    mv(x.view(), out);
3775                    for a in 0..k {
3776                        out[a] += ridge_beta * x[a];
3777                    }
3778                } else if hbb_is_kk {
3779                    // hbb_host row-major: hbb[a, b] = hbb_host[a * k + b].
3780                    for a in 0..k {
3781                        let mut acc = ridge_beta * x[a];
3782                        for b in 0..k {
3783                            acc += hbb_host[a * k + b] * x[b];
3784                        }
3785                        out[a] = acc;
3786                    }
3787                } else {
3788                    for a in 0..k {
3789                        out[a] = ridge_beta * x[a];
3790                    }
3791                }
3792
3793                // out[c] -= Σ_i (Y_i^T (Y_i·x))[c].
3794                // Y_i column-major at y_host[i·p_max·r_template + col·p_max + row].
3795                let mut z = vec![0.0_f64; d];
3796                for i in 0..n {
3797                    let y_base = i * p_max * r_template;
3798                    for r in 0..d {
3799                        let mut acc = 0.0;
3800                        for c in 0..k {
3801                            acc += y_host[y_base + c * p_max + r] * x[c];
3802                        }
3803                        z[r] = acc;
3804                    }
3805                    for c in 0..k {
3806                        let mut acc = 0.0;
3807                        for r in 0..d {
3808                            acc += y_host[y_base + c * p_max + r] * z[r];
3809                        }
3810                        out[c] -= acc;
3811                    }
3812                }
3813            });
3814
3815        Ok(closure)
3816    }
3817
3818    // ── #1017/#1026 frames-engaged device PCG ──────────────────────────────
3819
3820    struct DeviceSaeFrameBuffers {
3821        // Smooth `λ S_k ⊗ I_{r_k}`.
3822        s_off: CudaSlice<i32>,
3823        s_m: CudaSlice<i32>,
3824        s_r: CudaSlice<i32>,
3825        s_ptr: CudaSlice<i32>,
3826        s_data: CudaSlice<f64>,
3827        s_blocks: usize,
3828        // Data `G_{ij} ⊗ W_{ij}`.
3829        g_off_i: CudaSlice<i32>,
3830        g_off_j: CudaSlice<i32>,
3831        g_ri: CudaSlice<i32>,
3832        g_rj: CudaSlice<i32>,
3833        g_mi: CudaSlice<i32>,
3834        g_mj: CudaSlice<i32>,
3835        g_ptr: CudaSlice<i32>,
3836        g_data: CudaSlice<f64>,
3837        w_ptr: CudaSlice<i32>,
3838        w_data: CudaSlice<f64>,
3839        g_blocks: usize,
3840        g_max_work: usize,
3841        // Per-row dense cross-block H_tβ^(i) + row q + factored ainv.
3842        htb_ptr: CudaSlice<i32>,
3843        htb: CudaSlice<f64>,
3844        q_of: CudaSlice<i32>,
3845        ainv: CudaSlice<f64>,
3846        hvec: CudaSlice<f64>,
3847        svec: CudaSlice<f64>,
3848        n_rows: usize,
3849        k: usize,
3850        max_q: usize,
3851    }
3852
3853    fn flatten_device_sae_frame_data(
3854        sys: &ArrowSchurSystem,
3855        data: &DeviceSaePcgData,
3856        frame: &DeviceSaeFrameData,
3857        ridge_t: f64,
3858        stream: &Arc<CudaStream>,
3859    ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3860        let n_rows = sys.rows.len();
3861        let k = data.beta_dim;
3862        if frame.row_htbeta.len() != n_rows
3863            || frame.ranks.len() != frame.basis_sizes.len()
3864            || frame.border_offsets.len() != frame.ranks.len()
3865            || data.smooth_blocks.len() != frame.smooth_ranks.len()
3866        {
3867            return Err(ArrowSchurGpuFailure::Unavailable);
3868        }
3869
3870        // Smooth blocks.
3871        let mut s_off = Vec::new();
3872        let mut s_m = Vec::new();
3873        let mut s_r = Vec::new();
3874        let mut s_ptr = vec![0_i32];
3875        let mut s_data = Vec::<f64>::new();
3876        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3877            let (m, mc) = blk.factor_a.dim();
3878            if m != mc {
3879                return Err(ArrowSchurGpuFailure::Unavailable);
3880            }
3881            s_off.push(checked_i32(blk.global_offset)?);
3882            s_m.push(checked_i32(m)?);
3883            s_r.push(checked_i32(r)?);
3884            for ri in 0..m {
3885                for ci in 0..m {
3886                    s_data.push(blk.factor_a[[ri, ci]]);
3887                }
3888            }
3889            s_ptr.push(checked_i32(s_data.len())?);
3890        }
3891
3892        // Data blocks (g + w).
3893        let mut g_off_i = Vec::new();
3894        let mut g_off_j = Vec::new();
3895        let mut g_ri = Vec::new();
3896        let mut g_rj = Vec::new();
3897        let mut g_mi = Vec::new();
3898        let mut g_mj = Vec::new();
3899        let mut g_ptr = vec![0_i32];
3900        let mut g_data = Vec::<f64>::new();
3901        let mut w_ptr = vec![0_i32];
3902        let mut w_data = Vec::<f64>::new();
3903        let mut g_max_work = 0usize;
3904        for blk in &frame.frame_blocks {
3905            let ri = frame.ranks[blk.atom_i];
3906            let rj = frame.ranks[blk.atom_j];
3907            let (mi, mj) = blk.g.dim();
3908            if blk.w.dim() != (ri, rj) {
3909                return Err(ArrowSchurGpuFailure::Unavailable);
3910            }
3911            g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3912            g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3913            g_ri.push(checked_i32(ri)?);
3914            g_rj.push(checked_i32(rj)?);
3915            g_mi.push(checked_i32(mi)?);
3916            g_mj.push(checked_i32(mj)?);
3917            for r in 0..mi {
3918                for c in 0..mj {
3919                    g_data.push(blk.g[[r, c]]);
3920                }
3921            }
3922            g_ptr.push(checked_i32(g_data.len())?);
3923            for a in 0..ri {
3924                for b in 0..rj {
3925                    w_data.push(blk.w[[a, b]]);
3926                }
3927            }
3928            w_ptr.push(checked_i32(w_data.len())?);
3929            g_max_work = g_max_work.max(mi * ri);
3930        }
3931
3932        // Per-row dense cross-block + q + ainv (factored H_tt⁻¹).
3933        let mut htb_ptr = vec![0_i32];
3934        let mut htb = Vec::<f64>::new();
3935        let mut q_of = Vec::<i32>::with_capacity(n_rows);
3936        let mut max_q = 0usize;
3937        for (i, slab) in frame.row_htbeta.iter().enumerate() {
3938            let qi = sys.row_dims[i];
3939            // A populated slab must be q_i × k row-major; an empty slab ⇒ q=0
3940            // (the row contributes no reduced-Schur term).
3941            let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3942                qi
3943            } else {
3944                0
3945            };
3946            q_of.push(checked_i32(q_eff)?);
3947            max_q = max_q.max(q_eff);
3948            if q_eff > 0 {
3949                htb.extend_from_slice(slab);
3950            }
3951            htb_ptr.push(checked_i32(htb.len())?);
3952        }
3953        if max_q == 0 {
3954            // No row contributes a reduced term — the system is pure-penalty.
3955            // Still valid; give max_q=1 so the ainv buffer is non-empty.
3956            max_q = 1;
3957        }
3958
3959        let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3960        for (i, row) in sys.rows.iter().enumerate() {
3961            let q = q_of[i] as usize;
3962            if q == 0 {
3963                continue;
3964            }
3965            if row.htt.dim() != (q, q) {
3966                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3967                    reason: format!(
3968                        "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3969                        row.htt.dim()
3970                    ),
3971                });
3972            }
3973            let mut block = row.htt.clone();
3974            for d in 0..q {
3975                block[[d, d]] += ridge_t;
3976            }
3977            let factor = gam_linalg::triangular::cholesky_factor_in_place(
3978                block.view(),
3979                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3980            )
3981            .ok_or_else(|| {
3982                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
3983                // indefinite per-row block recovers in one outer-loop retry.
3984                ArrowSchurGpuFailure::RidgeBumpRequired {
3985                    row: i,
3986                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3987                }
3988            })?;
3989            for col in 0..q {
3990                let mut e = Array1::<f64>::zeros(q);
3991                e[col] = 1.0;
3992                let solved =
3993                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3994                for r in 0..q {
3995                    ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3996                }
3997            }
3998        }
3999
4000        let htod_i = |v: &[i32]| {
4001            stream
4002                .clone_htod(v)
4003                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4004        };
4005        let htod_f = |v: &[f64]| {
4006            stream
4007                .clone_htod(v)
4008                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4009        };
4010        Ok(DeviceSaeFrameBuffers {
4011            s_off: htod_i(&s_off)?,
4012            s_m: htod_i(&s_m)?,
4013            s_r: htod_i(&s_r)?,
4014            s_ptr: htod_i(&s_ptr)?,
4015            s_data: htod_f(&s_data)?,
4016            s_blocks: data.smooth_blocks.len(),
4017            g_off_i: htod_i(&g_off_i)?,
4018            g_off_j: htod_i(&g_off_j)?,
4019            g_ri: htod_i(&g_ri)?,
4020            g_rj: htod_i(&g_rj)?,
4021            g_mi: htod_i(&g_mi)?,
4022            g_mj: htod_i(&g_mj)?,
4023            g_ptr: htod_i(&g_ptr)?,
4024            g_data: htod_f(&g_data)?,
4025            w_ptr: htod_i(&w_ptr)?,
4026            w_data: htod_f(&w_data)?,
4027            g_blocks: frame.frame_blocks.len(),
4028            g_max_work,
4029            htb_ptr: htod_i(&htb_ptr)?,
4030            htb: htod_f(&htb)?,
4031            q_of: htod_i(&q_of)?,
4032            ainv: htod_f(&ainv)?,
4033            hvec: stream
4034                .alloc_zeros::<f64>(n_rows * max_q)
4035                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4036            svec: stream
4037                .alloc_zeros::<f64>(n_rows * max_q)
4038                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4039            n_rows,
4040            k,
4041            max_q,
4042        })
4043    }
4044
4045    fn sae_frame_penalty_diag_host(
4046        data: &DeviceSaePcgData,
4047        frame: &DeviceSaeFrameData,
4048        ridge_beta: f64,
4049    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4050        let mut diag = vec![ridge_beta; data.beta_dim];
4051        // Smooth: diag[off + ia·r + ib] += S[ia,ia].
4052        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4053            let m = blk.factor_a.nrows();
4054            for ia in 0..m {
4055                let coeff = blk.factor_a[[ia, ia]];
4056                let base = blk.global_offset + ia * r;
4057                for ib in 0..r {
4058                    if base + ib >= diag.len() {
4059                        return Err(ArrowSchurGpuFailure::Unavailable);
4060                    }
4061                    diag[base + ib] += coeff;
4062                }
4063            }
4064        }
4065        // Data: on-diagonal atom blocks contribute g[li,li]·w[a,a].
4066        for blk in &frame.frame_blocks {
4067            if blk.atom_i != blk.atom_j {
4068                continue;
4069            }
4070            let r = frame.ranks[blk.atom_i];
4071            let off = frame.border_offsets[blk.atom_i];
4072            let (mi, mj) = blk.g.dim();
4073            for li in 0..mi.min(mj) {
4074                let gii = blk.g[[li, li]];
4075                let base = off + li * r;
4076                for a in 0..r {
4077                    if base + a >= diag.len() {
4078                        return Err(ArrowSchurGpuFailure::Unavailable);
4079                    }
4080                    diag[base + a] += gii * blk.w[[a, a]];
4081                }
4082            }
4083        }
4084        Ok(diag)
4085    }
4086
4087    fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4088        Ok(LaunchConfig {
4089            grid_dim: (
4090                ((work as u32).saturating_add(255) / 256).max(1),
4091                checked_i32(n_rows)? as u32,
4092                1,
4093            ),
4094            block_dim: (256, 1, 1),
4095            shared_mem_bytes: 0,
4096        })
4097    }
4098
4099    fn launch_sae_frame_matvec(
4100        stream: &Arc<CudaStream>,
4101        module: &Arc<CudaModule>,
4102        buffers: &mut DeviceSaeFrameBuffers,
4103        x: &CudaSlice<f64>,
4104        out: &mut CudaSlice<f64>,
4105        ridge_beta: f64,
4106    ) -> Result<(), ArrowSchurGpuFailure> {
4107        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4108        // Smooth penalty.
4109        if buffers.s_blocks > 0 {
4110            let kernel = module
4111                .load_function("arrow_sae_frame_smooth_matvec")
4112                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4113            let blocks_i32 = checked_i32(buffers.s_blocks)?;
4114            let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4115            let mut b = stream.launch_builder(&kernel);
4116            b.arg(x)
4117                .arg(&mut *out)
4118                .arg(&buffers.s_off)
4119                .arg(&buffers.s_m)
4120                .arg(&buffers.s_r)
4121                .arg(&buffers.s_ptr)
4122                .arg(&buffers.s_data)
4123                .arg(&blocks_i32);
4124            // SAFETY: smooth block metadata/data are live device buffers; the grid
4125            // covers (k channels × n_blocks) and the kernel bounds-checks m·r.
4126            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4127        }
4128        // Data penalty.
4129        if buffers.g_blocks > 0 {
4130            let kernel = module
4131                .load_function("arrow_sae_frame_g_matvec")
4132                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4133            let blocks_i32 = checked_i32(buffers.g_blocks)?;
4134            let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4135            let mut b = stream.launch_builder(&kernel);
4136            b.arg(x)
4137                .arg(&mut *out)
4138                .arg(&buffers.g_off_i)
4139                .arg(&buffers.g_off_j)
4140                .arg(&buffers.g_ri)
4141                .arg(&buffers.g_rj)
4142                .arg(&buffers.g_mi)
4143                .arg(&buffers.g_mj)
4144                .arg(&buffers.g_ptr)
4145                .arg(&buffers.g_data)
4146                .arg(&buffers.w_ptr)
4147                .arg(&buffers.w_data)
4148                .arg(&blocks_i32);
4149            // SAFETY: g/w block metadata/data are live device buffers; the grid
4150            // covers (max m_i·r_i × n_blocks) and the kernel bounds-checks.
4151            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4152        }
4153        // Reduced-Schur subtraction via dense per-row cross-blocks.
4154        let k_i32 = checked_i32(buffers.k)?;
4155        let max_q_i32 = checked_i32(buffers.max_q)?;
4156        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4157        {
4158            let kernel = module
4159                .load_function("arrow_sae_frame_apply_h")
4160                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4161            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4162            let mut b = stream.launch_builder(&kernel);
4163            b.arg(x)
4164                .arg(&buffers.htb_ptr)
4165                .arg(&buffers.htb)
4166                .arg(&buffers.q_of)
4167                .arg(&mut buffers.hvec)
4168                .arg(&k_i32)
4169                .arg(&max_q_i32)
4170                .arg(&n_rows_i32);
4171            // SAFETY: dense cross-block + pointers + hvec are live buffers sized
4172            // for (n_rows × max_q) / (n_rows × k); kernel guards q_i and k.
4173            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4174        }
4175        {
4176            let kernel = module
4177                .load_function("arrow_sae_frame_apply_ainv")
4178                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4179            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4180            let mut b = stream.launch_builder(&kernel);
4181            b.arg(&buffers.ainv)
4182                .arg(&buffers.hvec)
4183                .arg(&buffers.q_of)
4184                .arg(&mut buffers.svec)
4185                .arg(&max_q_i32)
4186                .arg(&n_rows_i32);
4187            // SAFETY: ainv/hvec/svec are live buffers sized for n_rows·max_q²
4188            // and n_rows·max_q; the kernel guards row/coord bounds.
4189            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4190        }
4191        {
4192            let kernel = module
4193                .load_function("arrow_sae_frame_scatter_h")
4194                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4195            let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4196            let mut b = stream.launch_builder(&kernel);
4197            b.arg(&buffers.svec)
4198                .arg(&buffers.htb_ptr)
4199                .arg(&buffers.htb)
4200                .arg(&buffers.q_of)
4201                .arg(out)
4202                .arg(&k_i32)
4203                .arg(&max_q_i32)
4204                .arg(&n_rows_i32);
4205            // SAFETY: svec/cross-block/out are live buffers; the kernel atomically
4206            // accumulates into out[a] for a<k and reads c<q_i.
4207            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4208        }
4209        Ok(())
4210    }
4211
4212    fn launch_sae_frame_diag_sub(
4213        stream: &Arc<CudaStream>,
4214        module: &Arc<CudaModule>,
4215        buffers: &DeviceSaeFrameBuffers,
4216        diag: &mut CudaSlice<f64>,
4217    ) -> Result<(), ArrowSchurGpuFailure> {
4218        let kernel = module
4219            .load_function("arrow_sae_frame_diag_sub")
4220            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4221        let k_i32 = checked_i32(buffers.k)?;
4222        let max_q_i32 = checked_i32(buffers.max_q)?;
4223        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4224        let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4225        let mut b = stream.launch_builder(&kernel);
4226        b.arg(diag)
4227            .arg(&buffers.ainv)
4228            .arg(&buffers.htb_ptr)
4229            .arg(&buffers.htb)
4230            .arg(&buffers.q_of)
4231            .arg(&k_i32)
4232            .arg(&max_q_i32)
4233            .arg(&n_rows_i32);
4234        // SAFETY: diag + cross-block + ainv live buffers; kernel guards a<k, c/d<q.
4235        unsafe { b.launch(cfg) }
4236            .map(drop)
4237            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4238    }
4239
4240    /// #1551 kernel-isolating seam: evaluate the framed reduced-Schur matvec
4241    /// `out = S·x` EXACTLY ONCE on the device (no PCG, no offload-floor gate) and
4242    /// return `out`. This is the parity probe the test harness diffs against the
4243    /// CPU oracle [`super::sae_framed_schur_matvec_cpu`] element-by-element, so a
4244    /// kernel/marshalling defect is exposed directly — independent of how the
4245    /// iterative solver behaves on an ill-conditioned assembled `S` (where dense
4246    /// Cholesky and PCG legitimately disagree at the solution level). Declines
4247    /// (`Unavailable`) only when CUDA is genuinely absent so the test skips
4248    /// cleanly off-device; it deliberately does NOT consult the offload policy so
4249    /// even a tiny verifiable fixture runs on the GPU.
4250    pub(super) fn framed_schur_matvec_once_on_device(
4251        sys: &ArrowSchurSystem,
4252        data: &DeviceSaePcgData,
4253        ridge_t: f64,
4254        ridge_beta: f64,
4255        x: &Array1<f64>,
4256    ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4257        let k = x.len();
4258        if k == 0 || data.beta_dim != k || sys.k != k {
4259            return Err(ArrowSchurGpuFailure::Unavailable);
4260        }
4261        let frame = data
4262            .frame
4263            .as_ref()
4264            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4265        // No offload-policy filter here: the seam exists to validate the kernel on
4266        // ANY device, including the smallest hand-checkable fixture.
4267        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4268            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4269        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4270            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4271        let stream = ctx
4272            .new_stream()
4273            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4274        let vector_module = pcg_vector_module(&ctx)?;
4275        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4276        let x_dev = stream
4277            .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4278            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4279        let mut out_dev = stream
4280            .alloc_zeros::<f64>(k)
4281            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4282        launch_sae_frame_matvec(
4283            &stream,
4284            vector_module,
4285            &mut buffers,
4286            &x_dev,
4287            &mut out_dev,
4288            ridge_beta,
4289        )?;
4290        let out = stream
4291            .clone_dtoh(&out_dev)
4292            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4293        Ok(Array1::from_vec(out))
4294    }
4295
4296    pub(super) fn solve_sae_matrix_free_pcg_framed(
4297        sys: &ArrowSchurSystem,
4298        data: &DeviceSaePcgData,
4299        ridge_t: f64,
4300        ridge_beta: f64,
4301        rhs_beta: &Array1<f64>,
4302        max_iterations: usize,
4303        relative_tolerance: f64,
4304    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4305        let k = rhs_beta.len();
4306        if k == 0 || data.beta_dim != k || sys.k != k {
4307            return Err(ArrowSchurGpuFailure::Unavailable);
4308        }
4309        let frame = data
4310            .frame
4311            .as_ref()
4312            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4313        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4314            .filter(|rt| {
4315                rt.policy().reduced_schur_matvec_should_offload(
4316                    sys.rows.len(),
4317                    sys.k,
4318                    sys.d,
4319                    max_iterations,
4320                )
4321            })
4322            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4323        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4324            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4325        let stream = ctx
4326            .new_stream()
4327            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4328        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4329        let vector_module = pcg_vector_module(&ctx)?;
4330        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4331
4332        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4333        if rhs_norm == 0.0 {
4334            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4335        }
4336        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4337        let rhs_dev = stream
4338            .clone_htod(
4339                rhs_beta
4340                    .as_slice()
4341                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4342            )
4343            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4344        let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4345        let mut diag_dev = stream
4346            .clone_htod(&diag_host)
4347            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4348        launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4349        let diag_host = stream
4350            .clone_dtoh(&diag_dev)
4351            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4352        let mut inv_diag = Vec::with_capacity(k);
4353        for (idx, &d) in diag_host.iter().enumerate() {
4354            if !d.is_finite() || d <= 1.0e-18 {
4355                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4356                    reason: format!(
4357                        "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4358                    ),
4359                });
4360            }
4361            inv_diag.push(1.0 / d);
4362        }
4363        let inv_diag_dev = stream
4364            .clone_htod(&inv_diag)
4365            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4366
4367        let mut x_dev = stream
4368            .alloc_zeros::<f64>(k)
4369            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4370        let mut r_dev = stream
4371            .alloc_zeros::<f64>(k)
4372            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4374        let mut z_dev = stream
4375            .alloc_zeros::<f64>(k)
4376            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4377        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4378        let mut p_dev = stream
4379            .alloc_zeros::<f64>(k)
4380            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4381        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4382        let mut ap_dev = stream
4383            .alloc_zeros::<f64>(k)
4384            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4385
4386        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4387        if rz <= 0.0 || !rz.is_finite() {
4388            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4389                reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4390            });
4391        }
4392        let mut diag = PcgDiagnostics {
4393            precond_apply_calls: 1,
4394            stopping_reason: PcgStopReason::MaxIter,
4395            ..PcgDiagnostics::default()
4396        };
4397        for _ in 0..max_iterations.max(1) {
4398            launch_sae_frame_matvec(
4399                &stream,
4400                vector_module,
4401                &mut buffers,
4402                &p_dev,
4403                &mut ap_dev,
4404                ridge_beta,
4405            )?;
4406            diag.matvec_calls += 1;
4407            diag.iterations += 1;
4408            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4409            if pap <= 0.0 || !pap.is_finite() {
4410                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4411                    reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4412                });
4413            }
4414            let alpha = rz / pap;
4415            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4416            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4417            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4418            if r_norm <= tol {
4419                diag.final_relative_residual = r_norm / rhs_norm;
4420                diag.stopping_reason = PcgStopReason::Converged;
4421                break;
4422            }
4423            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4424            diag.precond_apply_calls += 1;
4425            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4426            if rz_new <= 0.0 || !rz_new.is_finite() {
4427                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4428                    reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4429                });
4430            }
4431            let beta = rz_new / rz;
4432            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4433            rz = rz_new;
4434        }
4435        if diag.stopping_reason != PcgStopReason::Converged {
4436            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4437            diag.final_relative_residual = r_norm / rhs_norm;
4438            diag.stopping_reason = PcgStopReason::MaxIter;
4439        }
4440        let x = stream
4441            .clone_dtoh(&x_dev)
4442            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4443        Ok((Array1::from_vec(x), diag))
4444    }
4445
4446    /// #1551 stage-isolating triage seam: run the framed reduced-Schur matvec
4447    /// `out = S·x` ONCE on the device (no PCG, no offload-floor gate) and return
4448    /// `out`, so a tiny hand-verifiable fixture can diff it against the CPU oracle
4449    /// `sae_framed_schur_matvec_cpu` element-by-element to localize the structural
4450    /// divergence to a single kernel stage. Returns `Unavailable` only when CUDA
4451    /// is genuinely absent (so the test skips cleanly off-device).
4452    pub(super) fn solve_sae_matrix_free_pcg(
4453        sys: &ArrowSchurSystem,
4454        data: &DeviceSaePcgData,
4455        ridge_t: f64,
4456        ridge_beta: f64,
4457        rhs_beta: &Array1<f64>,
4458        max_iterations: usize,
4459        relative_tolerance: f64,
4460    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4461        let k = rhs_beta.len();
4462        if k == 0 || data.beta_dim != k || sys.k != k {
4463            return Err(ArrowSchurGpuFailure::Unavailable);
4464        }
4465        // #1017/#1026 GUARD: the legacy `⊗ I_p` kernel must NEVER receive framed
4466        // data (factored `G ⊗ W_{ij}` + dense per-row cross blocks); decline so a
4467        // mis-route falls back to the CPU rather than returning a wrong step.
4468        if data.frame.is_some() {
4469            return Err(ArrowSchurGpuFailure::Unavailable);
4470        }
4471        // #1017 Phase-1 dispatch re-key: this is the matrix-free SAE reduced-Schur
4472        // PCG — the production hot path, not a single dense factorization. The
4473        // dense-Direct floor `dense_hessian_work_target_is_gpu(n, k)` keys on
4474        // `2·n·k²` and is the WRONG gate here: it ignores the per-row frame depth
4475        // `d` (the M dimension that multiplies the per-apply work) and the
4476        // `1/cg_iters` staging amortisation, so it both undercounts the SAE batched
4477        // work `n·k·d` and applies a cold single-launch breakeven to an apply that
4478        // reuses device-resident frames `max_iterations` times. Key instead on the
4479        // CG-amortised total batched work — the same predicate the host injection
4480        // gate (`maybe_inject_gpu_schur_matvec`) consults — so few-row/wide-`k`/
4481        // modest-`d` LLM shapes register the real `n × k × d × cg_iters` arithmetic.
4482        // Kernels and numerics are untouched; only where the matvec runs changes,
4483        // and the host falls back to the bit-identical CPU matvec when this declines.
4484        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4485            .filter(|rt| {
4486                rt.policy().reduced_schur_matvec_should_offload(
4487                    sys.rows.len(),
4488                    sys.k,
4489                    sys.d,
4490                    max_iterations,
4491                )
4492            })
4493            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4494        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4495            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4496        let stream = ctx
4497            .new_stream()
4498            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4499        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4500        let vector_module = pcg_vector_module(&ctx)?;
4501        let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4502
4503        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4504        if rhs_norm == 0.0 {
4505            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4506        }
4507        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4508        let rhs_dev = stream
4509            .clone_htod(
4510                rhs_beta
4511                    .as_slice()
4512                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4513            )
4514            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4515        let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4516        let mut diag_dev = stream
4517            .clone_htod(&diag_host)
4518            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4519        launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4520        let diag_host = stream
4521            .clone_dtoh(&diag_dev)
4522            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4523        let mut inv_diag = Vec::with_capacity(k);
4524        for (idx, &d) in diag_host.iter().enumerate() {
4525            if !d.is_finite() || d <= 1.0e-18 {
4526                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4527                    reason: format!(
4528                        "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4529                    ),
4530                });
4531            }
4532            inv_diag.push(1.0 / d);
4533        }
4534        let inv_diag_dev = stream
4535            .clone_htod(&inv_diag)
4536            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4537
4538        let mut x_dev = stream
4539            .alloc_zeros::<f64>(k)
4540            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541        let mut r_dev = stream
4542            .alloc_zeros::<f64>(k)
4543            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4545        let mut z_dev = stream
4546            .alloc_zeros::<f64>(k)
4547            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4548        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4549        let mut p_dev = stream
4550            .alloc_zeros::<f64>(k)
4551            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4552        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4553        let mut ap_dev = stream
4554            .alloc_zeros::<f64>(k)
4555            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4556
4557        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4558        if rz <= 0.0 || !rz.is_finite() {
4559            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4560                reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4561            });
4562        }
4563        let mut diag = PcgDiagnostics {
4564            precond_apply_calls: 1,
4565            stopping_reason: PcgStopReason::MaxIter,
4566            ..PcgDiagnostics::default()
4567        };
4568
4569        for _ in 0..max_iterations.max(1) {
4570            launch_sae_matvec(
4571                &stream,
4572                vector_module,
4573                &mut buffers,
4574                &p_dev,
4575                &mut ap_dev,
4576                ridge_beta,
4577            )?;
4578            diag.matvec_calls += 1;
4579            diag.iterations += 1;
4580            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4581            if pap <= 0.0 || !pap.is_finite() {
4582                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4583                    reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4584                });
4585            }
4586            let alpha = rz / pap;
4587            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4588            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4589            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4590            if r_norm <= tol {
4591                diag.final_relative_residual = r_norm / rhs_norm;
4592                diag.stopping_reason = PcgStopReason::Converged;
4593                break;
4594            }
4595            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4596            diag.precond_apply_calls += 1;
4597            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4598            if rz_new <= 0.0 || !rz_new.is_finite() {
4599                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4600                    reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4601                });
4602            }
4603            let beta = rz_new / rz;
4604            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4605            rz = rz_new;
4606        }
4607        if diag.stopping_reason != PcgStopReason::Converged {
4608            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4609            diag.final_relative_residual = r_norm / rhs_norm;
4610            diag.stopping_reason = PcgStopReason::MaxIter;
4611        }
4612        let x = stream
4613            .clone_dtoh(&x_dev)
4614            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4615        Ok((Array1::from_vec(x), diag))
4616    }
4617
4618    pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4619        s_acc: &ndarray::Array2<f64>,
4620        rhs_beta: &Array1<f64>,
4621        max_iterations: usize,
4622        relative_tolerance: f64,
4623    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4624        let k = rhs_beta.len();
4625        // #1017 dispatch re-key: this is an ITERATIVE device-resident PCG, not a
4626        // single GEMV. `S` (k×k) is uploaded once and reused for `max_iterations`
4627        // `S·p` GEMVs while only convergence scalars cross PCIe, so the staging
4628        // cost is amortised over the whole CG solve. Gating on the flops of ONE
4629        // `Gemv{k,k}` (`2·k²`) understates the work by the iteration count and
4630        // declines shapes (e.g. k≈512) whose total iterated arithmetic
4631        // `2·k²·iters` clears the device floor by orders of magnitude — the same
4632        // single-launch-breakeven miskey #1017 fixed for the framed reduced-Schur
4633        // matvec. Key on the CG-amortised total work via a `Gemm{k,k,iters}` whose
4634        // `flops()` is exactly `2·k²·iters`; numerics and kernels are untouched,
4635        // and the host falls back to the bit-identical CPU PCG when this declines.
4636        let cg_iters = max_iterations.max(1);
4637        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4638            gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4639                m: k,
4640                n: k,
4641                k: cg_iters,
4642            },
4643        )
4644        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4645        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4646            .and_then(|ctx| ctx.new_stream().ok())
4647            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4648        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4649        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4650            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4651        let vector_module = pcg_vector_module(&ctx)?;
4652
4653        // Jacobi diagonal from S; must be strictly positive for SPD.
4654        let mut inv_diag = vec![0.0_f64; k];
4655        for j in 0..k {
4656            let djj = s_acc[[j, j]];
4657            if !(djj > 0.0) {
4658                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4659                    reason: format!(
4660                        "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4661                    ),
4662                });
4663            }
4664            inv_diag[j] = 1.0 / djj;
4665        }
4666
4667        // Upload S column-major (S[row,col] at col*k + row).
4668        let mut s_host = vec![0.0_f64; k * k];
4669        for col in 0..k {
4670            for row in 0..k {
4671                s_host[col * k + row] = s_acc[[row, col]];
4672            }
4673        }
4674        let s_dev = stream
4675            .clone_htod(&s_host)
4676            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4677
4678        // Steihaug truncated-CG with Jacobi preconditioner, host scalar
4679        // recurrences and a device `S·p` matvec. The streaming reduced solve
4680        // uses an unbounded trust region (pure CG to tolerance).
4681        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4682        if rhs_norm == 0.0 {
4683            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4684        }
4685        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4686
4687        // Device-resident PCG state. Only convergence scalars cross back during
4688        // the loop; x/r/z/p/Sp stay on CUDA until the final solution download.
4689        let mut x_dev = stream
4690            .alloc_zeros::<f64>(k)
4691            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4692        let mut r_dev = stream
4693            .clone_htod(
4694                rhs_beta
4695                    .as_slice()
4696                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4697            )
4698            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4699        let inv_diag_dev = stream
4700            .clone_htod(&inv_diag)
4701            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702        let mut z_dev = stream
4703            .alloc_zeros::<f64>(k)
4704            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4705        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4706        let mut p_dev = stream
4707            .alloc_zeros::<f64>(k)
4708            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4709        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4710        let mut sp_dev = stream
4711            .alloc_zeros::<f64>(k)
4712            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4713        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4714        let mut diag = PcgDiagnostics {
4715            precond_apply_calls: 1,
4716            stopping_reason: PcgStopReason::MaxIter,
4717            ..PcgDiagnostics::default()
4718        };
4719        if rz <= 0.0 || !rz.is_finite() {
4720            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4721                reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4722            });
4723        }
4724
4725        let max_iters = max_iterations.max(1);
4726        for _ in 0..max_iters {
4727            // sp = S · p (device GEMV, S column-major k×k, op = N).
4728            let gemv_cfg = GemvConfig::<f64> {
4729                trans: cublasOperation_t::CUBLAS_OP_N,
4730                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4731                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4732                alpha: 1.0,
4733                lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4734                incx: 1,
4735                beta: 0.0,
4736                incy: 1,
4737            };
4738            // SAFETY: s_dev is k×k column-major, p_dev / sp_dev length k.
4739            unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4740                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4741            diag.matvec_calls += 1;
4742            diag.iterations += 1;
4743
4744            let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4745            if !(p_sp > 0.0) {
4746                // Non-positive curvature on a (proximal-ridged) SPD system means
4747                // numerical breakdown; surface so the caller escalates.
4748                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4749                    reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4750                });
4751            }
4752            let alpha = rz / p_sp;
4753            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4754            device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4755            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4756            if r_norm <= tol {
4757                diag.final_relative_residual = r_norm / rhs_norm;
4758                diag.stopping_reason = PcgStopReason::Converged;
4759                break;
4760            }
4761            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4762            diag.precond_apply_calls += 1;
4763            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4764            if rz_new <= 0.0 || !rz_new.is_finite() {
4765                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4766                    reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4767                });
4768            }
4769            let beta = rz_new / rz;
4770            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4771            rz = rz_new;
4772        }
4773        if diag.stopping_reason != PcgStopReason::Converged {
4774            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4775            diag.final_relative_residual = r_norm / rhs_norm;
4776            diag.stopping_reason = PcgStopReason::MaxIter;
4777        }
4778
4779        let x = stream
4780            .clone_dtoh(&x_dev)
4781            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4782        Ok((Array1::from_vec(x), diag))
4783    }
4784
4785    fn device_copy(
4786        blas: &CudaBlas,
4787        stream: &Arc<CudaStream>,
4788        n: usize,
4789        src: &CudaSlice<f64>,
4790        dst: &mut CudaSlice<f64>,
4791    ) -> Result<(), ArrowSchurGpuFailure> {
4792        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4793        let (src_ptr, _src_rec) = src.device_ptr(stream);
4794        let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4795        // SAFETY: src and dst are live device allocations on this stream with at
4796        // least n contiguous f64 entries and unit stride.
4797        let status = unsafe {
4798            cudarc::cublas::sys::cublasDcopy_v2(
4799                *blas.handle(),
4800                n_i,
4801                src_ptr as *const f64,
4802                1,
4803                dst_ptr as *mut f64,
4804                1,
4805            )
4806        };
4807        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4808            Ok(())
4809        } else {
4810            Err(ArrowSchurGpuFailure::Unavailable)
4811        }
4812    }
4813
4814    fn device_axpy(
4815        blas: &CudaBlas,
4816        stream: &Arc<CudaStream>,
4817        n: usize,
4818        alpha: f64,
4819        x: &CudaSlice<f64>,
4820        y: &mut CudaSlice<f64>,
4821    ) -> Result<(), ArrowSchurGpuFailure> {
4822        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4823        let (x_ptr, _x_rec) = x.device_ptr(stream);
4824        let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4825        // SAFETY: x and y are live device allocations on this stream with at
4826        // least n contiguous f64 entries and unit stride; cuBLAS only reads alpha.
4827        let status = unsafe {
4828            cudarc::cublas::sys::cublasDaxpy_v2(
4829                *blas.handle(),
4830                n_i,
4831                &alpha,
4832                x_ptr as *const f64,
4833                1,
4834                y_ptr as *mut f64,
4835                1,
4836            )
4837        };
4838        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4839            Ok(())
4840        } else {
4841            Err(ArrowSchurGpuFailure::Unavailable)
4842        }
4843    }
4844
4845    fn device_dot(
4846        blas: &CudaBlas,
4847        stream: &Arc<CudaStream>,
4848        n: usize,
4849        x: &CudaSlice<f64>,
4850        y: &CudaSlice<f64>,
4851    ) -> Result<f64, ArrowSchurGpuFailure> {
4852        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4853        let (x_ptr, _x_rec) = x.device_ptr(stream);
4854        let (y_ptr, _y_rec) = y.device_ptr(stream);
4855        let mut result = 0.0_f64;
4856        // SAFETY: x and y are live device allocations on this stream with at
4857        // least n contiguous f64 entries and unit stride; result is a valid host
4858        // out-pointer for the cuBLAS scalar.
4859        let status = unsafe {
4860            cudarc::cublas::sys::cublasDdot_v2(
4861                *blas.handle(),
4862                n_i,
4863                x_ptr as *const f64,
4864                1,
4865                y_ptr as *const f64,
4866                1,
4867                &mut result,
4868            )
4869        };
4870        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4871            Ok(result)
4872        } else {
4873            Err(ArrowSchurGpuFailure::Unavailable)
4874        }
4875    }
4876
4877    fn device_nrm2(
4878        blas: &CudaBlas,
4879        stream: &Arc<CudaStream>,
4880        n: usize,
4881        x: &CudaSlice<f64>,
4882    ) -> Result<f64, ArrowSchurGpuFailure> {
4883        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4884        let (x_ptr, _x_rec) = x.device_ptr(stream);
4885        let mut result = 0.0_f64;
4886        // SAFETY: x is a live device allocation on this stream with at least n
4887        // contiguous f64 entries and unit stride; result is a valid host
4888        // out-pointer for the cuBLAS scalar.
4889        let status = unsafe {
4890            cudarc::cublas::sys::cublasDnrm2_v2(
4891                *blas.handle(),
4892                n_i,
4893                x_ptr as *const f64,
4894                1,
4895                &mut result,
4896            )
4897        };
4898        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4899            Ok(result)
4900        } else {
4901            Err(ArrowSchurGpuFailure::Unavailable)
4902        }
4903    }
4904
4905    #[cfg(test)]
4906    mod tests {
4907        //! #1551 device-side framed-matvec triage. Lives inside `mod cuda` so it
4908        //! can call the private kernel launchers directly (no test-only public
4909        //! seam, which the ban-scanner forbids). A bare `#[cfg(test)] mod tests`
4910        //! is the one form the scanner permits.
4911        use super::*;
4912        use crate::arrow_schur::{
4913            ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4914            FactoredFrameGBlock,
4915        };
4916        use ndarray::Array2;
4917
4918        /// Run the framed reduced-Schur matvec `out = S·x` ONCE on the device
4919        /// (no PCG, no offload gate) and return `out`.
4920        fn device_matvec_once(
4921            sys: &ArrowSchurSystem,
4922            data: &DeviceSaePcgData,
4923            ridge_t: f64,
4924            ridge_beta: f64,
4925            x_host: &[f64],
4926        ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4927            let k = x_host.len();
4928            let frame = data
4929                .frame
4930                .as_ref()
4931                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4932            let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4933                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4934            let ctx =
4935                gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4936                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4937            let stream = ctx
4938                .new_stream()
4939                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4940            let vector_module = pcg_vector_module(&ctx)?;
4941            let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4942            let x_dev = stream
4943                .clone_htod(x_host)
4944                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4945            let mut out_dev = stream
4946                .alloc_zeros::<f64>(k)
4947                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4948            launch_sae_frame_matvec(
4949                &stream,
4950                vector_module,
4951                &mut buffers,
4952                &x_dev,
4953                &mut out_dev,
4954                ridge_beta,
4955            )?;
4956            stream
4957                .clone_dtoh(&out_dev)
4958                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4959        }
4960
4961        /// #1551 stage-isolating matvec triage on a TINY hand-verifiable fixture:
4962        /// diff the device framed matvec `S·e_col` against the CPU oracle
4963        /// `sae_framed_schur_matvec_cpu` for every identity column, reporting the
4964        /// worst-divergent border index so the structural 91% localizes to one
4965        /// kernel stage. Skips cleanly off-device.
4966        #[test]
4967        fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4968            if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4969                return;
4970            }
4971            let p = 3usize;
4972            let ranks = vec![2usize, 3usize];
4973            let basis_sizes = vec![2usize, 2usize];
4974            let mut border_offsets = Vec::new();
4975            let mut acc = 0usize;
4976            for k in 0..2 {
4977                border_offsets.push(acc);
4978                acc += basis_sizes[k] * ranks[k];
4979            }
4980            let border_dim = acc; // 2·2 + 2·3 = 10
4981            let frame_of = |k: usize| -> Array2<f64> {
4982                Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4983                    0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4984                })
4985            };
4986            let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4987            let w_of = |i: usize, j: usize| -> Array2<f64> {
4988                let (ui, uj) = (&frames[i], &frames[j]);
4989                Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4990                    (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4991                })
4992            };
4993            let mut frame_blocks = Vec::new();
4994            for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4995                let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4996                let mut g =
4997                    Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4998                if i == j {
4999                    for r in 0..mi.min(mj) {
5000                        g[[r, r]] += mi as f64 + 2.0;
5001                    }
5002                }
5003                frame_blocks.push(FactoredFrameGBlock {
5004                    atom_i: i,
5005                    atom_j: j,
5006                    g,
5007                    w: w_of(i, j),
5008                });
5009            }
5010            let mut smooth_blocks = Vec::new();
5011            for k in 0..2 {
5012                let m = basis_sizes[k];
5013                let mut s =
5014                    Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5015                for r in 0..m {
5016                    s[[r, r]] += 1.0;
5017                }
5018                smooth_blocks.push(DeviceSaeSmoothBlock {
5019                    global_offset: border_offsets[k],
5020                    factor_a: s,
5021                });
5022            }
5023            let smooth_ranks = ranks.clone();
5024            let n = 2usize;
5025            let q = 2usize;
5026            let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5027            let mut row_htbeta = Vec::new();
5028            for i in 0..n {
5029                let mut htt =
5030                    Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5031                for r in 0..q {
5032                    htt[[r, r]] += q as f64 + 2.0;
5033                }
5034                sys.rows[i].htt = htt;
5035                let mut slab = vec![0.0_f64; q * border_dim];
5036                for c in 0..q {
5037                    for col in 0..border_dim {
5038                        let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5039                        slab[c * border_dim + col] = v;
5040                        sys.rows[i].htbeta[[c, col]] = v;
5041                    }
5042                }
5043                row_htbeta.push(slab);
5044            }
5045            let data = DeviceSaePcgData {
5046                p,
5047                beta_dim: border_dim,
5048                a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5049                local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5050                smooth_blocks,
5051                sparse_g_blocks: Vec::new(),
5052                frame: Some(DeviceSaeFrameData {
5053                    ranks,
5054                    basis_sizes,
5055                    border_offsets,
5056                    frame_blocks,
5057                    smooth_ranks,
5058                    row_htbeta,
5059                }),
5060            };
5061            let ridge_t = 1e-7;
5062            let ridge_beta = 1e-6;
5063            let mut first_bad: Option<usize> = None;
5064            let mut worst = 0.0_f64;
5065            let mut worst_at = 0usize;
5066            let mut worst_dev = 0.0_f64;
5067            let mut worst_cpu = 0.0_f64;
5068            for col in 0..border_dim {
5069                let mut x = vec![0.0_f64; border_dim];
5070                x[col] = 1.0;
5071                let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5072                    Ok(v) => v,
5073                    Err(_) => return,
5074                };
5075                let mut cpu = vec![0.0_f64; border_dim];
5076                super::super::sae_framed_schur_matvec_cpu(
5077                    &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5078                )
5079                .expect("cpu matvec");
5080                for r in 0..border_dim {
5081                    let d = (dev[r] - cpu[r]).abs();
5082                    if d > 1e-9 && first_bad.is_none() {
5083                        first_bad = Some(r * border_dim + col);
5084                    }
5085                    if d > worst {
5086                        worst = d;
5087                        worst_at = r * border_dim + col;
5088                        worst_dev = dev[r];
5089                        worst_cpu = cpu[r];
5090                    }
5091                }
5092            }
5093            assert!(
5094                worst <= 1e-9,
5095                "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5096                 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5097                 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5098                 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5099                 G⊗W=cross, reduced-Schur=dense per-row)",
5100            );
5101        }
5102    }
5103}
5104
5105#[cfg(test)]
5106mod tests {
5107    use super::*;
5108    use crate::arrow_schur::ArrowSchurSystem;
5109    use ndarray::{Array2, ArrayView1};
5110
5111    fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5112        let mut sys = ArrowSchurSystem::new(n, d, k);
5113        let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5114        let mut sample = || -> f64 {
5115            state = state
5116                .wrapping_mul(6364136223846793005)
5117                .wrapping_add(1442695040888963407);
5118            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5119        };
5120        for row in &mut sys.rows {
5121            let mut a = Array2::<f64>::zeros((d, d));
5122            for r in 0..d {
5123                for c in 0..d {
5124                    a[[r, c]] = sample();
5125                }
5126            }
5127            let mut htt = a.t().dot(&a);
5128            for r in 0..d {
5129                htt[[r, r]] += d as f64 + 1.0;
5130            }
5131            row.htt = htt;
5132            for r in 0..d {
5133                for c in 0..k {
5134                    row.htbeta[[r, c]] = 0.1 * sample();
5135                }
5136                row.gt[r] = sample();
5137            }
5138        }
5139        let mut hbb_a = Array2::<f64>::zeros((k, k));
5140        for r in 0..k {
5141            for c in 0..k {
5142                hbb_a[[r, c]] = sample();
5143            }
5144        }
5145        let mut hbb = hbb_a.t().dot(&hbb_a);
5146        for r in 0..k {
5147            hbb[[r, r]] += k as f64 + 1.0;
5148        }
5149        sys.hbb = hbb;
5150        for r in 0..k {
5151            sys.gb[r] = sample();
5152        }
5153        sys
5154    }
5155
5156    /// The Gershgorin ridge bump must actually make a known-indefinite block PD
5157    /// on the first retry — the whole point of #1711. Verified directly by
5158    /// re-factoring `H_tt + (ridge_t + bump)·I` with the same Cholesky guard the
5159    /// device readback uses, for blocks whose `λ_min` is known in closed form.
5160    #[test]
5161    fn ridge_bump_makes_known_indefinite_blocks_pd() {
5162        // `cholesky_factor_in_place` / `CholeskyGuard` are already in scope via
5163        // `super::*` (imported at the top of the module).
5164        // A few blocks with a CLOSED-FORM smallest eigenvalue, all at ridge_t=0.
5165        // (label, matrix, λ_min) — the bump must clear each one.
5166        let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); // λ_min = -1
5167        let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); // λ_min = -250
5168        // Symmetric 2×2 [[1, 2], [2, 1]] has eigenvalues 3 and -1 → indefinite.
5169        let mut indef2 = Array2::<f64>::zeros((2, 2));
5170        indef2[[0, 0]] = 1.0;
5171        indef2[[1, 1]] = 1.0;
5172        indef2[[0, 1]] = 2.0;
5173        indef2[[1, 0]] = 2.0;
5174        // A genuinely PD block must get a bump that is the bare rounding margin
5175        // only (deficit 0), and must still factor — the helper is defensive.
5176        let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5177
5178        for (label, block) in [
5179            ("-I (λ_min=-1)", neg_identity),
5180            ("-250·I (λ_min=-250)", scaled_neg),
5181            ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5182            ("5·I (PD)", pd),
5183        ] {
5184            let ridge_t = 0.0;
5185            let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5186            assert!(
5187                bump > 0.0 && bump.is_finite(),
5188                "[{label}] bump must be strictly positive and finite, got {bump:e}"
5189            );
5190            let d = block.nrows();
5191            let mut shifted = block.clone();
5192            for i in 0..d {
5193                shifted[[i, i]] += ridge_t + bump;
5194            }
5195            assert!(
5196                cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5197                "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5198                 Gershgorin bump, but the Cholesky still rejected it"
5199            );
5200        }
5201    }
5202
5203    /// The column-major variant (multi-GPU tile path) must agree with the
5204    /// row-major helper for a symmetric block, since Gershgorin edges are
5205    /// invariant under reading the symmetric matrix by row vs by column. The
5206    /// colmajor variant takes the bound at ridge_t=0 (the ridge is already baked
5207    /// into the diagonal it reads), so compare against `ridge_bump_to_make_pd`
5208    /// with `ridge_t = 0`.
5209    ///
5210    /// Gated to linux: `ridge_bump_to_make_pd_colmajor` only exists on the
5211    /// linux CUDA tile path, so the parity test runs where the function does.
5212    #[cfg(target_os = "linux")]
5213    #[test]
5214    fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5215        // Symmetric 3×3 with a negative-definite-ish diagonal and off-diagonals.
5216        let mut a = Array2::<f64>::zeros((3, 3));
5217        a[[0, 0]] = -2.0;
5218        a[[1, 1]] = 0.5;
5219        a[[2, 2]] = 1.0;
5220        a[[0, 1]] = 0.3;
5221        a[[1, 0]] = 0.3;
5222        a[[1, 2]] = -0.4;
5223        a[[2, 1]] = -0.4;
5224        a[[0, 2]] = 0.1;
5225        a[[2, 0]] = 0.1;
5226
5227        let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5228
5229        // Flatten column-major: block[c*d + r] = a[[r, c]].
5230        let d = 3;
5231        let mut col_major = vec![0.0_f64; d * d];
5232        for c in 0..d {
5233            for r in 0..d {
5234                col_major[c * d + r] = a[[r, c]];
5235            }
5236        }
5237        let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5238
5239        assert!(
5240            (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5241            "colmajor bump {col_major_bump:e} must match rowmajor bump \
5242             {row_major_bump:e} for a symmetric block"
5243        );
5244
5245        // And the bump must actually make it PD (sanity, same as the row-major test).
5246        let mut shifted = a.clone();
5247        for i in 0..d {
5248            shifted[[i, i]] += col_major_bump;
5249        }
5250        assert!(
5251            cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5252            "colmajor Gershgorin bump must make the symmetric block PD"
5253        );
5254    }
5255
5256    fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5257        let mut s = Array2::<f64>::zeros((k, k));
5258        for row in 0..k {
5259            s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5260            if row + 1 < k {
5261                s[[row, row + 1]] = -0.05;
5262                s[[row + 1, row]] = -0.05;
5263            }
5264            if row + 7 < k {
5265                s[[row, row + 7]] = 0.01;
5266                s[[row + 7, row]] = 0.01;
5267            }
5268        }
5269        let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5270        (s, rhs)
5271    }
5272
5273    fn dense_pcg_cpu_reference(
5274        s: &Array2<f64>,
5275        rhs: &Array1<f64>,
5276        max_iterations: usize,
5277        relative_tolerance: f64,
5278    ) -> Array1<f64> {
5279        let k = rhs.len();
5280        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5281        if rhs_norm == 0.0 {
5282            return Array1::<f64>::zeros(k);
5283        }
5284        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5285        let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5286        let mut x = Array1::<f64>::zeros(k);
5287        let mut r = rhs.clone();
5288        let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5289        let mut p = z.clone();
5290        let mut sp = Array1::<f64>::zeros(k);
5291        let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5292        for _ in 0..max_iterations.max(1) {
5293            for row in 0..k {
5294                let mut acc = 0.0;
5295                for col in 0..k {
5296                    acc += s[[row, col]] * p[col];
5297                }
5298                sp[row] = acc;
5299            }
5300            let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5301            let alpha = rz / p_sp;
5302            for idx in 0..k {
5303                x[idx] += alpha * p[idx];
5304                r[idx] -= alpha * sp[idx];
5305            }
5306            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5307            if r_norm <= tol {
5308                break;
5309            }
5310            for idx in 0..k {
5311                z[idx] = inv_diag[idx] * r[idx];
5312            }
5313            let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5314            let beta = rz_next / rz;
5315            for idx in 0..k {
5316                p[idx] = z[idx] + beta * p[idx];
5317            }
5318            rz = rz_next;
5319        }
5320        x
5321    }
5322
5323    #[test]
5324    fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5325        let (s, rhs) = device_pcg_fixture(512);
5326        let max_iterations = 200usize;
5327        let relative_tolerance = 1.0e-12;
5328        let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5329        let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5330            &s,
5331            &rhs,
5332            max_iterations,
5333            relative_tolerance,
5334        ) {
5335            Ok(result) => result,
5336            // #1017 — fail loud, never skip-pass: this fixture clears the device
5337            // offload floor, so a CUDA device that is PRESENT yet declines/returns
5338            // Err means the device PCG kernel does not run on GPU (a real fault that
5339            // must not masquerade as a pass via this skip). Legit skip ONLY when no
5340            // usable CUDA device exists (CPU CI). The exact `ArrowSchurGpuFailure`
5341            // variant is folded into the assert message as the diagnostic.
5342            Err(failure) => {
5343                assert!(
5344                    gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5345                    "#1017: CUDA device present but the device reduced-beta PCG \
5346                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5347                     the kernel does not run correctly on GPU"
5348                );
5349                return;
5350            }
5351        };
5352        let max_err = cpu
5353            .iter()
5354            .zip(device.iter())
5355            .map(|(a, b)| (a - b).abs())
5356            .fold(0.0_f64, f64::max);
5357        assert!(
5358            max_err <= 1.0e-10,
5359            "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5360        );
5361        assert!(diag.matvec_calls > 0);
5362        assert_eq!(diag.matvec_calls, diag.iterations);
5363    }
5364
5365    #[test]
5366    fn dense_reference_matches_independent_solve() {
5367        let sys = build_fixture(4, 5, 3, 7);
5368        let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5369        // Re-solve by an independent matrix build and a textbook
5370        // Gaussian-elimination Cholesky to guard against typos in the
5371        // reference implementation itself.
5372        let n = sys.rows.len();
5373        let d = sys.d;
5374        let k = sys.k;
5375        let total = n * d + k;
5376        let mut h = Array2::<f64>::zeros((total, total));
5377        let mut g = ndarray::Array1::<f64>::zeros(total);
5378        for (i, row) in sys.rows.iter().enumerate() {
5379            let base = i * d;
5380            for c in 0..d {
5381                for r in 0..d {
5382                    h[[base + r, base + c]] = row.htt[[r, c]];
5383                }
5384            }
5385            for c in 0..k {
5386                for r in 0..d {
5387                    h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5388                    h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5389                }
5390            }
5391            for r in 0..d {
5392                g[base + r] = row.gt[r];
5393            }
5394        }
5395        for c in 0..k {
5396            for r in 0..k {
5397                h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5398            }
5399            g[n * d + c] = sys.gb[c];
5400        }
5401        let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5402        let rhs = g.mapv(|v| -v);
5403        let expected = cholesky_solve_vector(l.view(), rhs.view());
5404        for i in 0..n * d {
5405            assert!(
5406                (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5407                "delta_t[{i}] mismatch: got {} expected {}",
5408                solution.delta_t[i],
5409                expected[i]
5410            );
5411        }
5412        for a in 0..k {
5413            assert!(
5414                (solution.delta_beta[a] - expected[n * d + a]).abs()
5415                    < 1e-10 * (1.0 + expected[n * d + a].abs()),
5416                "delta_beta[{a}] mismatch"
5417            );
5418        }
5419    }
5420
5421    /// #1017: the row-procedural reduced-Schur matvec (the matrix-free SAE
5422    /// host backend) auto-fans its per-row point-elimination sum across rayon
5423    /// over fixed row chunks when at the top level (`n ≥
5424    /// SCHUR_MATVEC_PARALLEL_ROW_MIN`), and stays serial when already inside a
5425    /// rayon worker. The chunk-ordered fold makes the parallel result
5426    /// **deterministic** (two parallel calls are bit-identical — scheduling
5427    /// cannot change the numbers) and it agrees with the serial accumulation up
5428    /// to ULP-scale chunk reassociation (the #1017 verification gate). That
5429    /// reassociation is a genuine f64 departure from serial, so the criterion
5430    /// ranking across topology candidates is stable only up to the reassociation
5431    /// margin: a near-tie winner inside that margin can flip. This is NOT an
5432    /// exact no-move guarantee (#1211); for that, the ranking path must use the
5433    /// fixed-order serial accumulation.
5434    #[test]
5435    fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5436        use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5437        let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; // trips the parallel path
5438        let d = 3usize;
5439        let k = 24usize;
5440        let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5441        // Install a matrix-free forward/transpose pair that reads the dense
5442        // `htbeta` slabs the fixture already populated, so the procedural
5443        // backend has a well-defined operator to apply (and exercises exactly
5444        // the sparse gather/scatter the SAE Kronecker path drives).
5445        let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5446        let forward_slabs = slabs.clone();
5447        let transpose_slabs = slabs;
5448        sys.set_row_htbeta_operator(
5449            move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5450                let h = &forward_slabs[row];
5451                for r in 0..h.nrows() {
5452                    let mut acc = 0.0_f64;
5453                    for c in 0..h.ncols() {
5454                        acc += h[[r, c]] * x[c];
5455                    }
5456                    out[r] = acc;
5457                }
5458            },
5459            move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5460                let h = &transpose_slabs[row];
5461                for r in 0..h.nrows() {
5462                    for c in 0..h.ncols() {
5463                        out[c] += h[[r, c]] * v[r];
5464                    }
5465                }
5466            },
5467        );
5468
5469        let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5470            .expect("row-procedural matvec backend builds for matrix-free system");
5471        let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5472
5473        // Top-level call: auto-selects the parallel chunk-fold. Run twice and
5474        // assert bit-identity — the chunk-ordered reduction must not depend on
5475        // thread scheduling.
5476        let mut out_parallel_a = Array1::<f64>::zeros(k);
5477        matvec(&x, &mut out_parallel_a);
5478        let mut out_parallel_b = Array1::<f64>::zeros(k);
5479        matvec(&x, &mut out_parallel_b);
5480        for a in 0..k {
5481            assert_eq!(
5482                out_parallel_a[a].to_bits(),
5483                out_parallel_b[a].to_bits(),
5484                "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5485            );
5486        }
5487
5488        // Inside a rayon worker: auto-selects the serial path (nested-rayon
5489        // guard). `install` runs the closure on a pool thread, so
5490        // `current_thread_index()` is `Some`. The serial running sum and the
5491        // chunk-ordered parallel fold differ only by f64 reassociation.
5492        let mut out_serial = Array1::<f64>::zeros(k);
5493        rayon::ThreadPoolBuilder::new()
5494            .num_threads(2)
5495            .build()
5496            .expect("build rayon pool")
5497            .install(|| matvec(&x, &mut out_serial));
5498
5499        let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5500        for a in 0..k {
5501            let diff = (out_parallel_a[a] - out_serial[a]).abs();
5502            assert!(
5503                diff <= 1e-12 * (1.0 + max_abs),
5504                "row-procedural matvec parallel vs serial diverged beyond reassociation \
5505                 at index {a}: {} vs {} (diff={diff:e})",
5506                out_parallel_a[a],
5507                out_serial[a]
5508            );
5509        }
5510    }
5511
5512    /// #1017/#1026 — the frames-engaged CPU reduced-Schur matvec
5513    /// [`sae_framed_schur_matvec_cpu`] (the bit-parity oracle the GPU kernel
5514    /// mirrors) must equal the dense reduced Schur `S = (P_ββ + ρ_β I) −
5515    /// Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)` formed by the canonical dense
5516    /// reference, on a small framed system with mixed per-atom ranks
5517    /// (`r_k < p` framed + `r_k = p` un-framed). Size-independent gate.
5518    #[test]
5519    fn framed_sae_schur_matvec_matches_dense_reference() {
5520        use crate::arrow_schur::{
5521            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5522            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5523        };
5524
5525        let p = 4usize;
5526        // Three atoms: ranks 2 (framed), 4 (un-framed), 3 (framed).
5527        let ranks = vec![2usize, 4usize, 3usize];
5528        let basis_sizes = vec![2usize, 1usize, 2usize];
5529        let n_atoms = ranks.len();
5530        let mut border_offsets = Vec::with_capacity(n_atoms);
5531        let mut acc = 0usize;
5532        for k in 0..n_atoms {
5533            border_offsets.push(acc);
5534            acc += basis_sizes[k] * ranks[k];
5535        }
5536        let border_dim = acc; // 2*2 + 1*4 + 2*3 = 14
5537
5538        let mut state = 0x1234_5678_9abc_def0u64;
5539        let mut sample = || -> f64 {
5540            state = state
5541                .wrapping_mul(6364136223846793005)
5542                .wrapping_add(1442695040888963407);
5543            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5544        };
5545
5546        // Per-atom orthonormal-ish frames U_k (p × r_k) for the W = U_iᵀU_j
5547        // factors; un-framed atom (r=p) uses U = I_p.
5548        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5549        for k in 0..n_atoms {
5550            let r = ranks[k];
5551            let mut u = Array2::<f64>::zeros((p, r));
5552            for i in 0..p {
5553                for j in 0..r {
5554                    u[[i, j]] = if r == p && i == j {
5555                        1.0
5556                    } else if r == p {
5557                        0.0
5558                    } else {
5559                        sample()
5560                    };
5561                }
5562            }
5563            frames.push(u);
5564        }
5565        let w_of = |i: usize, j: usize| -> Array2<f64> {
5566            let (ui, uj) = (&frames[i], &frames[j]);
5567            let (ri, rj) = (ranks[i], ranks[j]);
5568            let mut w = Array2::<f64>::zeros((ri, rj));
5569            for a in 0..ri {
5570                for b in 0..rj {
5571                    let mut s = 0.0;
5572                    for c in 0..p {
5573                        s += ui[[c, a]] * uj[[c, b]];
5574                    }
5575                    w[[a, b]] = s;
5576                }
5577            }
5578            w
5579        };
5580
5581        // Co-occurring data-fit blocks: all diagonal pairs + one cross (0,2).
5582        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5583        let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5584        pairs.sort();
5585        for &(i, j) in &pairs {
5586            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5587            let mut g = Array2::<f64>::zeros((mi, mj));
5588            for r in 0..mi {
5589                for c in 0..mj {
5590                    g[[r, c]] = 0.3 * sample();
5591                }
5592            }
5593            // Make diagonal blocks SPD-leaning so S stays PD.
5594            if i == j {
5595                for r in 0..mi.min(mj) {
5596                    g[[r, r]] += mi as f64 + 2.0;
5597                }
5598            }
5599            frame_blocks.push(FactoredFrameGBlock {
5600                atom_i: i,
5601                atom_j: j,
5602                g,
5603                w: w_of(i, j),
5604            });
5605        }
5606
5607        // Smooth blocks λ S_k (M_k × M_k), SPD.
5608        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5609        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5610        for k in 0..n_atoms {
5611            let m = basis_sizes[k];
5612            let mut a = Array2::<f64>::zeros((m, m));
5613            for r in 0..m {
5614                for c in 0..m {
5615                    a[[r, c]] = 0.2 * sample();
5616                }
5617            }
5618            let mut s = a.t().dot(&a);
5619            for r in 0..m {
5620                s[[r, r]] += 1.0;
5621            }
5622            smooth_blocks.push(DeviceSaeSmoothBlock {
5623                global_offset: border_offsets[k],
5624                factor_a: s,
5625            });
5626            smooth_ranks.push(ranks[k]);
5627        }
5628
5629        // Build the system: n rows, dense htbeta slabs (q_i × border_dim).
5630        let n = 6usize;
5631        let q = 3usize;
5632        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5633        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5634        for i in 0..n {
5635            // SPD htt.
5636            let mut a = Array2::<f64>::zeros((q, q));
5637            for r in 0..q {
5638                for c in 0..q {
5639                    a[[r, c]] = sample();
5640                }
5641            }
5642            let mut htt = a.t().dot(&a);
5643            for r in 0..q {
5644                htt[[r, r]] += q as f64 + 1.0;
5645            }
5646            sys.rows[i].htt = htt;
5647            let mut slab = vec![0.0_f64; q * border_dim];
5648            for c in 0..q {
5649                for col in 0..border_dim {
5650                    let v = 0.15 * sample();
5651                    slab[c * border_dim + col] = v;
5652                    sys.rows[i].htbeta[[c, col]] = v;
5653                }
5654            }
5655            row_htbeta.push(slab);
5656        }
5657
5658        // Dense H_ββ from the SAME penalty ops (so the dense reference's S
5659        // matches the device penalty side exactly).
5660        let data_op =
5661            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5662                .expect("frame op");
5663        let mut hbb = data_op.to_dense();
5664        for k in 0..n_atoms {
5665            let op = IdentityRightKroneckerPenaltyOp {
5666                factor_a: smooth_blocks[k].factor_a.clone(),
5667                p: ranks[k],
5668                global_offset: border_offsets[k],
5669                k: border_dim,
5670            };
5671            let d = op.to_dense();
5672            for r in 0..border_dim {
5673                for c in 0..border_dim {
5674                    hbb[[r, c]] += d[[r, c]];
5675                }
5676            }
5677        }
5678        sys.hbb = hbb;
5679
5680        let data = DeviceSaePcgData {
5681            p,
5682            beta_dim: border_dim,
5683            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5684            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5685            smooth_blocks,
5686            sparse_g_blocks: Vec::new(),
5687            frame: Some(DeviceSaeFrameData {
5688                ranks: ranks.clone(),
5689                basis_sizes: basis_sizes.clone(),
5690                border_offsets: border_offsets.clone(),
5691                frame_blocks,
5692                smooth_ranks,
5693                row_htbeta,
5694            }),
5695        };
5696
5697        let ridge_t = 1e-7;
5698        let ridge_beta = 1e-6;
5699
5700        // Dense reference reduced Schur S (border_dim × border_dim), formed
5701        // exactly as solve_arrow_newton_step_dense_reference assembles the
5702        // bordered Hessian and eliminates the t-block.
5703        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5704        for r in 0..border_dim {
5705            for c in 0..border_dim {
5706                s_dense[[r, c]] = sys.hbb[[r, c]];
5707            }
5708            s_dense[[r, r]] += ridge_beta;
5709        }
5710        for row in &sys.rows {
5711            let mut htt = row.htt.clone();
5712            for d in 0..q {
5713                htt[[d, d]] += ridge_t;
5714            }
5715            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5716                .expect("htt PD");
5717            // Y = (htt)⁻¹ htbeta  (q × border_dim); S -= htbetaᵀ Y.
5718            let mut y = Array2::<f64>::zeros((q, border_dim));
5719            for col in 0..border_dim {
5720                let mut e = Array1::<f64>::zeros(q);
5721                for r in 0..q {
5722                    e[r] = row.htbeta[[r, col]];
5723                }
5724                let solved = cholesky_solve_vector(factor.view(), e.view());
5725                for r in 0..q {
5726                    y[[r, col]] = solved[r];
5727                }
5728            }
5729            for r in 0..border_dim {
5730                for c in 0..border_dim {
5731                    let mut acc = 0.0;
5732                    for d in 0..q {
5733                        acc += row.htbeta[[d, r]] * y[[d, c]];
5734                    }
5735                    s_dense[[r, c]] -= acc;
5736                }
5737            }
5738        }
5739
5740        // Probe vectors: compare S·x from the device-data CPU oracle vs dense S·x.
5741        let mut max_rel = 0.0_f64;
5742        for trial in 0..4 {
5743            let x: Vec<f64> = (0..border_dim)
5744                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5745                .collect();
5746            let mut got = vec![0.0_f64; border_dim];
5747            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5748                .expect("framed matvec");
5749            let mut want = vec![0.0_f64; border_dim];
5750            for r in 0..border_dim {
5751                let mut acc = 0.0;
5752                for c in 0..border_dim {
5753                    acc += s_dense[[r, c]] * x[c];
5754                }
5755                want[r] = acc;
5756            }
5757            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5758            for a in 0..border_dim {
5759                let rel = (got[a] - want[a]).abs() / scale;
5760                max_rel = max_rel.max(rel);
5761            }
5762        }
5763        assert!(
5764            max_rel <= 1e-10,
5765            "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5766        );
5767    }
5768
5769    /// #1017/#1026 — large-K, many-atom, dense-cross-pair parity for the framed
5770    /// SAE reduced-Schur CPU oracle. The small `framed_sae_schur_matvec_matches_dense_reference`
5771    /// pins `border_dim=14`/3 atoms/1 cross pair; the interactions that only
5772    /// appear at scale — variable per-atom `r_k` (mixed framed `r_k<p` and
5773    /// un-framed `r_k=p`), the prefix-sum `border_offsets`, and dense cross-atom
5774    /// `W_ij` coupling across MANY co-occurring `frame_blocks` and many `row_htbeta`
5775    /// slabs — were validated only on the A100. This pins the CPU oracle (and
5776    /// therefore the device kernel it mirrors) against the dense reduced Schur at
5777    /// 40 atoms / `border_dim≈240` / a neighbour-coupled cross-pair set, on CPU.
5778    #[test]
5779    fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5780        use crate::arrow_schur::{
5781            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5782            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5783        };
5784
5785        let p = 12usize;
5786        let n_atoms = 40usize;
5787        // Variable per-atom rank: most atoms are genuinely framed (r_k<p), every
5788        // fifth atom is un-framed (r_k=p, U=I_p — the within-atom G⊗I_r collapse).
5789        let ranks: Vec<usize> = (0..n_atoms)
5790            .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5791            .collect();
5792        let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5793        let mut border_offsets = Vec::with_capacity(n_atoms);
5794        let mut acc = 0usize;
5795        for k in 0..n_atoms {
5796            border_offsets.push(acc);
5797            acc += basis_sizes[k] * ranks[k];
5798        }
5799        let border_dim = acc;
5800
5801        let mut state = 0x0bad_c0de_dead_beefu64;
5802        let mut sample = || -> f64 {
5803            state = state
5804                .wrapping_mul(6364136223846793005)
5805                .wrapping_add(1442695040888963407);
5806            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5807        };
5808
5809        // Per-atom frames U_k (p × r_k); un-framed atom (r=p) uses U = I_p.
5810        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5811        for k in 0..n_atoms {
5812            let r = ranks[k];
5813            let mut u = Array2::<f64>::zeros((p, r));
5814            for i in 0..p {
5815                for j in 0..r {
5816                    u[[i, j]] = if r == p {
5817                        if i == j { 1.0 } else { 0.0 }
5818                    } else {
5819                        sample()
5820                    };
5821                }
5822            }
5823            frames.push(u);
5824        }
5825        let w_of = |i: usize, j: usize| -> Array2<f64> {
5826            let (ui, uj) = (&frames[i], &frames[j]);
5827            let (ri, rj) = (ranks[i], ranks[j]);
5828            let mut w = Array2::<f64>::zeros((ri, rj));
5829            for a in 0..ri {
5830                for b in 0..rj {
5831                    let mut s = 0.0;
5832                    for c in 0..p {
5833                        s += ui[[c, a]] * uj[[c, b]];
5834                    }
5835                    w[[a, b]] = s;
5836                }
5837            }
5838            w
5839        };
5840
5841        // Co-occurring data-fit blocks: every diagonal pair + each neighbour pair
5842        // (k,k+1) and its transpose — dense cross coupling across the whole border.
5843        let mut pairs: Vec<(usize, usize)> = Vec::new();
5844        for k in 0..n_atoms {
5845            pairs.push((k, k));
5846        }
5847        for k in 0..n_atoms - 1 {
5848            pairs.push((k, k + 1));
5849            pairs.push((k + 1, k));
5850        }
5851        pairs.sort_unstable();
5852        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5853        for &(i, j) in &pairs {
5854            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5855            let mut g = Array2::<f64>::zeros((mi, mj));
5856            for r in 0..mi {
5857                for c in 0..mj {
5858                    g[[r, c]] = 0.3 * sample();
5859                }
5860            }
5861            if i == j {
5862                for r in 0..mi.min(mj) {
5863                    g[[r, r]] += mi as f64 + 2.0;
5864                }
5865            }
5866            frame_blocks.push(FactoredFrameGBlock {
5867                atom_i: i,
5868                atom_j: j,
5869                g,
5870                w: w_of(i, j),
5871            });
5872        }
5873
5874        // Smooth blocks λ S_k (M_k × M_k), SPD.
5875        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5876        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5877        for k in 0..n_atoms {
5878            let m = basis_sizes[k];
5879            let mut a = Array2::<f64>::zeros((m, m));
5880            for r in 0..m {
5881                for c in 0..m {
5882                    a[[r, c]] = 0.2 * sample();
5883                }
5884            }
5885            let mut s = a.t().dot(&a);
5886            for r in 0..m {
5887                s[[r, r]] += 1.0;
5888            }
5889            smooth_blocks.push(DeviceSaeSmoothBlock {
5890                global_offset: border_offsets[k],
5891                factor_a: s,
5892            });
5893            smooth_ranks.push(ranks[k]);
5894        }
5895
5896        // n rows with SPD htt and full-width (q × border_dim) htbeta slabs.
5897        let n = 8usize;
5898        let q = 3usize;
5899        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5900        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5901        for i in 0..n {
5902            let mut a = Array2::<f64>::zeros((q, q));
5903            for r in 0..q {
5904                for c in 0..q {
5905                    a[[r, c]] = sample();
5906                }
5907            }
5908            let mut htt = a.t().dot(&a);
5909            for r in 0..q {
5910                htt[[r, r]] += q as f64 + 1.0;
5911            }
5912            sys.rows[i].htt = htt;
5913            let mut slab = vec![0.0_f64; q * border_dim];
5914            for c in 0..q {
5915                for col in 0..border_dim {
5916                    let v = 0.15 * sample();
5917                    slab[c * border_dim + col] = v;
5918                    sys.rows[i].htbeta[[c, col]] = v;
5919                }
5920            }
5921            row_htbeta.push(slab);
5922        }
5923
5924        // Dense H_ββ from the SAME penalty ops (data-fit + smooth), so the dense
5925        // reference S matches the device penalty side exactly.
5926        let data_op =
5927            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5928                .expect("frame op");
5929        let mut hbb = data_op.to_dense();
5930        for k in 0..n_atoms {
5931            let op = IdentityRightKroneckerPenaltyOp {
5932                factor_a: smooth_blocks[k].factor_a.clone(),
5933                p: ranks[k],
5934                global_offset: border_offsets[k],
5935                k: border_dim,
5936            };
5937            let d = op.to_dense();
5938            for r in 0..border_dim {
5939                for c in 0..border_dim {
5940                    hbb[[r, c]] += d[[r, c]];
5941                }
5942            }
5943        }
5944        sys.hbb = hbb;
5945
5946        let data = DeviceSaePcgData {
5947            p,
5948            beta_dim: border_dim,
5949            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5950            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5951            smooth_blocks,
5952            sparse_g_blocks: Vec::new(),
5953            frame: Some(DeviceSaeFrameData {
5954                ranks: ranks.clone(),
5955                basis_sizes: basis_sizes.clone(),
5956                border_offsets: border_offsets.clone(),
5957                frame_blocks,
5958                smooth_ranks,
5959                row_htbeta,
5960            }),
5961        };
5962
5963        let ridge_t = 1e-7;
5964        let ridge_beta = 1e-6;
5965
5966        // Dense reference reduced Schur S = (hbb + ridge_beta I) - Σ_i htbetaᵀ (htt+ridge_t I)⁻¹ htbeta.
5967        let mut s_dense = sys.hbb.clone();
5968        for r in 0..border_dim {
5969            s_dense[[r, r]] += ridge_beta;
5970        }
5971        for row in &sys.rows {
5972            let mut htt = row.htt.clone();
5973            for d in 0..q {
5974                htt[[d, d]] += ridge_t;
5975            }
5976            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5977                .expect("htt PD");
5978            let mut y = Array2::<f64>::zeros((q, border_dim));
5979            for col in 0..border_dim {
5980                let mut e = Array1::<f64>::zeros(q);
5981                for r in 0..q {
5982                    e[r] = row.htbeta[[r, col]];
5983                }
5984                let solved = cholesky_solve_vector(factor.view(), e.view());
5985                for r in 0..q {
5986                    y[[r, col]] = solved[r];
5987                }
5988            }
5989            for r in 0..border_dim {
5990                for c in 0..border_dim {
5991                    let mut acc = 0.0;
5992                    for d in 0..q {
5993                        acc += row.htbeta[[d, r]] * y[[d, c]];
5994                    }
5995                    s_dense[[r, c]] -= acc;
5996                }
5997            }
5998        }
5999
6000        let mut max_rel = 0.0_f64;
6001        for trial in 0..4 {
6002            let x: Vec<f64> = (0..border_dim)
6003                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6004                .collect();
6005            let mut got = vec![0.0_f64; border_dim];
6006            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6007                .expect("framed matvec");
6008            let mut want = vec![0.0_f64; border_dim];
6009            for r in 0..border_dim {
6010                let mut acc = 0.0;
6011                for c in 0..border_dim {
6012                    acc += s_dense[[r, c]] * x[c];
6013                }
6014                want[r] = acc;
6015            }
6016            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6017            for a in 0..border_dim {
6018                let rel = (got[a] - want[a]).abs() / scale;
6019                max_rel = max_rel.max(rel);
6020            }
6021        }
6022        assert!(
6023            max_rel <= 1e-10,
6024            "large-K framed SAE Schur matvec vs dense reference diverged: \
6025             max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6026        );
6027    }
6028
6029    /// #1017/#1026 GPU arm: when a CUDA device admits the framed SAE PCG, its
6030    /// solved `δβ` must match the CPU dense reduced-system solve of the SAME
6031    /// framed system (size-independent — a small device validates the kernel).
6032    /// Skips cleanly (returns) when no device is available or the policy
6033    /// declines (`solve_sae_matrix_free_pcg` → `Unavailable`).
6034    #[test]
6035    fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6036        use crate::arrow_schur::{
6037            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6038            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6039        };
6040
6041        // Large enough to clear the device-offload policy floor (k ≥ 32 and
6042        // n·k·d·iters ≥ MATVEC_OFFLOAD_FLOPS_MIN) so the GPU kernel actually
6043        // runs on a device rather than the policy declining.
6044        let p = 6usize;
6045        let n_atoms = 8usize;
6046        let ranks: Vec<usize> = (0..n_atoms)
6047            .map(|k| if k % 2 == 0 { 3usize } else { p })
6048            .collect();
6049        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6050        let mut border_offsets = Vec::with_capacity(n_atoms);
6051        let mut acc = 0usize;
6052        for k in 0..n_atoms {
6053            border_offsets.push(acc);
6054            acc += basis_sizes[k] * ranks[k];
6055        }
6056        let border_dim = acc; // Σ M_k·r_k = 4·(3·3) + 4·(3·6) = 36 + 72 = 108
6057
6058        let mut state = 0xfeed_face_dead_beefu64;
6059        let mut sample = || -> f64 {
6060            state = state
6061                .wrapping_mul(6364136223846793005)
6062                .wrapping_add(1442695040888963407);
6063            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6064        };
6065        let mut frames: Vec<Array2<f64>> = Vec::new();
6066        for k in 0..n_atoms {
6067            let r = ranks[k];
6068            let mut u = Array2::<f64>::zeros((p, r));
6069            for i in 0..p {
6070                for j in 0..r {
6071                    u[[i, j]] = if r == p && i == j {
6072                        1.0
6073                    } else if r == p {
6074                        0.0
6075                    } else {
6076                        sample()
6077                    };
6078                }
6079            }
6080            frames.push(u);
6081        }
6082        let w_of = |i: usize, j: usize| {
6083            let (ui, uj) = (&frames[i], &frames[j]);
6084            let (ri, rj) = (ranks[i], ranks[j]);
6085            let mut w = Array2::<f64>::zeros((ri, rj));
6086            for a in 0..ri {
6087                for b in 0..rj {
6088                    let mut s = 0.0;
6089                    for c in 0..p {
6090                        s += ui[[c, a]] * uj[[c, b]];
6091                    }
6092                    w[[a, b]] = s;
6093                }
6094            }
6095            w
6096        };
6097        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6098        // A few off-diagonal cross blocks (symmetric pairs).
6099        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6100            pairs.push((i, j));
6101            pairs.push((j, i));
6102        }
6103        let mut frame_blocks = Vec::new();
6104        for &(i, j) in &pairs {
6105            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6106            let mut g = Array2::<f64>::zeros((mi, mj));
6107            for r in 0..mi {
6108                for c in 0..mj {
6109                    g[[r, c]] = 0.25 * sample();
6110                }
6111            }
6112            if i == j {
6113                for r in 0..mi.min(mj) {
6114                    g[[r, r]] += mi as f64 + 2.0;
6115                }
6116            }
6117            frame_blocks.push(FactoredFrameGBlock {
6118                atom_i: i,
6119                atom_j: j,
6120                g,
6121                w: w_of(i, j),
6122            });
6123        }
6124        let mut smooth_blocks = Vec::new();
6125        let mut smooth_ranks = Vec::new();
6126        for k in 0..n_atoms {
6127            let m = basis_sizes[k];
6128            let mut a = Array2::<f64>::zeros((m, m));
6129            for r in 0..m {
6130                for c in 0..m {
6131                    a[[r, c]] = 0.2 * sample();
6132                }
6133            }
6134            let mut s = a.t().dot(&a);
6135            for r in 0..m {
6136                s[[r, r]] += 1.0;
6137            }
6138            smooth_blocks.push(DeviceSaeSmoothBlock {
6139                global_offset: border_offsets[k],
6140                factor_a: s,
6141            });
6142            smooth_ranks.push(ranks[k]);
6143        }
6144        let n = 400usize;
6145        let q = 4usize;
6146        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6147        let mut row_htbeta = Vec::new();
6148        for i in 0..n {
6149            let mut a = Array2::<f64>::zeros((q, q));
6150            for r in 0..q {
6151                for c in 0..q {
6152                    a[[r, c]] = sample();
6153                }
6154            }
6155            let mut htt = a.t().dot(&a);
6156            for r in 0..q {
6157                htt[[r, r]] += q as f64 + 1.0;
6158            }
6159            sys.rows[i].htt = htt;
6160            let mut slab = vec![0.0_f64; q * border_dim];
6161            for c in 0..q {
6162                for col in 0..border_dim {
6163                    // Small entries: with 400 rows the reduced-Schur subtraction
6164                    // Σ_i H_βtᵀ H_tt⁻¹ H_tβ must not overwhelm the PD penalty.
6165                    let v = 0.02 * sample();
6166                    slab[c * border_dim + col] = v;
6167                    sys.rows[i].htbeta[[c, col]] = v;
6168                }
6169            }
6170            row_htbeta.push(slab);
6171        }
6172        let data_op =
6173            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6174                .expect("frame op");
6175        let mut hbb = data_op.to_dense();
6176        for k in 0..n_atoms {
6177            let op = IdentityRightKroneckerPenaltyOp {
6178                factor_a: smooth_blocks[k].factor_a.clone(),
6179                p: ranks[k],
6180                global_offset: border_offsets[k],
6181                k: border_dim,
6182            };
6183            let d = op.to_dense();
6184            for r in 0..border_dim {
6185                for c in 0..border_dim {
6186                    hbb[[r, c]] += d[[r, c]];
6187                }
6188            }
6189        }
6190        sys.hbb = hbb;
6191        let data = DeviceSaePcgData {
6192            p,
6193            beta_dim: border_dim,
6194            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6195            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6196            smooth_blocks,
6197            sparse_g_blocks: Vec::new(),
6198            frame: Some(DeviceSaeFrameData {
6199                ranks: ranks.clone(),
6200                basis_sizes: basis_sizes.clone(),
6201                border_offsets: border_offsets.clone(),
6202                frame_blocks,
6203                smooth_ranks,
6204                row_htbeta,
6205            }),
6206        };
6207        let ridge_t = 1e-7;
6208        let ridge_beta = 1e-6;
6209        let rhs: Array1<f64> =
6210            Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6211
6212        let (device, diag) =
6213            match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6214                Ok(result) => result,
6215                // #1017 — fail loud, never skip-pass: this fixture clears the device
6216                // offload floor, so a CUDA device that is PRESENT yet declines means the
6217                // framed device PCG kernel does not run on GPU (the fault must not pass
6218                // silently). Legit skip ONLY when no usable CUDA device exists (CPU CI).
6219                // The exact `ArrowSchurGpuFailure` variant is folded into the assert.
6220                Err(failure) => {
6221                    assert!(
6222                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6223                        "#1017: CUDA device present but the framed device SAE PCG \
6224                     declined/faulted instead of returning a result (tag: {failure:?}) — \
6225                     the kernel does not run correctly on GPU"
6226                    );
6227                    return;
6228                }
6229            };
6230
6231        // #1551 PARITY GATE — operator-residual, NOT solution-vector equality.
6232        //
6233        // The honest GPU↔CPU contract for an iterative solve of `S·δβ = rhs` is
6234        // that the device solution SOLVES the system DEFINED BY THE CPU ORACLE to
6235        // PCG tolerance — i.e. `‖S_cpu·δβ_device − rhs‖ / ‖rhs‖ ≤ tol`, where
6236        // `S_cpu` is applied with the bit-for-bit CPU oracle matvec the device
6237        // kernel mirrors (`sae_framed_schur_matvec_cpu`). This is the correct,
6238        // conditioning-robust gate: a near-singular assembled `S` has a large
6239        // condition number `κ(S)`, which amplifies an O(ε) residual difference
6240        // into an O(κ·ε) *solution-vector* difference, so comparing δβ vectors
6241        // would spuriously fail even when both operators are bit-identical and
6242        // both solves converged. (Historically this test compared δβ against a
6243        // dense-Cholesky reference and "failed" with max_rel≈0.9 because the
6244        // dense solve itself only reached ‖S·x−rhs‖≈0.1 on this fixture's
6245        // ill-conditioned S while the device PCG reached ~1e-12 — the device was
6246        // MORE accurate than the reference, not wrong. The kernel correctness is
6247        // pinned conditioning-free by `framed_sae_device_matvec_matches_cpu_oracle_*`.)
6248        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6249        let oracle_resid = |x: &Array1<f64>| -> f64 {
6250            let mut sx = vec![0.0_f64; border_dim];
6251            sae_framed_schur_matvec_cpu(
6252                &sys,
6253                &data,
6254                ridge_t,
6255                ridge_beta,
6256                x.as_slice().unwrap(),
6257                &mut sx,
6258            )
6259            .expect("cpu oracle matvec");
6260            let mut acc = 0.0_f64;
6261            for a in 0..border_dim {
6262                let e = sx[a] - rhs[a];
6263                acc += e * e;
6264            }
6265            acc.sqrt()
6266        };
6267        let s_dev_resid = oracle_resid(&device);
6268        let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6269
6270        // Independent CPU iterative solve of the SAME operator with the SAME
6271        // Jacobi preconditioner the device builds, via the shared `pcg_core`. If
6272        // the device kernel computed a different operator, the two converged
6273        // residuals could not BOTH be tiny.
6274        let precond = {
6275            let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6276            // The reduced-Schur diagonal subtraction (device `arrow_sae_frame_diag_sub`)
6277            // mirrored on the host for the Jacobi preconditioner.
6278            let mut diag = d;
6279            for (i, row) in sys.rows.iter().enumerate() {
6280                let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6281                let qi = sys.row_dims[i];
6282                if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6283                    continue;
6284                }
6285                let mut block = row.htt.clone();
6286                for dd in 0..qi {
6287                    block[[dd, dd]] += ridge_t;
6288                }
6289                let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6290                    .expect("row htt PD");
6291                // ainv = (H_tt+ρI)⁻¹ column by column.
6292                let mut ainv = Array2::<f64>::zeros((qi, qi));
6293                for col in 0..qi {
6294                    let mut e = Array1::<f64>::zeros(qi);
6295                    e[col] = 1.0;
6296                    let s = cholesky_solve_vector(factor.view(), e.view());
6297                    for r in 0..qi {
6298                        ainv[[r, col]] = s[r];
6299                    }
6300                }
6301                for a in 0..border_dim {
6302                    let mut quad = 0.0_f64;
6303                    for c in 0..qi {
6304                        let hc = slab[c * border_dim + a];
6305                        for dd in 0..qi {
6306                            quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6307                        }
6308                    }
6309                    diag[a] -= quad;
6310                }
6311            }
6312            Array1::from_vec(diag)
6313        };
6314        let mut cpu = Array1::<f64>::zeros(border_dim);
6315        let cpu_result = {
6316            let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6317                let mut tmp = vec![0.0_f64; border_dim];
6318                sae_framed_schur_matvec_cpu(
6319                    &sys,
6320                    &data,
6321                    ridge_t,
6322                    ridge_beta,
6323                    v.as_slice().unwrap(),
6324                    &mut tmp,
6325                )
6326                .expect("cpu oracle matvec");
6327                out.assign(&Array1::from_vec(tmp));
6328            };
6329            gam_linalg::pcg::pcg_core(
6330                &mut apply,
6331                &rhs.view(),
6332                &precond.view(),
6333                1e-12,
6334                800,
6335                32,
6336                false,
6337                gam_linalg::pcg::DotReduction::Serial,
6338                &mut cpu.view_mut(),
6339            )
6340        };
6341        let s_cpu_resid = oracle_resid(&cpu);
6342        let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6343
6344        // GATE 1: the device solution solves the CPU-oracle system to PCG-grade
6345        // accuracy (proves device kernel == CPU operator AND device PCG converged).
6346        assert!(
6347            dev_rel_resid <= 1e-7,
6348            "[#1551] device δβ does not solve the CPU-oracle system: \
6349             ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6350             device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6351             means the device matvec is a DIFFERENT operator (kernel bug)",
6352            diag.stopping_reason,
6353            diag.iterations,
6354            diag.final_relative_residual,
6355        );
6356        // GATE 2: the independent CPU iterative solve of the same operator with the
6357        // same preconditioner also converges — both paths agree on the operator.
6358        assert!(
6359            cpu_rel_resid <= 1e-6,
6360            "[#1551] CPU pcg_core failed to solve the oracle system: \
6361             ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6362            cpu_result.stop,
6363            cpu_result.iterations,
6364        );
6365    }
6366
6367    /// Host mirror of the device `sae_frame_penalty_diag_host` for the framed
6368    /// Jacobi preconditioner (penalty diagonal only; the reduced-Schur diagonal
6369    /// subtraction is applied by the caller). Test-only.
6370    fn sae_frame_penalty_diag_host_for_test(
6371        data: &DeviceSaePcgData,
6372        ridge_beta: f64,
6373    ) -> Vec<f64> {
6374        let frame = data.frame.as_ref().expect("frame");
6375        let mut diag = vec![ridge_beta; data.beta_dim];
6376        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6377            let m = blk.factor_a.nrows();
6378            for ia in 0..m {
6379                let coeff = blk.factor_a[[ia, ia]];
6380                let base = blk.global_offset + ia * r;
6381                for ib in 0..r {
6382                    diag[base + ib] += coeff;
6383                }
6384            }
6385        }
6386        for blk in &frame.frame_blocks {
6387            if blk.atom_i != blk.atom_j {
6388                continue;
6389            }
6390            let r = frame.ranks[blk.atom_i];
6391            let off = frame.border_offsets[blk.atom_i];
6392            let (mi, mj) = blk.g.dim();
6393            for li in 0..mi.min(mj) {
6394                let gii = blk.g[[li, li]];
6395                let base = off + li * r;
6396                for a in 0..r {
6397                    diag[base + a] += gii * blk.w[[a, a]];
6398                }
6399            }
6400        }
6401        diag
6402    }
6403
6404    /// #1551 DEFINITIVE kernel-correctness proof: the framed reduced-Schur matvec
6405    /// `out = S·x` must agree with the CPU oracle [`sae_framed_schur_matvec_cpu`]
6406    /// element-wise, for several independent `x`, to ≤ 1e-9. This is the parity
6407    /// gate that actually localizes a kernel/marshalling defect — unlike a
6408    /// solved-`δβ` comparison, it does NOT route through a linear solve, so it is
6409    /// independent of the conditioning of the assembled `S` (a near-singular `S`
6410    /// can make a dense-Cholesky vector and an iterative-PCG vector disagree at
6411    /// the *solution* level even when both operators are bit-correct; the
6412    /// operator itself must still match here). Fails loud if CUDA is present but
6413    /// the device matvec declines; skips cleanly only when no device exists.
6414    #[test]
6415    fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6416        use crate::arrow_schur::{
6417            DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6418        };
6419
6420        // Hand-checkable frame fixture: a mix of framed (r<p) and identity-ride
6421        // (r==p) atoms, a few off-diagonal cross blocks, dense per-row H_tβ.
6422        let p = 6usize;
6423        let n_atoms = 8usize;
6424        let ranks: Vec<usize> = (0..n_atoms)
6425            .map(|k| if k % 2 == 0 { 3usize } else { p })
6426            .collect();
6427        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6428        let mut border_offsets = Vec::with_capacity(n_atoms);
6429        let mut acc = 0usize;
6430        for k in 0..n_atoms {
6431            border_offsets.push(acc);
6432            acc += basis_sizes[k] * ranks[k];
6433        }
6434        let border_dim = acc;
6435
6436        let mut state = 0x1551_0017_1026_0922u64;
6437        let mut sample = || -> f64 {
6438            state = state
6439                .wrapping_mul(6364136223846793005)
6440                .wrapping_add(1442695040888963407);
6441            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6442        };
6443        let mut frames: Vec<Array2<f64>> = Vec::new();
6444        for k in 0..n_atoms {
6445            let r = ranks[k];
6446            let mut u = Array2::<f64>::zeros((p, r));
6447            for i in 0..p {
6448                for j in 0..r {
6449                    u[[i, j]] = if r == p && i == j {
6450                        1.0
6451                    } else if r == p {
6452                        0.0
6453                    } else {
6454                        sample()
6455                    };
6456                }
6457            }
6458            frames.push(u);
6459        }
6460        let w_of = |i: usize, j: usize| {
6461            let (ui, uj) = (&frames[i], &frames[j]);
6462            let (ri, rj) = (ranks[i], ranks[j]);
6463            let mut w = Array2::<f64>::zeros((ri, rj));
6464            for a in 0..ri {
6465                for b in 0..rj {
6466                    let mut s = 0.0;
6467                    for c in 0..p {
6468                        s += ui[[c, a]] * uj[[c, b]];
6469                    }
6470                    w[[a, b]] = s;
6471                }
6472            }
6473            w
6474        };
6475        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6476        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6477            pairs.push((i, j));
6478            pairs.push((j, i));
6479        }
6480        let mut frame_blocks = Vec::new();
6481        for &(i, j) in &pairs {
6482            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6483            let mut g = Array2::<f64>::zeros((mi, mj));
6484            for r in 0..mi {
6485                for c in 0..mj {
6486                    g[[r, c]] = 0.25 * sample();
6487                }
6488            }
6489            if i == j {
6490                for r in 0..mi.min(mj) {
6491                    g[[r, r]] += mi as f64 + 2.0;
6492                }
6493            }
6494            frame_blocks.push(FactoredFrameGBlock {
6495                atom_i: i,
6496                atom_j: j,
6497                g,
6498                w: w_of(i, j),
6499            });
6500        }
6501        let mut smooth_blocks = Vec::new();
6502        let mut smooth_ranks = Vec::new();
6503        for k in 0..n_atoms {
6504            let m = basis_sizes[k];
6505            let mut a = Array2::<f64>::zeros((m, m));
6506            for r in 0..m {
6507                for c in 0..m {
6508                    a[[r, c]] = 0.2 * sample();
6509                }
6510            }
6511            let mut s = a.t().dot(&a);
6512            for r in 0..m {
6513                s[[r, r]] += 1.0;
6514            }
6515            smooth_blocks.push(DeviceSaeSmoothBlock {
6516                global_offset: border_offsets[k],
6517                factor_a: s,
6518            });
6519            smooth_ranks.push(ranks[k]);
6520        }
6521        // Modest row count: this seam bypasses the offload floor, so we keep the
6522        // fixture small and the per-row reduced-Schur term well-scaled — the
6523        // matvec parity does not depend on the assembled-S conditioning at all.
6524        let n = 32usize;
6525        let q = 4usize;
6526        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6527        let mut row_htbeta = Vec::new();
6528        for i in 0..n {
6529            let mut a = Array2::<f64>::zeros((q, q));
6530            for r in 0..q {
6531                for c in 0..q {
6532                    a[[r, c]] = sample();
6533                }
6534            }
6535            let mut htt = a.t().dot(&a);
6536            for r in 0..q {
6537                htt[[r, r]] += q as f64 + 1.0;
6538            }
6539            sys.rows[i].htt = htt;
6540            let mut slab = vec![0.0_f64; q * border_dim];
6541            for c in 0..q {
6542                for col in 0..border_dim {
6543                    let v = 0.3 * sample();
6544                    slab[c * border_dim + col] = v;
6545                    sys.rows[i].htbeta[[c, col]] = v;
6546                }
6547            }
6548            row_htbeta.push(slab);
6549        }
6550        let ridge_t = 1e-7;
6551        let ridge_beta = 1e-6;
6552        let data = DeviceSaePcgData {
6553            p,
6554            beta_dim: border_dim,
6555            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6556            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6557            smooth_blocks,
6558            sparse_g_blocks: Vec::new(),
6559            frame: Some(DeviceSaeFrameData {
6560                ranks: ranks.clone(),
6561                basis_sizes: basis_sizes.clone(),
6562                border_offsets: border_offsets.clone(),
6563                frame_blocks,
6564                smooth_ranks,
6565                row_htbeta,
6566            }),
6567        };
6568
6569        // Several independent probe vectors x, including unit axes and dense
6570        // random — a marshalling stride/offset bug shows up as a per-component
6571        // mismatch on at least one.
6572        let mut probes: Vec<Array1<f64>> = Vec::new();
6573        probes.push(Array1::from_shape_fn(border_dim, |a| {
6574            ((a as f64 + 1.0) * 0.37).sin()
6575        }));
6576        probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6577        for axis in [0usize, border_dim / 3, border_dim - 1] {
6578            let mut e = Array1::<f64>::zeros(border_dim);
6579            e[axis] = 1.0;
6580            probes.push(e);
6581        }
6582
6583        let mut any_ran = false;
6584        let mut worst = 0.0_f64;
6585        for (pi, x) in probes.iter().enumerate() {
6586            let device = match super::framed_schur_matvec_once_on_device(
6587                &sys, &data, ridge_t, ridge_beta, x,
6588            ) {
6589                Ok(out) => out,
6590                Err(failure) => {
6591                    // Fail loud: a present CUDA device that declines this seam
6592                    // (which deliberately ignores the offload floor) means the
6593                    // framed matvec kernel does not run on GPU.
6594                    assert!(
6595                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6596                        "#1551: CUDA device present but the framed device matvec \
6597                         declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6598                         does not run on GPU"
6599                    );
6600                    return;
6601                }
6602            };
6603            any_ran = true;
6604            let mut cpu = vec![0.0_f64; border_dim];
6605            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6606                .expect("cpu oracle matvec");
6607            let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6608            for a in 0..border_dim {
6609                let rel = (device[a] - cpu[a]).abs() / scale;
6610                worst = worst.max(rel);
6611                assert!(
6612                    rel <= 1e-9,
6613                    "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6614                     rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6615                    device[a],
6616                    cpu[a],
6617                );
6618            }
6619        }
6620        if any_ran {
6621            // Positive on-device confirmation: the framed matvec ran on the GPU
6622            // and matched the CPU oracle across every probe. (1e-9 is far above
6623            // the ~1e-13 fp64 GEMV round-off; a structural marshalling bug would
6624            // be O(1).)
6625            assert!(
6626                gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6627                "#1551: matvec ran but no GPU runtime — unexpected"
6628            );
6629            assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6630        }
6631    }
6632}