Skip to main content

gam_gpu/
linalg_dispatch.rs

1//! Automatic GPU dispatch shim for dense linear algebra hot kernels.
2//!
3//! Every `try_*` entry point in this module is invoked unconditionally from
4//! `gam_linalg::faer_ndarray` before the CPU fast-path runs. The decision to send
5//! the kernel to a device is fully automatic and never requires a user-facing
6//! flag — it depends only on:
7//!
8//!   1. `GpuRuntime::global()` returning `Some(_)` (a device was probed at
9//!      process startup).
10//!   2. The kernel being large enough to amortize launch/PCIe overhead, per
11//!      the thresholds in `policy::GpuDispatchPolicy`.
12//!   3. cudarc successfully dynamically loading `libcuda` at process startup
13//!      via its `fallback-dynamic-loading` feature. When the loader fails
14//!      (no driver, no toolkit installed), `GpuRuntime::probe()` returns
15//!      `Ok(None)` and every `try_*` returns `None` so the caller falls
16//!      through to the existing faer CPU kernel.
17//!
18//! The wiring lives here so `solver/pirls.rs` and the family Hessian
19//! assemblers can stay backend-agnostic: they call `gam_linalg::faer_ndarray::fast_*`
20//! and get GPU acceleration automatically whenever it is profitable.
21
22use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
23
24use super::device_runtime::GpuRuntime;
25
26pub struct CudaGemmDispatch;
27
28impl gam_linalg::gpu_hook::GpuGemmDispatch for CudaGemmDispatch {
29    fn try_fast_atb(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
30        try_fast_atb(a, b)
31    }
32
33    fn try_fast_ab(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
34        try_fast_ab(a, b)
35    }
36
37    fn try_fast_av(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
38        try_fast_av(a, v)
39    }
40
41    fn try_fast_atv(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
42        try_fast_atv(a, v)
43    }
44
45    fn try_fast_xt_diag_x(
46        &self,
47        x: ArrayView2<'_, f64>,
48        w: ArrayView1<'_, f64>,
49    ) -> Option<Array2<f64>> {
50        try_fast_xt_diag_x(x, w)
51    }
52
53    fn try_fast_xt_diag_y(
54        &self,
55        x: ArrayView2<'_, f64>,
56        w: ArrayView1<'_, f64>,
57        y: ArrayView2<'_, f64>,
58    ) -> Option<Array2<f64>> {
59        try_fast_xt_diag_y(x, w, y)
60    }
61
62    fn try_fast_joint_hessian_2x2(
63        &self,
64        x_a: ArrayView2<'_, f64>,
65        x_b: ArrayView2<'_, f64>,
66        w_aa: ArrayView1<'_, f64>,
67        w_ab: ArrayView1<'_, f64>,
68        w_bb: ArrayView1<'_, f64>,
69    ) -> Option<Array2<f64>> {
70        try_fast_joint_hessian_2x2(x_a, x_b, w_aa, w_ab, w_bb)
71    }
72
73    fn device_count(&self) -> usize {
74        GpuRuntime::global().map_or(0, |rt| rt.device_count())
75    }
76
77    fn try_fast_ab_broadcast_b_batched(
78        &self,
79        a3: ArrayView3<'_, f64>,
80        b: ArrayView2<'_, f64>,
81    ) -> Option<Array3<f64>> {
82        try_fast_ab_broadcast_b_batched(a3, b)
83    }
84}
85
86/// Discriminator used by [`route_through_gpu`] to apply the right
87/// size threshold from [`super::policy::GpuDispatchPolicy`].
88#[derive(Clone, Copy, Debug)]
89pub enum DispatchOp {
90    /// Generic matrix-matrix product with the given output dims and reduction depth.
91    Gemm { m: usize, n: usize, k: usize },
92    /// Batch of independent matrix-matrix products.
93    BatchedGemm {
94        batch: usize,
95        m: usize,
96        n: usize,
97        k: usize,
98    },
99    /// Dense Cholesky factorization.
100    Potrf { p: usize, batch: usize },
101    /// Batched small-dense Cholesky factorization where each block has the
102    /// same small width `p` (≲ 32) but the batch is large. Routed through
103    /// `cusolverDnDpotrfBatched` and kept device-resident for downstream
104    /// triangular solves (Arrow-Schur, Stage-3 PIRLS).
105    SmallDenseBatchedPotrf { p: usize, batch: usize },
106    /// Triangular matrix solve.
107    Trsm { m: usize, n: usize },
108    /// Matrix-vector (or matrix · single-column) product.
109    Gemv { m: usize, k: usize },
110    /// `Xᵀ · diag(w) · X` reduction with n rows and p columns.
111    XtDiagX { n: usize, p: usize },
112    /// `Xᵀ · diag(w) · Y` reduction; px and q are the design and response widths.
113    XtDiagY { n: usize, px: usize, q: usize },
114    /// 2×2 joint Hessian block with two design widths.
115    JointHessian2x2 { n: usize, pa: usize, pb: usize },
116}
117
118impl DispatchOp {
119    /// Conservative flop estimate used for the generic `gemm_min_flops` gate.
120    #[inline]
121    pub const fn flops(self) -> u128 {
122        match self {
123            Self::Gemm { m, n, k } => 2u128 * (m as u128) * (n as u128) * (k as u128),
124            Self::BatchedGemm { batch, m, n, k } => {
125                2u128 * (batch as u128) * (m as u128) * (n as u128) * (k as u128)
126            }
127            Self::Gemv { m, k } => 2u128 * (m as u128) * (k as u128),
128            Self::Potrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
129            Self::SmallDenseBatchedPotrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
130            Self::Trsm { m, n } => (m as u128) * (m as u128) * (n as u128),
131            Self::XtDiagX { n, p } => 2u128 * (n as u128) * (p as u128) * (p as u128),
132            Self::XtDiagY { n, px, q } => 2u128 * (n as u128) * (px as u128) * (q as u128),
133            Self::JointHessian2x2 { n, pa, pb } => {
134                let total = (pa as u128) + (pb as u128);
135                2u128 * (n as u128) * total * total
136            }
137        }
138    }
139}
140
141/// Returns `Some(runtime)` when both a device is available and the workload
142/// is large enough per policy. The caller can then attempt the actual device
143/// kernel; any backend failure is expected to return `None` from the lower
144/// layer and the CPU fast path resumes.
145#[inline]
146#[must_use]
147pub fn route_through_gpu(op: DispatchOp) -> Option<&'static GpuRuntime> {
148    let runtime = GpuRuntime::global()?;
149    let policy = &runtime.policy;
150    let admit = match op {
151        DispatchOp::Gemm { m, n, k } => {
152            op.flops() >= (policy.gemm_min_flops as u128) && m.min(n).min(k) > 0
153        }
154        DispatchOp::BatchedGemm { batch, m, n, k } => {
155            op.flops() >= (policy.gemm_min_flops as u128) && batch > 1 && m.min(n).min(k) > 0
156        }
157        DispatchOp::Gemv { m, k } => {
158            op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && k > 0
159        }
160        DispatchOp::Potrf { p, batch } => {
161            p > 0
162                && batch > 0
163                && (p >= policy.potrf_min_p
164                    || (batch > 1 && op.flops() >= policy.gemm_min_flops as u128))
165        }
166        DispatchOp::SmallDenseBatchedPotrf { p, batch } => {
167            p > 0
168                && p <= policy.small_dense_batched_potrf_max_p
169                && batch >= policy.small_dense_batched_potrf_min_batch
170        }
171        DispatchOp::Trsm { m, n } => {
172            op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && n > 0
173        }
174        DispatchOp::XtDiagX { n, p } => policy.xtwx_target_is_gpu(n, p, true),
175        DispatchOp::XtDiagY { n, px, q } => policy.xtwy_target_is_gpu(n, px, q, true),
176        DispatchOp::JointHessian2x2 { n, pa, pb } => {
177            n > 0 && (pa + pb) > 0 && op.flops() >= policy.gemm_min_flops as u128
178        }
179    };
180    if admit { Some(runtime) } else { None }
181}
182
183/// Minimum batch size before a batched kernel is worth splitting across more
184/// than one device. Below this the per-tile launch + extra H2D/D2H staging on a
185/// second device costs more than the GEMM time it saves, so a small batch stays
186/// on the single primary device. This is a fixed, conservatively-large constant
187/// (magic-by-default; no flag) — multi-GPU only kicks in for genuinely large
188/// batches such as large-scale Arrow-Schur / Stage-3 blocks.
189#[cfg(target_os = "linux")]
190const MULTI_GPU_BATCH_FLOOR: usize = 64;
191
192/// True when the pool has >1 usable device and `batch` is large enough that
193/// splitting the batch dimension across devices is worthwhile.
194#[cfg(target_os = "linux")]
195#[inline]
196fn should_split_batch(batch: usize) -> bool {
197    GpuRuntime::global().is_some_and(|rt| rt.device_count() > 1) && batch >= MULTI_GPU_BATCH_FLOOR
198}
199
200#[inline]
201#[must_use]
202pub fn try_fast_ab_broadcast_b_batched(
203    a: ArrayView3<'_, f64>,
204    b: ArrayView2<'_, f64>,
205) -> Option<Array3<f64>> {
206    let (batch, m, k) = a.dim();
207    let (bk, n) = b.dim();
208    if k != bk || batch == 0 || m == 0 || n == 0 {
209        return None;
210    }
211    #[cfg(not(target_os = "linux"))]
212    {
213        return None;
214    }
215    #[cfg(target_os = "linux")]
216    {
217        let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
218        if should_split_batch(batch) {
219            if let Some(out) = scatter_broadcast_b_batched(runtime, a, b, m, n) {
220                return Some(out);
221            }
222            // A multi-GPU tile failed; fall through to the single-device path so
223            // the whole batch is still produced on the primary device.
224        }
225        cuda_backend::gemm_broadcast_b_batched(runtime.device.ordinal, a, b)
226    }
227}
228
229/// Multi-GPU broadcast-B batched GEMM: split the batch dimension across all
230/// devices via [`scatter_batched`], running one cuBLAS strided-batched GEMM per
231/// device tile (each on its own bound ordinal). `b` is shared (broadcast) across
232/// every tile. Returns `None` if any tile fails so the caller falls back to the
233/// single-device path.
234#[cfg(target_os = "linux")]
235fn scatter_broadcast_b_batched(
236    runtime: &GpuRuntime,
237    a: ArrayView3<'_, f64>,
238    b: ArrayView2<'_, f64>,
239    m: usize,
240    n: usize,
241) -> Option<Array3<f64>> {
242    let batch = a.dim().0;
243    // One slot per batch item; the slot carries its own input matrix so the
244    // per-tile closure is range-agnostic and owns disjoint memory.
245    let mut items: Vec<(Array2<f64>, Option<Array2<f64>>)> = (0..batch)
246        .map(|i| (a.index_axis(ndarray::Axis(0), i).to_owned(), None))
247        .collect();
248    super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
249        let tile_batch = tile.len();
250        if tile_batch == 0 {
251            return Some(());
252        }
253        let k = b.dim().0;
254        let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
255        for (idx, (a_i, _)) in tile.iter().enumerate() {
256            a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
257        }
258        let out = cuda_backend::gemm_broadcast_b_batched(ordinal, a_tile.view(), b)?;
259        for (idx, (_, slot)) in tile.iter_mut().enumerate() {
260            *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
261        }
262        Some(())
263    })?;
264    stitch_batched(items, m, n)
265}
266
267#[inline]
268#[must_use]
269pub fn try_fast_abt_strided_batched(
270    a: ArrayView3<'_, f64>,
271    b: ArrayView3<'_, f64>,
272) -> Option<Array3<f64>> {
273    let (batch, m, k) = a.dim();
274    let (batch_b, n, k_b) = b.dim();
275    if batch != batch_b || k != k_b || batch == 0 || m == 0 || n == 0 {
276        return None;
277    }
278    #[cfg(not(target_os = "linux"))]
279    {
280        return None;
281    }
282    #[cfg(target_os = "linux")]
283    {
284        let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
285        if should_split_batch(batch) {
286            if let Some(out) = scatter_abt_strided_batched(runtime, a, b, m, n) {
287                return Some(out);
288            }
289        }
290        cuda_backend::gemm_abt_strided_batched(runtime.device.ordinal, a, b)
291    }
292}
293
294/// Multi-GPU A·Bᵀ strided-batched GEMM: split the batch dimension across all
295/// devices, running one strided-batched GEMM per device tile. Both `a` and `b`
296/// are batched (one matrix per batch item), so each slot carries its own
297/// `(a_i, b_i)` pair. Returns `None` on any tile failure.
298#[cfg(target_os = "linux")]
299fn scatter_abt_strided_batched(
300    runtime: &GpuRuntime,
301    a: ArrayView3<'_, f64>,
302    b: ArrayView3<'_, f64>,
303    m: usize,
304    n: usize,
305) -> Option<Array3<f64>> {
306    let batch = a.dim().0;
307    let mut items: Vec<(Array2<f64>, Array2<f64>, Option<Array2<f64>>)> = (0..batch)
308        .map(|i| {
309            (
310                a.index_axis(ndarray::Axis(0), i).to_owned(),
311                b.index_axis(ndarray::Axis(0), i).to_owned(),
312                None,
313            )
314        })
315        .collect();
316    super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
317        let tile_batch = tile.len();
318        if tile_batch == 0 {
319            return Some(());
320        }
321        let k = tile[0].0.dim().1;
322        let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
323        let mut b_tile = Array3::<f64>::zeros((tile_batch, n, k));
324        for (idx, (a_i, b_i, _)) in tile.iter().enumerate() {
325            a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
326            b_tile.index_axis_mut(ndarray::Axis(0), idx).assign(b_i);
327        }
328        let out = cuda_backend::gemm_abt_strided_batched(ordinal, a_tile.view(), b_tile.view())?;
329        for (idx, (_, _, slot)) in tile.iter_mut().enumerate() {
330            *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
331        }
332        Some(())
333    })?;
334    let slots: Vec<((), Option<Array2<f64>>)> =
335        items.into_iter().map(|(_, _, slot)| ((), slot)).collect();
336    stitch_batched(slots, m, n)
337}
338
339/// Reassemble per-batch output slots (filled by the device tiles) into a single
340/// `batch × m × n` array. Returns `None` if any slot is still empty (a tile
341/// silently skipped its item), which forces the single-device fallback.
342#[cfg(target_os = "linux")]
343fn stitch_batched<L>(
344    items: Vec<(L, Option<Array2<f64>>)>,
345    m: usize,
346    n: usize,
347) -> Option<Array3<f64>> {
348    let batch = items.len();
349    let mut out = Array3::<f64>::zeros((batch, m, n));
350    for (idx, (_, slot)) in items.into_iter().enumerate() {
351        let block = slot?;
352        if block.dim() != (m, n) {
353            return None;
354        }
355        out.index_axis_mut(ndarray::Axis(0), idx).assign(&block);
356    }
357    Some(out)
358}
359
360// ---------------------------------------------------------------------------
361// Dispatch entry points. Each takes views to keep the call site allocation-
362// free and returns Some(result) iff the GPU actually produced one. The CPU
363// fast path resumes on None.
364//
365// CUDA kernels are compiled into the runtime through cudarc's dynamic loader.
366// Each entry point admits only profitable workloads, then returns `None` when
367// no CUDA runtime path is available or the backend reports failure.
368// ---------------------------------------------------------------------------
369
370#[inline]
371#[must_use]
372pub fn try_fast_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
373    let (m, k) = a.dim();
374    let (kb, n) = b.dim();
375    if k != kb {
376        return None;
377    }
378    // Record every dispatch attempt — including ones that fall back to CPU
379    // because either the runtime is unavailable or the workload is below
380    // policy threshold. The diagnostics snapshot is what downstream telemetry
381    // uses to attribute CPU vs GPU time, so it must reflect *attempts*, not
382    // just successful device launches.
383    let runtime = route_through_gpu(DispatchOp::Gemm { m, n, k });
384    let used_gpu = runtime.is_some();
385    super::profile::record(super::profile::KernelStat {
386        name: "try_fast_ab",
387        n: m,
388        p: n,
389        k,
390        flops_est: (DispatchOp::Gemm { m, n, k }.flops().min(usize::MAX as u128)) as usize,
391        gpu_ms: if used_gpu { Some(0.0) } else { None },
392        ..Default::default()
393    });
394    #[cfg(not(target_os = "linux"))]
395    {
396        None
397    }
398    #[cfg(target_os = "linux")]
399    {
400        let runtime = runtime?;
401        cuda_backend::gemm(runtime, a, b, false, false)
402    }
403}
404
405#[inline]
406#[must_use]
407pub fn try_fast_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
408    let (n_a, p) = a.dim();
409    let (n_b, q) = b.dim();
410    if n_a != n_b || p == 0 || q == 0 {
411        return None;
412    }
413    #[cfg(not(target_os = "linux"))]
414    {
415        return None;
416    }
417    #[cfg(target_os = "linux")]
418    {
419        let runtime = route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
420        cuda_backend::gemm(runtime, a, b, true, false)
421    }
422}
423
424/// `Aᵀ·B` on a specific device ordinal, for pool-tiled callers that already own
425/// the ordinal (the worker thread has bound that ordinal's context). Semantics
426/// are identical to [`try_fast_atb`] — `a` is `m×k`, `b` is `m×n`, output is the
427/// `k×n` product `aᵀ·b` — but the kernel is pinned to `ordinal` instead of the
428/// probe-selected primary device. Returns `None` when CUDA is unavailable, the
429/// shape is below policy threshold, or the backend reports a transient failure,
430/// so the caller runs its CPU fallback. f64 only.
431#[inline]
432#[must_use]
433pub fn try_fast_atb_on_ordinal(
434    ordinal: usize,
435    a: ArrayView2<'_, f64>,
436    b: ArrayView2<'_, f64>,
437) -> Option<Array2<f64>> {
438    let (n_a, p) = a.dim();
439    let (n_b, q) = b.dim();
440    if n_a != n_b || p == 0 || q == 0 {
441        return None;
442    }
443    #[cfg(not(target_os = "linux"))]
444    {
445        // No CUDA off Linux, so the per-ordinal fast path is unavailable. Read
446        // `ordinal` once (the cross-platform signature must carry it for the
447        // Linux branch below) and decline so the caller runs its CPU AtB. Unlike
448        // `a`/`b` — already consumed by `.dim()` above — `ordinal` is otherwise
449        // untouched on this target, and `warnings = "deny"` rejects a dead bind.
450        log::trace!(
451            "try_fast_atb_on_ordinal: CUDA unavailable off Linux; declining ordinal {ordinal}"
452        );
453        return None;
454    }
455    #[cfg(target_os = "linux")]
456    {
457        // The size/policy gate is identical to `try_fast_atb`; only the target
458        // device differs. We still consult `route_through_gpu` so a below-floor
459        // shape declines to the caller's CPU path rather than paying PCIe cost.
460        //
461        // Arrow-Schur's `tile_schur_partial` reaches this gate after stacking
462        // its per-row factors into one transpose tile GEMM:
463        // `(total_d x k)^T * (total_d x k)`.
464        // At the SAE shape n=2000, p=2048, M=12, K=8, that is
465        // 2*(n*M)*p^2 = 201_326_592_000 flops for one stacked tile, or
466        // 1_610_612_736_000 flops across K=8 batches, so admission must be
467        // keyed on work rather than the observation row count.
468        route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
469        cuda_backend::gemm_on_ordinal(ordinal, a, b, true, false)
470    }
471}
472
473#[inline]
474#[must_use]
475pub fn try_fast_av(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
476    let (m, k) = a.dim();
477    if k != v.len() || m == 0 || k == 0 {
478        return None;
479    }
480    #[cfg(not(target_os = "linux"))]
481    {
482        return None;
483    }
484    #[cfg(target_os = "linux")]
485    {
486        let runtime = route_through_gpu(DispatchOp::Gemv { m, k })?;
487        cuda_backend::gemv(runtime, a, v, false)
488    }
489}
490
491#[inline]
492#[must_use]
493pub fn try_fast_atv(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
494    let (n, p) = a.dim();
495    if n != v.len() || n == 0 || p == 0 {
496        return None;
497    }
498    #[cfg(not(target_os = "linux"))]
499    {
500        return None;
501    }
502    #[cfg(target_os = "linux")]
503    {
504        let runtime = route_through_gpu(DispatchOp::Gemv { m: p, k: n })?;
505        cuda_backend::gemv(runtime, a, v, true)
506    }
507}
508
509#[inline]
510#[must_use]
511pub fn try_fast_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
512    let (n, p) = x.dim();
513    if n != w.len() || n == 0 || p == 0 {
514        return None;
515    }
516    #[cfg(not(target_os = "linux"))]
517    {
518        return None;
519    }
520    #[cfg(target_os = "linux")]
521    {
522        let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
523        cuda_backend::xt_diag_x(runtime, x, w)
524    }
525}
526
527/// #1017 Phase 3: a device-resident design matrix for repeated `Xᵀ·diag(w)·X`
528/// Gram evaluations that uploads `X` to the device ONCE.
529///
530/// The per-call [`try_fast_xt_diag_x`] re-uploads the full `n×p` `X` on every
531/// call. The SAE / IRLS inner loop holds `X` fixed and rebuilds the Gram once
532/// per Newton/PIRLS weight update, so the repeated H2D of `X` is pure waste —
533/// measured on an A100 (#1412) it makes the `XtWX` GEMM ~98% of the pipeline at
534/// <20% device utilisation (the device is starved by staging, not arithmetic).
535/// This handle uploads `X` once at construction; each [`Self::gram`] crosses
536/// only the `n`-vector `w` H2D and the `p×p` Gram D2H, so the per-Gram transfer
537/// shrinks by a factor of `p`.
538///
539/// Admission keys on the same work-based [`DispatchOp::XtDiagX`] gate as the
540/// per-call path (so it engages exactly when the Gram is GPU-profitable) and the
541/// numerics are bit-identical to [`try_fast_xt_diag_x`] on the same device
542/// (same `cublasDdgmm` row-scale + `gemm` reduction order). On a non-CUDA host,
543/// a below-threshold shape, or any device failure, [`Self::try_new`] returns
544/// `None` and the caller keeps its CPU/per-call path — residency never changes
545/// the result, only where (and how often) `X` is staged.
546pub struct ResidentDesignGram {
547    #[cfg(target_os = "linux")]
548    inner: super::blas::ResidentWeightedGram,
549    #[cfg(not(target_os = "linux"))]
550    _never: std::convert::Infallible,
551}
552
553impl ResidentDesignGram {
554    /// Upload `x` (`n×p`) to the device once. Returns `None` when CUDA is
555    /// unavailable, the shape is below the GPU Gram threshold, or the upload
556    /// fails.
557    #[must_use]
558    pub fn try_new(x: ArrayView2<'_, f64>) -> Option<Self> {
559        let (n, p) = x.dim();
560        if n == 0 || p == 0 {
561            return None;
562        }
563        #[cfg(not(target_os = "linux"))]
564        {
565            None
566        }
567        #[cfg(target_os = "linux")]
568        {
569            let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
570            let inner = super::blas::ResidentWeightedGram::new(runtime.device.ordinal, x)?;
571            Some(Self { inner })
572        }
573    }
574
575    /// Compute `Xᵀ·diag(w)·X` reusing the resident `X`. `w` must have one entry
576    /// per design row. Returns `None` on a shape mismatch or device failure.
577    #[must_use]
578    pub fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
579        #[cfg(not(target_os = "linux"))]
580        {
581            // SAFETY: off CUDA, `try_new` always returns `None`, so no `Self` of
582            // this type is ever constructed and this method is statically
583            // unreachable. Returning a benign `None` would silently launder that
584            // impossibility into a "GPU declined" sentinel, so fail loudly. The
585            // `w.len()` use also consumes the parameter on this target.
586            panic!(
587                "ResidentDesignGram cannot be constructed off CUDA (w.len()={})",
588                w.len()
589            )
590        }
591        #[cfg(target_os = "linux")]
592        {
593            self.inner.gram(w)
594        }
595    }
596
597    /// Solve the penalized normal equations `(Xᵀ·diag(w)·X + ridge·I)·β = rhs`
598    /// with the Gram, its Cholesky factor, and the RHS all kept DEVICE-RESIDENT —
599    /// only `w` (`n`), `rhs` (`p`), and the solution `β` (`p`) cross the bus.
600    ///
601    /// This is the #1017 Phase-3 fix for the next ceiling after [`Self::gram`]:
602    /// the bare Gram still pays a `p×p` D2H (134 MB at p=4096), but the SAE/IRLS
603    /// inner step only needs `β`, so chaining row-scale→GEMM→POTRF→TRSM on-device
604    /// and returning only the `p`-vector removes that transfer entirely. Returns
605    /// `None` on a shape mismatch, a non-PD Gram, or any device failure — the
606    /// caller then runs the CPU normal-equations solve. The numerics match a
607    /// host `Cholesky((XᵀWX+ridge·I))` solve up to IEEE-754 reduction order.
608    #[must_use]
609    pub fn solve_normal_equations(
610        &self,
611        w: ArrayView1<'_, f64>,
612        rhs: ArrayView1<'_, f64>,
613        ridge: f64,
614    ) -> Option<Array1<f64>> {
615        #[cfg(not(target_os = "linux"))]
616        {
617            // SAFETY: statically unreachable off CUDA (see `gram`); fail loudly.
618            panic!(
619                "ResidentDesignGram cannot be constructed off CUDA (w.len()={}, rhs.len()={}, ridge={ridge})",
620                w.len(),
621                rhs.len()
622            )
623        }
624        #[cfg(target_os = "linux")]
625        {
626            self.inner.solve_psd_normal_equations(w, rhs, ridge)
627        }
628    }
629
630    /// `(n, p)` of the resident design.
631    #[must_use]
632    pub fn dims(&self) -> (usize, usize) {
633        #[cfg(not(target_os = "linux"))]
634        {
635            // SAFETY: statically unreachable off CUDA (see `gram`) — no `Self`
636            // is ever constructed on this target; fail loudly rather than
637            // return a benign sentinel.
638            panic!("ResidentDesignGram cannot be constructed off CUDA")
639        }
640        #[cfg(target_os = "linux")]
641        {
642            self.inner.dims()
643        }
644    }
645}
646
647/// Number of row-chunks to carve per device for the spectral leverage stream
648/// so [`super::pool::balanced_partition`] can keep every GPU busy. With fewer
649/// chunks than devices the pool would idle the surplus devices; oversubscribing
650/// modestly amortizes the per-tile launch without bloating staging memory.
651/// Magic-by-default; no flag.
652#[cfg(target_os = "linux")]
653const LEVERAGE_CHUNKS_PER_DEVICE: usize = 4;
654
655/// Byte-balanced row-chunk width for the spectral leverage stream, mirroring
656/// the CPU `byte_balanced_row_chunk` sizing (≈8 MiB live blocks) so a single
657/// tile's `(chunk × p)` row slice plus `(chunk × rank)` GEMM output stay within
658/// the per-device staging budget.
659#[cfg(target_os = "linux")]
660#[inline]
661fn leverage_chunk_rows(cols: usize, n_rows: usize) -> usize {
662    const TARGET_BYTES: usize = 8 * 1024 * 1024;
663    const MIN_CHUNK_ROWS: usize = 512;
664    let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
665    (TARGET_BYTES / bytes_per_row)
666        .max(MIN_CHUNK_ROWS)
667        .min(n_rows.max(1))
668}
669
670/// GPU-offloaded spectral leverage diagonal `h[i] = ‖(X G)_{i,:}‖²`.
671///
672/// `G` is the `(p × rank)` spectral factor with `G_ε(H) = G Gᵀ`; the per-row
673/// leverage is the squared norm of the i-th row of `X G`. This is the dominant
674/// n-dependent cost of every REML outer evaluation at large scale (issue
675/// #922), and historically ran only on the CPU while the device pool idled.
676///
677/// The row dimension is split into byte-balanced chunks scattered across the
678/// whole device pool via [`super::pool::scatter_batched`] — the same
679/// whole-solve row-block granularity as Arrow-Schur — and each tile runs one
680/// cuBLAS GEMM `X_chunk · G` on its bound ordinal before reducing row-wise
681/// sum-of-squares. The arithmetic is identical f64 to the CPU faer path (modulo
682/// IEEE-754 reduction order); on no device, a below-threshold shape, or any
683/// tile failure the function returns `None` and the caller runs its
684/// deterministic CPU stream.
685#[inline]
686#[must_use]
687pub fn try_fast_spectral_leverage_diagonal(
688    x: &gam_linalg::matrix::DesignMatrix,
689    g: ArrayView2<'_, f64>,
690) -> Option<Array1<f64>> {
691    let n = x.nrows();
692    let p = x.ncols();
693    let rank = g.ncols();
694    if n == 0 || p == 0 || rank == 0 || g.nrows() != p {
695        return None;
696    }
697    #[cfg(not(target_os = "linux"))]
698    {
699        return None;
700    }
701    #[cfg(target_os = "linux")]
702    {
703        // n·p² gate is shared with the X^T diag(w) X reduction — the leverage
704        // diagonal is the same O(n·p·rank)-class dense pass over the design.
705        let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
706        let device_count = runtime.device_count().max(1);
707        let byte_chunk = leverage_chunk_rows(p + rank, n);
708        let target_chunks = device_count
709            .saturating_mul(LEVERAGE_CHUNKS_PER_DEVICE)
710            .max(1);
711        let chunk_rows = byte_chunk.min(n.div_ceil(target_chunks).max(1)).max(1);
712
713        // One slot per row-chunk; the slot carries its row range and receives
714        // its own output buffer so each tile owns disjoint memory.
715        let mut tiles: Vec<(std::ops::Range<usize>, Option<Array1<f64>>)> = Vec::new();
716        let mut start = 0usize;
717        while start < n {
718            let end = (start + chunk_rows).min(n);
719            tiles.push((start..end, None));
720            start = end;
721        }
722
723        super::pool::scatter_batched(runtime, &mut tiles, |ordinal, tile| {
724            for (range, slot) in tile.iter_mut() {
725                let rows = x.try_row_chunk(range.clone()).ok()?;
726                let xg = cuda_backend::gemm_on_ordinal(ordinal, rows.view(), g, false, false)?;
727                let mut out = Array1::<f64>::zeros(range.end - range.start);
728                for (local, row) in xg.outer_iter().enumerate() {
729                    out[local] = row.iter().map(|&v| v * v).sum();
730                }
731                *slot = Some(out);
732            }
733            Some(())
734        })?;
735
736        let mut h = Array1::<f64>::zeros(n);
737        for (range, slot) in tiles {
738            let vals = slot?;
739            if vals.len() != range.end - range.start {
740                return None;
741            }
742            h.slice_mut(ndarray::s![range]).assign(&vals);
743        }
744        Some(h)
745    }
746}
747
748#[inline]
749#[must_use]
750pub fn try_fast_xt_diag_y(
751    x: ArrayView2<'_, f64>,
752    w: ArrayView1<'_, f64>,
753    y: ArrayView2<'_, f64>,
754) -> Option<Array2<f64>> {
755    let (n, px) = x.dim();
756    let (n_y, q) = y.dim();
757    if n != n_y || n != w.len() || n == 0 || px == 0 || q == 0 {
758        return None;
759    }
760    #[cfg(not(target_os = "linux"))]
761    {
762        return None;
763    }
764    #[cfg(target_os = "linux")]
765    {
766        let runtime = route_through_gpu(DispatchOp::XtDiagY { n, px, q })?;
767        cuda_backend::xt_diag_y(runtime, x, w, y)
768    }
769}
770
771#[inline]
772#[must_use]
773pub fn try_fast_joint_hessian_2x2(
774    x_a: ArrayView2<'_, f64>,
775    x_b: ArrayView2<'_, f64>,
776    w_aa: ArrayView1<'_, f64>,
777    w_ab: ArrayView1<'_, f64>,
778    w_bb: ArrayView1<'_, f64>,
779) -> Option<Array2<f64>> {
780    let (n, pa) = x_a.dim();
781    let (n_b, pb) = x_b.dim();
782    if n != n_b || n != w_aa.len() || n != w_ab.len() || n != w_bb.len() || pa + pb == 0 {
783        return None;
784    }
785    #[cfg(not(target_os = "linux"))]
786    {
787        return None;
788    }
789    #[cfg(target_os = "linux")]
790    {
791        let runtime = route_through_gpu(DispatchOp::JointHessian2x2 { n, pa, pb })?;
792        cuda_backend::joint_hessian_2x2(runtime, x_a, x_b, w_aa, w_ab, w_bb)
793    }
794}
795
796#[inline]
797#[must_use]
798pub fn try_cholesky_lower_inplace(a: &mut Array2<f64>) -> Option<()> {
799    let p = a.nrows();
800    if p != a.ncols() {
801        return None;
802    }
803    #[cfg(not(target_os = "linux"))]
804    {
805        return None;
806    }
807    #[cfg(target_os = "linux")]
808    {
809        let runtime = route_through_gpu(DispatchOp::Potrf { p, batch: 1 })?;
810        let lower = cuda_backend::cholesky_lower(runtime, a.view())?;
811        *a = lower;
812        Some(())
813    }
814}
815
816#[inline]
817#[must_use]
818pub fn try_cholesky_batched_lower_inplace(matrices: &mut [Array2<f64>]) -> Option<()> {
819    let first = matrices.first()?;
820    let p = first.nrows();
821    if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
822        return None;
823    }
824    #[cfg(not(target_os = "linux"))]
825    {
826        return None;
827    }
828    #[cfg(target_os = "linux")]
829    {
830        let batch = matrices.len();
831        let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p, batch })
832            .or_else(|| route_through_gpu(DispatchOp::Potrf { p, batch }))?;
833        if should_split_batch(batch) {
834            // `matrices` is already the per-item slice, so the batch dimension
835            // tiles directly onto `scatter_batched`: each device factors its own
836            // contiguous block of matrices in place. On any tile failure the
837            // whole batch is re-run on the primary device for determinism (the
838            // factored tiles are overwritten by the single-device pass).
839            let split = super::pool::scatter_batched(runtime, matrices, |ordinal, tile| {
840                cuda_backend::cholesky_batched_lower(ordinal, tile)
841            });
842            if split.is_some() {
843                return Some(());
844            }
845        }
846        cuda_backend::cholesky_batched_lower(runtime.device.ordinal, matrices)
847    }
848}
849
850#[inline]
851#[must_use]
852pub fn try_solve_lower_triangular_matrix(
853    lower: ArrayView2<'_, f64>,
854    rhs: ArrayView2<'_, f64>,
855) -> Option<Array2<f64>> {
856    let (m, n) = rhs.dim();
857    if m == 0 || n == 0 || lower.nrows() != m {
858        return None;
859    }
860    #[cfg(not(target_os = "linux"))]
861    {
862        return None;
863    }
864    #[cfg(target_os = "linux")]
865    {
866        let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
867        cuda_backend::trsm(runtime, lower, rhs, false)
868    }
869}
870
871#[inline]
872#[must_use]
873pub fn try_solve_upper_triangular_matrix(
874    upper: ArrayView2<'_, f64>,
875    rhs: ArrayView2<'_, f64>,
876) -> Option<Array2<f64>> {
877    let (m, n) = rhs.dim();
878    if m == 0 || n == 0 || upper.nrows() != m {
879        return None;
880    }
881    #[cfg(not(target_os = "linux"))]
882    {
883        return None;
884    }
885    #[cfg(target_os = "linux")]
886    {
887        let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
888        cuda_backend::trsm(runtime, upper, rhs, true)
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::{DispatchOp, route_through_gpu};
895    use crate::device_runtime::GpuRuntime;
896
897    #[test]
898    fn sae_shape_dispatch_ops_route_when_cuda_runtime_is_present() {
899        let Some(runtime) = GpuRuntime::global() else {
900            eprintln!("[sae dispatch gate] no CUDA runtime - skipping branch-admission check");
901            return;
902        };
903
904        let n = 2_000usize;
905        let p = 2_048usize;
906        let m = 12usize;
907        let k = 8usize;
908        let dense_reduction_ops = [
909            DispatchOp::XtDiagX { n, p },
910            DispatchOp::XtDiagY { n, px: p, q: m * k },
911            DispatchOp::JointHessian2x2 {
912                n,
913                pa: p,
914                pb: m * k,
915            },
916            DispatchOp::Gemm {
917                m: p,
918                n: p,
919                k: n * m,
920            },
921        ];
922
923        for op in dense_reduction_ops {
924            assert!(
925                op.flops() >= runtime.policy.gemm_min_flops as u128,
926                "SAE dispatch fixture must clear the runtime GEMM work floor: op={op:?}, flops={}, floor={}",
927                op.flops(),
928                runtime.policy.gemm_min_flops
929            );
930            assert!(
931                route_through_gpu(op).is_some(),
932                "SAE dispatch fixture should route to GPU when CUDA is present: {op:?}"
933            );
934        }
935
936        let batched_potrf = DispatchOp::SmallDenseBatchedPotrf { p: m, batch: n };
937        assert!(
938            route_through_gpu(batched_potrf).is_some(),
939            "uniform SAE row blocks should reach the small-dense batched POTRF gate"
940        );
941    }
942}
943
944// ---------------------------------------------------------------------------
945// Backend selection. The wrappers keep CUDA types out of solver modules while
946// delegating to cudarc-backed BLAS, solver, and custom kernel implementations.
947// ---------------------------------------------------------------------------
948
949#[cfg(target_os = "linux")]
950mod cuda_backend {
951    //! CUDA-backed implementations of the dispatch entry points.
952    //!
953    //! The real device kernels live in `super::super::blas` and
954    //! `super::super::kernels::*`; this module simply forwards. When the
955    //! lower layer reports an unrecoverable backend error (OOM, transient
956    //! launch failure, …) the wrapper returns `None` so the CPU fast path
957    //! is exercised — there is never a silent panic, and the numerical
958    //! result is identical to the CPU code modulo IEEE-754 reduction order.
959
960    use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
961
962    use super::super::device_runtime::GpuRuntime;
963    use crate::driver::{from_col_major, to_col_major, to_i32};
964    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
965    use cudarc::driver::{DevicePtrMut, sys as driver_sys};
966
967    #[inline]
968    pub(super) fn gemm(
969        runtime: &GpuRuntime,
970        a: ArrayView2<'_, f64>,
971        b: ArrayView2<'_, f64>,
972        trans_a: bool,
973        trans_b: bool,
974    ) -> Option<Array2<f64>> {
975        super::super::blas::gemm_cuda(runtime, a, b, trans_a, trans_b)
976    }
977
978    #[inline]
979    pub(super) fn gemm_on_ordinal(
980        ordinal: usize,
981        a: ArrayView2<'_, f64>,
982        b: ArrayView2<'_, f64>,
983        trans_a: bool,
984        trans_b: bool,
985    ) -> Option<Array2<f64>> {
986        super::super::blas::gemm_on_ordinal_cuda(ordinal, a, b, trans_a, trans_b)
987    }
988
989    #[inline]
990    pub(super) fn gemv(
991        runtime: &GpuRuntime,
992        a: ArrayView2<'_, f64>,
993        v: ArrayView1<'_, f64>,
994        trans_a: bool,
995    ) -> Option<Array1<f64>> {
996        super::super::blas::gemv_cuda(runtime, a, v, trans_a)
997    }
998
999    #[inline]
1000    pub(super) fn gemm_broadcast_b_batched(
1001        ordinal: usize,
1002        a: ArrayView3<'_, f64>,
1003        b: ArrayView2<'_, f64>,
1004    ) -> Option<Array3<f64>> {
1005        super::super::blas::gemm_broadcast_b_batched_cuda(ordinal, a, b)
1006    }
1007
1008    #[inline]
1009    pub(super) fn gemm_abt_strided_batched(
1010        ordinal: usize,
1011        a: ArrayView3<'_, f64>,
1012        b: ArrayView3<'_, f64>,
1013    ) -> Option<Array3<f64>> {
1014        super::super::blas::gemm_abt_strided_batched_cuda(ordinal, a, b)
1015    }
1016
1017    #[inline]
1018    pub(super) fn xt_diag_x(
1019        runtime: &GpuRuntime,
1020        x: ArrayView2<'_, f64>,
1021        w: ArrayView1<'_, f64>,
1022    ) -> Option<Array2<f64>> {
1023        super::super::blas::xt_diag_x_cuda(runtime, x, w)
1024    }
1025
1026    #[inline]
1027    pub(super) fn xt_diag_y(
1028        runtime: &GpuRuntime,
1029        x: ArrayView2<'_, f64>,
1030        w: ArrayView1<'_, f64>,
1031        y: ArrayView2<'_, f64>,
1032    ) -> Option<Array2<f64>> {
1033        super::super::blas::xt_diag_y_cuda(runtime, x, w, y)
1034    }
1035
1036    #[inline]
1037    pub(super) fn joint_hessian_2x2(
1038        runtime: &GpuRuntime,
1039        x_a: ArrayView2<'_, f64>,
1040        x_b: ArrayView2<'_, f64>,
1041        w_aa: ArrayView1<'_, f64>,
1042        w_ab: ArrayView1<'_, f64>,
1043        w_bb: ArrayView1<'_, f64>,
1044    ) -> Option<Array2<f64>> {
1045        super::super::blas::joint_hessian_2x2_cuda(runtime, x_a, x_b, w_aa, w_ab, w_bb)
1046    }
1047
1048    #[inline]
1049    pub(super) fn trsm(
1050        runtime: &GpuRuntime,
1051        triangular: ArrayView2<'_, f64>,
1052        rhs: ArrayView2<'_, f64>,
1053        upper: bool,
1054    ) -> Option<Array2<f64>> {
1055        super::super::blas::trsm_cuda(runtime, triangular, rhs, upper)
1056    }
1057
1058    #[inline]
1059    pub(super) fn cholesky_lower(
1060        runtime: &GpuRuntime,
1061        a: ArrayView2<'_, f64>,
1062    ) -> Option<Array2<f64>> {
1063        let (p, p2) = a.dim();
1064        if p == 0 || p != p2 {
1065            return None;
1066        }
1067        let stream = super::super::device_runtime::cuda_context_for(runtime.device.ordinal)?
1068            .new_stream()
1069            .ok()?;
1070        let solver = DnHandle::new(stream.clone()).ok()?;
1071        let a_col = to_col_major(&a);
1072        let mut a_dev = stream.clone_htod(&*a_col).ok()?;
1073        potrf_lower_in_place(&solver, &stream, p, &mut a_dev)?;
1074        let factor_col = stream.clone_dtoh(&a_dev).ok()?;
1075        let mut lower = from_col_major(&factor_col, p, p)?;
1076        for row in 0..p {
1077            for col in (row + 1)..p {
1078                lower[[row, col]] = 0.0;
1079            }
1080        }
1081        Some(lower)
1082    }
1083
1084    /// Batched lower-Cholesky on a specific device ordinal. The ordinal's
1085    /// context is expected to be bound on the calling thread (multi-GPU
1086    /// `scatter_batched` worker or the single-device dispatcher).
1087    #[inline]
1088    pub(super) fn cholesky_batched_lower(
1089        ordinal: usize,
1090        matrices: &mut [Array2<f64>],
1091    ) -> Option<()> {
1092        let first = matrices.first()?;
1093        let p = first.nrows();
1094        if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
1095            return None;
1096        }
1097
1098        let stream = super::super::device_runtime::cuda_context_for(ordinal)?
1099            .new_stream()
1100            .ok()?;
1101        let solver = DnHandle::new(stream.clone()).ok()?;
1102        let matrix_len = p.checked_mul(p)?;
1103        let mut batch_col = Vec::with_capacity(matrices.len().checked_mul(matrix_len)?);
1104        for matrix in matrices.iter() {
1105            batch_col.extend(to_col_major(&matrix.view()).iter().copied());
1106        }
1107        let mut matrices_dev = stream.clone_htod(&batch_col).ok()?;
1108        let matrix_ptrs = {
1109            let (base_ptr, _matrix_record) = matrices_dev.device_ptr_mut(&stream);
1110            let bytes_per_matrix = driver_sys::CUdeviceptr::try_from(
1111                matrix_len.checked_mul(std::mem::size_of::<f64>())?,
1112            )
1113            .ok()?;
1114            let mut matrix_ptrs = Vec::with_capacity(matrices.len());
1115            for idx in 0..matrices.len() {
1116                let offset = driver_sys::CUdeviceptr::try_from(idx).ok()? * bytes_per_matrix;
1117                matrix_ptrs.push(base_ptr + offset);
1118            }
1119            matrix_ptrs
1120        };
1121        let mut matrix_ptrs_dev = stream.clone_htod(&matrix_ptrs).ok()?;
1122        let mut info_dev = stream.alloc_zeros::<i32>(matrices.len()).ok()?;
1123        let p_i = to_i32(p)?;
1124        let batch_i = to_i32(matrices.len())?;
1125        {
1126            let (ptrs_ptr, _ptrs_record) = matrix_ptrs_dev.device_ptr_mut(&stream);
1127            let (info_ptr, _info_record) = info_dev.device_ptr_mut(&stream);
1128            // SAFETY: `ptrs_ptr` points to a device array of batch pointers,
1129            // each pointer targets a live p×p column-major matrix in
1130            // `matrices_dev`, and `info_dev` has one entry per batch item.
1131            let status = unsafe {
1132                cusolver_sys::cusolverDnDpotrfBatched(
1133                    solver.cu(),
1134                    cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1135                    p_i,
1136                    ptrs_ptr as *mut *mut f64,
1137                    p_i,
1138                    info_ptr as *mut i32,
1139                    batch_i,
1140                )
1141            };
1142            check_cusolver(status)?;
1143        }
1144        let info_host = stream.clone_dtoh(&info_dev).ok()?;
1145        if info_host.iter().any(|info| *info != 0) {
1146            return None;
1147        }
1148        let factored_col = stream.clone_dtoh(&matrices_dev).ok()?;
1149        for (idx, matrix) in matrices.iter_mut().enumerate() {
1150            let start = idx.checked_mul(matrix_len)?;
1151            let end = start.checked_add(matrix_len)?;
1152            let mut lower = from_col_major(&factored_col[start..end], p, p)?;
1153            for row in 0..p {
1154                for col in (row + 1)..p {
1155                    lower[[row, col]] = 0.0;
1156                }
1157            }
1158            *matrix = lower;
1159        }
1160        Some(())
1161    }
1162
1163    /// Single-matrix lower Cholesky POTRF. Thin `Result → Option` adapter over
1164    /// the shared precision-generic core in `solver.rs`
1165    /// ([`crate::solver::potrf_in_place_generic`]) so the cuSOLVER
1166    /// bufferSize/POTRF/info scaffold lives in exactly one place. The batched
1167    /// variant (`cusolverDnDpotrfBatched`) above is kept separate by design.
1168    fn potrf_lower_in_place(
1169        solver: &DnHandle,
1170        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1171        p: usize,
1172        a: &mut cudarc::driver::CudaSlice<f64>,
1173    ) -> Option<()> {
1174        crate::solver::potrf_in_place_generic::<f64>(solver, stream, p, a).ok()
1175    }
1176
1177    #[inline]
1178    fn check_cusolver(status: cusolver_sys::cusolverStatus_t) -> Option<()> {
1179        if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1180            Some(())
1181        } else {
1182            None
1183        }
1184    }
1185}