Skip to main content

gam_gpu/
blas.rs

1//! Device BLAS surface for the cudarc-backed dense kernels.
2//!
3//! The public surface here is the lowest level of the GPU dispatch stack: it
4//! takes ndarray views, copies them to a device buffer, calls a cuBLAS / kernel
5//! routine, and returns the host result. The cudarc-backed implementations
6//! always compile (cudarc dynamically loads `libcuda` at runtime via the
7//! `fallback-dynamic-loading` feature), and dispatch is gated at runtime on
8//! `super::device_runtime::GpuRuntime::global()` — when no device is probed the
9//! status enum advertises `CudaUnavailable` and callers fall back to CPU.
10//!
11//! The implementations route through `super::device_runtime::cuda_context_for` and
12//! the cudarc 0.19 cuBLAS API. Any transient backend failure (OOM, launch
13//! error, …) is converted to `None` so the auto-dispatch shim in
14//! `super::linalg` falls back to the CPU fast path without disturbing
15//! numerics.
16
17pub fn blas_backend_status() -> super::CudaBackendStatus {
18    super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23    use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25    use crate::driver::{array_from_row_major, from_col_major, to_col_major, to_i32, to_row_major};
26
27    use super::super::device_runtime::GpuRuntime;
28    use cudarc::cublas::sys::{
29        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
30    };
31    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
32    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
33    use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
34    use std::sync::Arc;
35
36    /// Create a fresh stream + cuBLAS handle bound to a specific device
37    /// ordinal. This is the per-ordinal entry point used by multi-GPU fan-out
38    /// (`super::super::pool::scatter_batched` workers): the worker thread has
39    /// already bound that ordinal's context, and the stream/handle created here
40    /// target the same device. The single-device helper below is the
41    /// primary-ordinal specialization.
42    #[inline]
43    pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
44        let stream = super::super::device_runtime::cuda_context_for(ordinal)?
45            .new_stream()
46            .ok()?;
47        let blas = CudaBlas::new(stream.clone()).ok()?;
48        Some((stream, blas))
49    }
50
51    #[inline]
52    fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
53        stream_and_blas_for(runtime.device.ordinal)
54    }
55
56    #[inline]
57    fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
58        v.iter().copied().collect()
59    }
60
61    #[inline]
62    fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
63        let (batch_len, rows, cols) = batch.dim();
64        let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
65        for matrix in batch.axis_iter(Axis(0)) {
66            out.extend(to_col_major(&matrix).iter().copied());
67        }
68        out
69    }
70
71    #[inline]
72    fn from_col_major_batch(
73        data: &[f64],
74        batch: usize,
75        rows: usize,
76        cols: usize,
77    ) -> Option<Array3<f64>> {
78        if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
79            return None;
80        }
81        let mut out = Array3::<f64>::zeros((batch, rows, cols));
82        let matrix_len = rows.checked_mul(cols)?;
83        for batch_idx in 0..batch {
84            let base = batch_idx.checked_mul(matrix_len)?;
85            for col in 0..cols {
86                for row in 0..rows {
87                    out[[batch_idx, row, col]] = data[base + col * rows + row];
88                }
89            }
90        }
91        Some(out)
92    }
93
94    #[inline]
95    fn row_scale_device(
96        blas: &CudaBlas,
97        stream: &Arc<CudaStream>,
98        matrix_dev: &CudaSlice<f64>,
99        weights_dev: &CudaSlice<f64>,
100        scaled_dev: &mut CudaSlice<f64>,
101        rows: usize,
102        cols: usize,
103    ) -> Option<()> {
104        let rows_i = to_i32(rows)?;
105        let cols_i = to_i32(cols)?;
106        let handle = *blas.handle();
107        let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
108        let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
109        let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
110        // SAFETY: all device slices are on this stream/context. `matrix_dev`
111        // and `scaled_dev` are rows×cols column-major matrices with lda/ldc
112        // equal to rows; `weights_dev` has one contiguous value per row.
113        let status = unsafe {
114            cudarc::cublas::sys::cublasDdgmm(
115                handle,
116                cublasSideMode_t::CUBLAS_SIDE_LEFT,
117                rows_i,
118                cols_i,
119                matrix_ptr as *const f64,
120                rows_i,
121                weights_ptr as *const f64,
122                1,
123                scaled_ptr as *mut f64,
124                rows_i,
125            )
126        };
127        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
128            Some(())
129        } else {
130            None
131        }
132    }
133
134    #[inline]
135    fn weighted_crossprod(
136        runtime: &GpuRuntime,
137        left: ArrayView2<'_, f64>,
138        weights: ArrayView1<'_, f64>,
139        right: ArrayView2<'_, f64>,
140    ) -> Option<Array2<f64>> {
141        weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
142    }
143
144    #[inline]
145    fn weighted_crossprod_for(
146        ordinal: usize,
147        left: ArrayView2<'_, f64>,
148        weights: ArrayView1<'_, f64>,
149        right: ArrayView2<'_, f64>,
150    ) -> Option<Array2<f64>> {
151        let (rows, left_cols) = left.dim();
152        let (right_rows, right_cols) = right.dim();
153        if rows == 0
154            || left_cols == 0
155            || right_cols == 0
156            || rows != right_rows
157            || rows != weights.len()
158        {
159            return None;
160        }
161
162        let (stream, blas) = stream_and_blas_for(ordinal)?;
163        // #1412: the symmetric Gram `Xᵀ·diag(w)·X` (xt_diag_x) passes the SAME
164        // array as `left` and `right`. Detect that (identical data pointer +
165        // shape) and stage `X` ONCE instead of column-majoring and H2D-uploading
166        // two byte-identical n×p copies — halving the dominant H2D for the Gram.
167        // The GEMM operands are unchanged (`left_dev` doubles as the row-scale
168        // source), so the result is bit-identical to the two-upload path.
169        let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
170            && left.dim() == right.dim()
171            && left.strides() == right.strides();
172        let left_col = to_col_major(&left);
173        let weights_host = vector_values(weights);
174        let left_dev = stream.clone_htod(&*left_col).ok()?;
175        // Symmetric Gram: `right` IS `left`, so row-scale directly from the
176        // single resident `left_dev` and never upload a second n×p copy. The
177        // asymmetric path uploads `right` as before.
178        let right_dev = if same_operand {
179            None
180        } else {
181            let right_col = to_col_major(&right);
182            Some(stream.clone_htod(&*right_col).ok()?)
183        };
184        let weights_dev = stream.clone_htod(&weights_host).ok()?;
185        let mut weighted_right_dev = stream
186            .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
187            .ok()?;
188        row_scale_device(
189            &blas,
190            &stream,
191            right_dev.as_ref().unwrap_or(&left_dev),
192            &weights_dev,
193            &mut weighted_right_dev,
194            rows,
195            right_cols,
196        )?;
197
198        let mut out_dev = stream
199            .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
200            .ok()?;
201        let cfg = GemmConfig::<f64> {
202            transa: cublasOperation_t::CUBLAS_OP_T,
203            transb: cublasOperation_t::CUBLAS_OP_N,
204            m: to_i32(left_cols)?,
205            n: to_i32(right_cols)?,
206            k: to_i32(rows)?,
207            alpha: 1.0,
208            lda: to_i32(rows)?,
209            ldb: to_i32(rows)?,
210            beta: 0.0,
211            ldc: to_i32(left_cols)?,
212        };
213        // SAFETY: cfg computes leftᵀ (left_cols×rows) times weighted_right
214        // (rows×right_cols) into a left_cols×right_cols column-major output.
215        unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
216        let out_col = stream.clone_dtoh(&out_dev).ok()?;
217        from_col_major(&out_col, left_cols, right_cols)
218    }
219
220    /// #1017 Phase 3: a device-resident design matrix `X` whose `n×p` values are
221    /// uploaded to the device ONCE and reused across many `Xᵀ·diag(w)·X` Gram
222    /// evaluations.
223    ///
224    /// The per-call [`xt_diag_x_cuda`] path re-uploads the full `n×p` `X` (and a
225    /// second copy as the `right` operand) on EVERY call. For the SAE / IRLS
226    /// inner loop — where `X` is frozen across weight updates and the Gram is
227    /// rebuilt once per Newton/PIRLS step — that H2D staging dominates the wall
228    /// clock (measured #1412: the `XtWX` GEMM is ~98% of the pipeline at <20% GPU
229    /// utilisation, i.e. the device is starved by the per-call upload, not the
230    /// arithmetic). Uploading `X` once and crossing only the `n`-vector `w` (and
231    /// the `p×p` result) per call removes that ping-pong: the resident `X` is
232    /// `n·p` doubles vs the per-call `w` of `n` doubles, so the amortised
233    /// transfer per Gram drops by a factor of `p`.
234    pub(crate) struct ResidentWeightedGram {
235        stream: Arc<CudaStream>,
236        blas: CudaBlas,
237        x_dev: CudaSlice<f64>,
238        rows: usize,
239        cols: usize,
240    }
241
242    impl ResidentWeightedGram {
243        /// Upload `x` (`n×p`) to `ordinal` once, column-major, and keep it
244        /// resident. Returns `None` on a degenerate shape or any device failure
245        /// (the caller falls back to the per-call CPU/GPU path).
246        pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
247            let (rows, cols) = x.dim();
248            if rows == 0 || cols == 0 {
249                return None;
250            }
251            let (stream, blas) = stream_and_blas_for(ordinal)?;
252            let x_col = to_col_major(&x);
253            let x_dev = stream.clone_htod(&*x_col).ok()?;
254            Some(Self {
255                stream,
256                blas,
257                x_dev,
258                rows,
259                cols,
260            })
261        }
262
263        #[inline]
264        pub(crate) fn dims(&self) -> (usize, usize) {
265            (self.rows, self.cols)
266        }
267
268        /// Compute `Xᵀ·diag(w)·X` reusing the resident `X`. Only `w` (`n`
269        /// doubles) crosses H2D and only the `p×p` Gram crosses D2H. The
270        /// arithmetic is bit-identical to [`xt_diag_x_cuda`] on the same device
271        /// (same `cublasDdgmm` row-scale + same `gemm` reduction order).
272        pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
273            if w.len() != self.rows {
274                return None;
275            }
276            let weights_host = vector_values(w);
277            let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
278            let mut weighted_dev = self
279                .stream
280                .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
281                .ok()?;
282            row_scale_device(
283                &self.blas,
284                &self.stream,
285                &self.x_dev,
286                &weights_dev,
287                &mut weighted_dev,
288                self.rows,
289                self.cols,
290            )?;
291            let mut out_dev = self
292                .stream
293                .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
294                .ok()?;
295            let cfg = GemmConfig::<f64> {
296                transa: cublasOperation_t::CUBLAS_OP_T,
297                transb: cublasOperation_t::CUBLAS_OP_N,
298                m: to_i32(self.cols)?,
299                n: to_i32(self.cols)?,
300                k: to_i32(self.rows)?,
301                alpha: 1.0,
302                lda: to_i32(self.rows)?,
303                ldb: to_i32(self.rows)?,
304                beta: 0.0,
305                ldc: to_i32(self.cols)?,
306            };
307            // SAFETY: `x_dev` is the resident n×p column-major design; cfg forms
308            // Xᵀ (p×n) · weighted (n×p) → a p×p column-major Gram.
309            unsafe {
310                self.blas
311                    .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
312            }
313            .ok()?;
314            let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
315            from_col_major(&out_col, self.cols, self.cols)
316        }
317
318        /// Compute the resident weighted Gram `G = Xᵀ·diag(w)·X + ridge·I`,
319        /// factor it (cuSOLVER POTRF), and solve `G·β = rhs` — keeping `G`, its
320        /// Cholesky factor, and the RHS all DEVICE-RESIDENT. Only `w` (`n`),
321        /// `rhs` (`p`), and the result `β` (`p`) cross the PCIe boundary; the
322        /// `p×p` Gram is NEVER downloaded.
323        ///
324        /// This is the #1017 Phase-3 ceiling fix for the normal-equations solve:
325        /// the per-call [`gram`] still pays a `p×p` D2H (134 MB at p=4096 — the
326        /// next bottleneck once `X` is resident), whereas the SAE/IRLS inner step
327        /// only needs the `p`-vector `β = (XᵀWX+λ)⁻¹ XᵀWz`. Chaining
328        /// row-scale→GEMM→POTRF→TRSM on-device and returning only `β` removes the
329        /// Gram transfer entirely.
330        ///
331        /// `ridge` (e.g. the penalty diagonal `λ` or a Tikhonov floor) is seeded
332        /// as `ridge·I` on the device and the Gram is GEMM-accumulated onto it
333        /// (`beta = 1`), so the diagonal bump never costs a Gram round-trip.
334        /// Returns `None` on shape mismatch, a non-PD factorisation, or any
335        /// device failure (the caller falls back to the CPU solve).
336        pub(crate) fn solve_psd_normal_equations(
337            &self,
338            w: ArrayView1<'_, f64>,
339            rhs: ArrayView1<'_, f64>,
340            ridge: f64,
341        ) -> Option<Array1<f64>> {
342            if w.len() != self.rows || rhs.len() != self.cols {
343                return None;
344            }
345            let p = self.cols;
346
347            // weighted = diag(w) · X  (resident X row-scaled).
348            let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
349            let mut weighted_dev = self
350                .stream
351                .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
352                .ok()?;
353            row_scale_device(
354                &self.blas,
355                &self.stream,
356                &self.x_dev,
357                &weights_dev,
358                &mut weighted_dev,
359                self.rows,
360                p,
361            )?;
362
363            // Pre-seed G with `ridge·I` on the device, then GEMM-accumulate
364            // `XᵀW X` onto it with `beta = 1.0`. The Gram is formed and stays
365            // device-resident: `ridge·I` is a one-time H2D upload (the only way
366            // to set a diagonal without an NVRTC kernel), and crucially the p×p
367            // Gram is NEVER read back — only `β` returns. `ridge·I` upload is
368            // bandwidth-trivial vs the avoided per-solve Gram download.
369            let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
370            for i in 0..p {
371                ridge_init[i * p + i] = ridge;
372            }
373            let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
374            let cfg = GemmConfig::<f64> {
375                transa: cublasOperation_t::CUBLAS_OP_T,
376                transb: cublasOperation_t::CUBLAS_OP_N,
377                m: to_i32(p)?,
378                n: to_i32(p)?,
379                k: to_i32(self.rows)?,
380                alpha: 1.0,
381                lda: to_i32(self.rows)?,
382                ldb: to_i32(self.rows)?,
383                // Accumulate onto the resident ridge·I seed.
384                beta: 1.0,
385                ldc: to_i32(p)?,
386            };
387            // SAFETY: resident n×p X and the n×p weighted buffer form a p×p Gram;
388            // beta=1 accumulates Xᵀ(WX) onto the resident ridge·I in g_dev.
389            unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
390
391            // POTRF(G) → lower factor L, resident in g_dev.
392            let solver = DnHandle::new(self.stream.clone()).ok()?;
393            let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
394            if info != 0 {
395                // Not positive-definite at pivot `info`; caller falls back.
396                return None;
397            }
398
399            // Solve L Lᵀ β = rhs via two triangular solves, β resident in rhs_dev.
400            let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
401            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; // L y = rhs
402            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; // Lᵀ β = y
403
404            // Download ONLY the p-vector solution.
405            let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
406            Some(Array1::from_vec(beta_host))
407        }
408    }
409
410    /// Single cuSOLVER `DPOTRF` (lower) of a resident `p×p` column-major matrix,
411    /// factored in place. Returns the cuSOLVER `info` (0 = success, k>0 = the
412    /// leading minor of order k is not PD). Mirrors the arrow-Schur frame POTRF.
413    fn potrf_single_dev(
414        solver: &DnHandle,
415        stream: &Arc<CudaStream>,
416        p: usize,
417        matrix: &mut CudaSlice<f64>,
418    ) -> Option<i32> {
419        let p_i = to_i32(p)?;
420        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
421        let mut lwork = 0_i32;
422        {
423            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
424            // SAFETY: buffer-size query against a live p×p column-major matrix.
425            let status = unsafe {
426                cusolver_sys::cusolverDnDpotrf_bufferSize(
427                    solver.cu(),
428                    uplo,
429                    p_i,
430                    mat_ptr as *mut f64,
431                    p_i,
432                    &mut lwork,
433                )
434            };
435            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
436                return None;
437            }
438        }
439        let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
440        let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
441        {
442            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
443            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
444            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
445            // SAFETY: all buffers live on this stream; matrix is p×p column-major.
446            let status = unsafe {
447                cusolver_sys::cusolverDnDpotrf(
448                    solver.cu(),
449                    uplo,
450                    p_i,
451                    mat_ptr as *mut f64,
452                    p_i,
453                    work_ptr as *mut f64,
454                    lwork,
455                    info_ptr as *mut i32,
456                )
457            };
458            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
459                return None;
460            }
461        }
462        let info_host = stream.clone_dtoh(&info_dev).ok()?;
463        info_host.first().copied()
464    }
465
466    /// Triangular solve `op(L)·x = b` for a single `p`-vector RHS against a
467    /// resident lower Cholesky factor `L` (`p×p` column-major), in place over
468    /// `rhs`. `transposed` selects `Lᵀ` (the second back-substitution).
469    fn trsm_single_vec(
470        blas: &CudaBlas,
471        stream: &Arc<CudaStream>,
472        p: usize,
473        l: &CudaSlice<f64>,
474        rhs: &mut CudaSlice<f64>,
475        transposed: bool,
476    ) -> Option<()> {
477        let alpha = 1.0_f64;
478        let p_i = to_i32(p)?;
479        let handle = *blas.handle();
480        let (l_ptr, _l_rec) = l.device_ptr(stream);
481        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
482        // SAFETY: p×p lower factor and a single p-vector RHS, both resident.
483        let status = unsafe {
484            cudarc::cublas::sys::cublasDtrsm_v2(
485                handle,
486                cublasSideMode_t::CUBLAS_SIDE_LEFT,
487                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
488                if transposed {
489                    cublasOperation_t::CUBLAS_OP_T
490                } else {
491                    cublasOperation_t::CUBLAS_OP_N
492                },
493                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
494                p_i,
495                1,
496                &alpha,
497                l_ptr as *const f64,
498                p_i,
499                rhs_ptr as *mut f64,
500                p_i,
501            )
502        };
503        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
504            Some(())
505        } else {
506            None
507        }
508    }
509
510    #[inline]
511    fn assign_block(
512        out: &mut Array2<f64>,
513        row_offset: usize,
514        col_offset: usize,
515        block: &Array2<f64>,
516    ) {
517        let (rows, cols) = block.dim();
518        for col in 0..cols {
519            for row in 0..rows {
520                out[[row_offset + row, col_offset + col]] = block[[row, col]];
521            }
522        }
523    }
524
525    #[inline]
526    fn mirror_upper_to_lower(out: &mut Array2<f64>) {
527        let n = out.nrows();
528        for row in 0..n {
529            for col in 0..row {
530                out[[row, col]] = out[[col, row]];
531            }
532        }
533    }
534
535    #[inline]
536    pub(crate) fn gemm_cuda(
537        runtime: &GpuRuntime,
538        a: ArrayView2<'_, f64>,
539        b: ArrayView2<'_, f64>,
540        trans_a: bool,
541        trans_b: bool,
542    ) -> Option<Array2<f64>> {
543        gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
544    }
545
546    /// Dense GEMM (optionally transposing either operand) on a specific device
547    /// ordinal. The ordinal's context is expected to be bound on the calling
548    /// thread (pool-tiled callers via `super::super::pool::scatter_batched`, or
549    /// the single-device dispatcher through [`gemm_cuda`]). Semantics are
550    /// identical to [`gemm_cuda`]; only the target device differs.
551    #[inline]
552    pub(crate) fn gemm_on_ordinal_cuda(
553        ordinal: usize,
554        a: ArrayView2<'_, f64>,
555        b: ArrayView2<'_, f64>,
556        trans_a: bool,
557        trans_b: bool,
558    ) -> Option<Array2<f64>> {
559        let (a_rows, a_cols) = a.dim();
560        let (b_rows, b_cols) = b.dim();
561        let (m, k_a) = if trans_a {
562            (a_cols, a_rows)
563        } else {
564            (a_rows, a_cols)
565        };
566        let (k_b, n) = if trans_b {
567            (b_cols, b_rows)
568        } else {
569            (b_rows, b_cols)
570        };
571        if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
572            return None;
573        }
574        let (stream, blas) = stream_and_blas_for(ordinal)?;
575        // Host-transpose-free path. The row-major output buffer of
576        // `C = op(A)·op(B)` (shape m×n) is bit-identical to the column-major
577        // buffer of `Cᵀ = op(B)ᵀ·op(A)ᵀ` (shape n×m). A row-major buffer,
578        // reinterpreted column-major, is already the transpose of its logical
579        // matrix — so uploading `a`/`b` row-major (a borrow when C-contiguous)
580        // gives cuBLAS `Aᵀ`/`Bᵀ` for free, and downloading straight into a
581        // row-major `Array2` skips the result permutation too. This removes the
582        // two O(rows·cols) scalar `to_col_major`/`from_col_major` passes that
583        // dominated tall-skinny GEMMs (e.g. the 200000×200 Wahba design
584        // reduction) and made the device path slower than the host SIMD GEMM.
585        //
586        // Uploading the row-major buffer of an `(r×c)` array and declaring it
587        // column-major with leading dim `c` hands cuBLAS exactly that array's
588        // transpose (a `(c×r)` col-major matrix). So with
589        //   X = b's row-major buffer  → col-major X = bᵀ  (rows = b_cols),
590        //   Y = a's row-major buffer  → col-major Y = aᵀ  (rows = a_cols),
591        // cuBLAS's `out = opX(X)·opY(Y)` yields `Cᵀ` when
592        //   opX = trans_b ? T : N,   opY = trans_a ? T : N,
593        //   (M,N,K) = (n, m, k),   lda = b_cols, ldb = a_cols, ldc = n.
594        // The column-major `Cᵀ` buffer (n rows) is bit-identical to the
595        // row-major `C` buffer (m×n), so the download wraps with no permute.
596        let b_rm = to_row_major(&b);
597        let a_rm = to_row_major(&a);
598        let x_dev = stream.clone_htod(&*b_rm).ok()?;
599        let y_dev = stream.clone_htod(&*a_rm).ok()?;
600        let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
601        let cfg = GemmConfig::<f64> {
602            transa: if trans_b {
603                cublasOperation_t::CUBLAS_OP_T
604            } else {
605                cublasOperation_t::CUBLAS_OP_N
606            },
607            transb: if trans_a {
608                cublasOperation_t::CUBLAS_OP_T
609            } else {
610                cublasOperation_t::CUBLAS_OP_N
611            },
612            m: to_i32(n)?,
613            n: to_i32(m)?,
614            k: to_i32(k_a)?,
615            alpha: 1.0,
616            // Leading dim of each physically-stored col-major operand =
617            // its row count: B̌ has `b_cols` rows, Ǎ has `a_cols` rows.
618            lda: to_i32(b_cols)?,
619            ldb: to_i32(a_cols)?,
620            beta: 0.0,
621            ldc: to_i32(n)?,
622        };
623        // SAFETY: dims validated above; buffers carry exactly the row counts
624        // declared as leading dimensions.
625        unsafe { blas.gemm(cfg, &x_dev, &y_dev, &mut out_dev) }.ok()?;
626        // `out_dev` is `Cᵀ` column-major == `C` row-major: wrap with no permute.
627        let out_rm = stream.clone_dtoh(&out_dev).ok()?;
628        array_from_row_major(out_rm, m, n)
629    }
630
631    /// Broadcast-B batched GEMM on a specific device ordinal. The caller
632    /// (`super::super::pool::scatter_batched` worker, or the single-device
633    /// dispatcher) supplies the ordinal whose context is already bound on this
634    /// thread; the stream/handle are created on that same device.
635    #[inline]
636    pub(crate) fn gemm_broadcast_b_batched_cuda(
637        ordinal: usize,
638        a: ArrayView3<'_, f64>,
639        b: ArrayView2<'_, f64>,
640    ) -> Option<Array3<f64>> {
641        let (batch, m, k) = a.dim();
642        let (b_rows, n) = b.dim();
643        if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
644            return None;
645        }
646        let (stream, blas) = stream_and_blas_for(ordinal)?;
647        let a_col = to_col_major_batch(a);
648        let b_col = to_col_major(&b);
649        let a_dev = stream.clone_htod(&a_col).ok()?;
650        let b_dev = stream.clone_htod(&*b_col).ok()?;
651        let mut out_dev = stream
652            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
653            .ok()?;
654        let cfg = StridedBatchedConfig::<f64> {
655            gemm: GemmConfig::<f64> {
656                transa: cublasOperation_t::CUBLAS_OP_N,
657                transb: cublasOperation_t::CUBLAS_OP_N,
658                m: to_i32(m)?,
659                n: to_i32(n)?,
660                k: to_i32(k)?,
661                alpha: 1.0,
662                lda: to_i32(m)?,
663                ldb: to_i32(k)?,
664                beta: 0.0,
665                ldc: to_i32(m)?,
666            },
667            batch_size: to_i32(batch)?,
668            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
669            stride_b: 0,
670            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
671        };
672        // SAFETY: `a_dev` is a stack of batch column-major m×k matrices,
673        // `b_dev` is one shared column-major k×n matrix with zero batch stride,
674        // and `out_dev` is a stack of batch column-major m×n outputs.
675        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
676        let out_col = stream.clone_dtoh(&out_dev).ok()?;
677        from_col_major_batch(&out_col, batch, m, n)
678    }
679
680    /// A·Bᵀ strided-batched GEMM on a specific device ordinal. As with the
681    /// broadcast variant, the ordinal's context is expected to be bound on the
682    /// calling thread (multi-GPU worker or single-device dispatcher).
683    #[inline]
684    pub(crate) fn gemm_abt_strided_batched_cuda(
685        ordinal: usize,
686        a: ArrayView3<'_, f64>,
687        b: ArrayView3<'_, f64>,
688    ) -> Option<Array3<f64>> {
689        let (batch, m, k) = a.dim();
690        let (batch_b, n, k_b) = b.dim();
691        if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
692            return None;
693        }
694        let (stream, blas) = stream_and_blas_for(ordinal)?;
695        let a_col = to_col_major_batch(a);
696        let b_col = to_col_major_batch(b);
697        let a_dev = stream.clone_htod(&a_col).ok()?;
698        let b_dev = stream.clone_htod(&b_col).ok()?;
699        let mut out_dev = stream
700            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
701            .ok()?;
702        let cfg = StridedBatchedConfig::<f64> {
703            gemm: GemmConfig::<f64> {
704                transa: cublasOperation_t::CUBLAS_OP_N,
705                transb: cublasOperation_t::CUBLAS_OP_T,
706                m: to_i32(m)?,
707                n: to_i32(n)?,
708                k: to_i32(k)?,
709                alpha: 1.0,
710                lda: to_i32(m)?,
711                ldb: to_i32(n)?,
712                beta: 0.0,
713                ldc: to_i32(m)?,
714            },
715            batch_size: to_i32(batch)?,
716            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
717            stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
718            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
719        };
720        // SAFETY: each batch item is column-major. The B batch stores n×k
721        // matrices and cuBLAS transposes each to k×n before multiplication.
722        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
723        let out_col = stream.clone_dtoh(&out_dev).ok()?;
724        from_col_major_batch(&out_col, batch, m, n)
725    }
726
727    #[inline]
728    pub(crate) fn gemv_cuda(
729        runtime: &GpuRuntime,
730        a: ArrayView2<'_, f64>,
731        v: ArrayView1<'_, f64>,
732        trans_a: bool,
733    ) -> Option<Array1<f64>> {
734        let (rows, cols) = a.dim();
735        let out_len = if trans_a { cols } else { rows };
736        let needed = if trans_a { rows } else { cols };
737        if out_len == 0 || needed == 0 || v.len() != needed {
738            return None;
739        }
740        let (stream, blas) = stream_and_blas(runtime)?;
741        let a_col = to_col_major(&a);
742        let a_dev = stream.clone_htod(&*a_col).ok()?;
743        let v_host = vector_values(v);
744        let v_dev = stream.clone_htod(&v_host).ok()?;
745        let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
746        let cfg = GemvConfig::<f64> {
747            trans: if trans_a {
748                cublasOperation_t::CUBLAS_OP_T
749            } else {
750                cublasOperation_t::CUBLAS_OP_N
751            },
752            m: to_i32(rows)?,
753            n: to_i32(cols)?,
754            alpha: 1.0,
755            lda: to_i32(rows)?,
756            incx: 1,
757            beta: 0.0,
758            incy: 1,
759        };
760        // SAFETY: dimensions and vector length match the cuBLAS GEMV contract.
761        unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
762        Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
763    }
764
765    #[inline]
766    pub fn xt_diag_x_cuda(
767        runtime: &GpuRuntime,
768        x: ArrayView2<'_, f64>,
769        w: ArrayView1<'_, f64>,
770    ) -> Option<Array2<f64>> {
771        let (rows, cols) = x.dim();
772        if rows == 0 || cols == 0 || rows != w.len() {
773            return None;
774        }
775        weighted_crossprod(runtime, x, w, x)
776    }
777
778    #[inline]
779    pub(crate) fn xt_diag_x_on_ordinal_cuda(
780        ordinal: usize,
781        x: ArrayView2<'_, f64>,
782        w: ArrayView1<'_, f64>,
783    ) -> Option<Array2<f64>> {
784        let (rows, cols) = x.dim();
785        if rows == 0 || cols == 0 || rows != w.len() {
786            return None;
787        }
788        weighted_crossprod_for(ordinal, x, w, x)
789    }
790
791    #[inline]
792    pub fn xt_diag_y_cuda(
793        runtime: &GpuRuntime,
794        x: ArrayView2<'_, f64>,
795        w: ArrayView1<'_, f64>,
796        y: ArrayView2<'_, f64>,
797    ) -> Option<Array2<f64>> {
798        weighted_crossprod(runtime, x, w, y)
799    }
800
801    #[inline]
802    pub(crate) fn joint_hessian_2x2_cuda(
803        runtime: &GpuRuntime,
804        x_a: ArrayView2<'_, f64>,
805        x_b: ArrayView2<'_, f64>,
806        w_aa: ArrayView1<'_, f64>,
807        w_ab: ArrayView1<'_, f64>,
808        w_bb: ArrayView1<'_, f64>,
809    ) -> Option<Array2<f64>> {
810        let (rows, pa) = x_a.dim();
811        let (rows_b, pb) = x_b.dim();
812        let total = pa.checked_add(pb)?;
813        if rows == 0
814            || total == 0
815            || rows != rows_b
816            || rows != w_aa.len()
817            || rows != w_ab.len()
818            || rows != w_bb.len()
819        {
820            return None;
821        }
822
823        let mut out = Array2::<f64>::zeros((total, total));
824        if pa > 0 {
825            let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
826            assign_block(&mut out, 0, 0, &aa);
827        }
828        if pa > 0 && pb > 0 {
829            let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
830            assign_block(&mut out, 0, pa, &ab);
831        }
832        if pb > 0 {
833            let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
834            assign_block(&mut out, pa, pa, &bb);
835        }
836        mirror_upper_to_lower(&mut out);
837        Some(out)
838    }
839
840    #[inline]
841    pub(crate) fn trsm_cuda(
842        runtime: &GpuRuntime,
843        triangular: ArrayView2<'_, f64>,
844        rhs: ArrayView2<'_, f64>,
845        upper: bool,
846    ) -> Option<Array2<f64>> {
847        let (n, n2) = triangular.dim();
848        if n == 0 || n != n2 || rhs.nrows() != n {
849            return None;
850        }
851        let nrhs = rhs.ncols();
852        let (stream, blas) = stream_and_blas(runtime)?;
853        let tri_col = to_col_major(&triangular);
854        let rhs_col = to_col_major(&rhs);
855        let tri_dev = stream.clone_htod(&*tri_col).ok()?;
856        let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
857        let alpha = 1.0_f64;
858        let handle = *blas.handle();
859        {
860            let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
861            let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
862            // SAFETY: triangular is n×n and rhs is n×nrhs in column-major device
863            // buffers. cublasDtrsm overwrites rhs with A^{-1} rhs.
864            let status = unsafe {
865                cudarc::cublas::sys::cublasDtrsm_v2(
866                    handle,
867                    cublasSideMode_t::CUBLAS_SIDE_LEFT,
868                    if upper {
869                        cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
870                    } else {
871                        cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
872                    },
873                    cublasOperation_t::CUBLAS_OP_N,
874                    cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
875                    to_i32(n)?,
876                    to_i32(nrhs)?,
877                    &alpha,
878                    tri_ptr as *const f64,
879                    to_i32(n)?,
880                    rhs_ptr as *mut f64,
881                    to_i32(n)?,
882                )
883            };
884            if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
885                return None;
886            }
887        };
888        let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
889        from_col_major(&out_col, n, nrhs)
890    }
891}
892
893#[cfg(target_os = "linux")]
894pub(crate) use cuda_impl::{
895    ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
896    gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
897};
898// Cross-crate cuBLAS entry points (gam-models BMS Hessian paths call these
899// directly): the #1521 carve promoted the sibling solver/GEMM entry points to
900// `pub` but left these two `pub(crate)`, so they were invisible to their
901// out-of-crate callers (E0603) on the linux cuda build that the workspace
902// `cargo check` config does not exercise. Promote to match the rest of the
903// cross-crate cuBLAS surface.
904#[cfg(target_os = "linux")]
905pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};