Skip to main content

gam_gpu/
blas.rs

1//! Device BLAS surface for the cudarc-backed dense kernels.
2//!
3//! The public surface here is the lowest level of the GPU dispatch stack: it
4//! takes ndarray views, copies them to a device buffer, calls a cuBLAS / kernel
5//! routine, and returns the host result. The cudarc-backed implementations
6//! always compile (cudarc dynamically loads `libcuda` at runtime via the
7//! `fallback-dynamic-loading` feature), and dispatch is gated at runtime on
8//! `super::device_runtime::GpuRuntime::global()` — when no device is probed the
9//! status enum advertises `CudaUnavailable` and callers fall back to CPU.
10//!
11//! The implementations route through `super::device_runtime::cuda_context_for` and
12//! the cudarc 0.19 cuBLAS API. Any transient backend failure (OOM, launch
13//! error, …) is converted to `None` so the auto-dispatch shim in
14//! `super::linalg` falls back to the CPU fast path without disturbing
15//! numerics.
16
17pub fn blas_backend_status() -> super::CudaBackendStatus {
18    super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23    use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25    use crate::driver::{
26        array_from_row_major, from_col_major, to_col_major, to_i32, to_row_major,
27    };
28
29    use super::super::device_runtime::GpuRuntime;
30    use cudarc::cublas::sys::{
31        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
32    };
33    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
34    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
35    use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
36    use std::sync::Arc;
37
38    /// Create a fresh stream + cuBLAS handle bound to a specific device
39    /// ordinal. This is the per-ordinal entry point used by multi-GPU fan-out
40    /// (`super::super::pool::scatter_batched` workers): the worker thread has
41    /// already bound that ordinal's context, and the stream/handle created here
42    /// target the same device. The single-device helper below is the
43    /// primary-ordinal specialization.
44    #[inline]
45    pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
46        let stream = super::super::device_runtime::cuda_context_for(ordinal)?
47            .new_stream()
48            .ok()?;
49        let blas = CudaBlas::new(stream.clone()).ok()?;
50        Some((stream, blas))
51    }
52
53    #[inline]
54    fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
55        stream_and_blas_for(runtime.device.ordinal)
56    }
57
58    #[inline]
59    fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
60        v.iter().copied().collect()
61    }
62
63    #[inline]
64    fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
65        let (batch_len, rows, cols) = batch.dim();
66        let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
67        for matrix in batch.axis_iter(Axis(0)) {
68            out.extend(to_col_major(&matrix).iter().copied());
69        }
70        out
71    }
72
73    #[inline]
74    fn from_col_major_batch(
75        data: &[f64],
76        batch: usize,
77        rows: usize,
78        cols: usize,
79    ) -> Option<Array3<f64>> {
80        if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
81            return None;
82        }
83        let mut out = Array3::<f64>::zeros((batch, rows, cols));
84        let matrix_len = rows.checked_mul(cols)?;
85        for batch_idx in 0..batch {
86            let base = batch_idx.checked_mul(matrix_len)?;
87            for col in 0..cols {
88                for row in 0..rows {
89                    out[[batch_idx, row, col]] = data[base + col * rows + row];
90                }
91            }
92        }
93        Some(out)
94    }
95
96    #[inline]
97    fn row_scale_device(
98        blas: &CudaBlas,
99        stream: &Arc<CudaStream>,
100        matrix_dev: &CudaSlice<f64>,
101        weights_dev: &CudaSlice<f64>,
102        scaled_dev: &mut CudaSlice<f64>,
103        rows: usize,
104        cols: usize,
105    ) -> Option<()> {
106        let rows_i = to_i32(rows)?;
107        let cols_i = to_i32(cols)?;
108        let handle = *blas.handle();
109        let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
110        let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
111        let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
112        // SAFETY: all device slices are on this stream/context. `matrix_dev`
113        // and `scaled_dev` are rows×cols column-major matrices with lda/ldc
114        // equal to rows; `weights_dev` has one contiguous value per row.
115        let status = unsafe {
116            cudarc::cublas::sys::cublasDdgmm(
117                handle,
118                cublasSideMode_t::CUBLAS_SIDE_LEFT,
119                rows_i,
120                cols_i,
121                matrix_ptr as *const f64,
122                rows_i,
123                weights_ptr as *const f64,
124                1,
125                scaled_ptr as *mut f64,
126                rows_i,
127            )
128        };
129        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
130            Some(())
131        } else {
132            None
133        }
134    }
135
136    #[inline]
137    fn weighted_crossprod(
138        runtime: &GpuRuntime,
139        left: ArrayView2<'_, f64>,
140        weights: ArrayView1<'_, f64>,
141        right: ArrayView2<'_, f64>,
142    ) -> Option<Array2<f64>> {
143        weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
144    }
145
146    #[inline]
147    fn weighted_crossprod_for(
148        ordinal: usize,
149        left: ArrayView2<'_, f64>,
150        weights: ArrayView1<'_, f64>,
151        right: ArrayView2<'_, f64>,
152    ) -> Option<Array2<f64>> {
153        let (rows, left_cols) = left.dim();
154        let (right_rows, right_cols) = right.dim();
155        if rows == 0
156            || left_cols == 0
157            || right_cols == 0
158            || rows != right_rows
159            || rows != weights.len()
160        {
161            return None;
162        }
163
164        let (stream, blas) = stream_and_blas_for(ordinal)?;
165        // #1412: the symmetric Gram `Xᵀ·diag(w)·X` (xt_diag_x) passes the SAME
166        // array as `left` and `right`. Detect that (identical data pointer +
167        // shape) and stage `X` ONCE instead of column-majoring and H2D-uploading
168        // two byte-identical n×p copies — halving the dominant H2D for the Gram.
169        // The GEMM operands are unchanged (`left_dev` doubles as the row-scale
170        // source), so the result is bit-identical to the two-upload path.
171        let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
172            && left.dim() == right.dim()
173            && left.strides() == right.strides();
174        let left_col = to_col_major(&left);
175        let weights_host = vector_values(weights);
176        let left_dev = stream.clone_htod(&*left_col).ok()?;
177        // Symmetric Gram: `right` IS `left`, so row-scale directly from the
178        // single resident `left_dev` and never upload a second n×p copy. The
179        // asymmetric path uploads `right` as before.
180        let right_dev = if same_operand {
181            None
182        } else {
183            let right_col = to_col_major(&right);
184            Some(stream.clone_htod(&*right_col).ok()?)
185        };
186        let weights_dev = stream.clone_htod(&weights_host).ok()?;
187        let mut weighted_right_dev = stream
188            .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
189            .ok()?;
190        row_scale_device(
191            &blas,
192            &stream,
193            right_dev.as_ref().unwrap_or(&left_dev),
194            &weights_dev,
195            &mut weighted_right_dev,
196            rows,
197            right_cols,
198        )?;
199
200        let mut out_dev = stream
201            .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
202            .ok()?;
203        let cfg = GemmConfig::<f64> {
204            transa: cublasOperation_t::CUBLAS_OP_T,
205            transb: cublasOperation_t::CUBLAS_OP_N,
206            m: to_i32(left_cols)?,
207            n: to_i32(right_cols)?,
208            k: to_i32(rows)?,
209            alpha: 1.0,
210            lda: to_i32(rows)?,
211            ldb: to_i32(rows)?,
212            beta: 0.0,
213            ldc: to_i32(left_cols)?,
214        };
215        // SAFETY: cfg computes leftᵀ (left_cols×rows) times weighted_right
216        // (rows×right_cols) into a left_cols×right_cols column-major output.
217        unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
218        let out_col = stream.clone_dtoh(&out_dev).ok()?;
219        from_col_major(&out_col, left_cols, right_cols)
220    }
221
222    /// #1017 Phase 3: a device-resident design matrix `X` whose `n×p` values are
223    /// uploaded to the device ONCE and reused across many `Xᵀ·diag(w)·X` Gram
224    /// evaluations.
225    ///
226    /// The per-call [`xt_diag_x_cuda`] path re-uploads the full `n×p` `X` (and a
227    /// second copy as the `right` operand) on EVERY call. For the SAE / IRLS
228    /// inner loop — where `X` is frozen across weight updates and the Gram is
229    /// rebuilt once per Newton/PIRLS step — that H2D staging dominates the wall
230    /// clock (measured #1412: the `XtWX` GEMM is ~98% of the pipeline at <20% GPU
231    /// utilisation, i.e. the device is starved by the per-call upload, not the
232    /// arithmetic). Uploading `X` once and crossing only the `n`-vector `w` (and
233    /// the `p×p` result) per call removes that ping-pong: the resident `X` is
234    /// `n·p` doubles vs the per-call `w` of `n` doubles, so the amortised
235    /// transfer per Gram drops by a factor of `p`.
236    pub(crate) struct ResidentWeightedGram {
237        stream: Arc<CudaStream>,
238        blas: CudaBlas,
239        x_dev: CudaSlice<f64>,
240        rows: usize,
241        cols: usize,
242    }
243
244    impl ResidentWeightedGram {
245        /// Upload `x` (`n×p`) to `ordinal` once, column-major, and keep it
246        /// resident. Returns `None` on a degenerate shape or any device failure
247        /// (the caller falls back to the per-call CPU/GPU path).
248        pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
249            let (rows, cols) = x.dim();
250            if rows == 0 || cols == 0 {
251                return None;
252            }
253            let (stream, blas) = stream_and_blas_for(ordinal)?;
254            let x_col = to_col_major(&x);
255            let x_dev = stream.clone_htod(&*x_col).ok()?;
256            Some(Self {
257                stream,
258                blas,
259                x_dev,
260                rows,
261                cols,
262            })
263        }
264
265        #[inline]
266        pub(crate) fn dims(&self) -> (usize, usize) {
267            (self.rows, self.cols)
268        }
269
270        /// Compute `Xᵀ·diag(w)·X` reusing the resident `X`. Only `w` (`n`
271        /// doubles) crosses H2D and only the `p×p` Gram crosses D2H. The
272        /// arithmetic is bit-identical to [`xt_diag_x_cuda`] on the same device
273        /// (same `cublasDdgmm` row-scale + same `gemm` reduction order).
274        pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
275            if w.len() != self.rows {
276                return None;
277            }
278            let weights_host = vector_values(w);
279            let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
280            let mut weighted_dev = self
281                .stream
282                .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
283                .ok()?;
284            row_scale_device(
285                &self.blas,
286                &self.stream,
287                &self.x_dev,
288                &weights_dev,
289                &mut weighted_dev,
290                self.rows,
291                self.cols,
292            )?;
293            let mut out_dev = self
294                .stream
295                .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
296                .ok()?;
297            let cfg = GemmConfig::<f64> {
298                transa: cublasOperation_t::CUBLAS_OP_T,
299                transb: cublasOperation_t::CUBLAS_OP_N,
300                m: to_i32(self.cols)?,
301                n: to_i32(self.cols)?,
302                k: to_i32(self.rows)?,
303                alpha: 1.0,
304                lda: to_i32(self.rows)?,
305                ldb: to_i32(self.rows)?,
306                beta: 0.0,
307                ldc: to_i32(self.cols)?,
308            };
309            // SAFETY: `x_dev` is the resident n×p column-major design; cfg forms
310            // Xᵀ (p×n) · weighted (n×p) → a p×p column-major Gram.
311            unsafe {
312                self.blas
313                    .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
314            }
315            .ok()?;
316            let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
317            from_col_major(&out_col, self.cols, self.cols)
318        }
319
320        /// Compute the resident weighted Gram `G = Xᵀ·diag(w)·X + ridge·I`,
321        /// factor it (cuSOLVER POTRF), and solve `G·β = rhs` — keeping `G`, its
322        /// Cholesky factor, and the RHS all DEVICE-RESIDENT. Only `w` (`n`),
323        /// `rhs` (`p`), and the result `β` (`p`) cross the PCIe boundary; the
324        /// `p×p` Gram is NEVER downloaded.
325        ///
326        /// This is the #1017 Phase-3 ceiling fix for the normal-equations solve:
327        /// the per-call [`gram`] still pays a `p×p` D2H (134 MB at p=4096 — the
328        /// next bottleneck once `X` is resident), whereas the SAE/IRLS inner step
329        /// only needs the `p`-vector `β = (XᵀWX+λ)⁻¹ XᵀWz`. Chaining
330        /// row-scale→GEMM→POTRF→TRSM on-device and returning only `β` removes the
331        /// Gram transfer entirely.
332        ///
333        /// `ridge` (e.g. the penalty diagonal `λ` or a Tikhonov floor) is seeded
334        /// as `ridge·I` on the device and the Gram is GEMM-accumulated onto it
335        /// (`beta = 1`), so the diagonal bump never costs a Gram round-trip.
336        /// Returns `None` on shape mismatch, a non-PD factorisation, or any
337        /// device failure (the caller falls back to the CPU solve).
338        pub(crate) fn solve_psd_normal_equations(
339            &self,
340            w: ArrayView1<'_, f64>,
341            rhs: ArrayView1<'_, f64>,
342            ridge: f64,
343        ) -> Option<Array1<f64>> {
344            if w.len() != self.rows || rhs.len() != self.cols {
345                return None;
346            }
347            let p = self.cols;
348
349            // weighted = diag(w) · X  (resident X row-scaled).
350            let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
351            let mut weighted_dev = self
352                .stream
353                .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
354                .ok()?;
355            row_scale_device(
356                &self.blas,
357                &self.stream,
358                &self.x_dev,
359                &weights_dev,
360                &mut weighted_dev,
361                self.rows,
362                p,
363            )?;
364
365            // Pre-seed G with `ridge·I` on the device, then GEMM-accumulate
366            // `XᵀW X` onto it with `beta = 1.0`. The Gram is formed and stays
367            // device-resident: `ridge·I` is a one-time H2D upload (the only way
368            // to set a diagonal without an NVRTC kernel), and crucially the p×p
369            // Gram is NEVER read back — only `β` returns. `ridge·I` upload is
370            // bandwidth-trivial vs the avoided per-solve Gram download.
371            let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
372            for i in 0..p {
373                ridge_init[i * p + i] = ridge;
374            }
375            let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
376            let cfg = GemmConfig::<f64> {
377                transa: cublasOperation_t::CUBLAS_OP_T,
378                transb: cublasOperation_t::CUBLAS_OP_N,
379                m: to_i32(p)?,
380                n: to_i32(p)?,
381                k: to_i32(self.rows)?,
382                alpha: 1.0,
383                lda: to_i32(self.rows)?,
384                ldb: to_i32(self.rows)?,
385                // Accumulate onto the resident ridge·I seed.
386                beta: 1.0,
387                ldc: to_i32(p)?,
388            };
389            // SAFETY: resident n×p X and the n×p weighted buffer form a p×p Gram;
390            // beta=1 accumulates Xᵀ(WX) onto the resident ridge·I in g_dev.
391            unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
392
393            // POTRF(G) → lower factor L, resident in g_dev.
394            let solver = DnHandle::new(self.stream.clone()).ok()?;
395            let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
396            if info != 0 {
397                // Not positive-definite at pivot `info`; caller falls back.
398                return None;
399            }
400
401            // Solve L Lᵀ β = rhs via two triangular solves, β resident in rhs_dev.
402            let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
403            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; // L y = rhs
404            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; // Lᵀ β = y
405
406            // Download ONLY the p-vector solution.
407            let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
408            Some(Array1::from_vec(beta_host))
409        }
410    }
411
412    /// Single cuSOLVER `DPOTRF` (lower) of a resident `p×p` column-major matrix,
413    /// factored in place. Returns the cuSOLVER `info` (0 = success, k>0 = the
414    /// leading minor of order k is not PD). Mirrors the arrow-Schur frame POTRF.
415    fn potrf_single_dev(
416        solver: &DnHandle,
417        stream: &Arc<CudaStream>,
418        p: usize,
419        matrix: &mut CudaSlice<f64>,
420    ) -> Option<i32> {
421        let p_i = to_i32(p)?;
422        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
423        let mut lwork = 0_i32;
424        {
425            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
426            // SAFETY: buffer-size query against a live p×p column-major matrix.
427            let status = unsafe {
428                cusolver_sys::cusolverDnDpotrf_bufferSize(
429                    solver.cu(),
430                    uplo,
431                    p_i,
432                    mat_ptr as *mut f64,
433                    p_i,
434                    &mut lwork,
435                )
436            };
437            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
438                return None;
439            }
440        }
441        let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
442        let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
443        {
444            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
445            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
446            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
447            // SAFETY: all buffers live on this stream; matrix is p×p column-major.
448            let status = unsafe {
449                cusolver_sys::cusolverDnDpotrf(
450                    solver.cu(),
451                    uplo,
452                    p_i,
453                    mat_ptr as *mut f64,
454                    p_i,
455                    work_ptr as *mut f64,
456                    lwork,
457                    info_ptr as *mut i32,
458                )
459            };
460            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
461                return None;
462            }
463        }
464        let info_host = stream.clone_dtoh(&info_dev).ok()?;
465        info_host.first().copied()
466    }
467
468    /// Triangular solve `op(L)·x = b` for a single `p`-vector RHS against a
469    /// resident lower Cholesky factor `L` (`p×p` column-major), in place over
470    /// `rhs`. `transposed` selects `Lᵀ` (the second back-substitution).
471    fn trsm_single_vec(
472        blas: &CudaBlas,
473        stream: &Arc<CudaStream>,
474        p: usize,
475        l: &CudaSlice<f64>,
476        rhs: &mut CudaSlice<f64>,
477        transposed: bool,
478    ) -> Option<()> {
479        let alpha = 1.0_f64;
480        let p_i = to_i32(p)?;
481        let handle = *blas.handle();
482        let (l_ptr, _l_rec) = l.device_ptr(stream);
483        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
484        // SAFETY: p×p lower factor and a single p-vector RHS, both resident.
485        let status = unsafe {
486            cudarc::cublas::sys::cublasDtrsm_v2(
487                handle,
488                cublasSideMode_t::CUBLAS_SIDE_LEFT,
489                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
490                if transposed {
491                    cublasOperation_t::CUBLAS_OP_T
492                } else {
493                    cublasOperation_t::CUBLAS_OP_N
494                },
495                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
496                p_i,
497                1,
498                &alpha,
499                l_ptr as *const f64,
500                p_i,
501                rhs_ptr as *mut f64,
502                p_i,
503            )
504        };
505        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
506            Some(())
507        } else {
508            None
509        }
510    }
511
512    #[inline]
513    fn assign_block(
514        out: &mut Array2<f64>,
515        row_offset: usize,
516        col_offset: usize,
517        block: &Array2<f64>,
518    ) {
519        let (rows, cols) = block.dim();
520        for col in 0..cols {
521            for row in 0..rows {
522                out[[row_offset + row, col_offset + col]] = block[[row, col]];
523            }
524        }
525    }
526
527    #[inline]
528    fn mirror_upper_to_lower(out: &mut Array2<f64>) {
529        let n = out.nrows();
530        for row in 0..n {
531            for col in 0..row {
532                out[[row, col]] = out[[col, row]];
533            }
534        }
535    }
536
537    #[inline]
538    pub(crate) fn gemm_cuda(
539        runtime: &GpuRuntime,
540        a: ArrayView2<'_, f64>,
541        b: ArrayView2<'_, f64>,
542        trans_a: bool,
543        trans_b: bool,
544    ) -> Option<Array2<f64>> {
545        gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
546    }
547
548    /// Dense GEMM (optionally transposing either operand) on a specific device
549    /// ordinal. The ordinal's context is expected to be bound on the calling
550    /// thread (pool-tiled callers via `super::super::pool::scatter_batched`, or
551    /// the single-device dispatcher through [`gemm_cuda`]). Semantics are
552    /// identical to [`gemm_cuda`]; only the target device differs.
553    #[inline]
554    pub(crate) fn gemm_on_ordinal_cuda(
555        ordinal: usize,
556        a: ArrayView2<'_, f64>,
557        b: ArrayView2<'_, f64>,
558        trans_a: bool,
559        trans_b: bool,
560    ) -> Option<Array2<f64>> {
561        let (a_rows, a_cols) = a.dim();
562        let (b_rows, b_cols) = b.dim();
563        let (m, k_a) = if trans_a {
564            (a_cols, a_rows)
565        } else {
566            (a_rows, a_cols)
567        };
568        let (k_b, n) = if trans_b {
569            (b_cols, b_rows)
570        } else {
571            (b_rows, b_cols)
572        };
573        if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
574            return None;
575        }
576        let (stream, blas) = stream_and_blas_for(ordinal)?;
577        // Host-transpose-free path. The row-major output buffer of
578        // `C = op(A)·op(B)` (shape m×n) is bit-identical to the column-major
579        // buffer of `Cᵀ = op(B)ᵀ·op(A)ᵀ` (shape n×m). A row-major buffer,
580        // reinterpreted column-major, is already the transpose of its logical
581        // matrix — so uploading `a`/`b` row-major (a borrow when C-contiguous)
582        // gives cuBLAS `Aᵀ`/`Bᵀ` for free, and downloading straight into a
583        // row-major `Array2` skips the result permutation too. This removes the
584        // two O(rows·cols) scalar `to_col_major`/`from_col_major` passes that
585        // dominated tall-skinny GEMMs (e.g. the 200000×200 Wahba design
586        // reduction) and made the device path slower than the host SIMD GEMM.
587        //
588        // Uploading the row-major buffer of an `(r×c)` array and declaring it
589        // column-major with leading dim `c` hands cuBLAS exactly that array's
590        // transpose (a `(c×r)` col-major matrix). So with
591        //   X = b's row-major buffer  → col-major X = bᵀ  (rows = b_cols),
592        //   Y = a's row-major buffer  → col-major Y = aᵀ  (rows = a_cols),
593        // cuBLAS's `out = opX(X)·opY(Y)` yields `Cᵀ` when
594        //   opX = trans_b ? T : N,   opY = trans_a ? T : N,
595        //   (M,N,K) = (n, m, k),   lda = b_cols, ldb = a_cols, ldc = n.
596        // The column-major `Cᵀ` buffer (n rows) is bit-identical to the
597        // row-major `C` buffer (m×n), so the download wraps with no permute.
598        let b_rm = to_row_major(&b);
599        let a_rm = to_row_major(&a);
600        let x_dev = stream.clone_htod(&*b_rm).ok()?;
601        let y_dev = stream.clone_htod(&*a_rm).ok()?;
602        let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
603        let cfg = GemmConfig::<f64> {
604            transa: if trans_b {
605                cublasOperation_t::CUBLAS_OP_T
606            } else {
607                cublasOperation_t::CUBLAS_OP_N
608            },
609            transb: if trans_a {
610                cublasOperation_t::CUBLAS_OP_T
611            } else {
612                cublasOperation_t::CUBLAS_OP_N
613            },
614            m: to_i32(n)?,
615            n: to_i32(m)?,
616            k: to_i32(k_a)?,
617            alpha: 1.0,
618            // Leading dim of each physically-stored col-major operand =
619            // its row count: B̌ has `b_cols` rows, Ǎ has `a_cols` rows.
620            lda: to_i32(b_cols)?,
621            ldb: to_i32(a_cols)?,
622            beta: 0.0,
623            ldc: to_i32(n)?,
624        };
625        // SAFETY: dims validated above; buffers carry exactly the row counts
626        // declared as leading dimensions.
627        unsafe { blas.gemm(cfg, &x_dev, &y_dev, &mut out_dev) }.ok()?;
628        // `out_dev` is `Cᵀ` column-major == `C` row-major: wrap with no permute.
629        let out_rm = stream.clone_dtoh(&out_dev).ok()?;
630        array_from_row_major(out_rm, m, n)
631    }
632
633    /// Broadcast-B batched GEMM on a specific device ordinal. The caller
634    /// (`super::super::pool::scatter_batched` worker, or the single-device
635    /// dispatcher) supplies the ordinal whose context is already bound on this
636    /// thread; the stream/handle are created on that same device.
637    #[inline]
638    pub(crate) fn gemm_broadcast_b_batched_cuda(
639        ordinal: usize,
640        a: ArrayView3<'_, f64>,
641        b: ArrayView2<'_, f64>,
642    ) -> Option<Array3<f64>> {
643        let (batch, m, k) = a.dim();
644        let (b_rows, n) = b.dim();
645        if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
646            return None;
647        }
648        let (stream, blas) = stream_and_blas_for(ordinal)?;
649        let a_col = to_col_major_batch(a);
650        let b_col = to_col_major(&b);
651        let a_dev = stream.clone_htod(&a_col).ok()?;
652        let b_dev = stream.clone_htod(&*b_col).ok()?;
653        let mut out_dev = stream
654            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
655            .ok()?;
656        let cfg = StridedBatchedConfig::<f64> {
657            gemm: GemmConfig::<f64> {
658                transa: cublasOperation_t::CUBLAS_OP_N,
659                transb: cublasOperation_t::CUBLAS_OP_N,
660                m: to_i32(m)?,
661                n: to_i32(n)?,
662                k: to_i32(k)?,
663                alpha: 1.0,
664                lda: to_i32(m)?,
665                ldb: to_i32(k)?,
666                beta: 0.0,
667                ldc: to_i32(m)?,
668            },
669            batch_size: to_i32(batch)?,
670            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
671            stride_b: 0,
672            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
673        };
674        // SAFETY: `a_dev` is a stack of batch column-major m×k matrices,
675        // `b_dev` is one shared column-major k×n matrix with zero batch stride,
676        // and `out_dev` is a stack of batch column-major m×n outputs.
677        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
678        let out_col = stream.clone_dtoh(&out_dev).ok()?;
679        from_col_major_batch(&out_col, batch, m, n)
680    }
681
682    /// A·Bᵀ strided-batched GEMM on a specific device ordinal. As with the
683    /// broadcast variant, the ordinal's context is expected to be bound on the
684    /// calling thread (multi-GPU worker or single-device dispatcher).
685    #[inline]
686    pub(crate) fn gemm_abt_strided_batched_cuda(
687        ordinal: usize,
688        a: ArrayView3<'_, f64>,
689        b: ArrayView3<'_, f64>,
690    ) -> Option<Array3<f64>> {
691        let (batch, m, k) = a.dim();
692        let (batch_b, n, k_b) = b.dim();
693        if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
694            return None;
695        }
696        let (stream, blas) = stream_and_blas_for(ordinal)?;
697        let a_col = to_col_major_batch(a);
698        let b_col = to_col_major_batch(b);
699        let a_dev = stream.clone_htod(&a_col).ok()?;
700        let b_dev = stream.clone_htod(&b_col).ok()?;
701        let mut out_dev = stream
702            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
703            .ok()?;
704        let cfg = StridedBatchedConfig::<f64> {
705            gemm: GemmConfig::<f64> {
706                transa: cublasOperation_t::CUBLAS_OP_N,
707                transb: cublasOperation_t::CUBLAS_OP_T,
708                m: to_i32(m)?,
709                n: to_i32(n)?,
710                k: to_i32(k)?,
711                alpha: 1.0,
712                lda: to_i32(m)?,
713                ldb: to_i32(n)?,
714                beta: 0.0,
715                ldc: to_i32(m)?,
716            },
717            batch_size: to_i32(batch)?,
718            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
719            stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
720            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
721        };
722        // SAFETY: each batch item is column-major. The B batch stores n×k
723        // matrices and cuBLAS transposes each to k×n before multiplication.
724        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
725        let out_col = stream.clone_dtoh(&out_dev).ok()?;
726        from_col_major_batch(&out_col, batch, m, n)
727    }
728
729    #[inline]
730    pub(crate) fn gemv_cuda(
731        runtime: &GpuRuntime,
732        a: ArrayView2<'_, f64>,
733        v: ArrayView1<'_, f64>,
734        trans_a: bool,
735    ) -> Option<Array1<f64>> {
736        let (rows, cols) = a.dim();
737        let out_len = if trans_a { cols } else { rows };
738        let needed = if trans_a { rows } else { cols };
739        if out_len == 0 || needed == 0 || v.len() != needed {
740            return None;
741        }
742        let (stream, blas) = stream_and_blas(runtime)?;
743        let a_col = to_col_major(&a);
744        let a_dev = stream.clone_htod(&*a_col).ok()?;
745        let v_host = vector_values(v);
746        let v_dev = stream.clone_htod(&v_host).ok()?;
747        let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
748        let cfg = GemvConfig::<f64> {
749            trans: if trans_a {
750                cublasOperation_t::CUBLAS_OP_T
751            } else {
752                cublasOperation_t::CUBLAS_OP_N
753            },
754            m: to_i32(rows)?,
755            n: to_i32(cols)?,
756            alpha: 1.0,
757            lda: to_i32(rows)?,
758            incx: 1,
759            beta: 0.0,
760            incy: 1,
761        };
762        // SAFETY: dimensions and vector length match the cuBLAS GEMV contract.
763        unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
764        Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
765    }
766
767    #[inline]
768    pub fn xt_diag_x_cuda(
769        runtime: &GpuRuntime,
770        x: ArrayView2<'_, f64>,
771        w: ArrayView1<'_, f64>,
772    ) -> Option<Array2<f64>> {
773        let (rows, cols) = x.dim();
774        if rows == 0 || cols == 0 || rows != w.len() {
775            return None;
776        }
777        weighted_crossprod(runtime, x, w, x)
778    }
779
780    #[inline]
781    pub(crate) fn xt_diag_x_on_ordinal_cuda(
782        ordinal: usize,
783        x: ArrayView2<'_, f64>,
784        w: ArrayView1<'_, f64>,
785    ) -> Option<Array2<f64>> {
786        let (rows, cols) = x.dim();
787        if rows == 0 || cols == 0 || rows != w.len() {
788            return None;
789        }
790        weighted_crossprod_for(ordinal, x, w, x)
791    }
792
793    #[inline]
794    pub fn xt_diag_y_cuda(
795        runtime: &GpuRuntime,
796        x: ArrayView2<'_, f64>,
797        w: ArrayView1<'_, f64>,
798        y: ArrayView2<'_, f64>,
799    ) -> Option<Array2<f64>> {
800        weighted_crossprod(runtime, x, w, y)
801    }
802
803    #[inline]
804    pub(crate) fn joint_hessian_2x2_cuda(
805        runtime: &GpuRuntime,
806        x_a: ArrayView2<'_, f64>,
807        x_b: ArrayView2<'_, f64>,
808        w_aa: ArrayView1<'_, f64>,
809        w_ab: ArrayView1<'_, f64>,
810        w_bb: ArrayView1<'_, f64>,
811    ) -> Option<Array2<f64>> {
812        let (rows, pa) = x_a.dim();
813        let (rows_b, pb) = x_b.dim();
814        let total = pa.checked_add(pb)?;
815        if rows == 0
816            || total == 0
817            || rows != rows_b
818            || rows != w_aa.len()
819            || rows != w_ab.len()
820            || rows != w_bb.len()
821        {
822            return None;
823        }
824
825        let mut out = Array2::<f64>::zeros((total, total));
826        if pa > 0 {
827            let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
828            assign_block(&mut out, 0, 0, &aa);
829        }
830        if pa > 0 && pb > 0 {
831            let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
832            assign_block(&mut out, 0, pa, &ab);
833        }
834        if pb > 0 {
835            let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
836            assign_block(&mut out, pa, pa, &bb);
837        }
838        mirror_upper_to_lower(&mut out);
839        Some(out)
840    }
841
842    #[inline]
843    pub(crate) fn trsm_cuda(
844        runtime: &GpuRuntime,
845        triangular: ArrayView2<'_, f64>,
846        rhs: ArrayView2<'_, f64>,
847        upper: bool,
848    ) -> Option<Array2<f64>> {
849        let (n, n2) = triangular.dim();
850        if n == 0 || n != n2 || rhs.nrows() != n {
851            return None;
852        }
853        let nrhs = rhs.ncols();
854        let (stream, blas) = stream_and_blas(runtime)?;
855        let tri_col = to_col_major(&triangular);
856        let rhs_col = to_col_major(&rhs);
857        let tri_dev = stream.clone_htod(&*tri_col).ok()?;
858        let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
859        let alpha = 1.0_f64;
860        let handle = *blas.handle();
861        {
862            let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
863            let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
864            // SAFETY: triangular is n×n and rhs is n×nrhs in column-major device
865            // buffers. cublasDtrsm overwrites rhs with A^{-1} rhs.
866            let status = unsafe {
867                cudarc::cublas::sys::cublasDtrsm_v2(
868                    handle,
869                    cublasSideMode_t::CUBLAS_SIDE_LEFT,
870                    if upper {
871                        cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
872                    } else {
873                        cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
874                    },
875                    cublasOperation_t::CUBLAS_OP_N,
876                    cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
877                    to_i32(n)?,
878                    to_i32(nrhs)?,
879                    &alpha,
880                    tri_ptr as *const f64,
881                    to_i32(n)?,
882                    rhs_ptr as *mut f64,
883                    to_i32(n)?,
884                )
885            };
886            if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
887                return None;
888            }
889        };
890        let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
891        from_col_major(&out_col, n, nrhs)
892    }
893}
894
895#[cfg(target_os = "linux")]
896pub(crate) use cuda_impl::{
897    ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
898    gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
899};
900// Cross-crate cuBLAS entry points (gam-models BMS Hessian paths call these
901// directly): the #1521 carve promoted the sibling solver/GEMM entry points to
902// `pub` but left these two `pub(crate)`, so they were invisible to their
903// out-of-crate callers (E0603) on the linux cuda build that the workspace
904// `cargo check` config does not exercise. Promote to match the rest of the
905// cross-crate cuBLAS surface.
906#[cfg(target_os = "linux")]
907pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};