Skip to main content

gam_solve/gpu/
pirls_gpu.rs

1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2
3#[derive(Clone, Debug)]
4pub struct PirlsGpuInput<'a> {
5    pub x: ArrayView2<'a, f64>,
6    pub weights: ArrayView1<'a, f64>,
7    pub penalty_hessian: ArrayView2<'a, f64>,
8    /// Full descent-direction RHS: `Xᵀ·score − S·β + linear_shift`. The
9    /// returned `PirlsGpuStep::direction = H⁻¹·gradient` (no negation, #257).
10    /// Callers must assemble the corrected RHS before passing it here.
11    pub gradient: ArrayView1<'a, f64>,
12    /// Temporary Levenberg–Marquardt damping; added to H for the solve
13    /// only. Never enters the exported `penalized_hessian`, `RidgePassport`,
14    /// EDF, REML curvature, or penalty term.
15    pub step_lm_lambda: f64,
16    /// Real model-objective ridge. Enters the exported `penalized_hessian`,
17    /// `RidgePassport`, EDF, REML curvature, and penalty term.
18    pub objective_ridge: f64,
19}
20
21#[derive(Clone, Debug)]
22pub struct PirlsGpuStep {
23    pub penalized_hessian: Array2<f64>,
24    pub direction: Array1<f64>,
25    pub logdet: f64,
26}
27
28/// Per-step inputs for [`solve_pirls_step_on_stream`].
29///
30/// Mirrors [`PirlsGpuInput`] but elides the design matrix `x` because that
31/// lives device-resident in the shared batch state. Each PIRLS Newton step
32/// only changes `weights`, `penalty_hessian` (with the current Sλ sum),
33/// `gradient`, and the LM ridge — these are the small per-step uploads the
34/// stream-pool path streams to the device.
35#[derive(Clone, Debug)]
36pub struct PirlsStepStreamInput<'a> {
37    pub weights: ArrayView1<'a, f64>,
38    pub penalty_hessian: ArrayView2<'a, f64>,
39    pub gradient: ArrayView1<'a, f64>,
40    /// Temporary LM damping for this Newton solve step only. Added to H
41    /// before potrf; stripped out of the snapshotted `penalized_hessian`.
42    pub step_lm_lambda: f64,
43    /// Real model-objective ridge. Appears in the exported
44    /// `penalized_hessian` that flows to EDF / REML curvature.
45    pub objective_ridge: f64,
46}
47
48/// Stage 3.2 device-input variant of [`PirlsStepStreamInput`].
49///
50/// Where the host-input form uploads `weights` + `gradient` per Newton
51/// step, this form reads them straight from the
52/// [`crate::gpu_kernels::pirls_row::RowOutputDevBuffers`] populated by the
53/// device-side row-reweight kernel — no host round-trip for the row
54/// state. Only the penalty matrix still crosses the host boundary
55/// because the outer REML loop updates Sλ + LM ridge between PIRLS
56/// steps.
57#[cfg(target_os = "linux")]
58pub struct PirlsStepStreamDeviceInput<'a, 'b> {
59    /// Device-resident solver weights `w_solver_i` (length n). Read
60    /// in-place by the cublasDdgmm WX assembly.
61    pub w_solver_dev: &'a cudarc::driver::CudaSlice<f64>,
62    /// Device-resident IRLS gradient `∂ℓ/∂η_i` (length n). Read by the
63    /// `Xᵀg` dgemv to form the Newton RHS.
64    pub grad_eta_dev: &'b cudarc::driver::CudaSlice<f64>,
65    /// Penalty Hessian Sλ in row-major host layout (p × p).
66    pub penalty_hessian: ArrayView2<'b, f64>,
67    /// Temporary LM damping for this Newton solve step only. Added to H
68    /// before potrf; stripped out of the snapshotted `penalized_hessian`.
69    pub step_lm_lambda: f64,
70    /// Real model-objective ridge. Appears in the exported
71    /// `penalized_hessian` that flows to EDF / REML curvature.
72    pub objective_ridge: f64,
73    /// Current coefficient vector β (length p). Downloaded to the host to
74    /// form the Newton RHS correction S·β. Only p f64 values cross the
75    /// boundary (β is small), so the round-trip cost is negligible.
76    pub beta_dev: &'b cudarc::driver::CudaSlice<f64>,
77    /// Linear shift vector (length p) in transformed coordinates, on host.
78    /// Added to Newton RHS so the solve targets Xᵀ·score − S·β + linear_shift.
79    pub linear_shift: ArrayView1<'b, f64>,
80}
81
82/// Shared, batch-wide GPU state for stream-pool sigma-cubature PIRLS.
83///
84/// Construct once per model via [`upload_shared_pirls_gpu`] and hand a
85/// shared reference to many [`SigmaPirlsGpuWorkspace`]s. X_original, y,
86/// prior_w, and offset are uploaded once and reused across all ρ / σ
87/// points. Per ρ / σ point, only the small `Qs` reparam matrix is
88/// re-uploaded into the workspace.
89#[cfg(target_os = "linux")]
90pub struct PirlsGpuSharedData {
91    pub(crate) ctx: std::sync::Arc<cudarc::driver::CudaContext>,
92    pub(crate) n: usize,
93    pub(crate) p: usize,
94    /// `n*p` f64 column-major **original** design matrix `X_original`,
95    /// device-resident. Never the pre-multiplied `X·Qs` form.
96    pub(crate) x_original_dev: cudarc::driver::CudaSlice<f64>,
97    /// Response vector `y`, length `n`, device-resident.
98    pub(crate) y_dev: cudarc::driver::CudaSlice<f64>,
99    /// Prior weights, length `n`, device-resident.
100    pub(crate) prior_w_dev: cudarc::driver::CudaSlice<f64>,
101    /// Observation offset, length `n`, device-resident.
102    pub(crate) offset_dev: cudarc::driver::CudaSlice<f64>,
103}
104
105/// Per-stream workspace for [`solve_pirls_step_on_stream`].
106///
107/// Owns a non-default CUDA stream plus cuBLAS / cuSOLVER handles bound to
108/// that stream, and the persistent device buffers that every PIRLS Newton
109/// step in this sigma fit reuses (no per-step allocation, no per-step
110/// handle creation). Multiple workspaces on independent streams sharing
111/// one [`PirlsGpuSharedData`] are the substrate the stream-pool cubature
112/// executor (Block 6 P3) composes.
113///
114/// When `p < FUSED_XTWX_P_THRESHOLD`, the workspace skips the `n×p` `wx_dev`
115/// temporary entirely and routes through the fused `xtwx_lower` + `xtscore`
116/// kernels instead. `wx_dev` is `Some` only for the large-p fallback path
117/// where `ddgmm + gemm` beats the fused kernel.
118#[cfg(target_os = "linux")]
119pub struct SigmaPirlsGpuWorkspace {
120    pub(crate) stream: std::sync::Arc<cudarc::driver::CudaStream>,
121    pub(crate) blas: cudarc::cublas::CudaBlas,
122    pub(crate) solver: cudarc::cusolver::DnHandle,
123    /// `None` when `p < FUSED_XTWX_P_THRESHOLD` (fused path). `Some` for the
124    /// large-p fallback where the `ddgmm + dgemm` route is faster.
125    pub(crate) wx_dev: Option<cudarc::driver::CudaSlice<f64>>,
126    pub(crate) w_dev: cudarc::driver::CudaSlice<f64>,
127    /// `X_originalᵀ W X_original` (p×p) — intermediate before Qs projection.
128    pub(crate) xtwx_dev: cudarc::driver::CudaSlice<f64>,
129    pub(crate) h_dev: cudarc::driver::CudaSlice<f64>,
130    pub(crate) rhs_dev: cudarc::driver::CudaSlice<f64>,
131    pub(crate) penalty_dev: cudarc::driver::CudaSlice<f64>,
132    /// Reparameterisation matrix `Qs` (p×p, column-major), uploaded once per
133    /// ρ / σ point. Identity when no reparameterisation is active. Used to
134    /// project `A = X_originalᵀ W X_original` into the transformed frame:
135    /// `H_step = Qsᵀ A Qs + S + λI`.
136    pub(crate) qs_dev: cudarc::driver::CudaSlice<f64>,
137    /// Scratch p×p buffer for the two-step `Qsᵀ A Qs` accumulation:
138    /// first `tmp = A Qs`, then `H = Qsᵀ tmp`.
139    pub(crate) qs_tmp_dev: cudarc::driver::CudaSlice<f64>,
140    /// p-vector: `beta_orig = Qs · β` computed before each `eta = X · beta_orig`.
141    pub(crate) beta_orig_dev: cudarc::driver::CudaSlice<f64>,
142    /// p-vector scratch used for `Qs · direction` when forming `xd = X · (Qs · δ)`.
143    pub(crate) dir_orig_dev: cudarc::driver::CudaSlice<f64>,
144    /// Pre-allocated cuSOLVER POTRF workspace buffer. Sized once at
145    /// construction via `potrf_query_lwork`; reused every Newton step.
146    pub(crate) potrf_work_dev: cudarc::driver::CudaSlice<f64>,
147    /// Number of f64 elements in `potrf_work_dev`, stored as i32 to match
148    /// the cuSOLVER API signature for cusolverDnDpotrf.
149    pub(crate) potrf_lwork: i32,
150    /// Deferred POTRF info scalar. Stays device-resident across all PIRLS
151    /// Newton steps; downloaded once at end-of-fit via
152    /// `check_deferred_potrf_info`.
153    pub(crate) potrf_info_dev: cudarc::driver::CudaSlice<i32>,
154    /// Deferred POTRS info scalar. Mirrors the POTRF discipline.
155    pub(crate) potrs_info_dev: cudarc::driver::CudaSlice<i32>,
156    pub(crate) n: usize,
157    pub(crate) p: usize,
158}
159
160#[cfg(target_os = "linux")]
161pub(crate) mod cuda {
162    use super::{
163        PirlsGpuInput, PirlsGpuSharedData, PirlsGpuStep, PirlsStepStreamDeviceInput,
164        PirlsStepStreamInput, SigmaPirlsGpuWorkspace,
165    };
166    use gam_gpu::device_cache::PtxModuleCache;
167    use gam_gpu::driver::{from_col_major, to_col_major};
168    use gam_gpu::solver::{
169        check_deferred_potrf_info, check_deferred_potrs_info, context_and_stream, pinned_htod,
170        potrf_in_place_reuse, potrf_query_lwork, potrs_in_place_reuse,
171    };
172    use cudarc::cublas::sys::{
173        cublasDdgmm, cublasDgeam, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
174    };
175    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
176    use cudarc::cusolver::DnHandle;
177    use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut, LaunchConfig, PushKernelArg};
178    use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
179
180    /// One-thread reduction over a p×p column-major Cholesky factor's
181    /// diagonal, computing `2·Σ ln(L[i,i])` device-side and writing a
182    /// single f64 into `out[0]`. The factor's lower-triangular Cholesky
183    /// has positive diagonal by construction, so no abs/clamp needed.
184    /// One thread is enough for the dominant p ≤ ~200 sizes; the cost was
185    /// previously a full p² download, so even a serial device sweep wins.
186    const CHOL_LOGDET_PTX_SOURCE: &str = r#"
187extern "C" __global__ void chol_logdet_col_major(
188    const double* __restrict__ factor,
189    int p,
190    double* __restrict__ out
191) {
192    if (threadIdx.x != 0 || blockIdx.x != 0) return;
193    double acc = 0.0;
194    long long pp = (long long)p;
195    for (long long i = 0; i < pp; ++i) {
196        acc += log(factor[i * pp + i]);
197    }
198    out[0] = 2.0 * acc;
199}
200"#;
201
202    static CHOL_LOGDET_CACHE: PtxModuleCache = PtxModuleCache::new();
203
204    /// When `p` is below this threshold the workspace uses the fused
205    /// `xtwx_lower` + `xtscore` + `symmetrize_lower` kernels and omits the
206    /// `n*p` `wx_dev` temporary entirely. For `p >= FUSED_XTWX_P_THRESHOLD`
207    /// the existing `ddgmm + dgemm` path is used.
208    const FUSED_XTWX_P_THRESHOLD: usize = 256;
209
210    /// NVRTC kernels for the fused path.
211    ///
212    /// `xtwx_lower`: one thread per lower-tri pair `(j,k)` with `j >= k`;
213    /// iterates over `n` rows, writes `A[j + k*p]` (col-major lower triangle).
214    ///
215    /// `xtscore`: one thread per `j`; writes `s[j] = sum_i score[i]*X[i,j]`.
216    ///
217    /// `symmetrize_lower`: one thread per strict-lower pair `(j,k)` with
218    /// `j > k`; copies `A[k + j*p] = A[j + k*p]` to fill the upper triangle.
219    const FUSED_XTWX_PTX_SOURCE: &str = concat!(
220        // xtwx_lower: enumerate lower triangle row-by-row.
221        // Row j has entries (j,0),(j,1),...,(j,j).
222        // Cumulative offset before row j = j*(j+1)/2.
223        // Unrank t -> j = floor((sqrt(8t+1)-1)/2), k = t - j*(j+1)/2.
224        // Output: A[j + k*p] in col-major for j >= k.
225        "extern \"C\" __global__ void xtwx_lower(",
226        "const double* __restrict__ X,",
227        "const double* __restrict__ w,",
228        "double* __restrict__ A,",
229        "int n, int p) {",
230        "int t=blockIdx.x*blockDim.x+threadIdx.x;",
231        "int np=p*(p+1)/2; if(t>=np)return;",
232        // j = floor((sqrt(8t+1)-1)/2); clamp for fp rounding
233        "int jv=(int)((__dsqrt_rn((double)(8*t+1))-1.0)*0.5);",
234        "while((long long)(jv+1)*(jv+2)/2<=t)jv++;",
235        "while(jv>0&&(long long)jv*(jv+1)/2>t)jv--;",
236        "int kv=t-(int)((long long)jv*(jv+1)/2);",
237        "double acc=0.0;",
238        "const double*Xj=X+(long long)jv*n;",
239        "const double*Xk=X+(long long)kv*n;",
240        "for(int i=0;i<n;++i)acc+=w[i]*Xj[i]*Xk[i];",
241        // col-major index: A[jv, kv] = A[jv + kv*p]
242        "A[jv+(long long)kv*p]=acc;}",
243        // xtscore: one thread per output index j
244        "extern \"C\" __global__ void xtscore(",
245        "const double* __restrict__ X,",
246        "const double* __restrict__ score,",
247        "double* __restrict__ s,",
248        "int n, int p) {",
249        "int j=blockIdx.x*blockDim.x+threadIdx.x;",
250        "if(j>=p)return;",
251        "double acc=0.0;",
252        "const double*Xj=X+(long long)j*n;",
253        "for(int i=0;i<n;++i)acc+=score[i]*Xj[i];",
254        "s[j]=acc;}",
255        // symmetrize_lower: strict lower pairs (j,k) with j>k.
256        // Enumerate row-by-row: row j=1 has entry (1,0); row j=2 has (2,0),(2,1); etc.
257        // Cumulative before row j: j*(j-1)/2.
258        // Unrank t -> j = floor((sqrt(8t+1)+1)/2), k = t - j*(j-1)/2.
259        "extern \"C\" __global__ void symmetrize_lower(",
260        "double* __restrict__ A, int p) {",
261        "int ns=p*(p-1)/2;",
262        "int t=blockIdx.x*blockDim.x+threadIdx.x;",
263        "if(t>=ns)return;",
264        // j = floor((sqrt(8t+1)+1)/2); clamp
265        "int jv=(int)((__dsqrt_rn((double)(8*t+1))+1.0)*0.5);",
266        "while((long long)jv*(jv-1)/2>t)jv--;",
267        "while((long long)(jv+1)*jv/2<=t)jv++;",
268        "int kv=t-(int)((long long)jv*(jv-1)/2);",
269        // A[kv, jv] = A[kv + jv*p] = A[jv + kv*p] (copy lower to upper)
270        "A[kv+(long long)jv*p]=A[jv+(long long)kv*p];}",
271    );
272
273    static FUSED_XTWX_CACHE: PtxModuleCache = PtxModuleCache::new();
274
275    impl PirlsGpuSharedData {
276        /// Upload `x` to the cached per-ordinal CUDA context and return a
277        /// Upload X_original, y, prior_w, and offset to the device once.
278        /// Returns a shared handle reused across all ρ / σ points.
279        pub(crate) fn upload_impl(
280            x: ArrayView2<'_, f64>,
281            y: ArrayView1<'_, f64>,
282            prior_w: ArrayView1<'_, f64>,
283            offset: ArrayView1<'_, f64>,
284        ) -> Result<Self, String> {
285            let (n, p) = x.dim();
286            if n == 0 || p == 0 {
287                return Err("empty design cannot be uploaded".to_string());
288            }
289            if y.len() != n || prior_w.len() != n || offset.len() != n {
290                return Err(format!(
291                    "y/prior_w/offset length mismatch (y={}, w={}, offset={}, n={n})",
292                    y.len(),
293                    prior_w.len(),
294                    offset.len()
295                ));
296            }
297            let (ctx, stream) = context_and_stream()?;
298            let x_col = to_col_major(&x);
299            let x_original_dev = pinned_htod(&stream, &x_col)?;
300            let y_dev = pinned_htod(&stream, y.as_slice().ok_or("y not contiguous")?)?;
301            let prior_w_dev =
302                pinned_htod(&stream, prior_w.as_slice().ok_or("prior_w not contiguous")?)?;
303            let offset_dev =
304                pinned_htod(&stream, offset.as_slice().ok_or("offset not contiguous")?)?;
305            // Synchronize the upload stream so all buffers are visible to
306            // every workspace we hand off to. Workspaces use independent
307            // streams; the uploads completed on the bootstrap stream above.
308            stream
309                .synchronize()
310                .map_err(|e| format!("cuda sync after model upload: {e}"))?;
311            Ok(Self {
312                ctx,
313                n,
314                p,
315                x_original_dev,
316                y_dev,
317                prior_w_dev,
318                offset_dev,
319            })
320        }
321    }
322
323    impl SigmaPirlsGpuWorkspace {
324        /// Allocate a workspace bound to a fresh non-default CUDA stream on
325        /// the shared context. cuBLAS and cuSOLVER handles are created with
326        /// that stream so every kernel issued through them is enqueued on
327        /// this workspace's stream, allowing concurrent overlap with peer
328        /// workspaces in the stream pool.
329        pub(crate) fn allocate_impl(shared: &PirlsGpuSharedData) -> Result<Self, String> {
330            let n = shared.n;
331            let p = shared.p;
332            let stream = shared
333                .ctx
334                .new_stream()
335                .map_err(|e| format!("cuda stream alloc: {e}"))?;
336            let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
337            let solver =
338                DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
339            let np = n.checked_mul(p).ok_or("X size overflow")?;
340            let pp = p.checked_mul(p).ok_or("H size overflow")?;
341            // Skip the n*p WX scratch when the fused kernels will be used.
342            let wx_dev = if p >= FUSED_XTWX_P_THRESHOLD {
343                Some(
344                    stream
345                        .alloc_zeros::<f64>(np)
346                        .map_err(|e| format!("cuda alloc WX: {e}"))?,
347                )
348            } else {
349                None
350            };
351            let w_dev = stream
352                .alloc_zeros::<f64>(n)
353                .map_err(|e| format!("cuda alloc W: {e}"))?;
354            let xtwx_dev = stream
355                .alloc_zeros::<f64>(pp)
356                .map_err(|e| format!("cuda alloc XtWX: {e}"))?;
357            let h_dev = stream
358                .alloc_zeros::<f64>(pp)
359                .map_err(|e| format!("cuda alloc H: {e}"))?;
360            let rhs_dev = stream
361                .alloc_zeros::<f64>(p)
362                .map_err(|e| format!("cuda alloc RHS: {e}"))?;
363            let penalty_dev = stream
364                .alloc_zeros::<f64>(pp)
365                .map_err(|e| format!("cuda alloc penalty: {e}"))?;
366            // Qs and scratch: p×p identity-initialized and p-vector zeros.
367            let mut qs_dev = stream
368                .alloc_zeros::<f64>(pp)
369                .map_err(|e| format!("cuda alloc Qs: {e}"))?;
370            // Initialize Qs to identity: diagonal = 1.0.
371            {
372                let mut qs_host = vec![0.0_f64; pp];
373                for i in 0..p {
374                    qs_host[i * p + i] = 1.0;
375                }
376                stream
377                    .memcpy_htod(&qs_host, &mut qs_dev)
378                    .map_err(|e| format!("init Qs identity: {e}"))?;
379            }
380            let qs_tmp_dev = stream
381                .alloc_zeros::<f64>(pp)
382                .map_err(|e| format!("cuda alloc Qs tmp: {e}"))?;
383            let beta_orig_dev = stream
384                .alloc_zeros::<f64>(p)
385                .map_err(|e| format!("cuda alloc beta_orig: {e}"))?;
386            let dir_orig_dev = stream
387                .alloc_zeros::<f64>(p)
388                .map_err(|e| format!("cuda alloc dir_orig: {e}"))?;
389            // Query the POTRF workspace size once using the actual p so we
390            // can size the persistent buffer. This is the only buffer-size
391            // query in the hot path — every Newton step reuses it.
392            let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
393            let potrf_lwork = i32::try_from(potrf_lwork_usize)
394                .map_err(|_| format!("potrf lwork {potrf_lwork_usize} exceeds i32"))?;
395            // Allocate at least 1 element so the device pointer is always
396            // valid; cuSOLVER accepts a zero-length workspace when lwork==0.
397            let alloc_len = potrf_lwork_usize.max(1);
398            let potrf_work_dev = stream
399                .alloc_zeros::<f64>(alloc_len)
400                .map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
401            let potrf_info_dev = stream
402                .alloc_zeros::<i32>(1)
403                .map_err(|e| format!("cuda alloc potrf info: {e}"))?;
404            let potrs_info_dev = stream
405                .alloc_zeros::<i32>(1)
406                .map_err(|e| format!("cuda alloc potrs info: {e}"))?;
407            Ok(Self {
408                stream,
409                blas,
410                solver,
411                wx_dev,
412                w_dev,
413                xtwx_dev,
414                h_dev,
415                rhs_dev,
416                penalty_dev,
417                qs_dev,
418                qs_tmp_dev,
419                beta_orig_dev,
420                dir_orig_dev,
421                potrf_work_dev,
422                potrf_lwork,
423                potrf_info_dev,
424                potrs_info_dev,
425                n,
426                p,
427            })
428        }
429    }
430
431    /// Upload a new `Qs` matrix (p×p, row-major host) to `ws.qs_dev`.
432    /// Call once per ρ / σ point before calling `pirls_loop` or any step
433    /// function. When no reparameterisation is active, pass the identity.
434    pub(super) fn upload_qs(
435        ws: &mut SigmaPirlsGpuWorkspace,
436        qs: ArrayView2<'_, f64>,
437    ) -> Result<(), String> {
438        let p = ws.p;
439        if qs.dim() != (p, p) {
440            return Err(format!("upload_qs: Qs shape {:?} != ({p},{p})", qs.dim()));
441        }
442        let qs_col = to_col_major(&qs);
443        ws.stream
444            .memcpy_htod(qs_col.as_ref(), &mut ws.qs_dev)
445            .map_err(|e| format!("upload Qs: {e}"))
446    }
447
448    /// Upload an identity `Qs` (no reparameterisation) for the current ρ point.
449    pub(super) fn upload_qs_identity(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
450        let p = ws.p;
451        let pp = p * p;
452        let mut qs_host = vec![0.0_f64; pp];
453        for i in 0..p {
454            qs_host[i * p + i] = 1.0;
455        }
456        ws.stream
457            .memcpy_htod(&qs_host, &mut ws.qs_dev)
458            .map_err(|e| format!("upload Qs identity: {e}"))
459    }
460
461    /// Apply one fp64 iterative-refinement correction to a Newton step solve.
462    ///
463    /// Compute `r = g − H_step·x` (host, p-vector). When `p ≥ REFINEMENT_MIN_P`
464    /// and `‖r‖/‖g‖ > REFINEMENT_TOL`, apply one POTRS correction and return
465    /// `x + e`. Returns `direction_raw` unchanged when `p` is too small, the
466    /// residual is already tight, or `‖g‖ = 0`.
467    ///
468    /// `H_step·x = penalized_hessian·x + step_lm_delta·x`.
469    fn newton_step_refine_once(
470        solver: &cudarc::cusolver::DnHandle,
471        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
472        p: usize,
473        chol_factor_dev: &CudaSlice<f64>,
474        rhs_dev: &mut CudaSlice<f64>,
475        potrs_info_dev: &mut CudaSlice<i32>,
476        mut direction_raw: Vec<f64>,
477        g: &[f64],
478        penalized_hessian: &ndarray::Array2<f64>,
479        step_lm_delta: f64,
480    ) -> Result<Vec<f64>, String> {
481        use gam_gpu::policy::GpuDispatchPolicy;
482        if p < GpuDispatchPolicy::REFINEMENT_MIN_P {
483            return Ok(direction_raw);
484        }
485        let norm_g = g.iter().map(|v| v * v).sum::<f64>().sqrt();
486        if norm_g == 0.0 {
487            return Ok(direction_raw);
488        }
489        let hx: Vec<f64> = (0..p)
490            .map(|i| {
491                penalized_hessian
492                    .row(i)
493                    .iter()
494                    .zip(direction_raw.iter())
495                    .map(|(hij, xj)| hij * xj)
496                    .sum::<f64>()
497                    + step_lm_delta * direction_raw[i]
498            })
499            .collect();
500        let residual: Vec<f64> = g.iter().zip(hx.iter()).map(|(gi, hxi)| gi - hxi).collect();
501        let rel_res = residual.iter().map(|v| v * v).sum::<f64>().sqrt() / norm_g;
502        if rel_res <= GpuDispatchPolicy::REFINEMENT_TOL {
503            return Ok(direction_raw);
504        }
505        stream
506            .memcpy_htod(&residual, rhs_dev)
507            .map_err(|e| format!("upload residual: {e}"))?;
508        potrs_in_place_reuse(
509            solver,
510            stream,
511            p,
512            1,
513            chol_factor_dev,
514            rhs_dev,
515            potrs_info_dev,
516        )?;
517        let correction = stream
518            .clone_dtoh(rhs_dev)
519            .map_err(|e| format!("download correction: {e}"))?;
520        check_deferred_potrs_info(stream, potrs_info_dev)?;
521        for (xi, ei) in direction_raw.iter_mut().zip(correction.iter()) {
522            *xi += ei;
523        }
524        Ok(direction_raw)
525    }
526
527    /// Drive one PIRLS Newton step on the workspace's CUDA stream.
528    ///
529    /// Build `H = XᵀWX + S + λI`, Cholesky-factor it, solve `H·d = g`,
530    /// return `(H, d, log|H|)`. `input.gradient` is the full descent-direction
531    /// RHS `Xᵀscore − S·β + linear_shift` — the caller is responsible for
532    /// assembling the corrected RHS before calling this function. No negation
533    /// is applied; the returned `direction = H⁻¹·g` is the descent step δ
534    /// directly (#257). The difference vs the one-shot [`solve_step`] is
535    /// purely the execution model: no context creation, no handle creation,
536    /// no design-matrix upload, no per-step buffer allocations.
537    pub(super) fn solve_step_on_stream(
538        shared: &PirlsGpuSharedData,
539        ws: &mut SigmaPirlsGpuWorkspace,
540        input: PirlsStepStreamInput<'_>,
541    ) -> Result<PirlsGpuStep, String> {
542        let n = shared.n;
543        let p = shared.p;
544        if ws.n != n || ws.p != p {
545            return Err(format!(
546                "workspace shape ({}, {}) does not match shared design ({n}, {p})",
547                ws.n, ws.p
548            ));
549        }
550        if input.weights.len() != n {
551            return Err(format!(
552                "weights length {} does not match rows {n}",
553                input.weights.len()
554            ));
555        }
556        if input.penalty_hessian.dim() != (p, p) {
557            return Err(format!(
558                "penalty Hessian shape {:?} does not match p={p}",
559                input.penalty_hessian.dim()
560            ));
561        }
562        if input.gradient.len() != p {
563            return Err(format!(
564                "gradient length {} does not match p={p}",
565                input.gradient.len()
566            ));
567        }
568
569        // Upload per-step weights into the persistent W buffer.
570        let w_slice = input
571            .weights
572            .as_slice()
573            .ok_or("weights must be contiguous")?;
574        ws.stream
575            .memcpy_htod(w_slice, &mut ws.w_dev)
576            .map_err(|e| format!("upload W: {e}"))?;
577
578        // Compute XᵀWX into ws.xtwx_dev.  Two paths:
579        // Fused (p < FUSED_XTWX_P_THRESHOLD): row-sweep kernels, no n*p temp.
580        // Fallback (p >= FUSED_XTWX_P_THRESHOLD): ddgmm + dgemm via wx_dev.
581        let n_i = to_i32(n)?;
582        let p_i = to_i32(p)?;
583        if let Some(ref mut wx_dev) = ws.wx_dev {
584            left_scale_rows(
585                &ws.blas,
586                &ws.stream,
587                n,
588                p,
589                &shared.x_original_dev,
590                &mut ws.w_dev,
591                wx_dev,
592            )?;
593            let cfg = GemmConfig::<f64> {
594                transa: cublasOperation_t::CUBLAS_OP_T,
595                transb: cublasOperation_t::CUBLAS_OP_N,
596                m: p_i,
597                n: p_i,
598                k: n_i,
599                alpha: 1.0,
600                lda: n_i,
601                ldb: n_i,
602                beta: 0.0,
603                ldc: p_i,
604            };
605            // SAFETY: validated i32 dims; shared.x_original_dev and wx_dev are n*p
606            // f64 col-major; ws.xtwx_dev is the p*p output.
607            unsafe {
608                ws.blas
609                    .gemm(cfg, &shared.x_original_dev, wx_dev, &mut ws.xtwx_dev)
610            }
611            .map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
612        } else {
613            launch_xtwx_lower(
614                &ws.stream,
615                &shared.ctx,
616                n,
617                p,
618                &shared.x_original_dev,
619                &ws.w_dev,
620                &mut ws.xtwx_dev,
621            )?;
622            launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
623        }
624
625        // Upload S + step_lm_lambda·I for the Newton solve (LM damping only).
626        let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
627        let penalty_step_view = penalty_step.view();
628        let penalty_step_col = to_col_major(&penalty_step_view);
629        ws.stream
630            .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
631            .map_err(|e| format!("upload penalty: {e}"))?;
632
633        // Apply Qs rotation: H_xtx = Qsᵀ · XᵀWX · Qs (two p×p gemms).
634        // Matches solve_step_on_stream_device_inplace (#269 resident-X arch):
635        // X_original stays device-resident, Qs rotates into transformed frame.
636        {
637            let cfg_aq = GemmConfig::<f64> {
638                transa: cublasOperation_t::CUBLAS_OP_N,
639                transb: cublasOperation_t::CUBLAS_OP_N,
640                m: p_i,
641                n: p_i,
642                k: p_i,
643                alpha: 1.0,
644                lda: p_i,
645                ldb: p_i,
646                beta: 0.0,
647                ldc: p_i,
648            };
649            // SAFETY: xtwx_dev and qs_dev p*p col-major; qs_tmp_dev p*p output.
650            unsafe {
651                ws.blas
652                    .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
653            }
654            .map_err(|e| format!("dgemm A·Qs (host-input step): {e}"))?;
655        }
656        {
657            let cfg_qt = GemmConfig::<f64> {
658                transa: cublasOperation_t::CUBLAS_OP_T,
659                transb: cublasOperation_t::CUBLAS_OP_N,
660                m: p_i,
661                n: p_i,
662                k: p_i,
663                alpha: 1.0,
664                lda: p_i,
665                ldb: p_i,
666                beta: 0.0,
667                ldc: p_i,
668            };
669            // SAFETY: qs_dev p*p (transposed); qs_tmp_dev p*p; h_dev p*p output.
670            unsafe {
671                ws.blas
672                    .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
673            }
674            .map_err(|e| format!("dgemm Qsᵀ·A·Qs (host-input step): {e}"))?;
675        }
676        // H_step = Qsᵀ·XᵀWX·Qs + (S + step_lm_lambda·I).
677        geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
678
679        // Upload gradient into the persistent RHS buffer.
680        // `input.gradient` is already in transformed coordinates (Qsᵀ-projected
681        // by the caller), so no additional rotation is needed here.
682        let g_slice = input
683            .gradient
684            .as_slice()
685            .ok_or("gradient must be contiguous")?;
686        ws.stream
687            .memcpy_htod(g_slice, &mut ws.rhs_dev)
688            .map_err(|e| format!("upload gradient: {e}"))?;
689
690        // Exported penalised Hessian: H_final = Qsᵀ·XᵀWX·Qs + S + objective_ridge·I.
691        // Apply Qs rotation host-side on the downloaded XᵀWX so LM damping
692        // never contaminates exported EDF / REML curvature / RidgePassport.
693        let xtwx_col = ws
694            .stream
695            .clone_dtoh(&ws.xtwx_dev)
696            .map_err(|e| format!("download XᵀWX (host-input step): {e}"))?;
697        let xtwx_host = from_col_major(&xtwx_col, p, p).ok_or("XᵀWX layout conversion failed")?;
698        let qs_col = ws
699            .stream
700            .clone_dtoh(&ws.qs_dev)
701            .map_err(|e| format!("download Qs (host-input step): {e}"))?;
702        let qs_host =
703            from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (host-input step)")?;
704        let tmp_aq = xtwx_host.dot(&qs_host);
705        let h_rotated = qs_host.t().dot(&tmp_aq);
706        let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
707        let penalized_hessian = h_rotated + &penalty_export;
708
709        // Factor + solve in place on the stream using pre-allocated workspace
710        // and info buffers — no per-step allocation, no per-step info download.
711        potrf_in_place_reuse(
712            &ws.solver,
713            &ws.stream,
714            p,
715            ws.potrf_lwork,
716            &mut ws.h_dev,
717            &mut ws.potrf_work_dev,
718            &mut ws.potrf_info_dev,
719        )?;
720        potrs_in_place_reuse(
721            &ws.solver,
722            &ws.stream,
723            p,
724            1,
725            &ws.h_dev,
726            &mut ws.rhs_dev,
727            &mut ws.potrs_info_dev,
728        )?;
729
730        // Logdet device-side: reduces the previous p² Cholesky-factor
731        // download to a single f64 download. Stage 2's "no per-iteration
732        // host round-trip" budget keeps the p² factor on the device.
733        let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
734
735        // Direction: d = H⁻¹ g (no negation; g is the full corrected RHS, #257).
736        let direction_raw = ws
737            .stream
738            .clone_dtoh(&ws.rhs_dev)
739            .map_err(|e| format!("download direction: {e}"))?;
740        // Check deferred POTRF/POTRS info after the direction download
741        // (which already syncs the stream). Single host round-trip for both
742        // info scalars at end-of-step rather than one per cuSOLVER call.
743        check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
744        check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
745
746        // Iterative refinement on the Qs-rotated system.
747        // penalized_hessian = Qsᵀ·XtWX·Qs + S + objective_ridge·I.
748        // H_step = penalized_hessian + (step_lm_lambda − objective_ridge)·I.
749        let lm_ridge_delta = input.step_lm_lambda - input.objective_ridge;
750        let direction_raw = newton_step_refine_once(
751            &ws.solver,
752            &ws.stream,
753            p,
754            &ws.h_dev,
755            &mut ws.rhs_dev,
756            &mut ws.potrs_info_dev,
757            direction_raw,
758            g_slice,
759            &penalized_hessian,
760            lm_ridge_delta,
761        )?;
762
763        // No negation: `input.gradient` is the full descent-direction RHS
764        // `Xᵀscore − S·β + linear_shift`; solving H·δ = rhs gives δ directly.
765        let direction = Array1::from_vec(direction_raw);
766
767        Ok(PirlsGpuStep {
768            penalized_hessian,
769            direction,
770            logdet,
771        })
772    }
773
774    /// Stage 3.2 device-input PIRLS Newton step.
775    ///
776    /// Identical math to [`solve_step_on_stream`] but reads `w_solver`
777    /// and `grad_eta` straight from device buffers populated by the
778    /// device-side row-reweight kernel (no host upload of weights or
779    /// gradient). Only the penalty matrix still crosses the host
780    /// boundary because the outer REML loop updates Sλ + LM ridge
781    /// between PIRLS steps; the penalty is p×p which is independent of
782    /// n, so for large-scale n it is a negligible transfer.
783    ///
784    /// Outputs match `solve_step_on_stream`: returns the assembled
785    /// penalised Hessian, the Newton descent direction `δ = H⁻¹·rhs`
786    /// where `rhs = Xᵀ·score − S·β + linear_shift` (no negation, #257),
787    /// and the log-determinant computed via the device-side
788    /// `chol_logdet_col_major` kernel.
789    pub(super) fn solve_step_on_stream_device(
790        shared: &PirlsGpuSharedData,
791        ws: &mut SigmaPirlsGpuWorkspace,
792        input: PirlsStepStreamDeviceInput<'_, '_>,
793    ) -> Result<PirlsGpuStep, String> {
794        let n = shared.n;
795        let p = shared.p;
796        if ws.n != n || ws.p != p {
797            return Err(format!(
798                "workspace shape ({}, {}) does not match shared design ({n}, {p})",
799                ws.n, ws.p
800            ));
801        }
802        if input.w_solver_dev.len() != n {
803            return Err(format!(
804                "w_solver_dev length {} does not match n={n}",
805                input.w_solver_dev.len()
806            ));
807        }
808        if input.grad_eta_dev.len() != n {
809            return Err(format!(
810                "grad_eta_dev length {} does not match n={n}",
811                input.grad_eta_dev.len()
812            ));
813        }
814        if input.penalty_hessian.dim() != (p, p) {
815            return Err(format!(
816                "penalty Hessian shape {:?} does not match p={p}",
817                input.penalty_hessian.dim()
818            ));
819        }
820
821        // Compute XᵀWX and Xᵀ·score.  Fused path (p < threshold): no n*p WX.
822        // Fallback (p >= threshold): ddgmm + dgemm + gemv via wx_dev_fb.
823        let n_i = to_i32(n)?;
824        let p_i = to_i32(p)?;
825        if let Some(ref mut wx_dev_fb) = ws.wx_dev {
826            // Large-p fallback.
827            left_scale_rows_borrowed(
828                &ws.blas,
829                &ws.stream,
830                n,
831                p,
832                &shared.x_original_dev,
833                input.w_solver_dev,
834                wx_dev_fb,
835            )?;
836            let gemm_cfg = GemmConfig::<f64> {
837                transa: cublasOperation_t::CUBLAS_OP_T,
838                transb: cublasOperation_t::CUBLAS_OP_N,
839                m: p_i,
840                n: p_i,
841                k: n_i,
842                alpha: 1.0,
843                lda: n_i,
844                ldb: n_i,
845                beta: 0.0,
846                ldc: p_i,
847            };
848            // SAFETY: validated dims; shared.x_original_dev and wx_dev_fb are n*p
849            // f64 col-major; ws.xtwx_dev is p*p; all on ws.stream.
850            unsafe {
851                ws.blas.gemm(
852                    gemm_cfg,
853                    &shared.x_original_dev,
854                    wx_dev_fb,
855                    &mut ws.xtwx_dev,
856                )
857            }
858            .map_err(|e| format!("cublas dgemm XtWX (device-input): {e}"))?;
859            let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
860            let penalty_step_col = to_col_major(&penalty_step);
861            ws.stream
862                .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
863                .map_err(|e| format!("upload penalty (device-input): {e}"))?;
864            // Qs rotation on H: tmp = XᵀWX · Qs, then h_dev = Qsᵀ · tmp.
865            {
866                let cfg_aq = GemmConfig::<f64> {
867                    transa: cublasOperation_t::CUBLAS_OP_N,
868                    transb: cublasOperation_t::CUBLAS_OP_N,
869                    m: p_i,
870                    n: p_i,
871                    k: p_i,
872                    alpha: 1.0,
873                    lda: p_i,
874                    ldb: p_i,
875                    beta: 0.0,
876                    ldc: p_i,
877                };
878                // SAFETY: xtwx_dev and qs_dev p*p col-major; qs_tmp_dev p*p output.
879                unsafe {
880                    ws.blas
881                        .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
882                }
883                .map_err(|e| format!("dgemm A·Qs (device-input large-p): {e}"))?;
884            }
885            {
886                let cfg_qt = GemmConfig::<f64> {
887                    transa: cublasOperation_t::CUBLAS_OP_T,
888                    transb: cublasOperation_t::CUBLAS_OP_N,
889                    m: p_i,
890                    n: p_i,
891                    k: p_i,
892                    alpha: 1.0,
893                    lda: p_i,
894                    ldb: p_i,
895                    beta: 0.0,
896                    ldc: p_i,
897                };
898                // SAFETY: qs_dev p*p (transposed); qs_tmp_dev p*p; h_dev p*p output.
899                unsafe {
900                    ws.blas
901                        .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
902                }
903                .map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input large-p): {e}"))?;
904            }
905            geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
906            let gemv_cfg = GemvConfig::<f64> {
907                trans: cublasOperation_t::CUBLAS_OP_T,
908                m: n_i,
909                n: p_i,
910                alpha: 1.0,
911                lda: n_i,
912                incx: 1,
913                beta: 0.0,
914                incy: 1,
915            };
916            // SAFETY: shared.x_original_dev n*p col-major; grad_eta_dev length n; rhs_dev length p.
917            unsafe {
918                ws.blas.gemv(
919                    gemv_cfg,
920                    &shared.x_original_dev,
921                    input.grad_eta_dev,
922                    &mut ws.rhs_dev,
923                )
924            }
925            .map_err(|e| format!("cublas dgemv Xtg (device-input): {e}"))?;
926        } else {
927            // Fused path: row-sweep kernels, no n*p WX buffer.
928            launch_xtwx_lower(
929                &ws.stream,
930                &shared.ctx,
931                n,
932                p,
933                &shared.x_original_dev,
934                input.w_solver_dev,
935                &mut ws.xtwx_dev,
936            )?;
937            launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
938            launch_xtscore(
939                &ws.stream,
940                &shared.ctx,
941                n,
942                p,
943                &shared.x_original_dev,
944                input.grad_eta_dev,
945                &mut ws.rhs_dev,
946            )?;
947            // Qs rotation on H: tmp = XᵀWX · Qs, then h_dev = Qsᵀ · tmp.
948            {
949                let cfg_aq = GemmConfig::<f64> {
950                    transa: cublasOperation_t::CUBLAS_OP_N,
951                    transb: cublasOperation_t::CUBLAS_OP_N,
952                    m: p_i,
953                    n: p_i,
954                    k: p_i,
955                    alpha: 1.0,
956                    lda: p_i,
957                    ldb: p_i,
958                    beta: 0.0,
959                    ldc: p_i,
960                };
961                // SAFETY: xtwx_dev and qs_dev p*p col-major; qs_tmp_dev p*p output.
962                unsafe {
963                    ws.blas
964                        .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
965                }
966                .map_err(|e| format!("dgemm A·Qs (device-input fused): {e}"))?;
967            }
968            {
969                let cfg_qt = GemmConfig::<f64> {
970                    transa: cublasOperation_t::CUBLAS_OP_T,
971                    transb: cublasOperation_t::CUBLAS_OP_N,
972                    m: p_i,
973                    n: p_i,
974                    k: p_i,
975                    alpha: 1.0,
976                    lda: p_i,
977                    ldb: p_i,
978                    beta: 0.0,
979                    ldc: p_i,
980                };
981                // SAFETY: qs_dev p*p (transposed); qs_tmp_dev p*p; h_dev p*p output.
982                unsafe {
983                    ws.blas
984                        .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
985                }
986                .map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input fused): {e}"))?;
987            }
988            let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
989            let penalty_step_col = to_col_major(&penalty_step);
990            ws.stream
991                .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
992                .map_err(|e| format!("upload penalty (fused device-input): {e}"))?;
993            geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
994        }
995
996        // Apply rhs correction BEFORE the solve:
997        //   rhs = Qsᵀ·(Xᵀ·score) − S·β + linear_shift  (#257, #260, #269).
998        // First project X_origᵀ·score through Qsᵀ (p×p gemv on device), then
999        // apply the S·β correction host-side and re-upload.
1000        {
1001            // Qsᵀ · rhs_dev (= Xᵀ·score) → beta_orig_dev (scratch p-vector).
1002            let cfg_qts = GemvConfig::<f64> {
1003                trans: cublasOperation_t::CUBLAS_OP_T,
1004                m: p_i,
1005                n: p_i,
1006                alpha: 1.0,
1007                lda: p_i,
1008                incx: 1,
1009                beta: 0.0,
1010                incy: 1,
1011            };
1012            // SAFETY: qs_dev p*p (transposed); rhs_dev length p; beta_orig_dev length p.
1013            unsafe {
1014                ws.blas
1015                    .gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
1016            }
1017            .map_err(|e| format!("dgemv Qsᵀ·score (device-input): {e}"))?;
1018            // Swap: rhs_dev ← beta_orig_dev (now holds Qsᵀ·Xᵀ·score).
1019            ws.stream
1020                .memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
1021                .map_err(|e| format!("d2d Qsᵀ·score→rhs (device-input): {e}"))?;
1022            // Download rhs and β; apply penalty correction host-side.
1023            let rhs_raw = ws
1024                .stream
1025                .clone_dtoh(&ws.rhs_dev)
1026                .map_err(|e| format!("download Qsᵀscore (device-input): {e}"))?;
1027            let beta_raw = ws
1028                .stream
1029                .clone_dtoh(input.beta_dev)
1030                .map_err(|e| format!("download beta (device-input): {e}"))?;
1031            let mut rhs_host = Array1::from_vec(rhs_raw);
1032            let beta_host = Array1::from_vec(beta_raw);
1033            let s_beta = input.penalty_hessian.dot(&beta_host);
1034            rhs_host -= &s_beta;
1035            rhs_host += &input.linear_shift;
1036            ws.stream
1037                .memcpy_htod(
1038                    rhs_host
1039                        .as_slice()
1040                        .ok_or("rhs_host not contiguous (device-input correction)")?,
1041                    &mut ws.rhs_dev,
1042                )
1043                .map_err(|e| format!("re-upload corrected rhs (device-input): {e}"))?;
1044        }
1045
1046        // Exported penalised Hessian: H_final = Qsᵀ·XᵀWX·Qs + S + objective_ridge·I.
1047        // Apply Qs rotation host-side on the downloaded XᵀWX so LM damping
1048        // never contaminates exported EDF / REML curvature / RidgePassport.
1049        let xtwx_col = ws
1050            .stream
1051            .clone_dtoh(&ws.xtwx_dev)
1052            .map_err(|e| format!("download XᵀWX (device-input): {e}"))?;
1053        let xtwx_host = from_col_major(&xtwx_col, p, p)
1054            .ok_or("XᵀWX layout conversion failed (device-input)")?;
1055        let qs_col = ws
1056            .stream
1057            .clone_dtoh(&ws.qs_dev)
1058            .map_err(|e| format!("download Qs (device-input): {e}"))?;
1059        let qs_host =
1060            from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (device-input)")?;
1061        let tmp_aq = xtwx_host.dot(&qs_host);
1062        let h_rotated = qs_host.t().dot(&tmp_aq);
1063        let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
1064        let penalized_hessian = h_rotated + &penalty_export;
1065
1066        // Factor + solve in place on the stream using pre-allocated workspace
1067        // and info buffers — no per-step allocation, no per-step info download.
1068        potrf_in_place_reuse(
1069            &ws.solver,
1070            &ws.stream,
1071            p,
1072            ws.potrf_lwork,
1073            &mut ws.h_dev,
1074            &mut ws.potrf_work_dev,
1075            &mut ws.potrf_info_dev,
1076        )?;
1077        potrs_in_place_reuse(
1078            &ws.solver,
1079            &ws.stream,
1080            p,
1081            1,
1082            &ws.h_dev,
1083            &mut ws.rhs_dev,
1084            &mut ws.potrs_info_dev,
1085        )?;
1086
1087        let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
1088
1089        let direction_raw = ws
1090            .stream
1091            .clone_dtoh(&ws.rhs_dev)
1092            .map_err(|e| format!("download direction (device-input): {e}"))?;
1093        // Check deferred POTRF/POTRS info after the direction download
1094        // (which already syncs the stream). Single host round-trip for both
1095        // info scalars at end-of-step rather than one per cuSOLVER call.
1096        check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
1097        check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
1098        // No negation: rhs = Xᵀscore − Sβ + linear_shift already gives the
1099        // descent direction δ = H⁻¹·rhs directly (#257).
1100        let direction = Array1::from_vec(direction_raw);
1101
1102        Ok(PirlsGpuStep {
1103            penalized_hessian,
1104            direction,
1105            logdet,
1106        })
1107    }
1108
1109    /// In-place Newton step: rhs = Xᵀ·score − S·β + linear_shift (#257, #260).
1110    ///
1111    /// Solves H·δ = rhs (H = XᵀWX + S + step_lm_lambda·I). On return
1112    /// `ws.rhs_dev` holds the Newton descent direction δ (not negated).
1113    /// The loop copies `ws.rhs_dev` to `direction_dev` via `memcpy_dtod`.
1114    ///
1115    /// On return `ws.h_dev` holds the Cholesky factor; rebuild with
1116    /// `rebuild_h_final` to get the exported penalised Hessian.
1117    ///
1118    /// Returns `logdet = log|H|` computed device-side.
1119    pub(super) fn solve_step_on_stream_device_inplace(
1120        shared: &PirlsGpuSharedData,
1121        ws: &mut SigmaPirlsGpuWorkspace,
1122        input: PirlsStepStreamDeviceInput<'_, '_>,
1123    ) -> Result<f64, String> {
1124        let n = shared.n;
1125        let p = shared.p;
1126        if ws.n != n || ws.p != p {
1127            return Err(format!(
1128                "workspace shape ({}, {}) does not match shared design ({n}, {p})",
1129                ws.n, ws.p
1130            ));
1131        }
1132        if input.w_solver_dev.len() != n {
1133            return Err(format!(
1134                "w_solver_dev length {} does not match n={n}",
1135                input.w_solver_dev.len()
1136            ));
1137        }
1138        if input.grad_eta_dev.len() != n {
1139            return Err(format!(
1140                "grad_eta_dev length {} does not match n={n}",
1141                input.grad_eta_dev.len()
1142            ));
1143        }
1144        if input.penalty_hessian.dim() != (p, p) {
1145            return Err(format!(
1146                "penalty Hessian shape {:?} does not match p={p}",
1147                input.penalty_hessian.dim()
1148            ));
1149        }
1150
1151        if input.linear_shift.len() != p {
1152            return Err(format!(
1153                "linear_shift length {} does not match p={p}",
1154                input.linear_shift.len()
1155            ));
1156        }
1157        let n_i = to_i32(n)?;
1158        let p_i = to_i32(p)?;
1159
1160        // Step 1: A = X_origᵀ diag(w_solver) X_orig → ws.xtwx_dev.
1161        //         score_p = X_origᵀ grad_eta → ws.rhs_dev.
1162        if let Some(ref mut wx_dev_ib) = ws.wx_dev {
1163            // Large-p path: ddgmm then dgemm, then gemv.
1164            left_scale_rows_borrowed(
1165                &ws.blas,
1166                &ws.stream,
1167                n,
1168                p,
1169                &shared.x_original_dev,
1170                input.w_solver_dev,
1171                wx_dev_ib,
1172            )?;
1173            let cfg_xtx = GemmConfig::<f64> {
1174                transa: cublasOperation_t::CUBLAS_OP_T,
1175                transb: cublasOperation_t::CUBLAS_OP_N,
1176                m: p_i,
1177                n: p_i,
1178                k: n_i,
1179                alpha: 1.0,
1180                lda: n_i,
1181                ldb: n_i,
1182                beta: 0.0,
1183                ldc: p_i,
1184            };
1185            // SAFETY: x_original_dev and wx_dev_ib n*p col-major; xtwx_dev p*p; ws.stream.
1186            unsafe {
1187                ws.blas
1188                    .gemm(cfg_xtx, &shared.x_original_dev, wx_dev_ib, &mut ws.xtwx_dev)
1189            }
1190            .map_err(|e| format!("dgemm XtWX inplace (large-p): {e}"))?;
1191            let cfg_xts = GemvConfig::<f64> {
1192                trans: cublasOperation_t::CUBLAS_OP_T,
1193                m: n_i,
1194                n: p_i,
1195                alpha: 1.0,
1196                lda: n_i,
1197                incx: 1,
1198                beta: 0.0,
1199                incy: 1,
1200            };
1201            // SAFETY: x_original_dev n*p col-major; grad_eta_dev length n; rhs_dev length p.
1202            unsafe {
1203                ws.blas.gemv(
1204                    cfg_xts,
1205                    &shared.x_original_dev,
1206                    input.grad_eta_dev,
1207                    &mut ws.rhs_dev,
1208                )
1209            }
1210            .map_err(|e| format!("dgemv Xᵀ·score inplace (large-p): {e}"))?;
1211        } else {
1212            // Fused path: row-sweep kernels, no n*p WX buffer.
1213            launch_xtwx_lower(
1214                &ws.stream,
1215                &shared.ctx,
1216                n,
1217                p,
1218                &shared.x_original_dev,
1219                input.w_solver_dev,
1220                &mut ws.xtwx_dev,
1221            )?;
1222            launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
1223            launch_xtscore(
1224                &ws.stream,
1225                &shared.ctx,
1226                n,
1227                p,
1228                &shared.x_original_dev,
1229                input.grad_eta_dev,
1230                &mut ws.rhs_dev,
1231            )?;
1232        }
1233
1234        // Step 2: H_xtx = Qsᵀ A Qs  (two p×p gemms).
1235        //   tmp = A · Qs → ws.qs_tmp_dev.
1236        {
1237            let cfg_aq = GemmConfig::<f64> {
1238                transa: cublasOperation_t::CUBLAS_OP_N,
1239                transb: cublasOperation_t::CUBLAS_OP_N,
1240                m: p_i,
1241                n: p_i,
1242                k: p_i,
1243                alpha: 1.0,
1244                lda: p_i,
1245                ldb: p_i,
1246                beta: 0.0,
1247                ldc: p_i,
1248            };
1249            // SAFETY: xtwx_dev and qs_dev p*p col-major; qs_tmp_dev p*p output.
1250            unsafe {
1251                ws.blas
1252                    .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
1253            }
1254            .map_err(|e| format!("dgemm A·Qs inplace: {e}"))?;
1255        }
1256        //   H_xtx = Qsᵀ · tmp → ws.h_dev.
1257        {
1258            let cfg_qt = GemmConfig::<f64> {
1259                transa: cublasOperation_t::CUBLAS_OP_T,
1260                transb: cublasOperation_t::CUBLAS_OP_N,
1261                m: p_i,
1262                n: p_i,
1263                k: p_i,
1264                alpha: 1.0,
1265                lda: p_i,
1266                ldb: p_i,
1267                beta: 0.0,
1268                ldc: p_i,
1269            };
1270            // SAFETY: qs_dev p*p (transposed); qs_tmp_dev p*p; h_dev p*p output.
1271            unsafe {
1272                ws.blas
1273                    .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
1274            }
1275            .map_err(|e| format!("dgemm Qsᵀ·A·Qs inplace: {e}"))?;
1276        }
1277        // H_step = H_xtx + (S + step_lm_lambda·I).
1278        let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
1279        let penalty_step_col = to_col_major(&penalty_step);
1280        ws.stream
1281            .memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
1282            .map_err(|e| format!("upload penalty inplace: {e}"))?;
1283        geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
1284
1285        // Step 3: rhs = Qsᵀ score_p − S·β + linear_shift  (#257, #260).
1286        // First project score_p through Qsᵀ on device (p×p gemv):
1287        //   beta_orig_dev = Qsᵀ · rhs_dev,  then swap back.
1288        {
1289            let cfg_qts = GemvConfig::<f64> {
1290                trans: cublasOperation_t::CUBLAS_OP_T,
1291                m: p_i,
1292                n: p_i,
1293                alpha: 1.0,
1294                lda: p_i,
1295                incx: 1,
1296                beta: 0.0,
1297                incy: 1,
1298            };
1299            // SAFETY: qs_dev p*p (transposed); rhs_dev length p; beta_orig_dev length p.
1300            unsafe {
1301                ws.blas
1302                    .gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
1303            }
1304            .map_err(|e| format!("dgemv Qsᵀ·score inplace: {e}"))?;
1305            ws.stream
1306                .memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
1307                .map_err(|e| format!("d2d Qsᵀ·score→rhs inplace: {e}"))?;
1308        }
1309        // Now download rhs and β (both p-vectors; small, bounded-cost round-trip).
1310        // Apply rhs −= S·β and rhs += linear_shift on the host for correctness.
1311        let rhs_raw = ws
1312            .stream
1313            .clone_dtoh(&ws.rhs_dev)
1314            .map_err(|e| format!("download Qsᵀ·score inplace: {e}"))?;
1315        let beta_raw = ws
1316            .stream
1317            .clone_dtoh(input.beta_dev)
1318            .map_err(|e| format!("download beta inplace: {e}"))?;
1319        let mut rhs_host = Array1::from_vec(rhs_raw);
1320        let beta_host = Array1::from_vec(beta_raw);
1321        // S·β in transformed coordinates (S = input.penalty_hessian in transformed frame).
1322        let s_beta = input.penalty_hessian.dot(&beta_host);
1323        rhs_host -= &s_beta;
1324        rhs_host += &input.linear_shift;
1325        ws.stream
1326            .memcpy_htod(
1327                rhs_host.as_slice().ok_or("rhs_host not contiguous")?,
1328                &mut ws.rhs_dev,
1329            )
1330            .map_err(|e| format!("re-upload corrected rhs inplace: {e}"))?;
1331
1332        // Step 4: Cholesky factor + solve in-place.
1333        potrf_in_place_reuse(
1334            &ws.solver,
1335            &ws.stream,
1336            p,
1337            ws.potrf_lwork,
1338            &mut ws.h_dev,
1339            &mut ws.potrf_work_dev,
1340            &mut ws.potrf_info_dev,
1341        )?;
1342        potrs_in_place_reuse(
1343            &ws.solver,
1344            &ws.stream,
1345            p,
1346            1,
1347            &ws.h_dev,
1348            &mut ws.rhs_dev,
1349            &mut ws.potrs_info_dev,
1350        )?;
1351        let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
1352        check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
1353        check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
1354
1355        // ws.rhs_dev = δ = H⁻¹·(Qsᵀ score_p − Sβ + linear_shift) — descent direction.
1356        // No negation: the corrected RHS directly gives the descent direction (#257).
1357        Ok(logdet)
1358    }
1359
1360    /// Rebuild the penalised Hessian `H = XᵀW_hessianX + S + objective_ridge·I`
1361    /// on device using the accepted `w_hessian` weights and download it once.
1362    /// Called once after PIRLS convergence so the exported Hessian reflects
1363    /// the accepted eta, not a stale mid-loop snapshot.
1364    ///
1365    /// Uses `ws.wx_dev`, `ws.xtwx_dev`, `ws.h_dev`, `ws.penalty_dev` as
1366    /// scratch — all are fair game post-loop.
1367    pub(super) fn rebuild_h_final(
1368        shared: &PirlsGpuSharedData,
1369        ws: &mut SigmaPirlsGpuWorkspace,
1370        w_hessian_dev: &CudaSlice<f64>,
1371        penalty_hessian: ArrayView2<'_, f64>,
1372        objective_ridge: f64,
1373    ) -> Result<Array2<f64>, String> {
1374        let n = shared.n;
1375        let p = shared.p;
1376
1377        // XtWX via fused path (no n*p WX temp) or fallback ddgmm + dgemm.
1378        if let Some(ref mut wx_dev_rh) = ws.wx_dev {
1379            // Large-p fallback: WX = diag(w_hessian) · X.
1380            left_scale_rows_borrowed(
1381                &ws.blas,
1382                &ws.stream,
1383                n,
1384                p,
1385                &shared.x_original_dev,
1386                w_hessian_dev,
1387                wx_dev_rh,
1388            )?;
1389            let n_i = to_i32(n)?;
1390            let p_i = to_i32(p)?;
1391            let gemm_cfg = GemmConfig::<f64> {
1392                transa: cublasOperation_t::CUBLAS_OP_T,
1393                transb: cublasOperation_t::CUBLAS_OP_N,
1394                m: p_i,
1395                n: p_i,
1396                k: n_i,
1397                alpha: 1.0,
1398                lda: n_i,
1399                ldb: n_i,
1400                beta: 0.0,
1401                ldc: p_i,
1402            };
1403            // SAFETY: validated dims; shared.x_original_dev and wx_dev_rh n*p
1404            // col-major; ws.xtwx_dev is p*p; all on ws.stream.
1405            unsafe {
1406                ws.blas.gemm(
1407                    gemm_cfg,
1408                    &shared.x_original_dev,
1409                    wx_dev_rh,
1410                    &mut ws.xtwx_dev,
1411                )
1412            }
1413            .map_err(|e| format!("cublas dgemm XtWX (final H rebuild): {e}"))?;
1414        } else {
1415            // Fused path: xtwx_lower + symmetrize, no n*p temp.
1416            launch_xtwx_lower(
1417                &ws.stream,
1418                &shared.ctx,
1419                n,
1420                p,
1421                &shared.x_original_dev,
1422                w_hessian_dev,
1423                &mut ws.xtwx_dev,
1424            )?;
1425            launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
1426        }
1427
1428        // H_final = Qsᵀ (XtWX) Qs + S + objective_ridge·I.
1429        let p_i = to_i32(p)?;
1430        // tmp = XtWX · Qs → ws.qs_tmp_dev.
1431        {
1432            let cfg_aq = GemmConfig::<f64> {
1433                transa: cublasOperation_t::CUBLAS_OP_N,
1434                transb: cublasOperation_t::CUBLAS_OP_N,
1435                m: p_i,
1436                n: p_i,
1437                k: p_i,
1438                alpha: 1.0,
1439                lda: p_i,
1440                ldb: p_i,
1441                beta: 0.0,
1442                ldc: p_i,
1443            };
1444            // SAFETY: xtwx_dev and qs_dev p*p col-major; qs_tmp_dev p*p output.
1445            unsafe {
1446                ws.blas
1447                    .gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
1448            }
1449            .map_err(|e| format!("dgemm A·Qs (final H rebuild): {e}"))?;
1450        }
1451        // H_xtx = Qsᵀ · tmp → ws.h_dev.
1452        {
1453            let cfg_qt = GemmConfig::<f64> {
1454                transa: cublasOperation_t::CUBLAS_OP_T,
1455                transb: cublasOperation_t::CUBLAS_OP_N,
1456                m: p_i,
1457                n: p_i,
1458                k: p_i,
1459                alpha: 1.0,
1460                lda: p_i,
1461                ldb: p_i,
1462                beta: 0.0,
1463                ldc: p_i,
1464            };
1465            // SAFETY: qs_dev p*p (transposed); qs_tmp_dev p*p; h_dev p*p output.
1466            unsafe {
1467                ws.blas
1468                    .gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
1469            }
1470            .map_err(|e| format!("dgemm Qsᵀ·A·Qs (final H rebuild): {e}"))?;
1471        }
1472        let penalty = penalty_with_ridge(penalty_hessian, objective_ridge);
1473        let penalty_col = to_col_major(&penalty);
1474        ws.stream
1475            .memcpy_htod(penalty_col.as_ref(), &mut ws.penalty_dev)
1476            .map_err(|e| format!("upload penalty (final H rebuild): {e}"))?;
1477        geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
1478
1479        // One download — the only H transfer in the entire PIRLS loop.
1480        let h_col = ws
1481            .stream
1482            .clone_dtoh(&ws.h_dev)
1483            .map_err(|e| format!("download H_final: {e}"))?;
1484        from_col_major(&h_col, p, p).ok_or_else(|| "H_final layout conversion failed".to_string())
1485    }
1486
1487    pub(super) fn weighted_crossprod(
1488        x: ArrayView2<'_, f64>,
1489        weights: ArrayView1<'_, f64>,
1490    ) -> Result<Array2<f64>, String> {
1491        let (_, stream) = context_and_stream()?;
1492        let (n, p) = validate_design(x, weights)?;
1493        let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
1494        let x_col = to_col_major(&x);
1495        let x_dev = pinned_htod(&stream, &x_col)?;
1496        let mut w_dev = pinned_htod(
1497            &stream,
1498            weights.as_slice().ok_or("weights must be contiguous")?,
1499        )?;
1500        let mut wx_dev = stream
1501            .alloc_zeros::<f64>(n.checked_mul(p).ok_or("X size overflow")?)
1502            .map_err(|e| format!("cuda alloc WX: {e}"))?;
1503        left_scale_rows(&blas, &stream, n, p, &x_dev, &mut w_dev, &mut wx_dev)?;
1504        let mut h_dev = stream
1505            .alloc_zeros::<f64>(p.checked_mul(p).ok_or("H size overflow")?)
1506            .map_err(|e| format!("cuda alloc H: {e}"))?;
1507        let n_i = to_i32(n)?;
1508        let p_i = to_i32(p)?;
1509        let cfg = GemmConfig::<f64> {
1510            transa: cublasOperation_t::CUBLAS_OP_T,
1511            transb: cublasOperation_t::CUBLAS_OP_N,
1512            m: p_i,
1513            n: p_i,
1514            k: n_i,
1515            alpha: 1.0,
1516            lda: n_i,
1517            ldb: n_i,
1518            beta: 0.0,
1519            ldc: p_i,
1520        };
1521        // SAFETY: cuBLAS dgemm with validated i32 dimensions; x_dev/wx_dev are n*p f64 device
1522        // buffers and h_dev is the p*p output, all allocated above with matching sizes.
1523        unsafe { blas.gemm(cfg, &x_dev, &wx_dev, &mut h_dev) }
1524            .map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
1525        let h_col = stream
1526            .clone_dtoh(&h_dev)
1527            .map_err(|e| format!("download H: {e}"))?;
1528        from_col_major(&h_col, p, p).ok_or_else(|| "H layout conversion failed".to_string())
1529    }
1530
1531    pub(super) fn solve_step(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
1532        // One-shot path for the legacy single-step API: validate, build a
1533        // one-shot shared+workspace, run a single step, drop. This routes
1534        // through `solve_step_on_stream` so there is exactly one math path
1535        // for both the batch-mode cubature executor and the single-step
1536        // test/bench surface.
1537        let (_, p) = validate_design(input.x, input.weights)?;
1538        if input.penalty_hessian.dim() != (p, p) {
1539            return Err(format!(
1540                "penalty Hessian shape {:?} does not match p={p}",
1541                input.penalty_hessian.dim()
1542            ));
1543        }
1544        if input.gradient.len() != p {
1545            return Err(format!(
1546                "gradient length {} does not match p={p}",
1547                input.gradient.len()
1548            ));
1549        }
1550        // The legacy single-step API has no GLM data — `solve_step_on_stream`
1551        // (which this dispatches to) only reads `shared.x_original_dev`.
1552        // The shared upload requires y/prior_w/offset for the loop paths, so
1553        // pass zero placeholders sized to the design's row count; they are
1554        // never read by the one-shot Newton step path.
1555        let n_rows = input.x.nrows();
1556        let zero_n = ndarray::Array1::<f64>::zeros(n_rows);
1557        let shared =
1558            PirlsGpuSharedData::upload_impl(input.x, zero_n.view(), zero_n.view(), zero_n.view())?;
1559        let mut ws = SigmaPirlsGpuWorkspace::allocate_impl(&shared)?;
1560        solve_step_on_stream(
1561            &shared,
1562            &mut ws,
1563            PirlsStepStreamInput {
1564                weights: input.weights,
1565                penalty_hessian: input.penalty_hessian,
1566                gradient: input.gradient,
1567                step_lm_lambda: input.step_lm_lambda,
1568                objective_ridge: input.objective_ridge,
1569            },
1570        )
1571    }
1572
1573    fn validate_design(
1574        x: ArrayView2<'_, f64>,
1575        weights: ArrayView1<'_, f64>,
1576    ) -> Result<(usize, usize), String> {
1577        let (n, p) = x.dim();
1578        if weights.len() != n {
1579            return Err(format!(
1580                "weights length {} does not match rows {n}",
1581                weights.len()
1582            ));
1583        }
1584        if n == 0 || p == 0 {
1585            return Err("empty design cannot be solved on CUDA".to_string());
1586        }
1587        Ok((n, p))
1588    }
1589
1590    fn left_scale_rows(
1591        blas: &CudaBlas,
1592        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1593        n: usize,
1594        p: usize,
1595        x_dev: &CudaSlice<f64>,
1596        w_dev: &mut CudaSlice<f64>,
1597        wx_dev: &mut CudaSlice<f64>,
1598    ) -> Result<(), String> {
1599        let n_i = to_i32(n)?;
1600        let p_i = to_i32(p)?;
1601        let handle = *blas.handle();
1602        let (x_ptr, _x_record) = x_dev.device_ptr(stream);
1603        let (w_ptr, _w_record) = w_dev.device_ptr(stream);
1604        let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
1605        // SAFETY: FFI call into cuBLAS; pointers come from live CudaSlice device buffers sized
1606        // n*p (x, wx) and n (w), leading dims match column-major layout, handle is valid.
1607        let status = unsafe {
1608            cublasDdgmm(
1609                handle,
1610                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1611                n_i,
1612                p_i,
1613                x_ptr as *const f64,
1614                n_i,
1615                w_ptr as *const f64,
1616                1,
1617                wx_ptr as *mut f64,
1618                n_i,
1619            )
1620        };
1621        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1622            Ok(())
1623        } else {
1624            Err(format!("cublasDdgmm failed with {status:?}"))
1625        }
1626    }
1627
1628    /// Borrowed-input variant of [`left_scale_rows`] used by the Stage 3.2
1629    /// device-input PIRLS step. Reads weights through `&CudaSlice` so the
1630    /// caller can keep ownership of the row-reweight buffer across the
1631    /// PIRLS iteration without an extra device-side copy.
1632    fn left_scale_rows_borrowed(
1633        blas: &CudaBlas,
1634        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1635        n: usize,
1636        p: usize,
1637        x_dev: &CudaSlice<f64>,
1638        w_dev: &CudaSlice<f64>,
1639        wx_dev: &mut CudaSlice<f64>,
1640    ) -> Result<(), String> {
1641        let n_i = to_i32(n)?;
1642        let p_i = to_i32(p)?;
1643        let handle = *blas.handle();
1644        let (x_ptr, _x_record) = x_dev.device_ptr(stream);
1645        let (w_ptr, _w_record) = w_dev.device_ptr(stream);
1646        let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
1647        // SAFETY: FFI call into cuBLAS; pointers come from live CudaSlice
1648        // device buffers; x is n*p col-major (lda = n), w is length n
1649        // (stride 1), wx is n*p output (lda = n). Caller-owned w buffer
1650        // is borrowed read-only here, matching cublasDdgmm's contract.
1651        let status = unsafe {
1652            cublasDdgmm(
1653                handle,
1654                cublasSideMode_t::CUBLAS_SIDE_LEFT,
1655                n_i,
1656                p_i,
1657                x_ptr as *const f64,
1658                n_i,
1659                w_ptr as *const f64,
1660                1,
1661                wx_ptr as *mut f64,
1662                n_i,
1663            )
1664        };
1665        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1666            Ok(())
1667        } else {
1668            Err(format!("cublasDdgmm (borrowed) failed with {status:?}"))
1669        }
1670    }
1671
1672    // In-place `a := a + b` for two `p*p` column-major device buffers via
1673    // cublasDgeam. The C API explicitly permits `C = A` (output aliasing the
1674    // first input), but Rust's borrow checker cannot prove that — every
1675    // caller historically passed `&ws.h_dev, &ws.penalty_dev, &mut ws.h_dev`
1676    // and ran into E0502. Forcing the in-place semantics into the wrapper
1677    // signature makes the contract explicit and removes the aliasing-borrow
1678    // class of errors at the call sites.
1679    fn geam_add_inplace(
1680        blas: &CudaBlas,
1681        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1682        p: usize,
1683        a: &mut CudaSlice<f64>,
1684        b: &CudaSlice<f64>,
1685    ) -> Result<(), String> {
1686        let p_i = to_i32(p)?;
1687        let alpha = 1.0_f64;
1688        let beta = 1.0_f64;
1689        let handle = *blas.handle();
1690        let (b_ptr, _b_record) = b.device_ptr(stream);
1691        let (a_ptr, _a_record) = a.device_ptr_mut(stream);
1692        // cublasDgeam with C == A is allowed and computes `A := alpha*A + beta*B`.
1693        let out_ptr = a_ptr;
1694        // SAFETY: FFI call into cuBLAS geam; a, b, out are live p*p device buffers in column-major
1695        // with leading dim p_i, scalars live on host stack, handle is valid.
1696        let status = unsafe {
1697            cublasDgeam(
1698                handle,
1699                cublasOperation_t::CUBLAS_OP_N,
1700                cublasOperation_t::CUBLAS_OP_N,
1701                p_i,
1702                p_i,
1703                &alpha,
1704                a_ptr as *const f64,
1705                p_i,
1706                &beta,
1707                b_ptr as *const f64,
1708                p_i,
1709                out_ptr as *mut f64,
1710                p_i,
1711            )
1712        };
1713        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1714            Ok(())
1715        } else {
1716            Err(format!("cublasDgeam failed with {status:?}"))
1717        }
1718    }
1719
1720    /// Launch the `xtwx_lower` kernel: one thread per lower-tri pair `(j,k)`,
1721    /// iterates over all `n` rows and writes `A[j + k*p]` (col-major lower
1722    /// triangle of `XᵀWX`). Call `launch_symmetrize_lower` afterwards.
1723    fn launch_xtwx_lower(
1724        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1725        ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1726        n: usize,
1727        p: usize,
1728        x_dev: &CudaSlice<f64>,
1729        w_dev: &CudaSlice<f64>,
1730        a_dev: &mut CudaSlice<f64>,
1731    ) -> Result<(), String> {
1732        let module = FUSED_XTWX_CACHE
1733            .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1734            .map_err(|e| format!("fused_xtwx module: {e}"))?;
1735        let func = module
1736            .load_function("xtwx_lower")
1737            .map_err(|e| format!("load xtwx_lower: {e}"))?;
1738        let n_i = to_i32(n)?;
1739        let p_i = to_i32(p)?;
1740        let num_pairs = p * (p + 1) / 2;
1741        let num_pairs_u32 = u32::try_from(num_pairs)
1742            .map_err(|_| format!("xtwx_lower: num_pairs {num_pairs} > u32"))?;
1743        const BLOCK: u32 = 256;
1744        let grid = num_pairs_u32.div_ceil(BLOCK).max(1);
1745        let cfg = cudarc::driver::LaunchConfig {
1746            grid_dim: (grid, 1, 1),
1747            block_dim: (BLOCK, 1, 1),
1748            shared_mem_bytes: 0,
1749        };
1750        let mut builder = stream.launch_builder(&func);
1751        builder.arg(x_dev);
1752        builder.arg(w_dev);
1753        builder.arg(a_dev);
1754        builder.arg(&n_i);
1755        builder.arg(&p_i);
1756        // SAFETY: x_dev is n*p col-major f64; w_dev is length n; a_dev is p*p;
1757        // num_pairs threads each write one lower-tri entry A[j + k*p].
1758        unsafe { builder.launch(cfg) }
1759            .map_err(|e| format!("xtwx_lower launch: {e}"))
1760            .map(|_| ())
1761    }
1762
1763    /// Launch the `xtscore` kernel: one thread per output index `j`,
1764    /// iterates over `n` rows and writes `s[j] = sum_i score[i]*X[i,j]`.
1765    fn launch_xtscore(
1766        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1767        ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1768        n: usize,
1769        p: usize,
1770        x_dev: &CudaSlice<f64>,
1771        score_dev: &CudaSlice<f64>,
1772        s_dev: &mut CudaSlice<f64>,
1773    ) -> Result<(), String> {
1774        let module = FUSED_XTWX_CACHE
1775            .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1776            .map_err(|e| format!("fused_xtwx module (xtscore): {e}"))?;
1777        let func = module
1778            .load_function("xtscore")
1779            .map_err(|e| format!("load xtscore: {e}"))?;
1780        let n_i = to_i32(n)?;
1781        let p_i = to_i32(p)?;
1782        let p_u32 = u32::try_from(p).map_err(|_| format!("xtscore: p {p} > u32"))?;
1783        const BLOCK: u32 = 256;
1784        let grid = p_u32.div_ceil(BLOCK).max(1);
1785        let cfg = cudarc::driver::LaunchConfig {
1786            grid_dim: (grid, 1, 1),
1787            block_dim: (BLOCK, 1, 1),
1788            shared_mem_bytes: 0,
1789        };
1790        let mut builder = stream.launch_builder(&func);
1791        builder.arg(x_dev);
1792        builder.arg(score_dev);
1793        builder.arg(s_dev);
1794        builder.arg(&n_i);
1795        builder.arg(&p_i);
1796        // SAFETY: x_dev is n*p col-major f64; score_dev is length n; s_dev is length p;
1797        // p threads each write one output entry s[j].
1798        unsafe { builder.launch(cfg) }
1799            .map_err(|e| format!("xtscore launch: {e}"))
1800            .map(|_| ())
1801    }
1802
1803    /// Launch the `symmetrize_lower` kernel: one thread per strict lower-tri
1804    /// pair `(j,k)` with `j > k`; copies `A[k + j*p] = A[j + k*p]` to fill
1805    /// the upper triangle from the lower triangle populated by `xtwx_lower`.
1806    fn launch_symmetrize_lower(
1807        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1808        ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1809        p: usize,
1810        a_dev: &mut CudaSlice<f64>,
1811    ) -> Result<(), String> {
1812        if p <= 1 {
1813            return Ok(());
1814        }
1815        let module = FUSED_XTWX_CACHE
1816            .get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
1817            .map_err(|e| format!("fused_xtwx module (sym): {e}"))?;
1818        let func = module
1819            .load_function("symmetrize_lower")
1820            .map_err(|e| format!("load symmetrize_lower: {e}"))?;
1821        let p_i = to_i32(p)?;
1822        let num_strict = p * (p - 1) / 2;
1823        let num_strict_u32 = u32::try_from(num_strict)
1824            .map_err(|_| format!("symmetrize_lower: num_strict {num_strict} > u32"))?;
1825        const BLOCK: u32 = 256;
1826        let grid = num_strict_u32.div_ceil(BLOCK).max(1);
1827        let cfg = cudarc::driver::LaunchConfig {
1828            grid_dim: (grid, 1, 1),
1829            block_dim: (BLOCK, 1, 1),
1830            shared_mem_bytes: 0,
1831        };
1832        let mut builder = stream.launch_builder(&func);
1833        builder.arg(a_dev);
1834        builder.arg(&p_i);
1835        // SAFETY: a_dev is p*p col-major f64; each of the num_strict threads
1836        // writes one upper-triangle entry mirrored from the lower triangle.
1837        unsafe { builder.launch(cfg) }
1838            .map_err(|e| format!("symmetrize_lower launch: {e}"))
1839            .map(|_| ())
1840    }
1841
1842    /// Launch the device-side Cholesky-factor logdet kernel and download
1843    /// the single scalar result. Replaces the per-step p² host download of
1844    /// the Cholesky factor that the host-side `cholesky_logdet_from_col_major`
1845    /// required.
1846    fn cholesky_logdet_device(
1847        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1848        ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
1849        p: usize,
1850        factor_dev: &CudaSlice<f64>,
1851    ) -> Result<f64, String> {
1852        let module = CHOL_LOGDET_CACHE
1853            .get_or_compile(ctx, "pirls_gpu_chol_logdet", CHOL_LOGDET_PTX_SOURCE)
1854            .map_err(|err| format!("chol_logdet module: {err}"))?;
1855        let func = module
1856            .load_function("chol_logdet_col_major")
1857            .map_err(|err| format!("chol_logdet load_function: {err}"))?;
1858        let mut out_dev = stream
1859            .alloc_zeros::<f64>(1)
1860            .map_err(|err| format!("alloc chol_logdet out: {err}"))?;
1861        let p_i = to_i32(p)?;
1862        let cfg = LaunchConfig {
1863            grid_dim: (1, 1, 1),
1864            block_dim: (1, 1, 1),
1865            shared_mem_bytes: 0,
1866        };
1867        let mut builder = stream.launch_builder(&func);
1868        builder.arg(factor_dev);
1869        builder.arg(&p_i);
1870        builder.arg(&mut out_dev);
1871        // SAFETY: serial single-thread kernel reading `p` f64 diagonal
1872        // entries from a live p*p column-major factor and writing one f64
1873        // to `out_dev`; no aliasing, no oob — `p` matches the device buffer
1874        // shape every caller passes in.
1875        unsafe { builder.launch(cfg) }.map_err(|err| format!("chol_logdet launch: {err}"))?;
1876        let out_host = stream
1877            .clone_dtoh(&out_dev)
1878            .map_err(|err| format!("download chol_logdet: {err}"))?;
1879        Ok(out_host[0])
1880    }
1881
1882    fn penalty_with_ridge(penalty: ArrayView2<'_, f64>, ridge: f64) -> Array2<f64> {
1883        let mut out = penalty.to_owned();
1884        if ridge != 0.0 {
1885            for i in 0..out.nrows().min(out.ncols()) {
1886                out[[i, i]] += ridge;
1887            }
1888        }
1889        out
1890    }
1891
1892    fn to_i32(value: usize) -> Result<i32, String> {
1893        i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
1894    }
1895
1896    // ────────────────────────────────────────────────────────────────────
1897    // Stage 3.3: full device-resident PIRLS loop driver
1898    // ────────────────────────────────────────────────────────────────────
1899
1900    /// Bundled NVRTC helpers for the Stage 3.3 loop driver: axpy +
1901    /// single-block sum / linf reductions. Cached process-wide.
1902    const PIRLS_LOOP_PTX_SOURCE: &str = r#"
1903extern "C" {
1904    double fabs(double);
1905}
1906
1907extern "C" __global__ void axpy_n(
1908    double alpha,
1909    const double* __restrict__ x,
1910    double* __restrict__ y,
1911    int n
1912) {
1913    int i = blockIdx.x * blockDim.x + threadIdx.x;
1914    if (i >= n) return;
1915    y[i] += alpha * x[i];
1916}
1917
1918extern "C" __global__ void deviance_sum(
1919    const double* __restrict__ d,
1920    int n,
1921    double* __restrict__ out
1922) {
1923    __shared__ double sm[1024];
1924    int tid = threadIdx.x;
1925    int bdim = blockDim.x;
1926    double acc = 0.0;
1927    for (int i = tid; i < n; i += bdim) {
1928        acc += d[i];
1929    }
1930    sm[tid] = acc;
1931    __syncthreads();
1932    for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1933        if (tid < stride) sm[tid] += sm[tid + stride];
1934        __syncthreads();
1935    }
1936    if (tid == 0) out[0] = sm[0];
1937}
1938
1939extern "C" __global__ void linf_norm(
1940    const double* __restrict__ v,
1941    int p,
1942    double* __restrict__ out
1943) {
1944    __shared__ double sm[1024];
1945    int tid = threadIdx.x;
1946    int bdim = blockDim.x;
1947    double acc = 0.0;
1948    for (int i = tid; i < p; i += bdim) {
1949        double a = fabs(v[i]);
1950        if (a > acc) acc = a;
1951    }
1952    sm[tid] = acc;
1953    __syncthreads();
1954    for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1955        if (tid < stride) {
1956            double r = sm[tid + stride];
1957            if (r > sm[tid]) sm[tid] = r;
1958        }
1959        __syncthreads();
1960    }
1961    if (tid == 0) out[0] = sm[0];
1962}
1963
1964extern "C" __global__ void negate_n(
1965    double* __restrict__ v,
1966    int n
1967) {
1968    int i = blockIdx.x * blockDim.x + threadIdx.x;
1969    if (i >= n) return;
1970    v[i] = -v[i];
1971}
1972
1973// OR-reduction over a u32 status array (length n).  Single-block;
1974// same launch config as deviance_sum (1 block of 1024 threads).
1975// out[0] receives the bitwise-OR of all status[i] for i in [0, n).
1976extern "C" __global__ void status_or(
1977    const unsigned int* __restrict__ status,
1978    int n,
1979    unsigned int* __restrict__ out
1980) {
1981    __shared__ unsigned int sm[1024];
1982    int tid = threadIdx.x;
1983    int bdim = blockDim.x;
1984    unsigned int acc = 0u;
1985    for (int i = tid; i < n; i += bdim) {
1986        acc |= status[i];
1987    }
1988    sm[tid] = acc;
1989    __syncthreads();
1990    for (int stride = bdim / 2; stride > 0; stride >>= 1) {
1991        if (tid < stride) sm[tid] |= sm[tid + stride];
1992        __syncthreads();
1993    }
1994    if (tid == 0) out[0] = sm[0];
1995}
1996"#;
1997
1998    static PIRLS_LOOP_CACHE: PtxModuleCache = PtxModuleCache::new();
1999
2000    /// Per-fit device workspace for the Stage 3.3 PIRLS loop driver.
2001    ///
2002    /// Three row-kernel modes occupy separate device buffers:
2003    /// - `row_solve`: solve-row (4 fields), refreshed each Newton iteration.
2004    /// - `alpha_ladder`: candidate-objective (objective[7] + status[7]).
2005    /// - `row_final`: final-row (9 fields), written once at convergence.
2006    pub struct PirlsLoopWorkspace {
2007        pub beta_dev: CudaSlice<f64>,
2008        pub eta_dev: CudaSlice<f64>,
2009        /// Solve-row buffers: `grad_eta`, `w_solver`, `deviance`, `status`.
2010        pub row_solve: crate::gpu_kernels::pirls_row::SolveRowBuffers,
2011        /// Alpha-ladder buffers: `objective[7]`, `status[7]`.
2012        pub alpha_ladder: crate::gpu_kernels::pirls_row::AlphaLadderDevBuffers,
2013        /// Full final-row buffers: all 9 fields, written once at convergence.
2014        pub row_final: crate::gpu_kernels::pirls_row::RowOutputDevBuffers,
2015        pub direction_dev: CudaSlice<f64>,
2016        pub xd_dev: CudaSlice<f64>,
2017        pub scalar_dev: CudaSlice<f64>,
2018        /// Single-element u32 for the `status_or` OR-reduction kernel.
2019        pub status_u32_dev: CudaSlice<u32>,
2020        pub n: usize,
2021        pub p: usize,
2022    }
2023
2024    impl PirlsLoopWorkspace {
2025        pub fn allocate(
2026            shared: &PirlsGpuSharedData,
2027            stream: &std::sync::Arc<cudarc::driver::CudaStream>,
2028        ) -> Result<Self, String> {
2029            let n = shared.n;
2030            let p = shared.p;
2031            let alloc_f64 = |label: &'static str, len: usize| {
2032                stream
2033                    .alloc_zeros::<f64>(len)
2034                    .map_err(|e| format!("pirls loop alloc {label}: {e}"))
2035            };
2036            Ok(Self {
2037                beta_dev: alloc_f64("beta", p)?,
2038                eta_dev: alloc_f64("eta", n)?,
2039                row_solve: crate::gpu_kernels::pirls_row::SolveRowBuffers::allocate(stream, n)
2040                    .map_err(|e| format!("pirls loop alloc row_solve: {e}"))?,
2041                alpha_ladder: crate::gpu_kernels::pirls_row::AlphaLadderDevBuffers::allocate(
2042                    stream,
2043                )
2044                .map_err(|e| format!("pirls loop alloc alpha_ladder: {e}"))?,
2045                row_final: crate::gpu_kernels::pirls_row::RowOutputDevBuffers::allocate(stream, n)
2046                    .map_err(|e| format!("pirls loop alloc row_final: {e}"))?,
2047                direction_dev: alloc_f64("direction", p)?,
2048                xd_dev: alloc_f64("xd", n)?,
2049                scalar_dev: alloc_f64("scalar", 1)?,
2050                status_u32_dev: stream
2051                    .alloc_zeros::<u32>(1)
2052                    .map_err(|e| format!("pirls loop alloc status_u32: {e}"))?,
2053                n,
2054                p,
2055            })
2056        }
2057    }
2058
2059    /// Optional host-side inputs that turn the bare GPU loop result
2060    /// into a full-surface `PirlsLoopOutcome` matching the CPU oracle
2061    /// `fit_model_for_fixed_rho_with_adaptive_kkt`.
2062    ///
2063    /// When supplied, the postpass at loop exit runs the same host-side
2064    /// helpers the CPU oracle uses
2065    /// (`computeworkingweight_derivatives_from_eta`,
2066    /// `compute_observed_hessian_curvature_arrays`,
2067    /// `compute_constraint_kkt_diagnostics`) so the dispatch wirer can
2068    /// plumb every field of `PirlsResult` without doing math.
2069    ///
2070    /// When `None`, the derived fields on `PirlsLoopOutcome`
2071    /// (`finalweights`, `solveweights`, `solve_dmu_deta`,
2072    /// `solve_d2mu_deta2`, `solve_d3mu_deta3`, `solve_c_array`,
2073    /// `solve_d_array`, `status`, `constraint_kkt`, `ridge_passport`,
2074    /// `firth`, `edf`, `beta_transformed`, `derivatives_unsupported`)
2075    /// take safe defaults: empty arrays, `PirlsStatus::Converged` or
2076    /// `MaxIterationsReached` reflecting `converged`, no KKT
2077    /// diagnostics, identity ridge with `objective_ridge` magnitude,
2078    /// `FirthDiagnostics::Inactive`, `edf = NaN`,
2079    /// `beta_transformed = beta`, `derivatives_unsupported = true`.
2080    /// Existing callers that do not need the CPU oracle surface can
2081    /// pass `None` and ignore the derived fields.
2082    pub struct PirlsLoopExtra<'a> {
2083        /// GLM likelihood spec the row kernel was driven by. Needed by
2084        /// `computeworkingweight_derivatives_from_eta` to produce
2085        /// `solve_dmu_deta` / `solve_d2mu_deta2` / `solve_d3mu_deta3`
2086        /// and the score-side `c` / `d` arrays.
2087        pub likelihood: &'a gam_problem::GlmLikelihoodSpec,
2088        /// Inverse link the row kernel was driven by; pairs with
2089        /// `likelihood` for the family-specific derivatives.
2090        pub inverse_link: &'a gam_problem::InverseLink,
2091        /// Response vector `y` (length `n`) — same view passed to the
2092        /// row kernel. Needed for observed-curvature finalization.
2093        pub y: ndarray::ArrayView1<'a, f64>,
2094        /// Prior weights (length `n`) — same view passed to the row
2095        /// kernel. Carried through to the curvature helpers.
2096        pub priorweights: ndarray::ArrayView1<'a, f64>,
2097        /// Observation offset (length `n`). Stored verbatim on the
2098        /// outcome's `final_offset` so the dispatch wirer can populate
2099        /// `PirlsResult::final_offset` without re-allocating.
2100        pub offset: ndarray::ArrayView1<'a, f64>,
2101        /// Linear inequality constraints `A·β ≥ b` in the same
2102        /// coordinate frame as the GPU loop's β. When `Some`, the
2103        /// postpass calls `compute_constraint_kkt_diagnostics` on the
2104        /// converged β + reconstructed penalised gradient and emits
2105        /// the result on `PirlsLoopOutcome::constraint_kkt`. When
2106        /// `None`, no diagnostics are produced.
2107        pub linear_constraints: Option<&'a gam_problem::LinearInequalityConstraints>,
2108        /// Curvature surface the *outer* REML / LAML caller expects on
2109        /// the returned Hessian. The GPU loop runs under whatever
2110        /// `curvature: CurvatureMode` it was invoked with; if this
2111        /// differs (e.g. inner loop ran Fisher for stability but the
2112        /// outer caller demands observed curvature), the postpass
2113        /// promotes `finalweights` / `solve_c_array` / `solve_d_array`
2114        /// via `compute_observed_hessian_curvature_arrays` so the
2115        /// outcome matches the CPU oracle's `exported_laplace_curvature`
2116        /// contract.
2117        pub exported_curvature: crate::pirls::HessianCurvatureKind,
2118        /// Pre-built ridge passport carrying the stabilization
2119        /// magnitude + policy that the dispatch wirer wants stamped on
2120        /// `PirlsResult::ridge_passport`. When `None`, the postpass
2121        /// uses `RidgePassport::scaled_identity(objective_ridge,
2122        /// RidgePolicy::explicit_stabilization_full())`, which mirrors
2123        /// the CPU oracle's default for a no-escalation fit.
2124        pub ridge_passport: Option<gam_problem::RidgePassport>,
2125        /// Firth bias-reduction diagnostics. Today the GPU loop does
2126        /// not implement Firth; pass `None` to land
2127        /// `FirthDiagnostics::Inactive` on the outcome. A future
2128        /// device-side Firth path would populate this with the active
2129        /// Jeffreys-logdet + hat-diagonal vector.
2130        pub firth: Option<crate::pirls::FirthDiagnostics>,
2131        /// Canonical-basis transform `qs` (size `p × p`) that maps
2132        /// transformed-basis β to original coordinates via
2133        /// `beta_original = qs · beta_transformed`. Carried on the
2134        /// struct for callers that need original-coordinate β; the
2135        /// postpass does **not** apply `qs` to the loop's β because
2136        /// the GPU loop already solved in the transformed design
2137        /// `X·Qs`, so the loop's β *is* `beta_transformed`. When
2138        /// `None`, no reparameterization is active and transformed
2139        /// and original coordinates coincide.
2140        pub qs: Option<ndarray::ArrayView2<'a, f64>>,
2141        /// Effective degrees of freedom at the converged mode, when
2142        /// the dispatch wirer has it precomputed (typical case: the
2143        /// outer REML caller passes its own `e_transformed` /
2144        /// diagonal-penalty pre-image and computes EDF host-side).
2145        /// When `None`, the postpass emits `f64::NAN` and sets
2146        /// `derivatives_unsupported = true` — the dispatch wirer can
2147        /// then compute EDF itself from `penalized_hessian` and the
2148        /// caller-side penalty root.
2149        pub edf: Option<f64>,
2150    }
2151
2152    #[derive(Clone, Debug)]
2153    pub struct PirlsLoopOutcome {
2154        pub beta: Array1<f64>,
2155        pub penalized_hessian: Array2<f64>,
2156        pub logdet: f64,
2157        pub deviance: f64,
2158        pub iterations: usize,
2159        pub converged: bool,
2160        /// Final linear predictor η = X·β at the accepted PIRLS step
2161        /// (length `n`). Downloaded once at loop exit.
2162        pub final_eta: Array1<f64>,
2163        /// Mean response μ = g⁻¹(η) at the accepted step, length `n`.
2164        /// Maps to `PirlsResult::finalmu` / `solvemu`.
2165        pub final_mu: Array1<f64>,
2166        /// Score-side gradient contribution `∂ℓ/∂η_i` at the accepted
2167        /// step (length `n`). The CPU oracle uses this to form
2168        /// `score_norm = ‖Xᵀ grad_eta‖₂`.
2169        pub final_grad_eta: Array1<f64>,
2170        /// Hessian-side diagonal working weight `w_hessian_i` at the
2171        /// accepted step. Maps to `PirlsResult::finalweights` when no
2172        /// observed-curvature promotion is requested.
2173        pub final_w_hessian: Array1<f64>,
2174        /// Score-side diagonal working weight `w_solver_i` at the
2175        /// accepted step. Maps to `PirlsResult::solveweights`.
2176        pub final_w_solver: Array1<f64>,
2177        /// Observation offset (length `n`). Echoed from
2178        /// `PirlsLoopExtra::offset` when supplied, otherwise an empty
2179        /// array. Maps to `PirlsResult::final_offset`.
2180        pub final_offset: Array1<f64>,
2181        /// β in the canonical transformed basis. Always equals
2182        /// `beta` because the GPU loop solved in the transformed
2183        /// design `X·Qs`, so the loop's β is already transformed.
2184        /// Maps to `PirlsResult::beta_transformed`.
2185        pub beta_transformed: Array1<f64>,
2186        /// Hessian-side `finalweights` after optional Fisher→observed
2187        /// promotion driven by `extra.exported_curvature`. Empty when
2188        /// `extra` is `None`.
2189        pub finalweights: Array1<f64>,
2190        /// Score-side `solveweights` (= `final_w_solver`) echoed
2191        /// through so the dispatch wirer can stamp directly.
2192        pub solveweights: Array1<f64>,
2193        /// Solve-side `dμ/dη` at the converged η, family-specific.
2194        /// From `computeworkingweight_derivatives_from_eta`. Empty
2195        /// when `extra` is `None`.
2196        pub solve_dmu_deta: Array1<f64>,
2197        /// Solve-side `d²μ/dη²`. Empty when `extra` is `None`.
2198        pub solve_d2mu_deta2: Array1<f64>,
2199        /// Solve-side `d³μ/dη³`. Empty when `extra` is `None`.
2200        pub solve_d3mu_deta3: Array1<f64>,
2201        /// `c_i = dW_i/dη_i` at the converged mode (Fisher or
2202        /// observed depending on `extra.exported_curvature`). Maps to
2203        /// `PirlsResult::solve_c_array`. Empty when `extra` is `None`.
2204        pub solve_c_array: Array1<f64>,
2205        /// `d_i = d²W_i/dη_i²`. Maps to `PirlsResult::solve_d_array`.
2206        /// Empty when `extra` is `None`.
2207        pub solve_d_array: Array1<f64>,
2208        /// `true` when the family's analytic 3rd/4th derivatives are
2209        /// not supported and the c/d arrays are placeholders. Mirrors
2210        /// `PirlsResult::derivatives_unsupported`.
2211        pub derivatives_unsupported: bool,
2212        /// PirlsStatus the dispatch wirer should propagate. Emitted as
2213        /// `Converged` when the loop's tolerance test passed and
2214        /// `final_eta`/`final_mu` are finite; `Unstable` when any of
2215        /// those go non-finite; `MaxIterationsReached` when the loop
2216        /// hit its iteration cap without converging.
2217        pub status: crate::pirls::PirlsStatus,
2218        /// Ridge passport carrying the stabilization δ and policy.
2219        /// When `extra.ridge_passport` is `Some`, this is the supplied
2220        /// value verbatim. Otherwise a default `scaled_identity(
2221        /// objective_ridge, explicit_stabilization_full())` passport.
2222        pub ridge_passport: gam_problem::RidgePassport,
2223        /// Firth diagnostics. `Inactive` unless the caller passes an
2224        /// `Active` value through `extra.firth`.
2225        pub firth: crate::pirls::FirthDiagnostics,
2226        /// KKT diagnostics for `extra.linear_constraints`. `None`
2227        /// either when no constraints are supplied or when the
2228        /// constraint system is empty.
2229        pub constraint_kkt: Option<crate::active_set::ConstraintKktDiagnostics>,
2230        /// Effective degrees of freedom. Echoed from `extra.edf`;
2231        /// `f64::NAN` when not supplied.
2232        pub edf: f64,
2233        /// `prev_deviance − accepted_deviance` at the accepted step
2234        /// that terminated the loop. Matches the CPU oracle's
2235        /// `WorkingModelPirlsResult::last_deviance_change`.
2236        pub last_deviance_change: f64,
2237        /// Number of line-search halvings consumed on the accepted
2238        /// step (`k` when α = `0.5^k`; `0` when α = 1). When the
2239        /// ladder was fully exhausted (`step_search_exhausted`), this
2240        /// is `0` and `last_step_size = 0.0` — no step was committed.
2241        /// Mirrors `WorkingModelPirlsResult::last_step_halving`.
2242        pub last_step_halving: usize,
2243        /// Step size α that was accepted at the final iteration.
2244        /// Mirrors `WorkingModelPirlsResult::last_step_size`.
2245        pub last_step_size: f64,
2246        /// Levenberg-Marquardt damping coefficient (step_lm_lambda) in
2247        /// effect at the last accepted iter. The GPU loop has no
2248        /// on-device ridge escalation (it is a constant per call), so
2249        /// this echoes the input `step_lm_lambda`. Maps to
2250        /// `PirlsResult::final_lm_lambda`.
2251        pub final_lm_lambda: f64,
2252        /// Running minimum of the data-side deviance observed across
2253        /// all accepted Newton steps. The GPU loop only knows the
2254        /// data deviance device-side; the dispatch wirer can add
2255        /// `βᵀ·penalty_hessian·β` at the converged β to obtain the
2256        /// fully penalised running minimum when needed for
2257        /// `PirlsResult::min_penalized_deviance`.
2258        pub min_deviance: f64,
2259        /// `max_i |η_i|` at the accepted final step — the saturation
2260        /// diagnostic the CPU oracle stamps on
2261        /// `PirlsResult::max_abs_eta`. Used by REML's
2262        /// perfect-separation detection.
2263        pub max_abs_eta: f64,
2264        /// Bitwise-OR of all per-row status flags across the n rows at
2265        /// the final accepted PIRLS step. Carries
2266        /// [`crate::gpu_kernels::pirls_row::status_flags`] bits so callers can
2267        /// distinguish saturation (`ETA_CLAMPED`), numerical floor
2268        /// (`MU_FLOORED`), or invalid input (`INVALID_RESPONSE`,
2269        /// `ZERO_PRIOR_WEIGHT`). A value of 0 means no per-row
2270        /// anomaly was detected. Contributes to the `Unstable`
2271        /// classification when forbidden bits are set.
2272        pub per_row_status_or: u32,
2273    }
2274
2275    /// Full device-resident PIRLS loop. Only three scalar (1 f64)
2276    /// downloads per Newton iter (deviance, direction-L∞, candidate
2277    /// deviance per α). β + final H downloaded once at exit.
2278    pub(super) fn pirls_loop(
2279        shared: &PirlsGpuSharedData,
2280        ws: &mut SigmaPirlsGpuWorkspace,
2281        loop_ws: &mut PirlsLoopWorkspace,
2282        family: crate::gpu_kernels::pirls_row::PirlsRowFamily,
2283        curvature: crate::gpu_kernels::pirls_row::CurvatureMode,
2284        // Active Gamma dispersion shape (α > 0). Forwarded to every
2285        // `launch_row_reweight_on_stream` call. Pass `1.0` for non-Gamma fits.
2286        gamma_shape: f64,
2287        beta0_host: ArrayView1<'_, f64>,
2288        penalty_hessian: ArrayView2<'_, f64>,
2289        // Linear shift `b` of the shifted-quadratic penalty
2290        // `βᵀSβ − 2βᵀb + c`. Length `p`. Mirrors
2291        // `PirlsPenalty::linear_shift()` in the CPU oracle. Pass a zero
2292        // vector for fits with no prior-mean shift.
2293        linear_shift: ArrayView1<'_, f64>,
2294        // Constant shift `c` of the shifted-quadratic penalty. Pass
2295        // `0.0` for fits with no prior-mean shift.
2296        constant_shift: f64,
2297        // Temporary LM damping for the Newton solves only; never enters
2298        // RidgePassport / exported Hessian / EDF / penalty term.
2299        lm_ridge: f64,
2300        // Real model-objective ridge; enters RidgePassport / exported
2301        // Hessian / EDF / penalty term.
2302        objective_ridge: f64,
2303        max_iter: usize,
2304        tol: f64,
2305        extra: Option<&PirlsLoopExtra<'_>>,
2306    ) -> Result<PirlsLoopOutcome, String> {
2307        let n = shared.n;
2308        let p = shared.p;
2309        if loop_ws.n != n || loop_ws.p != p {
2310            return Err(format!(
2311                "loop workspace ({}, {}) ≠ shared ({n}, {p})",
2312                loop_ws.n, loop_ws.p
2313            ));
2314        }
2315        if beta0_host.len() != p {
2316            return Err(format!("beta0 length {} ≠ p={p}", beta0_host.len()));
2317        }
2318
2319        if linear_shift.len() != p {
2320            return Err(format!(
2321                "linear_shift length {} ≠ p={p}",
2322                linear_shift.len()
2323            ));
2324        }
2325        if penalty_hessian.dim() != (p, p) {
2326            return Err(format!(
2327                "penalty_hessian shape {:?} ≠ (p={p}, p={p})",
2328                penalty_hessian.dim()
2329            ));
2330        }
2331
2332        ws.stream
2333            .memcpy_htod(
2334                beta0_host.as_slice().ok_or("beta0 not contiguous")?,
2335                &mut loop_ws.beta_dev,
2336            )
2337            .map_err(|e| format!("upload beta0: {e}"))?;
2338
2339        let backend = crate::gpu_kernels::pirls_row::PirlsRowBackend::probe()
2340            .map_err(|e| format!("pirls_row backend: {e}"))?;
2341        let loop_module = PIRLS_LOOP_CACHE
2342            .get_or_compile(&shared.ctx, "pirls_loop", PIRLS_LOOP_PTX_SOURCE)
2343            .map_err(|e| format!("pirls loop module: {e}"))?;
2344        let axpy_func = loop_module
2345            .load_function("axpy_n")
2346            .map_err(|e| format!("load axpy_n: {e}"))?;
2347        let sum_func = loop_module
2348            .load_function("deviance_sum")
2349            .map_err(|e| format!("load deviance_sum: {e}"))?;
2350        let linf_func = loop_module
2351            .load_function("linf_norm")
2352            .map_err(|e| format!("load linf_norm: {e}"))?;
2353        let status_or_func = loop_module
2354            .load_function("status_or")
2355            .map_err(|e| format!("load status_or: {e}"))?;
2356
2357        // beta_orig = Qs · beta  (transforms from transformed to original coords).
2358        // For identity Qs, this is a copy; always goes through ws.beta_orig_dev.
2359        gemv_no_trans(
2360            &ws.blas,
2361            p,
2362            p,
2363            &ws.qs_dev,
2364            &loop_ws.beta_dev,
2365            &mut ws.beta_orig_dev,
2366        )?;
2367        // η = X_original · beta_orig  then η += offset (#258).
2368        gemv_no_trans(
2369            &ws.blas,
2370            n,
2371            p,
2372            &shared.x_original_dev,
2373            &ws.beta_orig_dev,
2374            &mut loop_ws.eta_dev,
2375        )?;
2376        axpy(
2377            &ws.stream,
2378            &axpy_func,
2379            1.0,
2380            &shared.offset_dev,
2381            &mut loop_ws.eta_dev,
2382            n,
2383        )?;
2384        // Initial solve-row pass on the starting η (4-output kernel only).
2385        crate::gpu_kernels::pirls_row::launch_solve_row_on_stream(
2386            backend,
2387            family,
2388            curvature,
2389            gamma_shape,
2390            &ws.stream,
2391            n,
2392            &loop_ws.eta_dev,
2393            &shared.y_dev,
2394            &shared.prior_w_dev,
2395            &mut loop_ws.row_solve,
2396        )
2397        .map_err(|e| format!("solve-row init: {e}"))?;
2398
2399        let mut prev_deviance = reduce_scalar(
2400            &ws.stream,
2401            &sum_func,
2402            &loop_ws.row_solve.deviance,
2403            n,
2404            &mut loop_ws.scalar_dev,
2405            "deviance_init",
2406        )?;
2407        let mut last_logdet = 0.0_f64;
2408        let mut converged = false;
2409
2410        // Host-side mirror of `beta_dev`. Maintained in lock-step with
2411        // every accepted Newton step so we can evaluate the
2412        // shifted-quadratic penalty `βᵀSβ − 2βᵀlinear_shift +
2413        // constant_shift` on the host without an extra `β` DtoH per
2414        // iteration. The initial state is `beta0_host` verbatim.
2415        let mut beta_host: Array1<f64> = beta0_host.to_owned();
2416
2417        // Initial *penalized* objective = data-deviance(β₀) + shifted
2418        // quadratic(β₀). This is the value the line search and
2419        // convergence test compare candidates against — matches the CPU
2420        // oracle's `penalized_objective` in `CandidateScreen`.
2421        let s_beta0 = penalty_hessian.dot(&beta_host);
2422        let penalty_init =
2423            beta_host.dot(&s_beta0) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
2424        let mut prev_objective = prev_deviance + penalty_init;
2425
2426        // Diagnostic scalars surfaced on the outcome so the dispatch
2427        // wirer can populate WorkingModelPirlsResult / PirlsResult
2428        // fields without re-running the loop. They mirror the CPU
2429        // oracle's per-iter tracking in runworking_model_pirls; the
2430        // "deviance change" diagnostic now carries the *penalized*
2431        // objective delta (matches the CPU oracle's convergence-test
2432        // input and what the issue requested).
2433        let mut last_dev_delta = 0.0_f64;
2434        let mut last_halving: usize = 0;
2435        let mut last_step_size = 0.0_f64;
2436        let mut min_dev = prev_deviance;
2437        let mut step_search_exhausted = false;
2438
2439        for it in 0..max_iter {
2440            last_logdet = solve_step_on_stream_device_inplace(
2441                shared,
2442                ws,
2443                PirlsStepStreamDeviceInput {
2444                    w_solver_dev: &loop_ws.row_solve.w_solver,
2445                    grad_eta_dev: &loop_ws.row_solve.grad_eta,
2446                    penalty_hessian,
2447                    step_lm_lambda: lm_ridge,
2448                    objective_ridge,
2449                    beta_dev: &loop_ws.beta_dev,
2450                    linear_shift,
2451                },
2452            )
2453            .map_err(|e| format!("inner step it={it}: {e}"))?;
2454            // ws.rhs_dev holds the Newton descent direction δ = H⁻¹·rhs (#257).
2455            // Copy device-to-device: no host round-trip.
2456            ws.stream
2457                .memcpy_dtod(&ws.rhs_dev, &mut loop_ws.direction_dev)
2458                .map_err(|e| format!("direction d2d copy it={it}: {e}"))?;
2459
2460            let dir_linf = reduce_scalar(
2461                &ws.stream,
2462                &linf_func,
2463                &loop_ws.direction_dev,
2464                p,
2465                &mut loop_ws.scalar_dev,
2466                "dir_linf",
2467            )?;
2468
2469            // dir_orig = Qs · direction (transform direction to original coords).
2470            gemv_no_trans(
2471                &ws.blas,
2472                p,
2473                p,
2474                &ws.qs_dev,
2475                &loop_ws.direction_dev,
2476                &mut ws.dir_orig_dev,
2477            )?;
2478            gemv_no_trans(
2479                &ws.blas,
2480                n,
2481                p,
2482                &shared.x_original_dev,
2483                &ws.dir_orig_dev,
2484                &mut loop_ws.xd_dev,
2485            )?;
2486
2487            // -- Fused alpha-ladder (candidate-objective mode) ----------------
2488            // One kernel launch evaluates eta + alpha_k*xdelta for all k in
2489            // ALPHA_LADDER simultaneously, atomically accumulating per-row
2490            // deviance into objective_dev[k] and OR-accumulating status flags
2491            // into status_dev[k].  A single DtoH of 7+7 scalars selects the
2492            // accepted step -- no per-alpha kernel launch, no full row-output
2493            // write, no per-alpha host scalar sync.
2494            loop_ws
2495                .alpha_ladder
2496                .zero(&ws.stream)
2497                .map_err(|e| format!("ladder zero it={it}: {e}"))?;
2498            crate::gpu_kernels::pirls_row::launch_alpha_ladder_on_stream(
2499                backend,
2500                family,
2501                curvature,
2502                gamma_shape,
2503                &ws.stream,
2504                n,
2505                &loop_ws.eta_dev,
2506                &loop_ws.xd_dev,
2507                &shared.y_dev,
2508                &shared.prior_w_dev,
2509                &mut loop_ws.alpha_ladder,
2510            )
2511            .map_err(|e| format!("alpha-ladder it={it}: {e}"))?;
2512            let obj_host: Vec<f64> = ws
2513                .stream
2514                .clone_dtoh(&loop_ws.alpha_ladder.objective_dev)
2515                .map_err(|e| format!("ladder dtoh obj it={it}: {e}"))?;
2516            let stat_host: Vec<u32> = ws
2517                .stream
2518                .clone_dtoh(&loop_ws.alpha_ladder.status_dev)
2519                .map_err(|e| format!("ladder dtoh stat it={it}: {e}"))?;
2520            // Download the direction (p << n; one DtoH per iteration to
2521            // compute the host-side penalty term and maintain beta_host).
2522            let direction_host: Vec<f64> = ws
2523                .stream
2524                .clone_dtoh(&loop_ws.direction_dev)
2525                .map_err(|e| format!("dtoh direction it={it}: {e}"))?;
2526
2527            // Penalized objective for each candidate step:
2528            //   obj_pen[k] = deviance(eta + alpha_k * xd)
2529            //               + (beta + alpha_k * d)^T S (beta + alpha_k * d)
2530            //               - 2 (beta + alpha_k * d) . linear_shift
2531            //               + constant_shift
2532            // The quadratic in alpha expands as:
2533            //   penalty(beta) + alpha * [2 d^T (S beta - linear_shift)]
2534            //                  + alpha^2 * d^T S d
2535            let dir_view = ndarray::aview1(&direction_host);
2536            let sd = penalty_hessian.dot(&dir_view);
2537            let s_beta = penalty_hessian.dot(&beta_host);
2538            let dtsd = dir_view.dot(&sd);
2539            let linear_coeff = 2.0 * dir_view.dot(&(&s_beta - &linear_shift));
2540            let penalty_beta =
2541                beta_host.dot(&s_beta) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
2542
2543            const FORBIDDEN_LINESEARCH: u32 =
2544                crate::gpu_kernels::pirls_row::status_flags::INVALID_RESPONSE
2545                    | crate::gpu_kernels::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
2546            let mut alpha = 0.0_f64;
2547            let mut accepted_dev = prev_deviance;
2548            let mut accepted_objective = prev_objective;
2549            let mut halving_count: usize = 0;
2550            for (k, (&dev_k, &st)) in obj_host.iter().zip(stat_host.iter()).enumerate() {
2551                let a = crate::gpu_kernels::pirls_row::ALPHA_LADDER[k];
2552                let pen_k = penalty_beta + a * linear_coeff + a * a * dtsd;
2553                let obj_k = dev_k + pen_k;
2554                // Match the CPU oracle's acceptance test (#263):
2555                // `<= prev_objective` is the `CandidateScreen`
2556                // criterion — a step that holds the penalized
2557                // objective steady (e.g. an exact zero-gradient
2558                // direction) must still be accepted so the line
2559                // search does not spuriously exhaust at a
2560                // stationary point.
2561                if obj_k.is_finite() && obj_k <= prev_objective && (st & FORBIDDEN_LINESEARCH) == 0
2562                {
2563                    alpha = a;
2564                    accepted_dev = dev_k;
2565                    accepted_objective = obj_k;
2566                    halving_count = k;
2567                    break;
2568                }
2569            }
2570            if alpha == 0.0 {
2571                // No α in the ladder produced a step lowering the
2572                // *penalized* objective. The previous code (and the
2573                // first draft of this rewrite) silently committed
2574                // α=1 here and merely *flagged* exhaustion — that
2575                // still commits a non-descent step, which is exactly
2576                // what the issue forbids (#263).
2577                //
2578                // Signal exhaustion and exit the inner loop without
2579                // committing β / η / solve-row buffers;
2580                // `build_loop_outcome` then maps
2581                // `step_search_exhausted` to
2582                // `PirlsStatus::LmStepSearchExhausted`, exactly the
2583                // CPU oracle's "no acceptable step direction even
2584                // after damping" signal. The outer REML / LM
2585                // controller can raise damping or reject the outer
2586                // iteration. β / η / prev_deviance / prev_objective
2587                // all stay at their last accepted values; the
2588                // device buffers are likewise untouched.
2589                step_search_exhausted = true;
2590                last_halving = 0;
2591                last_step_size = 0.0;
2592                last_dev_delta = 0.0;
2593                break;
2594            }
2595            step_search_exhausted = false;
2596            // Commit accepted step: beta and eta updated in-place.
2597            axpy(
2598                &ws.stream,
2599                &axpy_func,
2600                alpha,
2601                &loop_ws.direction_dev,
2602                &mut loop_ws.beta_dev,
2603                p,
2604            )?;
2605            axpy(
2606                &ws.stream,
2607                &axpy_func,
2608                alpha,
2609                &loop_ws.xd_dev,
2610                &mut loop_ws.eta_dev,
2611                n,
2612            )?;
2613            // Maintain host-side beta mirror: beta_host += alpha * direction.
2614            for (b, &d) in beta_host.iter_mut().zip(direction_host.iter()) {
2615                *b += alpha * d;
2616            }
2617            // Refresh the 4-output solve-row buffers for the next Newton iter.
2618            crate::gpu_kernels::pirls_row::launch_solve_row_on_stream(
2619                backend,
2620                family,
2621                curvature,
2622                gamma_shape,
2623                &ws.stream,
2624                n,
2625                &loop_ws.eta_dev,
2626                &shared.y_dev,
2627                &shared.prior_w_dev,
2628                &mut loop_ws.row_solve,
2629            )
2630            .map_err(|e| format!("solve-row accepted it={it}: {e}"))?;
2631
2632            let step_norm = alpha.abs() * dir_linf;
2633            let dev_delta = (prev_objective - accepted_objective).abs();
2634            last_dev_delta = dev_delta;
2635            last_halving = halving_count;
2636            last_step_size = alpha;
2637            if accepted_dev < min_dev {
2638                min_dev = accepted_dev;
2639            }
2640
2641            prev_deviance = accepted_dev;
2642            prev_objective = accepted_objective;
2643
2644            if dir_linf <= tol
2645                && step_norm <= tol
2646                && dev_delta <= tol * (1.0 + prev_objective.abs())
2647            {
2648                converged = true;
2649                // Final-row mode: write all 9 output fields once at convergence.
2650                crate::gpu_kernels::pirls_row::launch_row_reweight_on_stream(
2651                    backend,
2652                    family,
2653                    curvature,
2654                    gamma_shape,
2655                    &ws.stream,
2656                    n,
2657                    &loop_ws.eta_dev,
2658                    &shared.y_dev,
2659                    &shared.prior_w_dev,
2660                    &mut loop_ws.row_final,
2661                )
2662                .map_err(|e| format!("final-row converged: {e}"))?;
2663                let h_final = rebuild_h_final(
2664                    shared,
2665                    ws,
2666                    &loop_ws.row_final.w_hessian,
2667                    penalty_hessian,
2668                    objective_ridge,
2669                )
2670                .map_err(|e| format!("rebuild H_final (converged): {e}"))?;
2671                return build_loop_outcome(
2672                    ws,
2673                    loop_ws,
2674                    h_final,
2675                    last_logdet,
2676                    prev_deviance,
2677                    it + 1,
2678                    converged,
2679                    lm_ridge,
2680                    objective_ridge,
2681                    extra,
2682                    LoopDiagnostics {
2683                        last_deviance_change: last_dev_delta,
2684                        last_step_halving: last_halving,
2685                        last_step_size,
2686                        min_deviance: min_dev,
2687                        step_search_exhausted,
2688                    },
2689                    &status_or_func,
2690                );
2691            }
2692        }
2693
2694        // Final-row mode: write all 9 output fields once at max-iter exit.
2695        crate::gpu_kernels::pirls_row::launch_row_reweight_on_stream(
2696            backend,
2697            family,
2698            curvature,
2699            gamma_shape,
2700            &ws.stream,
2701            n,
2702            &loop_ws.eta_dev,
2703            &shared.y_dev,
2704            &shared.prior_w_dev,
2705            &mut loop_ws.row_final,
2706        )
2707        .map_err(|e| format!("final-row max_iter: {e}"))?;
2708        let h_final = rebuild_h_final(
2709            shared,
2710            ws,
2711            &loop_ws.row_final.w_hessian,
2712            penalty_hessian,
2713            objective_ridge,
2714        )
2715        .map_err(|e| format!("rebuild H_final (max_iter): {e}"))?;
2716        build_loop_outcome(
2717            ws,
2718            loop_ws,
2719            h_final,
2720            last_logdet,
2721            prev_deviance,
2722            max_iter,
2723            converged,
2724            lm_ridge,
2725            objective_ridge,
2726            extra,
2727            LoopDiagnostics {
2728                last_deviance_change: last_dev_delta,
2729                last_step_halving: last_halving,
2730                last_step_size,
2731                min_deviance: min_dev,
2732                step_search_exhausted,
2733            },
2734            &status_or_func,
2735        )
2736    }
2737
2738    /// Internal carrier for the scalar diagnostics tracked across the
2739    /// inner Newton loop. Surfaced verbatim on `PirlsLoopOutcome` so the
2740    /// dispatch wirer's plumbing to `WorkingModelPirlsResult` is a
2741    /// direct field copy.
2742    ///
2743    /// `step_search_exhausted` is the GPU mirror of the CPU oracle's
2744    /// `PirlsStatus::LmStepSearchExhausted` signal: the line-search
2745    /// halving ladder produced no step that lowered the *penalized*
2746    /// objective. When true, `build_loop_outcome` promotes the emitted
2747    /// status accordingly so the outer REML / LM controller can raise
2748    /// damping or fail the iteration cleanly instead of being handed a
2749    /// silently non-descent step.
2750    struct LoopDiagnostics {
2751        last_deviance_change: f64,
2752        last_step_halving: usize,
2753        last_step_size: f64,
2754        min_deviance: f64,
2755        step_search_exhausted: bool,
2756    }
2757
2758    /// Build a full-surface [`PirlsLoopOutcome`] from the loop's
2759    /// device-resident state plus optional caller-supplied
2760    /// [`PirlsLoopExtra`] context.
2761    ///
2762    /// Five n-vector DtoH downloads are unavoidable (η, μ, grad_η,
2763    /// w_hessian, w_solver); β is one p-vector download. When `extra`
2764    /// is `Some`, the host-side helpers
2765    /// `computeworkingweight_derivatives_from_eta` and (optionally)
2766    /// `compute_observed_hessian_curvature_arrays` produce the
2767    /// solve-side aux jets and the curvature-promoted Hessian-side
2768    /// weights; `compute_constraint_kkt_diagnostics` runs over the
2769    /// converged β and reconstructed penalised gradient. All of this
2770    /// is bit-identical to the corresponding CPU oracle code paths in
2771    /// `fit_model_for_fixed_rho_with_adaptive_kkt`.
2772    fn build_loop_outcome(
2773        ws: &mut SigmaPirlsGpuWorkspace,
2774        loop_ws: &mut PirlsLoopWorkspace,
2775        penalized_hessian: Array2<f64>,
2776        logdet: f64,
2777        deviance: f64,
2778        iterations: usize,
2779        converged: bool,
2780        step_lm_lambda: f64,
2781        objective_ridge: f64,
2782        extra: Option<&PirlsLoopExtra<'_>>,
2783        diagnostics: LoopDiagnostics,
2784        status_or_func: &cudarc::driver::CudaFunction,
2785    ) -> Result<PirlsLoopOutcome, String> {
2786        let beta = download_vec(&ws.stream, &loop_ws.beta_dev)?;
2787        let final_eta = download_vec(&ws.stream, &loop_ws.eta_dev)?;
2788        let final_mu = download_vec(&ws.stream, &loop_ws.row_final.mu)?;
2789        let final_grad_eta = download_vec(&ws.stream, &loop_ws.row_final.grad_eta)?;
2790        let final_w_hessian = download_vec(&ws.stream, &loop_ws.row_final.w_hessian)?;
2791        let final_w_solver = download_vec(&ws.stream, &loop_ws.row_final.w_solver)?;
2792
2793        // OR-reduce the per-row status flags of the final accepted step.
2794        // Any INVALID_RESPONSE or ZERO_PRIOR_WEIGHT bit that survived to
2795        // the accepted iterate means the line-search fallback swallowed a
2796        // structurally bad candidate; classify as Unstable.
2797        let n_rows = loop_ws.n;
2798        let final_row_status = reduce_status_or(
2799            &ws.stream,
2800            status_or_func,
2801            &loop_ws.row_final.status,
2802            n_rows,
2803            &mut loop_ws.status_u32_dev,
2804            "final_row_status",
2805        )?;
2806        const FORBIDDEN_FINAL: u32 = crate::gpu_kernels::pirls_row::status_flags::INVALID_RESPONSE
2807            | crate::gpu_kernels::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
2808
2809        // Stability classification — Unstable supersedes both
2810        // converged and MaxIterationsReached because a non-finite η /
2811        // μ at the accepted step means the line search swallowed a
2812        // divergence (saturated likelihood / perfect separation).
2813        // Also Unstable when forbidden row-status bits are set.
2814        let eta_finite = final_eta.iter().all(|v| v.is_finite());
2815        let mu_finite = final_mu.iter().all(|v| v.is_finite());
2816        let beta_finite = beta.iter().all(|v| v.is_finite());
2817        let stability_ok =
2818            eta_finite && mu_finite && beta_finite && (final_row_status & FORBIDDEN_FINAL) == 0;
2819        let status = if !stability_ok {
2820            crate::pirls::PirlsStatus::Unstable
2821        } else if converged {
2822            crate::pirls::PirlsStatus::Converged
2823        } else if diagnostics.step_search_exhausted {
2824            // The α-ladder produced no step lowering the *penalized*
2825            // objective — exactly the CPU oracle's "no acceptable step
2826            // direction even after damping" signal. Distinct from the
2827            // iteration-cap exhaustion (MaxIterationsReached) so the
2828            // outer REML / LM controller can react (raise damping / try
2829            // a different curvature) rather than silently accepting an
2830            // ascent step.
2831            crate::pirls::PirlsStatus::LmStepSearchExhausted
2832        } else {
2833            crate::pirls::PirlsStatus::MaxIterationsReached
2834        };
2835
2836        // RidgePassport is built from objective_ridge only — step_lm_lambda
2837        // is a solve-only artefact and must never contaminate EDF / REML.
2838        let default_ridge = gam_problem::RidgePassport::scaled_identity(
2839            objective_ridge,
2840            gam_linalg::RidgePolicy::explicit_stabilization_full(),
2841        );
2842
2843        let max_abs_eta = final_eta.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2844
2845        match extra {
2846            Some(ext) => {
2847                // Family aux jets at the converged η — bit-identical
2848                // to the CPU oracle's post-convergence finalization.
2849                let (score_c, score_d, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
2850                    crate::pirls::computeworkingweight_derivatives_from_eta(
2851                        ext.likelihood,
2852                        ext.inverse_link,
2853                        &final_eta,
2854                        ext.priorweights,
2855                    )
2856                    .map_err(|e| format!("pirls postpass dmu/deta: {e:?}"))?;
2857
2858                let (finalweights, solve_c_array, solve_d_array) = match ext.exported_curvature {
2859                    crate::pirls::HessianCurvatureKind::Observed => {
2860                        crate::pirls::compute_observed_hessian_curvature_arrays(
2861                            ext.likelihood,
2862                            ext.inverse_link,
2863                            &final_eta,
2864                            ext.y,
2865                            &final_w_solver,
2866                            ext.priorweights,
2867                        )
2868                        .map_err(|e| format!("pirls postpass observed curvature: {e:?}"))?
2869                    }
2870                    crate::pirls::HessianCurvatureKind::Fisher => {
2871                        (final_w_solver.clone(), score_c.clone(), score_d.clone())
2872                    }
2873                };
2874
2875                // The GPU loop solves in the transformed design X·Qs, so
2876                // the loop's β is already in transformed coordinates.
2877                // beta_original = qs · beta_transformed (not applied here;
2878                // callers that need original coordinates compute it from
2879                // reparam_result.qs per the PirlsResult contract).
2880                let beta_transformed = beta.clone();
2881
2882                let constraint_kkt = ext.linear_constraints.and_then(|lin| {
2883                    if lin.a.nrows() == 0 {
2884                        return None;
2885                    }
2886                    // Reconstruct the penalised gradient at the
2887                    // converged β: g = Xᵀ(grad_eta) + S β + objective_ridge·β.
2888                    // `penalized_hessian` is already XᵀWX + S + objective_ridge·I
2889                    // (step_lm_lambda was stripped from the export), so
2890                    // H_pen·β ≈ Xᵀ·grad_eta at a KKT-feasible solution.
2891                    let grad = penalized_hessian.dot(&beta);
2892                    Some(
2893                        crate::active_set::compute_constraint_kkt_diagnostics(
2894                            &beta, &grad, lin,
2895                        ),
2896                    )
2897                });
2898
2899                let ridge_passport = ext.ridge_passport.unwrap_or(default_ridge);
2900                let firth = ext
2901                    .firth
2902                    .clone()
2903                    .unwrap_or(crate::pirls::FirthDiagnostics::Inactive);
2904                let edf = ext.edf.unwrap_or(f64::NAN);
2905                // Mirrors CPU oracle's invariant: when
2906                // `computeworkingweight_derivatives_from_eta` returns
2907                // Ok, all five jets are real (not placeholders), so
2908                // this field is `false`. See
2909                // `src/solver/pirls.rs:6634`.
2910                let derivatives_unsupported = false;
2911
2912                Ok(PirlsLoopOutcome {
2913                    beta,
2914                    penalized_hessian,
2915                    logdet,
2916                    deviance,
2917                    iterations,
2918                    converged,
2919                    final_eta,
2920                    final_mu,
2921                    final_grad_eta,
2922                    final_w_hessian,
2923                    final_w_solver: final_w_solver.clone(),
2924                    final_offset: ext.offset.to_owned(),
2925                    beta_transformed,
2926                    finalweights,
2927                    solveweights: final_w_solver,
2928                    solve_dmu_deta,
2929                    solve_d2mu_deta2,
2930                    solve_d3mu_deta3,
2931                    solve_c_array,
2932                    solve_d_array,
2933                    derivatives_unsupported,
2934                    status,
2935                    ridge_passport,
2936                    firth,
2937                    constraint_kkt,
2938                    edf,
2939                    last_deviance_change: diagnostics.last_deviance_change,
2940                    last_step_halving: diagnostics.last_step_halving,
2941                    last_step_size: diagnostics.last_step_size,
2942                    final_lm_lambda: step_lm_lambda,
2943                    min_deviance: diagnostics.min_deviance,
2944                    max_abs_eta,
2945                    per_row_status_or: final_row_status,
2946                })
2947            }
2948            None => {
2949                // No extra context — pirls-dispatch-wirer can do the
2950                // derived-field plumbing host-side if needed. We give
2951                // it `solveweights = final_w_solver` echoed through,
2952                // empty arrays everywhere else, and safe default
2953                // status / passport / firth so the struct is fully
2954                // populated and the wirer's match arms can rely on
2955                // every field being present.
2956                Ok(PirlsLoopOutcome {
2957                    beta: beta.clone(),
2958                    penalized_hessian,
2959                    logdet,
2960                    deviance,
2961                    iterations,
2962                    converged,
2963                    final_eta,
2964                    final_mu,
2965                    final_grad_eta,
2966                    final_w_hessian,
2967                    final_w_solver: final_w_solver.clone(),
2968                    final_offset: Array1::<f64>::zeros(0),
2969                    beta_transformed: beta,
2970                    finalweights: Array1::<f64>::zeros(0),
2971                    solveweights: final_w_solver,
2972                    solve_dmu_deta: Array1::<f64>::zeros(0),
2973                    solve_d2mu_deta2: Array1::<f64>::zeros(0),
2974                    solve_d3mu_deta3: Array1::<f64>::zeros(0),
2975                    solve_c_array: Array1::<f64>::zeros(0),
2976                    solve_d_array: Array1::<f64>::zeros(0),
2977                    derivatives_unsupported: true,
2978                    status,
2979                    ridge_passport: default_ridge,
2980                    firth: crate::pirls::FirthDiagnostics::Inactive,
2981                    constraint_kkt: None,
2982                    edf: f64::NAN,
2983                    last_deviance_change: diagnostics.last_deviance_change,
2984                    last_step_halving: diagnostics.last_step_halving,
2985                    last_step_size: diagnostics.last_step_size,
2986                    final_lm_lambda: step_lm_lambda,
2987                    min_deviance: diagnostics.min_deviance,
2988                    max_abs_eta,
2989                    per_row_status_or: final_row_status,
2990                })
2991            }
2992        }
2993    }
2994
2995    fn gemv_no_trans(
2996        blas: &CudaBlas,
2997        n: usize,
2998        p: usize,
2999        a_dev: &CudaSlice<f64>,
3000        x_dev: &CudaSlice<f64>,
3001        y_dev: &mut CudaSlice<f64>,
3002    ) -> Result<(), String> {
3003        let n_i = to_i32(n)?;
3004        let p_i = to_i32(p)?;
3005        let cfg = GemvConfig::<f64> {
3006            trans: cublasOperation_t::CUBLAS_OP_N,
3007            m: n_i,
3008            n: p_i,
3009            alpha: 1.0,
3010            lda: n_i,
3011            incx: 1,
3012            beta: 0.0,
3013            incy: 1,
3014        };
3015        // SAFETY: a is n×p col-major lda=n; x length p incx=1; y length n incy=1.
3016        unsafe { blas.gemv(cfg, a_dev, x_dev, y_dev) }.map_err(|e| format!("dgemv no-trans: {e}"))
3017    }
3018
3019    fn axpy(
3020        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3021        func: &cudarc::driver::CudaFunction,
3022        alpha: f64,
3023        x_dev: &CudaSlice<f64>,
3024        y_dev: &mut CudaSlice<f64>,
3025        n: usize,
3026    ) -> Result<(), String> {
3027        const THREADS: u32 = 256;
3028        let n_i = to_i32(n)?;
3029        let n_u = u32::try_from(n).map_err(|_| format!("axpy n={n} > u32"))?;
3030        let grid = n_u.div_ceil(THREADS).max(1);
3031        let cfg = LaunchConfig {
3032            grid_dim: (grid, 1, 1),
3033            block_dim: (THREADS, 1, 1),
3034            shared_mem_bytes: 0,
3035        };
3036        let mut builder = stream.launch_builder(func);
3037        builder.arg(&alpha);
3038        builder.arg(x_dev);
3039        builder.arg(y_dev);
3040        builder.arg(&n_i);
3041        // SAFETY: axpy_n signature is (double, const double*, double*, int);
3042        // both vectors length n.
3043        unsafe { builder.launch(cfg) }
3044            .map(|_event_pair| ())
3045            .map_err(|e| format!("axpy launch: {e}"))
3046    }
3047
3048    fn reduce_scalar(
3049        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3050        func: &cudarc::driver::CudaFunction,
3051        src: &CudaSlice<f64>,
3052        len: usize,
3053        scalar_dev: &mut CudaSlice<f64>,
3054        label: &'static str,
3055    ) -> Result<f64, String> {
3056        const THREADS: u32 = 1024;
3057        let len_i = to_i32(len)?;
3058        let cfg = LaunchConfig {
3059            grid_dim: (1, 1, 1),
3060            block_dim: (THREADS, 1, 1),
3061            shared_mem_bytes: 0,
3062        };
3063        let mut builder = stream.launch_builder(func);
3064        builder.arg(src);
3065        builder.arg(&len_i);
3066        builder.arg(&mut *scalar_dev);
3067        // SAFETY: kernel signature (const double*, int, double*). The
3068        // `&mut *scalar_dev` reborrow keeps `scalar_dev` available for the
3069        // download below.
3070        unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} reduce launch: {e}"))?;
3071        let host = stream
3072            .clone_dtoh(scalar_dev)
3073            .map_err(|e| format!("download {label}: {e}"))?;
3074        Ok(host[0])
3075    }
3076
3077    /// OR-reduce a device-resident `u32` status array into a single `u32`.
3078    /// Mirrors [`reduce_scalar`] for `f64` deviance reductions: single-block,
3079    /// 1024-thread launch, one scalar DtoH download.
3080    fn reduce_status_or(
3081        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3082        func: &cudarc::driver::CudaFunction,
3083        src: &CudaSlice<u32>,
3084        len: usize,
3085        status_dev: &mut CudaSlice<u32>,
3086        label: &'static str,
3087    ) -> Result<u32, String> {
3088        const THREADS: u32 = 1024;
3089        let len_i = to_i32(len)?;
3090        let cfg = LaunchConfig {
3091            grid_dim: (1, 1, 1),
3092            block_dim: (THREADS, 1, 1),
3093            shared_mem_bytes: 0,
3094        };
3095        let mut builder = stream.launch_builder(func);
3096        builder.arg(src);
3097        builder.arg(&len_i);
3098        builder.arg(&mut *status_dev);
3099        // SAFETY: status_or kernel signature (const unsigned int*, int,
3100        // unsigned int*). The reborrow keeps `status_dev` available.
3101        unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} or reduce launch: {e}"))?;
3102        let host = stream
3103            .clone_dtoh(status_dev)
3104            .map_err(|e| format!("download {label}: {e}"))?;
3105        Ok(host[0])
3106    }
3107
3108    fn download_vec(
3109        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
3110        dev: &CudaSlice<f64>,
3111    ) -> Result<Array1<f64>, String> {
3112        let host = stream
3113            .clone_dtoh(dev)
3114            .map_err(|e| format!("download vec: {e}"))?;
3115        Ok(Array1::from_vec(host))
3116    }
3117
3118    /// Result of one GPU Gaussian exact penalised least-squares solve.
3119    pub struct GaussianPlsResult {
3120        pub beta: Array1<f64>,
3121        pub penalized_hessian: Array2<f64>,
3122        pub logdet: f64,
3123    }
3124
3125    /// Exact GPU PLS for Gaussian-identity: assembles QsT A Qs + S on host,
3126    /// then runs POTRF/POTRS on device.  Replaces the PIRLS loop for this family.
3127    pub fn solve_gaussian_pls_on_stream(
3128        a_orig: ArrayView2<'_, f64>,
3129        b_orig: ArrayView1<'_, f64>,
3130        s_transformed: ArrayView2<'_, f64>,
3131        linear_shift: ArrayView1<'_, f64>,
3132        prior_mean_target: ArrayView1<'_, f64>,
3133        ridge: f64,
3134        qs: Option<ArrayView2<'_, f64>>,
3135    ) -> Result<GaussianPlsResult, String> {
3136        let p = b_orig.len();
3137        if a_orig.dim() != (p, p) {
3138            return Err(format!("A shape {:?} != ({p},{p})", a_orig.dim()));
3139        }
3140        if s_transformed.dim() != (p, p) {
3141            return Err(format!("S shape {:?} != ({p},{p})", s_transformed.dim()));
3142        }
3143        if linear_shift.len() != p {
3144            return Err(format!("linear_shift len {} != p={p}", linear_shift.len()));
3145        }
3146        if prior_mean_target.len() != p {
3147            return Err(format!(
3148                "prior_mean_target len {} != p={p}",
3149                prior_mean_target.len()
3150            ));
3151        }
3152        if let Some(qs_v) = qs {
3153            if qs_v.dim() != (p, p) {
3154                return Err(format!("qs shape {:?} != ({p},{p})", qs_v.dim()));
3155            }
3156        }
3157        let (h_rotated, rhs_base) = if let Some(qs_v) = qs {
3158            let qs_owned = qs_v.to_owned();
3159            let tmp = a_orig.dot(&qs_owned);
3160            let h = qs_owned.t().dot(&tmp);
3161            let rb = qs_owned.t().dot(&b_orig);
3162            (h, rb)
3163        } else {
3164            (a_orig.to_owned(), b_orig.to_owned())
3165        };
3166        let penalized_hessian: Array2<f64> = &h_rotated + &s_transformed;
3167        let mut regularized = penalized_hessian.clone();
3168        if ridge > 0.0 {
3169            for i in 0..p {
3170                regularized[[i, i]] += ridge;
3171            }
3172        }
3173        let mut rhs_host = rhs_base;
3174        rhs_host += &linear_shift;
3175        if ridge > 0.0 {
3176            rhs_host.scaled_add(ridge, &prior_mean_target);
3177        }
3178        let (ctx, stream) = context_and_stream()?;
3179        let solver = DnHandle::new(stream.clone())
3180            .map_err(|e| format!("cusolver init (gaussian pls): {e}"))?;
3181        let pp = p.checked_mul(p).ok_or("p*p overflow (gaussian pls)")?;
3182        let mut h_dev = stream
3183            .alloc_zeros::<f64>(pp)
3184            .map_err(|e| format!("alloc H (gaussian pls): {e}"))?;
3185        let mut rhs_dev = stream
3186            .alloc_zeros::<f64>(p)
3187            .map_err(|e| format!("alloc rhs (gaussian pls): {e}"))?;
3188        let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
3189        let potrf_lwork = i32::try_from(potrf_lwork_usize)
3190            .map_err(|_| "potrf lwork overflow (gaussian pls)".to_string())?;
3191        let mut potrf_work_dev = stream
3192            .alloc_zeros::<f64>(potrf_lwork_usize.max(1))
3193            .map_err(|e| format!("alloc potrf workspace (gaussian pls): {e}"))?;
3194        let mut potrf_info_dev = stream
3195            .alloc_zeros::<i32>(1)
3196            .map_err(|e| format!("alloc potrf info (gaussian pls): {e}"))?;
3197        let mut potrs_info_dev = stream
3198            .alloc_zeros::<i32>(1)
3199            .map_err(|e| format!("alloc potrs info (gaussian pls): {e}"))?;
3200        let reg_col = to_col_major(&regularized);
3201        stream
3202            .memcpy_htod(reg_col.as_ref(), &mut h_dev)
3203            .map_err(|e| format!("upload H (gaussian pls): {e}"))?;
3204        let rhs_slice = rhs_host
3205            .as_slice()
3206            .ok_or("rhs_host not contiguous (gaussian pls)")?;
3207        stream
3208            .memcpy_htod(rhs_slice, &mut rhs_dev)
3209            .map_err(|e| format!("upload rhs (gaussian pls): {e}"))?;
3210        potrf_in_place_reuse(
3211            &solver,
3212            &stream,
3213            p,
3214            potrf_lwork,
3215            &mut h_dev,
3216            &mut potrf_work_dev,
3217            &mut potrf_info_dev,
3218        )?;
3219        potrs_in_place_reuse(
3220            &solver,
3221            &stream,
3222            p,
3223            1,
3224            &h_dev,
3225            &mut rhs_dev,
3226            &mut potrs_info_dev,
3227        )?;
3228        let logdet = cholesky_logdet_device(&stream, &ctx, p, &h_dev)?;
3229        let beta_raw = stream
3230            .clone_dtoh(&rhs_dev)
3231            .map_err(|e| format!("download beta (gaussian pls): {e}"))?;
3232        check_deferred_potrf_info(&stream, &potrf_info_dev)?;
3233        check_deferred_potrs_info(&stream, &potrs_info_dev)?;
3234        Ok(GaussianPlsResult {
3235            beta: Array1::from_vec(beta_raw),
3236            penalized_hessian,
3237            logdet,
3238        })
3239    }
3240}
3241
3242pub fn weighted_crossprod_gpu(
3243    x: ArrayView2<'_, f64>,
3244    weights: ArrayView1<'_, f64>,
3245) -> Result<Array2<f64>, String> {
3246    #[cfg(not(target_os = "linux"))]
3247    {
3248        return cpu_fallback::weighted_crossprod_cpu(x, weights);
3249    }
3250
3251    #[cfg(target_os = "linux")]
3252    {
3253        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3254            return cpu_fallback::weighted_crossprod_cpu(x, weights);
3255        }
3256        cuda::weighted_crossprod(x, weights)
3257    }
3258}
3259
3260pub fn solve_pirls_step_gpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
3261    #[cfg(not(target_os = "linux"))]
3262    {
3263        return cpu_fallback::solve_step_cpu(input);
3264    }
3265
3266    #[cfg(target_os = "linux")]
3267    {
3268        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3269            return cpu_fallback::solve_step_cpu(input);
3270        }
3271        cuda::solve_step(input)
3272    }
3273}
3274
3275/// Upload X_original, y, prior_w, and offset once per model and return a
3276/// shared device-resident handle reused across all ρ / σ points. All four
3277/// arrays must have the same row-count `n`. The shared handle keeps the
3278/// cached per-ordinal `CudaContext` alive so all peer workspaces bind to
3279/// the same context and can interleave on its asynchronous engines.
3280#[cfg(target_os = "linux")]
3281pub fn upload_shared_pirls_gpu(
3282    x: ndarray::ArrayView2<'_, f64>,
3283    y: ndarray::ArrayView1<'_, f64>,
3284    prior_w: ndarray::ArrayView1<'_, f64>,
3285    offset: ndarray::ArrayView1<'_, f64>,
3286) -> Result<PirlsGpuSharedData, String> {
3287    if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3288        return Err("cuda runtime unavailable; cannot upload shared GPU PIRLS data".to_string());
3289    }
3290    PirlsGpuSharedData::upload_impl(x, y, prior_w, offset)
3291}
3292
3293/// Allocate a per-stream workspace bound to a fresh non-default CUDA
3294/// stream on `shared`'s context. The cuBLAS and cuSOLVER handles are bound
3295/// to the workspace stream so peer workspaces achieve overlapped execution.
3296#[cfg(target_os = "linux")]
3297pub fn allocate_sigma_pirls_workspace(
3298    shared: &PirlsGpuSharedData,
3299) -> Result<SigmaPirlsGpuWorkspace, String> {
3300    SigmaPirlsGpuWorkspace::allocate_impl(shared)
3301}
3302
3303/// Upload the reparameterisation matrix `Qs` (p×p) for the current ρ / σ
3304/// point. Call once per ρ / σ point before calling
3305/// [`pirls_loop_on_stream`]. When no reparameterisation is active, pass an
3306/// identity matrix.
3307#[cfg(target_os = "linux")]
3308pub fn upload_qs_pirls(
3309    ws: &mut SigmaPirlsGpuWorkspace,
3310    qs: ndarray::ArrayView2<'_, f64>,
3311) -> Result<(), String> {
3312    cuda::upload_qs(ws, qs)
3313}
3314
3315/// Upload an identity Qs for the current ρ / σ point. Equivalent to
3316/// [`upload_qs_pirls`] with an identity matrix; avoids host allocation.
3317#[cfg(target_os = "linux")]
3318pub fn upload_qs_identity_pirls(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
3319    cuda::upload_qs_identity(ws)
3320}
3321
3322/// Drive one PIRLS Newton step on the workspace's CUDA stream against the
3323/// device-resident shared design matrix. The math is bit-identical to the
3324/// one-shot [`solve_pirls_step_gpu`]; this entry differs only by
3325/// amortising the design upload and the cuBLAS / cuSOLVER handle creation
3326/// across many sigma fits.
3327#[cfg(target_os = "linux")]
3328pub fn solve_pirls_step_on_stream(
3329    shared: &PirlsGpuSharedData,
3330    ws: &mut SigmaPirlsGpuWorkspace,
3331    input: PirlsStepStreamInput<'_>,
3332) -> Result<PirlsGpuStep, String> {
3333    cuda::solve_step_on_stream(shared, ws, input)
3334}
3335
3336/// Stage 3.2 device-input PIRLS step. Reads `w_solver` and `grad_eta`
3337/// from caller-supplied device buffers (typically populated by
3338/// [`crate::gpu_kernels::pirls_row::launch_row_reweight_on_stream`]) instead of
3339/// uploading them from host arrays. Math is bit-identical to
3340/// [`solve_pirls_step_on_stream`]; this entry differs only by skipping
3341/// the per-iter `weights` and `gradient` host-to-device transfers — only
3342/// the small p×p penalty matrix still crosses the host boundary.
3343#[cfg(target_os = "linux")]
3344pub fn solve_pirls_step_on_stream_device(
3345    shared: &PirlsGpuSharedData,
3346    ws: &mut SigmaPirlsGpuWorkspace,
3347    input: PirlsStepStreamDeviceInput<'_, '_>,
3348) -> Result<PirlsGpuStep, String> {
3349    cuda::solve_step_on_stream_device(shared, ws, input)
3350}
3351
3352/// Stage 3.3 device-resident PIRLS loop driver. See
3353/// [`cuda::pirls_loop`] for the full per-iter contract. Only a few
3354/// 1-f64 scalars cross the host boundary per Newton iteration; β and
3355/// the final penalised Hessian are downloaded once at loop exit.
3356///
3357/// `step_lm_lambda` is the Levenberg–Marquardt damping applied to each
3358/// Newton solve only; it never enters the exported `penalized_hessian`,
3359/// `RidgePassport`, EDF, or penalty term.  `objective_ridge` is the
3360/// real model ridge that enters all of those.
3361#[cfg(target_os = "linux")]
3362pub fn pirls_loop_on_stream(
3363    shared: &PirlsGpuSharedData,
3364    ws: &mut SigmaPirlsGpuWorkspace,
3365    loop_ws: &mut cuda::PirlsLoopWorkspace,
3366    family: crate::gpu_kernels::pirls_row::PirlsRowFamily,
3367    curvature: crate::gpu_kernels::pirls_row::CurvatureMode,
3368    // Active Gamma dispersion shape (α > 0). Pass `1.0` for non-Gamma fits.
3369    gamma_shape: f64,
3370    beta0: ndarray::ArrayView1<'_, f64>,
3371    penalty_hessian: ndarray::ArrayView2<'_, f64>,
3372    // Linear shift `b` for the shifted-quadratic penalty `βᵀSβ−2βᵀb+c`.
3373    // Pass a zero-length or all-zero slice for fits with no prior-mean shift.
3374    linear_shift: ndarray::ArrayView1<'_, f64>,
3375    // Constant shift `c` for the shifted-quadratic penalty. Pass `0.0` when absent.
3376    constant_shift: f64,
3377    step_lm_lambda: f64,
3378    objective_ridge: f64,
3379    max_iter: usize,
3380    tol: f64,
3381    extra: Option<&cuda::PirlsLoopExtra<'_>>,
3382) -> Result<cuda::PirlsLoopOutcome, String> {
3383    cuda::pirls_loop(
3384        shared,
3385        ws,
3386        loop_ws,
3387        family,
3388        curvature,
3389        gamma_shape,
3390        beta0,
3391        penalty_hessian,
3392        linear_shift,
3393        constant_shift,
3394        step_lm_lambda,
3395        objective_ridge,
3396        max_iter,
3397        tol,
3398        extra,
3399    )
3400}
3401
3402/// Allocate a Stage 3.3 PIRLS loop workspace bound to the same stream
3403/// as `ws` against the shared device-resident design matrix.
3404#[cfg(target_os = "linux")]
3405pub fn allocate_pirls_loop_workspace(
3406    shared: &PirlsGpuSharedData,
3407    ws: &SigmaPirlsGpuWorkspace,
3408) -> Result<cuda::PirlsLoopWorkspace, String> {
3409    cuda::PirlsLoopWorkspace::allocate(shared, &ws.stream)
3410}
3411
3412/// GPU exact penalised least-squares for Gaussian-identity models.
3413///
3414/// Public wrapper around [`cuda::solve_gaussian_pls_on_stream`].  Delegates
3415/// immediately if the CUDA runtime is initialised; returns an error otherwise
3416/// so the caller can fall back to the CPU path.
3417#[cfg(target_os = "linux")]
3418pub fn solve_gaussian_pls_gpu(
3419    a_orig: ndarray::ArrayView2<'_, f64>,
3420    b_orig: ndarray::ArrayView1<'_, f64>,
3421    s_transformed: ndarray::ArrayView2<'_, f64>,
3422    linear_shift: ndarray::ArrayView1<'_, f64>,
3423    prior_mean_target: ndarray::ArrayView1<'_, f64>,
3424    ridge: f64,
3425    qs: Option<ndarray::ArrayView2<'_, f64>>,
3426) -> Result<cuda::GaussianPlsResult, String> {
3427    cuda::solve_gaussian_pls_on_stream(
3428        a_orig,
3429        b_orig,
3430        s_transformed,
3431        linear_shift,
3432        prior_mean_target,
3433        ridge,
3434        qs,
3435    )
3436}
3437
3438/// CPU fallback for the PIRLS-step GPU primitives.  When this build has no
3439/// CUDA runtime probed, the GPU entry points must still return numerically
3440/// correct results so that callers can route a single code path through
3441/// `*_gpu` while the canonical policy layer in `crate::gpu` records whether
3442/// device execution was selected. Returning `Err` here would silently force
3443/// every caller to grow an `if cuda { .. } else { .. }` branch and risk
3444/// drifting away from the GPU formula.
3445mod cpu_fallback {
3446    use super::{PirlsGpuInput, PirlsGpuStep};
3447    use gam_linalg::faer_ndarray::FaerCholesky;
3448    use crate::estimate::reml::assembly::xt_diag_x_dense_into;
3449    use faer::Side;
3450    use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
3451
3452    pub(super) fn weighted_crossprod_cpu(
3453        x: ArrayView2<'_, f64>,
3454        weights: ArrayView1<'_, f64>,
3455    ) -> Result<Array2<f64>, String> {
3456        validate(x, weights)?;
3457        let x_owned = x.to_owned();
3458        let w_owned = weights.to_owned();
3459        let mut scratch = Array2::<f64>::zeros(x_owned.dim());
3460        Ok(xt_diag_x_dense_into(&x_owned, &w_owned, &mut scratch))
3461    }
3462
3463    pub(super) fn solve_step_cpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
3464        validate(input.x, input.weights)?;
3465        let (_n, p) = input.x.dim();
3466        if input.penalty_hessian.dim() != (p, p) {
3467            return Err(format!(
3468                "penalty Hessian shape {:?} does not match p={p}",
3469                input.penalty_hessian.dim()
3470            ));
3471        }
3472        if input.gradient.len() != p {
3473            return Err(format!(
3474                "gradient length {} does not match p={p}",
3475                input.gradient.len()
3476            ));
3477        }
3478        let xtwx = weighted_crossprod_cpu(input.x, input.weights)?;
3479        // Exported H_final = XᵀWX + S + objective_ridge·I.
3480        let mut penalized_hessian = xtwx.clone();
3481        penalized_hessian += &input.penalty_hessian;
3482        if input.objective_ridge != 0.0 {
3483            for i in 0..p {
3484                penalized_hessian[[i, i]] += input.objective_ridge;
3485            }
3486        }
3487        // H_step = XᵀWX + S + step_lm_lambda·I for the Newton solve only.
3488        let mut h_step = xtwx;
3489        h_step += &input.penalty_hessian;
3490        if input.step_lm_lambda != 0.0 {
3491            for i in 0..p {
3492                h_step[[i, i]] += input.step_lm_lambda;
3493            }
3494        }
3495        let factor = h_step
3496            .cholesky(Side::Lower)
3497            .map_err(|e| format!("CPU Cholesky failed in PIRLS fallback: {e:?}"))?;
3498        let g = Array1::from_iter(input.gradient.iter().copied());
3499        // No negation: `input.gradient` is the full descent-direction RHS
3500        // `Xᵀscore − S·β + linear_shift`; solving H·δ = rhs gives δ directly (#257).
3501        let direction = factor.solvevec(&g);
3502        // Logdet comes from H_step's Cholesky (the actual factored matrix).
3503        let logdet = 2.0 * factor.diag().iter().map(|v| v.ln()).sum::<f64>();
3504        Ok(PirlsGpuStep {
3505            penalized_hessian,
3506            direction,
3507            logdet,
3508        })
3509    }
3510
3511    fn validate(x: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>) -> Result<(), String> {
3512        let (n, p) = x.dim();
3513        if weights.len() != n {
3514            return Err(format!(
3515                "weights length {} does not match rows {n}",
3516                weights.len()
3517            ));
3518        }
3519        if n == 0 || p == 0 {
3520            return Err("empty design cannot be solved".to_string());
3521        }
3522        Ok(())
3523    }
3524}
3525
3526pub fn cholesky_solve_gpu(
3527    hessian: ArrayView2<'_, f64>,
3528    rhs: ArrayView2<'_, f64>,
3529) -> Result<(Array2<f64>, f64), String> {
3530    gam_gpu::solver::cholesky_solve_gpu(hessian, rhs)
3531}
3532
3533/// Solution-only mixed-precision solve (logdet discarded). Skips the redundant
3534/// fp64 POTRF so the PIRLS Newton direction solve gets the full fp32-factor
3535/// speedup; the solution is fp64-accurate via iterative refinement.
3536pub fn cholesky_solve_only_gpu(
3537    hessian: ArrayView2<'_, f64>,
3538    rhs: ArrayView2<'_, f64>,
3539) -> Result<Array2<f64>, String> {
3540    gam_gpu::solver::cholesky_solve_only_gpu(hessian, rhs)
3541}
3542
3543pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
3544    gam_gpu::solver::cholesky_lower_gpu(hessian)
3545}
3546
3547/// Stage 3.2 V100 parity: the device-input PIRLS step must produce
3548/// numerically identical `(H, direction, logdet)` triples to the
3549/// host-input form when fed the same weights + gradient. This is the
3550/// production caller that satisfies the dead-pub scanner for
3551/// `solve_pirls_step_on_stream_device` and `PirlsStepStreamDeviceInput`.
3552#[cfg(all(test, target_os = "linux"))]
3553mod stream_device_parity_tests {
3554    use super::*;
3555    use ndarray::arr2;
3556
3557    #[test]
3558    fn device_input_step_matches_host_input_step_on_v100() {
3559        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3560            eprintln!("[stream_device_parity] no CUDA runtime — skipping");
3561            return;
3562        }
3563        let x = arr2(&[
3564            [1.0, 0.5, 0.1],
3565            [0.2, -0.3, 1.4],
3566            [0.7, 1.1, -0.2],
3567            [-0.4, 0.9, 0.6],
3568            [0.3, -0.8, 0.5],
3569        ]);
3570        let weights = ndarray::arr1(&[1.0, 0.8, 1.2, 0.9, 1.05]);
3571        // Pick g_eta directly (length n) and derive the equivalent
3572        // host-side gradient via the same Xᵀ projection the
3573        // device-input form does on the GPU.
3574        let g_eta = ndarray::arr1(&[0.10_f64, -0.20, 0.05, 0.30, -0.15]);
3575        let gradient: ndarray::Array1<f64> = x.t().dot(&g_eta);
3576        let penalty = arr2(&[[0.4, 0.0, 0.0], [0.0, 0.9, 0.0], [0.0, 0.0, 1.2]]);
3577        let lm_ridge = 0.1;
3578
3579        let n = x.nrows();
3580        let y_dummy = ndarray::Array1::<f64>::zeros(n);
3581        let prior_w_dummy = ndarray::Array1::<f64>::ones(n);
3582        let offset_dummy = ndarray::Array1::<f64>::zeros(n);
3583        let shared = upload_shared_pirls_gpu(
3584            x.view(),
3585            y_dummy.view(),
3586            prior_w_dummy.view(),
3587            offset_dummy.view(),
3588        )
3589        .expect("upload shared design");
3590        let mut ws_host = allocate_sigma_pirls_workspace(&shared).expect("alloc host-input ws");
3591        let mut ws_dev = allocate_sigma_pirls_workspace(&shared).expect("alloc device-input ws");
3592
3593        let host_step = solve_pirls_step_on_stream(
3594            &shared,
3595            &mut ws_host,
3596            PirlsStepStreamInput {
3597                weights: weights.view(),
3598                penalty_hessian: penalty.view(),
3599                gradient: gradient.view(),
3600                step_lm_lambda: lm_ridge,
3601                objective_ridge: 0.0,
3602            },
3603        )
3604        .expect("host-input step");
3605
3606        let mut w_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc w_dev");
3607        let mut g_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc g_dev");
3608        ws_dev
3609            .stream
3610            .memcpy_htod(weights.as_slice().unwrap(), &mut w_dev)
3611            .expect("upload w_dev");
3612        ws_dev
3613            .stream
3614            .memcpy_htod(g_eta.as_slice().unwrap(), &mut g_dev)
3615            .expect("upload g_dev");
3616
3617        let beta_dev_test = ws_dev
3618            .stream
3619            .alloc_zeros::<f64>(x.ncols())
3620            .expect("alloc beta_dev_test");
3621        let linear_shift_test = ndarray::Array1::<f64>::zeros(x.ncols());
3622        let dev_step = solve_pirls_step_on_stream_device(
3623            &shared,
3624            &mut ws_dev,
3625            PirlsStepStreamDeviceInput {
3626                w_solver_dev: &w_dev,
3627                grad_eta_dev: &g_dev,
3628                penalty_hessian: penalty.view(),
3629                step_lm_lambda: lm_ridge,
3630                objective_ridge: 0.0,
3631                beta_dev: &beta_dev_test,
3632                linear_shift: linear_shift_test.view(),
3633            },
3634        )
3635        .expect("device-input step");
3636
3637        // H + logdet must match to round-off (same XᵀWX, same penalty
3638        // add, same potrf).
3639        for i in 0..3 {
3640            for j in 0..3 {
3641                let diff = (host_step.penalized_hessian[[i, j]]
3642                    - dev_step.penalized_hessian[[i, j]])
3643                .abs();
3644                assert!(diff <= 1e-10, "H[{i},{j}] mismatch: {diff}");
3645            }
3646        }
3647        assert!(
3648            (host_step.logdet - dev_step.logdet).abs() <= 1e-9,
3649            "logdet mismatch: host={} dev={}",
3650            host_step.logdet,
3651            dev_step.logdet
3652        );
3653        // Direction must match because Xᵀ·g_eta = (Xᵀ·X)·α = host
3654        // gradient by construction.
3655        for i in 0..3 {
3656            let diff = (host_step.direction[i] - dev_step.direction[i]).abs();
3657            assert!(diff <= 1e-9, "direction[{i}] mismatch: {diff}");
3658        }
3659    }
3660
3661    /// V100 hill-climb gate: at large-scale (n=80k, p=44,
3662    /// BernoulliLogit/Fisher) the device-resident loop must be ≥10×
3663    /// faster than the CPU reference. Marked `#[ignore]` so it only
3664    /// runs when explicitly invoked (`cargo test -- --ignored
3665    /// hill_climb_loop`); the CI/mac path can't host the GPU work
3666    /// anyway. Uses CPU `row_reweight_cpu` + faer Cholesky as the
3667    /// PIRLS reference loop to avoid dragging in `solver::pirls`'s
3668    /// 13k-line state machine.
3669    #[test]
3670    fn hill_climb_loop_beats_cpu_10x_on_large_scale_logit() {
3671        use crate::gpu_kernels::pirls_row::{
3672            CurvatureMode, PirlsRowFamily, RowInput, row_reweight_cpu,
3673        };
3674        use std::time::Instant;
3675        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3676            eprintln!("[hill_climb] no CUDA runtime — skipping");
3677            return;
3678        }
3679        let n = 80_000_usize;
3680        let p = 44_usize;
3681        // Synthesise X (col-major dense) and y from a known β.
3682        let beta_true: ndarray::Array1<f64> = ndarray::Array1::from_iter(
3683            (0..p).map(|j| 0.05 * ((j as f64) - 0.5 * p as f64) / p as f64),
3684        );
3685        let mut x = ndarray::Array2::<f64>::zeros((n, p));
3686        for i in 0..n {
3687            for j in 0..p {
3688                x[[i, j]] = ((i as f64 + j as f64 * 17.0) * 0.001).sin();
3689            }
3690        }
3691        let eta: ndarray::Array1<f64> = x.dot(&beta_true);
3692        let y: ndarray::Array1<f64> = eta
3693            .iter()
3694            .enumerate()
3695            .map(|(i, &e)| {
3696                let mu = 0.5 * (1.0 + (0.5 * e).tanh());
3697                if (i as f64 * 1.31).fract() < mu {
3698                    1.0
3699                } else {
3700                    0.0
3701                }
3702            })
3703            .collect();
3704        let prior_w = ndarray::Array1::<f64>::ones(n);
3705        let penalty = ndarray::Array2::<f64>::eye(p) * 1e-3;
3706        let beta0 = ndarray::Array1::<f64>::zeros(p);
3707
3708        // GPU timing.
3709        let offset_bench = ndarray::Array1::<f64>::zeros(n);
3710        let shared =
3711            upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_bench.view())
3712                .expect("upload shared design");
3713        let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
3714        let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
3715        let t0 = Instant::now();
3716        // No prior-mean shift in this benchmark — penalty = ½βᵀSβ
3717        // with `s_transformed = penalty`, `linear_shift = 0`,
3718        // `constant_shift = 0`.
3719        let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
3720        drop(
3721            pirls_loop_on_stream(
3722                &shared,
3723                &mut ws,
3724                &mut loop_ws,
3725                PirlsRowFamily::BernoulliLogit,
3726                CurvatureMode::Fisher,
3727                1.0,
3728                beta0.view(),
3729                penalty.view(),
3730                linear_shift_zero.view(),
3731                0.0,
3732                0.0,
3733                0.0,
3734                30,
3735                1e-6,
3736                None,
3737            )
3738            .expect("pirls loop"),
3739        );
3740        let gpu_secs = t0.elapsed().as_secs_f64();
3741
3742        // CPU reference: same PIRLS structure (eta = Xβ; row reweight;
3743        // XᵀWX + Sλ; faer Cholesky; β update with α=1).
3744        let t1 = Instant::now();
3745        let mut beta = ndarray::Array1::<f64>::zeros(p);
3746        for _ in 0..30 {
3747            let eta: ndarray::Array1<f64> = x.dot(&beta);
3748            let mut w = ndarray::Array1::<f64>::zeros(n);
3749            let mut g = ndarray::Array1::<f64>::zeros(n);
3750            for i in 0..n {
3751                let out = row_reweight_cpu(
3752                    PirlsRowFamily::BernoulliLogit,
3753                    CurvatureMode::Fisher,
3754                    RowInput {
3755                        eta: eta[i],
3756                        y: y[i],
3757                        prior_weight: prior_w[i],
3758                    },
3759                    1.0,
3760                );
3761                w[i] = out.w_solver;
3762                g[i] = out.grad_eta;
3763            }
3764            let mut wx_full = x.clone();
3765            for j in 0..p {
3766                for i in 0..n {
3767                    wx_full[[i, j]] *= w[i];
3768                }
3769            }
3770            let h = x.t().dot(&wx_full) + &penalty;
3771            let rhs = x.t().dot(&g);
3772            use gam_linalg::faer_ndarray::FaerCholesky;
3773            let chol = h
3774                .cholesky(faer::Side::Lower)
3775                .expect("CPU PIRLS reference Cholesky");
3776            let d = chol.solvevec(&rhs);
3777            for i in 0..p {
3778                beta[i] -= d[i];
3779            }
3780        }
3781        let cpu_secs = t1.elapsed().as_secs_f64();
3782
3783        let speedup = cpu_secs / gpu_secs;
3784        eprintln!(
3785            "[hill_climb] n={n} p={p} BernoulliLogit/Fisher: gpu={:.3}s cpu={:.3}s speedup={:.2}×",
3786            gpu_secs, cpu_secs, speedup
3787        );
3788        assert!(
3789            speedup >= 10.0,
3790            "GPU PIRLS loop must be ≥10× CPU at large-scale shape; got speedup={speedup:.2}× (gpu={gpu_secs:.3}s cpu={cpu_secs:.3}s)"
3791        );
3792    }
3793
3794    /// Stage 3.3 production caller: end-to-end GPU PIRLS loop on a
3795    /// Gaussian-identity fit reaches OLS β to high precision in a
3796    /// handful of iterations and matches the closed-form
3797    /// `(XᵀX + Sλ)⁻¹·Xᵀy` solution.
3798    #[test]
3799    fn pirls_loop_converges_to_ols_solution_on_gaussian_identity() {
3800        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
3801            eprintln!("[stage_3_3] no CUDA runtime — skipping");
3802            return;
3803        }
3804        let x = arr2(&[
3805            [1.0, 0.5, 0.1],
3806            [0.2, -0.3, 1.4],
3807            [0.7, 1.1, -0.2],
3808            [-0.4, 0.9, 0.6],
3809            [0.3, -0.8, 0.5],
3810            [1.1, 0.2, -0.4],
3811            [-0.6, 0.4, 0.3],
3812            [0.8, -1.0, 0.7],
3813        ]);
3814        let n = x.nrows();
3815        let p = x.ncols();
3816        // y = X·β_true + small wiggle (still in identity link space).
3817        let beta_true = ndarray::arr1(&[0.5_f64, -1.2, 0.3]);
3818        let y: ndarray::Array1<f64> = x.dot(&beta_true);
3819        let prior_w = ndarray::Array1::<f64>::ones(n);
3820        let penalty = ndarray::Array2::<f64>::eye(p) * 1e-4; // tiny ridge
3821        let beta0 = ndarray::Array1::<f64>::zeros(p);
3822
3823        let offset_ols = ndarray::Array1::<f64>::zeros(n);
3824        let shared = upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_ols.view())
3825            .expect("upload shared design");
3826        let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
3827        let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
3828
3829        // No prior-mean shift in this OLS test — `linear_shift = 0`,
3830        // `constant_shift = 0`. `y` / `prior_w` are now uploaded via
3831        // the shared workspace (#258).
3832        let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
3833        let outcome = pirls_loop_on_stream(
3834            &shared,
3835            &mut ws,
3836            &mut loop_ws,
3837            crate::gpu_kernels::pirls_row::PirlsRowFamily::GaussianIdentity,
3838            crate::gpu_kernels::pirls_row::CurvatureMode::Fisher,
3839            1.0,
3840            beta0.view(),
3841            penalty.view(),
3842            linear_shift_zero.view(),
3843            0.0,
3844            0.0,
3845            0.0,
3846            20,
3847            1e-9,
3848            None,
3849        )
3850        .expect("pirls loop");
3851
3852        // Closed-form OLS (with tiny ridge).
3853        let xtx = x.t().dot(&x);
3854        let xty = x.t().dot(&y);
3855        let h_ref = xtx + &penalty;
3856        // Solve via the crate's faer/ndarray bridge.
3857        use gam_linalg::faer_ndarray::FaerCholesky;
3858        let chol = h_ref
3859            .cholesky(faer::Side::Lower)
3860            .expect("OLS reference Cholesky");
3861        let beta_ref: ndarray::Array1<f64> = chol.solvevec(&xty);
3862
3863        // Gaussian-identity PIRLS converges in one Newton iter (linear
3864        // problem); the loop may take a few iters because the line
3865        // search starts at α=1 and the first step is exact. Allow up
3866        // to 5 iters but assert convergence and 1e-6 abs precision.
3867        assert!(
3868            outcome.converged || outcome.iterations <= 5,
3869            "PIRLS loop did not converge in 20 iters on Gaussian-identity (iters={})",
3870            outcome.iterations
3871        );
3872        for i in 0..p {
3873            let diff = (outcome.beta[i] - beta_ref[i]).abs();
3874            assert!(
3875                diff <= 1e-6,
3876                "β[{i}] mismatch: gpu={} ref={} diff={}",
3877                outcome.beta[i],
3878                beta_ref[i],
3879                diff
3880            );
3881        }
3882        // Also check H matches XᵀX + Sλ (no W weighting since identity-link
3883        // canonical-weight = 1 for Gaussian).
3884        for i in 0..p {
3885            for j in 0..p {
3886                let diff = (outcome.penalized_hessian[[i, j]] - h_ref[[i, j]]).abs();
3887                assert!(diff <= 1e-8, "H[{i},{j}] mismatch: {diff}");
3888            }
3889        }
3890    }
3891}
3892
3893/// CPU-fallback contract for the weighted-crossprod GPU dispatcher.
3894///
3895/// `weighted_crossprod_gpu` moved here from `gam-gpu` during the #1521 crate
3896/// carve. On a host with no usable CUDA runtime it must transparently fall back
3897/// to the dense CPU path, return `Ok`, and produce the exact XᵀWX. This guards
3898/// the panic-free / Ok-via-CPU-fallback contract previously (loosely) checked in
3899/// gam-gpu's `cpu_only_host_never_panics_on_gpu_entry_points`, which could no
3900/// longer reach the function after the carve.
3901#[cfg(test)]
3902mod weighted_crossprod_cpu_fallback_tests {
3903    use super::weighted_crossprod_gpu;
3904    use ndarray::{Array1, Array2};
3905
3906    #[test]
3907    fn weighted_crossprod_gpu_cpu_fallback_matches_dense_xtwx() {
3908        // Small, below any GPU dispatch threshold → exercises the CPU fallback
3909        // on a CPU-only host (and stays Ok on a GPU host via the same contract).
3910        let x = Array2::<f64>::from_shape_fn((4, 3), |(i, j)| (i + j) as f64 + 1.0);
3911        let w = Array1::<f64>::from_vec(vec![0.5, 1.0, 1.5, 2.0]);
3912
3913        let got = weighted_crossprod_gpu(x.view(), w.view())
3914            .expect("weighted_crossprod_gpu must return Ok via CPU fallback on a CPU-only host");
3915
3916        // Reference XᵀWX = Σ_k w_k x_k x_kᵀ, formed directly.
3917        let (n, p) = x.dim();
3918        let mut expected = Array2::<f64>::zeros((p, p));
3919        for k in 0..n {
3920            for i in 0..p {
3921                for j in 0..p {
3922                    expected[[i, j]] += w[k] * x[[k, i]] * x[[k, j]];
3923                }
3924            }
3925        }
3926
3927        assert_eq!(got.dim(), (p, p));
3928        for i in 0..p {
3929            for j in 0..p {
3930                let diff = (got[[i, j]] - expected[[i, j]]).abs();
3931                assert!(diff <= 1e-10, "XtWX[{i},{j}] mismatch: got vs expected diff={diff}");
3932            }
3933        }
3934    }
3935}