Skip to main content

gam_solve/gpu_kernels/
arrow_schur.rs

1//! Fully GPU-resident batched Arrow-Schur dense Cholesky solver.
2//!
3//! Implements the square-root Schur form: each local block `D_i = L_i L_i^T`
4//! is factored on device, `u_i = L_i^{-1} g_i` and `Y_i = L_i^{-1} B_i` are
5//! formed by triangular solves, the reduced shared system
6//!     `S_β = C + ρ_β I − Σ_i Y_i^T Y_i,  r_β = -g_β + Σ_i Y_i^T u_i`
7//! is assembled on device, factored once, and the back-substitution
8//!     `w_i = u_i + Y_i · δβ,  L_i^T x_i = w_i,  δt_i = -x_i`
9//! is run on device. Only the final `(δt, δβ, log|H|)` triple is downloaded.
10//!
11//! The current caller (Arrow-Schur Newton step inside PIRLS) feeds uniform
12//! local block size `d` and uniform shared width `k`, so the entire pipeline
13//! is dispatched as a single p-group; per-p grouping for heterogenous blocks
14//! is Layer D's NVRTC fused-kernel concern and lives in this module's
15//! follow-up implementation rather than in policy plumbing.
16//!
17//! On non-Linux builds the entire module degrades to a CPU-fallback shim.
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24/// Outcome of a single Arrow-Schur Newton solve.
25pub struct ArrowSchurGpuSolution {
26    pub delta_t: Array1<f64>,
27    pub delta_beta: Array1<f64>,
28    /// Natural log of the determinant of the full bordered Hessian, computed
29    /// from the local Cholesky factors and the Schur factor on device.
30    pub log_det_hessian: f64,
31}
32
33/// Reason a device path declined to run; lets the host caller decide between
34/// CPU fallback and per-row escalation. `RidgeBumpRequired` carries the
35/// estimated diagonal bump needed to clear the failed pivot.
36#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38    /// CUDA runtime unavailable, allocation failed, or workload below policy.
39    Unavailable,
40    /// A row block was not positive definite even after the requested ridge.
41    /// Caller may retry with `ridge_t + bump`.
42    RidgeBumpRequired { row: usize, bump: f64 },
43    /// Shared Schur factor failed; bordered system is rank-deficient at the
44    /// requested ridges and the CPU path should handle escalation.
45    SchurFactorFailed { reason: String },
46    /// The dense GPU Schur path cannot consume this system's β-block. Either
47    /// the system carries matrix-free `H_ββ` / per-row `H_tβ` operators
48    /// (`had_*_matvec` set), OR the dense `(K×K)` `H_ββ` block is simply absent
49    /// (both flags false) — e.g. an SAE-manifold system whose β-curvature lives
50    /// in a `penalty_op` / factored-frame representation with `hbb` reclaimed to
51    /// a `0×0` workspace. In BOTH cases this is a capability mismatch, NOT a
52    /// numerical failure: the caller should route to CPU `InexactPCG` (or supply
53    /// dense buffers) rather than escalating a ridge. See `gpu/arrow_schur.rs`
54    /// Part B for the planned GPU PCG path that will lift this restriction at
55    /// K ≥ 5000.
56    GpuRequiresDenseSystem {
57        had_hbb_matvec: bool,
58        had_htbeta_matvec: bool,
59    },
60}
61
62/// Relative rounding margin (multiplier on `diag_scale · √ε`) added on top of
63/// the deficit-clearing shift in [`ridge_bump_to_make_pd`].
64///
65/// The exact shift `-(λ_min)` makes a block PD in exact arithmetic, but a
66/// single retry at precisely that magnitude is routinely re-rejected by the
67/// next POTRF because the rounding error of forming `D + ridge·I` and
68/// re-factoring is itself O(√ε). The 1024× headroom (≈ 2¹⁰, ten extra bits
69/// below the f64 mantissa's 52) clears the pivot on the first retry without
70/// materially perturbing the curvature the Newton step sees. Shared by every
71/// per-row / batched / fused producer so they suggest a consistent bump.
72const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
73
74/// Diagonal ridge bump that is GUARANTEED to make `H_tt + (ridge_t + bump)·I`
75/// positive definite for a *symmetric* per-row block, sized from the block's
76/// own entries rather than from the factorization's pivot index.
77///
78/// # Why the old `scale · |pivot| · √ε · 1024` estimate is wrong
79///
80/// The batched/fused device paths derive the suggested bump from the
81/// factorization "pivot" — but cuSOLVER's `potrf` (and the NVRTC kernel's
82/// status code) report the failing pivot as a **1-based row index**, NOT the
83/// magnitude of the negative pivot. A block that is indefinite by `O(1)`
84/// (e.g. `H_tt = -I`, whose smallest eigenvalue is `-1`) then yields the same
85/// `bump ≈ √ε · 1024 ≈ 1.5e-5` as a block that is indefinite by `O(√ε)`. The
86/// outer LM escalation, which retries at `ridge_t + bump` and grows
87/// geometrically with a bounded step count, can never lift a strongly
88/// indefinite block out of the negative regime, so the solve fails to recover
89/// even though the block is trivially regularizable. (Surfaced by the V100
90/// `ridge_bump_required_on_non_pd_row_recovers_after_bump` validation test.)
91///
92/// # The bound
93///
94/// By the Gershgorin circle theorem every eigenvalue `λ` of the symmetric
95/// matrix `A = H_tt` satisfies, for some row `i`,
96///   `λ ≥ A[i,i] − Σ_{j≠i} |A[i,j]|`,
97/// so `λ_min(A) ≥ min_i ( A[i,i] − Σ_{j≠i} |A[i,j]| ) =: g` (the most negative
98/// Gershgorin left edge). Adding `t·I` shifts every eigenvalue up by `t`, so
99/// `A + t·I` is PD as soon as `t > -g`. We are already sitting at `ridge_t`, so
100/// the ADDITIONAL bump needed is `-(g + ridge_t)` when that is positive. We add
101/// a relative safety margin (`√ε · scale · 1024`, the same headroom the legacy
102/// estimate used) so the re-factored, rounding-perturbed block clears the pivot
103/// on the first retry, and a `max(1)`-scaled floor so a marginally-indefinite
104/// block still gets a strictly positive, non-vanishing bump.
105///
106/// The returned value is the bump to ADD to the current `ridge_t`. It is always
107/// strictly positive (the caller only constructs `RidgeBumpRequired` on an
108/// actual non-PD failure, but the bound is defensive regardless).
109#[must_use]
110fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
111    let d = htt.nrows();
112    // Diagonal magnitude scale (also the legacy `scale`), and the most-negative
113    // Gershgorin left edge `g = min_i (A_ii − Σ_{j≠i} |A_ij|)`.
114    let mut scale = 1.0_f64;
115    let mut min_gershgorin_edge = f64::INFINITY;
116    for i in 0..d {
117        let diag = htt[[i, i]];
118        scale = scale.max(diag.abs());
119        let mut off_sum = 0.0_f64;
120        for j in 0..d {
121            if j != i {
122                off_sum += htt[[i, j]].abs();
123            }
124        }
125        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
126    }
127    if !min_gershgorin_edge.is_finite() {
128        // d == 0 (no rows) or non-finite entries: fall back to the scale-only
129        // floor so the caller still gets a strictly positive bump.
130        return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131    }
132    // Additional shift needed so `λ_min(A) + ridge_t + bump > 0`, i.e.
133    // `bump > -(min_gershgorin_edge + ridge_t)`.
134    let deficit = -(min_gershgorin_edge + ridge_t);
135    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
136    // Lift past the deficit (when positive) plus a rounding margin; never below
137    // the scale-relative floor so a marginal block still moves.
138    deficit.max(0.0) + margin
139}
140
141/// [`ridge_bump_to_make_pd`] for a `d × d` symmetric block stored column-major
142/// in a flat slice with the current ridge ALREADY baked into the diagonal
143/// (the device packers emit `D = H_tt + ridge_t·I` this way). Because the shift
144/// is already present, the Gershgorin bound is taken at `ridge_t = 0` and the
145/// returned value is still the ADDITIONAL bump to add on top of the current
146/// ridge. Returns the scale-only floor when `block` is mis-sized.
147// The column-major bump is a helper for the device tile packers, which only
148// exist on the linux CUDA path (`mod cuda`, `#[cfg(target_os = "linux")]`). It
149// therefore lives where it is used: gate it to linux. Its parity unit test is
150// gated to linux to match (a `test` token in the cfg would trip the build.rs
151// `#[cfg(test)]`-on-a-src-item ban; including `test` to dodge the non-linux
152// dead_code lint is exactly the escape hatch that ban forbids).
153#[cfg(target_os = "linux")]
154#[must_use]
155fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
156    if d == 0 || block.len() < d * d {
157        return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
158    }
159    // Column-major: element (row r, col c) at block[c*d + r]. The matrix is
160    // symmetric, so reading by column gives the same Gershgorin edges as by row.
161    let mut scale = 1.0_f64;
162    let mut min_gershgorin_edge = f64::INFINITY;
163    for i in 0..d {
164        let diag = block[i * d + i];
165        scale = scale.max(diag.abs());
166        let mut off_sum = 0.0_f64;
167        for j in 0..d {
168            if j != i {
169                off_sum += block[j * d + i].abs();
170            }
171        }
172        min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
173    }
174    let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
175    if !min_gershgorin_edge.is_finite() {
176        return margin;
177    }
178    (-min_gershgorin_edge).max(0.0) + margin
179}
180
181/// Entry point: attempt the fully device-resident Arrow-Schur Newton solve.
182/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` to indicate "device path
183/// declined, fall back to CPU" — never panics.
184pub fn solve_arrow_newton_step(
185    sys: &ArrowSchurSystem,
186    ridge_t: f64,
187    ridge_beta: f64,
188) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
189    let n = sys.rows.len();
190    let d = sys.d;
191    let k = sys.k;
192
193    // Detect matrix-free operators before any dim() checks so callers get a
194    // clear, actionable error instead of a generic SchurFactorFailed. The GPU
195    // dense-Schur path requires materialised H_ββ and per-row H_tβ slabs;
196    // CPU InexactPCG is the correct fallback when either operator is abstract.
197    let had_hbb_matvec = sys.hbb_matvec.is_some();
198    let had_htbeta_matvec = sys.htbeta_matvec.is_some();
199    if had_hbb_matvec || had_htbeta_matvec {
200        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
201            had_hbb_matvec,
202            had_htbeta_matvec,
203        });
204    }
205
206    // A `penalty_op` is the AUTHORITATIVE β-curvature source whenever it is
207    // installed: assembly bypasses the dense `hbb` accumulator (and, for
208    // frames-engaged SAE systems, reclaims it to a 0×0 workspace), so whatever
209    // `hbb` survives is STALE relative to the operator. The dense device Schur
210    // path reads ONLY `hbb`, so it must decline here — BEFORE the shape gate
211    // below — even when a stale `(k, k)` `hbb` would pass that gate: proceeding
212    // would silently compute the WRONG Newton step from stale curvature instead
213    // of routing to the CPU matrix-free lane (which reads `penalty_op`) that
214    // returns the correct one. The production caller `try_device_arrow_direct`
215    // already short-circuits this shape, but enforcing it at the entry keeps
216    // EVERY caller covered — the resident / reupload harnesses and any future
217    // direct caller — so no path can reach the device solve with a stale dense
218    // block. Both matvec flags are false here: control only reaches this point
219    // when neither matrix-free operator is installed (returned above), and the
220    // frames-engaged path installs `penalty_op` with no `hbb_matvec` /
221    // `htbeta_matvec`.
222    if sys.penalty_op.is_some() {
223        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
224            had_hbb_matvec: false,
225            had_htbeta_matvec: false,
226        });
227    }
228
229    if sys.hbb.dim() != (k, k) {
230        // The dense (K×K) H_ββ block is absent (e.g. an SAE-manifold system
231        // whose β-curvature is carried by a matrix-free `penalty_op` /
232        // factored-frame representation, with `hbb` reclaimed to a 0×0
233        // workspace at the end of assembly). This is a CAPABILITY decline, not
234        // a numerical failure: the dense device Schur path simply cannot
235        // consume this system, so the host must route it to the CPU lane
236        // (which reads the matrix-free operators) exactly as it does for the
237        // `hbb_matvec` / `htbeta_matvec` case above. Returning `SchurFactorFailed`
238        // here would masquerade as a non-PD/rank-deficient factorization and be
239        // escalated (and ultimately surfaced as a FATAL RemlConvergenceError)
240        // by the outer LM loop instead of falling back. Decline instead. Both
241        // matvec flags are false because the absence is structural (no dense
242        // block was materialized), not caused by an installed matrix-free op.
243        return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
244            had_hbb_matvec: false,
245            had_htbeta_matvec: false,
246        });
247    }
248    if n == 0 || d == 0 {
249        return Err(ArrowSchurGpuFailure::Unavailable);
250    }
251    if sys
252        .rows
253        .iter()
254        .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
255    {
256        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
257            reason: "row block dimension mismatch".to_string(),
258        });
259    }
260
261    #[cfg(not(target_os = "linux"))]
262    {
263        if ridge_t.is_nan() || ridge_beta.is_nan() {
264            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
265                reason: "ridge is NaN".to_string(),
266            });
267        }
268        Err(ArrowSchurGpuFailure::Unavailable)
269    }
270
271    #[cfg(target_os = "linux")]
272    {
273        // Multi-GPU: the arrow-Schur solve is row-block separable in its forward
274        // (per-row factor / whiten / partial-Schur) and backward (per-row
275        // back-sub) phases — only the small shared K×K reduce+factor+δβ is
276        // central. When more than one device is usable, split the WHOLE solve at
277        // row-block granularity across all GPUs. The POTRF stays fused with its
278        // dependent TRSM+GEMM on each tile's own stream, so no on-stream solve is
279        // orphaned. On `Unavailable` (one device, shape below policy, transient)
280        // fall through to the single-device fused / Layer-A paths below.
281        if gam_gpu::device_runtime::GpuRuntime::global()
282            .map(gam_gpu::device_runtime::GpuRuntime::device_count)
283            .unwrap_or(0)
284            > 1
285        {
286            match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
287                Ok(sol) => return Ok(sol),
288                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
289                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
290                }
291                Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
292                    return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
293                }
294                // Unavailable / GpuRequiresDenseSystem: fall through to the
295                // single-device paths (already shape-validated above).
296                Err(_) => {}
297            }
298        }
299        // Layer D admission: when the system shape passes the
300        // (Σ p³ ≥ 1e5 OR R ≥ 16) heuristic and `p ≤ MAX_FUSED_P`, the fused
301        // NVRTC kernel replaces the cuSOLVER/cuBLAS Layer A+B+C path with a
302        // single per-row block. Layer C↔D parity (math block 3 §16 test 6)
303        // requires both paths to agree to 1e-10 on identical inputs.
304        if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
305            match cuda::solve_fused(sys, ridge_t, ridge_beta) {
306                Ok(sol) => return Ok(sol),
307                // RidgeBumpRequired must surface to the outer escalation loop —
308                // the fused path's pivot diagnostic is identical in semantics
309                // to the cuSOLVER batched POTRF info code.
310                Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
311                    return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
312                }
313                // Any other failure (Unavailable, SchurFactorFailed) falls
314                // through to the unfused path so a flaky NVRTC compile or
315                // shared-mem allocation does not abort the outer Newton step.
316                Err(_) => {}
317            }
318        }
319        cuda::solve(sys, ridge_t, ridge_beta)
320    }
321}
322
323/// Build the stacked column-major D buffer (n local d×d blocks), the stacked
324/// stacked B buffer (n local d×k blocks), and the stacked g buffer
325/// (n local d-vectors) consumed by the device pipeline. Each block is laid
326/// out column-major so a single allocation + `cuMemcpyHtoD` reaches the
327/// device without per-row dispatch overhead.
328#[cfg(target_os = "linux")]
329fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
330    let n = sys.rows.len();
331    let d = sys.d;
332    let k = sys.k;
333    let mut d_buf = Vec::with_capacity(n * d * d);
334    let mut b_buf = Vec::with_capacity(n * d * k);
335    let mut g_buf = Vec::with_capacity(n * d);
336    for row in &sys.rows {
337        pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
338    }
339    (d_buf, b_buf, g_buf)
340}
341
342#[cfg(target_os = "linux")]
343#[inline]
344fn pack_block(
345    row: &crate::arrow_schur::ArrowRowBlock,
346    ridge_t: f64,
347    d: usize,
348    k: usize,
349    d_buf: &mut Vec<f64>,
350    b_buf: &mut Vec<f64>,
351    g_buf: &mut Vec<f64>,
352) {
353    for col in 0..d {
354        for r in 0..d {
355            let mut value = row.htt[[r, col]];
356            if r == col {
357                value += ridge_t;
358            }
359            d_buf.push(value);
360        }
361    }
362    for col in 0..k {
363        for r in 0..d {
364            b_buf.push(row.htbeta[[r, col]]);
365        }
366    }
367    for r in 0..d {
368        g_buf.push(row.gt[r]);
369    }
370}
371
372/// Test-only entry that forces the Layer D + E fused NVRTC path regardless
373/// of the admission heuristic. Used by the V100 Layer C↔D parity test to
374/// drive the fused kernel at small shapes the heuristic would otherwise
375/// route through the cuSOLVER/cuBLAS Layer A+B+C path.
376#[doc(hidden)]
377#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] // `sys` is consumed only by the linux branch
378pub fn solve_arrow_newton_step_fused_force(
379    sys: &ArrowSchurSystem,
380    ridge_t: f64,
381    ridge_beta: f64,
382) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
383    if ridge_t.is_nan() || ridge_beta.is_nan() {
384        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
385            reason: "ridge is NaN".to_string(),
386        });
387    }
388    #[cfg(not(target_os = "linux"))]
389    {
390        // No NVRTC toolchain off linux: the fused path is unconditionally
391        // unavailable. `sys` is consumed only by the linux branch below; the
392        // fn-level cfg_attr allows it to read as unused here without a banned
393        // `let _` binding or a no-op `drop` of the reference.
394        Err(ArrowSchurGpuFailure::Unavailable)
395    }
396    #[cfg(target_os = "linux")]
397    {
398        if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
399            .is_none()
400        {
401            return Err(ArrowSchurGpuFailure::Unavailable);
402        }
403        cuda::solve_fused(sys, ridge_t, ridge_beta)
404    }
405}
406
407/// #1017 Phase 3: a device-resident Arrow-Schur frame whose constant Hessian
408/// blocks (`D = H_tt`, `B = H_tβ`, border `H_ββ`) and their factors stay on the
409/// device across the inner Newton loop. Construct once per frozen gate/basis
410/// frame, then call [`ResidentArrowFrameHandle::solve_gradient`] once per
411/// iterate with the fresh residual gradient — only the `O(n·d + p)` gradient
412/// crosses to the device and only `δ` crosses back, in contrast to
413/// [`solve_arrow_newton_step`] which re-uploads and re-factors the full system
414/// every call. On a non-CUDA host construction returns
415/// `ArrowSchurGpuFailure::Unavailable`.
416pub struct ResidentArrowFrameHandle {
417    #[cfg(target_os = "linux")]
418    inner: cuda::ResidentArrowFrame,
419    #[cfg(not(target_os = "linux"))]
420    _never: std::convert::Infallible,
421}
422
423impl ResidentArrowFrameHandle {
424    /// Upload the constant Hessian blocks and perform the one-time factor work.
425    pub fn new(
426        sys: &ArrowSchurSystem,
427        ridge_t: f64,
428        ridge_beta: f64,
429    ) -> Result<Self, ArrowSchurGpuFailure> {
430        // The dense device path requires materialised blocks, same admission as
431        // `solve_arrow_newton_step`.
432        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
433            return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
434                had_hbb_matvec: sys.hbb_matvec.is_some(),
435                had_htbeta_matvec: sys.htbeta_matvec.is_some(),
436            });
437        }
438        #[cfg(not(target_os = "linux"))]
439        {
440            if ridge_t.is_nan() || ridge_beta.is_nan() {
441                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
442                    reason: "ridge is NaN".to_string(),
443                });
444            }
445            Err(ArrowSchurGpuFailure::Unavailable)
446        }
447        #[cfg(target_os = "linux")]
448        {
449            Ok(Self {
450                inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
451            })
452        }
453    }
454
455    /// Solve `H δ = −gradient` for a fresh gradient reusing the resident factors.
456    pub fn solve_gradient(
457        &self,
458        g_t: &[f64],
459        g_beta: &[f64],
460    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
461        #[cfg(not(target_os = "linux"))]
462        {
463            if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
464                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
465                    reason: "non-finite gradient entry".to_string(),
466                });
467            }
468            Err(ArrowSchurGpuFailure::Unavailable)
469        }
470        #[cfg(target_os = "linux")]
471        {
472            self.inner.solve_gradient(g_t, g_beta)
473        }
474    }
475
476    /// `log|H|` for the frame (constant; depends only on the factored Hessian).
477    #[must_use]
478    pub fn log_det_hessian(&self) -> f64 {
479        #[cfg(not(target_os = "linux"))]
480        {
481            // SAFETY: off-CUDA, `ResidentArrowFrameHandle::new` always returns
482            // `Err(Unavailable)`, so no handle of this type is ever constructed and
483            // this method is statically unreachable on non-Linux targets. A NaN
484            // sentinel would silently corrupt any consumer of the log-determinant,
485            // so fail loudly on the impossible path instead.
486            panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
487        }
488        #[cfg(target_os = "linux")]
489        {
490            self.inner.log_det_hessian()
491        }
492    }
493}
494
495/// Build a GPU-backed Schur matvec closure for CPU-driven PCG at K ≥ 5000.
496///
497/// Runs the fused NVRTC forward kernel once on the dense per-row `H_tβ` slabs
498/// to compute `Y_i = L_i^{-1} H_tβ^(i)` for all rows, persists the `Y_i`
499/// factors in a host-side buffer, and returns an `Arc<dyn Fn(...)>` closure
500/// that computes the full Schur matvec
501///
502/// ```text
503/// S·x = (H_ββ + ridge_beta·I)·x  −  Σ_i Y_i^T (Y_i·x)
504/// ```
505///
506/// each time it is called. At K ≥ 5000 the `Σ_i Y_i^T (Y_i·x)` term
507/// dominates over the host↔device transfer of the K-vector `x`, so the GPU
508/// path is a clear win even with per-iteration transfer.
509///
510/// `H_ββ·x` is evaluated on the CPU using `sys.hbb_matvec` when present (the
511/// matrix-free hook for SAE-manifold scale callers) or the dense `sys.hbb`
512/// block otherwise. The `Y_i` term uses cuBLAS batched GEMV device-side; only
513/// `x` (K doubles) and `out` (K doubles) cross the host↔device boundary per
514/// PCG iteration.
515///
516/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` if CUDA is unavailable or
517/// the system shape is outside the fused kernel's admission range (e.g.
518/// `d > MAX_FUSED_P = 32` or no CUDA context). Callers should fall back to CPU
519/// `InexactPCG` on `Unavailable`.
520///
521/// Returns `Err(ArrowSchurGpuFailure::RidgeBumpRequired)` if a per-row Cholesky
522/// factor failed at the requested `ridge_t`; the outer LM escalation should
523/// bump `ridge_t` and retry.
524///
525/// # Composition with the matrix-free SAE Kronecker operator
526///
527/// When `sys.htbeta_matvec` is set (matrix-free `H_tβ` Kronecker operator),
528/// the dense `H_tβ` slabs are absent — the dense forward kernel above cannot
529/// run, and at `K = 100K` the dense `Y_i = L_i^{-1} H_tβ^(i)` (`d × K` per row)
530/// could not be materialised anyway. Instead, `build_row_procedural_matvec`
531/// returns a row-procedural Schur matvec: per row it gathers
532/// `v_i = H_tβ^(i)·x` through the forward operator (sparse `O(m_i · p)`),
533/// solves `(H_tt^(i) + ρ_t·I)^{-1} v_i` through a pre-computed per-row Cholesky
534/// factor, and scatters `H_βt^(i)·w_i` through the sparse transpose operator
535/// (`O(m_i · p)`, replacing the old `O(K)` column-probe). This is the
536/// row-procedural `a_ik · Φ_k[i,m]` Kronecker apply over the active atoms only.
537pub fn gpu_schur_matvec_backend(
538    sys: &ArrowSchurSystem,
539    ridge_t: f64,
540    ridge_beta: f64,
541) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
542    // Matrix-free H_tβ operator present: drive the row-procedural sparse
543    // Kronecker apply (active atoms only) instead of the dense forward kernel.
544    if sys.htbeta_matvec.is_some() {
545        return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
546    }
547
548    #[cfg(not(target_os = "linux"))]
549    {
550        // No CUDA runtime on non-Linux. NaN ridges are validated to ensure the
551        // same contract as the Linux path.
552        if ridge_t.is_nan() || ridge_beta.is_nan() {
553            return Err(ArrowSchurGpuFailure::Unavailable);
554        }
555        Err(ArrowSchurGpuFailure::Unavailable)
556    }
557
558    #[cfg(target_os = "linux")]
559    {
560        cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
561    }
562}
563
564/// Build a row-procedural reduced-Schur matvec for matrix-free SAE Kronecker
565/// systems, eliminating the per-row latent block via cached per-row Cholesky
566/// factors and applying the cross-block through the sparse forward/transpose
567/// Kronecker operators (active atoms only).
568///
569/// The returned closure evaluates
570/// `S·x = (H_ββ + ρ_β·I)·x − Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x`,
571/// the same reduced Schur complement the dense path forms, but never
572/// materialises the `d × K` cross-block `H_tβ^(i)`: the forward operator
573/// (`out = H_tβ^(i)·x`) and transpose operator (`out += H_βt^(i)·v`) are the
574/// sparse Kronecker gather/scatter from `SaeKroneckerRows`. The per-row factor
575/// of `H_tt^(i) + ρ_t·I` is computed once when the closure is built and reused
576/// across every CG iteration.
577///
578/// Returns `RidgeBumpRequired` if a per-row block is not positive definite at
579/// the requested `ridge_t`; the outer LM escalation bumps `ridge_t` and retries.
580fn build_row_procedural_matvec(
581    sys: &ArrowSchurSystem,
582    ridge_t: f64,
583    ridge_beta: f64,
584) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
585    use std::sync::Arc;
586    let n = sys.rows.len();
587    let k = sys.k;
588    let forward = sys
589        .htbeta_matvec
590        .clone()
591        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
592    let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
593        // A forward operator without its sparse adjoint cannot be applied
594        // row-procedurally; this is a wiring error, surfaced as a Schur failure
595        // so the caller routes to the dense CPU path rather than misreporting a
596        // numerical bump.
597        ArrowSchurGpuFailure::SchurFactorFailed {
598            reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
599                     forward operator installed without its sparse adjoint"
600                .to_string(),
601        }
602    })?;
603
604    // Pre-factor each per-row block H_tt^(i) + ρ_t·I = L_i L_iᵀ on the host.
605    // The blocks are tiny (d_i ≲ 32) and the dense cross-block slabs are
606    // absent, so there is no device forward-kernel work to amortise here; the
607    // GPU win is the reduced K-system solve in `solve_reduced_beta_pcg`.
608    let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
609    for (i, row) in sys.rows.iter().enumerate() {
610        let di = row.htt.nrows();
611        if row.htt.ncols() != di || row.gt.len() != di {
612            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
613                reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
614            });
615        }
616        let mut block = row.htt.clone();
617        for r in 0..di {
618            block[[r, r]] += ridge_t;
619        }
620        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
621            .ok_or_else(|| {
622                // Deficit-aware bump from the block's own entries (Gershgorin),
623                // so the outer LM escalation lifts a strongly-indefinite block
624                // out of the negative regime in one retry.
625                ArrowSchurGpuFailure::RidgeBumpRequired {
626                    row: i,
627                    bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
628                }
629            })?;
630        factors.push(factor);
631    }
632
633    // The SAE-manifold β-Hessian lives in the structured penalty operator
634    // (data-fit Gauss-Newton `G ⊗ I_p` + smoothness Kronecker blocks + any
635    // dense analytic-β residual), NOT in the dense `hbb` accumulator — for
636    // matrix-free systems `hbb` is zero/absent. Capture the effective penalty
637    // operator so `H_ββ·x` matches the CPU `schur_matvec` path exactly. The
638    // operator's `matvec` adds (`y += P x`), so seed `out` from the ridge term.
639    let penalty_op = sys.effective_penalty_op();
640    let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
641
642    let closure: crate::arrow_schur::GpuSchurMatvec =
643        Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
644            assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
645            assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
646
647            // (H_ββ + ρ_β·I)·x into out. Seed with the ridge term, then add the
648            // structured penalty-side product (penalty_op.matvec is additive).
649            {
650                let x_slice = x.as_slice().expect("x must be contiguous");
651                let out_slice = out.as_slice_mut().expect("out must be contiguous");
652                for a in 0..k {
653                    out_slice[a] = ridge_beta * x_slice[a];
654                }
655                penalty_op.matvec(x_slice, out_slice);
656            }
657
658            // out -= Σ_i H_βt^(i) (H_tt^(i) + ρ_t·I)^{-1} H_tβ^(i)·x.
659            //
660            // #1017: this row-procedural reduced-Schur term is the matrix-free
661            // SAE path's matvec hot loop (`build_row_procedural_matvec` is the
662            // host backend `gpu_schur_matvec_backend` returns when the dense
663            // `H_tβ` slabs are absent — the production Qwen shape). At
664            // (n≈2000 rows) it ran SERIALLY on one core and allocated a fresh
665            // length-`K` `neg` plus per-row `v_i`/`w_i` on EVERY CG iteration —
666            // tens of thousands of tiny heap allocations across a solve. Each
667            // row contributes an independent length-`K` scatter, so the sum is
668            // embarrassingly parallel; fan it across rayon over fixed row chunks
669            // and fold the per-chunk length-`K` partials in chunk order so the
670            // f64 reduction is deterministic (bit-identical run-to-run)
671            // regardless of thread scheduling — it agrees with the serial sum up
672            // to ULP-scale chunk reassociation (the #1017 verification gate).
673            // Because that reassociation is a real (if tiny) departure from
674            // serial, the criterion ranking across topology candidates is stable
675            // except for candidates separated by less than the reassociation
676            // margin, where the near-tie winner can flip — not an exact no-move
677            // guarantee (#1211). Stay
678            // sequential below
679            // `SCHUR_MATVEC_PARALLEL_ROW_MIN` rows and when already inside a
680            // rayon worker (the topology race fans candidates with
681            // `run_topology_race_parallel`) — the same nested-rayon guard the
682            // CPU `schur_matvec` uses. Buffers (`v_i`, `neg`) are reused across
683            // rows within a chunk, so the per-row allocation churn is gone.
684            let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
685                && rayon::current_thread_index().is_none();
686            if parallel {
687                use rayon::prelude::*;
688                const CHUNK: usize = 64;
689                let partials: Vec<Array1<f64>> = (0..n)
690                    .into_par_iter()
691                    .chunks(CHUNK)
692                    .map(|idxs| {
693                        // One length-`K` scatter accumulator per chunk; the
694                        // per-row latent vector `v_i` (length `d_i ≲ 32`) is the
695                        // only per-row buffer, sized to the row's own `d_i`.
696                        let mut neg = Array1::<f64>::zeros(k);
697                        for i in idxs {
698                            let di = row_dims[i];
699                            // v_i = H_tβ^(i)·x (sparse Kronecker gather).
700                            let mut v_i = Array1::<f64>::zeros(di);
701                            forward(i, x.view(), &mut v_i);
702                            // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
703                            let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
704                            // neg += H_βt^(i)·w_i (sparse scatter).
705                            transpose(i, w_i.view(), &mut neg);
706                        }
707                        neg
708                    })
709                    .collect();
710                // #1017/#1175 floating-point parity contract: Rayon may
711                // schedule chunks on any worker, but `.chunks(CHUNK).collect()`
712                // returns partials in chunk-index order. Each chunk's row sum
713                // is formed locally in increasing row order, then chunk
714                // partials are folded left-to-right below. That makes the
715                // parallel row-procedural Schur term deterministic for a fixed
716                // input and chunking (no scheduling-dependent gather/scatter
717                // reordering), but it is not required to be bit-identical to
718                // the serial path because additions are reassociated at chunk
719                // boundaries. CPU/GPU validation should therefore allow
720                // ULP-scale drift while expecting stable run-to-run results.
721                let mut neg = Array1::<f64>::zeros(k);
722                for part in &partials {
723                    for a in 0..k {
724                        neg[a] += part[a];
725                    }
726                }
727                for a in 0..k {
728                    out[a] -= neg[a];
729                }
730            } else {
731                // Serial path: reuse one `neg` and one `v_i` across rows.
732                let mut neg = Array1::<f64>::zeros(k);
733                for i in 0..n {
734                    let di = row_dims[i];
735                    // v_i = H_tβ^(i)·x (sparse Kronecker gather, length d_i).
736                    let mut v_i = Array1::<f64>::zeros(di);
737                    forward(i, x.view(), &mut v_i);
738                    // w_i = (H_tt^(i) + ρ_t·I)^{-1} v_i via L_i L_iᵀ.
739                    let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
740                    // neg += H_βt^(i)·w_i (sparse scatter); subtract once at end.
741                    transpose(i, w_i.view(), &mut neg);
742                }
743                for a in 0..k {
744                    out[a] -= neg[a];
745                }
746            }
747        });
748
749    Ok(closure)
750}
751
752/// Solve the reduced shared β-system `S·δβ = r` fully on device with a
753/// Jacobi-preconditioned conjugate-gradient (Steihaug truncated-CG) loop.
754///
755/// `S` is the already-reduced symmetric positive-definite `K × K` Schur
756/// complement the streaming SAE joint fit accumulates across minibatches
757/// (`StreamingArrowSchur::take_accumulators` summed over chunks, with the
758/// global β ridge folded in). The per-row latent blocks have already been
759/// eliminated into `S` on the host streaming path; the device's job is the
760/// dense `K`-dimensional solve, which is the dominant cost at `K = 100K`.
761///
762/// The dense `S·p` matvec runs on device via cuBLAS `Dgemv`, and the PCG state
763/// vectors (`x`, `r`, `z`, `p`, `S·p`) remain device-resident for the solve.
764/// Jacobi preconditioning is an elementwise CUDA kernel; only convergence
765/// scalars (`pᵀSp`, `rᵀz`, `‖r‖`) cross the host boundary per iteration, plus the
766/// final solution vector.
767///
768/// Returns `Err(ArrowSchurGpuFailure::Unavailable)` when CUDA is unavailable
769/// or the workload is below the dispatch policy; the caller then runs the CPU
770/// reduced-β solve. Returns `Err(ArrowSchurGpuFailure::SchurFactorFailed)`
771/// when `S` carries a non-positive Jacobi diagonal (caller escalates the
772/// proximal ridge).
773pub fn solve_reduced_beta_pcg(
774    s_acc: &Array2<f64>,
775    rhs_beta: &Array1<f64>,
776    max_iterations: usize,
777    relative_tolerance: f64,
778) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
779    solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
780        .map(|(x, _)| x)
781}
782
783#[doc(hidden)]
784pub fn solve_reduced_beta_pcg_with_diagnostics(
785    s_acc: &Array2<f64>,
786    rhs_beta: &Array1<f64>,
787    max_iterations: usize,
788    relative_tolerance: f64,
789) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
790    let k = rhs_beta.len();
791    if s_acc.dim() != (k, k) {
792        return Err(ArrowSchurGpuFailure::SchurFactorFailed {
793            reason: format!(
794                "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
795                s_acc.dim()
796            ),
797        });
798    }
799    if k == 0 {
800        return Err(ArrowSchurGpuFailure::Unavailable);
801    }
802
803    #[cfg(not(target_os = "linux"))]
804    {
805        if relative_tolerance.is_nan() || max_iterations == 0 {
806            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
807                reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
808            });
809        }
810        Err(ArrowSchurGpuFailure::Unavailable)
811    }
812
813    #[cfg(target_os = "linux")]
814    {
815        cuda::solve_reduced_beta_pcg_with_diagnostics(
816            s_acc,
817            rhs_beta,
818            max_iterations,
819            relative_tolerance,
820        )
821    }
822}
823
824pub fn solve_sae_matrix_free_pcg(
825    sys: &ArrowSchurSystem,
826    data: &DeviceSaePcgData,
827    ridge_t: f64,
828    ridge_beta: f64,
829    rhs_beta: &Array1<f64>,
830    max_iterations: usize,
831    relative_tolerance: f64,
832) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
833    if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
834        return Err(ArrowSchurGpuFailure::Unavailable);
835    }
836    #[cfg(not(target_os = "linux"))]
837    {
838        if ridge_t.is_nan()
839            || ridge_beta.is_nan()
840            || relative_tolerance.is_nan()
841            || max_iterations == 0
842        {
843            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
844                reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
845            });
846        }
847        Err(ArrowSchurGpuFailure::Unavailable)
848    }
849    #[cfg(target_os = "linux")]
850    {
851        // #1017/#1026 dispatch GUARD: framed data (frame metadata present) carries
852        // a factored β border `G ⊗ W_{ij}` data Hessian and dense per-row cross
853        // blocks the legacy `⊗ I_p` kernel CANNOT represent — feeding it framed
854        // data would silently return a WRONG Newton step (it returns Ok with no
855        // fallback). Route framed systems to the dedicated framed kernel and
856        // legacy full-`B` systems to the legacy kernel; the two never cross.
857        if data.frame.is_some() {
858            cuda::solve_sae_matrix_free_pcg_framed(
859                sys,
860                data,
861                ridge_t,
862                ridge_beta,
863                rhs_beta,
864                max_iterations,
865                relative_tolerance,
866            )
867        } else {
868            cuda::solve_sae_matrix_free_pcg(
869                sys,
870                data,
871                ridge_t,
872                ridge_beta,
873                rhs_beta,
874                max_iterations,
875                relative_tolerance,
876            )
877        }
878    }
879}
880
881/// #1551 kernel-isolating parity probe: run the framed reduced-Schur matvec
882/// `out = S·x` exactly once on the device and return it (no PCG, no offload-floor
883/// gate). The test suite diffs this element-wise against the CPU oracle
884/// [`sae_framed_schur_matvec_cpu`] to prove the GPU kernel computes the SAME
885/// operator — a check that is independent of solver conditioning (unlike a
886/// solved-`δβ` comparison, which can diverge purely because dense Cholesky and
887/// iterative PCG resolve an ill-conditioned `S` to different accuracies). On a
888/// non-CUDA host this returns `Unavailable` so the caller skips cleanly.
889#[doc(hidden)]
890pub fn framed_schur_matvec_once_on_device(
891    sys: &ArrowSchurSystem,
892    data: &DeviceSaePcgData,
893    ridge_t: f64,
894    ridge_beta: f64,
895    x: &Array1<f64>,
896) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
897    if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
898        return Err(ArrowSchurGpuFailure::Unavailable);
899    }
900    if data.frame.is_none() {
901        return Err(ArrowSchurGpuFailure::Unavailable);
902    }
903    #[cfg(not(target_os = "linux"))]
904    {
905        // CUDA is linux-only; the framed device solve is unavailable here. Read
906        // the ridge params (a real, cheap use) so the values the shared signature
907        // carries are not flagged unused on this target — without an `#[allow]`
908        // or `let _` dodge, both of which the build.rs scanner bans.
909        if ridge_t.is_finite() && ridge_beta.is_finite() {
910            return Err(ArrowSchurGpuFailure::Unavailable);
911        }
912        Err(ArrowSchurGpuFailure::Unavailable)
913    }
914    #[cfg(target_os = "linux")]
915    {
916        cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
917    }
918}
919
920/// Reference dense back-end used by tests and as the fallback when the
921/// GPU declines. Kept here (not in `arrow_schur_gpu.rs`) so the validation
922/// suite has one canonical baseline.
923#[doc(hidden)]
924pub fn solve_arrow_newton_step_dense_reference(
925    sys: &ArrowSchurSystem,
926    ridge_t: f64,
927    ridge_beta: f64,
928) -> Result<ArrowSchurGpuSolution, String> {
929    let n = sys.rows.len();
930    let d = sys.d;
931    let k = sys.k;
932    let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
933    let mut h = Array2::<f64>::zeros((total, total));
934    let mut rhs = Array1::<f64>::zeros(total);
935    for (i, row) in sys.rows.iter().enumerate() {
936        let base = i * d;
937        for c in 0..d {
938            for r in 0..d {
939                h[[base + r, base + c]] = row.htt[[r, c]];
940            }
941            h[[base + c, base + c]] += ridge_t;
942        }
943        for c in 0..k {
944            for r in 0..d {
945                let value = row.htbeta[[r, c]];
946                h[[base + r, n * d + c]] = value;
947                h[[n * d + c, base + r]] = value;
948            }
949        }
950        for r in 0..d {
951            rhs[base + r] = -row.gt[r];
952        }
953    }
954    for c in 0..k {
955        for r in 0..k {
956            h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
957        }
958        h[[n * d + c, n * d + c]] += ridge_beta;
959        rhs[n * d + c] = -sys.gb[c];
960    }
961    let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
962        .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
963    let mut log_det = 0.0_f64;
964    for i in 0..total {
965        log_det += factor[[i, i]].ln();
966    }
967    log_det *= 2.0;
968    let solved = cholesky_solve_vector(factor.view(), rhs.view());
969    let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
970    let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
971    Ok(ArrowSchurGpuSolution {
972        delta_t,
973        delta_beta,
974        log_det_hessian: log_det,
975    })
976}
977
978/// Frames-engaged reduced-Schur penalty-side matvec `out = (P_ββ + ρ_β I)·x`,
979/// computed purely from the factored device data (issue #1017/#1026). This is
980/// the CPU bit-parity ORACLE for the GPU `arrow_sae_*` penalty kernels on the
981/// frames path: smooth `λ S_k ⊗ I_{r_k}` (each `smooth_blocks[i]` at its
982/// `global_offset` with right-width `frame.smooth_ranks[i]`) plus data-fit
983/// `G_{ij} ⊗ W_{ij}` (each `frame.frame_blocks` entry, with the `μ`-major /
984/// frame-minor index `border_offset[atom] + basis·r + frame_coord`). The
985/// accumulation order matches the device kernels exactly.
986///
987/// `out` is OVERWRITTEN: first set to `ρ_β·x`, then the penalty blocks add in.
988#[doc(hidden)]
989pub fn sae_framed_penalty_matvec_cpu(
990    data: &DeviceSaePcgData,
991    ridge_beta: f64,
992    x: &[f64],
993    out: &mut [f64],
994) {
995    let frame = data
996        .frame
997        .as_ref()
998        .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
999    let k = data.beta_dim;
1000    for a in 0..k {
1001        out[a] = ridge_beta * x[a];
1002    }
1003    // Smooth penalty `λ S_k ⊗ I_{r_k}`: y[off + ia·r + ib] += Σ_ja S[ia,ja]·x[off + ja·r + ib].
1004    for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
1005        let off = blk.global_offset;
1006        let m = blk.factor_a.nrows();
1007        for i_a in 0..m {
1008            for i_b in 0..r {
1009                let mut acc = 0.0_f64;
1010                for j_a in 0..m {
1011                    let s = blk.factor_a[[i_a, j_a]];
1012                    if s == 0.0 {
1013                        continue;
1014                    }
1015                    acc += s * x[off + j_a * r + i_b];
1016                }
1017                out[off + i_a * r + i_b] += acc;
1018            }
1019        }
1020    }
1021    // Data-fit penalty `G_{ij} ⊗ W_{ij}`.
1022    for blk in &frame.frame_blocks {
1023        let r_i = frame.ranks[blk.atom_i];
1024        let r_j = frame.ranks[blk.atom_j];
1025        let off_i = frame.border_offsets[blk.atom_i];
1026        let off_j = frame.border_offsets[blk.atom_j];
1027        let (m_i, m_j) = blk.g.dim();
1028        for li in 0..m_i {
1029            let yi_base = off_i + li * r_i;
1030            for lj in 0..m_j {
1031                let g = blk.g[[li, lj]];
1032                if g == 0.0 {
1033                    continue;
1034                }
1035                let xj_base = off_j + lj * r_j;
1036                for a in 0..r_i {
1037                    let mut acc = 0.0_f64;
1038                    for b in 0..r_j {
1039                        acc += blk.w[[a, b]] * x[xj_base + b];
1040                    }
1041                    out[yi_base + a] += g * acc;
1042                }
1043            }
1044        }
1045    }
1046}
1047
1048/// Frames-engaged FULL reduced-Schur matvec `out = S·x` purely from the device
1049/// data, where `S = (P_ββ + ρ_β I) − Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)`
1050/// (issue #1017/#1026). The penalty side is [`sae_framed_penalty_matvec_cpu`];
1051/// the per-row reduced term reads the dense `frame.row_htbeta[i]`
1052/// (`q_i × border_dim`, row-major), solves against the row's
1053/// `H_tt^(i)+ρ_t I` Cholesky factor, and scatters the transpose back. This is
1054/// the size-independent bit-parity oracle the device kernel mirrors; it is also
1055/// the matvec the GPU PCG iterates.
1056#[doc(hidden)]
1057pub fn sae_framed_schur_matvec_cpu(
1058    sys: &ArrowSchurSystem,
1059    data: &DeviceSaePcgData,
1060    ridge_t: f64,
1061    ridge_beta: f64,
1062    x: &[f64],
1063    out: &mut [f64],
1064) -> Result<(), String> {
1065    let frame = data
1066        .frame
1067        .as_ref()
1068        .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1069    let k = data.beta_dim;
1070    sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1071    if frame.row_htbeta.len() != sys.rows.len() {
1072        return Err(format!(
1073            "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1074            frame.row_htbeta.len(),
1075            sys.rows.len()
1076        ));
1077    }
1078    for (i, row) in sys.rows.iter().enumerate() {
1079        let slab = &frame.row_htbeta[i];
1080        if slab.is_empty() {
1081            continue;
1082        }
1083        let qi = sys.row_dims[i];
1084        if qi == 0 || slab.len() != qi * k {
1085            continue;
1086        }
1087        // h = H_tβ^(i) · x  (length q_i).
1088        let mut h = vec![0.0_f64; qi];
1089        for c in 0..qi {
1090            let base = c * k;
1091            let mut acc = 0.0_f64;
1092            for a in 0..k {
1093                acc += slab[base + a] * x[a];
1094            }
1095            h[c] = acc;
1096        }
1097        // solve (H_tt^(i)+ρ_t I) s = h.
1098        let mut block = row.htt.clone();
1099        for d in 0..qi {
1100            block[[d, d]] += ridge_t;
1101        }
1102        let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1103            .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1104        let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1105        // out -= H_βt^(i) · s = (H_tβ^(i))ᵀ · s.
1106        for c in 0..qi {
1107            let sc = s[c];
1108            if sc == 0.0 {
1109                continue;
1110            }
1111            let base = c * k;
1112            for a in 0..k {
1113                out[a] -= slab[base + a] * sc;
1114            }
1115        }
1116    }
1117    Ok(())
1118}
1119
1120#[cfg(target_os = "linux")]
1121mod cuda {
1122    use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1123    use gam_gpu::driver::to_i32;
1124    use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1125    use crate::arrow_schur::{
1126        ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1127    };
1128    use cudarc::cublas::sys::{
1129        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1130    };
1131    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1132    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1133    use cudarc::driver::{
1134        CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1135        PushKernelArg,
1136    };
1137    use ndarray::Array1;
1138    use std::sync::{Arc, OnceLock};
1139
1140    /// Per-row work slot for the row-block-granular multi-GPU solve. Inputs are
1141    /// the packed single-row buffers (`d×d` D block + ρ_t ridge, `d×k` B block,
1142    /// `d` g vector); the forward pass fills the whitened factors `l/u/y` and the
1143    /// per-tile reduction lands in the tile's leading slot.
1144    struct RowSlot {
1145        // Inputs (packed once on the host, column-major).
1146        d_block: Vec<f64>, // d*d
1147        b_block: Vec<f64>, // d*k
1148        g_vec: Vec<f64>,   // d
1149        // Forward outputs, kept on the host for the back-sub pass.
1150        l_block: Vec<f64>, // d*d lower factor, column-major
1151        u_vec: Vec<f64>,   // d   (= L^{-1} g)
1152        y_block: Vec<f64>, // d*k (= L^{-1} B), column-major
1153        log_det_local: f64,
1154        // Set on a non-PD pivot so the orchestrator can raise RidgeBumpRequired
1155        // for the offending global row instead of silently falling back.
1156        bump: Option<f64>,
1157        // Tile-level reduction, written into the tile's first slot only.
1158        tile_partial_schur: Option<Vec<f64>>, // k*k col-major, = Σ Y_iᵀY_i
1159        tile_partial_rhs: Option<Vec<f64>>,   // k, = Σ Y_iᵀu_i
1160        // Back-sub output for this row.
1161        delta_t_block: Vec<f64>, // d
1162    }
1163
1164    /// Row-block-granular multi-GPU Arrow-Schur Newton solve.
1165    ///
1166    /// The solve is separable across row blocks in both phases:
1167    ///   * forward — each row's local Cholesky `L_i`, whitening
1168    ///     `u_i = L_i⁻¹g_i`, `Y_i = L_i⁻¹B_i`, and partial Schur
1169    ///     `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` are independent;
1170    ///   * backward — `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)` is independent.
1171    /// Only the small shared `K×K` reduce + factor + `δβ` solve is central.
1172    ///
1173    /// `gam_gpu::pool::scatter_batched` hands each device a contiguous row
1174    /// tile on its own bound context/stream; the per-tile forward keeps the
1175    /// POTRF fused with its dependent TRSM + Schur GEMM on that one stream, so no
1176    /// on-stream solve is orphaned. Tile partials and per-tile `log|L|` are
1177    /// reduced on the host (in tile/row order), `S_β` is factored on the primary
1178    /// device, and the back-sub is scattered back across the same tiles.
1179    ///
1180    /// Returns `Unavailable` (caller uses a single-device path) when the system
1181    /// carries matrix-free operators, the shared block is not dense `K×K`, the
1182    /// pool is single-device, or any tile's device work declines. A non-PD tip
1183    /// block surfaces as `RidgeBumpRequired` for the precise global row.
1184    pub(super) fn solve_multi_gpu(
1185        sys: &ArrowSchurSystem,
1186        ridge_t: f64,
1187        ridge_beta: f64,
1188    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1189        let n = sys.rows.len();
1190        let d = sys.d;
1191        let k = sys.k;
1192        if n == 0 || d == 0 || k == 0 {
1193            return Err(ArrowSchurGpuFailure::Unavailable);
1194        }
1195        // Dense shared block + materialised per-row slabs are required; the
1196        // public entry already rejected matrix-free operators, but re-check so
1197        // this routine is safe in isolation.
1198        if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1199            return Err(ArrowSchurGpuFailure::Unavailable);
1200        }
1201
1202        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1203            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1204        if runtime.device_count() < 2 {
1205            return Err(ArrowSchurGpuFailure::Unavailable);
1206        }
1207
1208        // Pack one slot per row (column-major), folding ρ_t into each D block.
1209        let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1210        for row in &sys.rows {
1211            if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1212                return Err(ArrowSchurGpuFailure::Unavailable);
1213            }
1214            let mut d_block = Vec::with_capacity(d * d);
1215            let mut b_block = Vec::with_capacity(d * k);
1216            let mut g_vec = Vec::with_capacity(d);
1217            pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1218            slots.push(RowSlot {
1219                d_block,
1220                b_block,
1221                g_vec,
1222                l_block: Vec::new(),
1223                u_vec: Vec::new(),
1224                y_block: Vec::new(),
1225                log_det_local: 0.0,
1226                bump: None,
1227                tile_partial_schur: None,
1228                tile_partial_rhs: None,
1229                delta_t_block: vec![0.0; d],
1230            });
1231        }
1232
1233        // ---- Forward pass: per-device row tile, fused on its own stream ----
1234        let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1235            forward_tile(ordinal, d, k, tile)
1236        });
1237        if forward_ok.is_none() {
1238            return Err(ArrowSchurGpuFailure::Unavailable);
1239        }
1240
1241        // Surface a non-PD tip block as a precise per-row ridge bump.
1242        let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1243        if let Some((row, bump)) = slots
1244            .iter()
1245            .enumerate()
1246            .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1247        {
1248            return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1249        }
1250
1251        // ---- Central: reduce tile partials → S_β, r_β; factor; solve δβ ----
1252        // Seed S_β with H_ββ + ρ_β I (column-major) and r_β with -g_β, then fold
1253        // in the per-tile partials in tile order so the reduction order tracks
1254        // the single-device accumulation (up to inter-tile reassociation).
1255        let mut schur_host = vec![0.0_f64; k * k];
1256        for col in 0..k {
1257            for row in 0..k {
1258                let mut v = sys.hbb[[row, col]];
1259                if row == col {
1260                    v += ridge_beta;
1261                }
1262                schur_host[col * k + row] = v;
1263            }
1264        }
1265        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1266        let mut log_det = 0.0_f64;
1267        for start in tile_starts(&row_base_of_tile) {
1268            let slot = &slots[start];
1269            let partial_schur = slot
1270                .tile_partial_schur
1271                .as_ref()
1272                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1273            let partial_rhs = slot
1274                .tile_partial_rhs
1275                .as_ref()
1276                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1277            // `accumulate_schur` writes `partial_schur = -Σ_tile Y_iᵀY_i` (GEMM
1278            // α=-1, β=1 into a zero seed) and `partial_rhs = +Σ_tile Y_iᵀu_i`.
1279            // The reduced Schur is `S = (H_ββ+ρI) − Σ_all Y_iᵀY_i`, so adding the
1280            // (already-negated) partials reproduces the single-device sign.
1281            for idx in 0..k * k {
1282                schur_host[idx] += partial_schur[idx];
1283            }
1284            for a in 0..k {
1285                rhs_host[a] += partial_rhs[a];
1286            }
1287        }
1288        for slot in &slots {
1289            log_det += slot.log_det_local;
1290        }
1291
1292        // Factor S_β and solve δβ on the primary device (small K×K leaf). The
1293        // stream carries the primary context (same pattern as `solve()`); no
1294        // thread bind is needed for the cuSOLVER/cuBLAS handles created from it.
1295        let primary = runtime.selected_device().ordinal;
1296        let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1297            .and_then(|ctx| ctx.new_stream().ok())
1298            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1299        let solver =
1300            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1301        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1302        let mut schur_dev = stream
1303            .clone_htod(&schur_host)
1304            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1305        let mut rhs_dev = stream
1306            .clone_htod(&rhs_host)
1307            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1308        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1309        if info != 0 {
1310            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1311                reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1312            });
1313        }
1314        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1315        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1316        let delta_beta_host = stream
1317            .clone_dtoh(&rhs_dev)
1318            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1319        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1320        let l_schur_host = stream
1321            .clone_dtoh(&schur_dev)
1322            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1323        for j in 0..k {
1324            log_det += l_schur_host[j * k + j].ln();
1325        }
1326        log_det *= 2.0;
1327
1328        // ---- Backward pass: δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ), per-device tile ----
1329        let delta_beta_ref = &delta_beta_host;
1330        let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1331            back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1332        });
1333        if back_ok.is_none() {
1334            return Err(ArrowSchurGpuFailure::Unavailable);
1335        }
1336
1337        // Stitch per-row δt into the stacked (n*d) result.
1338        let mut delta_t = Array1::<f64>::zeros(n * d);
1339        for (i, slot) in slots.iter().enumerate() {
1340            let base = i * d;
1341            for r in 0..d {
1342                delta_t[base + r] = slot.delta_t_block[r];
1343            }
1344        }
1345
1346        Ok(ArrowSchurGpuSolution {
1347            delta_t,
1348            delta_beta,
1349            log_det_hessian: log_det,
1350        })
1351    }
1352
1353    /// Tile starts: the leading global row index of each device tile (where the
1354    /// tile-level partial reduction was written by the forward pass).
1355    fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1356        tiles.iter().map(|(_, range)| range.start)
1357    }
1358
1359    /// Forward pass for one device row tile, running on `ordinal`'s bound stream.
1360    /// Factors each row block, whitens `u`/`Y`, accumulates the tile's partial
1361    /// Schur `(Σ Y_iᵀY_i, Σ Y_iᵀu_i)` into the tile's leading slot, keeps the
1362    /// per-row `L`/`u`/`Y` on the host for back-sub, and records the per-row
1363    /// `Σ_j log L_jj`. A non-PD pivot is recorded in `slot.bump` (the tile still
1364    /// returns `Some(())` so the orchestrator raises a precise `RidgeBumpRequired`
1365    /// rather than collapsing the whole batch to CPU).
1366    fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1367        if tile.is_empty() {
1368            return Some(());
1369        }
1370        // `scatter_batched` has already bound this ordinal's context on this
1371        // worker thread; the stream below targets that same device.
1372        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1373            .and_then(|ctx| ctx.new_stream().ok())?;
1374        let solver = DnHandle::new(stream.clone()).ok()?;
1375        let blas = CudaBlas::new(stream.clone()).ok()?;
1376        let m = tile.len();
1377
1378        // Stack the tile's D, B, g into contiguous device buffers (same layout
1379        // the single-device path packs for `m` rows).
1380        let mut d_host = Vec::with_capacity(m * d * d);
1381        let mut b_host = Vec::with_capacity(m * d * k);
1382        let mut g_host = Vec::with_capacity(m * d);
1383        for slot in tile.iter() {
1384            d_host.extend_from_slice(&slot.d_block);
1385            b_host.extend_from_slice(&slot.b_block);
1386            g_host.extend_from_slice(&slot.g_vec);
1387        }
1388        let mut d_dev = stream.clone_htod(&d_host).ok()?;
1389        let mut b_dev = stream.clone_htod(&b_host).ok()?;
1390        let mut g_dev = stream.clone_htod(&g_host).ok()?;
1391
1392        // Batched POTRF; a non-PD block records its bump and stops the tile.
1393        // The bump is deficit-aware (Gershgorin lower bound on λ_min of the
1394        // already-ridged `d_block`), NOT derived from the cuSOLVER `info` —
1395        // which is a 1-based pivot ROW INDEX, not a pivot magnitude — so a
1396        // strongly-indefinite block recovers in one outer-loop retry.
1397        let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1398        if let Some(local) = info_host.iter().position(|info| *info != 0) {
1399            tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1400                &tile[local].d_block,
1401                d,
1402            ));
1403            return Some(());
1404        }
1405
1406        // Whiten: u = L⁻¹ g, Y = L⁻¹ B.
1407        trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1408        trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1409
1410        // Tile partial Schur: zero-seeded so the host adds the H_ββ seed once.
1411        let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1412        let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1413        accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1414
1415        // Download L, u, Y, and the tile partials.
1416        let l_host = stream.clone_dtoh(&d_dev).ok()?;
1417        let u_host = stream.clone_dtoh(&g_dev).ok()?;
1418        let y_host = stream.clone_dtoh(&b_dev).ok()?;
1419        let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1420        let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1421
1422        for (local, slot) in tile.iter_mut().enumerate() {
1423            let l_base = local * d * d;
1424            let u_base = local * d;
1425            let y_base = local * d * k;
1426            slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1427            slot.u_vec = u_host[u_base..u_base + d].to_vec();
1428            slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1429            let mut log_det_local = 0.0_f64;
1430            for j in 0..d {
1431                log_det_local += l_host[l_base + j * d + j].ln();
1432            }
1433            slot.log_det_local = log_det_local;
1434        }
1435        tile[0].tile_partial_schur = Some(partial_schur);
1436        tile[0].tile_partial_rhs = Some(partial_rhs);
1437        Some(())
1438    }
1439
1440    /// Back-substitution for one device row tile: `δt_i = -L_iᵀ⁻¹(u_i + Y_iδβ)`.
1441    /// Re-uploads the tile's kept `L`/`u`/`Y` to `ordinal`, applies the GEMV
1442    /// accumulate + transposed TRSM, and writes each row's `δt` into its slot.
1443    fn back_sub_tile(
1444        ordinal: usize,
1445        d: usize,
1446        k: usize,
1447        delta_beta: &[f64],
1448        tile: &mut [RowSlot],
1449    ) -> Option<()> {
1450        if tile.is_empty() {
1451            return Some(());
1452        }
1453        // `scatter_batched` has already bound this ordinal's context on this
1454        // worker thread; the stream below targets that same device.
1455        let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1456            .and_then(|ctx| ctx.new_stream().ok())?;
1457        let blas = CudaBlas::new(stream.clone()).ok()?;
1458        let m = tile.len();
1459
1460        let mut l_host = Vec::with_capacity(m * d * d);
1461        let mut u_host = Vec::with_capacity(m * d);
1462        let mut y_host = Vec::with_capacity(m * d * k);
1463        for slot in tile.iter() {
1464            l_host.extend_from_slice(&slot.l_block);
1465            u_host.extend_from_slice(&slot.u_vec);
1466            y_host.extend_from_slice(&slot.y_block);
1467        }
1468        let d_dev = stream.clone_htod(&l_host).ok()?;
1469        let mut g_dev = stream.clone_htod(&u_host).ok()?;
1470        let b_dev = stream.clone_htod(&y_host).ok()?;
1471        let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1472
1473        // g ← u + Y·δβ, then x = L⁻ᵀ g; δt = -x.
1474        accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1475        trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1476        let x_host = stream.clone_dtoh(&g_dev).ok()?;
1477        for (local, slot) in tile.iter_mut().enumerate() {
1478            let base = local * d;
1479            for r in 0..d {
1480                slot.delta_t_block[r] = -x_host[base + r];
1481            }
1482        }
1483        Some(())
1484    }
1485
1486    pub(super) fn solve(
1487        sys: &ArrowSchurSystem,
1488        ridge_t: f64,
1489        ridge_beta: f64,
1490    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1491        let n = sys.rows.len();
1492        let d = sys.d;
1493        let k = sys.k;
1494        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1495            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1496
1497        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1498            .and_then(|ctx| ctx.new_stream().ok())
1499            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1500        let solver =
1501            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1502        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1503
1504        // ----- Pack + upload D, B, g -----
1505        let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1506        let mut d_dev = stream
1507            .clone_htod(&d_host)
1508            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1509        let mut b_dev = stream
1510            .clone_htod(&b_host)
1511            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1512        let mut g_dev = stream
1513            .clone_htod(&g_host)
1514            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1515
1516        // ----- Layer A: batched lower Cholesky of D in place -----
1517        // This POTRF is fused with the downstream TRSM + Schur GEMM + back-sub
1518        // on this one stream, so splitting only the POTRF across devices would
1519        // orphan the dependent on-stream solves. Multi-GPU here is the
1520        // whole-solve row-block split in `solve_arrow_newton_step` (see
1521        // `solve_multi_gpu`), not a per-layer split — this device-resident path
1522        // is the single-device leaf the split dispatches per tile.
1523        let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1524        if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1525            // `info` is cuSOLVER's 1-based pivot ROW INDEX, not a magnitude;
1526            // size the bump from the block's own entries (Gershgorin λ_min
1527            // bound) so a strongly-indefinite block recovers in one retry.
1528            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1529                row: idx,
1530                bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1531            });
1532        }
1533
1534        // ----- Layer B (1/2): in-place triangular solves -----
1535        // u_i = L_i^{-1} g_i, packed as a stacked (n*d) column-vector.
1536        trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1537        // Y_i = L_i^{-1} B_i, in place over the (n*d) × k buffer (laid out as
1538        // n stacked column-major d×k tiles).
1539        trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1540
1541        // ----- Layer B (2/2): Schur reduction via single big GEMM / GEMV -----
1542        // Y_all is (n*d) × k column-major: viewing all n stacked d×k tiles as
1543        // one big matrix is bit-exact because each tile is column-major with
1544        // leading dim d and the tiles are contiguous in memory, so the
1545        // combined leading dim is n*d only for the *outer* matrix view. To
1546        // make the single-GEMM equivalence hold we must treat the stacked
1547        // buffer as (n*d) × k column-major with leading dim = n*d, which
1548        // means columns of Y_all are interleaved by row across blocks.
1549        // That is NOT what we packed. So we use the cuBLAS stride pattern
1550        // instead: stride-by-block, transpose-A, and *accumulate* into one
1551        // S_β buffer via beta=1 across batches. Equivalent flop count, no
1552        // extra reduction kernel, and correct layout.
1553        //
1554        // Concretely: schur ← C + ρ_β I; rhs ← -g_β; then for each block
1555        //   schur -= Y_i^T Y_i      (k×k)
1556        //   rhs   += Y_i^T u_i      (k)
1557        // We launch this as `n` sequential GEMMs/GEMVs with beta=1 on the
1558        // accumulator. Layer D fuses these into one NVRTC launch.
1559        let schur_init: Vec<f64> = {
1560            let mut tmp = Vec::with_capacity(k * k);
1561            for col in 0..k {
1562                for row in 0..k {
1563                    let mut v = sys.hbb[[row, col]];
1564                    if row == col {
1565                        v += ridge_beta;
1566                    }
1567                    tmp.push(v);
1568                }
1569            }
1570            tmp
1571        };
1572        let mut schur_dev = stream
1573            .clone_htod(&schur_init)
1574            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1575        let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1576        let mut rhs_dev = stream
1577            .clone_htod(&rhs_init)
1578            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1579
1580        accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1581
1582        // ----- Layer C (1/2): factor S_β and solve for δβ -----
1583        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1584        if info != 0 {
1585            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1586                reason: format!("Schur Cholesky failed at pivot {info}"),
1587            });
1588        }
1589        // δβ ← L_S^{-T} L_S^{-1} rhs
1590        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1591        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1592        let delta_beta_host = stream
1593            .clone_dtoh(&rhs_dev)
1594            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1595        let delta_beta = Array1::from_vec(delta_beta_host.clone());
1596
1597        // ----- Layer C (2/2): back-sub δt_i = -L_i^{-T} (u_i + Y_i δβ) -----
1598        // Already on device:
1599        //   g_dev holds u_i stacked (n*d).
1600        //   b_dev holds Y_i stacked column-major n×(d×k) tiles.
1601        // Compute g_dev ← g_dev + Y_block · δβ per block (cuBLAS gemv with beta=1),
1602        // then in-place trsm with L_i^T (CUBLAS_OP_T) to obtain x_i, and finally
1603        // δt_i = -x_i on host after download.
1604        accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1605        trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1606
1607        let x_host = stream
1608            .clone_dtoh(&g_dev)
1609            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1610        let mut delta_t = Array1::<f64>::zeros(n * d);
1611        for (i, v) in x_host.iter().enumerate() {
1612            delta_t[i] = -*v;
1613        }
1614
1615        // ----- log|H| = 2 Σ log L_{i,jj} + 2 Σ log R_{β,aa} -----
1616        let l_local_host = stream
1617            .clone_dtoh(&d_dev)
1618            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1619        let l_schur_host = stream
1620            .clone_dtoh(&schur_dev)
1621            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1622        let mut log_det = 0.0_f64;
1623        for i in 0..n {
1624            let base = i * d * d;
1625            for j in 0..d {
1626                log_det += l_local_host[base + j * d + j].ln();
1627            }
1628        }
1629        for j in 0..k {
1630            log_det += l_schur_host[j * k + j].ln();
1631        }
1632        log_det *= 2.0;
1633
1634        Ok(ArrowSchurGpuSolution {
1635            delta_t,
1636            delta_beta,
1637            log_det_hessian: log_det,
1638        })
1639    }
1640
1641    fn potrf_batched(
1642        solver: &DnHandle,
1643        stream: &Arc<CudaStream>,
1644        p: usize,
1645        batch: usize,
1646        matrices: &mut CudaSlice<f64>,
1647    ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1648        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1649        let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1650        let matrix_len = p * p;
1651        let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1652        let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1653        let mut ptrs = Vec::with_capacity(batch);
1654        for idx in 0..batch {
1655            ptrs.push(base_ptr + (idx as u64) * bytes_per);
1656        }
1657        let mut ptrs_dev = stream
1658            .clone_htod(&ptrs)
1659            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1660        let mut info_dev = stream
1661            .alloc_zeros::<i32>(batch)
1662            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1663        let status = {
1664            let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1665            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1666            // SAFETY: pointer array and info buffer live on the device,
1667            // matrices_dev holds `batch` contiguous p×p column-major blocks.
1668            unsafe {
1669                cusolver_sys::cusolverDnDpotrfBatched(
1670                    solver.cu(),
1671                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1672                    p_i,
1673                    ptrs_ptr as *mut *mut f64,
1674                    p_i,
1675                    info_ptr as *mut i32,
1676                    batch_i,
1677                )
1678            }
1679        };
1680        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1681            return Err(ArrowSchurGpuFailure::Unavailable);
1682        }
1683        stream
1684            .clone_dtoh(&info_dev)
1685            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1686    }
1687
1688    fn potrf_single(
1689        solver: &DnHandle,
1690        stream: &Arc<CudaStream>,
1691        p: usize,
1692        matrix: &mut CudaSlice<f64>,
1693    ) -> Result<i32, ArrowSchurGpuFailure> {
1694        let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1695        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1696        let mut lwork = 0_i32;
1697        {
1698            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1699            // SAFETY: buffer query against a live p-by-p column-major device matrix.
1700            let status = unsafe {
1701                cusolver_sys::cusolverDnDpotrf_bufferSize(
1702                    solver.cu(),
1703                    uplo,
1704                    p_i,
1705                    mat_ptr as *mut f64,
1706                    p_i,
1707                    &mut lwork,
1708                )
1709            };
1710            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1711                return Err(ArrowSchurGpuFailure::Unavailable);
1712            }
1713        }
1714        let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1715        let mut workspace = stream
1716            .alloc_zeros::<f64>(lwork_usize.max(1))
1717            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1718        let mut info_dev = stream
1719            .alloc_zeros::<i32>(1)
1720            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1721        {
1722            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1723            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1724            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1725            // SAFETY: all three pointers refer to live, correctly sized device buffers.
1726            let status = unsafe {
1727                cusolver_sys::cusolverDnDpotrf(
1728                    solver.cu(),
1729                    uplo,
1730                    p_i,
1731                    mat_ptr as *mut f64,
1732                    p_i,
1733                    work_ptr as *mut f64,
1734                    lwork,
1735                    info_ptr as *mut i32,
1736                )
1737            };
1738            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1739                return Err(ArrowSchurGpuFailure::Unavailable);
1740            }
1741        }
1742        let info_host = stream
1743            .clone_dtoh(&info_dev)
1744            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1745        Ok(info_host[0])
1746    }
1747
1748    /// In-place lower-triangular solves `X_i ← L_i^{-1} X_i` over the n stacked
1749    /// d×nrhs RHS tiles in `rhs`. Uses `cublasDtrsmBatched` so all n solves
1750    /// hit the device in one launch.
1751    fn trsm_batched_lower_inplace(
1752        blas: &CudaBlas,
1753        stream: &Arc<CudaStream>,
1754        d: usize,
1755        n: usize,
1756        nrhs: usize,
1757        l_stack: &CudaSlice<f64>,
1758        rhs_stack: &mut CudaSlice<f64>,
1759    ) -> Result<(), ArrowSchurGpuFailure> {
1760        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1761    }
1762
1763    /// As above but with `L_i^T` instead of `L_i`.
1764    fn trsm_batched_lower_inplace_transposed(
1765        blas: &CudaBlas,
1766        stream: &Arc<CudaStream>,
1767        d: usize,
1768        n: usize,
1769        nrhs: usize,
1770        l_stack: &CudaSlice<f64>,
1771        rhs_stack: &mut CudaSlice<f64>,
1772    ) -> Result<(), ArrowSchurGpuFailure> {
1773        trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1774    }
1775
1776    fn trsm_batched_inplace_inner(
1777        blas: &CudaBlas,
1778        stream: &Arc<CudaStream>,
1779        d: usize,
1780        n: usize,
1781        nrhs: usize,
1782        l_stack: &CudaSlice<f64>,
1783        rhs_stack: &mut CudaSlice<f64>,
1784        transposed: bool,
1785    ) -> Result<(), ArrowSchurGpuFailure> {
1786        let alpha = 1.0_f64;
1787        let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1788        let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1789        let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1790        let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1791        let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1792        let (l_base, _l_record) = l_stack.device_ptr(stream);
1793        let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1794        let mut l_ptrs = Vec::with_capacity(n);
1795        let mut rhs_ptrs = Vec::with_capacity(n);
1796        for i in 0..n {
1797            l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1798            rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1799        }
1800        let mut l_ptrs_dev = stream
1801            .clone_htod(&l_ptrs)
1802            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1803        let mut rhs_ptrs_dev = stream
1804            .clone_htod(&rhs_ptrs)
1805            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1806        let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1807        let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1808        let op = if transposed {
1809            cublasOperation_t::CUBLAS_OP_T
1810        } else {
1811            cublasOperation_t::CUBLAS_OP_N
1812        };
1813        let handle = *blas.handle();
1814        // SAFETY: pointer arrays and base buffers were just constructed from
1815        // live device allocations covering the entire batch.
1816        let status = unsafe {
1817            cudarc::cublas::sys::cublasDtrsmBatched(
1818                handle,
1819                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1820                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1821                op,
1822                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1823                d_i,
1824                nrhs_i,
1825                &alpha,
1826                l_ptrs_ptr as *const *const f64,
1827                d_i,
1828                rhs_ptrs_ptr as *const *mut f64,
1829                d_i,
1830                batch_i,
1831            )
1832        };
1833        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1834            return Err(ArrowSchurGpuFailure::Unavailable);
1835        }
1836        Ok(())
1837    }
1838
1839    /// Single-matrix lower-triangular solve: `rhs ← L^{-1} rhs` (or
1840    /// `L^{-T} rhs` if `transposed`). For the Schur Cholesky back-sub.
1841    fn trsm_single(
1842        blas: &CudaBlas,
1843        stream: &Arc<CudaStream>,
1844        n: usize,
1845        l: &CudaSlice<f64>,
1846        rhs: &mut CudaSlice<f64>,
1847        upper: bool,
1848        transposed: bool,
1849    ) -> Result<(), ArrowSchurGpuFailure> {
1850        let alpha = 1.0_f64;
1851        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1852        let handle = *blas.handle();
1853        let (l_ptr, _l_rec) = l.device_ptr(stream);
1854        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1855        // SAFETY: single n×n lower factor and n-vector RHS on device.
1856        let status = unsafe {
1857            cudarc::cublas::sys::cublasDtrsm_v2(
1858                handle,
1859                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1860                if upper {
1861                    cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1862                } else {
1863                    cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1864                },
1865                if transposed {
1866                    cublasOperation_t::CUBLAS_OP_T
1867                } else {
1868                    cublasOperation_t::CUBLAS_OP_N
1869                },
1870                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1871                n_i,
1872                1,
1873                &alpha,
1874                l_ptr as *const f64,
1875                n_i,
1876                rhs_ptr as *mut f64,
1877                n_i,
1878            )
1879        };
1880        if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1881            return Err(ArrowSchurGpuFailure::Unavailable);
1882        }
1883        Ok(())
1884    }
1885
1886    /// Accumulate `schur ← schur − Σ_i Y_i^T Y_i` and `rhs ← rhs + Σ_i Y_i^T u_i`
1887    /// using one GEMM and one GEMV per block. Each call uses beta=1 to chain
1888    /// the accumulation device-side.
1889    fn accumulate_schur(
1890        blas: &CudaBlas,
1891        d: usize,
1892        k: usize,
1893        n: usize,
1894        y_stack: &CudaSlice<f64>,
1895        u_stack: &CudaSlice<f64>,
1896        schur: &mut CudaSlice<f64>,
1897        rhs: &mut CudaSlice<f64>,
1898    ) -> Result<(), ArrowSchurGpuFailure> {
1899        let y_block_elems = d * k;
1900        let u_block_elems = d;
1901        for i in 0..n {
1902            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1903            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1904            // GEMM: schur += (-1) · Y_i^T · Y_i  (Y_i is d×k col-major; out is k×k)
1905            let gemm_cfg = GemmConfig::<f64> {
1906                transa: cublasOperation_t::CUBLAS_OP_T,
1907                transb: cublasOperation_t::CUBLAS_OP_N,
1908                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1909                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1910                k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1911                alpha: -1.0,
1912                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1913                ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1914                beta: 1.0,
1915                ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1916            };
1917            // SAFETY: y_slice is d×k col-major, schur is k×k col-major; alpha/beta scalars set above.
1918            unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1919                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1920            // GEMV: rhs += 1 · Y_i^T · u_i
1921            let gemv_cfg = GemvConfig::<f64> {
1922                trans: cublasOperation_t::CUBLAS_OP_T,
1923                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1924                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1925                alpha: 1.0,
1926                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1927                incx: 1,
1928                beta: 1.0,
1929                incy: 1,
1930            };
1931            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1932            // device buffers; `rhs` is the length-k accumulator.
1933            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1934                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1935        }
1936        Ok(())
1937    }
1938
1939    /// `#1017` resident gradient path: accumulate ONLY the Schur RHS term
1940    /// `rhs += Σ_i Y_iᵀ u_i`, skipping the `−Σ_i Y_iᵀ Y_i` matrix GEMM that the
1941    /// resident frame already folded into its persistent `L_S` factor. This is
1942    /// the per-iterate-cheap counterpart of [`accumulate_schur`]: the GEMV here
1943    /// is bit-identical to the GEMV inside `accumulate_schur` (same config, same
1944    /// `beta=1` accumulation order over rows), so the resident frame's `δβ`
1945    /// matches a full `solve()` at the same gradient.
1946    fn accumulate_schur_rhs_only(
1947        blas: &CudaBlas,
1948        d: usize,
1949        k: usize,
1950        n: usize,
1951        y_stack: &CudaSlice<f64>,
1952        u_stack: &CudaSlice<f64>,
1953        rhs: &mut CudaSlice<f64>,
1954    ) -> Result<(), ArrowSchurGpuFailure> {
1955        let y_block_elems = d * k;
1956        let u_block_elems = d;
1957        for i in 0..n {
1958            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1959            let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1960            let gemv_cfg = GemvConfig::<f64> {
1961                trans: cublasOperation_t::CUBLAS_OP_T,
1962                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1963                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1964                alpha: 1.0,
1965                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1966                incx: 1,
1967                beta: 1.0,
1968                incy: 1,
1969            };
1970            // SAFETY: y_slice (d×k col-major) and u_slice (length d) are live
1971            // device buffers; `rhs` is the length-k accumulator.
1972            unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1973                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1974        }
1975        Ok(())
1976    }
1977
1978    /// Accumulate `g_dev[i] ← u_i + Y_i · δβ` per block. This is the
1979    /// pre-trsm RHS for the back-substitution `L_i^T x_i = w_i`.
1980    fn accumulate_back_sub_rhs(
1981        blas: &CudaBlas,
1982        d: usize,
1983        k: usize,
1984        n: usize,
1985        y_stack: &CudaSlice<f64>,
1986        delta_beta: &CudaSlice<f64>,
1987        u_stack: &mut CudaSlice<f64>,
1988    ) -> Result<(), ArrowSchurGpuFailure> {
1989        let y_block_elems = d * k;
1990        let u_block_elems = d;
1991        for i in 0..n {
1992            let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1993            let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1994            let gemv_cfg = GemvConfig::<f64> {
1995                trans: cublasOperation_t::CUBLAS_OP_N,
1996                m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1997                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1998                alpha: 1.0,
1999                lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
2000                incx: 1,
2001                beta: 1.0,
2002                incy: 1,
2003            };
2004            // SAFETY: y_slice / delta_beta / u_slice are live device buffers
2005            // of the expected sizes (d×k, k, d).
2006            unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
2007                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2008        }
2009        Ok(())
2010    }
2011
2012    // ────────────────────────────────────────────────────────────────────
2013    // Layer D + E — fused NVRTC dispatch.
2014    //
2015    // The forward kernel (`arrow_schur_forward_pgroup`) is a single launch
2016    // that, per row block, factors `D_i + ρI = L_i L_iᵀ` in shared memory,
2017    // forward-solves `u_i = L_i⁻¹ g_i` and `Y_i = L_i⁻¹ B_i`, and emits the
2018    // per-block Schur partials `partial_S[i] = Yᵀ Y` (R×R) and
2019    // `partial_r[i] = Yᵀ u` (R). The host reduces partials on the CPU after
2020    // dtoh (one fused sum across `n` blocks of R²+R doubles; cheap because
2021    // n·R² ≲ 5M doubles at large scale), assembles `S_β`, factors it via
2022    // cuSOLVER, and launches the back-substitution kernel
2023    // `arrow_schur_back_sub_pgroup` to recover `δt_i = -L_i⁻ᵀ(u_i + Y_i δβ)`
2024    // without re-uploading the local factors.
2025    // ────────────────────────────────────────────────────────────────────
2026
2027    use std::collections::HashMap;
2028    use std::sync::Mutex;
2029
2030    /// One compiled NVRTC module per `(cc_major, cc_minor, p_max, r_template)`.
2031    /// `cc_*` lets one process drive multiple device generations; the
2032    /// `(p_max, r_template)` pair selects the shared-memory layout baked into
2033    /// the kernel source.
2034    struct FusedModuleCache {
2035        modules: Mutex<
2036            HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
2037        >,
2038    }
2039
2040    fn fused_module_cache() -> &'static FusedModuleCache {
2041        static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
2042        CACHE.get_or_init(|| FusedModuleCache {
2043            modules: Mutex::new(HashMap::new()),
2044        })
2045    }
2046
2047    fn fused_module_for(
2048        ctx: &Arc<CudaContext>,
2049        key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2050    ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2051        let cache = fused_module_cache();
2052        if let Ok(guard) = cache.modules.lock() {
2053            if let Some(existing) = guard.get(&key) {
2054                return Ok(existing.clone());
2055            }
2056        }
2057        let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2058            key.p_max as usize,
2059            key.r_template as usize,
2060        );
2061        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).map_err(|err| {
2062            ArrowSchurGpuFailure::SchurFactorFailed {
2063                reason: format!(
2064                    "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2065                    key.p_max, key.r_template
2066                ),
2067            }
2068        })?;
2069        let module = ctx
2070            .load_module(ptx)
2071            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2072        if let Ok(mut guard) = cache.modules.lock() {
2073            guard.entry(key).or_insert_with(|| module.clone());
2074        }
2075        Ok(module)
2076    }
2077
2078    const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2079extern "C" __global__ void arrow_pcg_jacobi_mul(
2080    const double* __restrict__ inv_diag,
2081    const double* __restrict__ r,
2082    double* __restrict__ z,
2083    int n
2084) {
2085    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2086    if (idx < n) {
2087        z[idx] = inv_diag[idx] * r[idx];
2088    }
2089}
2090
2091extern "C" __global__ void arrow_pcg_update_p(
2092    const double* __restrict__ z,
2093    double beta,
2094    double* __restrict__ p,
2095    int n
2096) {
2097    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2098    if (idx < n) {
2099        p[idx] = z[idx] + beta * p[idx];
2100    }
2101}
2102
2103extern "C" __global__ void arrow_sae_init(
2104    double* __restrict__ out,
2105    const double* __restrict__ x,
2106    double ridge,
2107    int n
2108) {
2109    int idx = blockIdx.x * blockDim.x + threadIdx.x;
2110    if (idx < n) {
2111        out[idx] = ridge * x[idx];
2112    }
2113}
2114
2115extern "C" __global__ void arrow_sae_smooth_matvec(
2116    const double* __restrict__ x,
2117    double* __restrict__ out,
2118    const int* __restrict__ block_offsets,
2119    const int* __restrict__ block_m,
2120    const int* __restrict__ factor_ptr,
2121    const double* __restrict__ factors,
2122    int p,
2123    int n_blocks
2124) {
2125    int block_id = blockIdx.y;
2126    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2127    if (block_id >= n_blocks) {
2128        return;
2129    }
2130    int m = block_m[block_id];
2131    int total = m * p;
2132    if (linear >= total) {
2133        return;
2134    }
2135    int li = linear / p;
2136    int oc = linear - li * p;
2137    int off = block_offsets[block_id];
2138    int fbase = factor_ptr[block_id];
2139    double acc = 0.0;
2140    for (int lj = 0; lj < m; ++lj) {
2141        double a = factors[fbase + li * m + lj];
2142        acc += a * x[off + lj * p + oc];
2143    }
2144    out[off + li * p + oc] += acc;
2145}
2146
2147extern "C" __global__ void arrow_sae_sparse_g_matvec(
2148    const double* __restrict__ x,
2149    double* __restrict__ out,
2150    const int* __restrict__ row_off,
2151    const int* __restrict__ col_off,
2152    const int* __restrict__ rows,
2153    const int* __restrict__ cols,
2154    const int* __restrict__ data_ptr,
2155    const double* __restrict__ data,
2156    int p,
2157    int n_blocks
2158) {
2159    int block_id = blockIdx.y;
2160    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2161    if (block_id >= n_blocks) {
2162        return;
2163    }
2164    int m_i = rows[block_id];
2165    int m_j = cols[block_id];
2166    int total = m_i * p;
2167    if (linear >= total) {
2168        return;
2169    }
2170    int li = linear / p;
2171    int oc = linear - li * p;
2172    int rbase = row_off[block_id];
2173    int cbase = col_off[block_id];
2174    int dbase = data_ptr[block_id];
2175    double acc = 0.0;
2176    for (int lj = 0; lj < m_j; ++lj) {
2177        acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2178    }
2179    // #1017 — a row atom co-occurs with multiple column atoms, so several
2180    // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2181    // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2182    // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2183    // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2184    atomicAdd(&out[(rbase + li) * p + oc], acc);
2185}
2186
2187extern "C" __global__ void arrow_sae_gather_u(
2188    const double* __restrict__ x,
2189    const int* __restrict__ row_ptr,
2190    const int* __restrict__ beta_base,
2191    const double* __restrict__ phi,
2192    double* __restrict__ u,
2193    int p,
2194    int n_rows
2195) {
2196    int row = blockIdx.y;
2197    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2198    if (row >= n_rows || oc >= p) {
2199        return;
2200    }
2201    double acc = 0.0;
2202    int start = row_ptr[row];
2203    int end = row_ptr[row + 1];
2204    for (int e = start; e < end; ++e) {
2205        acc += phi[e] * x[beta_base[e] + oc];
2206    }
2207    u[row * p + oc] = acc;
2208}
2209
2210extern "C" __global__ void arrow_sae_apply_l(
2211    const double* __restrict__ u,
2212    const int* __restrict__ jac_ptr,
2213    const double* __restrict__ jac,
2214    double* __restrict__ w,
2215    int p,
2216    int max_q,
2217    int n_rows
2218) {
2219    int row = blockIdx.y;
2220    int c = blockIdx.x * blockDim.x + threadIdx.x;
2221    if (row >= n_rows) {
2222        return;
2223    }
2224    int jstart = jac_ptr[row];
2225    int q = (jac_ptr[row + 1] - jstart) / p;
2226    if (c >= q) {
2227        return;
2228    }
2229    double acc = 0.0;
2230    for (int oc = 0; oc < p; ++oc) {
2231        acc += jac[jstart + c * p + oc] * u[row * p + oc];
2232    }
2233    w[row * max_q + c] = acc;
2234}
2235
2236extern "C" __global__ void arrow_sae_apply_ainv(
2237    const double* __restrict__ ainv,
2238    const double* __restrict__ w,
2239    double* __restrict__ v,
2240    int max_q,
2241    int n_rows
2242) {
2243    int row = blockIdx.y;
2244    int c = blockIdx.x * blockDim.x + threadIdx.x;
2245    if (row >= n_rows || c >= max_q) {
2246        return;
2247    }
2248    double acc = 0.0;
2249    int base = row * max_q * max_q;
2250    for (int j = 0; j < max_q; ++j) {
2251        acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2252    }
2253    v[row * max_q + c] = acc;
2254}
2255
2256extern "C" __global__ void arrow_sae_scatter_sub(
2257    const double* __restrict__ v,
2258    const int* __restrict__ jac_ptr,
2259    const double* __restrict__ jac,
2260    const int* __restrict__ row_ptr,
2261    const int* __restrict__ beta_base,
2262    const double* __restrict__ phi,
2263    double* __restrict__ out,
2264    int p,
2265    int max_q,
2266    int n_rows
2267) {
2268    int row = blockIdx.y;
2269    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2270    if (row >= n_rows || oc >= p) {
2271        return;
2272    }
2273    int jstart = jac_ptr[row];
2274    int q = (jac_ptr[row + 1] - jstart) / p;
2275    double lt_v = 0.0;
2276    for (int c = 0; c < q; ++c) {
2277        lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2278    }
2279    int start = row_ptr[row];
2280    int end = row_ptr[row + 1];
2281    for (int e = start; e < end; ++e) {
2282        atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2283    }
2284}
2285
2286extern "C" __global__ void arrow_sae_diag_sub(
2287    double* __restrict__ diag,
2288    const double* __restrict__ ainv,
2289    const int* __restrict__ jac_ptr,
2290    const double* __restrict__ jac,
2291    const int* __restrict__ row_ptr,
2292    const int* __restrict__ beta_base,
2293    const double* __restrict__ phi,
2294    int p,
2295    int max_q,
2296    int n_rows
2297) {
2298    int row = blockIdx.y;
2299    int oc = blockIdx.x * blockDim.x + threadIdx.x;
2300    if (row >= n_rows || oc >= p) {
2301        return;
2302    }
2303    int jstart = jac_ptr[row];
2304    int q = (jac_ptr[row + 1] - jstart) / p;
2305    int abase = row * max_q * max_q;
2306    double quad = 0.0;
2307    for (int c = 0; c < q; ++c) {
2308        double lc = jac[jstart + c * p + oc];
2309        for (int d = 0; d < q; ++d) {
2310            quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2311        }
2312    }
2313    int start = row_ptr[row];
2314    int end = row_ptr[row + 1];
2315    for (int e = start; e < end; ++e) {
2316        double pe = phi[e];
2317        atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2318    }
2319}
2320
2321/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2322 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2323 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2324 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2325 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2326
2327extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2328    const double* __restrict__ x,
2329    double* __restrict__ out,
2330    const int* __restrict__ block_offsets,
2331    const int* __restrict__ block_m,
2332    const int* __restrict__ block_r,
2333    const int* __restrict__ factor_ptr,
2334    const double* __restrict__ factors,
2335    int n_blocks
2336) {
2337    int block_id = blockIdx.y;
2338    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2339    if (block_id >= n_blocks) {
2340        return;
2341    }
2342    int m = block_m[block_id];
2343    int r = block_r[block_id];
2344    int total = m * r;
2345    if (linear >= total) {
2346        return;
2347    }
2348    int li = linear / r;
2349    int ib = linear - li * r;
2350    int off = block_offsets[block_id];
2351    int fbase = factor_ptr[block_id];
2352    double acc = 0.0;
2353    for (int lj = 0; lj < m; ++lj) {
2354        double a = factors[fbase + li * m + lj];
2355        acc += a * x[off + lj * r + ib];
2356    }
2357    out[off + li * r + ib] += acc;
2358}
2359
2360extern "C" __global__ void arrow_sae_frame_g_matvec(
2361    const double* __restrict__ x,
2362    double* __restrict__ out,
2363    const int* __restrict__ off_i,
2364    const int* __restrict__ off_j,
2365    const int* __restrict__ r_i,
2366    const int* __restrict__ r_j,
2367    const int* __restrict__ m_i,
2368    const int* __restrict__ m_j,
2369    const int* __restrict__ g_ptr,
2370    const double* __restrict__ g_data,
2371    const int* __restrict__ w_ptr,
2372    const double* __restrict__ w_data,
2373    int n_blocks
2374) {
2375    int block_id = blockIdx.y;
2376    int linear = blockIdx.x * blockDim.x + threadIdx.x;
2377    if (block_id >= n_blocks) {
2378        return;
2379    }
2380    int ri = r_i[block_id];
2381    int rj = r_j[block_id];
2382    int mi = m_i[block_id];
2383    int mj = m_j[block_id];
2384    int total = mi * ri;
2385    if (linear >= total) {
2386        return;
2387    }
2388    int li = linear / ri;       // basis row in atom i
2389    int a = linear - li * ri;   // frame coord in atom i
2390    int oi = off_i[block_id];
2391    int oj = off_j[block_id];
2392    int gbase = g_ptr[block_id];
2393    int wbase = w_ptr[block_id];
2394    double acc = 0.0;
2395    for (int lj = 0; lj < mj; ++lj) {
2396        double g = g_data[gbase + li * mj + lj];
2397        if (g == 0.0) { continue; }
2398        int xj_base = oj + lj * rj;
2399        double inner = 0.0;
2400        for (int b = 0; b < rj; ++b) {
2401            inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2402        }
2403        acc += g * inner;
2404    }
2405    // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2406    // multiple co-occurring (i,j) frame blocks running concurrently on
2407    // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2408    // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2409    atomicAdd(&out[oi + li * ri + a], acc);
2410}
2411
2412/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2413 *   h_i   = H_tβ^(i) · x                (length q_i)
2414 *   s_i   = (H_tt^(i)+ρ_t I)⁻¹ h_i      (apply cached ainv, length q_i)
2415 *   out  -= (H_tβ^(i))ᵀ · s_i           (scatter into border_dim)
2416 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2417 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2418extern "C" __global__ void arrow_sae_frame_apply_h(
2419    const double* __restrict__ x,
2420    const int* __restrict__ htb_ptr,
2421    const double* __restrict__ htb,
2422    const int* __restrict__ q_of,
2423    double* __restrict__ hvec,
2424    int k,
2425    int max_q,
2426    int n_rows
2427) {
2428    int row = blockIdx.y;
2429    int c = blockIdx.x * blockDim.x + threadIdx.x;
2430    if (row >= n_rows) { return; }
2431    int q = q_of[row];
2432    if (c >= q) { return; }
2433    int base = htb_ptr[row] + c * k;
2434    double acc = 0.0;
2435    for (int a = 0; a < k; ++a) {
2436        acc += htb[base + a] * x[a];
2437    }
2438    hvec[row * max_q + c] = acc;
2439}
2440
2441extern "C" __global__ void arrow_sae_frame_apply_ainv(
2442    const double* __restrict__ ainv,
2443    const double* __restrict__ hvec,
2444    const int* __restrict__ q_of,
2445    double* __restrict__ svec,
2446    int max_q,
2447    int n_rows
2448) {
2449    int row = blockIdx.y;
2450    int c = blockIdx.x * blockDim.x + threadIdx.x;
2451    if (row >= n_rows || c >= max_q) { return; }
2452    int q = q_of[row];
2453    double acc = 0.0;
2454    int abase = row * max_q * max_q;
2455    for (int j = 0; j < q; ++j) {
2456        acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2457    }
2458    svec[row * max_q + c] = acc;
2459}
2460
2461extern "C" __global__ void arrow_sae_frame_scatter_h(
2462    const double* __restrict__ svec,
2463    const int* __restrict__ htb_ptr,
2464    const double* __restrict__ htb,
2465    const int* __restrict__ q_of,
2466    double* __restrict__ out,
2467    int k,
2468    int max_q,
2469    int n_rows
2470) {
2471    int row = blockIdx.y;
2472    int a = blockIdx.x * blockDim.x + threadIdx.x;
2473    if (row >= n_rows || a >= k) { return; }
2474    int q = q_of[row];
2475    int hbase = htb_ptr[row];
2476    double acc = 0.0;
2477    for (int c = 0; c < q; ++c) {
2478        acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2479    }
2480    atomicAdd(&out[a], -acc);
2481}
2482
2483/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2484extern "C" __global__ void arrow_sae_frame_diag_sub(
2485    double* __restrict__ diag,
2486    const double* __restrict__ ainv,
2487    const int* __restrict__ htb_ptr,
2488    const double* __restrict__ htb,
2489    const int* __restrict__ q_of,
2490    int k,
2491    int max_q,
2492    int n_rows
2493) {
2494    int row = blockIdx.y;
2495    int a = blockIdx.x * blockDim.x + threadIdx.x;
2496    if (row >= n_rows || a >= k) { return; }
2497    int q = q_of[row];
2498    int hbase = htb_ptr[row];
2499    int abase = row * max_q * max_q;
2500    double quad = 0.0;
2501    for (int c = 0; c < q; ++c) {
2502        double hc = htb[hbase + c * k + a];
2503        for (int d = 0; d < q; ++d) {
2504            quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2505        }
2506    }
2507    atomicAdd(&diag[a], -quad);
2508}
2509"#;
2510
2511    fn pcg_vector_module(
2512        ctx: &Arc<CudaContext>,
2513    ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2514        static CACHE: gam_gpu::device_cache::PtxModuleCache =
2515            gam_gpu::device_cache::PtxModuleCache::new();
2516        CACHE
2517            .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2518            .map_err(|err| {
2519                // #1551: an NVRTC compile / module-load failure of
2520                // PCG_VECTOR_KERNEL_SOURCE means the device SAE PCG cannot run;
2521                // log it (the historical silent collapse to `Unavailable` is what
2522                // masked the missing `--gpu-architecture` for so long) and fall
2523                // back to the CPU.
2524                log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2525                ArrowSchurGpuFailure::Unavailable
2526            })
2527    }
2528
2529    fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2530        let threads = 256u32;
2531        let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2532        Ok(LaunchConfig {
2533            grid_dim: (blocks, 1, 1),
2534            block_dim: (threads, 1, 1),
2535            shared_mem_bytes: 0,
2536        })
2537    }
2538
2539    fn launch_jacobi_mul(
2540        stream: &Arc<CudaStream>,
2541        module: &Arc<CudaModule>,
2542        inv_diag: &CudaSlice<f64>,
2543        r: &CudaSlice<f64>,
2544        z: &mut CudaSlice<f64>,
2545        n: usize,
2546    ) -> Result<(), ArrowSchurGpuFailure> {
2547        let kernel = module
2548            .load_function("arrow_pcg_jacobi_mul")
2549            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2550        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2551        let mut builder = stream.launch_builder(&kernel);
2552        builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2553        // SAFETY: all buffers have length n and belong to `stream`; the kernel only
2554        // reads/writes indices `< n`.
2555        unsafe { builder.launch(pcg_launch_config(n)?) }
2556            .map(drop)
2557            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2558    }
2559
2560    fn launch_update_p(
2561        stream: &Arc<CudaStream>,
2562        module: &Arc<CudaModule>,
2563        z: &CudaSlice<f64>,
2564        beta: f64,
2565        p: &mut CudaSlice<f64>,
2566        n: usize,
2567    ) -> Result<(), ArrowSchurGpuFailure> {
2568        let kernel = module
2569            .load_function("arrow_pcg_update_p")
2570            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2571        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2572        let mut builder = stream.launch_builder(&kernel);
2573        builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2574        // SAFETY: z/p both have length n and belong to `stream`; the kernel only
2575        // reads/writes indices `< n`.
2576        unsafe { builder.launch(pcg_launch_config(n)?) }
2577            .map(drop)
2578            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2579    }
2580
2581    struct DeviceSaePcgBuffers {
2582        row_ptr: CudaSlice<i32>,
2583        beta_base: CudaSlice<i32>,
2584        phi: CudaSlice<f64>,
2585        jac_ptr: CudaSlice<i32>,
2586        jac: CudaSlice<f64>,
2587        smooth_offsets: CudaSlice<i32>,
2588        smooth_m: CudaSlice<i32>,
2589        smooth_ptr: CudaSlice<i32>,
2590        smooth_data: CudaSlice<f64>,
2591        g_row_off: CudaSlice<i32>,
2592        g_col_off: CudaSlice<i32>,
2593        g_rows: CudaSlice<i32>,
2594        g_cols: CudaSlice<i32>,
2595        g_ptr: CudaSlice<i32>,
2596        g_data: CudaSlice<f64>,
2597        ainv: CudaSlice<f64>,
2598        u: CudaSlice<f64>,
2599        w: CudaSlice<f64>,
2600        v: CudaSlice<f64>,
2601        n_rows: usize,
2602        p: usize,
2603        k: usize,
2604        max_q: usize,
2605        smooth_blocks: usize,
2606        g_blocks: usize,
2607    }
2608
2609    fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2610        to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2611    }
2612
2613    fn sae_penalty_diag_host(
2614        data: &DeviceSaePcgData,
2615        ridge_beta: f64,
2616    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2617        let mut diag = vec![ridge_beta; data.beta_dim];
2618        for block in &data.smooth_blocks {
2619            let (rows, cols) = block.factor_a.dim();
2620            if rows != cols {
2621                return Err(ArrowSchurGpuFailure::Unavailable);
2622            }
2623            for row in 0..rows {
2624                let coeff = block.factor_a[[row, row]];
2625                let base = block
2626                    .global_offset
2627                    .checked_add(
2628                        row.checked_mul(data.p)
2629                            .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2630                    )
2631                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2632                let end = base
2633                    .checked_add(data.p)
2634                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2635                if end > diag.len() {
2636                    return Err(ArrowSchurGpuFailure::Unavailable);
2637                }
2638                for channel in 0..data.p {
2639                    diag[base + channel] += coeff;
2640                }
2641            }
2642        }
2643        for block in &data.sparse_g_blocks {
2644            if block.row_off != block.col_off {
2645                continue;
2646            }
2647            let (rows, cols) = block.data.dim();
2648            for row in 0..rows.min(cols) {
2649                let coeff = block.data[[row, row]];
2650                let beta_row = block
2651                    .row_off
2652                    .checked_add(row)
2653                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2654                let base = beta_row
2655                    .checked_mul(data.p)
2656                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2657                let end = base
2658                    .checked_add(data.p)
2659                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2660                if end > diag.len() {
2661                    return Err(ArrowSchurGpuFailure::Unavailable);
2662                }
2663                for channel in 0..data.p {
2664                    diag[base + channel] += coeff;
2665                }
2666            }
2667        }
2668        Ok(diag)
2669    }
2670
2671    fn flatten_device_sae_data(
2672        sys: &ArrowSchurSystem,
2673        data: &DeviceSaePcgData,
2674        ridge_t: f64,
2675        stream: &Arc<CudaStream>,
2676    ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2677        let n_rows = sys.rows.len();
2678        let p = data.p;
2679        let k = data.beta_dim;
2680        if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2681            return Err(ArrowSchurGpuFailure::Unavailable);
2682        }
2683
2684        let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2685        let mut beta_base_host = Vec::<i32>::new();
2686        let mut phi_host = Vec::<f64>::new();
2687        row_ptr_host.push(0_i32);
2688        for row in data.a_phi.iter() {
2689            for &(base, phi) in row {
2690                beta_base_host.push(checked_i32(base)?);
2691                phi_host.push(phi);
2692            }
2693            row_ptr_host.push(checked_i32(beta_base_host.len())?);
2694        }
2695
2696        let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2697        let mut jac_host = Vec::<f64>::new();
2698        let mut max_q = 0usize;
2699        jac_ptr_host.push(0_i32);
2700        for row_jac in data.local_jac.iter() {
2701            if row_jac.len() % p != 0 {
2702                return Err(ArrowSchurGpuFailure::Unavailable);
2703            }
2704            max_q = max_q.max(row_jac.len() / p);
2705            jac_host.extend_from_slice(row_jac);
2706            jac_ptr_host.push(checked_i32(jac_host.len())?);
2707        }
2708        if max_q == 0 {
2709            return Err(ArrowSchurGpuFailure::Unavailable);
2710        }
2711
2712        let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2713        let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2714        let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2715        let mut smooth_data_host = Vec::<f64>::new();
2716        smooth_ptr_host.push(0_i32);
2717        for block in &data.smooth_blocks {
2718            let (rows, cols) = block.factor_a.dim();
2719            if rows != cols {
2720                return Err(ArrowSchurGpuFailure::Unavailable);
2721            }
2722            smooth_offsets_host.push(checked_i32(block.global_offset)?);
2723            smooth_m_host.push(checked_i32(rows)?);
2724            for r in 0..rows {
2725                for c in 0..cols {
2726                    smooth_data_host.push(block.factor_a[[r, c]]);
2727                }
2728            }
2729            smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2730        }
2731
2732        let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2733        let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2734        let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2735        let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2736        let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2737        let mut g_data_host = Vec::<f64>::new();
2738        g_ptr_host.push(0_i32);
2739        for block in &data.sparse_g_blocks {
2740            let (rows, cols) = block.data.dim();
2741            g_row_off_host.push(checked_i32(block.row_off)?);
2742            g_col_off_host.push(checked_i32(block.col_off)?);
2743            g_rows_host.push(checked_i32(rows)?);
2744            g_cols_host.push(checked_i32(cols)?);
2745            for r in 0..rows {
2746                for c in 0..cols {
2747                    g_data_host.push(block.data[[r, c]]);
2748                }
2749            }
2750            g_ptr_host.push(checked_i32(g_data_host.len())?);
2751        }
2752
2753        let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2754        for (row_idx, row) in sys.rows.iter().enumerate() {
2755            let q = data.local_jac[row_idx].len() / p;
2756            if row.htt.dim() != (q, q) {
2757                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2758                    reason: format!(
2759                        "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2760                        row.htt.dim()
2761                    ),
2762                });
2763            }
2764            let mut block = row.htt.clone();
2765            for d in 0..q {
2766                block[[d, d]] += ridge_t;
2767            }
2768            let factor = gam_linalg::triangular::cholesky_factor_in_place(
2769                block.view(),
2770                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2771            )
2772            .ok_or_else(|| {
2773                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
2774                // indefinite per-row block recovers in one outer-loop retry.
2775                ArrowSchurGpuFailure::RidgeBumpRequired {
2776                    row: row_idx,
2777                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2778                }
2779            })?;
2780            for col in 0..q {
2781                let mut e = Array1::<f64>::zeros(q);
2782                e[col] = 1.0;
2783                let solved =
2784                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2785                for r in 0..q {
2786                    ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2787                }
2788            }
2789        }
2790
2791        Ok(DeviceSaePcgBuffers {
2792            row_ptr: stream
2793                .clone_htod(&row_ptr_host)
2794                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795            beta_base: stream
2796                .clone_htod(&beta_base_host)
2797                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798            phi: stream
2799                .clone_htod(&phi_host)
2800                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801            jac_ptr: stream
2802                .clone_htod(&jac_ptr_host)
2803                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804            jac: stream
2805                .clone_htod(&jac_host)
2806                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2807            smooth_offsets: stream
2808                .clone_htod(&smooth_offsets_host)
2809                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2810            smooth_m: stream
2811                .clone_htod(&smooth_m_host)
2812                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2813            smooth_ptr: stream
2814                .clone_htod(&smooth_ptr_host)
2815                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2816            smooth_data: stream
2817                .clone_htod(&smooth_data_host)
2818                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2819            g_row_off: stream
2820                .clone_htod(&g_row_off_host)
2821                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2822            g_col_off: stream
2823                .clone_htod(&g_col_off_host)
2824                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2825            g_rows: stream
2826                .clone_htod(&g_rows_host)
2827                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2828            g_cols: stream
2829                .clone_htod(&g_cols_host)
2830                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2831            g_ptr: stream
2832                .clone_htod(&g_ptr_host)
2833                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2834            g_data: stream
2835                .clone_htod(&g_data_host)
2836                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2837            ainv: stream
2838                .clone_htod(&ainv_host)
2839                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2840            u: stream
2841                .alloc_zeros::<f64>(n_rows * p)
2842                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2843            w: stream
2844                .alloc_zeros::<f64>(n_rows * max_q)
2845                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2846            v: stream
2847                .alloc_zeros::<f64>(n_rows * max_q)
2848                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2849            n_rows,
2850            p,
2851            k,
2852            max_q,
2853            smooth_blocks: data.smooth_blocks.len(),
2854            g_blocks: data.sparse_g_blocks.len(),
2855        })
2856    }
2857
2858    fn launch_sae_init(
2859        stream: &Arc<CudaStream>,
2860        module: &Arc<CudaModule>,
2861        out: &mut CudaSlice<f64>,
2862        x: &CudaSlice<f64>,
2863        ridge: f64,
2864        n: usize,
2865    ) -> Result<(), ArrowSchurGpuFailure> {
2866        let kernel = module
2867            .load_function("arrow_sae_init")
2868            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2869        let n_i32 = checked_i32(n)?;
2870        let mut builder = stream.launch_builder(&kernel);
2871        builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2872        // SAFETY: `out` and `x` are live device buffers with at least `n`
2873        // entries on `stream`; the kernel writes one in-bounds element per
2874        // launched index below `n`.
2875        unsafe { builder.launch(pcg_launch_config(n)?) }
2876            .map(drop)
2877            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2878    }
2879
2880    fn launch_sae_penalty_matvec(
2881        stream: &Arc<CudaStream>,
2882        module: &Arc<CudaModule>,
2883        buffers: &mut DeviceSaePcgBuffers,
2884        x: &CudaSlice<f64>,
2885        out: &mut CudaSlice<f64>,
2886        ridge_beta: f64,
2887    ) -> Result<(), ArrowSchurGpuFailure> {
2888        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2889        if buffers.smooth_blocks > 0 {
2890            let kernel = module
2891                .load_function("arrow_sae_smooth_matvec")
2892                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2893            let max_m = buffers.k;
2894            let p_i32 = checked_i32(buffers.p)?;
2895            let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2896            let cfg = LaunchConfig {
2897                grid_dim: (
2898                    ((max_m as u32).saturating_add(255) / 256).max(1),
2899                    checked_i32(buffers.smooth_blocks)? as u32,
2900                    1,
2901                ),
2902                block_dim: (256, 1, 1),
2903                shared_mem_bytes: 0,
2904            };
2905            let mut builder = stream.launch_builder(&kernel);
2906            builder
2907                .arg(x)
2908                .arg(&mut *out)
2909                .arg(&buffers.smooth_offsets)
2910                .arg(&buffers.smooth_m)
2911                .arg(&buffers.smooth_ptr)
2912                .arg(&buffers.smooth_data)
2913                .arg(&p_i32)
2914                .arg(&blocks_i32);
2915            // SAFETY: smooth block metadata and dense smooth data were flattened
2916            // into live device buffers; the 2D grid covers only declared block
2917            // and coefficient-channel work items, and the kernel bounds-checks
2918            // against each block's stored size.
2919            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2920        }
2921        if buffers.g_blocks > 0 {
2922            let kernel = module
2923                .load_function("arrow_sae_sparse_g_matvec")
2924                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2925            let max_work = buffers
2926                .k
2927                .checked_div(buffers.p)
2928                .unwrap_or(0)
2929                .saturating_mul(buffers.p);
2930            let p_i32 = checked_i32(buffers.p)?;
2931            let blocks_i32 = checked_i32(buffers.g_blocks)?;
2932            let cfg = LaunchConfig {
2933                grid_dim: (
2934                    ((max_work as u32).saturating_add(255) / 256).max(1),
2935                    checked_i32(buffers.g_blocks)? as u32,
2936                    1,
2937                ),
2938                block_dim: (256, 1, 1),
2939                shared_mem_bytes: 0,
2940            };
2941            let mut builder = stream.launch_builder(&kernel);
2942            builder
2943                .arg(x)
2944                .arg(&mut *out)
2945                .arg(&buffers.g_row_off)
2946                .arg(&buffers.g_col_off)
2947                .arg(&buffers.g_rows)
2948                .arg(&buffers.g_cols)
2949                .arg(&buffers.g_ptr)
2950                .arg(&buffers.g_data)
2951                .arg(&p_i32)
2952                .arg(&blocks_i32);
2953            // SAFETY: sparse G block metadata/data are live device buffers built
2954            // from host CSR-like block descriptors; the launch dimensions cover
2955            // declared block work only and the kernel checks row/column bounds
2956            // before reading or accumulating.
2957            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2958        }
2959        Ok(())
2960    }
2961
2962    fn launch_sae_row_schur_sub(
2963        stream: &Arc<CudaStream>,
2964        module: &Arc<CudaModule>,
2965        buffers: &mut DeviceSaePcgBuffers,
2966        x: &CudaSlice<f64>,
2967        out: &mut CudaSlice<f64>,
2968    ) -> Result<(), ArrowSchurGpuFailure> {
2969        let p_i32 = checked_i32(buffers.p)?;
2970        let max_q_i32 = checked_i32(buffers.max_q)?;
2971        let n_rows_i32 = checked_i32(buffers.n_rows)?;
2972        let cfg_p_rows = LaunchConfig {
2973            grid_dim: (
2974                ((buffers.p as u32).saturating_add(255) / 256).max(1),
2975                checked_i32(buffers.n_rows)? as u32,
2976                1,
2977            ),
2978            block_dim: (256, 1, 1),
2979            shared_mem_bytes: 0,
2980        };
2981        let gather = module
2982            .load_function("arrow_sae_gather_u")
2983            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2984        {
2985            let mut builder = stream.launch_builder(&gather);
2986            builder
2987                .arg(x)
2988                .arg(&buffers.row_ptr)
2989                .arg(&buffers.beta_base)
2990                .arg(&buffers.phi)
2991                .arg(&mut buffers.u)
2992                .arg(&p_i32)
2993                .arg(&n_rows_i32);
2994            // SAFETY: `x`, row pointers, beta offsets, basis rows, and `u` are
2995            // live device buffers sized for `n_rows` by `p`; the kernel guards
2996            // row/channel indices before gathering.
2997            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2998        }
2999
3000        let cfg_q_rows = LaunchConfig {
3001            grid_dim: (
3002                ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
3003                checked_i32(buffers.n_rows)? as u32,
3004                1,
3005            ),
3006            block_dim: (256, 1, 1),
3007            shared_mem_bytes: 0,
3008        };
3009        let apply_l = module
3010            .load_function("arrow_sae_apply_l")
3011            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3012        {
3013            let mut builder = stream.launch_builder(&apply_l);
3014            builder
3015                .arg(&buffers.u)
3016                .arg(&buffers.jac_ptr)
3017                .arg(&buffers.jac)
3018                .arg(&mut buffers.w)
3019                .arg(&p_i32)
3020                .arg(&max_q_i32)
3021                .arg(&n_rows_i32);
3022            // SAFETY: `u`, Jacobian row pointers/data, and `w` are live buffers
3023            // sized for the `(n_rows, p)` to `(n_rows, max_q)` multiply; the
3024            // kernel checks row and local-coordinate bounds.
3025            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3026        }
3027
3028        let apply_ainv = module
3029            .load_function("arrow_sae_apply_ainv")
3030            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3031        {
3032            let mut builder = stream.launch_builder(&apply_ainv);
3033            builder
3034                .arg(&buffers.ainv)
3035                .arg(&buffers.w)
3036                .arg(&mut buffers.v)
3037                .arg(&max_q_i32)
3038                .arg(&n_rows_i32);
3039            // SAFETY: `ainv`, `w`, and `v` are live device buffers sized for
3040            // `n_rows * max_q`; the kernel guards all row/local-coordinate
3041            // indices before reading or writing.
3042            unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3043        }
3044
3045        let scatter = module
3046            .load_function("arrow_sae_scatter_sub")
3047            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3048        {
3049            let mut builder = stream.launch_builder(&scatter);
3050            builder
3051                .arg(&buffers.v)
3052                .arg(&buffers.jac_ptr)
3053                .arg(&buffers.jac)
3054                .arg(&buffers.row_ptr)
3055                .arg(&buffers.beta_base)
3056                .arg(&buffers.phi)
3057                .arg(out)
3058                .arg(&p_i32)
3059                .arg(&max_q_i32)
3060                .arg(&n_rows_i32);
3061            // SAFETY: `v`, Jacobian metadata, row pointers, beta offsets, basis
3062            // rows, and `out` are live buffers for `n_rows` by `p`; scatter
3063            // indices are checked against row and channel bounds in the kernel.
3064            unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3065        }
3066        Ok(())
3067    }
3068
3069    fn launch_sae_diag_sub(
3070        stream: &Arc<CudaStream>,
3071        module: &Arc<CudaModule>,
3072        buffers: &DeviceSaePcgBuffers,
3073        diag: &mut CudaSlice<f64>,
3074    ) -> Result<(), ArrowSchurGpuFailure> {
3075        let kernel = module
3076            .load_function("arrow_sae_diag_sub")
3077            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3078        let p_i32 = checked_i32(buffers.p)?;
3079        let max_q_i32 = checked_i32(buffers.max_q)?;
3080        let n_rows_i32 = checked_i32(buffers.n_rows)?;
3081        let cfg = LaunchConfig {
3082            grid_dim: (
3083                ((buffers.p as u32).saturating_add(255) / 256).max(1),
3084                checked_i32(buffers.n_rows)? as u32,
3085                1,
3086            ),
3087            block_dim: (256, 1, 1),
3088            shared_mem_bytes: 0,
3089        };
3090        let mut builder = stream.launch_builder(&kernel);
3091        builder
3092            .arg(diag)
3093            .arg(&buffers.ainv)
3094            .arg(&buffers.jac_ptr)
3095            .arg(&buffers.jac)
3096            .arg(&buffers.row_ptr)
3097            .arg(&buffers.beta_base)
3098            .arg(&buffers.phi)
3099            .arg(&p_i32)
3100            .arg(&max_q_i32)
3101            .arg(&n_rows_i32);
3102        // SAFETY: diagonal output and all read-only SAE row metadata buffers are
3103        // live on `stream` with sizes matching `n_rows`, `p`, and `max_q`; the
3104        // kernel bounds-checks its flattened work index.
3105        unsafe { builder.launch(cfg) }
3106            .map(drop)
3107            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3108    }
3109
3110    fn launch_sae_matvec(
3111        stream: &Arc<CudaStream>,
3112        module: &Arc<CudaModule>,
3113        buffers: &mut DeviceSaePcgBuffers,
3114        x: &CudaSlice<f64>,
3115        out: &mut CudaSlice<f64>,
3116        ridge_beta: f64,
3117    ) -> Result<(), ArrowSchurGpuFailure> {
3118        launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3119        launch_sae_row_schur_sub(stream, module, buffers, x, out)
3120    }
3121
3122    /// Pack `D + ρ_t I`, `B`, and `g` into the strided `(n × P_MAX × P_MAX)`
3123    /// / `(n × P_MAX × R_TEMPLATE)` / `(n × P_MAX)` layout the fused kernel
3124    /// expects. Entries outside the runtime `(p, r)` window stay at zero so
3125    /// the kernel's per-element loops are safe to no-op there.
3126    fn pack_fused_host(
3127        sys: &ArrowSchurSystem,
3128        ridge_t: f64,
3129        p_max: usize,
3130        r_template: usize,
3131    ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3132        let n = sys.rows.len();
3133        let d = sys.d;
3134        let k = sys.k;
3135        let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3136        let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3137        let mut g_buf = vec![0.0_f64; n * p_max];
3138        for (i, row) in sys.rows.iter().enumerate() {
3139            // D_i + ρI, column-major in P_MAX×P_MAX strided block.
3140            for col in 0..d {
3141                let base = (i * p_max + col) * p_max;
3142                for r in 0..d {
3143                    let mut value = row.htt[[r, col]];
3144                    if r == col {
3145                        value += ridge_t;
3146                    }
3147                    d_buf[base + r] = value;
3148                }
3149            }
3150            // B_i in P_MAX×R_TEMPLATE strided block. The per-row (per-i) block
3151            // stride is `p_max · r_template` (matching the `b_buf` allocation
3152            // above and the kernel's `b_stack`/`y_out` layout), NOT
3153            // `p_max · p_max`: using the D-block multiplier here overflows the
3154            // buffer whenever `p_max > r_template` (e.g. d=30→p_max=32,
3155            // k=5→r_template=5). The within-block element offset stays
3156            // column-major `col·p_max + r` (P_MAX rows per column).
3157            for col in 0..k {
3158                let base = (i * r_template + col) * p_max;
3159                for r in 0..d {
3160                    b_buf[base + r] = row.htbeta[[r, col]];
3161                }
3162            }
3163            // g_i in P_MAX strided vector.
3164            let g_base = i * p_max;
3165            for r in 0..d {
3166                g_buf[g_base + r] = row.gt[r];
3167            }
3168        }
3169        (d_buf, b_buf, g_buf)
3170    }
3171
3172    // -----------------------------------------------------------------------
3173    // #1017 Phase 3: across-iteration device residency.
3174    //
3175    // `solve()` re-packs and re-uploads `D` (`H_tt`), `B` (`H_tβ`) and `g`,
3176    // then re-runs the per-row POTRF and the border Schur factorization on
3177    // EVERY call. For the SAE joint inner Newton at a frozen gate/basis frame
3178    // the Hessian blocks `D`, `B`, `H_ββ` are CONSTANT across the inner loop —
3179    // only the gradient `g = r(z) = H z − g₀` changes per iterate. So the
3180    // factor work (`O(n·d³ + p³)`) and the dominant `O(n·d·p)` cross-block
3181    // upload are pure waste when repeated per iterate.
3182    //
3183    // `ResidentArrowFrame` performs that constant work ONCE at construction:
3184    // upload+ridge+POTRF of `D` (keeping `L_i` resident in `l_dev`), the
3185    // forward solve `Y_i = L_i^{-1} B_i` (kept resident in `y_dev`), and the
3186    // Schur assembly + border POTRF (keeping `L_S` resident in `schur_dev`).
3187    // Each subsequent `solve_gradient(g)` uploads only the `n·d` row gradient,
3188    // runs the cheap residual path — `u_i = L_i^{-1} g_i` (one batched TRSM),
3189    // Schur RHS `−g_β + Σ Y_iᵀ u_i`, `δβ = L_S^{-T} L_S^{-1} rhs` (two TRSM,
3190    // NO POTRF), back-sub `δt_i = −L_i^{-T}(u_i + Y_i δβ)` — and reads back only
3191    // `δ` and the cached log|H|. The heavy buffers never leave the device
3192    // across iterations; the per-iterate host transfer is `O(n·d + p)`, not
3193    // `O(n·d·p)`. Numerics are bit-identical to a `solve()` at the same
3194    // `(D, B, H_ββ, g, ridge_t, ridge_beta)` because the factor buffers and the
3195    // helper kernels are the same; the resident path merely SKIPS re-deriving
3196    // the parts that do not depend on `g`. The CPU dense reference
3197    // (`solve_arrow_newton_step_dense_reference`) is the parity oracle.
3198    pub(super) struct ResidentArrowFrame {
3199        n: usize,
3200        d: usize,
3201        k: usize,
3202        stream: Arc<CudaStream>,
3203        blas: CudaBlas,
3204        /// Per-row lower Cholesky factors `L_i` of `H_tt + ρ_t I`, stacked
3205        /// column-major (`n` tiles of `d×d`). Resident across iterations.
3206        l_dev: CudaSlice<f64>,
3207        /// Whitened cross blocks `Y_i = L_i^{-1} H_tβ^(i)`, stacked column-major
3208        /// (`n` tiles of `d×k`). Resident across iterations.
3209        y_dev: CudaSlice<f64>,
3210        /// Lower Cholesky factor `L_S` of the reduced Schur complement
3211        /// `S_β = H_ββ + ρ_β I − Σ_i Y_iᵀ Y_i`. Resident across iterations.
3212        schur_dev: CudaSlice<f64>,
3213        /// `log|H| = 2 Σ log L_{i,jj} + 2 Σ log L_{S,aa}`, constant for the
3214        /// frame (depends only on the factored Hessian, not on `g`).
3215        log_det_hessian: f64,
3216    }
3217
3218    impl ResidentArrowFrame {
3219        /// Upload the constant Hessian blocks and perform the one-time factor
3220        /// work (`POTRF(D)`, `Y_i = L_i^{-1} B_i`, Schur assembly + border
3221        /// `POTRF`). The frame then serves cheap per-gradient solves.
3222        pub(super) fn new(
3223            sys: &ArrowSchurSystem,
3224            ridge_t: f64,
3225            ridge_beta: f64,
3226        ) -> Result<Self, ArrowSchurGpuFailure> {
3227            if ridge_t.is_nan() || ridge_beta.is_nan() {
3228                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3229                    reason: "ridge is NaN".to_string(),
3230                });
3231            }
3232            let n = sys.rows.len();
3233            let d = sys.d;
3234            let k = sys.k;
3235            let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3236                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3237            let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3238                .and_then(|ctx| ctx.new_stream().ok())
3239                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3240            let solver =
3241                DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3242            let blas =
3243                CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3244
3245            // Upload the constant blocks. `g` is uploaded per-gradient, not here.
3246            let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3247            let mut l_dev = stream
3248                .clone_htod(&d_host)
3249                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3250            let mut y_dev = stream
3251                .clone_htod(&b_host)
3252                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3253
3254            // POTRF(D) → L_i, kept resident in l_dev.
3255            let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3256            if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3257                // cuSOLVER `info` is a 1-based pivot row index; size the bump
3258                // from the block (Gershgorin λ_min bound) so a strongly
3259                // indefinite block recovers in one retry.
3260                return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3261                    row: idx,
3262                    bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3263                });
3264            }
3265
3266            // Y_i = L_i^{-1} B_i, in place over y_dev. Kept resident.
3267            trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3268
3269            // Schur assembly S_β = (H_ββ + ρ_β I) − Σ Y_iᵀ Y_i, then POTRF → L_S.
3270            // The RHS accumulation is folded into the gradient path; here we
3271            // only need the (gradient-independent) Schur factor, so accumulate
3272            // into a throwaway rhs buffer.
3273            let schur_init: Vec<f64> = {
3274                let mut tmp = Vec::with_capacity(k * k);
3275                for col in 0..k {
3276                    for row in 0..k {
3277                        let mut v = sys.hbb[[row, col]];
3278                        if row == col {
3279                            v += ridge_beta;
3280                        }
3281                        tmp.push(v);
3282                    }
3283                }
3284                tmp
3285            };
3286            let mut schur_dev = stream
3287                .clone_htod(&schur_init)
3288                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3289            // A zero u-stack makes `Σ Y_iᵀ u_i = 0`, so only the `−Σ Y_iᵀ Y_i`
3290            // Schur term is accumulated (the rhs is rebuilt per gradient).
3291            let zero_u = stream
3292                .clone_htod(&vec![0.0_f64; n * d])
3293                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3294            let mut throwaway_rhs = stream
3295                .clone_htod(&vec![0.0_f64; k])
3296                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3297            accumulate_schur(
3298                &blas,
3299                d,
3300                k,
3301                n,
3302                &y_dev,
3303                &zero_u,
3304                &mut schur_dev,
3305                &mut throwaway_rhs,
3306            )?;
3307            let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3308            if info != 0 {
3309                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3310                    reason: format!("Schur Cholesky failed at pivot {info}"),
3311                });
3312            }
3313
3314            // log|H| from the resident factors (constant for the frame).
3315            let l_local_host = stream
3316                .clone_dtoh(&l_dev)
3317                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3318            let l_schur_host = stream
3319                .clone_dtoh(&schur_dev)
3320                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3321            let mut log_det = 0.0_f64;
3322            for i in 0..n {
3323                let base = i * d * d;
3324                for j in 0..d {
3325                    log_det += l_local_host[base + j * d + j].ln();
3326                }
3327            }
3328            for j in 0..k {
3329                log_det += l_schur_host[j * k + j].ln();
3330            }
3331            log_det *= 2.0;
3332
3333            Ok(Self {
3334                n,
3335                d,
3336                k,
3337                stream,
3338                blas,
3339                l_dev,
3340                y_dev,
3341                schur_dev,
3342                log_det_hessian: log_det,
3343            })
3344        }
3345
3346        #[inline]
3347        pub(super) fn log_det_hessian(&self) -> f64 {
3348            self.log_det_hessian
3349        }
3350
3351        /// Solve `H δ = −gradient` for a fresh gradient `(g_t, g_β)` reusing the
3352        /// resident factors. Uploads only `g_t` (`n·d` scalars); reads back only
3353        /// `δ`. No POTRF runs here — all factorization is amortized into `new`.
3354        pub(super) fn solve_gradient(
3355            &self,
3356            g_t: &[f64],
3357            g_beta: &[f64],
3358        ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3359            let n = self.n;
3360            let d = self.d;
3361            let k = self.k;
3362            if g_t.len() != n * d || g_beta.len() != k {
3363                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3364                    reason: format!(
3365                        "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3366                        g_t.len(),
3367                        n * d,
3368                        g_beta.len(),
3369                        k
3370                    ),
3371                });
3372            }
3373            // Upload the per-iterate row gradient → u_i = L_i^{-1} g_i in place.
3374            let mut u_dev = self
3375                .stream
3376                .clone_htod(&g_t.to_vec())
3377                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3378            trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3379
3380            // Schur RHS = −g_β + Σ_i Y_iᵀ u_i. Reuse the resident Schur factor
3381            // (no POTRF, and skip the −Σ Y_iᵀ Y_i GEMM already baked into L_S).
3382            let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3383            let mut rhs_dev = self
3384                .stream
3385                .clone_htod(&rhs_init)
3386                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3387            accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3388
3389            // δβ ← L_S^{-T} L_S^{-1} rhs using the resident border factor.
3390            trsm_single(
3391                &self.blas,
3392                &self.stream,
3393                k,
3394                &self.schur_dev,
3395                &mut rhs_dev,
3396                false,
3397                false,
3398            )?;
3399            trsm_single(
3400                &self.blas,
3401                &self.stream,
3402                k,
3403                &self.schur_dev,
3404                &mut rhs_dev,
3405                false,
3406                true,
3407            )?;
3408            let delta_beta_host = self
3409                .stream
3410                .clone_dtoh(&rhs_dev)
3411                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3412            let delta_beta = Array1::from_vec(delta_beta_host);
3413
3414            // Back-sub δt_i = −L_i^{-T}(u_i + Y_i δβ).
3415            accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3416            trsm_batched_lower_inplace_transposed(
3417                &self.blas,
3418                &self.stream,
3419                d,
3420                n,
3421                1,
3422                &self.l_dev,
3423                &mut u_dev,
3424            )?;
3425            let x_host = self
3426                .stream
3427                .clone_dtoh(&u_dev)
3428                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3429            let mut delta_t = Array1::<f64>::zeros(n * d);
3430            for (i, v) in x_host.iter().enumerate() {
3431                delta_t[i] = -*v;
3432            }
3433
3434            Ok(ArrowSchurGpuSolution {
3435                delta_t,
3436                delta_beta,
3437                log_det_hessian: self.log_det_hessian,
3438            })
3439        }
3440    }
3441
3442    pub(super) fn solve_fused(
3443        sys: &ArrowSchurSystem,
3444        ridge_t: f64,
3445        ridge_beta: f64,
3446    ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3447        let n = sys.rows.len();
3448        let d = sys.d;
3449        let k = sys.k;
3450        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3451            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3452        let p_max = plan.p_max;
3453        let r_template = plan.r_template;
3454
3455        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3456            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3457        )
3458        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3459        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3460            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3461        let stream = ctx
3462            .new_stream()
3463            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3464        let cap = &runtime.device.capability;
3465        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3466            cc_major: cap.compute_major,
3467            cc_minor: cap.compute_minor,
3468            p_max: p_max as u32,
3469            r_template: r_template as u32,
3470        };
3471        let module = fused_module_for(&ctx, key)?;
3472        let forward = module
3473            .load_function("arrow_schur_forward_pgroup")
3474            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3475        let back_sub = module
3476            .load_function("arrow_schur_back_sub_pgroup")
3477            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3478
3479        // ----- Upload packed D, B, g -----
3480        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3481        let d_dev = stream
3482            .clone_htod(&d_host)
3483            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3484        let b_dev = stream
3485            .clone_htod(&b_host)
3486            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487        let g_dev = stream
3488            .clone_htod(&g_host)
3489            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3490        let mut l_out = stream
3491            .alloc_zeros::<f64>(n * p_max * p_max)
3492            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3493        let mut u_out = stream
3494            .alloc_zeros::<f64>(n * p_max)
3495            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3496        let mut y_out = stream
3497            .alloc_zeros::<f64>(n * p_max * r_template)
3498            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3499        let mut partial_s = stream
3500            .alloc_zeros::<f64>(plan.partial_s_doubles)
3501            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3502        let mut partial_r = stream
3503            .alloc_zeros::<f64>(plan.partial_r_doubles)
3504            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3505        let mut status_dev = stream
3506            .alloc_zeros::<i32>(n)
3507            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3508
3509        // ----- Launch forward kernel: 1 block per row, P_MAX threads -----
3510        let cfg = LaunchConfig {
3511            grid_dim: (plan.blocks, 1, 1),
3512            block_dim: (plan.threads_per_block, 1, 1),
3513            shared_mem_bytes: 0,
3514        };
3515        let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3516        let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3517        let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3518        let ridge_arg = ridge_t;
3519        {
3520            let mut builder = stream.launch_builder(&forward);
3521            builder
3522                .arg(&d_dev)
3523                .arg(&b_dev)
3524                .arg(&g_dev)
3525                .arg(&n_i32)
3526                .arg(&p_i32)
3527                .arg(&r_i32)
3528                .arg(&ridge_arg)
3529                .arg(&mut l_out)
3530                .arg(&mut u_out)
3531                .arg(&mut y_out)
3532                .arg(&mut partial_s)
3533                .arg(&mut partial_r)
3534                .arg(&mut status_dev);
3535            // SAFETY: all buffers were just allocated on `stream` with sizes
3536            // derived from `plan`; kernel parameter list matches the
3537            // FORWARD_KERNEL_SOURCE signature.
3538            unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3539        }
3540        stream
3541            .synchronize()
3542            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3543
3544        // ----- Check per-block pivot status -----
3545        let status_host = stream
3546            .clone_dtoh(&status_dev)
3547            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3548        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3549            // The NVRTC kernel's status code is a 1-based pivot row index, not
3550            // a magnitude; size the bump from the block (Gershgorin λ_min
3551            // bound) so a strongly indefinite block recovers in one retry.
3552            return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3553                row,
3554                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3555            });
3556        }
3557
3558        // ----- Reduce partials on host into S_β and r_β -----
3559        let partial_s_host = stream
3560            .clone_dtoh(&partial_s)
3561            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3562        let partial_r_host = stream
3563            .clone_dtoh(&partial_r)
3564            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3565        let mut schur_host = vec![0.0_f64; k * k];
3566        for col in 0..k {
3567            for row in 0..k {
3568                let mut v = sys.hbb[[row, col]];
3569                if row == col {
3570                    v += ridge_beta;
3571                }
3572                schur_host[col * k + row] = v;
3573            }
3574        }
3575        let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3576        for i in 0..n {
3577            // partial_S[i] stride is R_TEMPLATE × R_TEMPLATE column-major; we
3578            // only read the leading (k × k) sub-block.
3579            let s_base = i * r_template * r_template;
3580            for col in 0..k {
3581                let col_base = s_base + col * r_template;
3582                let dst_col_base = col * k;
3583                for row in 0..k {
3584                    schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3585                }
3586            }
3587            let r_base = i * r_template;
3588            for a in 0..k {
3589                rhs_host[a] += partial_r_host[r_base + a];
3590            }
3591        }
3592
3593        // ----- Factor S_β on device (cuSOLVER), solve for δβ -----
3594        let mut schur_dev = stream
3595            .clone_htod(&schur_host)
3596            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3597        let mut rhs_dev = stream
3598            .clone_htod(&rhs_host)
3599            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3600        let solver =
3601            DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3602        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3603        let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3604        if info != 0 {
3605            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3606                reason: format!("fused Schur Cholesky failed at pivot {info}"),
3607            });
3608        }
3609        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3610        trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3611        let delta_beta_host = stream
3612            .clone_dtoh(&rhs_dev)
3613            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3614        let delta_beta = Array1::from_vec(delta_beta_host.clone());
3615
3616        // ----- Layer E: launch back-sub kernel using persisted L, u, Y -----
3617        let mut delta_t_dev = stream
3618            .alloc_zeros::<f64>(n * p_max)
3619            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3620        let back_cfg = LaunchConfig {
3621            grid_dim: (plan.blocks, 1, 1),
3622            block_dim: (plan.threads_per_block, 1, 1),
3623            shared_mem_bytes: 0,
3624        };
3625        {
3626            let mut builder = stream.launch_builder(&back_sub);
3627            builder
3628                .arg(&l_out)
3629                .arg(&u_out)
3630                .arg(&y_out)
3631                .arg(&rhs_dev)
3632                .arg(&n_i32)
3633                .arg(&p_i32)
3634                .arg(&r_i32)
3635                .arg(&mut delta_t_dev);
3636            // SAFETY: kernel parameter list matches FORWARD_KERNEL_SOURCE
3637            // back-sub signature; `rhs_dev` holds δβ in the leading k entries
3638            // (R_TEMPLATE-strided indexing is column 0..k of the R-vector).
3639            unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3640        }
3641        stream
3642            .synchronize()
3643            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3644
3645        let delta_t_host = stream
3646            .clone_dtoh(&delta_t_dev)
3647            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3648        let mut delta_t = Array1::<f64>::zeros(n * d);
3649        for i in 0..n {
3650            let src_base = i * p_max;
3651            let dst_base = i * d;
3652            for r in 0..d {
3653                delta_t[dst_base + r] = delta_t_host[src_base + r];
3654            }
3655        }
3656
3657        // ----- log|H| = 2·Σ log L_{i,jj} + 2·Σ log R_{β,aa} -----
3658        let l_local_host = stream
3659            .clone_dtoh(&l_out)
3660            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3661        let l_schur_host = stream
3662            .clone_dtoh(&schur_dev)
3663            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3664        let mut log_det = 0.0_f64;
3665        for i in 0..n {
3666            let base = i * p_max * p_max;
3667            for j in 0..d {
3668                log_det += l_local_host[base + j * p_max + j].ln();
3669            }
3670        }
3671        for j in 0..k {
3672            log_det += l_schur_host[j * k + j].ln();
3673        }
3674        log_det *= 2.0;
3675
3676        Ok(ArrowSchurGpuSolution {
3677            delta_t,
3678            delta_beta,
3679            log_det_hessian: log_det,
3680        })
3681    }
3682
3683    /// Pre-compute `Y_i = L_i^{-1} H_tβ^(i)` via the fused forward kernel and
3684    /// return a closure that evaluates the full Schur matvec
3685    /// `S·x = (H_ββ + ρ·I)·x − Σ_i Y_i^T (Y_i·x)` for each PCG iteration.
3686    ///
3687    /// The `Y_i` factors are kept in a host-side buffer after one GPU forward
3688    /// pass. Each matvec call runs O(N·d·K) host loops over the pre-computed
3689    /// buffer plus an optional `H_ββ·x` call (matrix-free or dense). This is
3690    /// the first landing of the GPU matvec; a future iteration can move the
3691    /// `Y_i·x` / `Y_i^T z_i` steps to cuBLAS batched GEMV.
3692    pub(super) fn build_schur_matvec_backend(
3693        sys: &ArrowSchurSystem,
3694        ridge_t: f64,
3695        ridge_beta: f64,
3696    ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3697        let n = sys.rows.len();
3698        let d = sys.d;
3699        let k = sys.k;
3700        let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3701            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3702        let p_max = plan.p_max;
3703        let r_template = plan.r_template;
3704
3705        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3706            gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3707        )
3708        .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3709        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3710            .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3711        let stream = ctx
3712            .new_stream()
3713            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3714        let cap = &runtime.device.capability;
3715        let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3716            cc_major: cap.compute_major,
3717            cc_minor: cap.compute_minor,
3718            p_max: p_max as u32,
3719            r_template: r_template as u32,
3720        };
3721        let module = fused_module_for(&ctx, key)?;
3722        let forward = module
3723            .load_function("arrow_schur_forward_pgroup")
3724            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3725
3726        let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3727        let d_dev = stream
3728            .clone_htod(&d_host)
3729            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3730        let b_dev = stream
3731            .clone_htod(&b_host)
3732            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3733        let g_dev = stream
3734            .clone_htod(&g_host)
3735            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3736        let mut l_out = stream
3737            .alloc_zeros::<f64>(n * p_max * p_max)
3738            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3739        let mut u_out = stream
3740            .alloc_zeros::<f64>(n * p_max)
3741            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3742        let mut y_out = stream
3743            .alloc_zeros::<f64>(n * p_max * r_template)
3744            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3745        let mut partial_s = stream
3746            .alloc_zeros::<f64>(plan.partial_s_doubles)
3747            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3748        let mut partial_r = stream
3749            .alloc_zeros::<f64>(plan.partial_r_doubles)
3750            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3751        let mut status_dev = stream
3752            .alloc_zeros::<i32>(n)
3753            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3754
3755        let cfg = LaunchConfig {
3756            grid_dim: (plan.blocks, 1, 1),
3757            block_dim: (plan.threads_per_block, 1, 1),
3758            shared_mem_bytes: 0,
3759        };
3760        let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3761        let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3762        let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3763        let ridge_arg = ridge_t;
3764        {
3765            let mut builder = stream.launch_builder(&forward);
3766            builder
3767                .arg(&d_dev)
3768                .arg(&b_dev)
3769                .arg(&g_dev)
3770                .arg(&n_i32)
3771                .arg(&p_i32)
3772                .arg(&r_i32)
3773                .arg(&ridge_arg)
3774                .arg(&mut l_out)
3775                .arg(&mut u_out)
3776                .arg(&mut y_out)
3777                .arg(&mut partial_s)
3778                .arg(&mut partial_r)
3779                .arg(&mut status_dev);
3780            // SAFETY: all buffers were allocated on `stream` with sizes
3781            // derived from `plan`; parameter list matches FORWARD_KERNEL_SOURCE.
3782            unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3783        }
3784        stream
3785            .synchronize()
3786            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3787
3788        let status_host = stream
3789            .clone_dtoh(&status_dev)
3790            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3791        if let Some(row) = status_host.iter().position(|s| *s != 0) {
3792            // Status code is a 1-based pivot row index, not a magnitude; size
3793            // the bump from the block (Gershgorin λ_min bound) so a strongly
3794            // indefinite block recovers in one retry.
3795            return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3796                row,
3797                bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3798            });
3799        }
3800
3801        // Download Y_i factors: n × p_max × r_template column-major per block.
3802        let y_host = stream
3803            .clone_dtoh(&y_out)
3804            .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3805
3806        // Capture H_ββ data for the closure. Use the matrix-free hook if present
3807        // (SAE-manifold callers), otherwise fall back to the dense matrix rows.
3808        let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3809        let hbb_is_kk = sys.hbb.dim() == (k, k);
3810        let hbb_matvec_opt = sys.hbb_matvec.clone();
3811
3812        let closure: crate::arrow_schur::GpuSchurMatvec =
3813            Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3814                assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3815                assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3816
3817                // (H_ββ + ρ·I)·x into out.
3818                if let Some(ref mv) = hbb_matvec_opt {
3819                    mv(x.view(), out);
3820                    for a in 0..k {
3821                        out[a] += ridge_beta * x[a];
3822                    }
3823                } else if hbb_is_kk {
3824                    // hbb_host row-major: hbb[a, b] = hbb_host[a * k + b].
3825                    for a in 0..k {
3826                        let mut acc = ridge_beta * x[a];
3827                        for b in 0..k {
3828                            acc += hbb_host[a * k + b] * x[b];
3829                        }
3830                        out[a] = acc;
3831                    }
3832                } else {
3833                    for a in 0..k {
3834                        out[a] = ridge_beta * x[a];
3835                    }
3836                }
3837
3838                // out[c] -= Σ_i (Y_i^T (Y_i·x))[c].
3839                // Y_i column-major at y_host[i·p_max·r_template + col·p_max + row].
3840                let mut z = vec![0.0_f64; d];
3841                for i in 0..n {
3842                    let y_base = i * p_max * r_template;
3843                    for r in 0..d {
3844                        let mut acc = 0.0;
3845                        for c in 0..k {
3846                            acc += y_host[y_base + c * p_max + r] * x[c];
3847                        }
3848                        z[r] = acc;
3849                    }
3850                    for c in 0..k {
3851                        let mut acc = 0.0;
3852                        for r in 0..d {
3853                            acc += y_host[y_base + c * p_max + r] * z[r];
3854                        }
3855                        out[c] -= acc;
3856                    }
3857                }
3858            });
3859
3860        Ok(closure)
3861    }
3862
3863    // ── #1017/#1026 frames-engaged device PCG ──────────────────────────────
3864
3865    struct DeviceSaeFrameBuffers {
3866        // Smooth `λ S_k ⊗ I_{r_k}`.
3867        s_off: CudaSlice<i32>,
3868        s_m: CudaSlice<i32>,
3869        s_r: CudaSlice<i32>,
3870        s_ptr: CudaSlice<i32>,
3871        s_data: CudaSlice<f64>,
3872        s_blocks: usize,
3873        // Data `G_{ij} ⊗ W_{ij}`.
3874        g_off_i: CudaSlice<i32>,
3875        g_off_j: CudaSlice<i32>,
3876        g_ri: CudaSlice<i32>,
3877        g_rj: CudaSlice<i32>,
3878        g_mi: CudaSlice<i32>,
3879        g_mj: CudaSlice<i32>,
3880        g_ptr: CudaSlice<i32>,
3881        g_data: CudaSlice<f64>,
3882        w_ptr: CudaSlice<i32>,
3883        w_data: CudaSlice<f64>,
3884        g_blocks: usize,
3885        g_max_work: usize,
3886        // Per-row dense cross-block H_tβ^(i) + row q + factored ainv.
3887        htb_ptr: CudaSlice<i32>,
3888        htb: CudaSlice<f64>,
3889        q_of: CudaSlice<i32>,
3890        ainv: CudaSlice<f64>,
3891        hvec: CudaSlice<f64>,
3892        svec: CudaSlice<f64>,
3893        n_rows: usize,
3894        k: usize,
3895        max_q: usize,
3896    }
3897
3898    fn flatten_device_sae_frame_data(
3899        sys: &ArrowSchurSystem,
3900        data: &DeviceSaePcgData,
3901        frame: &DeviceSaeFrameData,
3902        ridge_t: f64,
3903        stream: &Arc<CudaStream>,
3904    ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3905        let n_rows = sys.rows.len();
3906        let k = data.beta_dim;
3907        if frame.row_htbeta.len() != n_rows
3908            || frame.ranks.len() != frame.basis_sizes.len()
3909            || frame.border_offsets.len() != frame.ranks.len()
3910            || data.smooth_blocks.len() != frame.smooth_ranks.len()
3911        {
3912            return Err(ArrowSchurGpuFailure::Unavailable);
3913        }
3914
3915        // Smooth blocks.
3916        let mut s_off = Vec::new();
3917        let mut s_m = Vec::new();
3918        let mut s_r = Vec::new();
3919        let mut s_ptr = vec![0_i32];
3920        let mut s_data = Vec::<f64>::new();
3921        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3922            let (m, mc) = blk.factor_a.dim();
3923            if m != mc {
3924                return Err(ArrowSchurGpuFailure::Unavailable);
3925            }
3926            s_off.push(checked_i32(blk.global_offset)?);
3927            s_m.push(checked_i32(m)?);
3928            s_r.push(checked_i32(r)?);
3929            for ri in 0..m {
3930                for ci in 0..m {
3931                    s_data.push(blk.factor_a[[ri, ci]]);
3932                }
3933            }
3934            s_ptr.push(checked_i32(s_data.len())?);
3935        }
3936
3937        // Data blocks (g + w).
3938        let mut g_off_i = Vec::new();
3939        let mut g_off_j = Vec::new();
3940        let mut g_ri = Vec::new();
3941        let mut g_rj = Vec::new();
3942        let mut g_mi = Vec::new();
3943        let mut g_mj = Vec::new();
3944        let mut g_ptr = vec![0_i32];
3945        let mut g_data = Vec::<f64>::new();
3946        let mut w_ptr = vec![0_i32];
3947        let mut w_data = Vec::<f64>::new();
3948        let mut g_max_work = 0usize;
3949        for blk in &frame.frame_blocks {
3950            let ri = frame.ranks[blk.atom_i];
3951            let rj = frame.ranks[blk.atom_j];
3952            let (mi, mj) = blk.g.dim();
3953            if blk.w.dim() != (ri, rj) {
3954                return Err(ArrowSchurGpuFailure::Unavailable);
3955            }
3956            g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3957            g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3958            g_ri.push(checked_i32(ri)?);
3959            g_rj.push(checked_i32(rj)?);
3960            g_mi.push(checked_i32(mi)?);
3961            g_mj.push(checked_i32(mj)?);
3962            for r in 0..mi {
3963                for c in 0..mj {
3964                    g_data.push(blk.g[[r, c]]);
3965                }
3966            }
3967            g_ptr.push(checked_i32(g_data.len())?);
3968            for a in 0..ri {
3969                for b in 0..rj {
3970                    w_data.push(blk.w[[a, b]]);
3971                }
3972            }
3973            w_ptr.push(checked_i32(w_data.len())?);
3974            g_max_work = g_max_work.max(mi * ri);
3975        }
3976
3977        // Per-row dense cross-block + q + ainv (factored H_tt⁻¹).
3978        let mut htb_ptr = vec![0_i32];
3979        let mut htb = Vec::<f64>::new();
3980        let mut q_of = Vec::<i32>::with_capacity(n_rows);
3981        let mut max_q = 0usize;
3982        for (i, slab) in frame.row_htbeta.iter().enumerate() {
3983            let qi = sys.row_dims[i];
3984            // A populated slab must be q_i × k row-major; an empty slab ⇒ q=0
3985            // (the row contributes no reduced-Schur term).
3986            let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3987                qi
3988            } else {
3989                0
3990            };
3991            q_of.push(checked_i32(q_eff)?);
3992            max_q = max_q.max(q_eff);
3993            if q_eff > 0 {
3994                htb.extend_from_slice(slab);
3995            }
3996            htb_ptr.push(checked_i32(htb.len())?);
3997        }
3998        if max_q == 0 {
3999            // No row contributes a reduced term — the system is pure-penalty.
4000            // Still valid; give max_q=1 so the ainv buffer is non-empty.
4001            max_q = 1;
4002        }
4003
4004        let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
4005        for (i, row) in sys.rows.iter().enumerate() {
4006            let q = q_of[i] as usize;
4007            if q == 0 {
4008                continue;
4009            }
4010            if row.htt.dim() != (q, q) {
4011                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4012                    reason: format!(
4013                        "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
4014                        row.htt.dim()
4015                    ),
4016                });
4017            }
4018            let mut block = row.htt.clone();
4019            for d in 0..q {
4020                block[[d, d]] += ridge_t;
4021            }
4022            let factor = gam_linalg::triangular::cholesky_factor_in_place(
4023                block.view(),
4024                gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
4025            )
4026            .ok_or_else(|| {
4027                // Deficit-aware bump (Gershgorin λ_min bound) so a strongly
4028                // indefinite per-row block recovers in one outer-loop retry.
4029                ArrowSchurGpuFailure::RidgeBumpRequired {
4030                    row: i,
4031                    bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
4032                }
4033            })?;
4034            for col in 0..q {
4035                let mut e = Array1::<f64>::zeros(q);
4036                e[col] = 1.0;
4037                let solved =
4038                    gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
4039                for r in 0..q {
4040                    ainv[i * max_q * max_q + r * max_q + col] = solved[r];
4041                }
4042            }
4043        }
4044
4045        let htod_i = |v: &[i32]| {
4046            stream
4047                .clone_htod(v)
4048                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4049        };
4050        let htod_f = |v: &[f64]| {
4051            stream
4052                .clone_htod(v)
4053                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4054        };
4055        Ok(DeviceSaeFrameBuffers {
4056            s_off: htod_i(&s_off)?,
4057            s_m: htod_i(&s_m)?,
4058            s_r: htod_i(&s_r)?,
4059            s_ptr: htod_i(&s_ptr)?,
4060            s_data: htod_f(&s_data)?,
4061            s_blocks: data.smooth_blocks.len(),
4062            g_off_i: htod_i(&g_off_i)?,
4063            g_off_j: htod_i(&g_off_j)?,
4064            g_ri: htod_i(&g_ri)?,
4065            g_rj: htod_i(&g_rj)?,
4066            g_mi: htod_i(&g_mi)?,
4067            g_mj: htod_i(&g_mj)?,
4068            g_ptr: htod_i(&g_ptr)?,
4069            g_data: htod_f(&g_data)?,
4070            w_ptr: htod_i(&w_ptr)?,
4071            w_data: htod_f(&w_data)?,
4072            g_blocks: frame.frame_blocks.len(),
4073            g_max_work,
4074            htb_ptr: htod_i(&htb_ptr)?,
4075            htb: htod_f(&htb)?,
4076            q_of: htod_i(&q_of)?,
4077            ainv: htod_f(&ainv)?,
4078            hvec: stream
4079                .alloc_zeros::<f64>(n_rows * max_q)
4080                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4081            svec: stream
4082                .alloc_zeros::<f64>(n_rows * max_q)
4083                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4084            n_rows,
4085            k,
4086            max_q,
4087        })
4088    }
4089
4090    fn sae_frame_penalty_diag_host(
4091        data: &DeviceSaePcgData,
4092        frame: &DeviceSaeFrameData,
4093        ridge_beta: f64,
4094    ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4095        let mut diag = vec![ridge_beta; data.beta_dim];
4096        // Smooth: diag[off + ia·r + ib] += S[ia,ia].
4097        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4098            let m = blk.factor_a.nrows();
4099            for ia in 0..m {
4100                let coeff = blk.factor_a[[ia, ia]];
4101                let base = blk.global_offset + ia * r;
4102                for ib in 0..r {
4103                    if base + ib >= diag.len() {
4104                        return Err(ArrowSchurGpuFailure::Unavailable);
4105                    }
4106                    diag[base + ib] += coeff;
4107                }
4108            }
4109        }
4110        // Data: on-diagonal atom blocks contribute g[li,li]·w[a,a].
4111        for blk in &frame.frame_blocks {
4112            if blk.atom_i != blk.atom_j {
4113                continue;
4114            }
4115            let r = frame.ranks[blk.atom_i];
4116            let off = frame.border_offsets[blk.atom_i];
4117            let (mi, mj) = blk.g.dim();
4118            for li in 0..mi.min(mj) {
4119                let gii = blk.g[[li, li]];
4120                let base = off + li * r;
4121                for a in 0..r {
4122                    if base + a >= diag.len() {
4123                        return Err(ArrowSchurGpuFailure::Unavailable);
4124                    }
4125                    diag[base + a] += gii * blk.w[[a, a]];
4126                }
4127            }
4128        }
4129        Ok(diag)
4130    }
4131
4132    fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4133        Ok(LaunchConfig {
4134            grid_dim: (
4135                ((work as u32).saturating_add(255) / 256).max(1),
4136                checked_i32(n_rows)? as u32,
4137                1,
4138            ),
4139            block_dim: (256, 1, 1),
4140            shared_mem_bytes: 0,
4141        })
4142    }
4143
4144    fn launch_sae_frame_matvec(
4145        stream: &Arc<CudaStream>,
4146        module: &Arc<CudaModule>,
4147        buffers: &mut DeviceSaeFrameBuffers,
4148        x: &CudaSlice<f64>,
4149        out: &mut CudaSlice<f64>,
4150        ridge_beta: f64,
4151    ) -> Result<(), ArrowSchurGpuFailure> {
4152        launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4153        // Smooth penalty.
4154        if buffers.s_blocks > 0 {
4155            let kernel = module
4156                .load_function("arrow_sae_frame_smooth_matvec")
4157                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4158            let blocks_i32 = checked_i32(buffers.s_blocks)?;
4159            let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4160            let mut b = stream.launch_builder(&kernel);
4161            b.arg(x)
4162                .arg(&mut *out)
4163                .arg(&buffers.s_off)
4164                .arg(&buffers.s_m)
4165                .arg(&buffers.s_r)
4166                .arg(&buffers.s_ptr)
4167                .arg(&buffers.s_data)
4168                .arg(&blocks_i32);
4169            // SAFETY: smooth block metadata/data are live device buffers; the grid
4170            // covers (k channels × n_blocks) and the kernel bounds-checks m·r.
4171            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4172        }
4173        // Data penalty.
4174        if buffers.g_blocks > 0 {
4175            let kernel = module
4176                .load_function("arrow_sae_frame_g_matvec")
4177                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4178            let blocks_i32 = checked_i32(buffers.g_blocks)?;
4179            let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4180            let mut b = stream.launch_builder(&kernel);
4181            b.arg(x)
4182                .arg(&mut *out)
4183                .arg(&buffers.g_off_i)
4184                .arg(&buffers.g_off_j)
4185                .arg(&buffers.g_ri)
4186                .arg(&buffers.g_rj)
4187                .arg(&buffers.g_mi)
4188                .arg(&buffers.g_mj)
4189                .arg(&buffers.g_ptr)
4190                .arg(&buffers.g_data)
4191                .arg(&buffers.w_ptr)
4192                .arg(&buffers.w_data)
4193                .arg(&blocks_i32);
4194            // SAFETY: g/w block metadata/data are live device buffers; the grid
4195            // covers (max m_i·r_i × n_blocks) and the kernel bounds-checks.
4196            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4197        }
4198        // Reduced-Schur subtraction via dense per-row cross-blocks.
4199        let k_i32 = checked_i32(buffers.k)?;
4200        let max_q_i32 = checked_i32(buffers.max_q)?;
4201        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4202        {
4203            let kernel = module
4204                .load_function("arrow_sae_frame_apply_h")
4205                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4206            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4207            let mut b = stream.launch_builder(&kernel);
4208            b.arg(x)
4209                .arg(&buffers.htb_ptr)
4210                .arg(&buffers.htb)
4211                .arg(&buffers.q_of)
4212                .arg(&mut buffers.hvec)
4213                .arg(&k_i32)
4214                .arg(&max_q_i32)
4215                .arg(&n_rows_i32);
4216            // SAFETY: dense cross-block + pointers + hvec are live buffers sized
4217            // for (n_rows × max_q) / (n_rows × k); kernel guards q_i and k.
4218            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4219        }
4220        {
4221            let kernel = module
4222                .load_function("arrow_sae_frame_apply_ainv")
4223                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4224            let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4225            let mut b = stream.launch_builder(&kernel);
4226            b.arg(&buffers.ainv)
4227                .arg(&buffers.hvec)
4228                .arg(&buffers.q_of)
4229                .arg(&mut buffers.svec)
4230                .arg(&max_q_i32)
4231                .arg(&n_rows_i32);
4232            // SAFETY: ainv/hvec/svec are live buffers sized for n_rows·max_q²
4233            // and n_rows·max_q; the kernel guards row/coord bounds.
4234            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4235        }
4236        {
4237            let kernel = module
4238                .load_function("arrow_sae_frame_scatter_h")
4239                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4240            let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4241            let mut b = stream.launch_builder(&kernel);
4242            b.arg(&buffers.svec)
4243                .arg(&buffers.htb_ptr)
4244                .arg(&buffers.htb)
4245                .arg(&buffers.q_of)
4246                .arg(out)
4247                .arg(&k_i32)
4248                .arg(&max_q_i32)
4249                .arg(&n_rows_i32);
4250            // SAFETY: svec/cross-block/out are live buffers; the kernel atomically
4251            // accumulates into out[a] for a<k and reads c<q_i.
4252            unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4253        }
4254        Ok(())
4255    }
4256
4257    fn launch_sae_frame_diag_sub(
4258        stream: &Arc<CudaStream>,
4259        module: &Arc<CudaModule>,
4260        buffers: &DeviceSaeFrameBuffers,
4261        diag: &mut CudaSlice<f64>,
4262    ) -> Result<(), ArrowSchurGpuFailure> {
4263        let kernel = module
4264            .load_function("arrow_sae_frame_diag_sub")
4265            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4266        let k_i32 = checked_i32(buffers.k)?;
4267        let max_q_i32 = checked_i32(buffers.max_q)?;
4268        let n_rows_i32 = checked_i32(buffers.n_rows)?;
4269        let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4270        let mut b = stream.launch_builder(&kernel);
4271        b.arg(diag)
4272            .arg(&buffers.ainv)
4273            .arg(&buffers.htb_ptr)
4274            .arg(&buffers.htb)
4275            .arg(&buffers.q_of)
4276            .arg(&k_i32)
4277            .arg(&max_q_i32)
4278            .arg(&n_rows_i32);
4279        // SAFETY: diag + cross-block + ainv live buffers; kernel guards a<k, c/d<q.
4280        unsafe { b.launch(cfg) }
4281            .map(drop)
4282            .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4283    }
4284
4285    /// #1551 kernel-isolating seam: evaluate the framed reduced-Schur matvec
4286    /// `out = S·x` EXACTLY ONCE on the device (no PCG, no offload-floor gate) and
4287    /// return `out`. This is the parity probe the test harness diffs against the
4288    /// CPU oracle [`super::sae_framed_schur_matvec_cpu`] element-by-element, so a
4289    /// kernel/marshalling defect is exposed directly — independent of how the
4290    /// iterative solver behaves on an ill-conditioned assembled `S` (where dense
4291    /// Cholesky and PCG legitimately disagree at the solution level). Declines
4292    /// (`Unavailable`) only when CUDA is genuinely absent so the test skips
4293    /// cleanly off-device; it deliberately does NOT consult the offload policy so
4294    /// even a tiny verifiable fixture runs on the GPU.
4295    pub(super) fn framed_schur_matvec_once_on_device(
4296        sys: &ArrowSchurSystem,
4297        data: &DeviceSaePcgData,
4298        ridge_t: f64,
4299        ridge_beta: f64,
4300        x: &Array1<f64>,
4301    ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4302        let k = x.len();
4303        if k == 0 || data.beta_dim != k || sys.k != k {
4304            return Err(ArrowSchurGpuFailure::Unavailable);
4305        }
4306        let frame = data
4307            .frame
4308            .as_ref()
4309            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4310        // No offload-policy filter here: the seam exists to validate the kernel on
4311        // ANY device, including the smallest hand-checkable fixture.
4312        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4313            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4314        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4315            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316        let stream = ctx
4317            .new_stream()
4318            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4319        let vector_module = pcg_vector_module(&ctx)?;
4320        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4321        let x_dev = stream
4322            .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4323            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4324        let mut out_dev = stream
4325            .alloc_zeros::<f64>(k)
4326            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4327        launch_sae_frame_matvec(
4328            &stream,
4329            vector_module,
4330            &mut buffers,
4331            &x_dev,
4332            &mut out_dev,
4333            ridge_beta,
4334        )?;
4335        let out = stream
4336            .clone_dtoh(&out_dev)
4337            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4338        Ok(Array1::from_vec(out))
4339    }
4340
4341    pub(super) fn solve_sae_matrix_free_pcg_framed(
4342        sys: &ArrowSchurSystem,
4343        data: &DeviceSaePcgData,
4344        ridge_t: f64,
4345        ridge_beta: f64,
4346        rhs_beta: &Array1<f64>,
4347        max_iterations: usize,
4348        relative_tolerance: f64,
4349    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4350        let k = rhs_beta.len();
4351        if k == 0 || data.beta_dim != k || sys.k != k {
4352            return Err(ArrowSchurGpuFailure::Unavailable);
4353        }
4354        let frame = data
4355            .frame
4356            .as_ref()
4357            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4358        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4359            .filter(|rt| {
4360                rt.policy().reduced_schur_matvec_should_offload(
4361                    sys.rows.len(),
4362                    sys.k,
4363                    sys.d,
4364                    max_iterations,
4365                )
4366            })
4367            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4368        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4369            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4370        let stream = ctx
4371            .new_stream()
4372            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4374        let vector_module = pcg_vector_module(&ctx)?;
4375        let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4376
4377        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4378        if rhs_norm == 0.0 {
4379            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4380        }
4381        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4382        let rhs_dev = stream
4383            .clone_htod(
4384                rhs_beta
4385                    .as_slice()
4386                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4387            )
4388            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4389        let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4390        let mut diag_dev = stream
4391            .clone_htod(&diag_host)
4392            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4393        launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4394        let diag_host = stream
4395            .clone_dtoh(&diag_dev)
4396            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4397        let mut inv_diag = Vec::with_capacity(k);
4398        for (idx, &d) in diag_host.iter().enumerate() {
4399            if !d.is_finite() || d <= 1.0e-18 {
4400                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4401                    reason: format!(
4402                        "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4403                    ),
4404                });
4405            }
4406            inv_diag.push(1.0 / d);
4407        }
4408        let inv_diag_dev = stream
4409            .clone_htod(&inv_diag)
4410            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4411
4412        let mut x_dev = stream
4413            .alloc_zeros::<f64>(k)
4414            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4415        let mut r_dev = stream
4416            .alloc_zeros::<f64>(k)
4417            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4418        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4419        let mut z_dev = stream
4420            .alloc_zeros::<f64>(k)
4421            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4422        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4423        let mut p_dev = stream
4424            .alloc_zeros::<f64>(k)
4425            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4426        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4427        let mut ap_dev = stream
4428            .alloc_zeros::<f64>(k)
4429            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4430
4431        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4432        if rz <= 0.0 || !rz.is_finite() {
4433            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4434                reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4435            });
4436        }
4437        let mut diag = PcgDiagnostics {
4438            precond_apply_calls: 1,
4439            stopping_reason: PcgStopReason::MaxIter,
4440            ..PcgDiagnostics::default()
4441        };
4442        for _ in 0..max_iterations.max(1) {
4443            launch_sae_frame_matvec(
4444                &stream,
4445                vector_module,
4446                &mut buffers,
4447                &p_dev,
4448                &mut ap_dev,
4449                ridge_beta,
4450            )?;
4451            diag.matvec_calls += 1;
4452            diag.iterations += 1;
4453            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4454            if pap <= 0.0 || !pap.is_finite() {
4455                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4456                    reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4457                });
4458            }
4459            let alpha = rz / pap;
4460            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4461            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4462            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4463            if r_norm <= tol {
4464                diag.final_relative_residual = r_norm / rhs_norm;
4465                diag.stopping_reason = PcgStopReason::Converged;
4466                break;
4467            }
4468            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4469            diag.precond_apply_calls += 1;
4470            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4471            if rz_new <= 0.0 || !rz_new.is_finite() {
4472                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4473                    reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4474                });
4475            }
4476            let beta = rz_new / rz;
4477            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4478            rz = rz_new;
4479        }
4480        if diag.stopping_reason != PcgStopReason::Converged {
4481            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4482            diag.final_relative_residual = r_norm / rhs_norm;
4483            diag.stopping_reason = PcgStopReason::MaxIter;
4484        }
4485        let x = stream
4486            .clone_dtoh(&x_dev)
4487            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4488        Ok((Array1::from_vec(x), diag))
4489    }
4490
4491    /// #1551 stage-isolating triage seam: run the framed reduced-Schur matvec
4492    /// `out = S·x` ONCE on the device (no PCG, no offload-floor gate) and return
4493    /// `out`, so a tiny hand-verifiable fixture can diff it against the CPU oracle
4494    /// `sae_framed_schur_matvec_cpu` element-by-element to localize the structural
4495    /// divergence to a single kernel stage. Returns `Unavailable` only when CUDA
4496    /// is genuinely absent (so the test skips cleanly off-device).
4497    pub(super) fn solve_sae_matrix_free_pcg(
4498        sys: &ArrowSchurSystem,
4499        data: &DeviceSaePcgData,
4500        ridge_t: f64,
4501        ridge_beta: f64,
4502        rhs_beta: &Array1<f64>,
4503        max_iterations: usize,
4504        relative_tolerance: f64,
4505    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4506        let k = rhs_beta.len();
4507        if k == 0 || data.beta_dim != k || sys.k != k {
4508            return Err(ArrowSchurGpuFailure::Unavailable);
4509        }
4510        // #1017/#1026 GUARD: the legacy `⊗ I_p` kernel must NEVER receive framed
4511        // data (factored `G ⊗ W_{ij}` + dense per-row cross blocks); decline so a
4512        // mis-route falls back to the CPU rather than returning a wrong step.
4513        if data.frame.is_some() {
4514            return Err(ArrowSchurGpuFailure::Unavailable);
4515        }
4516        // #1017 Phase-1 dispatch re-key: this is the matrix-free SAE reduced-Schur
4517        // PCG — the production hot path, not a single dense factorization. The
4518        // dense-Direct floor `dense_hessian_work_target_is_gpu(n, k)` keys on
4519        // `2·n·k²` and is the WRONG gate here: it ignores the per-row frame depth
4520        // `d` (the M dimension that multiplies the per-apply work) and the
4521        // `1/cg_iters` staging amortisation, so it both undercounts the SAE batched
4522        // work `n·k·d` and applies a cold single-launch breakeven to an apply that
4523        // reuses device-resident frames `max_iterations` times. Key instead on the
4524        // CG-amortised total batched work — the same predicate the host injection
4525        // gate (`maybe_inject_gpu_schur_matvec`) consults — so few-row/wide-`k`/
4526        // modest-`d` LLM shapes register the real `n × k × d × cg_iters` arithmetic.
4527        // Kernels and numerics are untouched; only where the matvec runs changes,
4528        // and the host falls back to the bit-identical CPU matvec when this declines.
4529        let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4530            .filter(|rt| {
4531                rt.policy().reduced_schur_matvec_should_offload(
4532                    sys.rows.len(),
4533                    sys.k,
4534                    sys.d,
4535                    max_iterations,
4536                )
4537            })
4538            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4539        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4540            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4541        let stream = ctx
4542            .new_stream()
4543            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4545        let vector_module = pcg_vector_module(&ctx)?;
4546        let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4547
4548        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4549        if rhs_norm == 0.0 {
4550            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4551        }
4552        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4553        let rhs_dev = stream
4554            .clone_htod(
4555                rhs_beta
4556                    .as_slice()
4557                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4558            )
4559            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4560        let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4561        let mut diag_dev = stream
4562            .clone_htod(&diag_host)
4563            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4564        launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4565        let diag_host = stream
4566            .clone_dtoh(&diag_dev)
4567            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4568        let mut inv_diag = Vec::with_capacity(k);
4569        for (idx, &d) in diag_host.iter().enumerate() {
4570            if !d.is_finite() || d <= 1.0e-18 {
4571                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4572                    reason: format!(
4573                        "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4574                    ),
4575                });
4576            }
4577            inv_diag.push(1.0 / d);
4578        }
4579        let inv_diag_dev = stream
4580            .clone_htod(&inv_diag)
4581            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4582
4583        let mut x_dev = stream
4584            .alloc_zeros::<f64>(k)
4585            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4586        let mut r_dev = stream
4587            .alloc_zeros::<f64>(k)
4588            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4589        device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4590        let mut z_dev = stream
4591            .alloc_zeros::<f64>(k)
4592            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4593        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4594        let mut p_dev = stream
4595            .alloc_zeros::<f64>(k)
4596            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4597        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4598        let mut ap_dev = stream
4599            .alloc_zeros::<f64>(k)
4600            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4601
4602        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4603        if rz <= 0.0 || !rz.is_finite() {
4604            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4605                reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4606            });
4607        }
4608        let mut diag = PcgDiagnostics {
4609            precond_apply_calls: 1,
4610            stopping_reason: PcgStopReason::MaxIter,
4611            ..PcgDiagnostics::default()
4612        };
4613
4614        for _ in 0..max_iterations.max(1) {
4615            launch_sae_matvec(
4616                &stream,
4617                vector_module,
4618                &mut buffers,
4619                &p_dev,
4620                &mut ap_dev,
4621                ridge_beta,
4622            )?;
4623            diag.matvec_calls += 1;
4624            diag.iterations += 1;
4625            let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4626            if pap <= 0.0 || !pap.is_finite() {
4627                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4628                    reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4629                });
4630            }
4631            let alpha = rz / pap;
4632            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4633            device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4634            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4635            if r_norm <= tol {
4636                diag.final_relative_residual = r_norm / rhs_norm;
4637                diag.stopping_reason = PcgStopReason::Converged;
4638                break;
4639            }
4640            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4641            diag.precond_apply_calls += 1;
4642            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4643            if rz_new <= 0.0 || !rz_new.is_finite() {
4644                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4645                    reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4646                });
4647            }
4648            let beta = rz_new / rz;
4649            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4650            rz = rz_new;
4651        }
4652        if diag.stopping_reason != PcgStopReason::Converged {
4653            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4654            diag.final_relative_residual = r_norm / rhs_norm;
4655            diag.stopping_reason = PcgStopReason::MaxIter;
4656        }
4657        let x = stream
4658            .clone_dtoh(&x_dev)
4659            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4660        Ok((Array1::from_vec(x), diag))
4661    }
4662
4663    pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4664        s_acc: &ndarray::Array2<f64>,
4665        rhs_beta: &Array1<f64>,
4666        max_iterations: usize,
4667        relative_tolerance: f64,
4668    ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4669        let k = rhs_beta.len();
4670        // #1017 dispatch re-key: this is an ITERATIVE device-resident PCG, not a
4671        // single GEMV. `S` (k×k) is uploaded once and reused for `max_iterations`
4672        // `S·p` GEMVs while only convergence scalars cross PCIe, so the staging
4673        // cost is amortised over the whole CG solve. Gating on the flops of ONE
4674        // `Gemv{k,k}` (`2·k²`) understates the work by the iteration count and
4675        // declines shapes (e.g. k≈512) whose total iterated arithmetic
4676        // `2·k²·iters` clears the device floor by orders of magnitude — the same
4677        // single-launch-breakeven miskey #1017 fixed for the framed reduced-Schur
4678        // matvec. Key on the CG-amortised total work via a `Gemm{k,k,iters}` whose
4679        // `flops()` is exactly `2·k²·iters`; numerics and kernels are untouched,
4680        // and the host falls back to the bit-identical CPU PCG when this declines.
4681        let cg_iters = max_iterations.max(1);
4682        let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4683            gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4684                m: k,
4685                n: k,
4686                k: cg_iters,
4687            },
4688        )
4689        .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4690        let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4691            .and_then(|ctx| ctx.new_stream().ok())
4692            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4693        let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4694        let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4695            .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4696        let vector_module = pcg_vector_module(&ctx)?;
4697
4698        // Jacobi diagonal from S; must be strictly positive for SPD.
4699        let mut inv_diag = vec![0.0_f64; k];
4700        for j in 0..k {
4701            let djj = s_acc[[j, j]];
4702            if !(djj > 0.0) {
4703                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4704                    reason: format!(
4705                        "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4706                    ),
4707                });
4708            }
4709            inv_diag[j] = 1.0 / djj;
4710        }
4711
4712        // Upload S column-major (S[row,col] at col*k + row).
4713        let mut s_host = vec![0.0_f64; k * k];
4714        for col in 0..k {
4715            for row in 0..k {
4716                s_host[col * k + row] = s_acc[[row, col]];
4717            }
4718        }
4719        let s_dev = stream
4720            .clone_htod(&s_host)
4721            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4722
4723        // Steihaug truncated-CG with Jacobi preconditioner, host scalar
4724        // recurrences and a device `S·p` matvec. The streaming reduced solve
4725        // uses an unbounded trust region (pure CG to tolerance).
4726        let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4727        if rhs_norm == 0.0 {
4728            return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4729        }
4730        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4731
4732        // Device-resident PCG state. Only convergence scalars cross back during
4733        // the loop; x/r/z/p/Sp stay on CUDA until the final solution download.
4734        let mut x_dev = stream
4735            .alloc_zeros::<f64>(k)
4736            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4737        let mut r_dev = stream
4738            .clone_htod(
4739                rhs_beta
4740                    .as_slice()
4741                    .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4742            )
4743            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4744        let inv_diag_dev = stream
4745            .clone_htod(&inv_diag)
4746            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4747        let mut z_dev = stream
4748            .alloc_zeros::<f64>(k)
4749            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4750        launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4751        let mut p_dev = stream
4752            .alloc_zeros::<f64>(k)
4753            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4754        device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4755        let mut sp_dev = stream
4756            .alloc_zeros::<f64>(k)
4757            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4758        let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4759        let mut diag = PcgDiagnostics {
4760            precond_apply_calls: 1,
4761            stopping_reason: PcgStopReason::MaxIter,
4762            ..PcgDiagnostics::default()
4763        };
4764        if rz <= 0.0 || !rz.is_finite() {
4765            return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4766                reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4767            });
4768        }
4769
4770        let max_iters = max_iterations.max(1);
4771        for _ in 0..max_iters {
4772            // sp = S · p (device GEMV, S column-major k×k, op = N).
4773            let gemv_cfg = GemvConfig::<f64> {
4774                trans: cublasOperation_t::CUBLAS_OP_N,
4775                m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4776                n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4777                alpha: 1.0,
4778                lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4779                incx: 1,
4780                beta: 0.0,
4781                incy: 1,
4782            };
4783            // SAFETY: s_dev is k×k column-major, p_dev / sp_dev length k.
4784            unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4785                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4786            diag.matvec_calls += 1;
4787            diag.iterations += 1;
4788
4789            let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4790            if !(p_sp > 0.0) {
4791                // Non-positive curvature on a (proximal-ridged) SPD system means
4792                // numerical breakdown; surface so the caller escalates.
4793                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4794                    reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4795                });
4796            }
4797            let alpha = rz / p_sp;
4798            device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4799            device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4800            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4801            if r_norm <= tol {
4802                diag.final_relative_residual = r_norm / rhs_norm;
4803                diag.stopping_reason = PcgStopReason::Converged;
4804                break;
4805            }
4806            launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4807            diag.precond_apply_calls += 1;
4808            let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4809            if rz_new <= 0.0 || !rz_new.is_finite() {
4810                return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4811                    reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4812                });
4813            }
4814            let beta = rz_new / rz;
4815            launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4816            rz = rz_new;
4817        }
4818        if diag.stopping_reason != PcgStopReason::Converged {
4819            let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4820            diag.final_relative_residual = r_norm / rhs_norm;
4821            diag.stopping_reason = PcgStopReason::MaxIter;
4822        }
4823
4824        let x = stream
4825            .clone_dtoh(&x_dev)
4826            .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4827        Ok((Array1::from_vec(x), diag))
4828    }
4829
4830    fn device_copy(
4831        blas: &CudaBlas,
4832        stream: &Arc<CudaStream>,
4833        n: usize,
4834        src: &CudaSlice<f64>,
4835        dst: &mut CudaSlice<f64>,
4836    ) -> Result<(), ArrowSchurGpuFailure> {
4837        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4838        let (src_ptr, _src_rec) = src.device_ptr(stream);
4839        let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4840        // SAFETY: src and dst are live device allocations on this stream with at
4841        // least n contiguous f64 entries and unit stride.
4842        let status = unsafe {
4843            cudarc::cublas::sys::cublasDcopy_v2(
4844                *blas.handle(),
4845                n_i,
4846                src_ptr as *const f64,
4847                1,
4848                dst_ptr as *mut f64,
4849                1,
4850            )
4851        };
4852        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4853            Ok(())
4854        } else {
4855            Err(ArrowSchurGpuFailure::Unavailable)
4856        }
4857    }
4858
4859    fn device_axpy(
4860        blas: &CudaBlas,
4861        stream: &Arc<CudaStream>,
4862        n: usize,
4863        alpha: f64,
4864        x: &CudaSlice<f64>,
4865        y: &mut CudaSlice<f64>,
4866    ) -> Result<(), ArrowSchurGpuFailure> {
4867        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4868        let (x_ptr, _x_rec) = x.device_ptr(stream);
4869        let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4870        // SAFETY: x and y are live device allocations on this stream with at
4871        // least n contiguous f64 entries and unit stride; cuBLAS only reads alpha.
4872        let status = unsafe {
4873            cudarc::cublas::sys::cublasDaxpy_v2(
4874                *blas.handle(),
4875                n_i,
4876                &alpha,
4877                x_ptr as *const f64,
4878                1,
4879                y_ptr as *mut f64,
4880                1,
4881            )
4882        };
4883        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4884            Ok(())
4885        } else {
4886            Err(ArrowSchurGpuFailure::Unavailable)
4887        }
4888    }
4889
4890    fn device_dot(
4891        blas: &CudaBlas,
4892        stream: &Arc<CudaStream>,
4893        n: usize,
4894        x: &CudaSlice<f64>,
4895        y: &CudaSlice<f64>,
4896    ) -> Result<f64, ArrowSchurGpuFailure> {
4897        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4898        let (x_ptr, _x_rec) = x.device_ptr(stream);
4899        let (y_ptr, _y_rec) = y.device_ptr(stream);
4900        let mut result = 0.0_f64;
4901        // SAFETY: x and y are live device allocations on this stream with at
4902        // least n contiguous f64 entries and unit stride; result is a valid host
4903        // out-pointer for the cuBLAS scalar.
4904        let status = unsafe {
4905            cudarc::cublas::sys::cublasDdot_v2(
4906                *blas.handle(),
4907                n_i,
4908                x_ptr as *const f64,
4909                1,
4910                y_ptr as *const f64,
4911                1,
4912                &mut result,
4913            )
4914        };
4915        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4916            Ok(result)
4917        } else {
4918            Err(ArrowSchurGpuFailure::Unavailable)
4919        }
4920    }
4921
4922    fn device_nrm2(
4923        blas: &CudaBlas,
4924        stream: &Arc<CudaStream>,
4925        n: usize,
4926        x: &CudaSlice<f64>,
4927    ) -> Result<f64, ArrowSchurGpuFailure> {
4928        let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4929        let (x_ptr, _x_rec) = x.device_ptr(stream);
4930        let mut result = 0.0_f64;
4931        // SAFETY: x is a live device allocation on this stream with at least n
4932        // contiguous f64 entries and unit stride; result is a valid host
4933        // out-pointer for the cuBLAS scalar.
4934        let status = unsafe {
4935            cudarc::cublas::sys::cublasDnrm2_v2(
4936                *blas.handle(),
4937                n_i,
4938                x_ptr as *const f64,
4939                1,
4940                &mut result,
4941            )
4942        };
4943        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4944            Ok(result)
4945        } else {
4946            Err(ArrowSchurGpuFailure::Unavailable)
4947        }
4948    }
4949
4950    #[cfg(test)]
4951    mod tests {
4952        //! #1551 device-side framed-matvec triage. Lives inside `mod cuda` so it
4953        //! can call the private kernel launchers directly (no test-only public
4954        //! seam, which the ban-scanner forbids). A bare `#[cfg(test)] mod tests`
4955        //! is the one form the scanner permits.
4956        use super::*;
4957        use crate::arrow_schur::{
4958            ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4959            FactoredFrameGBlock,
4960        };
4961        use ndarray::Array2;
4962
4963        /// Run the framed reduced-Schur matvec `out = S·x` ONCE on the device
4964        /// (no PCG, no offload gate) and return `out`.
4965        fn device_matvec_once(
4966            sys: &ArrowSchurSystem,
4967            data: &DeviceSaePcgData,
4968            ridge_t: f64,
4969            ridge_beta: f64,
4970            x_host: &[f64],
4971        ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4972            let k = x_host.len();
4973            let frame = data
4974                .frame
4975                .as_ref()
4976                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4977            let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4978                .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4979            let ctx =
4980                gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4981                    .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4982            let stream = ctx
4983                .new_stream()
4984                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4985            let vector_module = pcg_vector_module(&ctx)?;
4986            let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4987            let x_dev = stream
4988                .clone_htod(x_host)
4989                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4990            let mut out_dev = stream
4991                .alloc_zeros::<f64>(k)
4992                .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4993            launch_sae_frame_matvec(
4994                &stream,
4995                vector_module,
4996                &mut buffers,
4997                &x_dev,
4998                &mut out_dev,
4999                ridge_beta,
5000            )?;
5001            stream
5002                .clone_dtoh(&out_dev)
5003                .map_err(|_| ArrowSchurGpuFailure::Unavailable)
5004        }
5005
5006        /// #1551 stage-isolating matvec triage on a TINY hand-verifiable fixture:
5007        /// diff the device framed matvec `S·e_col` against the CPU oracle
5008        /// `sae_framed_schur_matvec_cpu` for every identity column, reporting the
5009        /// worst-divergent border index so the structural 91% localizes to one
5010        /// kernel stage. Skips cleanly off-device.
5011        #[test]
5012        fn framed_sae_device_matvec_stage_diff_tiny_1551() {
5013            if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
5014                return;
5015            }
5016            let p = 3usize;
5017            let ranks = vec![2usize, 3usize];
5018            let basis_sizes = vec![2usize, 2usize];
5019            let mut border_offsets = Vec::new();
5020            let mut acc = 0usize;
5021            for k in 0..2 {
5022                border_offsets.push(acc);
5023                acc += basis_sizes[k] * ranks[k];
5024            }
5025            let border_dim = acc; // 2·2 + 2·3 = 10
5026            let frame_of = |k: usize| -> Array2<f64> {
5027                Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
5028                    0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
5029                })
5030            };
5031            let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
5032            let w_of = |i: usize, j: usize| -> Array2<f64> {
5033                let (ui, uj) = (&frames[i], &frames[j]);
5034                Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
5035                    (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
5036                })
5037            };
5038            let mut frame_blocks = Vec::new();
5039            for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
5040                let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5041                let mut g =
5042                    Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
5043                if i == j {
5044                    for r in 0..mi.min(mj) {
5045                        g[[r, r]] += mi as f64 + 2.0;
5046                    }
5047                }
5048                frame_blocks.push(FactoredFrameGBlock {
5049                    atom_i: i,
5050                    atom_j: j,
5051                    g,
5052                    w: w_of(i, j),
5053                });
5054            }
5055            let mut smooth_blocks = Vec::new();
5056            for k in 0..2 {
5057                let m = basis_sizes[k];
5058                let mut s =
5059                    Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5060                for r in 0..m {
5061                    s[[r, r]] += 1.0;
5062                }
5063                smooth_blocks.push(DeviceSaeSmoothBlock {
5064                    global_offset: border_offsets[k],
5065                    factor_a: s,
5066                });
5067            }
5068            let smooth_ranks = ranks.clone();
5069            let n = 2usize;
5070            let q = 2usize;
5071            let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5072            let mut row_htbeta = Vec::new();
5073            for i in 0..n {
5074                let mut htt =
5075                    Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5076                for r in 0..q {
5077                    htt[[r, r]] += q as f64 + 2.0;
5078                }
5079                sys.rows[i].htt = htt;
5080                let mut slab = vec![0.0_f64; q * border_dim];
5081                for c in 0..q {
5082                    for col in 0..border_dim {
5083                        let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5084                        slab[c * border_dim + col] = v;
5085                        sys.rows[i].htbeta[[c, col]] = v;
5086                    }
5087                }
5088                row_htbeta.push(slab);
5089            }
5090            let data = DeviceSaePcgData {
5091                p,
5092                beta_dim: border_dim,
5093                a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5094                local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5095                smooth_blocks,
5096                sparse_g_blocks: Vec::new(),
5097                frame: Some(DeviceSaeFrameData {
5098                    ranks,
5099                    basis_sizes,
5100                    border_offsets,
5101                    frame_blocks,
5102                    smooth_ranks,
5103                    row_htbeta,
5104                }),
5105            };
5106            let ridge_t = 1e-7;
5107            let ridge_beta = 1e-6;
5108            let mut first_bad: Option<usize> = None;
5109            let mut worst = 0.0_f64;
5110            let mut worst_at = 0usize;
5111            let mut worst_dev = 0.0_f64;
5112            let mut worst_cpu = 0.0_f64;
5113            for col in 0..border_dim {
5114                let mut x = vec![0.0_f64; border_dim];
5115                x[col] = 1.0;
5116                let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5117                    Ok(v) => v,
5118                    Err(_) => return,
5119                };
5120                let mut cpu = vec![0.0_f64; border_dim];
5121                super::super::sae_framed_schur_matvec_cpu(
5122                    &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5123                )
5124                .expect("cpu matvec");
5125                for r in 0..border_dim {
5126                    let d = (dev[r] - cpu[r]).abs();
5127                    if d > 1e-9 && first_bad.is_none() {
5128                        first_bad = Some(r * border_dim + col);
5129                    }
5130                    if d > worst {
5131                        worst = d;
5132                        worst_at = r * border_dim + col;
5133                        worst_dev = dev[r];
5134                        worst_cpu = cpu[r];
5135                    }
5136                }
5137            }
5138            assert!(
5139                worst <= 1e-9,
5140                "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5141                 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5142                 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5143                 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5144                 G⊗W=cross, reduced-Schur=dense per-row)",
5145            );
5146        }
5147    }
5148}
5149
5150#[cfg(test)]
5151mod tests {
5152    use super::*;
5153    use crate::arrow_schur::ArrowSchurSystem;
5154    use ndarray::{Array2, ArrayView1};
5155
5156    fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5157        let mut sys = ArrowSchurSystem::new(n, d, k);
5158        let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5159        let mut sample = || -> f64 {
5160            state = state
5161                .wrapping_mul(6364136223846793005)
5162                .wrapping_add(1442695040888963407);
5163            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5164        };
5165        for row in &mut sys.rows {
5166            let mut a = Array2::<f64>::zeros((d, d));
5167            for r in 0..d {
5168                for c in 0..d {
5169                    a[[r, c]] = sample();
5170                }
5171            }
5172            let mut htt = a.t().dot(&a);
5173            for r in 0..d {
5174                htt[[r, r]] += d as f64 + 1.0;
5175            }
5176            row.htt = htt;
5177            for r in 0..d {
5178                for c in 0..k {
5179                    row.htbeta[[r, c]] = 0.1 * sample();
5180                }
5181                row.gt[r] = sample();
5182            }
5183        }
5184        let mut hbb_a = Array2::<f64>::zeros((k, k));
5185        for r in 0..k {
5186            for c in 0..k {
5187                hbb_a[[r, c]] = sample();
5188            }
5189        }
5190        let mut hbb = hbb_a.t().dot(&hbb_a);
5191        for r in 0..k {
5192            hbb[[r, r]] += k as f64 + 1.0;
5193        }
5194        sys.hbb = hbb;
5195        for r in 0..k {
5196            sys.gb[r] = sample();
5197        }
5198        sys
5199    }
5200
5201    /// The Gershgorin ridge bump must actually make a known-indefinite block PD
5202    /// on the first retry — the whole point of #1711. Verified directly by
5203    /// re-factoring `H_tt + (ridge_t + bump)·I` with the same Cholesky guard the
5204    /// device readback uses, for blocks whose `λ_min` is known in closed form.
5205    #[test]
5206    fn ridge_bump_makes_known_indefinite_blocks_pd() {
5207        // `cholesky_factor_in_place` / `CholeskyGuard` are already in scope via
5208        // `super::*` (imported at the top of the module).
5209        // A few blocks with a CLOSED-FORM smallest eigenvalue, all at ridge_t=0.
5210        // (label, matrix, λ_min) — the bump must clear each one.
5211        let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); // λ_min = -1
5212        let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); // λ_min = -250
5213        // Symmetric 2×2 [[1, 2], [2, 1]] has eigenvalues 3 and -1 → indefinite.
5214        let mut indef2 = Array2::<f64>::zeros((2, 2));
5215        indef2[[0, 0]] = 1.0;
5216        indef2[[1, 1]] = 1.0;
5217        indef2[[0, 1]] = 2.0;
5218        indef2[[1, 0]] = 2.0;
5219        // A genuinely PD block must get a bump that is the bare rounding margin
5220        // only (deficit 0), and must still factor — the helper is defensive.
5221        let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5222
5223        for (label, block) in [
5224            ("-I (λ_min=-1)", neg_identity),
5225            ("-250·I (λ_min=-250)", scaled_neg),
5226            ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5227            ("5·I (PD)", pd),
5228        ] {
5229            let ridge_t = 0.0;
5230            let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5231            assert!(
5232                bump > 0.0 && bump.is_finite(),
5233                "[{label}] bump must be strictly positive and finite, got {bump:e}"
5234            );
5235            let d = block.nrows();
5236            let mut shifted = block.clone();
5237            for i in 0..d {
5238                shifted[[i, i]] += ridge_t + bump;
5239            }
5240            assert!(
5241                cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5242                "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5243                 Gershgorin bump, but the Cholesky still rejected it"
5244            );
5245        }
5246    }
5247
5248    /// The column-major variant (multi-GPU tile path) must agree with the
5249    /// row-major helper for a symmetric block, since Gershgorin edges are
5250    /// invariant under reading the symmetric matrix by row vs by column. The
5251    /// colmajor variant takes the bound at ridge_t=0 (the ridge is already baked
5252    /// into the diagonal it reads), so compare against `ridge_bump_to_make_pd`
5253    /// with `ridge_t = 0`.
5254    ///
5255    /// Gated to linux: `ridge_bump_to_make_pd_colmajor` only exists on the
5256    /// linux CUDA tile path, so the parity test runs where the function does.
5257    #[cfg(target_os = "linux")]
5258    #[test]
5259    fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5260        // Symmetric 3×3 with a negative-definite-ish diagonal and off-diagonals.
5261        let mut a = Array2::<f64>::zeros((3, 3));
5262        a[[0, 0]] = -2.0;
5263        a[[1, 1]] = 0.5;
5264        a[[2, 2]] = 1.0;
5265        a[[0, 1]] = 0.3;
5266        a[[1, 0]] = 0.3;
5267        a[[1, 2]] = -0.4;
5268        a[[2, 1]] = -0.4;
5269        a[[0, 2]] = 0.1;
5270        a[[2, 0]] = 0.1;
5271
5272        let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5273
5274        // Flatten column-major: block[c*d + r] = a[[r, c]].
5275        let d = 3;
5276        let mut col_major = vec![0.0_f64; d * d];
5277        for c in 0..d {
5278            for r in 0..d {
5279                col_major[c * d + r] = a[[r, c]];
5280            }
5281        }
5282        let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5283
5284        assert!(
5285            (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5286            "colmajor bump {col_major_bump:e} must match rowmajor bump \
5287             {row_major_bump:e} for a symmetric block"
5288        );
5289
5290        // And the bump must actually make it PD (sanity, same as the row-major test).
5291        let mut shifted = a.clone();
5292        for i in 0..d {
5293            shifted[[i, i]] += col_major_bump;
5294        }
5295        assert!(
5296            cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5297            "colmajor Gershgorin bump must make the symmetric block PD"
5298        );
5299    }
5300
5301    fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5302        let mut s = Array2::<f64>::zeros((k, k));
5303        for row in 0..k {
5304            s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5305            if row + 1 < k {
5306                s[[row, row + 1]] = -0.05;
5307                s[[row + 1, row]] = -0.05;
5308            }
5309            if row + 7 < k {
5310                s[[row, row + 7]] = 0.01;
5311                s[[row + 7, row]] = 0.01;
5312            }
5313        }
5314        let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5315        (s, rhs)
5316    }
5317
5318    fn dense_pcg_cpu_reference(
5319        s: &Array2<f64>,
5320        rhs: &Array1<f64>,
5321        max_iterations: usize,
5322        relative_tolerance: f64,
5323    ) -> Array1<f64> {
5324        let k = rhs.len();
5325        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5326        if rhs_norm == 0.0 {
5327            return Array1::<f64>::zeros(k);
5328        }
5329        let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5330        let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5331        let mut x = Array1::<f64>::zeros(k);
5332        let mut r = rhs.clone();
5333        let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5334        let mut p = z.clone();
5335        let mut sp = Array1::<f64>::zeros(k);
5336        let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5337        for _ in 0..max_iterations.max(1) {
5338            for row in 0..k {
5339                let mut acc = 0.0;
5340                for col in 0..k {
5341                    acc += s[[row, col]] * p[col];
5342                }
5343                sp[row] = acc;
5344            }
5345            let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5346            let alpha = rz / p_sp;
5347            for idx in 0..k {
5348                x[idx] += alpha * p[idx];
5349                r[idx] -= alpha * sp[idx];
5350            }
5351            let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5352            if r_norm <= tol {
5353                break;
5354            }
5355            for idx in 0..k {
5356                z[idx] = inv_diag[idx] * r[idx];
5357            }
5358            let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5359            let beta = rz_next / rz;
5360            for idx in 0..k {
5361                p[idx] = z[idx] + beta * p[idx];
5362            }
5363            rz = rz_next;
5364        }
5365        x
5366    }
5367
5368    #[test]
5369    fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5370        let (s, rhs) = device_pcg_fixture(512);
5371        let max_iterations = 200usize;
5372        let relative_tolerance = 1.0e-12;
5373        let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5374        let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5375            &s,
5376            &rhs,
5377            max_iterations,
5378            relative_tolerance,
5379        ) {
5380            Ok(result) => result,
5381            // #1017 — fail loud, never skip-pass: this fixture clears the device
5382            // offload floor, so a CUDA device that is PRESENT yet declines/returns
5383            // Err means the device PCG kernel does not run on GPU (a real fault that
5384            // must not masquerade as a pass via this skip). Legit skip ONLY when no
5385            // usable CUDA device exists (CPU CI). The exact `ArrowSchurGpuFailure`
5386            // variant is folded into the assert message as the diagnostic.
5387            Err(failure) => {
5388                assert!(
5389                    gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5390                    "#1017: CUDA device present but the device reduced-beta PCG \
5391                     declined/faulted instead of returning a result (tag: {failure:?}) — \
5392                     the kernel does not run correctly on GPU"
5393                );
5394                return;
5395            }
5396        };
5397        let max_err = cpu
5398            .iter()
5399            .zip(device.iter())
5400            .map(|(a, b)| (a - b).abs())
5401            .fold(0.0_f64, f64::max);
5402        assert!(
5403            max_err <= 1.0e-10,
5404            "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5405        );
5406        assert!(diag.matvec_calls > 0);
5407        assert_eq!(diag.matvec_calls, diag.iterations);
5408    }
5409
5410    #[test]
5411    fn dense_reference_matches_independent_solve() {
5412        let sys = build_fixture(4, 5, 3, 7);
5413        let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5414        // Re-solve by an independent matrix build and a textbook
5415        // Gaussian-elimination Cholesky to guard against typos in the
5416        // reference implementation itself.
5417        let n = sys.rows.len();
5418        let d = sys.d;
5419        let k = sys.k;
5420        let total = n * d + k;
5421        let mut h = Array2::<f64>::zeros((total, total));
5422        let mut g = ndarray::Array1::<f64>::zeros(total);
5423        for (i, row) in sys.rows.iter().enumerate() {
5424            let base = i * d;
5425            for c in 0..d {
5426                for r in 0..d {
5427                    h[[base + r, base + c]] = row.htt[[r, c]];
5428                }
5429            }
5430            for c in 0..k {
5431                for r in 0..d {
5432                    h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5433                    h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5434                }
5435            }
5436            for r in 0..d {
5437                g[base + r] = row.gt[r];
5438            }
5439        }
5440        for c in 0..k {
5441            for r in 0..k {
5442                h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5443            }
5444            g[n * d + c] = sys.gb[c];
5445        }
5446        let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5447        let rhs = g.mapv(|v| -v);
5448        let expected = cholesky_solve_vector(l.view(), rhs.view());
5449        for i in 0..n * d {
5450            assert!(
5451                (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5452                "delta_t[{i}] mismatch: got {} expected {}",
5453                solution.delta_t[i],
5454                expected[i]
5455            );
5456        }
5457        for a in 0..k {
5458            assert!(
5459                (solution.delta_beta[a] - expected[n * d + a]).abs()
5460                    < 1e-10 * (1.0 + expected[n * d + a].abs()),
5461                "delta_beta[{a}] mismatch"
5462            );
5463        }
5464    }
5465
5466    /// #1017: the row-procedural reduced-Schur matvec (the matrix-free SAE
5467    /// host backend) auto-fans its per-row point-elimination sum across rayon
5468    /// over fixed row chunks when at the top level (`n ≥
5469    /// SCHUR_MATVEC_PARALLEL_ROW_MIN`), and stays serial when already inside a
5470    /// rayon worker. The chunk-ordered fold makes the parallel result
5471    /// **deterministic** (two parallel calls are bit-identical — scheduling
5472    /// cannot change the numbers) and it agrees with the serial accumulation up
5473    /// to ULP-scale chunk reassociation (the #1017 verification gate). That
5474    /// reassociation is a genuine f64 departure from serial, so the criterion
5475    /// ranking across topology candidates is stable only up to the reassociation
5476    /// margin: a near-tie winner inside that margin can flip. This is NOT an
5477    /// exact no-move guarantee (#1211); for that, the ranking path must use the
5478    /// fixed-order serial accumulation.
5479    #[test]
5480    fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5481        use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5482        let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; // trips the parallel path
5483        let d = 3usize;
5484        let k = 24usize;
5485        let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5486        // Install a matrix-free forward/transpose pair that reads the dense
5487        // `htbeta` slabs the fixture already populated, so the procedural
5488        // backend has a well-defined operator to apply (and exercises exactly
5489        // the sparse gather/scatter the SAE Kronecker path drives).
5490        let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5491        let forward_slabs = slabs.clone();
5492        let transpose_slabs = slabs;
5493        sys.set_row_htbeta_operator(
5494            move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5495                let h = &forward_slabs[row];
5496                for r in 0..h.nrows() {
5497                    let mut acc = 0.0_f64;
5498                    for c in 0..h.ncols() {
5499                        acc += h[[r, c]] * x[c];
5500                    }
5501                    out[r] = acc;
5502                }
5503            },
5504            move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5505                let h = &transpose_slabs[row];
5506                for r in 0..h.nrows() {
5507                    for c in 0..h.ncols() {
5508                        out[c] += h[[r, c]] * v[r];
5509                    }
5510                }
5511            },
5512        );
5513
5514        let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5515            .expect("row-procedural matvec backend builds for matrix-free system");
5516        let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5517
5518        // Top-level call: auto-selects the parallel chunk-fold. Run twice and
5519        // assert bit-identity — the chunk-ordered reduction must not depend on
5520        // thread scheduling.
5521        let mut out_parallel_a = Array1::<f64>::zeros(k);
5522        matvec(&x, &mut out_parallel_a);
5523        let mut out_parallel_b = Array1::<f64>::zeros(k);
5524        matvec(&x, &mut out_parallel_b);
5525        for a in 0..k {
5526            assert_eq!(
5527                out_parallel_a[a].to_bits(),
5528                out_parallel_b[a].to_bits(),
5529                "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5530            );
5531        }
5532
5533        // Inside a rayon worker: auto-selects the serial path (nested-rayon
5534        // guard). `install` runs the closure on a pool thread, so
5535        // `current_thread_index()` is `Some`. The serial running sum and the
5536        // chunk-ordered parallel fold differ only by f64 reassociation.
5537        let mut out_serial = Array1::<f64>::zeros(k);
5538        rayon::ThreadPoolBuilder::new()
5539            .num_threads(2)
5540            .build()
5541            .expect("build rayon pool")
5542            .install(|| matvec(&x, &mut out_serial));
5543
5544        let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5545        for a in 0..k {
5546            let diff = (out_parallel_a[a] - out_serial[a]).abs();
5547            assert!(
5548                diff <= 1e-12 * (1.0 + max_abs),
5549                "row-procedural matvec parallel vs serial diverged beyond reassociation \
5550                 at index {a}: {} vs {} (diff={diff:e})",
5551                out_parallel_a[a],
5552                out_serial[a]
5553            );
5554        }
5555    }
5556
5557    /// #1017/#1026 — the frames-engaged CPU reduced-Schur matvec
5558    /// [`sae_framed_schur_matvec_cpu`] (the bit-parity oracle the GPU kernel
5559    /// mirrors) must equal the dense reduced Schur `S = (P_ββ + ρ_β I) −
5560    /// Σ_i H_βt^(i)(H_tt^(i)+ρ_t I)⁻¹ H_tβ^(i)` formed by the canonical dense
5561    /// reference, on a small framed system with mixed per-atom ranks
5562    /// (`r_k < p` framed + `r_k = p` un-framed). Size-independent gate.
5563    #[test]
5564    fn framed_sae_schur_matvec_matches_dense_reference() {
5565        use crate::arrow_schur::{
5566            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5567            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5568        };
5569
5570        let p = 4usize;
5571        // Three atoms: ranks 2 (framed), 4 (un-framed), 3 (framed).
5572        let ranks = vec![2usize, 4usize, 3usize];
5573        let basis_sizes = vec![2usize, 1usize, 2usize];
5574        let n_atoms = ranks.len();
5575        let mut border_offsets = Vec::with_capacity(n_atoms);
5576        let mut acc = 0usize;
5577        for k in 0..n_atoms {
5578            border_offsets.push(acc);
5579            acc += basis_sizes[k] * ranks[k];
5580        }
5581        let border_dim = acc; // 2*2 + 1*4 + 2*3 = 14
5582
5583        let mut state = 0x1234_5678_9abc_def0u64;
5584        let mut sample = || -> f64 {
5585            state = state
5586                .wrapping_mul(6364136223846793005)
5587                .wrapping_add(1442695040888963407);
5588            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5589        };
5590
5591        // Per-atom orthonormal-ish frames U_k (p × r_k) for the W = U_iᵀU_j
5592        // factors; un-framed atom (r=p) uses U = I_p.
5593        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5594        for k in 0..n_atoms {
5595            let r = ranks[k];
5596            let mut u = Array2::<f64>::zeros((p, r));
5597            for i in 0..p {
5598                for j in 0..r {
5599                    u[[i, j]] = if r == p && i == j {
5600                        1.0
5601                    } else if r == p {
5602                        0.0
5603                    } else {
5604                        sample()
5605                    };
5606                }
5607            }
5608            frames.push(u);
5609        }
5610        let w_of = |i: usize, j: usize| -> Array2<f64> {
5611            let (ui, uj) = (&frames[i], &frames[j]);
5612            let (ri, rj) = (ranks[i], ranks[j]);
5613            let mut w = Array2::<f64>::zeros((ri, rj));
5614            for a in 0..ri {
5615                for b in 0..rj {
5616                    let mut s = 0.0;
5617                    for c in 0..p {
5618                        s += ui[[c, a]] * uj[[c, b]];
5619                    }
5620                    w[[a, b]] = s;
5621                }
5622            }
5623            w
5624        };
5625
5626        // Co-occurring data-fit blocks: all diagonal pairs + one cross (0,2).
5627        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5628        let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5629        pairs.sort();
5630        for &(i, j) in &pairs {
5631            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5632            let mut g = Array2::<f64>::zeros((mi, mj));
5633            for r in 0..mi {
5634                for c in 0..mj {
5635                    g[[r, c]] = 0.3 * sample();
5636                }
5637            }
5638            // Make diagonal blocks SPD-leaning so S stays PD.
5639            if i == j {
5640                for r in 0..mi.min(mj) {
5641                    g[[r, r]] += mi as f64 + 2.0;
5642                }
5643            }
5644            frame_blocks.push(FactoredFrameGBlock {
5645                atom_i: i,
5646                atom_j: j,
5647                g,
5648                w: w_of(i, j),
5649            });
5650        }
5651
5652        // Smooth blocks λ S_k (M_k × M_k), SPD.
5653        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5654        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5655        for k in 0..n_atoms {
5656            let m = basis_sizes[k];
5657            let mut a = Array2::<f64>::zeros((m, m));
5658            for r in 0..m {
5659                for c in 0..m {
5660                    a[[r, c]] = 0.2 * sample();
5661                }
5662            }
5663            let mut s = a.t().dot(&a);
5664            for r in 0..m {
5665                s[[r, r]] += 1.0;
5666            }
5667            smooth_blocks.push(DeviceSaeSmoothBlock {
5668                global_offset: border_offsets[k],
5669                factor_a: s,
5670            });
5671            smooth_ranks.push(ranks[k]);
5672        }
5673
5674        // Build the system: n rows, dense htbeta slabs (q_i × border_dim).
5675        let n = 6usize;
5676        let q = 3usize;
5677        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5678        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5679        for i in 0..n {
5680            // SPD htt.
5681            let mut a = Array2::<f64>::zeros((q, q));
5682            for r in 0..q {
5683                for c in 0..q {
5684                    a[[r, c]] = sample();
5685                }
5686            }
5687            let mut htt = a.t().dot(&a);
5688            for r in 0..q {
5689                htt[[r, r]] += q as f64 + 1.0;
5690            }
5691            sys.rows[i].htt = htt;
5692            let mut slab = vec![0.0_f64; q * border_dim];
5693            for c in 0..q {
5694                for col in 0..border_dim {
5695                    let v = 0.15 * sample();
5696                    slab[c * border_dim + col] = v;
5697                    sys.rows[i].htbeta[[c, col]] = v;
5698                }
5699            }
5700            row_htbeta.push(slab);
5701        }
5702
5703        // Dense H_ββ from the SAME penalty ops (so the dense reference's S
5704        // matches the device penalty side exactly).
5705        let data_op =
5706            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5707                .expect("frame op");
5708        let mut hbb = data_op.to_dense();
5709        for k in 0..n_atoms {
5710            let op = IdentityRightKroneckerPenaltyOp {
5711                factor_a: smooth_blocks[k].factor_a.clone(),
5712                p: ranks[k],
5713                global_offset: border_offsets[k],
5714                k: border_dim,
5715            };
5716            let d = op.to_dense();
5717            for r in 0..border_dim {
5718                for c in 0..border_dim {
5719                    hbb[[r, c]] += d[[r, c]];
5720                }
5721            }
5722        }
5723        sys.hbb = hbb;
5724
5725        let data = DeviceSaePcgData {
5726            p,
5727            beta_dim: border_dim,
5728            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5729            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5730            smooth_blocks,
5731            sparse_g_blocks: Vec::new(),
5732            frame: Some(DeviceSaeFrameData {
5733                ranks: ranks.clone(),
5734                basis_sizes: basis_sizes.clone(),
5735                border_offsets: border_offsets.clone(),
5736                frame_blocks,
5737                smooth_ranks,
5738                row_htbeta,
5739            }),
5740        };
5741
5742        let ridge_t = 1e-7;
5743        let ridge_beta = 1e-6;
5744
5745        // Dense reference reduced Schur S (border_dim × border_dim), formed
5746        // exactly as solve_arrow_newton_step_dense_reference assembles the
5747        // bordered Hessian and eliminates the t-block.
5748        let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5749        for r in 0..border_dim {
5750            for c in 0..border_dim {
5751                s_dense[[r, c]] = sys.hbb[[r, c]];
5752            }
5753            s_dense[[r, r]] += ridge_beta;
5754        }
5755        for row in &sys.rows {
5756            let mut htt = row.htt.clone();
5757            for d in 0..q {
5758                htt[[d, d]] += ridge_t;
5759            }
5760            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5761                .expect("htt PD");
5762            // Y = (htt)⁻¹ htbeta  (q × border_dim); S -= htbetaᵀ Y.
5763            let mut y = Array2::<f64>::zeros((q, border_dim));
5764            for col in 0..border_dim {
5765                let mut e = Array1::<f64>::zeros(q);
5766                for r in 0..q {
5767                    e[r] = row.htbeta[[r, col]];
5768                }
5769                let solved = cholesky_solve_vector(factor.view(), e.view());
5770                for r in 0..q {
5771                    y[[r, col]] = solved[r];
5772                }
5773            }
5774            for r in 0..border_dim {
5775                for c in 0..border_dim {
5776                    let mut acc = 0.0;
5777                    for d in 0..q {
5778                        acc += row.htbeta[[d, r]] * y[[d, c]];
5779                    }
5780                    s_dense[[r, c]] -= acc;
5781                }
5782            }
5783        }
5784
5785        // Probe vectors: compare S·x from the device-data CPU oracle vs dense S·x.
5786        let mut max_rel = 0.0_f64;
5787        for trial in 0..4 {
5788            let x: Vec<f64> = (0..border_dim)
5789                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5790                .collect();
5791            let mut got = vec![0.0_f64; border_dim];
5792            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5793                .expect("framed matvec");
5794            let mut want = vec![0.0_f64; border_dim];
5795            for r in 0..border_dim {
5796                let mut acc = 0.0;
5797                for c in 0..border_dim {
5798                    acc += s_dense[[r, c]] * x[c];
5799                }
5800                want[r] = acc;
5801            }
5802            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5803            for a in 0..border_dim {
5804                let rel = (got[a] - want[a]).abs() / scale;
5805                max_rel = max_rel.max(rel);
5806            }
5807        }
5808        assert!(
5809            max_rel <= 1e-10,
5810            "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5811        );
5812    }
5813
5814    /// #1017/#1026 — large-K, many-atom, dense-cross-pair parity for the framed
5815    /// SAE reduced-Schur CPU oracle. The small `framed_sae_schur_matvec_matches_dense_reference`
5816    /// pins `border_dim=14`/3 atoms/1 cross pair; the interactions that only
5817    /// appear at scale — variable per-atom `r_k` (mixed framed `r_k<p` and
5818    /// un-framed `r_k=p`), the prefix-sum `border_offsets`, and dense cross-atom
5819    /// `W_ij` coupling across MANY co-occurring `frame_blocks` and many `row_htbeta`
5820    /// slabs — were validated only on the A100. This pins the CPU oracle (and
5821    /// therefore the device kernel it mirrors) against the dense reduced Schur at
5822    /// 40 atoms / `border_dim≈240` / a neighbour-coupled cross-pair set, on CPU.
5823    #[test]
5824    fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5825        use crate::arrow_schur::{
5826            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5827            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5828        };
5829
5830        let p = 12usize;
5831        let n_atoms = 40usize;
5832        // Variable per-atom rank: most atoms are genuinely framed (r_k<p), every
5833        // fifth atom is un-framed (r_k=p, U=I_p — the within-atom G⊗I_r collapse).
5834        let ranks: Vec<usize> = (0..n_atoms)
5835            .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5836            .collect();
5837        let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5838        let mut border_offsets = Vec::with_capacity(n_atoms);
5839        let mut acc = 0usize;
5840        for k in 0..n_atoms {
5841            border_offsets.push(acc);
5842            acc += basis_sizes[k] * ranks[k];
5843        }
5844        let border_dim = acc;
5845
5846        let mut state = 0x0bad_c0de_dead_beefu64;
5847        let mut sample = || -> f64 {
5848            state = state
5849                .wrapping_mul(6364136223846793005)
5850                .wrapping_add(1442695040888963407);
5851            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5852        };
5853
5854        // Per-atom frames U_k (p × r_k); un-framed atom (r=p) uses U = I_p.
5855        let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5856        for k in 0..n_atoms {
5857            let r = ranks[k];
5858            let mut u = Array2::<f64>::zeros((p, r));
5859            for i in 0..p {
5860                for j in 0..r {
5861                    u[[i, j]] = if r == p {
5862                        if i == j { 1.0 } else { 0.0 }
5863                    } else {
5864                        sample()
5865                    };
5866                }
5867            }
5868            frames.push(u);
5869        }
5870        let w_of = |i: usize, j: usize| -> Array2<f64> {
5871            let (ui, uj) = (&frames[i], &frames[j]);
5872            let (ri, rj) = (ranks[i], ranks[j]);
5873            let mut w = Array2::<f64>::zeros((ri, rj));
5874            for a in 0..ri {
5875                for b in 0..rj {
5876                    let mut s = 0.0;
5877                    for c in 0..p {
5878                        s += ui[[c, a]] * uj[[c, b]];
5879                    }
5880                    w[[a, b]] = s;
5881                }
5882            }
5883            w
5884        };
5885
5886        // Co-occurring data-fit blocks: every diagonal pair + each neighbour pair
5887        // (k,k+1) and its transpose — dense cross coupling across the whole border.
5888        let mut pairs: Vec<(usize, usize)> = Vec::new();
5889        for k in 0..n_atoms {
5890            pairs.push((k, k));
5891        }
5892        for k in 0..n_atoms - 1 {
5893            pairs.push((k, k + 1));
5894            pairs.push((k + 1, k));
5895        }
5896        pairs.sort_unstable();
5897        let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5898        for &(i, j) in &pairs {
5899            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5900            let mut g = Array2::<f64>::zeros((mi, mj));
5901            for r in 0..mi {
5902                for c in 0..mj {
5903                    g[[r, c]] = 0.3 * sample();
5904                }
5905            }
5906            if i == j {
5907                for r in 0..mi.min(mj) {
5908                    g[[r, r]] += mi as f64 + 2.0;
5909                }
5910            }
5911            frame_blocks.push(FactoredFrameGBlock {
5912                atom_i: i,
5913                atom_j: j,
5914                g,
5915                w: w_of(i, j),
5916            });
5917        }
5918
5919        // Smooth blocks λ S_k (M_k × M_k), SPD.
5920        let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5921        let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5922        for k in 0..n_atoms {
5923            let m = basis_sizes[k];
5924            let mut a = Array2::<f64>::zeros((m, m));
5925            for r in 0..m {
5926                for c in 0..m {
5927                    a[[r, c]] = 0.2 * sample();
5928                }
5929            }
5930            let mut s = a.t().dot(&a);
5931            for r in 0..m {
5932                s[[r, r]] += 1.0;
5933            }
5934            smooth_blocks.push(DeviceSaeSmoothBlock {
5935                global_offset: border_offsets[k],
5936                factor_a: s,
5937            });
5938            smooth_ranks.push(ranks[k]);
5939        }
5940
5941        // n rows with SPD htt and full-width (q × border_dim) htbeta slabs.
5942        let n = 8usize;
5943        let q = 3usize;
5944        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5945        let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5946        for i in 0..n {
5947            let mut a = Array2::<f64>::zeros((q, q));
5948            for r in 0..q {
5949                for c in 0..q {
5950                    a[[r, c]] = sample();
5951                }
5952            }
5953            let mut htt = a.t().dot(&a);
5954            for r in 0..q {
5955                htt[[r, r]] += q as f64 + 1.0;
5956            }
5957            sys.rows[i].htt = htt;
5958            let mut slab = vec![0.0_f64; q * border_dim];
5959            for c in 0..q {
5960                for col in 0..border_dim {
5961                    let v = 0.15 * sample();
5962                    slab[c * border_dim + col] = v;
5963                    sys.rows[i].htbeta[[c, col]] = v;
5964                }
5965            }
5966            row_htbeta.push(slab);
5967        }
5968
5969        // Dense H_ββ from the SAME penalty ops (data-fit + smooth), so the dense
5970        // reference S matches the device penalty side exactly.
5971        let data_op =
5972            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5973                .expect("frame op");
5974        let mut hbb = data_op.to_dense();
5975        for k in 0..n_atoms {
5976            let op = IdentityRightKroneckerPenaltyOp {
5977                factor_a: smooth_blocks[k].factor_a.clone(),
5978                p: ranks[k],
5979                global_offset: border_offsets[k],
5980                k: border_dim,
5981            };
5982            let d = op.to_dense();
5983            for r in 0..border_dim {
5984                for c in 0..border_dim {
5985                    hbb[[r, c]] += d[[r, c]];
5986                }
5987            }
5988        }
5989        sys.hbb = hbb;
5990
5991        let data = DeviceSaePcgData {
5992            p,
5993            beta_dim: border_dim,
5994            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5995            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5996            smooth_blocks,
5997            sparse_g_blocks: Vec::new(),
5998            frame: Some(DeviceSaeFrameData {
5999                ranks: ranks.clone(),
6000                basis_sizes: basis_sizes.clone(),
6001                border_offsets: border_offsets.clone(),
6002                frame_blocks,
6003                smooth_ranks,
6004                row_htbeta,
6005            }),
6006        };
6007
6008        let ridge_t = 1e-7;
6009        let ridge_beta = 1e-6;
6010
6011        // Dense reference reduced Schur S = (hbb + ridge_beta I) - Σ_i htbetaᵀ (htt+ridge_t I)⁻¹ htbeta.
6012        let mut s_dense = sys.hbb.clone();
6013        for r in 0..border_dim {
6014            s_dense[[r, r]] += ridge_beta;
6015        }
6016        for row in &sys.rows {
6017            let mut htt = row.htt.clone();
6018            for d in 0..q {
6019                htt[[d, d]] += ridge_t;
6020            }
6021            let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
6022                .expect("htt PD");
6023            let mut y = Array2::<f64>::zeros((q, border_dim));
6024            for col in 0..border_dim {
6025                let mut e = Array1::<f64>::zeros(q);
6026                for r in 0..q {
6027                    e[r] = row.htbeta[[r, col]];
6028                }
6029                let solved = cholesky_solve_vector(factor.view(), e.view());
6030                for r in 0..q {
6031                    y[[r, col]] = solved[r];
6032                }
6033            }
6034            for r in 0..border_dim {
6035                for c in 0..border_dim {
6036                    let mut acc = 0.0;
6037                    for d in 0..q {
6038                        acc += row.htbeta[[d, r]] * y[[d, c]];
6039                    }
6040                    s_dense[[r, c]] -= acc;
6041                }
6042            }
6043        }
6044
6045        let mut max_rel = 0.0_f64;
6046        for trial in 0..4 {
6047            let x: Vec<f64> = (0..border_dim)
6048                .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6049                .collect();
6050            let mut got = vec![0.0_f64; border_dim];
6051            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6052                .expect("framed matvec");
6053            let mut want = vec![0.0_f64; border_dim];
6054            for r in 0..border_dim {
6055                let mut acc = 0.0;
6056                for c in 0..border_dim {
6057                    acc += s_dense[[r, c]] * x[c];
6058                }
6059                want[r] = acc;
6060            }
6061            let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6062            for a in 0..border_dim {
6063                let rel = (got[a] - want[a]).abs() / scale;
6064                max_rel = max_rel.max(rel);
6065            }
6066        }
6067        assert!(
6068            max_rel <= 1e-10,
6069            "large-K framed SAE Schur matvec vs dense reference diverged: \
6070             max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6071        );
6072    }
6073
6074    /// #1017/#1026 GPU arm: when a CUDA device admits the framed SAE PCG, its
6075    /// solved `δβ` must match the CPU dense reduced-system solve of the SAME
6076    /// framed system (size-independent — a small device validates the kernel).
6077    /// Skips cleanly (returns) when no device is available or the policy
6078    /// declines (`solve_sae_matrix_free_pcg` → `Unavailable`).
6079    #[test]
6080    fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6081        use crate::arrow_schur::{
6082            BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6083            FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6084        };
6085
6086        // Large enough to clear the device-offload policy floor (k ≥ 32 and
6087        // n·k·d·iters ≥ MATVEC_OFFLOAD_FLOPS_MIN) so the GPU kernel actually
6088        // runs on a device rather than the policy declining.
6089        let p = 6usize;
6090        let n_atoms = 8usize;
6091        let ranks: Vec<usize> = (0..n_atoms)
6092            .map(|k| if k % 2 == 0 { 3usize } else { p })
6093            .collect();
6094        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6095        let mut border_offsets = Vec::with_capacity(n_atoms);
6096        let mut acc = 0usize;
6097        for k in 0..n_atoms {
6098            border_offsets.push(acc);
6099            acc += basis_sizes[k] * ranks[k];
6100        }
6101        let border_dim = acc; // Σ M_k·r_k = 4·(3·3) + 4·(3·6) = 36 + 72 = 108
6102
6103        let mut state = 0xfeed_face_dead_beefu64;
6104        let mut sample = || -> f64 {
6105            state = state
6106                .wrapping_mul(6364136223846793005)
6107                .wrapping_add(1442695040888963407);
6108            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6109        };
6110        let mut frames: Vec<Array2<f64>> = Vec::new();
6111        for k in 0..n_atoms {
6112            let r = ranks[k];
6113            let mut u = Array2::<f64>::zeros((p, r));
6114            for i in 0..p {
6115                for j in 0..r {
6116                    u[[i, j]] = if r == p && i == j {
6117                        1.0
6118                    } else if r == p {
6119                        0.0
6120                    } else {
6121                        sample()
6122                    };
6123                }
6124            }
6125            frames.push(u);
6126        }
6127        let w_of = |i: usize, j: usize| {
6128            let (ui, uj) = (&frames[i], &frames[j]);
6129            let (ri, rj) = (ranks[i], ranks[j]);
6130            let mut w = Array2::<f64>::zeros((ri, rj));
6131            for a in 0..ri {
6132                for b in 0..rj {
6133                    let mut s = 0.0;
6134                    for c in 0..p {
6135                        s += ui[[c, a]] * uj[[c, b]];
6136                    }
6137                    w[[a, b]] = s;
6138                }
6139            }
6140            w
6141        };
6142        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6143        // A few off-diagonal cross blocks (symmetric pairs).
6144        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6145            pairs.push((i, j));
6146            pairs.push((j, i));
6147        }
6148        let mut frame_blocks = Vec::new();
6149        for &(i, j) in &pairs {
6150            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6151            let mut g = Array2::<f64>::zeros((mi, mj));
6152            for r in 0..mi {
6153                for c in 0..mj {
6154                    g[[r, c]] = 0.25 * sample();
6155                }
6156            }
6157            if i == j {
6158                for r in 0..mi.min(mj) {
6159                    g[[r, r]] += mi as f64 + 2.0;
6160                }
6161            }
6162            frame_blocks.push(FactoredFrameGBlock {
6163                atom_i: i,
6164                atom_j: j,
6165                g,
6166                w: w_of(i, j),
6167            });
6168        }
6169        let mut smooth_blocks = Vec::new();
6170        let mut smooth_ranks = Vec::new();
6171        for k in 0..n_atoms {
6172            let m = basis_sizes[k];
6173            let mut a = Array2::<f64>::zeros((m, m));
6174            for r in 0..m {
6175                for c in 0..m {
6176                    a[[r, c]] = 0.2 * sample();
6177                }
6178            }
6179            let mut s = a.t().dot(&a);
6180            for r in 0..m {
6181                s[[r, r]] += 1.0;
6182            }
6183            smooth_blocks.push(DeviceSaeSmoothBlock {
6184                global_offset: border_offsets[k],
6185                factor_a: s,
6186            });
6187            smooth_ranks.push(ranks[k]);
6188        }
6189        let n = 400usize;
6190        let q = 4usize;
6191        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6192        let mut row_htbeta = Vec::new();
6193        for i in 0..n {
6194            let mut a = Array2::<f64>::zeros((q, q));
6195            for r in 0..q {
6196                for c in 0..q {
6197                    a[[r, c]] = sample();
6198                }
6199            }
6200            let mut htt = a.t().dot(&a);
6201            for r in 0..q {
6202                htt[[r, r]] += q as f64 + 1.0;
6203            }
6204            sys.rows[i].htt = htt;
6205            let mut slab = vec![0.0_f64; q * border_dim];
6206            for c in 0..q {
6207                for col in 0..border_dim {
6208                    // Small entries: with 400 rows the reduced-Schur subtraction
6209                    // Σ_i H_βtᵀ H_tt⁻¹ H_tβ must not overwhelm the PD penalty.
6210                    let v = 0.02 * sample();
6211                    slab[c * border_dim + col] = v;
6212                    sys.rows[i].htbeta[[c, col]] = v;
6213                }
6214            }
6215            row_htbeta.push(slab);
6216        }
6217        let data_op =
6218            FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6219                .expect("frame op");
6220        let mut hbb = data_op.to_dense();
6221        for k in 0..n_atoms {
6222            let op = IdentityRightKroneckerPenaltyOp {
6223                factor_a: smooth_blocks[k].factor_a.clone(),
6224                p: ranks[k],
6225                global_offset: border_offsets[k],
6226                k: border_dim,
6227            };
6228            let d = op.to_dense();
6229            for r in 0..border_dim {
6230                for c in 0..border_dim {
6231                    hbb[[r, c]] += d[[r, c]];
6232                }
6233            }
6234        }
6235        sys.hbb = hbb;
6236        let data = DeviceSaePcgData {
6237            p,
6238            beta_dim: border_dim,
6239            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6240            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6241            smooth_blocks,
6242            sparse_g_blocks: Vec::new(),
6243            frame: Some(DeviceSaeFrameData {
6244                ranks: ranks.clone(),
6245                basis_sizes: basis_sizes.clone(),
6246                border_offsets: border_offsets.clone(),
6247                frame_blocks,
6248                smooth_ranks,
6249                row_htbeta,
6250            }),
6251        };
6252        let ridge_t = 1e-7;
6253        let ridge_beta = 1e-6;
6254        let rhs: Array1<f64> =
6255            Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6256
6257        let (device, diag) =
6258            match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6259                Ok(result) => result,
6260                // #1017 — fail loud, never skip-pass: this fixture clears the device
6261                // offload floor, so a CUDA device that is PRESENT yet declines means the
6262                // framed device PCG kernel does not run on GPU (the fault must not pass
6263                // silently). Legit skip ONLY when no usable CUDA device exists (CPU CI).
6264                // The exact `ArrowSchurGpuFailure` variant is folded into the assert.
6265                Err(failure) => {
6266                    assert!(
6267                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6268                        "#1017: CUDA device present but the framed device SAE PCG \
6269                     declined/faulted instead of returning a result (tag: {failure:?}) — \
6270                     the kernel does not run correctly on GPU"
6271                    );
6272                    return;
6273                }
6274            };
6275
6276        // #1551 PARITY GATE — operator-residual, NOT solution-vector equality.
6277        //
6278        // The honest GPU↔CPU contract for an iterative solve of `S·δβ = rhs` is
6279        // that the device solution SOLVES the system DEFINED BY THE CPU ORACLE to
6280        // PCG tolerance — i.e. `‖S_cpu·δβ_device − rhs‖ / ‖rhs‖ ≤ tol`, where
6281        // `S_cpu` is applied with the bit-for-bit CPU oracle matvec the device
6282        // kernel mirrors (`sae_framed_schur_matvec_cpu`). This is the correct,
6283        // conditioning-robust gate: a near-singular assembled `S` has a large
6284        // condition number `κ(S)`, which amplifies an O(ε) residual difference
6285        // into an O(κ·ε) *solution-vector* difference, so comparing δβ vectors
6286        // would spuriously fail even when both operators are bit-identical and
6287        // both solves converged. (Historically this test compared δβ against a
6288        // dense-Cholesky reference and "failed" with max_rel≈0.9 because the
6289        // dense solve itself only reached ‖S·x−rhs‖≈0.1 on this fixture's
6290        // ill-conditioned S while the device PCG reached ~1e-12 — the device was
6291        // MORE accurate than the reference, not wrong. The kernel correctness is
6292        // pinned conditioning-free by `framed_sae_device_matvec_matches_cpu_oracle_*`.)
6293        let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6294        let oracle_resid = |x: &Array1<f64>| -> f64 {
6295            let mut sx = vec![0.0_f64; border_dim];
6296            sae_framed_schur_matvec_cpu(
6297                &sys,
6298                &data,
6299                ridge_t,
6300                ridge_beta,
6301                x.as_slice().unwrap(),
6302                &mut sx,
6303            )
6304            .expect("cpu oracle matvec");
6305            let mut acc = 0.0_f64;
6306            for a in 0..border_dim {
6307                let e = sx[a] - rhs[a];
6308                acc += e * e;
6309            }
6310            acc.sqrt()
6311        };
6312        let s_dev_resid = oracle_resid(&device);
6313        let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6314
6315        // Independent CPU iterative solve of the SAME operator with the SAME
6316        // Jacobi preconditioner the device builds, via the shared `pcg_core`. If
6317        // the device kernel computed a different operator, the two converged
6318        // residuals could not BOTH be tiny.
6319        let precond = {
6320            let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6321            // The reduced-Schur diagonal subtraction (device `arrow_sae_frame_diag_sub`)
6322            // mirrored on the host for the Jacobi preconditioner.
6323            let mut diag = d;
6324            for (i, row) in sys.rows.iter().enumerate() {
6325                let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6326                let qi = sys.row_dims[i];
6327                if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6328                    continue;
6329                }
6330                let mut block = row.htt.clone();
6331                for dd in 0..qi {
6332                    block[[dd, dd]] += ridge_t;
6333                }
6334                let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6335                    .expect("row htt PD");
6336                // ainv = (H_tt+ρI)⁻¹ column by column.
6337                let mut ainv = Array2::<f64>::zeros((qi, qi));
6338                for col in 0..qi {
6339                    let mut e = Array1::<f64>::zeros(qi);
6340                    e[col] = 1.0;
6341                    let s = cholesky_solve_vector(factor.view(), e.view());
6342                    for r in 0..qi {
6343                        ainv[[r, col]] = s[r];
6344                    }
6345                }
6346                for a in 0..border_dim {
6347                    let mut quad = 0.0_f64;
6348                    for c in 0..qi {
6349                        let hc = slab[c * border_dim + a];
6350                        for dd in 0..qi {
6351                            quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6352                        }
6353                    }
6354                    diag[a] -= quad;
6355                }
6356            }
6357            Array1::from_vec(diag)
6358        };
6359        let mut cpu = Array1::<f64>::zeros(border_dim);
6360        let cpu_result = {
6361            let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6362                let mut tmp = vec![0.0_f64; border_dim];
6363                sae_framed_schur_matvec_cpu(
6364                    &sys,
6365                    &data,
6366                    ridge_t,
6367                    ridge_beta,
6368                    v.as_slice().unwrap(),
6369                    &mut tmp,
6370                )
6371                .expect("cpu oracle matvec");
6372                out.assign(&Array1::from_vec(tmp));
6373            };
6374            gam_linalg::pcg::pcg_core(
6375                &mut apply,
6376                &rhs.view(),
6377                &precond.view(),
6378                1e-12,
6379                800,
6380                32,
6381                false,
6382                gam_linalg::pcg::DotReduction::Serial,
6383                &mut cpu.view_mut(),
6384            )
6385        };
6386        let s_cpu_resid = oracle_resid(&cpu);
6387        let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6388
6389        // GATE 1: the device solution solves the CPU-oracle system to PCG-grade
6390        // accuracy (proves device kernel == CPU operator AND device PCG converged).
6391        assert!(
6392            dev_rel_resid <= 1e-7,
6393            "[#1551] device δβ does not solve the CPU-oracle system: \
6394             ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6395             device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6396             means the device matvec is a DIFFERENT operator (kernel bug)",
6397            diag.stopping_reason,
6398            diag.iterations,
6399            diag.final_relative_residual,
6400        );
6401        // GATE 2: the independent CPU iterative solve of the same operator with the
6402        // same preconditioner also converges — both paths agree on the operator.
6403        assert!(
6404            cpu_rel_resid <= 1e-6,
6405            "[#1551] CPU pcg_core failed to solve the oracle system: \
6406             ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6407            cpu_result.stop,
6408            cpu_result.iterations,
6409        );
6410    }
6411
6412    /// Host mirror of the device `sae_frame_penalty_diag_host` for the framed
6413    /// Jacobi preconditioner (penalty diagonal only; the reduced-Schur diagonal
6414    /// subtraction is applied by the caller). Test-only.
6415    fn sae_frame_penalty_diag_host_for_test(
6416        data: &DeviceSaePcgData,
6417        ridge_beta: f64,
6418    ) -> Vec<f64> {
6419        let frame = data.frame.as_ref().expect("frame");
6420        let mut diag = vec![ridge_beta; data.beta_dim];
6421        for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6422            let m = blk.factor_a.nrows();
6423            for ia in 0..m {
6424                let coeff = blk.factor_a[[ia, ia]];
6425                let base = blk.global_offset + ia * r;
6426                for ib in 0..r {
6427                    diag[base + ib] += coeff;
6428                }
6429            }
6430        }
6431        for blk in &frame.frame_blocks {
6432            if blk.atom_i != blk.atom_j {
6433                continue;
6434            }
6435            let r = frame.ranks[blk.atom_i];
6436            let off = frame.border_offsets[blk.atom_i];
6437            let (mi, mj) = blk.g.dim();
6438            for li in 0..mi.min(mj) {
6439                let gii = blk.g[[li, li]];
6440                let base = off + li * r;
6441                for a in 0..r {
6442                    diag[base + a] += gii * blk.w[[a, a]];
6443                }
6444            }
6445        }
6446        diag
6447    }
6448
6449    /// #1551 DEFINITIVE kernel-correctness proof: the framed reduced-Schur matvec
6450    /// `out = S·x` must agree with the CPU oracle [`sae_framed_schur_matvec_cpu`]
6451    /// element-wise, for several independent `x`, to ≤ 1e-9. This is the parity
6452    /// gate that actually localizes a kernel/marshalling defect — unlike a
6453    /// solved-`δβ` comparison, it does NOT route through a linear solve, so it is
6454    /// independent of the conditioning of the assembled `S` (a near-singular `S`
6455    /// can make a dense-Cholesky vector and an iterative-PCG vector disagree at
6456    /// the *solution* level even when both operators are bit-correct; the
6457    /// operator itself must still match here). Fails loud if CUDA is present but
6458    /// the device matvec declines; skips cleanly only when no device exists.
6459    #[test]
6460    fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6461        use crate::arrow_schur::{
6462            DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6463        };
6464
6465        // Hand-checkable frame fixture: a mix of framed (r<p) and identity-ride
6466        // (r==p) atoms, a few off-diagonal cross blocks, dense per-row H_tβ.
6467        let p = 6usize;
6468        let n_atoms = 8usize;
6469        let ranks: Vec<usize> = (0..n_atoms)
6470            .map(|k| if k % 2 == 0 { 3usize } else { p })
6471            .collect();
6472        let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6473        let mut border_offsets = Vec::with_capacity(n_atoms);
6474        let mut acc = 0usize;
6475        for k in 0..n_atoms {
6476            border_offsets.push(acc);
6477            acc += basis_sizes[k] * ranks[k];
6478        }
6479        let border_dim = acc;
6480
6481        let mut state = 0x1551_0017_1026_0922u64;
6482        let mut sample = || -> f64 {
6483            state = state
6484                .wrapping_mul(6364136223846793005)
6485                .wrapping_add(1442695040888963407);
6486            ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6487        };
6488        let mut frames: Vec<Array2<f64>> = Vec::new();
6489        for k in 0..n_atoms {
6490            let r = ranks[k];
6491            let mut u = Array2::<f64>::zeros((p, r));
6492            for i in 0..p {
6493                for j in 0..r {
6494                    u[[i, j]] = if r == p && i == j {
6495                        1.0
6496                    } else if r == p {
6497                        0.0
6498                    } else {
6499                        sample()
6500                    };
6501                }
6502            }
6503            frames.push(u);
6504        }
6505        let w_of = |i: usize, j: usize| {
6506            let (ui, uj) = (&frames[i], &frames[j]);
6507            let (ri, rj) = (ranks[i], ranks[j]);
6508            let mut w = Array2::<f64>::zeros((ri, rj));
6509            for a in 0..ri {
6510                for b in 0..rj {
6511                    let mut s = 0.0;
6512                    for c in 0..p {
6513                        s += ui[[c, a]] * uj[[c, b]];
6514                    }
6515                    w[[a, b]] = s;
6516                }
6517            }
6518            w
6519        };
6520        let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6521        for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6522            pairs.push((i, j));
6523            pairs.push((j, i));
6524        }
6525        let mut frame_blocks = Vec::new();
6526        for &(i, j) in &pairs {
6527            let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6528            let mut g = Array2::<f64>::zeros((mi, mj));
6529            for r in 0..mi {
6530                for c in 0..mj {
6531                    g[[r, c]] = 0.25 * sample();
6532                }
6533            }
6534            if i == j {
6535                for r in 0..mi.min(mj) {
6536                    g[[r, r]] += mi as f64 + 2.0;
6537                }
6538            }
6539            frame_blocks.push(FactoredFrameGBlock {
6540                atom_i: i,
6541                atom_j: j,
6542                g,
6543                w: w_of(i, j),
6544            });
6545        }
6546        let mut smooth_blocks = Vec::new();
6547        let mut smooth_ranks = Vec::new();
6548        for k in 0..n_atoms {
6549            let m = basis_sizes[k];
6550            let mut a = Array2::<f64>::zeros((m, m));
6551            for r in 0..m {
6552                for c in 0..m {
6553                    a[[r, c]] = 0.2 * sample();
6554                }
6555            }
6556            let mut s = a.t().dot(&a);
6557            for r in 0..m {
6558                s[[r, r]] += 1.0;
6559            }
6560            smooth_blocks.push(DeviceSaeSmoothBlock {
6561                global_offset: border_offsets[k],
6562                factor_a: s,
6563            });
6564            smooth_ranks.push(ranks[k]);
6565        }
6566        // Modest row count: this seam bypasses the offload floor, so we keep the
6567        // fixture small and the per-row reduced-Schur term well-scaled — the
6568        // matvec parity does not depend on the assembled-S conditioning at all.
6569        let n = 32usize;
6570        let q = 4usize;
6571        let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6572        let mut row_htbeta = Vec::new();
6573        for i in 0..n {
6574            let mut a = Array2::<f64>::zeros((q, q));
6575            for r in 0..q {
6576                for c in 0..q {
6577                    a[[r, c]] = sample();
6578                }
6579            }
6580            let mut htt = a.t().dot(&a);
6581            for r in 0..q {
6582                htt[[r, r]] += q as f64 + 1.0;
6583            }
6584            sys.rows[i].htt = htt;
6585            let mut slab = vec![0.0_f64; q * border_dim];
6586            for c in 0..q {
6587                for col in 0..border_dim {
6588                    let v = 0.3 * sample();
6589                    slab[c * border_dim + col] = v;
6590                    sys.rows[i].htbeta[[c, col]] = v;
6591                }
6592            }
6593            row_htbeta.push(slab);
6594        }
6595        let ridge_t = 1e-7;
6596        let ridge_beta = 1e-6;
6597        let data = DeviceSaePcgData {
6598            p,
6599            beta_dim: border_dim,
6600            a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6601            local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6602            smooth_blocks,
6603            sparse_g_blocks: Vec::new(),
6604            frame: Some(DeviceSaeFrameData {
6605                ranks: ranks.clone(),
6606                basis_sizes: basis_sizes.clone(),
6607                border_offsets: border_offsets.clone(),
6608                frame_blocks,
6609                smooth_ranks,
6610                row_htbeta,
6611            }),
6612        };
6613
6614        // Several independent probe vectors x, including unit axes and dense
6615        // random — a marshalling stride/offset bug shows up as a per-component
6616        // mismatch on at least one.
6617        let mut probes: Vec<Array1<f64>> = Vec::new();
6618        probes.push(Array1::from_shape_fn(border_dim, |a| {
6619            ((a as f64 + 1.0) * 0.37).sin()
6620        }));
6621        probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6622        for axis in [0usize, border_dim / 3, border_dim - 1] {
6623            let mut e = Array1::<f64>::zeros(border_dim);
6624            e[axis] = 1.0;
6625            probes.push(e);
6626        }
6627
6628        let mut any_ran = false;
6629        let mut worst = 0.0_f64;
6630        for (pi, x) in probes.iter().enumerate() {
6631            let device = match super::framed_schur_matvec_once_on_device(
6632                &sys, &data, ridge_t, ridge_beta, x,
6633            ) {
6634                Ok(out) => out,
6635                Err(failure) => {
6636                    // Fail loud: a present CUDA device that declines this seam
6637                    // (which deliberately ignores the offload floor) means the
6638                    // framed matvec kernel does not run on GPU.
6639                    assert!(
6640                        gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6641                        "#1551: CUDA device present but the framed device matvec \
6642                         declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6643                         does not run on GPU"
6644                    );
6645                    return;
6646                }
6647            };
6648            any_ran = true;
6649            let mut cpu = vec![0.0_f64; border_dim];
6650            sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6651                .expect("cpu oracle matvec");
6652            let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6653            for a in 0..border_dim {
6654                let rel = (device[a] - cpu[a]).abs() / scale;
6655                worst = worst.max(rel);
6656                assert!(
6657                    rel <= 1e-9,
6658                    "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6659                     rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6660                    device[a],
6661                    cpu[a],
6662                );
6663            }
6664        }
6665        if any_ran {
6666            // Positive on-device confirmation: the framed matvec ran on the GPU
6667            // and matched the CPU oracle across every probe. (1e-9 is far above
6668            // the ~1e-13 fp64 GEMV round-off; a structural marshalling bug would
6669            // be O(1).)
6670            assert!(
6671                gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6672                "#1551: matvec ran but no GPU runtime — unexpected"
6673            );
6674            assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6675        }
6676    }
6677}